使用tf.tile()和tf.sparse_to_dense()构建多标签

使用tf.tile()和tf.sparse_to_dense()构建多标签

原文地址:https://blog.csdn.net/qq_31456593/article/details/89448262
多标签分类时需要构建类似one-hot的多标签y_hot,在tensorflow中可以用tf.tile()和tf.sparse_to_dense()构建

tf.tile()

tf.tile()应用于需要张量扩展的场景,具体说来就是:
如果现有一个形状如[width, height]的张量,需要得到一个基于原张量的,形状如[batch_size,width,height]的张量,其中每一个batch的内容都和原张量一模一样

tf.sparse_to_dense

tf.sparse_to_dense(sparse_indices, output_shape, sparse_values, default_value, name=None)
除去name参数用以指定该操作的name,与方法有关的一共四个参数:
第一个参数sparse_indices:稀疏矩阵中那些个别元素对应的索引值。

第二个参数output_shape:输出的稀疏矩阵的shape

第三个参数sparse_values:个别元素的值。

构造onehot

batch_size=6
label=tf.expand_dims(tf.constant([0,2,3,6,7,9]),1)
index=tf.expand_dims(tf.range(0, batch_size),1)
concated = tf.concat(1, [index, label])

构造多标签

		ids = tf.range(batch_size)
        ids = tf.expand_dims(ids, -1)
        m_ids = tf.tile(ids, [1, size])


        # 构建稀疏矩阵
        indices_one = tf.reshape(indices, [-1, 1])
        m_ids_one = tf.reshape(m_ids, [-1, 1])

        sparse_indices = tf.concat([m_ids_one, indices_one], -1)

        mul_hots = tf.sparse_to_dense(sparse_indices, [batch_size, n_labels], 1, 0,validate_indices=False)
        #print(mul_hots[0])

报错:InvalidArgumentError: indices[8] = [2,1] is out of order

原因:在不设置参数validate_indices=False时,tf.sparse_to_dense要求indices必须是递增的。这个主要是为了方便函数检查indices是否有重复的。

解决:设置validate_indices=False,关闭这个功能

发布了143 篇原创文章 · 获赞 345 · 访问量 47万+

猜你喜欢

转载自blog.csdn.net/qq_31456593/article/details/89448262