CTCLOSS backward报错

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hzhj2007/article/details/81078248

利用该项目进行中文训练时,程序在ctcloss反向传播时出现问题。为验证ctcloss的有效性,网上找了一段测试代码。

import torch
from torch.autograd import Variable
from warpctc_pytorch import CTCLoss
ctc_loss = CTCLoss()
# expected shape of seqLength x batchSize x alphabet_size
probs = torch.FloatTensor([[[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]]]).transpose(0, 1).contiguous()
labels = Variable(torch.IntTensor([1, 2]))
label_sizes = Variable(torch.IntTensor([2]))
probs_sizes = Variable(torch.IntTensor([2]))
probs = Variable(probs, requires_grad=True) # tells autograd to compute gradients for probs
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()

如上为此处测试warpctc的一段代码,运行后报错。

>>> cost.backward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib64/python2.7/site-packages/torch/autograd/variable.py", line 152, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
  File "/usr/lib64/python2.7/site-packages/torch/autograd/__init__.py", line 98, in backward
    variables, grad_variables, retain_graph)
RuntimeError: expected Variable or None (got torch.FloatTensor)

pytorch官方论坛上有关于此问题的解答,解决的办法是修改$PATH_TO_warp-ctc/pytorch_binding/warpctc_pytorch/__init__.py文件中关于backward()函数的返回值(50行),修改后如下。

 48     @staticmethod
 49     def backward(ctx, grad_output):
 50         #return ctx.grads, None, None, None, None, None
 51         return torch.autograd.Variable(ctx.grads), None, None, None, None, None

另附上该项目测试时,即运行demo.py文件遇到的其他问题:

  • 提示网络的keyError错误。这主要是训练程序中使用cuda配置项时,crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu))函数导致网络的key发生改变,所以,demo.py文件中添加对应语句即可:model = torch.nn.DataParallel(model, device_ids=range(1))。另外ft时,需要修改网络的key。
Traceback (most recent call last):
  File "demo.py", line 29, in <module>
    model.load_state_dict(pre_model)
  File "/usr/lib64/python2.7/site-packages/torch/nn/modules/module.py", line 332, in load_state_dict
    .format(name))
KeyError: 'unexpected key "module.cnn.conv0.weight" in state_dict'
  • 维度错误:一方面修改网络的维度CRNN的最后一个数值和训练时相同;另一方面注释掉preds = preds.squeeze(2)

总算可以训练和测试了,效果不提了。后来改用tensorflow版的crnn,只要修改网络分类的类别数即可,方便高效

参考文献:

  1. https://discuss.pytorch.org/t/ctcloss-backward-path/13576

猜你喜欢

转载自blog.csdn.net/hzhj2007/article/details/81078248