从零开始:UNet图像语义分割实战指南——数据集制作、训练与推理全流程解析
2025.09.18 17:15浏览量:0简介:本文详细介绍如何使用UNet模型进行图像语义分割,从自制数据集的标注与预处理,到模型训练与推理测试,提供完整的代码实现与操作指南,适合初学者与进阶开发者。
一、引言:为什么选择UNet进行语义分割?
UNet(U-shaped Network)由Ronneberger等人在2015年提出,是一种专为医学图像分割设计的全卷积网络(FCN)。其核心思想是通过编码器-解码器结构和跳跃连接实现特征的高效提取与空间信息的保留,尤其适用于小数据集与高分辨率图像。相较于其他模型,UNet具有以下优势:
- 轻量化设计:参数量少,适合边缘设备部署。
- 跳跃连接:融合浅层(空间细节)与深层(语义信息)特征,提升分割精度。
- 扩展性强:可轻松适配不同任务(如卫星图像、工业缺陷检测等)。
本文将围绕UNet展开,从数据集制作到模型训练与推理,提供一套完整的解决方案。
二、数据集制作:从原始图像到标注数据
1. 数据收集与预处理
步骤1:图像采集
- 使用手机或相机拍摄目标场景(如道路、植物、工业零件等),确保光照与角度一致。
- 示例:若需分割道路裂缝,需采集不同光照条件下的裂缝图像。
步骤2:图像标准化
- 统一分辨率(如512×512),避免尺寸差异导致训练不稳定。
- 归一化像素值至[0,1]范围,加速模型收敛。
import cv2
import numpy as np
def preprocess_image(image_path, target_size=(512, 512)):
image = cv2.imread(image_path)
image = cv2.resize(image, target_size)
image = image.astype(np.float32) / 255.0 # 归一化
return image
2. 标注工具与格式转换
工具选择:
- Labelme:开源标注工具,支持多边形、矩形标注,导出JSON格式。
- CVAT:在线标注平台,适合团队协作。
标注流程:
- 使用Labelme标注目标区域(如裂缝、植物),生成JSON文件。
- 转换为掩码(Mask)图像:
```python
import json
import os
from PIL import Image, ImageDraw
def json_to_mask(json_path, output_path, image_shape=(512, 512)):
with open(json_path) as f:
data = json.load(f)
mask = Image.new('L', image_shape, 0) # 'L'表示灰度图
draw = ImageDraw.Draw(mask)
for shape in data['shapes']:
points = shape['points']
if shape['shape_type'] == 'polygon':
draw.polygon(points, fill=255) # 填充为白色(255)
mask.save(output_path)
**数据集结构**:
dataset/
├── images/
│ ├── img1.jpg
│ └── img2.jpg
└── masks/
├── img1_mask.png
└── img2_mask.png
# 三、UNet模型实现:从编码器到解码器
## 1. 模型架构解析
UNet的核心是**对称的U型结构**,包含:
- **编码器(下采样)**:通过卷积与池化提取高级特征。
- **解码器(上采样)**:通过转置卷积恢复空间分辨率。
- **跳跃连接**:将编码器的特征图与解码器的上采样结果拼接。
## 2. 代码实现(PyTorch)
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""双卷积块:Conv2d + ReLU + Conv2d + ReLU"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, n_classes=1):
super().__init__()
# 编码器
self.dconv1 = DoubleConv(3, 64)
self.dconv2 = DoubleConv(64, 128)
self.dconv3 = DoubleConv(128, 256)
self.dconv4 = DoubleConv(256, 512)
# 解码器
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
# 输出层
self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)
def forward(self, x):
# 编码器
conv1 = self.dconv1(x)
pool1 = F.max_pool2d(conv1, 2)
conv2 = self.dconv2(pool1)
pool2 = F.max_pool2d(conv2, 2)
conv3 = self.dconv3(pool2)
pool3 = F.max_pool2d(conv3, 2)
conv4 = self.dconv4(pool3)
# 解码器 + 跳跃连接
up3 = self.upconv3(conv4)
up3 = torch.cat([up3, conv3], dim=1) # 拼接特征图
up2 = self.upconv2(up3)
up2 = torch.cat([up2, conv2], dim=1)
up1 = self.upconv1(up2)
up1 = torch.cat([up1, conv1], dim=1)
# 输出
out = self.final_conv(up1)
return out
四、模型训练:从数据加载到优化
1. 数据加载器(DataLoader)
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CustomDataset(Dataset):
def __init__(self, image_paths, mask_paths, transform=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
return image, mask
# 示例:创建DataLoader
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = CustomDataset(image_paths, mask_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
2. 训练循环
import torch.optim as optim
from tqdm import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(n_classes=1).to(device)
criterion = nn.BCEWithLogitsLoss() # 二分类任务
optimizer = optim.Adam(model.parameters(), lr=1e-4)
def train_model(model, dataloader, epochs=50):
model.train()
for epoch in range(epochs):
running_loss = 0.0
for images, masks in tqdm(dataloader, desc=f'Epoch {epoch+1}/{epochs}'):
images = images.to(device)
masks = masks.float().unsqueeze(1).to(device) # 增加通道维度
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}')
train_model(model, dataloader)
五、推理测试:从模型加载到可视化
1. 模型保存与加载
# 保存模型
torch.save(model.state_dict(), 'unet_model.pth')
# 加载模型
model = UNet(n_classes=1).to(device)
model.load_state_dict(torch.load('unet_model.pth'))
model.eval()
2. 推理与可视化
import matplotlib.pyplot as plt
def predict_and_visualize(model, image_path, mask_path=None):
image = preprocess_image(image_path)
image_tensor = transforms.ToTensor()(image).unsqueeze(0).to(device)
with torch.no_grad():
output = model(image_tensor)
pred_mask = torch.sigmoid(output).squeeze().cpu().numpy()
# 可视化
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title('Original Image')
plt.subplot(1, 2, 2)
plt.imshow(pred_mask, cmap='gray')
plt.title('Predicted Mask')
if mask_path:
true_mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
plt.figure(figsize=(5, 5))
plt.imshow(true_mask, cmap='gray')
plt.title('True Mask')
plt.show()
# 示例调用
predict_and_visualize('test_image.jpg', 'test_mask.png')
六、总结与优化建议
- 数据增强:通过旋转、翻转增加数据多样性,提升模型泛化能力。
- 超参数调优:调整学习率、批次大小以优化训练效果。
- 模型轻量化:使用MobileUNet等变体,适配移动端部署。
通过本文的完整流程,读者可快速掌握UNet从数据集制作到模型推理的全过程,为实际项目提供技术支撑。
发表评论
登录后可评论,请前往 登录 或 注册