ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

CNN 猫狗图像分类

2021-12-23 15:02:48  阅读:185  来源: 互联网

标签:分类 torch labels train 图像 CNN import data correct


 

 

导入基本要的库 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.models as models
import PIL.Image as Image

图形变tensor的转化 

image_size = (224,224)
data_transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
train_data=dset.ImageFolder(root="data/cat-dog/training_set",transform=data_transform)
# 数据集长度
totallen = len(train_data)
print('train data length:',totallen)
test_data=dset.ImageFolder(root="data/cat-dog/test_set",transform=data_transform)
# 数据集长度
testtotallen = len(test_data)
print('test data length:',testtotallen)

trainlen = int(totallen*0.7)
vallen = totallen - trainlen
train_db,val_db=torch.utils.data.random_split(train_data,[trainlen,vallen])
print('train:',len(train_db),'validation:',len(val_db))
# batch size
bs=16
# 训练集
train_loader=torch.utils.data.DataLoader(train_db,batch_size=bs, shuffle=True,num_workers=2)
# 验证集
val_loader=torch.utils.data.DataLoader(val_db,batch_size=bs, shuffle=True,num_workers=2)

 关键的来了

resnet18 = models.resnet18(pretrained=True)
model = resnet18    #下载使用resent18模型
n_classes = len(train_data.classes)
model.fc = nn.Linear(512, n_classes)
import torch.nn.init as init
for name, module in model._modules.items():
    if(name=='fc'):
        # print(module.weight.shape)
        init.kaiming_uniform_(module.weight, a=0, mode='fan_in')
def get_num_correct(out, labels):
    return out.argmax(dim=1).eq(labels).sum().item()
optimizer=torch.optim.SGD(model.parameters(),lr=0.01)
epoch_num = 5
for epoch in range(epoch_num):    #开始反复训练 epoch为次数
    total_loss=0
    total_correct=0
    val_correct=0
    for batch in train_loader:#GetBatch
        images,labels=batch
        outs=model(images)#PassBatch
        loss=F.cross_entropy(outs,labels)#CalculateLoss
        optimizer.zero_grad()
        loss.backward()#CalculateGradients
        optimizer.step()#UpdateWeights
        total_loss+=loss.item()
        total_correct+=get_num_correct(outs,labels)
    for batch in val_loader:
        images,labels=batch
        outs=model(images)
        val_correct+=get_num_correct(outs,labels)
        print("loss:",total_loss,"train_correct:",total_correct/trainlen,        "    val_correct:",val_correct/vallen)

保存训练的模型,开始预测 

torch.save(model, 'catvsdog.pkl')
model = torch.load('catvsdog.pkl')
model.eval()
test_loader = torch.utils.data.DataLoader(dataset = test_data
,batch_size=100
,shuffle=True
)
batch = next(iter(test_loader))
images, labels = batch 
out = model(images)
grid = torchvision.utils.make_grid(images, nrow=10)
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,10))
plt.imshow(np.transpose(grid, (1,2,0)))
print('labels:', labels)
print('predicts:', out.argmax(dim=1))

 

标签:分类,torch,labels,train,图像,CNN,import,data,correct
来源: https://blog.csdn.net/long_songs/article/details/122104681

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有