logo

使用PyTorch从零实现图像分类:完整代码与深度解析

作者:半吊子全栈工匠2025.09.18 17:43浏览量:2

简介:本文提供基于PyTorch的图像分类完整实现,包含数据加载、模型构建、训练循环和评估全流程代码,每行均附详细注释说明,适合开发者快速上手并深入理解深度学习图像分类技术。

使用PyTorch实现图像分类:完整代码与详细解析

图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了灵活高效的工具来实现这一目标。本文将详细介绍如何使用PyTorch从零实现一个完整的图像分类系统,包含数据准备、模型构建、训练过程和结果评估,所有代码均附详细注释。

一、环境准备与依赖安装

首先需要安装必要的Python库:

  1. # 推荐使用conda创建虚拟环境
  2. # conda create -n pytorch_img_cls python=3.8
  3. # conda activate pytorch_img_cls
  4. # 安装基础依赖
  5. !pip install torch torchvision matplotlib numpy tqdm

关键依赖说明:

  • torch:PyTorch核心库,提供张量计算和自动微分功能
  • torchvision:包含计算机视觉常用数据集、模型架构和图像转换工具
  • matplotlib:用于可视化训练过程和样本图像
  • tqdm:提供进度条显示,提升训练体验

二、数据准备与预处理

1. 使用内置数据集

PyTorch提供了多个标准数据集,这里以CIFAR-10为例:

  1. import torchvision
  2. import torchvision.transforms as transforms
  3. # 定义数据转换流程
  4. transform = transforms.Compose([
  5. transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
  7. ])
  8. # 加载训练集和测试集
  9. trainset = torchvision.datasets.CIFAR10(
  10. root='./data',
  11. train=True,
  12. download=True,
  13. transform=transform
  14. )
  15. testset = torchvision.datasets.CIFAR10(
  16. root='./data',
  17. train=False,
  18. download=True,
  19. transform=transform
  20. )

2. 自定义数据加载器

使用DataLoader实现批量加载和随机打乱:

  1. from torch.utils.data import DataLoader
  2. batch_size = 32
  3. trainloader = DataLoader(
  4. trainset,
  5. batch_size=batch_size,
  6. shuffle=True, # 每个epoch打乱数据顺序
  7. num_workers=2 # 使用2个子进程加载数据
  8. )
  9. testloader = DataLoader(
  10. testset,
  11. batch_size=batch_size,
  12. shuffle=False,
  13. num_workers=2
  14. )
  15. # CIFAR-10类别名称
  16. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  17. 'dog', 'frog', 'horse', 'ship', 'truck')

3. 数据可视化验证

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. def imshow(img):
  4. # 反标准化并转换显示格式
  5. img = img / 2 + 0.5 # 从[-1,1]恢复到[0,1]
  6. npimg = img.numpy()
  7. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  8. plt.show()
  9. # 获取一个批次的图像
  10. dataiter = iter(trainloader)
  11. images, labels = next(dataiter)
  12. # 显示图像和标签
  13. imshow(torchvision.utils.make_grid(images))
  14. print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

三、模型架构设计

1. 基础CNN实现

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class Net(nn.Module):
  4. def __init__(self):
  5. super(Net, self).__init__()
  6. # 卷积层1:输入3通道,输出6通道,5x5卷积核
  7. self.conv1 = nn.Conv2d(3, 6, 5)
  8. # 池化层:2x2最大池化
  9. self.pool = nn.MaxPool2d(2, 2)
  10. # 卷积层2:输入6通道,输出16通道,5x5卷积核
  11. self.conv2 = nn.Conv2d(6, 16, 5)
  12. # 全连接层1:输入16*5*5(经过两次池化后尺寸),输出120
  13. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  14. # 全连接层2:输入120,输出84
  15. self.fc2 = nn.Linear(120, 84)
  16. # 输出层:输入84,输出10(CIFAR-10类别数)
  17. self.fc3 = nn.Linear(84, 10)
  18. def forward(self, x):
  19. # 卷积1 -> 激活 -> 池化
  20. x = self.pool(F.relu(self.conv1(x)))
  21. # 卷积2 -> 激活 -> 池化
  22. x = self.pool(F.relu(self.conv2(x)))
  23. # 展平特征图
  24. x = x.view(-1, 16 * 5 * 5)
  25. # 全连接层
  26. x = F.relu(self.fc1(x))
  27. x = F.relu(self.fc2(x))
  28. x = self.fc3(x)
  29. return x
  30. net = Net()
  31. print(net) # 打印模型结构

2. 使用预训练模型(可选)

  1. # 加载预训练的ResNet18
  2. model = torchvision.models.resnet18(pretrained=True)
  3. # 修改最后一层全连接层
  4. num_ftrs = model.fc.in_features
  5. model.fc = nn.Linear(num_ftrs, 10) # 适配CIFAR-10的10个类别

四、训练过程实现

1. 定义损失函数和优化器

  1. import torch.optim as optim
  2. # 交叉熵损失函数
  3. criterion = nn.CrossEntropyLoss()
  4. # 随机梯度下降优化器,学习率0.001,动量0.9
  5. optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

2. 设备配置

  1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2. print(f"Using device: {device}")
  3. net.to(device) # 将模型移动到GPU

3. 完整训练循环

  1. from tqdm import tqdm
  2. def train_model(num_epochs=10):
  3. for epoch in range(num_epochs):
  4. running_loss = 0.0
  5. correct = 0
  6. total = 0
  7. # 训练模式
  8. net.train()
  9. pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
  10. for i, data in enumerate(pbar, 0):
  11. inputs, labels = data
  12. inputs, labels = inputs.to(device), labels.to(device)
  13. # 梯度清零
  14. optimizer.zero_grad()
  15. # 前向传播
  16. outputs = net(inputs)
  17. # 计算损失
  18. loss = criterion(outputs, labels)
  19. # 反向传播和优化
  20. loss.backward()
  21. optimizer.step()
  22. # 统计信息
  23. running_loss += loss.item()
  24. _, predicted = torch.max(outputs.data, 1)
  25. total += labels.size(0)
  26. correct += (predicted == labels).sum().item()
  27. # 每200个batch打印一次统计
  28. if i % 200 == 199:
  29. avg_loss = running_loss / 200
  30. acc = 100 * correct / total
  31. pbar.set_postfix(loss=f'{avg_loss:.3f}', acc=f'{acc:.2f}%')
  32. running_loss = 0.0
  33. print('Finished Training')
  34. train_model(num_epochs=10)

五、模型评估与可视化

1. 测试集评估

  1. def evaluate_model():
  2. correct = 0
  3. total = 0
  4. class_correct = list(0. for i in range(10))
  5. class_total = list(0. for i in range(10))
  6. # 评估模式(关闭dropout等)
  7. net.eval()
  8. with torch.no_grad():
  9. for data in testloader:
  10. images, labels = data
  11. images, labels = images.to(device), labels.to(device)
  12. outputs = net(images)
  13. _, predicted = torch.max(outputs.data, 1)
  14. total += labels.size(0)
  15. correct += (predicted == labels).sum().item()
  16. # 按类别统计
  17. c = (predicted == labels).squeeze()
  18. for i in range(batch_size):
  19. label = labels[i]
  20. class_correct[label] += c[i].item()
  21. class_total[label] += 1
  22. # 打印总体准确率
  23. print(f'Accuracy on test set: {100 * correct / total:.2f}%')
  24. # 打印各类别准确率
  25. for i in range(10):
  26. print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
  27. evaluate_model()

2. 错误样本可视化

  1. def visualize_errors():
  2. net.eval()
  3. dataiter = iter(testloader)
  4. images, labels = next(dataiter)
  5. images, labels = images.to(device), labels.to(device)
  6. with torch.no_grad():
  7. outputs = net(images)
  8. _, predicted = torch.max(outputs, 1)
  9. # 显示错误预测的样本
  10. images = images.cpu() # 移回CPU用于显示
  11. imshow(torchvision.utils.make_grid(images[predicted != labels]))
  12. # 打印错误信息
  13. print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
  14. print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(batch_size)))
  15. visualize_errors()

六、模型保存与加载

  1. # 保存模型参数
  2. PATH = './cifar_net.pth'
  3. torch.save(net.state_dict(), PATH)
  4. # 加载模型
  5. def load_model():
  6. loaded_net = Net()
  7. loaded_net.load_state_dict(torch.load(PATH))
  8. loaded_net.to(device)
  9. return loaded_net
  10. # 测试加载的模型
  11. loaded_net = load_model()
  12. evaluate_model(loaded_net) # 需要修改evaluate_model函数接收net参数

七、进阶优化建议

  1. 数据增强:在transform中添加随机裁剪、水平翻转等操作提升模型泛化能力

    1. transform_train = transforms.Compose([
    2. transforms.RandomHorizontalFlip(),
    3. transforms.RandomCrop(32, padding=4),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    6. ])
  2. 学习率调度:使用StepLR或ReduceLROnPlateau动态调整学习率

    1. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用scheduler.step()
  3. 模型改进:尝试更深的网络结构如ResNet,或使用注意力机制

  4. 分布式训练:对于大规模数据集,可使用torch.nn.DataParallel实现多GPU训练

八、完整代码整合

将上述所有部分整合为一个完整的可运行脚本,包含必要的注释和错误处理。建议将代码组织为:

  • data.py:数据加载和预处理
  • model.py:模型定义
  • train.py:训练过程
  • utils.py:辅助函数

九、总结与展望

本文详细介绍了使用PyTorch实现图像分类的全流程,从数据准备到模型评估。关键点包括:

  1. 合理的数据预处理和增强
  2. 适当的模型架构选择
  3. 有效的训练策略和超参数调整
  4. 全面的模型评估方法

未来工作可以探索:

  • 更先进的网络架构(如EfficientNet、Vision Transformer)
  • 自监督学习预训练方法
  • 面向特定领域的微调技术

通过理解这个完整实现,读者可以建立起对PyTorch图像分类的深入认识,并能够基于此扩展到更复杂的计算机视觉任务。

相关文章推荐

发表评论