logo

基于PyTorch的不平衡数据集图像分类实战指南

作者:carzy2025.09.18 17:02浏览量:0

简介:本文针对不平衡数据集的图像分类问题,详细阐述使用PyTorch框架实现解决方案的全流程,涵盖数据预处理、模型构建、损失函数优化及评估方法,为开发者提供可落地的技术指导。

基于PyTorch的不平衡数据集图像分类实战指南

一、不平衡数据集的挑战与解决方案概述

在真实场景中,图像分类任务常面临类别样本数量差异巨大的问题。例如医疗影像中病变样本占比不足10%,自动驾驶中罕见障碍物样本稀缺。这种不平衡会导致模型偏向多数类,严重影响少数类的识别性能。

PyTorch作为深度学习领域的核心框架,提供了灵活的工具链解决该问题。本文将从数据层面、算法层面和评估层面系统阐述解决方案,结合代码示例展示完整实现路径。

二、数据预处理与增强策略

1. 类别权重计算

通过统计各类样本数量,计算类别权重用于后续损失函数调整:

  1. import numpy as np
  2. from collections import Counter
  3. def calculate_class_weights(labels):
  4. counter = Counter(labels)
  5. majority = max(counter.values())
  6. return {cls: majority/count for cls, count in counter.items()}
  7. # 示例:计算CIFAR-100中各类权重
  8. # labels = [...] # 样本标签列表
  9. # class_weights = calculate_class_weights(labels)

2. 智能数据增强

针对少数类实施更激进的数据增强策略:

  1. import torchvision.transforms as transforms
  2. from torchvision.transforms import RandomRotation, RandomHorizontalFlip
  3. class ImbalancedDataset(torch.utils.data.Dataset):
  4. def __init__(self, data, targets, is_minority):
  5. self.data = data
  6. self.targets = targets
  7. self.is_minority = is_minority
  8. # 少数类增强策略
  9. self.minority_transform = transforms.Compose([
  10. RandomRotation(30),
  11. RandomHorizontalFlip(p=0.8),
  12. transforms.ColorJitter(brightness=0.3, contrast=0.3)
  13. ])
  14. # 多数类基础增强
  15. self.majority_transform = transforms.Compose([
  16. RandomHorizontalFlip(p=0.5)
  17. ])
  18. def __getitem__(self, idx):
  19. img, target = self.data[idx], self.targets[idx]
  20. if self.is_minority[idx]:
  21. img = self.minority_transform(img)
  22. else:
  23. img = self.majority_transform(img)
  24. return img, target

3. 重采样技术实现

  • 过采样:对少数类进行重复采样或SMOTE生成新样本
    ```python
    from imblearn.over_sampling import SMOTE
    import numpy as np

def oversample_features(features, labels):
smote = SMOTE(random_state=42)
features_resampled, labels_resampled = smote.fit_resample(
features.reshape(-1, features.shape[-1]),
labels
)
return features_resampled.reshape(-1, *features.shape[1:]), labels_resampled

  1. - **欠采样**:随机减少多数类样本,需配合交叉验证避免信息丢失
  2. ## 三、模型架构优化策略
  3. ### 1. 损失函数改进
  4. PyTorch内置多种处理不平衡的损失函数:
  5. ```python
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. # 加权交叉熵
  9. class WeightedCrossEntropy(nn.Module):
  10. def __init__(self, class_weights):
  11. super().__init__()
  12. self.class_weights = torch.tensor(class_weights, dtype=torch.float32)
  13. def forward(self, outputs, targets):
  14. log_probs = F.log_softmax(outputs, dim=1)
  15. weights = self.class_weights[targets]
  16. loss = F.nll_loss(log_probs, targets, reduction='none')
  17. return (weights * loss).mean()
  18. # Focal Loss实现
  19. class FocalLoss(nn.Module):
  20. def __init__(self, alpha=0.25, gamma=2.0):
  21. super().__init__()
  22. self.alpha = alpha
  23. self.gamma = gamma
  24. def forward(self, outputs, targets):
  25. ce_loss = F.cross_entropy(outputs, targets, reduction='none')
  26. pt = torch.exp(-ce_loss)
  27. focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
  28. return focal_loss.mean()

2. 双分支网络架构

设计专门处理少数类的辅助分支:

  1. class DualBranchCNN(nn.Module):
  2. def __init__(self, base_model, num_classes):
  3. super().__init__()
  4. self.shared_features = base_model.features[:-2] # 共享特征提取
  5. # 多数类分支
  6. self.majority_branch = nn.Sequential(
  7. base_model.features[-2:],
  8. nn.AdaptiveAvgPool2d(1),
  9. nn.Flatten(),
  10. nn.Linear(512, num_classes)
  11. )
  12. # 少数类专用分支(更深结构)
  13. self.minority_branch = nn.Sequential(
  14. nn.Conv2d(256, 512, 3, padding=1),
  15. nn.ReLU(),
  16. nn.AdaptiveAvgPool2d(1),
  17. nn.Flatten(),
  18. nn.Linear(512, num_classes)
  19. )
  20. def forward(self, x):
  21. features = self.shared_features(x)
  22. majority_out = self.majority_branch(features)
  23. minority_out = self.minority_branch(features)
  24. # 动态权重融合
  25. alpha = 0.7 # 可学习参数
  26. return alpha * majority_out + (1-alpha) * minority_out

四、训练流程优化

1. 动态采样策略

实现基于难例挖掘的采样方法:

  1. def dynamic_sampling(dataset, batch_size, hard_ratio=0.3):
  2. # 假设已有难例索引列表hard_indices
  3. num_hard = int(batch_size * hard_ratio)
  4. num_easy = batch_size - num_hard
  5. # 随机选择难例和易例
  6. hard_batch = torch.utils.data.SubsetRandomSampler(
  7. hard_indices[:num_hard]
  8. )
  9. easy_batch = torch.utils.data.RandomSampler(
  10. dataset,
  11. num_samples=num_easy
  12. )
  13. # 合并采样器(需自定义BatchSampler)
  14. # ...

2. 学习率调整策略

针对不同类别样本数量调整优化器参数:

  1. def create_optimizer(model, class_counts, base_lr=0.001):
  2. param_groups = []
  3. for name, param in model.named_parameters():
  4. # 根据参数所属模块调整学习率
  5. if 'minority_branch' in name:
  6. # 少数类分支使用更高学习率
  7. lr = base_lr * 2
  8. else:
  9. lr = base_lr
  10. param_groups.append({
  11. 'params': param,
  12. 'lr': lr
  13. })
  14. return torch.optim.Adam(param_groups)

五、评估指标与可视化

1. 多维度评估体系

  1. from sklearn.metrics import classification_report, confusion_matrix
  2. import seaborn as sns
  3. import matplotlib.pyplot as plt
  4. def evaluate_model(model, test_loader, class_names):
  5. model.eval()
  6. all_preds, all_targets = [], []
  7. with torch.no_grad():
  8. for images, labels in test_loader:
  9. outputs = model(images)
  10. _, preds = torch.max(outputs, 1)
  11. all_preds.extend(preds.cpu().numpy())
  12. all_targets.extend(labels.cpu().numpy())
  13. # 生成分类报告
  14. print(classification_report(all_targets, all_preds, target_names=class_names))
  15. # 绘制混淆矩阵
  16. cm = confusion_matrix(all_targets, all_preds)
  17. plt.figure(figsize=(10,8))
  18. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  19. xticklabels=class_names, yticklabels=class_names)
  20. plt.xlabel('Predicted')
  21. plt.ylabel('Actual')
  22. plt.show()

2. 类别性能追踪

实现按类别监控的训练日志

  1. class ClassWiseLogger:
  2. def __init__(self, num_classes):
  3. self.num_classes = num_classes
  4. self.class_metrics = {
  5. 'accuracy': [[] for _ in range(num_classes)],
  6. 'loss': [[] for _ in range(num_classes)]
  7. }
  8. def update(self, epoch, class_idx, accuracy, loss):
  9. self.class_metrics['accuracy'][class_idx].append((epoch, accuracy))
  10. self.class_metrics['loss'][class_idx].append((epoch, loss))
  11. def plot_metrics(self):
  12. for cls in range(self.num_classes):
  13. # 绘制准确率曲线
  14. epochs, accs = zip(*self.class_metrics['accuracy'][cls])
  15. plt.plot(epochs, accs, label=f'Class {cls}')
  16. plt.legend()
  17. plt.show()

六、完整案例:CIFAR-100不平衡分类

1. 数据准备

  1. from torchvision.datasets import CIFAR100
  2. import torchvision.transforms as transforms
  3. # 创建不平衡数据集(示例:每类样本数按指数递减)
  4. transform = transforms.Compose([
  5. transforms.ToTensor(),
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  7. ])
  8. full_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
  9. # 手动创建不平衡数据集
  10. class_counts = [5000 // (2**i) for i in range(100)] # 指数递减
  11. imbalanced_data = []
  12. imbalanced_targets = []
  13. current_idx = 0
  14. for cls, count in enumerate(class_counts):
  15. cls_indices = [i for i, label in enumerate(full_dataset.targets)
  16. if label == cls][:count]
  17. imbalanced_data.extend([full_dataset.data[i] for i in cls_indices])
  18. imbalanced_targets.extend([cls]*len(cls_indices))
  19. current_idx += len(cls_indices)
  20. # 转换为PyTorch Dataset
  21. from torch.utils.data import TensorDataset
  22. import numpy as np
  23. # 需要将PIL图像转换为Tensor(此处简化处理)
  24. # 实际实现中需处理图像格式转换

2. 训练流程

  1. def train_model():
  2. # 设备配置
  3. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  4. # 模型初始化
  5. model = models.resnet18(pretrained=True)
  6. num_ftrs = model.fc.in_features
  7. model.fc = nn.Linear(num_ftrs, 100)
  8. model = model.to(device)
  9. # 损失函数(带类别权重)
  10. class_counts = [...] # 实际类别数量
  11. class_weights = calculate_class_weights(imbalanced_targets)
  12. criterion = WeightedCrossEntropy(class_weights).to(device)
  13. # 优化器
  14. optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
  15. scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
  16. # 数据加载
  17. dataset = CustomImbalancedDataset(...) # 实现前述数据增强
  18. train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
  19. # 训练循环
  20. for epoch in range(100):
  21. model.train()
  22. running_loss = 0.0
  23. for inputs, labels in train_loader:
  24. inputs, labels = inputs.to(device), labels.to(device)
  25. optimizer.zero_grad()
  26. outputs = model(inputs)
  27. loss = criterion(outputs, labels)
  28. loss.backward()
  29. optimizer.step()
  30. running_loss += loss.item()
  31. # 验证阶段(省略)
  32. # scheduler.step(val_loss)

七、最佳实践建议

  1. 渐进式解决方案:优先尝试数据增强和重采样,无效时再调整模型架构
  2. 类别分组策略:将相似类别合并处理,缓解极端不平衡问题
  3. 持续监控机制:建立按类别监控的训练仪表盘,及时发现性能异常
  4. 后处理校准:使用温度缩放(Temperature Scaling)调整预测概率
  5. 集成方法:结合多个模型的预测结果,提升少数类识别率

八、总结与展望

PyTorch为不平衡数据分类提供了灵活且强大的工具链。通过数据增强、损失函数改进和模型架构优化三管齐下,可有效提升少数类的识别性能。未来研究方向包括:

  • 自适应采样算法的进一步优化
  • 基于元学习的少数类学习方法
  • 跨数据集的不平衡问题迁移学习

开发者应根据具体场景选择合适的方法组合,并通过充分的实验验证确定最佳方案。

相关文章推荐

发表评论