logo

从零开始:UNet图像语义分割实战指南——数据集制作、训练与推理全流程解析

作者:公子世无双2025.09.18 17:15浏览量:0

简介:本文详细介绍如何使用UNet模型进行图像语义分割,从自制数据集的标注与预处理,到模型训练与推理测试,提供完整的代码实现与操作指南,适合初学者与进阶开发者。

一、引言:为什么选择UNet进行语义分割?

UNet(U-shaped Network)由Ronneberger等人在2015年提出,是一种专为医学图像分割设计的全卷积网络(FCN)。其核心思想是通过编码器-解码器结构跳跃连接实现特征的高效提取与空间信息的保留,尤其适用于小数据集与高分辨率图像。相较于其他模型,UNet具有以下优势:

  1. 轻量化设计:参数量少,适合边缘设备部署。
  2. 跳跃连接:融合浅层(空间细节)与深层(语义信息)特征,提升分割精度。
  3. 扩展性强:可轻松适配不同任务(如卫星图像、工业缺陷检测等)。

本文将围绕UNet展开,从数据集制作到模型训练与推理,提供一套完整的解决方案。

二、数据集制作:从原始图像到标注数据

1. 数据收集与预处理

步骤1:图像采集

  • 使用手机或相机拍摄目标场景(如道路、植物、工业零件等),确保光照与角度一致。
  • 示例:若需分割道路裂缝,需采集不同光照条件下的裂缝图像。

步骤2:图像标准化

  • 统一分辨率(如512×512),避免尺寸差异导致训练不稳定。
  • 归一化像素值至[0,1]范围,加速模型收敛。
  1. import cv2
  2. import numpy as np
  3. def preprocess_image(image_path, target_size=(512, 512)):
  4. image = cv2.imread(image_path)
  5. image = cv2.resize(image, target_size)
  6. image = image.astype(np.float32) / 255.0 # 归一化
  7. return image

2. 标注工具与格式转换

工具选择

  • Labelme:开源标注工具,支持多边形、矩形标注,导出JSON格式。
  • CVAT:在线标注平台,适合团队协作。

标注流程

  1. 使用Labelme标注目标区域(如裂缝、植物),生成JSON文件。
  2. 转换为掩码(Mask)图像:
    ```python
    import json
    import os
    from PIL import Image, ImageDraw

def json_to_mask(json_path, output_path, image_shape=(512, 512)):
with open(json_path) as f:
data = json.load(f)

  1. mask = Image.new('L', image_shape, 0) # 'L'表示灰度图
  2. draw = ImageDraw.Draw(mask)
  3. for shape in data['shapes']:
  4. points = shape['points']
  5. if shape['shape_type'] == 'polygon':
  6. draw.polygon(points, fill=255) # 填充为白色(255)
  7. mask.save(output_path)
  1. **数据集结构**:

dataset/
├── images/
│ ├── img1.jpg
│ └── img2.jpg
└── masks/
├── img1_mask.png
└── img2_mask.png

  1. # 三、UNet模型实现:从编码器到解码器
  2. ## 1. 模型架构解析
  3. UNet的核心是**对称的U型结构**,包含:
  4. - **编码器(下采样)**:通过卷积与池化提取高级特征。
  5. - **解码器(上采样)**:通过转置卷积恢复空间分辨率。
  6. - **跳跃连接**:将编码器的特征图与解码器的上采样结果拼接。
  7. ## 2. 代码实现(PyTorch
  8. ```python
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. class DoubleConv(nn.Module):
  13. """双卷积块:Conv2d + ReLU + Conv2d + ReLU"""
  14. def __init__(self, in_channels, out_channels):
  15. super().__init__()
  16. self.double_conv = nn.Sequential(
  17. nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
  18. nn.ReLU(inplace=True),
  19. nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
  20. nn.ReLU(inplace=True)
  21. )
  22. def forward(self, x):
  23. return self.double_conv(x)
  24. class UNet(nn.Module):
  25. def __init__(self, n_classes=1):
  26. super().__init__()
  27. # 编码器
  28. self.dconv1 = DoubleConv(3, 64)
  29. self.dconv2 = DoubleConv(64, 128)
  30. self.dconv3 = DoubleConv(128, 256)
  31. self.dconv4 = DoubleConv(256, 512)
  32. # 解码器
  33. self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
  34. self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
  35. self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
  36. # 输出层
  37. self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)
  38. def forward(self, x):
  39. # 编码器
  40. conv1 = self.dconv1(x)
  41. pool1 = F.max_pool2d(conv1, 2)
  42. conv2 = self.dconv2(pool1)
  43. pool2 = F.max_pool2d(conv2, 2)
  44. conv3 = self.dconv3(pool2)
  45. pool3 = F.max_pool2d(conv3, 2)
  46. conv4 = self.dconv4(pool3)
  47. # 解码器 + 跳跃连接
  48. up3 = self.upconv3(conv4)
  49. up3 = torch.cat([up3, conv3], dim=1) # 拼接特征图
  50. up2 = self.upconv2(up3)
  51. up2 = torch.cat([up2, conv2], dim=1)
  52. up1 = self.upconv1(up2)
  53. up1 = torch.cat([up1, conv1], dim=1)
  54. # 输出
  55. out = self.final_conv(up1)
  56. return out

四、模型训练:从数据加载到优化

1. 数据加载器(DataLoader)

  1. from torch.utils.data import Dataset, DataLoader
  2. from torchvision import transforms
  3. class CustomDataset(Dataset):
  4. def __init__(self, image_paths, mask_paths, transform=None):
  5. self.image_paths = image_paths
  6. self.mask_paths = mask_paths
  7. self.transform = transform
  8. def __len__(self):
  9. return len(self.image_paths)
  10. def __getitem__(self, idx):
  11. image = cv2.imread(self.image_paths[idx])
  12. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  13. mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
  14. if self.transform:
  15. image = self.transform(image)
  16. mask = self.transform(mask)
  17. return image, mask
  18. # 示例:创建DataLoader
  19. transform = transforms.Compose([
  20. transforms.ToTensor(),
  21. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  22. ])
  23. dataset = CustomDataset(image_paths, mask_paths, transform=transform)
  24. dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

2. 训练循环

  1. import torch.optim as optim
  2. from tqdm import tqdm
  3. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4. model = UNet(n_classes=1).to(device)
  5. criterion = nn.BCEWithLogitsLoss() # 二分类任务
  6. optimizer = optim.Adam(model.parameters(), lr=1e-4)
  7. def train_model(model, dataloader, epochs=50):
  8. model.train()
  9. for epoch in range(epochs):
  10. running_loss = 0.0
  11. for images, masks in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}'):
  12. images = images.to(device)
  13. masks = masks.float().unsqueeze(1).to(device) # 增加通道维度
  14. optimizer.zero_grad()
  15. outputs = model(images)
  16. loss = criterion(outputs, masks)
  17. loss.backward()
  18. optimizer.step()
  19. running_loss += loss.item()
  20. print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}')
  21. train_model(model, dataloader)

五、推理测试:从模型加载到可视化

1. 模型保存与加载

  1. # 保存模型
  2. torch.save(model.state_dict(), 'unet_model.pth')
  3. # 加载模型
  4. model = UNet(n_classes=1).to(device)
  5. model.load_state_dict(torch.load('unet_model.pth'))
  6. model.eval()

2. 推理与可视化

  1. import matplotlib.pyplot as plt
  2. def predict_and_visualize(model, image_path, mask_path=None):
  3. image = preprocess_image(image_path)
  4. image_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device)
  5. with torch.no_grad():
  6. output = model(image_tensor)
  7. pred_mask = torch.sigmoid(output).squeeze().cpu().numpy()
  8. # 可视化
  9. plt.figure(figsize=(10, 5))
  10. plt.subplot(1, 2, 1)
  11. plt.imshow(image)
  12. plt.title('Original Image')
  13. plt.subplot(1, 2, 2)
  14. plt.imshow(pred_mask, cmap='gray')
  15. plt.title('Predicted Mask')
  16. if mask_path:
  17. true_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
  18. plt.figure(figsize=(5, 5))
  19. plt.imshow(true_mask, cmap='gray')
  20. plt.title('True Mask')
  21. plt.show()
  22. # 示例调用
  23. predict_and_visualize('test_image.jpg', 'test_mask.png')

六、总结与优化建议

  1. 数据增强:通过旋转、翻转增加数据多样性,提升模型泛化能力。
  2. 超参数调优:调整学习率、批次大小以优化训练效果。
  3. 模型轻量化:使用MobileUNet等变体,适配移动端部署。

通过本文的完整流程,读者可快速掌握UNet从数据集制作到模型推理的全过程,为实际项目提供技术支撑。

相关文章推荐

发表评论