torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型
调用pytorch内置的模型的方法
import torchvision model = torchvision.models.resnet50(pretrained=True)
这样就导入了resnet50的预训练模型了。如果只需要网络结构,不需要用预训练模型的参数来初始化
那么就是:
model = torchvision.models.resnet50(pretrained=False)
如果要导入densenet模型也是同样的道理
比如导入densenet169,且不需要是预训练的模型:
model = torchvision.models.densenet169(pretrained=False)
由于pretrained参数默认是False,所以等价于:
model = torchvision.models.densenet169()
不过为了代码清晰,最好还是加上参数赋值。
解读模型源码Resnet.py
包含的库文件
import torch.nn as nn import math import torch.utils.model_zoo as model_zoo
该库定义了6种Resnet的网络结构
包括
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
每种网络都有训练好的可以直接用的.pth参数文件
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
Resnet中大多使用3*3的卷积定义如下
def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
该函数继承自nn网络中的2维卷积,这样做主要是为了方便,少写参数参数由原来的6个变成了3个
输出图与输入图长宽保持一致
如何定义不同大小的Resnet网络
Resnet类是一个基类,
所谓的”Resnet18″, ‘resnet34′, ‘resnet50′, ‘resnet101′, ‘resnet152’只是Resnet类初始化的时候使用了不同的参数,理论上我们可以根据Resnet类定义任意大小的Resnet网络
下面先看看这些不同大小的Resnet网络是如何定义的
定义Resnet18
def resnet18(pretrained=False, **kwargs): """ Constructs a ResNet-18 model. Args: pretrained (bool):If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) return model
定义Resnet34
def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) return model
我们发现Resnet18和Resnet34的定义几乎是一样的,下面我们把Resnet18,Resnet34,Resnet50,Resnet101,Resnet152,不一样的部分写在一块进行对比
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) #Resnet18 model = R<span style="color:transparent">来源gaodai#ma#com搞*!代#%^码$网</span>esNet(BasicBlock, [3, 4, 6, 3], **kwargs) #Resnet34 model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) #Eesnt50 model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) #Resnet101 model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) #Resnet152
代码看起来非常的简洁工整,
其他resnet18、resnet101等函数和resnet18基本类似,差别主要是在:
1、构建网络结构的时候block的参数不一样,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。
2、调用的block类不一样,比如在resnet50、resnet101、resnet152中调用的是Bottleneck类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在residual结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍。
3、如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。
Resnet类
构建ResNet
网络是通过ResNet这个类进行的。ResNet类是继承PyTorch
中网络的基类:torch.nn.Module。
构建Resnet类主要在于重写 init()
和 forward()
方法。
我们构建的所有网络比如:VGG
,Alexnet
等都需要重写这两个方法,这两个方法很重要
看起来Resne类是整个文档的核心
下面我们就要研究一下Resnet基类
以上就是pytorch教程resnet.py的实现文件源码分析的详细内容,更多请关注gaodaima搞代码网其它相关文章!