• 欢迎访问搞代码网站,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站!
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏搞代码吧

pytorch教程resnet.py的实现文件源码分析

python 搞代码 4年前 (2022-01-08) 34次浏览 已收录 0个评论
文章目录[隐藏]

torchvision.models这个包中包含alexnet、densenet、inception、resnet、squeezenet、vgg等常用的网络结构,并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型

调用pytorch内置的模型的方法

 import torchvision model = torchvision.models.resnet50(pretrained=True) 

这样就导入了resnet50的预训练模型了。如果只需要网络结构,不需要用预训练模型的参数来初始化

那么就是:

 model = torchvision.models.resnet50(pretrained=False) 

如果要导入densenet模型也是同样的道理

比如导入densenet169,且不需要是预训练的模型:

 model = torchvision.models.densenet169(pretrained=False) 

由于pretrained参数默认是False,所以等价于:

 model = torchvision.models.densenet169() 

不过为了代码清晰,最好还是加上参数赋值。

解读模型源码Resnet.py

包含的库文件

 import torch.nn as nn import math import torch.utils.model_zoo as model_zoo 

该库定义了6种Resnet的网络结构

包括

 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50',  'resnet101',  'resnet152'] 

每种网络都有训练好的可以直接用的.pth参数文件

 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50',  'resnet101',  'resnet152'] 

Resnet中大多使用3*3的卷积定义如下

 def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 

该函数继承自nn网络中的2维卷积,这样做主要是为了方便,少写参数参数由原来的6个变成了3个

输出图与输入图长宽保持一致

如何定义不同大小的Resnet网络

Resnet类是一个基类,
所谓的”Resnet18″, ‘resnet34′, ‘resnet50′, ‘resnet101′, ‘resnet152’只是Resnet类初始化的时候使用了不同的参数,理论上我们可以根据Resnet类定义任意大小的Resnet网络
下面先看看这些不同大小的Resnet网络是如何定义的

定义Resnet18

 def resnet18(pretrained=False, **kwargs): """ Constructs a ResNet-18 model. Args: pretrained (bool):If True, returns a model pre-trained on ImageNet """ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) return model 

定义Resnet34

 def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. Args:        pretrained (bool): If True, returns a model pre-trained on ImageNet    """ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) return model 

我们发现Resnet18和Resnet34的定义几乎是一样的,下面我们把Resnet18,Resnet34,Resnet50,Resnet101,Resnet152,不一样的部分写在一块进行对比

 model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)    #Resnet18 model = R<span style="color:transparent">来源gaodai#ma#com搞*!代#%^码$网</span>esNet(BasicBlock, [3, 4, 6, 3], **kwargs)    #Resnet34 model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)    #Eesnt50 model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)  #Resnet101 model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)  #Resnet152 

代码看起来非常的简洁工整,

其他resnet18、resnet101等函数和resnet18基本类似,差别主要是在:

1、构建网络结构的时候block的参数不一样,比如resnet18中是[2, 2, 2, 2],resnet101中是[3, 4, 23, 3]。

2、调用的block类不一样,比如在resnet50、resnet101、resnet152中调用的是Bottleneck类,而在resnet18和resnet34中调用的是BasicBlock类,这两个类的区别主要是在residual结果中卷积层的数量不同,这个是和网络结构相关的,后面会详细介绍。

3、如果下载预训练模型的话,model_urls字典的键不一样,对应不同的预训练模型。因此接下来分别看看如何构建网络结构和如何导入预训练模型。

Resnet类

构建ResNet网络是通过ResNet这个类进行的。ResNet类是继承PyTorch中网络的基类:torch.nn.Module。

构建Resnet类主要在于重写 init() forward() 方法。

我们构建的所有网络比如:VGGAlexnet等都需要重写这两个方法,这两个方法很重要

看起来Resne类是整个文档的核心

下面我们就要研究一下Resnet基类

以上就是pytorch教程resnet.py的实现文件源码分析的详细内容,更多请关注gaodaima搞代码网其它相关文章!


搞代码网(gaodaima.com)提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发送到邮箱[email protected],我们会在看到邮件的第一时间内为您处理,或直接联系QQ:872152909。本网站采用BY-NC-SA协议进行授权
转载请注明原文链接:pytorch教程resnet.py的实现文件源码分析

喜欢 (0)
[搞代码]
分享 (0)
发表我的评论
取消评论

表情 贴图 加粗 删除线 居中 斜体 签到

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址