PyTorch图像分类实战:从零构建完整代码与注释
2025.09.18 17:52浏览量:0简介:本文将通过PyTorch框架实现一个完整的图像分类系统,包含数据加载、模型构建、训练与评估全流程。提供可运行的完整代码,每行均配有详细注释,并深入解析关键技术点。
PyTorch图像分类实战:从零构建完整代码与注释
一、项目概述
图像分类是计算机视觉的基础任务,PyTorch作为主流深度学习框架,提供了灵活的API和自动求导机制。本文将实现一个基于CIFAR-10数据集的图像分类器,包含以下核心模块:
- 数据加载与预处理
- 卷积神经网络模型构建
- 训练循环与优化策略
- 模型评估与可视化
二、环境准备
# 环境配置要求
# Python 3.8+
# PyTorch 2.0+
# torchvision 0.15+
# matplotlib 3.5+
# numpy 1.22+
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 检查GPU可用性
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
三、数据准备与预处理
1. 数据增强与归一化
# 定义数据转换管道
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(15), # 随机旋转±15度
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
2. 数据集加载
# 加载CIFAR-10数据集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform_test
)
# 创建数据加载器
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=32,
shuffle=True,
num_workers=2
)
testloader = torch.utils.data.DataLoader(
testset,
batch_size=32,
shuffle=False,
num_workers=2
)
# 类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
四、模型架构设计
1. CNN模型实现
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 卷积层1: 输入3通道,输出16通道,3x3卷积核
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16) # 批归一化
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2) # 2x2最大池化
# 卷积层2: 输入16通道,输出32通道
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(32)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2)
# 全连接层
self.fc1 = nn.Linear(32 * 8 * 8, 256) # CIFAR-10经过两次池化后为8x8
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(256, 10) # 输出10个类别
def forward(self, x):
# 第一卷积块
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
# 第二卷积块
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
# 展平特征图
x = x.view(-1, 32 * 8 * 8)
# 全连接层
x = self.fc1(x)
x = self.dropout(x)
x = self.fc2(x)
return x
2. 模型初始化
model = CNN().to(device)
print(model) # 打印模型结构
五、训练流程实现
1. 损失函数与优化器
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器
2. 训练循环
def train_model(model, trainloader, criterion, optimizer, epochs=10):
model.train() # 设置为训练模式
for epoch in range(epochs):
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(trainloader, 0):
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播与优化
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 每100个batch打印一次
if i % 100 == 99:
print(f"Epoch {epoch+1}, Batch {i+1}, "
f"Loss: {running_loss/100:.3f}, "
f"Acc: {100*correct/total:.2f}%")
running_loss = 0.0
# 每个epoch结束打印验证准确率
val_acc = evaluate_model(model, testloader)
print(f"Epoch {epoch+1} completed, Val Acc: {val_acc:.2f}%")
def evaluate_model(model, testloader):
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
3. 执行训练
train_model(model, trainloader, criterion, optimizer, epochs=10)
六、模型评估与可视化
1. 测试集评估
test_acc = evaluate_model(model, testloader)
print(f"Final Test Accuracy: {test_acc:.2f}%")
2. 混淆矩阵可视化
from sklearn.metrics import confusion_matrix
import seaborn as sns
def plot_confusion_matrix(model, testloader, classes):
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for inputs, labels in testloader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
y_true.extend(labels.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
plot_confusion_matrix(model, testloader, classes)
七、关键技术点解析
- 批归一化作用:在每个卷积层后添加批归一化,可加速收敛并提高模型稳定性
- 数据增强重要性:随机翻转和旋转显著提升了模型在测试集上的泛化能力
- 学习率选择:Adam优化器默认的0.001学习率在CIFAR-10上表现良好
- 模型深度设计:两层卷积+两层全连接的架构在计算量和准确率间取得平衡
八、进阶优化方向
- 尝试使用预训练的ResNet等更复杂模型
- 实现学习率调度器(如ReduceLROnPlateau)
- 添加更复杂的数据增强(如CutMix)
- 实现模型保存与加载功能
九、完整代码包说明
本文提供的代码包含:
- 完整的数据加载管道
- 可训练的CNN模型实现
- 详细的训练与评估流程
- 可视化工具代码
所有代码均经过测试,可在PyTorch 2.0+环境下直接运行。建议读者从基础版本开始,逐步尝试添加更复杂的优化技术。
发表评论
登录后可评论,请前往 登录 或 注册