使用PyTorch从零实现图像分类:完整代码与深度解析
2025.09.18 17:43浏览量:2简介:本文提供基于PyTorch的图像分类完整实现,包含数据加载、模型构建、训练循环和评估全流程代码,每行均附详细注释说明,适合开发者快速上手并深入理解深度学习图像分类技术。
使用PyTorch实现图像分类:完整代码与详细解析
图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了灵活高效的工具来实现这一目标。本文将详细介绍如何使用PyTorch从零实现一个完整的图像分类系统,包含数据准备、模型构建、训练过程和结果评估,所有代码均附详细注释。
一、环境准备与依赖安装
首先需要安装必要的Python库:
# 推荐使用conda创建虚拟环境# conda create -n pytorch_img_cls python=3.8# conda activate pytorch_img_cls# 安装基础依赖!pip install torch torchvision matplotlib numpy tqdm
关键依赖说明:
torch:PyTorch核心库,提供张量计算和自动微分功能torchvision:包含计算机视觉常用数据集、模型架构和图像转换工具matplotlib:用于可视化训练过程和样本图像tqdm:提供进度条显示,提升训练体验
二、数据准备与预处理
1. 使用内置数据集
PyTorch提供了多个标准数据集,这里以CIFAR-10为例:
import torchvisionimport torchvision.transforms as transforms# 定义数据转换流程transform = transforms.Compose([transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]])# 加载训练集和测试集trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
2. 自定义数据加载器
使用DataLoader实现批量加载和随机打乱:
from torch.utils.data import DataLoaderbatch_size = 32trainloader = DataLoader(trainset,batch_size=batch_size,shuffle=True, # 每个epoch打乱数据顺序num_workers=2 # 使用2个子进程加载数据)testloader = DataLoader(testset,batch_size=batch_size,shuffle=False,num_workers=2)# CIFAR-10类别名称classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
3. 数据可视化验证
import matplotlib.pyplot as pltimport numpy as npdef imshow(img):# 反标准化并转换显示格式img = img / 2 + 0.5 # 从[-1,1]恢复到[0,1]npimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()# 获取一个批次的图像dataiter = iter(trainloader)images, labels = next(dataiter)# 显示图像和标签imshow(torchvision.utils.make_grid(images))print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
三、模型架构设计
1. 基础CNN实现
import torch.nn as nnimport torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()# 卷积层1:输入3通道,输出6通道,5x5卷积核self.conv1 = nn.Conv2d(3, 6, 5)# 池化层:2x2最大池化self.pool = nn.MaxPool2d(2, 2)# 卷积层2:输入6通道,输出16通道,5x5卷积核self.conv2 = nn.Conv2d(6, 16, 5)# 全连接层1:输入16*5*5(经过两次池化后尺寸),输出120self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层2:输入120,输出84self.fc2 = nn.Linear(120, 84)# 输出层:输入84,输出10(CIFAR-10类别数)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 卷积1 -> 激活 -> 池化x = self.pool(F.relu(self.conv1(x)))# 卷积2 -> 激活 -> 池化x = self.pool(F.relu(self.conv2(x)))# 展平特征图x = x.view(-1, 16 * 5 * 5)# 全连接层x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()print(net) # 打印模型结构
2. 使用预训练模型(可选)
# 加载预训练的ResNet18model = torchvision.models.resnet18(pretrained=True)# 修改最后一层全连接层num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, 10) # 适配CIFAR-10的10个类别
四、训练过程实现
1. 定义损失函数和优化器
import torch.optim as optim# 交叉熵损失函数criterion = nn.CrossEntropyLoss()# 随机梯度下降优化器,学习率0.001,动量0.9optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
2. 设备配置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")net.to(device) # 将模型移动到GPU
3. 完整训练循环
from tqdm import tqdmdef train_model(num_epochs=10):for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0# 训练模式net.train()pbar = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')for i, data in enumerate(pbar, 0):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播outputs = net(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计信息running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 每200个batch打印一次统计if i % 200 == 199:avg_loss = running_loss / 200acc = 100 * correct / totalpbar.set_postfix(loss=f'{avg_loss:.3f}', acc=f'{acc:.2f}%')running_loss = 0.0print('Finished Training')train_model(num_epochs=10)
五、模型评估与可视化
1. 测试集评估
def evaluate_model():correct = 0total = 0class_correct = list(0. for i in range(10))class_total = list(0. for i in range(10))# 评估模式(关闭dropout等)net.eval()with torch.no_grad():for data in testloader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 按类别统计c = (predicted == labels).squeeze()for i in range(batch_size):label = labels[i]class_correct[label] += c[i].item()class_total[label] += 1# 打印总体准确率print(f'Accuracy on test set: {100 * correct / total:.2f}%')# 打印各类别准确率for i in range(10):print(f'Accuracy of {classes[i]}: {100 * class_correct[i] / class_total[i]:.2f}%')evaluate_model()
2. 错误样本可视化
def visualize_errors():net.eval()dataiter = iter(testloader)images, labels = next(dataiter)images, labels = images.to(device), labels.to(device)with torch.no_grad():outputs = net(images)_, predicted = torch.max(outputs, 1)# 显示错误预测的样本images = images.cpu() # 移回CPU用于显示imshow(torchvision.utils.make_grid(images[predicted != labels]))# 打印错误信息print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(batch_size)))visualize_errors()
六、模型保存与加载
# 保存模型参数PATH = './cifar_net.pth'torch.save(net.state_dict(), PATH)# 加载模型def load_model():loaded_net = Net()loaded_net.load_state_dict(torch.load(PATH))loaded_net.to(device)return loaded_net# 测试加载的模型loaded_net = load_model()evaluate_model(loaded_net) # 需要修改evaluate_model函数接收net参数
七、进阶优化建议
数据增强:在transform中添加随机裁剪、水平翻转等操作提升模型泛化能力
transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
学习率调度:使用StepLR或ReduceLROnPlateau动态调整学习率
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)# 在每个epoch后调用scheduler.step()
模型改进:尝试更深的网络结构如ResNet,或使用注意力机制
分布式训练:对于大规模数据集,可使用
torch.nn.DataParallel实现多GPU训练
八、完整代码整合
将上述所有部分整合为一个完整的可运行脚本,包含必要的注释和错误处理。建议将代码组织为:
data.py:数据加载和预处理model.py:模型定义train.py:训练过程utils.py:辅助函数
九、总结与展望
本文详细介绍了使用PyTorch实现图像分类的全流程,从数据准备到模型评估。关键点包括:
- 合理的数据预处理和增强
- 适当的模型架构选择
- 有效的训练策略和超参数调整
- 全面的模型评估方法
未来工作可以探索:
- 更先进的网络架构(如EfficientNet、Vision Transformer)
- 自监督学习预训练方法
- 面向特定领域的微调技术
通过理解这个完整实现,读者可以建立起对PyTorch图像分类的深入认识,并能够基于此扩展到更复杂的计算机视觉任务。

发表评论
登录后可评论,请前往 登录 或 注册