使用PyTorch构建图像分类系统:完整代码与深度解析
2025.09.19 17:05浏览量:2简介:本文详细介绍如何使用PyTorch框架实现图像分类任务,包含从数据加载到模型训练的全流程代码,每行代码均附有详细注释,适合PyTorch初学者及有一定基础的开发者参考。
使用PyTorch构建图像分类系统:完整代码与深度解析
图像分类是计算机视觉领域的核心任务之一,PyTorch作为主流深度学习框架,提供了简洁高效的API支持。本文将通过完整代码示例,展示如何使用PyTorch实现从数据准备到模型部署的全流程,所有代码均包含详细注释,确保读者能够理解每个步骤的实现原理。
一、环境准备与依赖安装
首先需要安装PyTorch及相关依赖库。推荐使用conda创建虚拟环境:
conda create -n pytorch_img_cls python=3.8conda activate pytorch_img_clspip install torch torchvision matplotlib numpy
关键依赖说明:
torch:PyTorch核心库torchvision:提供计算机视觉常用数据集和模型架构matplotlib:用于可视化训练过程numpy:数值计算基础库
二、数据集准备与预处理
1. 使用CIFAR-10数据集
CIFAR-10包含10个类别的60000张32x32彩色图像,分为50000张训练集和10000张测试集。
import torchvisionimport torchvision.transforms as transforms# 定义数据预处理流程transform = transforms.Compose([transforms.ToTensor(), # 将PIL图像转换为Tensor,并归一化到[0,1]transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到[-1,1]])# 加载训练集trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)trainloader = torch.utils.data.DataLoader(trainset,batch_size=32,shuffle=True,num_workers=2)# 加载测试集testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)testloader = torch.utils.data.DataLoader(testset,batch_size=32,shuffle=False,num_workers=2)# 类别名称classes = ('plane', 'car', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck')
关键点解析:
transforms.Compose:组合多个数据预处理操作ToTensor():将HWC格式的PIL图像转换为CHW格式的TensorNormalize:使用均值和标准差进行标准化,这里使用(0.5,0.5,0.5)将像素值映射到[-1,1]区间DataLoader:实现批量加载、数据打乱和多线程加载
2. 自定义数据集加载
对于自定义数据集,可以继承torch.utils.data.Dataset类:
from torch.utils.data import Datasetimport osfrom PIL import Imageclass CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_labels = []self.img_paths = []self.transform = transform# 遍历目录,假设子目录名为类别名for class_name in os.listdir(img_dir):class_path = os.path.join(img_dir, class_name)if os.path.isdir(class_path):for img_name in os.listdir(class_path):self.img_paths.append(os.path.join(class_path, img_name))self.img_labels.append(classes.index(class_name))def __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]image = Image.open(img_path)label = self.img_labels[idx]if self.transform:image = self.transform(image)return image, label
三、模型架构设计
1. 基础CNN模型
import torch.nn as nnimport torch.nn.functional as Fclass SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 输入通道3(RGB),输出通道32,3x3卷积核self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2) # 2x2最大池化self.fc1 = nn.Linear(64 * 8 * 8, 512) # CIFAR-10经过两次池化后为8x8self.fc2 = nn.Linear(512, 10) # 10个输出类别self.dropout = nn.Dropout(0.25)def forward(self, x):# 第一层卷积+ReLU+池化x = self.pool(F.relu(self.conv1(x)))# 第二层卷积+ReLU+池化x = self.pool(F.relu(self.conv2(x)))# 展平特征图x = x.view(-1, 64 * 8 * 8)# 全连接层+ReLU+Dropoutx = self.dropout(F.relu(self.fc1(x)))# 输出层x = self.fc2(x)return x
架构解析:
- 两个卷积层提取空间特征,每个卷积层后接ReLU激活函数和最大池化
- 两个全连接层完成分类,中间加入Dropout防止过拟合
- 输入32x32x3图像,经过两次2x2池化后变为8x8x64特征图
2. 使用预训练模型
PyTorch提供了多种预训练模型,可通过torchvision.models加载:
import torchvision.models as modelsdef get_pretrained_model(model_name='resnet18', pretrained=True, num_classes=10):if model_name == 'resnet18':model = models.resnet18(pretrained=pretrained)# 修改最后一层全连接网络num_ftrs = model.fc.in_featuresmodel.fc = nn.Linear(num_ftrs, num_classes)elif model_name == 'vgg16':model = models.vgg16(pretrained=pretrained)num_ftrs = model.classifier[6].in_featuresmodel.classifier[6] = nn.Linear(num_ftrs, num_classes)else:raise ValueError("Unsupported model name")return model
四、训练流程实现
1. 完整训练代码
import torchimport torch.optim as optimfrom tqdm import tqdm # 进度条库def train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10):device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model.to(device)for epoch in range(num_epochs):# 训练阶段model.train()running_loss = 0.0correct = 0total = 0# 使用tqdm显示进度条train_loop = tqdm(trainloader, desc=f'Epoch {epoch+1}/{num_epochs}')for inputs, labels in train_loop:inputs, labels = inputs.to(device), labels.to(device)# 梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()optimizer.step()# 统计信息running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 更新进度条信息train_loop.set_postfix(loss=running_loss/(train_loop.n+1),acc=100.*correct/total)# 测试阶段test_loss, test_acc = evaluate_model(model, testloader, criterion, device)print(f'Epoch {epoch+1}, Train Loss: {running_loss/len(trainloader):.4f}, 'f'Train Acc: {100*correct/total:.2f}%, 'f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')def evaluate_model(model, testloader, criterion, device):model.eval()test_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in testloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()return test_loss/len(testloader), 100*correct/total# 初始化模型model = SimpleCNN()# 或者使用预训练模型# model = get_pretrained_model('resnet18')# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 开始训练train_model(model, trainloader, testloader, criterion, optimizer, num_epochs=10)
2. 关键训练参数说明
- 学习率:控制参数更新步长,常用值为0.001(Adam)或0.01(SGD)
- 批量大小:影响内存使用和梯度估计稳定性,CIFAR-10常用32或64
- 优化器选择:
- Adam:自适应学习率,收敛快
- SGD+Momentum:可能获得更好泛化性能
- 损失函数:分类任务通常使用交叉熵损失
五、模型评估与可视化
1. 混淆矩阵实现
import numpy as npimport matplotlib.pyplot as pltfrom sklearn.metrics import confusion_matriximport seaborn as snsdef plot_confusion_matrix(model, testloader, classes, device):model.eval()all_labels = []all_preds = []with torch.no_grad():for inputs, labels in testloader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)all_labels.extend(labels.cpu().numpy())all_preds.extend(predicted.cpu().numpy())cm = confusion_matrix(all_labels, all_preds)plt.figure(figsize=(10,8))sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',xticklabels=classes, yticklabels=classes)plt.xlabel('Predicted')plt.ylabel('True')plt.title('Confusion Matrix')plt.show()# 调用示例plot_confusion_matrix(model, testloader, classes, device)
2. 训练过程可视化
def plot_training_curve(train_losses, test_losses, train_accs, test_accs):plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Train Loss')plt.plot(test_losses, label='Test Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(train_accs, label='Train Accuracy')plt.plot(test_accs, label='Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.show()# 需要在训练过程中记录这些指标# 示例数据epochs = range(1, 11)train_losses = [2.3, 1.8, 1.5, 1.2, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5]test_losses = [2.1, 1.7, 1.4, 1.1, 0.95, 0.85, 0.78, 0.72, 0.68, 0.65]train_accs = [45, 58, 65, 70, 75, 78, 80, 82, 84, 85]test_accs = [50, 62, 68, 72, 76, 78, 80, 81, 82, 83]plot_training_curve(train_losses, test_losses, train_accs, test_accs)
六、模型部署建议
模型导出:使用
torch.save保存模型参数torch.save(model.state_dict(), 'cifar_classifier.pth')
推理脚本示例:
def predict_image(image_path, model, transform, classes, device):image = Image.open(image_path)image = transform(image).unsqueeze(0).to(device)model.eval()with torch.no_grad():output = model(image)_, predicted = torch.max(output.data, 1)return classes[predicted.item()]
性能优化技巧:
- 使用混合精度训练(
torch.cuda.amp) - 模型量化减少内存占用
- 使用TensorRT加速推理
- 使用混合精度训练(
七、常见问题解决方案
训练不收敛:
- 检查学习率是否过大
- 确认数据预处理是否正确
- 尝试不同的优化器
过拟合问题:
- 增加数据增强
- 添加Dropout层
- 使用L2正则化
GPU内存不足:
- 减小批量大小
- 使用梯度累积
- 清理缓存(
torch.cuda.empty_cache())
本文完整代码可在GitHub获取,建议读者从简单CNN开始实践,逐步尝试预训练模型和更复杂的架构。通过调整超参数和观察训练曲线,可以深入理解深度学习模型的工作原理。

发表评论
登录后可评论,请前往 登录 或 注册