logo

使用PyTorch实现图像分类:完整代码与深度解析

作者:热心市民鹿先生2025.09.18 17:43浏览量:0

简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,并附有详细注释。适合PyTorch初学者和有一定基础的开发者参考。

使用PyTorch实现图像分类:完整代码与深度解析

引言

图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计,成为实现图像分类任务的理想选择。本文将通过一个完整的CIFAR-10分类案例,详细展示如何使用PyTorch实现图像分类,包含所有关键代码和详细注释。

环境准备

在开始之前,请确保已安装以下环境:

  • Python 3.8+
  • PyTorch 1.12+
  • torchvision 0.13+
  • NumPy 1.21+
  • Matplotlib 3.5+

推荐使用conda创建虚拟环境:

  1. conda create -n pytorch_cls python=3.8
  2. conda activate pytorch_cls
  3. pip install torch torchvision numpy matplotlib

1. 数据准备与预处理

1.1 数据集加载

我们使用torchvision内置的CIFAR-10数据集,该数据集包含10个类别的60000张32x32彩色图像。

  1. import torch
  2. from torchvision import datasets, transforms
  3. from torch.utils.data import DataLoader
  4. # 定义数据预处理流程
  5. transform = transforms.Compose([
  6. transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
  7. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
  8. ])
  9. # 加载训练集和测试集
  10. train_dataset = datasets.CIFAR10(
  11. root='./data',
  12. train=True,
  13. download=True,
  14. transform=transform
  15. )
  16. test_dataset = datasets.CIFAR10(
  17. root='./data',
  18. train=False,
  19. download=True,
  20. transform=transform
  21. )
  22. # 创建数据加载器
  23. batch_size = 64
  24. train_loader = DataLoader(
  25. train_dataset,
  26. batch_size=batch_size,
  27. shuffle=True, # 每个epoch打乱数据
  28. num_workers=2 # 使用2个子进程加载数据
  29. )
  30. test_loader = DataLoader(
  31. test_dataset,
  32. batch_size=batch_size,
  33. shuffle=False,
  34. num_workers=2
  35. )

关键点说明

  • transforms.Compose:将多个预处理操作组合在一起
  • ToTensor():自动将HWC格式的图像转为CHW格式的Tensor
  • Normalize:使用(mean, std)参数进行标准化,这里使用CIFAR-10的均值和标准差
  • DataLoader:提供批量加载、多线程加载和打乱数据等功能

1.2 数据可视化

为了验证数据加载是否正确,我们可以可视化部分训练样本:

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. def imshow(img):
  4. # 反归一化
  5. img = img / 2 + 0.5 # 从[-1,1]恢复到[0,1]
  6. npimg = img.numpy()
  7. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  8. plt.show()
  9. # 获取一个批次的图像
  10. dataiter = iter(train_loader)
  11. images, labels = next(dataiter)
  12. # 显示图像
  13. imshow(torchvision.utils.make_grid(images))
  14. # 打印标签
  15. print(' '.join(f'{datasets.CIFAR10.classes[labels[j]]}' for j in range(batch_size)))

2. 模型定义

2.1 基础CNN模型

我们定义一个简单的CNN模型,包含3个卷积层和2个全连接层:

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. class SimpleCNN(nn.Module):
  4. def __init__(self):
  5. super(SimpleCNN, self).__init__()
  6. # 卷积层1:输入通道3,输出通道32,3x3卷积核
  7. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  8. # 卷积层2:输入通道32,输出通道64,3x3卷积核
  9. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  10. # 卷积层3:输入通道64,输出通道128,3x3卷积核
  11. self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  12. # 计算展平后的维度
  13. # CIFAR-10图像原始大小32x32,经过3次池化(每次尺寸减半)后为4x4
  14. # 128个通道,所以展平后维度为128*4*4=2048
  15. self.fc1 = nn.Linear(128 * 4 * 4, 512)
  16. self.fc2 = nn.Linear(512, 10) # 10个类别
  17. # 最大池化层
  18. self.pool = nn.MaxPool2d(2, 2)
  19. def forward(self, x):
  20. # 卷积1 -> ReLU -> 池化
  21. x = self.pool(F.relu(self.conv1(x)))
  22. # 卷积2 -> ReLU -> 池化
  23. x = self.pool(F.relu(self.conv2(x)))
  24. # 卷积3 -> ReLU -> 池化
  25. x = self.pool(F.relu(self.conv3(x)))
  26. # 展平特征图
  27. x = x.view(-1, 128 * 4 * 4)
  28. # 全连接层
  29. x = F.relu(self.fc1(x))
  30. x = self.fc2(x)
  31. return x

模型结构说明

  1. 3个卷积层,每层后接ReLU激活和2x2最大池化
  2. 卷积核大小均为3x3,使用padding=1保持空间尺寸
  3. 两个全连接层,第一个有512个神经元,第二个输出10个类别
  4. 输入图像32x32,经过3次池化后变为4x4

2.2 更先进的模型(可选)

对于更高性能的需求,可以使用ResNet等现代架构:

  1. from torchvision import models
  2. def get_resnet18(pretrained=False):
  3. model = models.resnet18(pretrained=pretrained)
  4. # 修改第一个卷积层以适应3通道输入(ResNet默认用于ImageNet的3通道)
  5. # 实际上resnet18已经支持3通道,这里只是示例
  6. # 修改最后的全连接层以适应10个类别
  7. num_ftrs = model.fc.in_features
  8. model.fc = nn.Linear(num_ftrs, 10)
  9. return model

3. 训练过程

3.1 初始化模型和优化器

  1. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  2. print(f"Using device: {device}")
  3. model = SimpleCNN().to(device)
  4. criterion = nn.CrossEntropyLoss() # 交叉熵损失
  5. optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam优化器

3.2 训练函数

  1. def train_model(model, criterion, optimizer, num_epochs=10):
  2. for epoch in range(num_epochs):
  3. model.train() # 设置为训练模式
  4. running_loss = 0.0
  5. correct = 0
  6. total = 0
  7. for i, (inputs, labels) in enumerate(train_loader):
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. # 梯度清零
  10. optimizer.zero_grad()
  11. # 前向传播
  12. outputs = model(inputs)
  13. loss = criterion(outputs, labels)
  14. # 反向传播和优化
  15. loss.backward()
  16. optimizer.step()
  17. # 统计信息
  18. running_loss += loss.item()
  19. _, predicted = torch.max(outputs.data, 1)
  20. total += labels.size(0)
  21. correct += (predicted == labels).sum().item()
  22. # 每100个batch打印一次统计信息
  23. if i % 100 == 99:
  24. print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
  25. f'Loss: {running_loss/100:.3f}, Acc: {100*correct/total:.2f}%')
  26. running_loss = 0.0
  27. # 每个epoch结束后计算并打印测试集准确率
  28. test_acc = test_model(model, test_loader)
  29. print(f'Epoch [{epoch+1}/{num_epochs}] completed, Test Acc: {test_acc:.2f}%')
  30. def test_model(model, test_loader):
  31. model.eval() # 设置为评估模式
  32. correct = 0
  33. total = 0
  34. with torch.no_grad(): # 不计算梯度
  35. for inputs, labels in test_loader:
  36. inputs, labels = inputs.to(device), labels.to(device)
  37. outputs = model(inputs)
  38. _, predicted = torch.max(outputs.data, 1)
  39. total += labels.size(0)
  40. correct += (predicted == labels).sum().item()
  41. return 100 * correct / total

3.3 启动训练

  1. num_epochs = 10
  2. train_model(model, criterion, optimizer, num_epochs)

4. 模型评估与保存

4.1 评估模型

训练完成后,我们可以在测试集上评估模型性能:

  1. test_acc = test_model(model, test_loader)
  2. print(f'Final Test Accuracy: {test_acc:.2f}%')

4.2 保存模型

  1. # 保存整个模型(包括结构和参数)
  2. torch.save(model.state_dict(), 'cifar10_cnn.pth')
  3. # 加载模型的代码示例
  4. def load_model():
  5. model = SimpleCNN().to(device)
  6. model.load_state_dict(torch.load('cifar10_cnn.pth'))
  7. model.eval()
  8. return model

5. 完整代码整合

以下是整合后的完整代码:

  1. # 完整实现PyTorch图像分类
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from torchvision import datasets, transforms, models
  7. from torch.utils.data import DataLoader
  8. import matplotlib.pyplot as plt
  9. import numpy as np
  10. # 1. 数据准备
  11. transform = transforms.Compose([
  12. transforms.ToTensor(),
  13. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  14. ])
  15. train_dataset = datasets.CIFAR10(
  16. root='./data', train=True, download=True, transform=transform
  17. )
  18. test_dataset = datasets.CIFAR10(
  19. root='./data', train=False, download=True, transform=transform
  20. )
  21. batch_size = 64
  22. train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
  23. test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
  24. # 2. 模型定义
  25. class SimpleCNN(nn.Module):
  26. def __init__(self):
  27. super(SimpleCNN, self).__init__()
  28. self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
  29. self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
  30. self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
  31. self.fc1 = nn.Linear(128 * 4 * 4, 512)
  32. self.fc2 = nn.Linear(512, 10)
  33. self.pool = nn.MaxPool2d(2, 2)
  34. def forward(self, x):
  35. x = self.pool(F.relu(self.conv1(x)))
  36. x = self.pool(F.relu(self.conv2(x)))
  37. x = self.pool(F.relu(self.conv3(x)))
  38. x = x.view(-1, 128 * 4 * 4)
  39. x = F.relu(self.fc1(x))
  40. x = self.fc2(x)
  41. return x
  42. # 3. 训练设置
  43. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  44. model = SimpleCNN().to(device)
  45. criterion = nn.CrossEntropyLoss()
  46. optimizer = optim.Adam(model.parameters(), lr=0.001)
  47. # 4. 训练和测试函数
  48. def train_model(model, criterion, optimizer, num_epochs=10):
  49. for epoch in range(num_epochs):
  50. model.train()
  51. running_loss = 0.0
  52. correct = 0
  53. total = 0
  54. for i, (inputs, labels) in enumerate(train_loader):
  55. inputs, labels = inputs.to(device), labels.to(device)
  56. optimizer.zero_grad()
  57. outputs = model(inputs)
  58. loss = criterion(outputs, labels)
  59. loss.backward()
  60. optimizer.step()
  61. running_loss += loss.item()
  62. _, predicted = torch.max(outputs.data, 1)
  63. total += labels.size(0)
  64. correct += (predicted == labels).sum().item()
  65. if i % 100 == 99:
  66. print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
  67. f'Loss: {running_loss/100:.3f}, Acc: {100*correct/total:.2f}%')
  68. running_loss = 0.0
  69. test_acc = test_model(model, test_loader)
  70. print(f'Epoch [{epoch+1}/{num_epochs}] completed, Test Acc: {test_acc:.2f}%')
  71. def test_model(model, test_loader):
  72. model.eval()
  73. correct = 0
  74. total = 0
  75. with torch.no_grad():
  76. for inputs, labels in test_loader:
  77. inputs, labels = inputs.to(device), labels.to(device)
  78. outputs = model(inputs)
  79. _, predicted = torch.max(outputs.data, 1)
  80. total += labels.size(0)
  81. correct += (predicted == labels).sum().item()
  82. return 100 * correct / total
  83. # 5. 启动训练
  84. num_epochs = 10
  85. train_model(model, criterion, optimizer, num_epochs)
  86. # 6. 保存模型
  87. torch.save(model.state_dict(), 'cifar10_cnn.pth')

6. 性能优化建议

  1. 数据增强:在训练时使用随机裁剪、水平翻转等增强技术

    1. transform_train = transforms.Compose([
    2. transforms.RandomHorizontalFlip(),
    3. transforms.RandomCrop(32, padding=4),
    4. transforms.ToTensor(),
    5. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    6. ])
  2. 学习率调度:使用学习率衰减策略

    1. scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    2. # 在每个epoch后调用scheduler.step()
  3. 批归一化:在卷积层后添加批归一化层

    1. self.bn1 = nn.BatchNorm2d(32)
    2. # 在forward中使用
    3. x = self.pool(F.relu(self.bn1(self.conv1(x))))
  4. 使用更先进的模型:如ResNet、EfficientNet等预训练模型

7. 总结与展望

本文详细介绍了使用PyTorch实现图像分类的完整流程,包括数据准备、模型定义、训练过程和模型评估。通过CIFAR-10数据集的实践,读者可以掌握:

  • PyTorch的基本数据加载机制
  • CNN模型的设计原则
  • 训练循环的实现细节
  • 模型评估和保存的方法

未来工作可以探索:

  1. 更大规模的数据集(如ImageNet)
  2. 更复杂的模型架构(如Transformer)
  3. 分布式训练以加速模型收敛
  4. 模型压缩技术以部署到移动端

希望本文能为PyTorch初学者提供实用的参考,帮助快速上手图像分类任务。

相关文章推荐

发表评论