文章目录
Pytorch中的常用函数与操作
PYtorch中的各种函数就像英语单词一样,见得多用得多就慢慢掌握了,这里以DQN代码为例,记录我经常用的和碰见的函数方法。(还会有一些python操作)
.detach().cpu()
out = model(inputs)
ls.append(out.detach().cpu().numpy())
detach
阻断反向传播的,经过detach()方法后,变量仍然在GPU上,再利用.cpu()
将数据移至CPU中进行后续操作,如tensor变量转numpy。
np.array与np.ndarray的区别
import numpy as np
# numpy.array() 和 numpy.ndarray()的区别?
mat1 = np.array([[1,2,3],[4,5,6]])
print("mat1 data:{}".format(mat1))
print("mat1 type:{}".format(type(mat1)))
print("mat1 dtype:{}".format(mat1.dtype))
mat2 = np.ndarray(shape=(2,3), dtype=np.int32)
print("mat2 data:{}".format(mat2))
print("mat2 type:{}".format(type(mat2)))
print("mat2 dtype:{}".format(mat2.dtype))
>>>output
mat1 data:[[1 2 3]
[4 5 6]]
mat1 type:<class 'numpy.ndarray'>
mat1 dtype:int32
mat2 data:[[ -153199152 440 0]
[ 0 131074 -2147483648]]
mat2 type:<class 'numpy.ndarray'>
mat2 dtype:int32
ndarray是一个类,其默认构造函数是ndarray()。
array是一个函数,便于创建一个ndarray对象。
np.ndarray()构造函数相对更low-level一些,使用默认构造函数创建的ndarray对象的数组元素是随机值,而numpy提供了一系列的创建ndarray对象的函数,array()就是其中的一种;通常使用这些上层一点的函数来构造ndarray对象会更方便一些3
torch.load
torch.save(net.state_dict(), 'test.pth') # save的是net的state_dict
net.load_state_dict(torch.load('test.pth')) # 加载的也是state_dict
因为state_dict本质上Python字典对象,所以可以很好地进行保存、更新、修改和恢复操作(python字典结构的特性),从而为PyTorch模型和优化器增加了大量的模块化。
.modules()和.children()
model.modules()能够迭代地遍历模型的所有子层,而model.children()只会遍历模型的子层。4
注意这两实现的时候用的是set,相同的网络只输出一次,但是使用nn.Sequential就不会有这种困扰
def named_modules(self, memo: Optional[Set['Module']] = None, prefix: str = '', remove_duplicate: bool = True):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(string, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
Example::
>>> l = nn.Linear(2, 2)
>>> net = nn.Sequential(l, l)
>>> for idx, m in enumerate(net.named_modules()):
print(idx, '->', m)
0 -> ('', Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
))
1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
"""
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
关键词yield
的解释:5
题外话:foo函数的命名由来
在计算机程序设计与计算机技术的相关文档中,术语foobar是一个常见的无名氏化名,常被作为“伪变量”使用。 从技术上讲,“foobar”很可能在1960年代至1970年代初通过迪吉多的系统手册传播开来。另一种说法是,“foobar”可能来源于电子学中反转的foo信号;这是因为如果一个数字信号是低电平有效,那么在信号标记上方一般会标有一根水平横线,而横线的英文即为“bar”。在《新黑客辞典》中,还提到“foo”可能早于“FUBAR”出现。http://zh.wikipedia.org/zh-cn/
permute和reshape/view的区别
permute作用为调换Tensor的维度,参数为调换的维度。例如对于一个二维Tensor来说,调用tensor.permute(1,0)意为将1轴(列轴)与0轴(行轴)调换,相当于进行转置。使用view或者reshape,得到的tensor并不是转置的效果,而是相当于将原tensor的元素按行取出,然后按行放入到新形状的tensor中。6
In [20]: a
Out[20]:
tensor([[0, 1, 2],
[3, 4, 5]])
In [21]: a.permute(1,0)
Out[21]:
tensor([[0, 3],
[1, 4],
[2, 5]])
In [22]: a.reshape(3,2)
Out[22]:
tensor([[0, 1],
[2, 3],
[4, 5]])
In [23]: a.view(3,2)
Out[23]:
tensor([[0, 1],
[2, 3],
[4, 5]])
可以理解为,对于一个高维的Tensor执行permute,我们没有改变数据的相对位置,而只是旋转了一下这个(超)立方体。或者也可以说,改变了我们对这个(超)立方体的“观察角度”而已。
.parameters()
#网络参数数量
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('Total:{}, Trainable:{}'.format(total_num, trainable_num))
.diag()
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(a)
# output:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(torch.diag(a))
# output:
# tensor([1, 5, 9])
print(torch.diag(a, 1))
# output:
# tensor([2, 6])
print(torch.diag(a, -1))
# output:
# tensor([4, 8])
print(torch.diag(a, 2))
# output:
# tensor([3])
print(torch.diag(a, -2))
# output:
# tensor([7])