本章将开始进行机器学习中的分类问题,虽然名字叫回归,但他是做分类的。分类问题最基本的一个思维方式:比如识别手写数字,模型会将输入的图片进行十个类别的预测,并给出10个概率。从当中选取概率最高的一个作为预测结果,即为多分类。
这里我们采用一个经典的数据集——minist,该数据集可从pytorch中下载,代码如下:
import torchvision
train_set = torchvision.datasets.MNIST(root=’../dataset/mnist',train=FTrue,download=True)
test_set = torchvision.datasets.MNIST(root='../dataset/mnist',train=False,download=True)
pytorch中还提供了CIFAR-10数据集,就是一堆32×32的小图片,包含50000个训练集与10000个测试集,有10个类别。
由于预测y=wx+b中,y∈R,但输出的概率需要是[0, 1],所以我们需要将预测结果映射到[0, 1]上。这里我们就要用到了logistics函数,该函数图像如图1所示,位于[0, 1]。
用这个函数即可将y_hat映射到需要的区间上,而logistics又称sigmoid,在pytorch的库中就把logistics叫做sigmoid。在深度学习论文中,如果看到σ(),那就是在用sigmoid函数在激活。它与线性回归的唯一区别就是加了一个σ偏置。在代码上的区别如图2所示:
在二分类问题中的loss函数需要用到的公式为:
该函数称为BCE Loss,在代码中的使用为:
criterion = torch.nn.BCELoss(size_average = False)
整个的代码变化只有两处,这样一个框架结构可以编写大量的模型。
下一节将进行处理多维特征输入。