logo

PyTorch图像分类实战:从零构建完整代码与注释

作者:渣渣辉2025.09.18 17:52浏览量:0

简介:本文将通过PyTorch框架实现一个完整的图像分类系统,包含数据加载、模型构建、训练与评估全流程。提供可运行的完整代码,每行均配有详细注释,并深入解析关键技术点。

PyTorch图像分类实战:从零构建完整代码与注释

一、项目概述

图像分类是计算机视觉的基础任务,PyTorch作为主流深度学习框架,提供了灵活的API和自动求导机制。本文将实现一个基于CIFAR-10数据集的图像分类器,包含以下核心模块:

  1. 数据加载与预处理
  2. 卷积神经网络模型构建
  3. 训练循环与优化策略
  4. 模型评估与可视化

二、环境准备

  1. # 环境配置要求
  2. # Python 3.8+
  3. # PyTorch 2.0+
  4. # torchvision 0.15+
  5. # matplotlib 3.5+
  6. # numpy 1.22+
  7. import torch
  8. import torch.nn as nn
  9. import torch.optim as optim
  10. import torchvision
  11. import torchvision.transforms as transforms
  12. import matplotlib.pyplot as plt
  13. import numpy as np
  14. # 检查GPU可用性
  15. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  16. print(f"Using device: {device}")

三、数据准备与预处理

1. 数据增强与归一化

  1. # 定义数据转换管道
  2. transform_train = transforms.Compose([
  3. transforms.RandomHorizontalFlip(), # 随机水平翻转
  4. transforms.RandomRotation(15), # 随机旋转±15度
  5. transforms.ToTensor(), # 转换为Tensor
  6. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
  7. ])
  8. transform_test = transforms.Compose([
  9. transforms.ToTensor(),
  10. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  11. ])

2. 数据集加载

  1. # 加载CIFAR-10数据集
  2. trainset = torchvision.datasets.CIFAR10(
  3. root='./data',
  4. train=True,
  5. download=True,
  6. transform=transform_train
  7. )
  8. testset = torchvision.datasets.CIFAR10(
  9. root='./data',
  10. train=False,
  11. download=True,
  12. transform=transform_test
  13. )
  14. # 创建数据加载器
  15. trainloader = torch.utils.data.DataLoader(
  16. trainset,
  17. batch_size=32,
  18. shuffle=True,
  19. num_workers=2
  20. )
  21. testloader = torch.utils.data.DataLoader(
  22. testset,
  23. batch_size=32,
  24. shuffle=False,
  25. num_workers=2
  26. )
  27. # 类别标签
  28. classes = ('plane', 'car', 'bird', 'cat', 'deer',
  29. 'dog', 'frog', 'horse', 'ship', 'truck')

四、模型架构设计

1. CNN模型实现

  1. class CNN(nn.Module):
  2. def __init__(self):
  3. super(CNN, self).__init__()
  4. # 卷积层1: 输入3通道,输出16通道,3x3卷积核
  5. self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
  6. self.bn1 = nn.BatchNorm2d(16) # 批归一化
  7. self.relu1 = nn.ReLU()
  8. self.pool1 = nn.MaxPool2d(2, 2) # 2x2最大池化
  9. # 卷积层2: 输入16通道,输出32通道
  10. self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
  11. self.bn2 = nn.BatchNorm2d(32)
  12. self.relu2 = nn.ReLU()
  13. self.pool2 = nn.MaxPool2d(2, 2)
  14. # 全连接层
  15. self.fc1 = nn.Linear(32 * 8 * 8, 256) # CIFAR-10经过两次池化后为8x8
  16. self.dropout = nn.Dropout(0.5)
  17. self.fc2 = nn.Linear(256, 10) # 输出10个类别
  18. def forward(self, x):
  19. # 第一卷积块
  20. x = self.conv1(x)
  21. x = self.bn1(x)
  22. x = self.relu1(x)
  23. x = self.pool1(x)
  24. # 第二卷积块
  25. x = self.conv2(x)
  26. x = self.bn2(x)
  27. x = self.relu2(x)
  28. x = self.pool2(x)
  29. # 展平特征图
  30. x = x.view(-1, 32 * 8 * 8)
  31. # 全连接层
  32. x = self.fc1(x)
  33. x = self.dropout(x)
  34. x = self.fc2(x)
  35. return x

2. 模型初始化

  1. model = CNN().to(device)
  2. print(model) # 打印模型结构

五、训练流程实现

1. 损失函数与优化器

  1. criterion = nn.CrossEntropyLoss() # 交叉熵损失
  2. optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器

2. 训练循环

  1. def train_model(model, trainloader, criterion, optimizer, epochs=10):
  2. model.train() # 设置为训练模式
  3. for epoch in range(epochs):
  4. running_loss = 0.0
  5. correct = 0
  6. total = 0
  7. for i, (inputs, labels) in enumerate(trainloader, 0):
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. # 梯度清零
  10. optimizer.zero_grad()
  11. # 前向传播
  12. outputs = model(inputs)
  13. loss = criterion(outputs, labels)
  14. # 反向传播与优化
  15. loss.backward()
  16. optimizer.step()
  17. # 统计信息
  18. running_loss += loss.item()
  19. _, predicted = torch.max(outputs.data, 1)
  20. total += labels.size(0)
  21. correct += (predicted == labels).sum().item()
  22. # 每100个batch打印一次
  23. if i % 100 == 99:
  24. print(f"Epoch {epoch+1}, Batch {i+1}, "
  25. f"Loss: {running_loss/100:.3f}, "
  26. f"Acc: {100*correct/total:.2f}%")
  27. running_loss = 0.0
  28. # 每个epoch结束打印验证准确率
  29. val_acc = evaluate_model(model, testloader)
  30. print(f"Epoch {epoch+1} completed, Val Acc: {val_acc:.2f}%")
  31. def evaluate_model(model, testloader):
  32. model.eval() # 设置为评估模式
  33. correct = 0
  34. total = 0
  35. with torch.no_grad():
  36. for inputs, labels in testloader:
  37. inputs, labels = inputs.to(device), labels.to(device)
  38. outputs = model(inputs)
  39. _, predicted = torch.max(outputs.data, 1)
  40. total += labels.size(0)
  41. correct += (predicted == labels).sum().item()
  42. return 100 * correct / total

3. 执行训练

  1. train_model(model, trainloader, criterion, optimizer, epochs=10)

六、模型评估与可视化

1. 测试集评估

  1. test_acc = evaluate_model(model, testloader)
  2. print(f"Final Test Accuracy: {test_acc:.2f}%")

2. 混淆矩阵可视化

  1. from sklearn.metrics import confusion_matrix
  2. import seaborn as sns
  3. def plot_confusion_matrix(model, testloader, classes):
  4. model.eval()
  5. y_true = []
  6. y_pred = []
  7. with torch.no_grad():
  8. for inputs, labels in testloader:
  9. inputs, labels = inputs.to(device), labels.to(device)
  10. outputs = model(inputs)
  11. _, predicted = torch.max(outputs.data, 1)
  12. y_true.extend(labels.cpu().numpy())
  13. y_pred.extend(predicted.cpu().numpy())
  14. cm = confusion_matrix(y_true, y_pred)
  15. plt.figure(figsize=(10,8))
  16. sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
  17. xticklabels=classes, yticklabels=classes)
  18. plt.xlabel('Predicted')
  19. plt.ylabel('True')
  20. plt.title('Confusion Matrix')
  21. plt.show()
  22. plot_confusion_matrix(model, testloader, classes)

七、关键技术点解析

  1. 批归一化作用:在每个卷积层后添加批归一化,可加速收敛并提高模型稳定性
  2. 数据增强重要性:随机翻转和旋转显著提升了模型在测试集上的泛化能力
  3. 学习率选择:Adam优化器默认的0.001学习率在CIFAR-10上表现良好
  4. 模型深度设计:两层卷积+两层全连接的架构在计算量和准确率间取得平衡

八、进阶优化方向

  1. 尝试使用预训练的ResNet等更复杂模型
  2. 实现学习率调度器(如ReduceLROnPlateau)
  3. 添加更复杂的数据增强(如CutMix)
  4. 实现模型保存与加载功能

九、完整代码包说明

本文提供的代码包含:

  1. 完整的数据加载管道
  2. 可训练的CNN模型实现
  3. 详细的训练与评估流程
  4. 可视化工具代码

所有代码均经过测试,可在PyTorch 2.0+环境下直接运行。建议读者从基础版本开始,逐步尝试添加更复杂的优化技术。

相关文章推荐

发表评论