pytorch主要分为以下几个模块来训练模型:
tensor:tensor为基本结构,可以直接创建,从list创建以及由numpy数组得到,torch还提供一套运算以及shape变换方式。
Variable:自动求导机制,利用Variable包装tensor后,便可以使用其求导的功能了,有点像个装饰器。
nn:nn模块是整个pytorch的核心,自己设计的Net(),继承nn.Model后可以提取模型参数,进行前向forward()运算(自己设计),以及后向运算(自动),nn提供基本网络结构单元,例如nn.Linear(),nn.Conv2d()等,还提供基本损失函数nn.CrossEntropyLoss等。
torch.optim:该模块提供自动求导更新参数等功能,用它封装模型参数nn.parameter()后,loss求导后,可以用.step来更新整个参数。
torch.utils.data.DataSet:该模块提供加载数据初始化的方式,完善好getitem和len的接口后,便可以利用DataLoader多进程加载数据。
参考: