有时想要在一台有GPU的机器上(前提是这台机器已经装好了CUDA)训练好模型之后将模型转到CPU型的,这样就可以在没有GPU的机器上(或者没装cuda加速的机器)导入这个模型了。但是可能会遇到奇怪的错误,这里就记录一些贫僧遇到的奇怪的错误。
奇怪错误之读取模型
在进入了torch的交互命令行环境(就是用th
来进入的那个环境)之后,如果发现用m = torch.load('x.t7')
遇到了这种unknown Torch class <nn.gModule>
错误的话,那么很有可能没有导入全要导入的包,例如这里就少导入了
stack traceback:
[C]: in function 'error'require 'nngraph'
包,导入之后就可以读取模型了。
奇怪错误之nil什么什么的
使用m = torch.load('x.t7')
来读取模型(t7文件,鬼知道里面存了什么。。。)成功后,使用m = m:float()
来转化模型的时候,如果遇到了这个错误:
attempt to call method 'float' (a nil value) 。。。后面略
那么就是时候用torch.type(m)
来确定这个模型的类型了的说。通常会发现模型的类型不是名字含有tensor
的类型(并且很有可能是nil类型,毕竟错误信息都提示了)。遇到这种情况就要用type(m)
来确定这个东东到底是什么东西,例如贫僧在发动了type(m)
技能之后发现m
居然是个lua的table!
这种情况通常都是因为原模型训练者为了方便把模型(nn.gModule
类型)作为table的一部分,和其他一些附加信息(例如模型的设置、作者之类的)一起存在了这个.t7
文件里面。
那么怎么提取出真正的模型呢?
首先要做的就是确定这个读取到的table里面到底有哪些键值:
for key, value in pairs(protos) do
print(key)
end
上面的命令也是直接在交互环境下敲并运行就可以的了。
得到了键值之后可以依次用torch.type(xxx)
来确定这个内容是什么类型的,例如有一个键值是doc
那么就用torch.type(m.doc)
,或者torch.type(m["doc"])
,两个指令其实是一个意思。
如果全都不是的话可能原训练者很逗逼地把模型藏在了table的table里面。。。这时就在table里面找到的talbe上重复上面的步骤吧。。。
通过这种方式找到真正的模型之后(就是nn.gModule
类型的,可能也有别的类型,贫僧还是Torch7小白,不太清楚的说)再用回xxx:float()
转换,就可以了。例如m.doc
就是那部分模型:
model_to_save = m.doc:float()
torch.save('model.t7', model_to_save)
这样就可以保存好已经转化成cpu型的模型了。