我们在实际tensorflow应用中,如果遇到保存稀疏矩阵的时候,会选择Sparse_tensor
,这样可以节省大量的空间。
但是如果想要拆分稀疏矩阵的时候,直观的思路是:先将spare_tensor
转为dense_tensor
,然后拆分,然后再转成spare_tensor
,这个过程中耗时不说,专程dense实际上就违背了我们节省空间的初衷。
- 正确的解决方式是:
def sparse_split(keyword_required=KeywordRequired(),
sp_input=None,
num_split=None,
axis=None,
name=None,
split_dim=None):
"""Split a `SparseTensor` into `num_split` tensors along `axis`.
If the `sp_input.dense_shape[axis]` is not an integer multiple of `num_split`
each slice starting from 0:`shape[axis] % num_split` gets extra one
dimension. For example, if `axis = 1` and `num_split = 2` and the
input is:
input_tensor = shape = [2, 7]
[ a d e ]
[b c ]
Graphically the output tensors are:
output_tensor[0] =
[ a ]
[b c ]
output_tensor[1] =
[ d e ]
[ ]
Args:
keyword_required: Python 2 standin for * (temporary for argument reorder)
sp_input: The `SparseTensor` to split.
num_split: A Python integer. The number of ways to split.
axis: A 0-D `int32` `Tensor`. The dimension along which to split.
name: A name for the operation (optional).
split_dim: Deprecated old name for axis.
Returns:
`num_split` `SparseTensor` objects resulting from splitting `value`.
Raises:
TypeError: If `sp_input` is not a `SparseTensor`.
ValueError: If the deprecated `split_dim` and `axis` are both non None.
"""
- 举例说明用法:
import tensorflow as tf
a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))
b,c = tf.sparse_split(sp_input=a,num_split=2,axis=1)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
输出是:
SparseTensorValue(indices=array([[0, 0],
[1, 1]]), values=array([1, 2], dtype=int32), dense_shape=array([2, 2]))
SparseTensorValue(indices=array([[0, 0]]), values=array([1], dtype=int32), dense_shape=array([2, 1]))
SparseTensorValue(indices=array([[1, 0]]), values=array([2], dtype=int32), dense_shape=array([2, 1]))
-
需要注意的是:
上面是使用python3的版本,如果使用python2,必须传入keyword_required
参数,否则会报错:Keyword arguments are required for this function
python2的调用方法为:
from tensorflow.python.ops.sparse_ops import KeywordRequired
import tensorflow as tf
a = tf.SparseTensor(indices=[[0,0],[1,1]],values=[1,2],dense_shape=(2,2))
b,c = tf.sparse_split(keyword_required=KeywordRequired(),sp_input=a,num_split=2,axis=1)
with tf.Session() as sess:
print(sess.run(a))
print(sess.run(b))
print(sess.run(c))
这样就能解决报错的问题。