使用 TensorFlow 中的 Keras API 训练一个模型 --知识铺
使用 TensorFlow 中的 Keras API 训练一个模型
👶 终极目标:训练一个模型
我们的目标是训练一个模型,你给它一张 28x28 像素的服装图片(例如运动鞋或衬衫),它会告诉你这张图片属于 10 个类别中的哪一个。
📋 完整流程概览
我们将严格遵循以下步骤,并详细解释每一行代码:
- 环境准备:安装 TensorFlow。
- 导入库:加载我们需要的所有工具。
- 加载数据:获取 Fashion MNIST 数据集。
- 探索数据:查看数据长什么样。
- 预处理数据:将数据“格式化”以适应模型。
- 构建模型:定义神经网络的“架构”。
- 编译模型:选择模型的“优化器”、“损失函数”和“评估指标”。
- 训练模型:将数据“喂”给模型进行学习。
- 评估模型:检查模型在“未见过”的数据上的表现。
- 进行预测:使用我们训练好的模型。
💻 步骤 1: 环境准备 (命令行)
在开始编写 Python 代码之前,你需要在你的计算机上安装 TensorFlow。
Bash
# 在你的终端(Terminal)或 Anaconda Prompt 中运行此命令
pip install tensorflow
pip:是 Python 的包安装器 (Package Installer for Python) 的缩写。install:是pip的一个命令,告诉它你要“安装”一个东西。tensorflow:是我们要安装的库的名称。
解释:这行命令会从 Python 包索引 (PyPI) 下载 TensorFlow 库及其所有依赖项(它需要一起工作的其他库),并将其安装到你的 Python 环境中。你只需要运行这个命令一次。
🐍 步骤 2: 导入所需库 (Python 脚本)
现在,打开你最喜欢的代码编辑器(例如 VS Code, Jupyter Notebook, PyCharm),创建一个新的 Python 文件(例如 first_nn.py),然后开始编写代码。
Python
# 导入 TensorFlow 库
import tensorflow as tf
# 导入 NumPy,一个用于科学计算的流行库
import numpy as np
# 导入 Matplotlib,一个用于绘图的库
import matplotlib.pyplot as plt
import tensorflow as tfimport tensorflow:告诉 Python 我们想要使用 TensorFlow 库中的功能。as tf:为tensorflow库创建一个“别名”tf。这是一种广泛接受的惯例,让我们不必每次都输入tensorflow这么长的名字,只需输入tf即可。
import numpy as np- 导入 NumPy 库,并使用
np作为别名。我们将用它来处理和操作数据数组。
- 导入 NumPy 库,并使用
import matplotlib.pyplot as plt- 从
matplotlib库中导入pyplot模块,并使用plt作为别名。我们将用它来可视化我们的图像数据。
- 从
📥 步骤 3: 加载数据集
我们将使用 Keras 中内置的 Fashion MNIST 数据集。
Python
# 从 tf.keras.datasets 中加载 Fashion MNIST 数据集
fashion_mnist = tf.keras.datasets.fashion_mnist
# 调用 load_data() 方法,它会返回两组数据:训练集和测试集
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
fashion_mnist = tf.keras.datasets.fashion_mnist:这行代码没有下载数据,它只是获取了fashion_mnist数据集在 Keras 库中的“模块”或“句柄”。(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()fashion_mnist.load_data():这是真正执行下载(如果本地没有缓存的话)和加载数据的函数。- 它返回两个“元组”(tuple)。
- 第一个元组
(train_images, train_labels)是训练集。这是模型将用来“学习”的数据。 - 第二个元组
(test_images, test_labels)是测试集。这是模型“从未见过”的数据,我们将用它来“测试”模型的真实表现。 train_images:包含所有训练图像的 NumPy 数组。train_labels:包含train_images中每张图像对应标签的 NumPy 数组。test_images,test_labels:同上,但用于测试。
🧐 步骤 4: 探索数据
在训练之前,我们必须了解我们的数据。
Python
# 查看训练集图像的“形状”(维度)
print(f"训练集图像形状: {train_images.shape}")
# 输出: 训练集图像形状: (60000, 28, 28)
# 查看训练集标签的“形状”
print(f"训练集标签数量: {len(train_labels)}")
# 输出: 训练集标签数量: 60000
# 查看测试集图像的“形状”
print(f"测试集图像形状: {test_images.shape}")
# 输出: 测试集图像形状: (10000, 28, 28)
# 查看测试集标签的“形状”
print(f"测试集标签数量: {len(test_labels)}")
# 输出: 测试集标签数量: 10000
代码解释:
train_images.shape:.shape是 NumPy 数组的一个属性,它返回一个元组,描述了数组的维度。(60000, 28, 28)意味着:我们有 60,000 张图像,每张图像都是 28 像素高 x 28 像素宽。
len(train_labels):len()是 Python 的内置函数,用于获取一个集合(如列表或数组)的长度。60000意味着:我们有 60,000 个标签,与 60,000 张图像一一对应。
标签是什么样子的?
Python
# 查看第一个训练标签
print(f"第一个标签: {train_labels[0]}")
# 输出: 第一个标签: 9
train_labels[0]:我们查看训练标签数组中的第一个元素。它是一个数字9。- 在 Fashion MNIST 数据集中,每个数字代表一个服装类别。例如,
9代表“Ankle boot”(踝靴)。 - 我们需要一个“类别名称”列表来将数字映射回文本。
Python
# 定义 10 个类别的名称
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
- 这行代码创建了一个 Python 列表。索引
0对应T-shirt/top,索引1对应Trouser,…,索引9对应Ankle boot。
图像是什么样子的?
Python
# 显示第一张训练图像
plt.figure()
plt.imshow(train_images[0], cmap=plt.cm.binary)
plt.colorbar()
plt.grid(False)
plt.show()
plt.figure():告诉matplotlib:“我要开始画一张新图了”。plt.imshow(train_images[0], cmap=plt.cm.binary):plt.imshow():是“image show”(图像显示)的缩写。train_images[0]:我们告诉它要显示的具体数据——训练集中的第一张图像。cmap=plt.cm.binary:cmap指 “color map”(颜色映射)。我们告诉它使用binary(黑白)颜色图来显示这张灰度图。
plt.colorbar():在图像旁边添加一个颜色条,显示像素值(例如 0 到 255)与颜色的对应关系。plt.grid(False):关闭图像上的网格线,让它看起来更清晰。plt.show():将我们“画”好的图像显示在屏幕上。
你会看到一张踝靴的图片,像素值范围是 0(黑色)到 255(白色)。
🛠️ 步骤 5: 数据预处理
我们的像素值在 0 到 255 之间。神经网络在处理 0 到 1 之间的小数值时表现最好。因此,我们需要进行归一化 (Normalization)。
Python
# 将训练集图像的像素值从 0-255 缩放到 0-1
train_images = train_images / 255.0
# 将测试集图像的像素值从 0-255 缩放到 0-1
test_images = test_images / 255.0
train_images = train_images / 255.0:- 我们使用
255.0(一个浮点数)而不是255(一个整数),以确保除法的结果是浮点数。 - NumPy 的一个很棒的功能叫做“广播”(broadcasting)。这行代码会自动将
train_images数组中的每一个像素值(总共有60000*28*28个)都除以 255.0。
- 我们使用
- 为什么?
- 更快的收敛:较大的输入值(如 255)可能会导致训练初期的梯度(学习信号)非常大,使得模型“摇摆不定”。较小的值(0-1)使训练过程更平滑、更稳定。
- 激活函数的敏感区:许多激活函数(如
sigmoid或tanh)在 0 附近最“敏感”。将输入数据保持在这个范围内有助于模型更快地学习。
🧠 步骤 6: 构建模型 (定义架构)
这是最核心的部分。我们将使用 tf.keras.Sequential 模型,它是一个“顺序”的层堆栈。
Python
# 初始化一个 Sequential 模型
model = tf.keras.Sequential([
# 第 1 层:Flatten (展平层)
tf.keras.layers.Flatten(input_shape=(28, 28)),
# 第 2 层:Dense (全连接层) - 隐藏层
tf.keras.layers.Dense(128, activation='relu'),
# 第 3 层:Dense (全连接层) - 输出层
tf.keras.layers.Dense(10)
])
让我们逐层分解:
model = tf.keras.Sequential([...])- 我们创建了一个
Sequential模型的实例。我们传入一个 Python 列表[...],列表中的每一项都是一个网络层。
- 我们创建了一个
tf.keras.layers.Flatten(input_shape=(28, 28))- 这是第一层,也叫输入层。
Flatten:这一层不做任何“学习”。它的唯一工作是将我们的 2D 图像(28x28 像素的矩阵)“展平”或“拉直”成一个 1D 数组。28x28的矩阵会变成一个28 * 28 = 784的一维向量。input_shape=(28, 28):只在第一层需要这个参数。它告诉模型:“请做好准备,你即将接收到的数据是 (28, 28) 形状的。”
tf.keras.layers.Dense(128, activation='relu')- 这是第二层,一个隐藏层。
Dense:意思是“全连接层”。它意味着前一层(展平后的 784 个节点)中的每一个节点都连接到这一层(128 个节点)中的每一个节点。这是神经网络中“学习”发生的主要地方。128:这是该层中“神经元”或“单元”的数量。这是一个“超参数”,你可以自己选择(例如 64, 256, 512)。它决定了模型的“容量”(能学习多复杂模式)。activation='relu':这是“激活函数”。- Relu (Rectified Linear Unit, 修正线性单元) 是目前最流行、最标准的激活函数。
- 它的工作很简单:
f(x) = max(0, x)。如果输入是负数,它输出 0;如果输入是正数,它原样输出。 - 为什么需要它? 它向网络中引入了“非线性”。没有它,你的神经网络(即使有 100 层)也只不过是一个复杂的“线性回归”,无法学习像服装形状这样的复杂模式。
tf.keras.layers.Dense(10)- 这是第三层,也是输出层。
Dense:它也是一个全连接层。10:它有 10 个神经元。这个数字不能随便选。它的数量必须等于我们类别的数量(我们有 10 种服装)。- 没有激活函数? 我们没有在这里指定
activation。这意味着它使用默认的“线性”激活。它将为 10 个类别中的每一个输出一个原始的、未归一化的分数(称为 logits)。例如[1.2, 0.5, 9.8, ...]. - 分数越高的类别,模型认为它越“可能”。在下一步(编译)中,我们将告诉损失函数如何正确解释这些“logits”。
⚙️ 步骤 7: 编译模型 (配置学习过程)
在模型准备好训练之前,它还需要三样东西:优化器、损失函数和评估指标。
Python
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.compile(...):这个函数用于“配置”模型,为训练做准备。optimizer='adam'- 优化器 (Optimizer):这是模型用来更新其内部权重(参数)以“学习”的算法。
adam:是一种非常流行且通常表现很好的优化器。它是一种“自适应”学习率算法,你通常不需要担心如何调整学习率。对于初学者来说,它是一个安全且强大的默认选项。
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)- 损失函数 (Loss Function):这是模型用来衡量“我错得有多离谱?”的数学公式。训练的目标就是最小化这个损失值。
SparseCategoricalCrossentropy:这是用于多类别分类的“黄金标准”损失函数。CategoricalCrossentropy:意味着我们正在处理多个类别。Sparse:意味着我们的“真实标签” (train_labels) 是稀疏的整数(例如9),而不是“one-hot”编码的向量(例如[0,0,0,0,0,0,0,0,0,1])。这非常方便!
from_logits=True:极其重要! 这行代码告诉损失函数:“嘿,我上一步(Dense(10))给你的值是原始的 ’logits’ 分数,不是 0 到 1 之间的概率。请你在计算损失之前,自己内部先应用 Softmax 激活。” 这种方式在数值上更稳定。
metrics=['accuracy']- 评估指标 (Metrics):这是我们用来监控训练和测试过程的“人类可读”的指标。
'accuracy':我们告诉模型:“在训练时,请除了报告那个复杂的损失值之外,也请告诉我,我猜对了百分之几?”(即,准确率)。
🏋️ 步骤 8: 训练模型 (拟合数据)
现在,是时候把我们的训练数据“喂”给模型了。
Python
# 训练模型
history = model.fit(train_images, train_labels, epochs=10)
model.fit(...):这是启动训练过程的函数,它的名字“fit”意味着“使模型去拟合训练数据”。train_images, train_labels:这是我们的训练数据和训练标签。模型将查看train_images,进行预测,然后将其预测与train_labels进行比较,以计算损失并更新其权重。epochs=10:- 一个 Epoch(时代)意味着模型已经“看”过了整个训练数据集一次。
epochs=10意味着模型将完整地遍历 60,000 张图像,10 次。- 在每个 epoch 结束时,你会看到 Keras 打印出
loss和accuracy。
history = ...:fit函数会返回一个History对象。这个对象以字典的形式记录了训练过程中每个 epoch 的loss和accuracy。这对于以后绘制“学习曲线”非常有用。
当你运行这行代码时,你将看到 TensorFlow 开始工作!
Epoch 1/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.4990 - accuracy: 0.8245
Epoch 2/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.3752 - accuracy: 0.8647
...
Epoch 10/10
1875/1875 [==============================] - 5s 3ms/step - loss: 0.2227 - accuracy: 0.9160
1875/1875:这表示“批次”(batches)。Keras 默认的批次大小是 32。它不会一次性处理 60,000 张图片,而是每次处理 32 张。60000 / 32 = 1875。所以它需要 1875 步来完成一个 epoch。loss: 0.2227:训练损失在下降,这是好事。accuracy: 0.9160:训练准确率在上升,达到 91.6%,这也是好事。
📊 步骤 9: 评估模型 (检查泛化能力)
模型在训练数据上表现很好 (91.6%),但这并不意味着什么。它可能只是“背住”了训练集。我们需要在它从未见过的测试集上评估它。
Python
# 在测试集上评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
# 打印测试准确率
print(f'\nTest accuracy: {test_acc}')
-
model.evaluate(...):此函数在“评估模式”下运行模型(例如,关闭 dropout),并计算你指定的指标。 -
test_images, test_labels:我们传入测试数据。 -
verbose=2:这只是设置日志的详细程度(0=无,1=进度条,2=每个 epoch 一行)。 -
test_loss, test_acc = ...:该函数返回我们编译时指定的loss和metrics。 -
print(…):你可能会看到如下输出:
313/313 - 1s - loss: 0.3340 - accuracy: 0.8805
Test accuracy: 0.8805
-
分析:
- 训练准确率: 91.6%
- 测试准确率: 88.0%
-
训练准确率略高于测试准确率,这被称为过拟合 (Overfitting),是完全正常的。这表明我们的模型在测试数据上(真实世界)的准确率约为 88%。对于这个简单的模型来说,这已经相当不错了!
🔮 步骤 10: 进行预测
最后,让我们用这个模型来预测一些事情。
我们的原始 model 输出的是 logits(原始分数)。为了获得人类可读的“概率”(例如“80% 的概率是踝靴”),我们可以在模型之后附加一个 Softmax 层。
Python
# 创建一个新模型,它包含我们的原始模型,并在最后附加一个 Softmax 层
probability_model = tf.keras.Sequential([model,
tf.keras.layers.Softmax()])
# 对测试集中的所有图像进行预测
predictions = probability_model.predict(test_images)
probability_model = ...:我们创建了一个新的Sequential模型。它的第一层是…我们整个训练好的model!第二层是一个Softmax层。tf.keras.layers.Softmax():这个层会将 logits(例如[1.2, 0.5, 9.8])转换成概率(例如[0.01, 0.00, 0.99]),并且所有概率的总和为 1。predictions = probability_model.predict(test_images):- 我们对所有 10,000 张测试图像运行
predict。 predictions将是一个形状为(10000, 10)的 NumPy 数组。predictions[0]将是第一张测试图像的 10 个概率组成的数组。
- 我们对所有 10,000 张测试图像运行
让我们检查第一个预测:
Python
# 查看对第一张测试图像的预测结果(一个包含 10 个概率的数组)
print(f"第一个预测的概率分布: {predictions[0]}")
# 找出哪个类别的概率最高
predicted_label_index = np.argmax(predictions[0])
print(f"预测的类别索引: {predicted_label_index}")
# 使用 class_names 列表将索引转换为名称
print(f"预测的类别名称: {class_names[predicted_label_index]}")
# 检查真实标签是什么
print(f"真实的类别名称: {class_names[test_labels[0]]}")
-
predictions[0]:打印出一个包含 10 个数字的数组,例如[1.2e-07, ... , 9.9e-01, ...]。 -
np.argmax(predictions[0]):NumPy 的argmax函数会返回数组中最大值的索引。 -
输出示例:
预测的类别索引: 9
预测的类别名称: Ankle boot
真实的类别名称: Ankle boot
-
成功! 我们的模型正确地将第一张测试图像分类为“Ankle boot”。
🎉 总结
恭喜你!你已经从头到尾使用 TensorFlow/Keras 构建、训练、评估和使用了一个神经网络。
你已经学会了:
- 加载和预处理数据(归一化)。
- 使用
tf.keras.Sequential构建模型架构(Flatten,Dense,relu)。 - 编译模型(
adam优化器,SparseCategoricalCrossentropy损失)。 - 训练模型(
model.fit)。 - 评估模型(
model.evaluate)。 - 预测新数据(
model.predict和Softmax)。
希望这个逐行解释能帮助你打下坚实的基础!
下一步,你可以尝试调整隐藏层中的神经元数量(例如 Dense(64, ...)),或者增加 epochs 的数量,看看你是否能提高那 88% 的测试准确率。
- 原文作者:知识铺
- 原文链接:https://index.zshipu.com/ai002/post/20251029/%E4%BD%BF%E7%94%A8-TensorFlow-%E4%B8%AD%E7%9A%84-Keras-API-%E8%AE%AD%E7%BB%83%E4%B8%80%E4%B8%AA%E6%A8%A1%E5%9E%8B/
- 版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,非商业转载请注明出处(作者,原文链接),商业转载请联系作者获得授权。
- 免责声明:本页面内容均来源于站内编辑发布,部分信息来源互联网,并不意味着本站赞同其观点或者证实其内容的真实性,如涉及版权等问题,请立即联系客服进行更改或删除,保证您的合法权益。转载请注明来源,欢迎对文章中的引用来源进行考证,欢迎指出任何有错误或不够清晰的表达。也可以邮件至 sblig@126.com