• 欢迎访问搞代码网站,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站!
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏搞代码吧

Python Pytorch深度学习之图像分类器

python 搞代码 4年前 (2022-01-09) 21次浏览 已收录 0个评论
文章目录[隐藏]

一、简介

通常,当处理图像、文本、语音或视频数据时,可以使用标准Pyt来源gaodai#ma#com搞@代~码$网hon将数据加载到numpy数组格式,然后将这个数组转换成torch.*Tensor

  • 对于图像,可以用Pillow,OpenCV
  • 对于语音,可以用scipy,librosa
  • 对于文本,可以直接用Python或Cython基础数据加载模块,或者用NLTK和SpaCy

特别是对于视觉,Pytorch已经创建了一个叫torchvision的package,该报包含了支持加载类似Imagenet,CIFAR10,MNIST等公共数据集的数据加载模快torchvision.datasets和支持加载图像数据数据转换模块torch.utils.data.DataLoader。这提供了极大地便利,并避免了编写“样板代码”

二、数据集

对于本小节,使用CIFAR10数据集,它包含了是个类别:airplane,automobile,bird,cat,deer,dog,frog,horse,ship,truck。CIFAR10中的图像尺寸是33232,也就是RGB的3层颜色通道,每层通道内的尺寸为32*32

三、训练一个图像分类器

训练图像分类器的步骤

  • 使用torchvision加载并且归一化CIFAR10的训练和测试数据集
  • 定义一个卷积神经网络
  • 定义一个损失函数
  • 在训练样本数据上训练网络
  • 在测试样本数据上测试网络

1、导入package吧

# 使用torchvision,加载并归一化CIFAR10
import torch
import torchvision
import torchvision.transforms as transforms

2、归一化处理+贴标签吧

# torchvision数据集的输出是范围在[0,1]之间的PILImage,将他们转换成归一化范围为[-1,1]之间的张量Tensor
transform=transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
    )
# 训练集
trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=False,transform=transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)
# 测试集
testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=transform)
testloader=torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2)
classes=("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")

3、先来康康训练集中的照片吧

# 展示其中的训练照片
import matplotlib.pyplot as plt
import numpy as np
# 定义图片显示的function
def imshow(img):
    img=img/2+0.5
    npimg=img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()
# 得到随机训练图像
dataiter=iter(trainloader)
images,labels=dataiter.next()
# 展示图片
imshow(torchvision.utils.make_grid(images))
#打印标签labels
print(' '.join("%5s"%classes[labels[j]] for j in range(4)))

搞代码网(gaodaima.com)提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发送到邮箱[email protected],我们会在看到邮件的第一时间内为您处理,或直接联系QQ:872152909。本网站采用BY-NC-SA协议进行授权
转载请注明原文链接:Python Pytorch深度学习之图像分类器
喜欢 (0)
[搞代码]
分享 (0)
发表我的评论
取消评论

表情 贴图 加粗 删除线 居中 斜体 签到

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址