使用PyTorch实现图像分类:完整代码与深度解析
2025.09.18 17:43浏览量:0简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,并附有详细注释。适合PyTorch初学者和有一定基础的开发者参考。
使用PyTorch实现图像分类:完整代码与深度解析
引言
图像分类是计算机视觉领域的核心任务之一,广泛应用于人脸识别、医学影像分析、自动驾驶等场景。PyTorch作为主流深度学习框架,以其动态计算图和简洁的API设计,成为实现图像分类任务的理想选择。本文将通过一个完整的CIFAR-10分类案例,详细展示如何使用PyTorch实现图像分类,包含所有关键代码和详细注释。
环境准备
在开始之前,请确保已安装以下环境:
- Python 3.8+
- PyTorch 1.12+
- torchvision 0.13+
- NumPy 1.21+
- Matplotlib 3.5+
推荐使用conda创建虚拟环境:
conda create -n pytorch_cls python=3.8
conda activate pytorch_cls
pip install torch torchvision numpy matplotlib
1. 数据准备与预处理
1.1 数据集加载
我们使用torchvision内置的CIFAR-10数据集,该数据集包含10个类别的60000张32x32彩色图像。
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义数据预处理流程
transform = transforms.Compose([
transforms.ToTensor(), # 将PIL图像或numpy数组转为Tensor,并缩放到[0,1]
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化到[-1,1]
])
# 加载训练集和测试集
train_dataset = datasets.CIFAR10(
root='./data',
train=True,
download=True,
transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data',
train=False,
download=True,
transform=transform
)
# 创建数据加载器
batch_size = 64
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True, # 每个epoch打乱数据
num_workers=2 # 使用2个子进程加载数据
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=2
)
关键点说明:
transforms.Compose
:将多个预处理操作组合在一起ToTensor()
:自动将HWC格式的图像转为CHW格式的TensorNormalize
:使用(mean, std)参数进行标准化,这里使用CIFAR-10的均值和标准差DataLoader
:提供批量加载、多线程加载和打乱数据等功能
1.2 数据可视化
为了验证数据加载是否正确,我们可以可视化部分训练样本:
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
# 反归一化
img = img / 2 + 0.5 # 从[-1,1]恢复到[0,1]
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# 获取一个批次的图像
dataiter = iter(train_loader)
images, labels = next(dataiter)
# 显示图像
imshow(torchvision.utils.make_grid(images))
# 打印标签
print(' '.join(f'{datasets.CIFAR10.classes[labels[j]]}' for j in range(batch_size)))
2. 模型定义
2.1 基础CNN模型
我们定义一个简单的CNN模型,包含3个卷积层和2个全连接层:
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
# 卷积层1:输入通道3,输出通道32,3x3卷积核
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
# 卷积层2:输入通道32,输出通道64,3x3卷积核
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
# 卷积层3:输入通道64,输出通道128,3x3卷积核
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
# 计算展平后的维度
# CIFAR-10图像原始大小32x32,经过3次池化(每次尺寸减半)后为4x4
# 128个通道,所以展平后维度为128*4*4=2048
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10) # 10个类别
# 最大池化层
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
# 卷积1 -> ReLU -> 池化
x = self.pool(F.relu(self.conv1(x)))
# 卷积2 -> ReLU -> 池化
x = self.pool(F.relu(self.conv2(x)))
# 卷积3 -> ReLU -> 池化
x = self.pool(F.relu(self.conv3(x)))
# 展平特征图
x = x.view(-1, 128 * 4 * 4)
# 全连接层
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
模型结构说明:
- 3个卷积层,每层后接ReLU激活和2x2最大池化
- 卷积核大小均为3x3,使用padding=1保持空间尺寸
- 两个全连接层,第一个有512个神经元,第二个输出10个类别
- 输入图像32x32,经过3次池化后变为4x4
2.2 更先进的模型(可选)
对于更高性能的需求,可以使用ResNet等现代架构:
from torchvision import models
def get_resnet18(pretrained=False):
model = models.resnet18(pretrained=pretrained)
# 修改第一个卷积层以适应3通道输入(ResNet默认用于ImageNet的3通道)
# 实际上resnet18已经支持3通道,这里只是示例
# 修改最后的全连接层以适应10个类别
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)
return model
3. 训练过程
3.1 初始化模型和优化器
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam优化器
3.2 训练函数
def train_model(model, criterion, optimizer, num_epochs=10):
for epoch in range(num_epochs):
model.train() # 设置为训练模式
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
loss.backward()
optimizer.step()
# 统计信息
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 每100个batch打印一次统计信息
if i % 100 == 99:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
f'Loss: {running_loss/100:.3f}, Acc: {100*correct/total:.2f}%')
running_loss = 0.0
# 每个epoch结束后计算并打印测试集准确率
test_acc = test_model(model, test_loader)
print(f'Epoch [{epoch+1}/{num_epochs}] completed, Test Acc: {test_acc:.2f}%')
def test_model(model, test_loader):
model.eval() # 设置为评估模式
correct = 0
total = 0
with torch.no_grad(): # 不计算梯度
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
3.3 启动训练
num_epochs = 10
train_model(model, criterion, optimizer, num_epochs)
4. 模型评估与保存
4.1 评估模型
训练完成后,我们可以在测试集上评估模型性能:
test_acc = test_model(model, test_loader)
print(f'Final Test Accuracy: {test_acc:.2f}%')
4.2 保存模型
# 保存整个模型(包括结构和参数)
torch.save(model.state_dict(), 'cifar10_cnn.pth')
# 加载模型的代码示例
def load_model():
model = SimpleCNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn.pth'))
model.eval()
return model
5. 完整代码整合
以下是整合后的完整代码:
# 完整实现PyTorch图像分类
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 1. 数据准备
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform
)
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# 2. 模型定义
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 4 * 4)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
# 3. 训练设置
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 4. 训练和测试函数
def train_model(model, criterion, optimizer, num_epochs=10):
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if i % 100 == 99:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], '
f'Loss: {running_loss/100:.3f}, Acc: {100*correct/total:.2f}%')
running_loss = 0.0
test_acc = test_model(model, test_loader)
print(f'Epoch [{epoch+1}/{num_epochs}] completed, Test Acc: {test_acc:.2f}%')
def test_model(model, test_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return 100 * correct / total
# 5. 启动训练
num_epochs = 10
train_model(model, criterion, optimizer, num_epochs)
# 6. 保存模型
torch.save(model.state_dict(), 'cifar10_cnn.pth')
6. 性能优化建议
数据增强:在训练时使用随机裁剪、水平翻转等增强技术
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
学习率调度:使用学习率衰减策略
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
# 在每个epoch后调用scheduler.step()
批归一化:在卷积层后添加批归一化层
self.bn1 = nn.BatchNorm2d(32)
# 在forward中使用
x = self.pool(F.relu(self.bn1(self.conv1(x))))
使用更先进的模型:如ResNet、EfficientNet等预训练模型
7. 总结与展望
本文详细介绍了使用PyTorch实现图像分类的完整流程,包括数据准备、模型定义、训练过程和模型评估。通过CIFAR-10数据集的实践,读者可以掌握:
- PyTorch的基本数据加载机制
- CNN模型的设计原则
- 训练循环的实现细节
- 模型评估和保存的方法
未来工作可以探索:
- 更大规模的数据集(如ImageNet)
- 更复杂的模型架构(如Transformer)
- 分布式训练以加速模型收敛
- 模型压缩技术以部署到移动端
希望本文能为PyTorch初学者提供实用的参考,帮助快速上手图像分类任务。
发表评论
登录后可评论,请前往 登录 或 注册