版权声明:本文为博主原创文章,转载请联系作者取得授权。 https://blog.csdn.net/xingchenbingbuyu/article/details/77899270
一、环境
- windows 7
- python3.6(Anaconda3)
- keras 2 api
二、模型保存遇到了这个问题
保存部分代码如下:
model_name = 'ssd7_0'
model.save("ssd7_0.h5")
model.save_weights(r'ssd7_0_weights.h5')
然后运行就遇到了这个问题:
---------------------------------------------------------------------------
UnicodeDecodeError Traceback (most recent call last)
<ipython-input-31-bd7fdea3d294> in <module>()
1 model_name = 'ssd7_0'
----> 2 model.save("ssd7_0.h5")
3 model.save_weights(r'ssd7_0_weights.h5')
4
5 print()
D:\Software\Anaconda3\lib\site-packages\keras\engine\topology.py in save(self, filepath, overwrite, include_optimizer)
2551 """
2552 from ..models import save_model
-> 2553 save_model(self, filepath, overwrite, include_optimizer)
2554
2555 def save_weights(self, filepath, overwrite=True):
D:\Software\Anaconda3\lib\site-packages\keras\models.py in save_model(model, filepath, overwrite, include_optimizer)
105 f.attrs['model_config'] = json.dumps({
106 'class_name': model.__class__.__name__,
--> 107 'config': model.get_config()
108 }, default=get_json_type).encode('utf8')
109
D:\Software\Anaconda3\lib\site-packages\keras\engine\topology.py in get_config(self)
2324 for layer in self.layers: # From the earliest layers on.
2325 layer_class_name = layer.__class__.__name__
-> 2326 layer_config = layer.get_config()
2327 filtered_inbound_nodes = []
2328 for original_node_index, node in enumerate(layer.inbound_nodes):
D:\Software\Anaconda3\lib\site-packages\keras\layers\core.py in get_config(self)
657 def get_config(self):
658 if isinstance(self.function, python_types.LambdaType):
--> 659 function = func_dump(self.function)
660 function_type = 'lambda'
661 else:
D:\Software\Anaconda3\lib\site-packages\keras\utils\generic_utils.py in func_dump(func)
173 A tuple `(code, defaults, closure)`.
174 """
--> 175 code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
176 defaults = func.__defaults__
177 if func.__closure__:
UnicodeDecodeError: 'rawunicodeescape' codec can't decode bytes in position 80-81: truncated \UXXXXXXXX escape
三、解决方法
这是一个系统相关的问题,windows路径不兼容。keras的github有好几个这个错误的issue,stack overflow上也有这个问题。参考:
按照里面的方法修改了keras源码,然后重新启动ipython notebook,保存成功。然而代价是重新训练了一遍。。。
修改方法就是:在D:\Software\Anaconda3\Lib\site-packages\keras\utils\generic_utils.py中,修改第175行代码:
#code = marshal.dumps(func.code).decode('raw_unicode_escape')
code = marshal.dumps(func.code).replace(b'\\',b'/').decode('raw_unicode_escape')