TensorFlow教程:Keras基础实战①服装图像分类全解
2025.09.18 17:02浏览量:0简介:本文通过TensorFlow与Keras框架,系统讲解服装图像分类的完整流程,涵盖数据加载、模型构建、训练优化及评估部署,适合机器学习初学者及进阶开发者。
TensorFlow教程汇总—Keras机器学习基础① 对服装图像进行分类
一、引言:为什么选择服装图像分类?
服装图像分类是计算机视觉领域的经典任务,其应用场景涵盖电商推荐、智能试衣、库存管理等。相较于传统图像分类任务(如手写数字识别),服装图像具有更复杂的纹理、形状和姿态变化,对模型的特征提取能力提出更高要求。本文以TensorFlow 2.x中的Keras API为核心,通过完整的代码示例,从数据加载到模型部署,系统讲解服装图像分类的实现流程。
二、环境准备与数据集介绍
1. 环境配置
建议使用Python 3.7+环境,安装TensorFlow 2.x版本(如pip install tensorflow
)。Keras作为TensorFlow的高级API,无需单独安装。
2. 数据集选择:Fashion MNIST
Fashion MNIST是MNIST的升级版,包含10类服装图像(T-shirt/top、Trouser、Pullover等),每类7000张灰度图(28×28像素)。其优势在于:
- 数据量适中(6万训练集,1万测试集)
- 类别多样性覆盖基础服装类型
- 无需复杂预处理
import tensorflow as tf
from tensorflow.keras.datasets import fashion_mnist
# 加载数据集
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
三、数据预处理与可视化
1. 像素值归一化
原始图像像素值为0-255,需归一化至0-1范围以加速模型收敛:
train_images = train_images / 255.0
test_images = test_images / 255.0
2. 标签编码
将类别标签转换为独热编码(One-Hot Encoding):
from tensorflow.keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
3. 数据可视化
使用Matplotlib查看部分样本:
import matplotlib.pyplot as plt
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i].argmax()])
plt.show()
四、模型构建:Keras Sequential API
1. 基础CNN模型
服装图像分类推荐使用卷积神经网络(CNN),其结构如下:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
model = Sequential([
# 输入层:28×28灰度图,需扩展为4D张量(样本数,高,宽,通道数)
tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
# 第一卷积层:32个3×3卷积核,ReLU激活
Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
MaxPooling2D((2,2)),
# 第二卷积层:64个3×3卷积核
Conv2D(64, (3,3), activation='relu'),
MaxPooling2D((2,2)),
# 展平层
Flatten(),
# 全连接层:128个神经元
Dense(128, activation='relu'),
# 输出层:10个类别,Softmax激活
Dense(10, activation='softmax')
])
2. 模型编译
配置损失函数、优化器和评估指标:
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
五、模型训练与优化
1. 基础训练
history = model.fit(train_images, train_labels,
epochs=10,
batch_size=64,
validation_data=(test_images, test_labels))
2. 训练过程可视化
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0, 1])
plt.legend(loc='lower right')
plt.show()
3. 优化策略
- 数据增强:通过旋转、缩放等操作扩充数据集
```python
from tensorflow.keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=10,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1)
生成增强数据并训练
datagen.fit(train_images)
model.fit(datagen.flow(train_images, train_labels, batch_size=64),
epochs=20)
- **模型调参**:调整卷积核数量、全连接层神经元数
- **正则化**:添加Dropout层防止过拟合
```python
from tensorflow.keras.layers import Dropout
model.add(Dropout(0.5)) # 在全连接层后添加50%丢弃率
六、模型评估与部署
1. 测试集评估
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f'Test accuracy: {test_acc}')
2. 单张图像预测
import numpy as np
def predict_image(img):
img_array = tf.expand_dims(img, 0) # 扩展为batch维度
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions[0])
return class_names[predicted_class]
# 示例预测
sample_img = test_images[0]
print(predict_image(sample_img)) # 输出预测类别
3. 模型保存与加载
# 保存模型
model.save('fashion_mnist_model.h5')
# 加载模型
loaded_model = tf.keras.models.load_model('fashion_mnist_model.h5')
七、进阶方向
- 迁移学习:使用预训练模型(如MobileNetV2)进行特征提取
base_model = tf.keras.applications.MobileNetV2(input_shape=(224,224,3),
include_top=False,
weights='imagenet')
base_model.trainable = False # 冻结预训练层
- 多标签分类:修改输出层为Sigmoid激活,适用于同时标注多个类别的场景
- 实时分类:结合OpenCV实现摄像头实时服装识别
八、常见问题与解决方案
- 过拟合问题:
- 增加训练数据量
- 添加L2正则化或Dropout层
- 早停法(Early Stopping)
```python
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor=’val_loss’, patience=3)
model.fit(…, callbacks=[early_stop])
```
训练速度慢:
- 使用GPU加速(
tf.config.list_physical_devices('GPU')
) - 减小batch_size(但需平衡内存占用)
- 使用GPU加速(
模型性能瓶颈:
- 尝试更深的网络结构(如ResNet)
- 调整学习率(使用
tf.keras.optimizers.Adam(learning_rate=0.0001)
)
九、总结与学习资源
本文通过Fashion MNIST数据集,系统演示了使用TensorFlow Keras实现服装图像分类的全流程。关键步骤包括数据预处理、CNN模型构建、训练优化和部署。对于进阶学习者,推荐以下资源:
- TensorFlow官方文档:tensorflow.org/tutorials
- Keras中文文档:keras.io/zh/
- 经典论文:《Deep Learning for Generic Object Detection: A Survey》
通过实践本文代码,读者可快速掌握Keras机器学习基础,并为后续复杂视觉任务打下坚实基础。
发表评论
登录后可评论,请前往 登录 或 注册