TensorFlow进行不同模型和数据集之间的迁移学习和模型微调

在这里插入图片描述

迁移学习包括获取从一个问题中学习到的特征,然后将这些特征用于新的类似问题。例如,来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。
在这里插入图片描述

对于数据集中的数据太少而无法从头开始训练完整模型的任务,通常会执行迁移学习。

在深度学习情境中,迁移学习最常见的形式是以下工作流:

  • 从之前训练的模型中获取层。
  • 冻结这些层,以避免在后续训练轮次中破坏它们包含的任何信息。
  • 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。
  • 在您的数据集上训练新层。
  • 最后一个可选步骤是微调,包括解冻上面获得的整个模型(或模型的一部分),然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据,有可能实现有意义的改进。

典型的迁移学习工作流
下面将介绍如何在 Keras 中实现典型的迁移学习工作流:

  • 实例化一个基础模型并加载预训练权重。

  • 通过设置 trainable = False 冻结基础模型中的所有层。

  • 根据基础模型中一个(或多个)层的输出创建一个新模型。

  • 在您的新数据集上训练新模型。
    请注意,另一种更轻量的工作流如下:

  • 实例化一个基础模型并加载预训练权重。

  • 通过该模型运行新的数据集,并记录基础模型中一个(或多个)层的输出。这一过程称为特征提取。

  • 使用该输出作为新的较小模型的输入数据。

第二种工作流有一个关键优势,即您只需在自己的数据上运行一次基础模型,而不是每个训练周期都运行一次。因此,它的速度更快,开销也更低。

但是,第二种工作流存在一个问题,即它不允许您在训练期间动态修改新模型的输入数据,在进行数据扩充时,这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时,任务通常会使用迁移学习,在这种情况下,数据扩充非常重要。因此,在接下来的篇幅中,我们将专注于第一种工作流。

完整代码

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 13:25
 * 描述:
"""
import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras import layers

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False
)

base_model.trainable = False

data_augmentation = keras.Sequential([
    layers.experimental.preprocessing.RandomFlip('horizontal'),
    layers.experimental.preprocessing.RandomRotation(0.1)
])

inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)

norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
x = norm_layer(x)
# norm_layer.set_weights([mean, var])
x = base_model(x, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(1)(x)
model = Model(inputs, outputs)

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

猜你喜欢

转载自blog.csdn.net/m0_47256162/article/details/122301217