PytorchCNN项目搭建4---常见的卷积神经网络cnn

PytorchCNN项目搭建4---常见的卷积神经网络cnn


整体的代码在我的github上面可以查阅


常见的卷积神经网络

今日的主要目的不是介绍各个CNN的原理等,原理性的介绍很多博主已经解释的很详细了。本文主要目的是如何使用这些models,之后进行训练学习,所以,我会列举几个重要的CNN,给出文章及基本的介绍,大家可以自行查阅。

 
我们把需要的 cnn_models 写入新建的models文件夹中,统一调用

常见的CNN有:Lenet、Alexnet、VGG、Inception、Resnet等。具体的文章和代码详见参考文献[2]


联合Argparse调用训练的网络

之前,我们学习了argparse的使用,我们把每次使用的网络模型也作为参数进行输入,需要一个函数来将输入的参数和实际的models联系起来,在utils文件夹下,写一个get_net.py 函数,达到需要的要求。

import os
import sys
from models import resnet,vgg,inception,squeezenet
from models.resnet import resnet18,resnet34,resnet50,resnet101,resnet152
from models.vgg import vgg11,vgg11_bn,vgg13,vgg13_bn,vgg16,vgg16_bn,vgg19,vgg19_bn
from models.inception import inception_v3
from models.squeezenet import squeezenet1_0
from models.alexnet import alexnet
from models.lenet5 import LeNet5
import pdb

def get_network(args,cfg):
    """ return given network
    """
    # pdb.set_trace()
    if args.net == 'lenet5':
        net = LeNet5().cuda()
    elif args.net == 'alexnet':
        net = alexnet(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16':
        net = vgg16(pretrained=args.pretrain, num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13':
        net = vgg13(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11':
        net = vgg11(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19':
        net = vgg19(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg16_bn':
        net = vgg16_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg13_bn':
        net = vgg13_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg11_bn':
        net = vgg11_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'vgg19_bn':
        net = vgg19_bn(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net =='inceptionv3':
        net = inception_v3().cuda()
    elif args.net == 'resnet18':
        net = resnet18(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet34':
        net = resnet34(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet50':
        net = resnet50(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda(args.gpuid)
    elif args.net == 'resnet101':
        net = resnet101(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'resnet152':
        net = resnet152(pretrained=args.pretrain,num_classes=cfg.PARA.train.num_classes).cuda()
    elif args.net == 'squeezenet':
        net = squeezenet1_0().cuda()
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    return net


参考文献

1. Pytorch_Torchvision官方CNN Github代码

2. 深度学习的经典算法的论文、解读和代码实现

最后感谢我的师兄,是他手把手教我搭建了整个项目,还有实验室一起学习的小伙伴~ 希望他们万事胜意,鹏程万里!

猜你喜欢

转载自blog.csdn.net/qq_44783177/article/details/113770624