出现这种问题的原因是,一般都是数据维度的原因,导致参数不匹配,一定要检查参数维度。
你比如,我报错的原因:
我的 image_placeholder声明的是[10,16,112,112,3],但是后面我在调用的时候,却这样用了
logits = c3d_model.inference(image_placeholder[10,:, :, :, :], 。。。。。)
这里的这样(注意10后面逗号)的话就会导致维度错误,有的同学会说,哦!我懂了,把逗号去了就行了,然而并不是这样的。
接下来我们来看一下代码:
import numpy as np
a = np.ones([5,6,7,7,9])
#print(a)
print("加逗号:",a[0,:,:,:,:].shape) #[6,7,7,9]
print("加逗号:",a[1,:,:,:,:].shape) #[6,7,7,9]
print("加逗号:",a[2,:,:,:,:].shape) #[6,7,7,9]
print("加逗号:",a[3,:,:,:,:].shape) #[6,7,7,9]
print("加逗号:",a[4,:,:,:,:].shape) #[6,7,7,9]
#print("加逗号:",a[5,:,:,:,:].shape) #报错:
print("--------------------------")
#上述都等价于a[i]
print("不加逗号:",a[0:,:,:,:].shape) #[5,6,7,7,9]
print("不加逗号:",a[1:,:,:,:].shape) #[4,6,7,7,9]
print("不加逗号: ",a[2:,:,:,:].shape) #[3,6,7,7,9]
print("不加逗号: ",a[3:,:,:,:].shape) #[2,6,7,7,9]
print("不加逗号: ",a[4:,:,:,:].shape) #[1,6,7,7,9]
print("不加逗号: ",a[9:,:,:,:].shape) #不报错 #[0,6,7,7,9],这样的维度的是什么东西,其实就是一个空list :[].
#上述都等价于啊a[i:]
ps:只要维度里出现0,则这都是一个空列表。
相信看到这里不少小伙伴大概明白了,我们稍加解释:
(1)加逗号:即当前维度与下一维度之间有逗号,表示提取第几个slice,这样产生的数据会降低一个维度,并且对数据有要求,只能取值为0-dims(当前维度) -1。否则会报错:ValueError: slice index xxx of dimension 0 out of bounds。
(2)不加逗号:即当前维度与下一维度之间无逗号,表示从当前维度的第几个值进行切片,不会改变原始数据的维度。当取值>dims(当前维度) -1 时,不报错,但是会产生空列表。(小技巧,查逗号,看数值出现在第几个逗号处,则对哪一个维度切片)
接下来有一些例子来加深我们的理解,大家可以先手动计算,然后再写代码验证:
import numpy as np
a = np.ones([5,6,7,8,9])
#print(a)
print("加逗号:",a[:,0,:,:,:].shape)
print('加逗号:',a[3,:,:,1,:].shape)
print('不加逗号:',a[3:,:,2:].shape) #小技巧,查逗号,看数值出现在第几个逗号处,则对哪一个维度取切片。
加逗号: (5, 7, 8, 9)
加逗号: (6, 7, 9)
不加逗号: (2, 6, 5, 8, 9)
熟悉:c3d_model的朋友,或许对个语句不陌生:
images_placeholder[gpu_index * FLAGS.batch_size:(gpu_index + 1) * FLAGS.batch_size,:,:,:,:],
其实很容易理解,这又涉及到另外一种情况:我们来看代码:
import numpy as np
batch_size = 8
gpu_index = 0
a = np.ones([4*8,6,7,8,9])
print(a[gpu_index * batch_size:(gpu_index + 1) * batch_size,:,:,:,:].shape)
for i in range (4):
print(": ",a[:(i + 1)*batch_size,:,:,:,:].shape)
print("i*batch_size: ",a[i * batch_size:(i + 1)*batch_size,:,:,:,:].shape)
(8, 6, 7, 8, 9)
: (8, 6, 7, 8, 9)
i*batch_size: (8, 6, 7, 8, 9)
: (16, 6, 7, 8, 9)
i*batch_size: (8, 6, 7, 8, 9)
: (24, 6, 7, 8, 9)
i*batch_size: (8, 6, 7, 8, 9)
: (32, 6, 7, 8, 9)
i*batch_size: (8, 6, 7, 8, 9)
分析,这里表示对第最高维是(i + 1)*batch_size之前的切片,从i * batch_size处取切片。添加一个:且与当前数值维度无逗号间隔,相当于添加一个维度。
这里我们详细的讲解了上知识,具体的情况大家也可以编写代码验证,希望维度问题这种问题不再影响大家心情。