基于PyTorch的不平衡数据集图像分类实战指南
2025.09.18 17:02浏览量:0简介:本文针对不平衡数据集的图像分类问题,详细阐述使用PyTorch框架实现解决方案的全流程,涵盖数据预处理、模型构建、损失函数优化及评估方法,为开发者提供可落地的技术指导。
基于PyTorch的不平衡数据集图像分类实战指南
一、不平衡数据集的挑战与解决方案概述
在真实场景中,图像分类任务常面临类别样本数量差异巨大的问题。例如医疗影像中病变样本占比不足10%,自动驾驶中罕见障碍物样本稀缺。这种不平衡会导致模型偏向多数类,严重影响少数类的识别性能。
PyTorch作为深度学习领域的核心框架,提供了灵活的工具链解决该问题。本文将从数据层面、算法层面和评估层面系统阐述解决方案,结合代码示例展示完整实现路径。
二、数据预处理与增强策略
1. 类别权重计算
通过统计各类样本数量,计算类别权重用于后续损失函数调整:
import numpy as np
from collections import Counter
def calculate_class_weights(labels):
counter = Counter(labels)
majority = max(counter.values())
return {cls: majority/count for cls, count in counter.items()}
# 示例:计算CIFAR-100中各类权重
# labels = [...] # 样本标签列表
# class_weights = calculate_class_weights(labels)
2. 智能数据增强
针对少数类实施更激进的数据增强策略:
import torchvision.transforms as transforms
from torchvision.transforms import RandomRotation, RandomHorizontalFlip
class ImbalancedDataset(torch.utils.data.Dataset):
def __init__(self, data, targets, is_minority):
self.data = data
self.targets = targets
self.is_minority = is_minority
# 少数类增强策略
self.minority_transform = transforms.Compose([
RandomRotation(30),
RandomHorizontalFlip(p=0.8),
transforms.ColorJitter(brightness=0.3, contrast=0.3)
])
# 多数类基础增强
self.majority_transform = transforms.Compose([
RandomHorizontalFlip(p=0.5)
])
def __getitem__(self, idx):
img, target = self.data[idx], self.targets[idx]
if self.is_minority[idx]:
img = self.minority_transform(img)
else:
img = self.majority_transform(img)
return img, target
3. 重采样技术实现
- 过采样:对少数类进行重复采样或SMOTE生成新样本
```python
from imblearn.over_sampling import SMOTE
import numpy as np
def oversample_features(features, labels):
smote = SMOTE(random_state=42)
features_resampled, labels_resampled = smote.fit_resample(
features.reshape(-1, features.shape[-1]),
labels
)
return features_resampled.reshape(-1, *features.shape[1:]), labels_resampled
- **欠采样**:随机减少多数类样本,需配合交叉验证避免信息丢失
## 三、模型架构优化策略
### 1. 损失函数改进
PyTorch内置多种处理不平衡的损失函数:
```python
import torch.nn as nn
import torch.nn.functional as F
# 加权交叉熵
class WeightedCrossEntropy(nn.Module):
def __init__(self, class_weights):
super().__init__()
self.class_weights = torch.tensor(class_weights, dtype=torch.float32)
def forward(self, outputs, targets):
log_probs = F.log_softmax(outputs, dim=1)
weights = self.class_weights[targets]
loss = F.nll_loss(log_probs, targets, reduction='none')
return (weights * loss).mean()
# Focal Loss实现
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, outputs, targets):
ce_loss = F.cross_entropy(outputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
return focal_loss.mean()
2. 双分支网络架构
设计专门处理少数类的辅助分支:
class DualBranchCNN(nn.Module):
def __init__(self, base_model, num_classes):
super().__init__()
self.shared_features = base_model.features[:-2] # 共享特征提取
# 多数类分支
self.majority_branch = nn.Sequential(
base_model.features[-2:],
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, num_classes)
)
# 少数类专用分支(更深结构)
self.minority_branch = nn.Sequential(
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, num_classes)
)
def forward(self, x):
features = self.shared_features(x)
majority_out = self.majority_branch(features)
minority_out = self.minority_branch(features)
# 动态权重融合
alpha = 0.7 # 可学习参数
return alpha * majority_out + (1-alpha) * minority_out
四、训练流程优化
1. 动态采样策略
实现基于难例挖掘的采样方法:
def dynamic_sampling(dataset, batch_size, hard_ratio=0.3):
# 假设已有难例索引列表hard_indices
num_hard = int(batch_size * hard_ratio)
num_easy = batch_size - num_hard
# 随机选择难例和易例
hard_batch = torch.utils.data.SubsetRandomSampler(
hard_indices[:num_hard]
)
easy_batch = torch.utils.data.RandomSampler(
dataset,
num_samples=num_easy
)
# 合并采样器(需自定义BatchSampler)
# ...
2. 学习率调整策略
针对不同类别样本数量调整优化器参数:
def create_optimizer(model, class_counts, base_lr=0.001):
param_groups = []
for name, param in model.named_parameters():
# 根据参数所属模块调整学习率
if 'minority_branch' in name:
# 少数类分支使用更高学习率
lr = base_lr * 2
else:
lr = base_lr
param_groups.append({
'params': param,
'lr': lr
})
return torch.optim.Adam(param_groups)
五、评估指标与可视化
1. 多维度评估体系
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
def evaluate_model(model, test_loader, class_names):
model.eval()
all_preds, all_targets = [], []
with torch.no_grad():
for images, labels in test_loader:
outputs = model(images)
_, preds = torch.max(outputs, 1)
all_preds.extend(preds.cpu().numpy())
all_targets.extend(labels.cpu().numpy())
# 生成分类报告
print(classification_report(all_targets, all_preds, target_names=class_names))
# 绘制混淆矩阵
cm = confusion_matrix(all_targets, all_preds)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()
2. 类别性能追踪
实现按类别监控的训练日志:
class ClassWiseLogger:
def __init__(self, num_classes):
self.num_classes = num_classes
self.class_metrics = {
'accuracy': [[] for _ in range(num_classes)],
'loss': [[] for _ in range(num_classes)]
}
def update(self, epoch, class_idx, accuracy, loss):
self.class_metrics['accuracy'][class_idx].append((epoch, accuracy))
self.class_metrics['loss'][class_idx].append((epoch, loss))
def plot_metrics(self):
for cls in range(self.num_classes):
# 绘制准确率曲线
epochs, accs = zip(*self.class_metrics['accuracy'][cls])
plt.plot(epochs, accs, label=f'Class {cls}')
plt.legend()
plt.show()
六、完整案例:CIFAR-100不平衡分类
1. 数据准备
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms
# 创建不平衡数据集(示例:每类样本数按指数递减)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
full_dataset = CIFAR100(root='./data', train=True, download=True, transform=transform)
# 手动创建不平衡数据集
class_counts = [5000 // (2**i) for i in range(100)] # 指数递减
imbalanced_data = []
imbalanced_targets = []
current_idx = 0
for cls, count in enumerate(class_counts):
cls_indices = [i for i, label in enumerate(full_dataset.targets)
if label == cls][:count]
imbalanced_data.extend([full_dataset.data[i] for i in cls_indices])
imbalanced_targets.extend([cls]*len(cls_indices))
current_idx += len(cls_indices)
# 转换为PyTorch Dataset
from torch.utils.data import TensorDataset
import numpy as np
# 需要将PIL图像转换为Tensor(此处简化处理)
# 实际实现中需处理图像格式转换
2. 训练流程
def train_model():
# 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 模型初始化
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100)
model = model.to(device)
# 损失函数(带类别权重)
class_counts = [...] # 实际类别数量
class_weights = calculate_class_weights(imbalanced_targets)
criterion = WeightedCrossEntropy(class_weights).to(device)
# 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# 数据加载
dataset = CustomImbalancedDataset(...) # 实现前述数据增强
train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
# 训练循环
for epoch in range(100):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# 验证阶段(省略)
# scheduler.step(val_loss)
七、最佳实践建议
- 渐进式解决方案:优先尝试数据增强和重采样,无效时再调整模型架构
- 类别分组策略:将相似类别合并处理,缓解极端不平衡问题
- 持续监控机制:建立按类别监控的训练仪表盘,及时发现性能异常
- 后处理校准:使用温度缩放(Temperature Scaling)调整预测概率
- 集成方法:结合多个模型的预测结果,提升少数类识别率
八、总结与展望
PyTorch为不平衡数据分类提供了灵活且强大的工具链。通过数据增强、损失函数改进和模型架构优化三管齐下,可有效提升少数类的识别性能。未来研究方向包括:
- 自适应采样算法的进一步优化
- 基于元学习的少数类学习方法
- 跨数据集的不平衡问题迁移学习
开发者应根据具体场景选择合适的方法组合,并通过充分的实验验证确定最佳方案。
发表评论
登录后可评论,请前往 登录 或 注册