版权声明:本文为博主原创文章,欢迎交流分享,未经博主允许不得转载。 https://blog.csdn.net/HHTNAN/article/details/82493952
» 嵌入层 Embedding
Embedding
keras.layers.Embedding(input_dim, output_dim,
embeddings_initializer='uniform', embeddings_regularizer=None,
activity_regularizer=None, embeddings_constraint=None, mask_zero=False,
input_length=None)
将正整数(索引值)转换为固定尺寸的稠密向量。 例如: [[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]
该层只能用作模型中的第一层。
参数
input_dim: int > 0。词汇表大小, 即,最大整数 index + 1。
output_dim: int >= 0。词向量的维度。
embeddings_initializer: embeddings 矩阵的初始化方法 (详见 initializers)。
embeddings_regularizer: embeddings matrix 的正则化方法 (详见 regularizer)。
embeddings_constraint: embeddings matrix 的约束函数 (详见 constraints)。
mask_zero: 是否把 0 看作为一个应该被遮蔽的特殊的 "padding" 值。 这对于可变长的 循环神经网络层 十分有用。 如果设定为 True,那么接下来的所有层都必须支持 masking,否则就会抛出异常。 如果 mask_zero 为 True,作为结果,索引 0 就不能被用于词汇表中 (input_dim 应该与 vocabulary + 1 大小相同)。
input_length: 输入序列的长度,当它是固定的时。 如果你需要连接 Flatten 和 Dense 层,则这个参数是必须的 (没有它,dense 层的输出尺寸就无法计算)。
输入尺寸
尺寸为 (batch_size, sequence_length) 的 2D 张量。
输出尺寸
尺寸为 (batch_size, sequence_length, output_dim) 的 3D 张量。
参考文献
A Theoretically Grounded Application of Dropout in Recurrent Neural Networks
案例:
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM
from keras_contrib.layers import CRF
import numpy as np
model = Sequential()
model.add(Embedding(input_dim=1000, output_dim=60, input_length=10))
# 模型将输入一个大小为 (batch, input_length) 的整数矩阵。
# 输入中最大的整数(即词索引)不应该大于 999 (词汇表大小)
# 现在 model.output_shape == (None, 10, 60),其中 None 是 batch 的维度。
input_array = np.random.randint(1000, size=(32, 10))
print("input_array.shape={},len(input_array)={}".format(input_array.shape,len(input_array)))
print(input_array)
model.compile('rmsprop', 'mse')
output_array = model.predict(input_array)
assert output_array.shape == (32, 10, 60)
input_array.shape=(32, 10),len(input_array)=32
[[ 21 772 551 347 451 993 593 219 923 117]
[711 600 601 66 984 581 671 292 963 39]
[810 978 800 377 224 68 113 526 466 258]
[908 145 471 724 519 795 926 904 879 29]
[475 230 469 157 0 715 274 680 880 820]
[344 889 34 938 915 563 384 947 752 405]
[302 371 427 77 861 99 352 467 438 653]
[682 536 321 221 137 48 387 380 36 409]
[569 812 825 751 850 8 704 532 443 973]
[226 634 491 294 512 65 434 88 653 76]
[229 419 633 426 751 966 599 794 404 488]
[792 259 833 130 65 561 361 282 815 372]
[733 282 692 434 949 939 221 847 425 341]
[666 510 690 842 801 981 556 777 10 438]
[156 71 338 705 475 548 48 766 317 237]
[109 919 138 640 508 522 236 17 444 604]
[869 817 372 725 369 24 78 330 910 684]
[573 579 409 41 83 310 591 617 0 56]
[669 327 353 92 238 741 429 692 626 174]
[924 328 43 529 329 409 929 44 204 114]
[981 408 10 212 999 150 233 384 911 557]
[ 14 615 573 565 422 899 35 498 204 534]
[126 906 160 352 690 405 427 422 657 693]
[821 520 896 164 898 539 450 355 236 292]
[390 970 631 93 112 589 506 625 76 436]
[732 790 494 874 113 131 657 426 558 398]
[753 748 146 554 255 849 824 766 954 809]
[ 96 997 313 376 986 839 378 959 689 395]
[ 98 502 699 400 131 718 20 619 909 385]
[867 757 430 605 63 172 964 344 835 309]
[637 746 759 790 382 811 647 899 867 580]
[478 284 838 146 428 637 311 221 175 849]]