版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012435142/article/details/82986595
Keras源码结构
keras源码非常简洁,学习源码可以对深度学习整体框架有更清楚的认识。通过对keras源码的阅读,我简单总结了一些笔记,类似于大纲一样的东西。具体的实现细节和步骤,直接看源码会更容易理解。
keras搭建网络和运行的一般过程
model=Sequential() model=Model() |
建立模型 |
model.add() |
模型编辑 |
model.compile() |
模型编译 |
model.fit() |
模型拟合 |
model.evaluate() |
模型估计 |
model.save() |
|
实际使用时,增加下面功能:
model.summary()
plot_model(model,””)
fit()函数增加callbacks实现模型自动保存,早期停止,学习率变化
fit
是用来对训练集和标签直接训练,训练过程中不再对训练集和标签做处理。
fit_generator
需要对训练集和标签做处理,然后再训练。这种处理包括:简单数据集的数据增强;复杂任务的样本处理。相比fit,增加了线程数的设置等。
Keras源码文件结构
Keras核心类
主要类 |
成员函数 |
目录 |
功能 |
Layer(object) |
get_config、get_weights、set_weight |
keras/engine/base_layer.py |
基础类 |
InputLayer(Layer) |
|
keras/engine/input_layer.py |
|
Dense(Layer) Dropout(Layer) Activation(Layer) Flatten(Layer) |
|
keras/layers/core.py |
|
_Pooling1D/2D/3D(Layer) _GlobalPooling1D/2D/3D(Layer) AveragePooling2D(_Pooling2D) MaxPooling2D(_Pooling2D) |
compute_output_shape _pooling_function get_config call |
keras/layers/pooling.py |
保护类 |
重载_pooling_function |
|
|
|
BatchNormalization(Layer) |
build、call、get_config、compute_output_shape |
keras/layers/normalization.py |
|
Network(Layer) |
|
keras/engine/network.py |
|
Model(Network) |
compile、fit、evaluate、prediction train/test/predict fit/evaluate/predict_generator |
|
|
Sequential(Model) |
layers、model、add、pop、build、 predict_proba、predict_classes |
|
|
卷积层
主要类 |
成员函数 |
目录 |
功能 |
_Conv(Layer)_Cropping(Layer) _UpSampling(Layer) _ZeroPadding(Layer) _SeparableConv(_Conv)Conv1D/2D/3D(_Conv) Conv2D/3DTranspose(Conv2D) Cropping1D/2D/3D(_Cropping)UpSampling1D/2D/3D(_UpSampling) ZeroPadding1D/2D/3D(_ZeroPadding) SeparableConv1D/2D(_SeparableConv) Conv1D/2D/3DTranspose(Conv1D/2D/3D) |
get_config |
keras/layers/convolutional.py |
|
优化函数
主要类 |
成员函数 |
目录 |
功能 |
Optimizer(object) |
get_updates get_gradients get/set_weights get_config from_config |
keras/optimizers.py |
|
Adagrad(Optimizer) |
重载get_updates、get_config |
|
|
SGD(Optimizer) |
|
|
|
RMSprop(Optimizer) |
|
|
|
Adadelta(Optimizer) |
|
|
|
Adam(Optimizer) |
|
|
|
Adamax(Optimizer) |
|
|
|
Nadam(Optimizer) |
|
|
|
TFOptimizer(Optimizer) |
|
|
|
训练过程中的回调函数
实现训练超参数的更改、性能度量、记录等功能
callbacks |
|
功能 |
Callback(object) |
on_epoch_begin/end on_batch_begin/end on_train_begin/end |
|
CallbackList(object) |
|
|
LambdaCallback(Callback) |
|
自定义功能 |
BaseLogger(Callback) |
|
|
CSVLogger(Callback) |
|
将每个epochs的结果保存到表格 |
ReduceLROnPlateau(Callback) |
|
当性能不在提升时改变学习率 |
TensorBoard(Callback) |
|
|
LearningRateScheduler(Callback) |
|
学习率按照设定规则变化 |
TerminateOnNaN(Callback) |
|
当损失变成NaN时停止 |
ProgbarLogger(Callback) |
|
输出度量metrics |
History(Callback) |
|
保存历史记录 |
ModelCheckpoint(Callback) |
|
每隔epochs保存模型 |
EarlyStopping(Callback) |
|
当监视的性能指标不再提升时停止 |
RemoteMonitor(Callback) |
|
|
损失函数
keras损失函数 |
两个参数:y(预测)、y_(真实) |
mean_squared_error |
平均平方损失 |
mean_absolute_error |
|
mean_absolute_percentage_error |
平均百分比损失mean(|y-y_hat|/y_hat) |
mean_squared_logarithmic_error |
平均对数平方损失 |
squared_hinge |
|
squared_hinge |
|
hinge |
铰链损失,最大间隔分类(SVM)L=max(0,1-y*y_) |
categorical_hinge |
|
logcosh |
预测误差的双曲余弦的对数 |
categorical_crossentropy |
交叉熵 |
sparse_categorical_crossentropy |
|
binary_crossentropy |
|
kullback_leibler_divergence |
K-L散度,相对熵,信息增益 |
poisson |
|
cosine_proximity |
|
权重初始化方法
初始化类or函数 |
参数or功能 |
Initializer(object) |
|
Zeros(Initializer)、Ones(Initializer)、Constant(Initializer) |
|
RandomNormal(Initializer) |
mean、stddev |
RandomUniform(Initializer) |
minval、maxval |
TruncatedNormal(Initializer) |
mean、stddev 偏离均值两倍标准差的值被丢弃【推荐】 |
VarianceScaling(Initializer) |
scale、mode、distribution,该初始化方法能够自适应目标张量的shape |
Orthogonal(Initializer) |
随机初始化一个正交矩阵 |
Identity(Initializer) |
随机初始化一个单位矩阵 |
函数 |
|
lecun_uniform(seed=None) |
LeCun均匀分布 [-sqrt(3 / fan_in), sqrt(3 / fan_in)] |
lecun_normal(seed=None) |
mean=0, stddev = sqrt(1 / fan_in) |
glorot_normal(seed=None) |
mean=0, stddev=sqrt(2 / (fan_in + fan_out)) |
glorot_uniform(seed=None) |
[-limit, limit], limit=sqrt(6 / (fan_in + fan_out)) |
he_normal(seed=None) |
mean=0, stddev=sqrt(2 / fan_in) |
he_uniform(seed=None) |
[-sqrt(6 / fan_in), sqrt(6 / fan_in)] |