使用PyTorch从零实现图像分类:完整代码与深度解析
2025.09.18 17:43浏览量:2简介:本文提供基于PyTorch的图像分类完整实现,包含数据加载、模型构建、训练循环和评估全流程代码,每行均附详细注释说明,适合开发者快速上手并深入理解深度学习图像分类技术。
使用PyTorch实现图像分类:完整代码与详细解析
图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了灵活高效的工具来实现这一目标。本文将详细介绍如何使用PyTorch从零实现一个完整的图像分类系统,包含数据准备、模型构建、训练过程和结果评估,所有代码均附详细注释。
一、环境准备与依赖安装
首先需要安装必要的Python库:
# 推荐使用conda创建虚拟环境
# conda create -n pytorch_img_cls python=3.8
# conda activate pytorch_img_cls
# 安装基础依赖
!pip install torch torchvision matplotlib numpy tqdm
关键依赖说明:
torch
:PyTorch核心库,提供张量计算和自动微分功能torchvision
:包含计算机视觉常用数据集、模型架构和图像转换工具matplotlib
:用于可视化训练过程和样本图像tqdm
:提供进度条显示,提升训练体验
二、数据准备与预处理
1. 使用内置数据集
PyTorch提供了多个标准数据集,这里以CIFAR-10为例:
import torchvision
import torchvision.transforms as transforms
# 定义数据转换流程
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]
])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
testset = torchvision.datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
2. 自定义数据加载器
使用DataLoader实现批量加载和随机打乱:
from torch.utils.data import DataLoader
batch_size = 32
trainloader = DataLoader(
trainset,
batch_size=batch_size,
shuffle=True, # 每个epoch打乱数据顺序
num_workers=2 # 使用2个子进程加载数据
)
testloader = DataLoader(
testset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
# CIFAR-10类别名称
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
3. 数据可视化验证
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
# 反标准化并转换显示格式
img = img / 2 + 0.5 # 从[-1,1]恢复到[0,1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 获取一个批次的图像
dataiter = iter(trainloader)
images, labels = next(dataiter)
# 显示图像和标签
imshow(torchvision.utils.make_grid(images))
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
三、模型架构设计
1. 基础CNN实现
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# 卷积层1:输入3通道,输出6通道,5x5卷积核
self.conv1 = nn.Conv2d(3, 6, 5)
# 池化层:2x2最大池化
self.pool = nn.MaxPool2d(2, 2)
# 卷积层2:输入6通道,输出16通道,5x5卷积核
self.conv2 = nn.Conv2d(6, 16, 5)
# 全连接层1:输入16*5*5(经过两次池化后尺寸),输出120
self.fc1 = nn.Linear(16 * 5 * 5, 120)
# 全连接层2:输入120,输出84
self.fc2 = nn.Linear(120, 84)
# 输出层:输入84,输出10(CIFAR-10类别数)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# 卷积1 -> 激活 -> 池化
x = self.pool(F.relu(self.conv1(x)))
# 卷积2 -> 激活 -> 池化
x = self.pool(F.relu(self.conv2(x)))
# 展平特征图
x = x.view(-1, 16 * 5 * 5)
# 全连接层
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net) # 打印模型结构
2. 使用预训练模型(可选)
# 加载预训练的ResNet18
model = torchvision.models.resnet18(pretrained=True)
# 修改最后一层全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10) # 适配CIFAR-10的10个类别
四、训练过程实现
1. 定义损失函数和优化器
import torch.optim as optim
# 交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 随机梯度下降优化器,学习率0.001,动量0.9
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
2. 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
net.to(device) # 将模型移动到GPU
3. 完整训练循环
from tqdm import tqdm
def train_model(num_epochs=10):
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
# 训练模式
net.train()
pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')
for i, data in enumerate(pbar, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = net(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 每200个batch打印一次统计
if i % 200 == 199:
avg_loss = running_loss / 200
acc = 100 * correct / total
pbar.set_postfix(loss=f'{avg_loss:.3f}', acc=f'{acc:.2f}%')
running_loss = 0.0
print('Finished Training')
train_model(num_epochs=10)
五、模型评估与可视化
1. 测试集评估
def evaluate_model():
correct = 0
total = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
# 评估模式(关闭dropout等)
net.eval()
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 按类别统计
c = (predicted == labels).squeeze()
for i in range(batch_size):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
# 打印总体准确率
print(f'Accuracy on test set: {100 * correct / total:.2f}%')
# 打印各类别准确率
for i in range(10):
print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')
evaluate_model()
2. 错误样本可视化
def visualize_errors():
net.eval()
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
outputs = net(images)
_, predicted = torch.max(outputs, 1)
# 显示错误预测的样本
images = images.cpu() # 移回CPU用于显示
imshow(torchvision.utils.make_grid(images[predicted != labels]))
# 打印错误信息
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(batch_size)))
visualize_errors()
六、模型保存与加载
# 保存模型参数
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
# 加载模型
def load_model():
loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH))
loaded_net.to(device)
return loaded_net
# 测试加载的模型
loaded_net = load_model()
evaluate_model(loaded_net) # 需要修改evaluate_model函数接收net参数
七、进阶优化建议
数据增强:在transform中添加随机裁剪、水平翻转等操作提升模型泛化能力
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
学习率调度:使用StepLR或ReduceLROnPlateau动态调整学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# 在每个epoch后调用scheduler.step()
模型改进:尝试更深的网络结构如ResNet,或使用注意力机制
分布式训练:对于大规模数据集,可使用
torch.nn.DataParallel
实现多GPU训练
八、完整代码整合
将上述所有部分整合为一个完整的可运行脚本,包含必要的注释和错误处理。建议将代码组织为:
data.py
:数据加载和预处理model.py
:模型定义train.py
:训练过程utils.py
:辅助函数
九、总结与展望
本文详细介绍了使用PyTorch实现图像分类的全流程,从数据准备到模型评估。关键点包括:
- 合理的数据预处理和增强
- 适当的模型架构选择
- 有效的训练策略和超参数调整
- 全面的模型评估方法
未来工作可以探索:
- 更先进的网络架构(如EfficientNet、Vision Transformer)
- 自监督学习预训练方法
- 面向特定领域的微调技术
通过理解这个完整实现,读者可以建立起对PyTorch图像分类的深入认识,并能够基于此扩展到更复杂的计算机视觉任务。
发表评论
登录后可评论,请前往 登录 或 注册