使用tf.tile()和tf.sparse_to_dense()构建多标签
原文地址:https://blog.csdn.net/qq_31456593/article/details/89448262
多标签分类时需要构建类似one-hot的多标签y_hot,在tensorflow中可以用tf.tile()和tf.sparse_to_dense()构建
tf.tile()
tf.tile()应用于需要张量扩展的场景,具体说来就是:
如果现有一个形状如[width, height]的张量,需要得到一个基于原张量的,形状如[batch_size,width,height]的张量,其中每一个batch的内容都和原张量一模一样
tf.sparse_to_dense
tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)
除去name参数用以指定该操作的name,与方法有关的一共四个参数:
第一个参数sparse_indices:稀疏矩阵中那些个别元素对应的索引值。
第二个参数output_shape:输出的稀疏矩阵的shape
第三个参数sparse_values:个别元素的值。
构造onehot
batch_size=6
label=tf.expand_dims(tf.constant([0,2,3,6,7,9]),1)
index=tf.expand_dims(tf.range(0, batch_size),1)
concated = tf.concat(1, [index, label])
构造多标签
ids = tf.range(batch_size)
ids = tf.expand_dims(ids, -1)
m_ids = tf.tile(ids, [1, size])
# 构建稀疏矩阵
indices_one = tf.reshape(indices, [-1, 1])
m_ids_one = tf.reshape(m_ids, [-1, 1])
sparse_indices = tf.concat([m_ids_one, indices_one], -1)
mul_hots = tf.sparse_to_dense(sparse_indices, [batch_size, n_labels], 1, 0,validate_indices=False)
#print(mul_hots[0])
报错:InvalidArgumentError: indices[8] = [2,1] is out of order
原因:在不设置参数validate_indices=False时,tf.sparse_to_dense要求indices必须是递增的。这个主要是为了方便函数检查indices是否有重复的。
解决:设置validate_indices=False,关闭这个功能