Как обучить нейросеть на маленьком датасете: методы аугментации и transfer learning
-
Маленький датасет — частая проблема в компьютерном зрении и других задачах машинного обучения. Модель может быстро переобучиться на ограниченных данных, теряя обобщающую способность. В этом гайде мы расскажем, как бороться с этим с помощью аугментации данных и transfer learning, а также покажем рабочие примеры на Python.
Проблема малых данных
Если ваш датасет содержит меньше 10 000 изображений (или аналогичный объем данных для другой задачи), стандартные архитектуры нейросетей будут переобучаться. Например, сверточная сеть VGG16 имеет 138 миллионов параметров — обучать её на 1000 изображениях бессмысленно. Решение:
- Аугментация — искусственное увеличение датасета.
- Transfer learning — использование предобученных моделей.
Метод 1: Аугментация данных
Аугментация создает новые примеры из существующих данных с помощью трансформаций. Для изображений это повороты, отражения, изменение яркости и т.д.
Пример на PyTorch
from torchvision import transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder # Определяем аугментации transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(10), transforms.ColorJitter(brightness=0.2), transforms.ToTensor(), ]) # Загружаем датасет dataset = ImageFolder(root='path/to/train', transform=transform) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # Теперь dataloader выдает аугментированные изображения
Пример на TensorFlow/Keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator datagen = ImageDataGenerator( rotation_range=10, width_shift_range=0.1, height_shift_range=0.1, horizontal_flip=True, brightness_range=(0.8, 1.2) ) # Пример использования с моделью model.fit(datagen.flow(x_train, y_train, batch_size=32), epochs=10)
Совет: Не переборщите с аугментациями! Например, для медицинских изображений повороты могут искажать важные детали.
Метод 2: Transfer Learning
Используем предобученные модели (например, ResNet, EfficientNet) и дообучаем их на своем датасете.
Пример на PyTorch
import torch import torchvision.models as models # Загружаем предобученную модель model = models.resnet18(pretrained=True) # Замораживаем все слои for param in model.parameters(): param.requires_grad = False # Заменяем последний слой под нашу задачу num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, 2) # 2 класса # Обучаем только последний слой criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3) # Обучение for epoch in range(10): for inputs, labels in dataloader: outputs = model(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step()
Пример на TensorFlow/Keras
from tensorflow.keras.applications import ResNet50 from tensorflow.keras import layers, Model # Загружаем предобученную модель без верхнего слоя base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3)) # Замораживаем базу base_model.trainable = False # Добавляем свои слои x = layers.GlobalAveragePooling2D()(base_model.output) x = layers.Dense(128, activation='relu')(x) output = layers.Dense(1, activation='sigmoid')(x) model = Model(base_model.input, output) # Компилируем и обучаем model.compile(optimizer='adam', loss='binary_crossentropy') model.fit(train_dataset, epochs=10)
Совет: Если датасет очень мал (менее 1000 изображений), замораживайте все слои, кроме последних. Для больших датасетов можно разморозить и дообучить все слои.
Комбинация методов: Аугментация + Transfer Learning
Объедините оба подхода для максимального эффекта:
# PyTorch: Добавьте аугментации в преобразования transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Transfer learning с аугментациями model = models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False model.fc = torch.nn.Linear(512, 10) # Новый классификатор
Дополнительные советы
- Регуляризация: Используйте Dropout и L2-регуляризацию.
- Early Stopping: Прерывайте обучение при остановке роста метрики на валидации.
- Batch Size: Уменьшайте batch size, если данные малы (например, 8–16).
- Кросс-валидация: Используйте для оценки модели на малых данных.
Маленький датасет — не приговор. С помощью аугментации и transfer learning вы можете обучить качественную модель даже на 1000 изображениях. Экспериментируйте с комбинациями методов, следите за переобучением и используйте предобученные архитектуры.
Вопросы? Задавайте в комментариях на форуме!
Полезные ресурсы
- PyTorch Documentation
- Keras Applications
- Albumentations: Альтернатива torchvision.transforms
- Fast.ai Course: Transfer Learning
Готовы к экспериментам? Попробуйте применить эти методы в своих проектах!
© 2024 - 2025 RosDesk, Inc. Все права защищены.