ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

3 softmax

2021-07-06 01:03:16  阅读:234  来源: 互联网

标签:pred torch iter train softmax test net


import torch
import torchvision

def get_data(batch_size=50):
trans = torchvision.transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
transform=trans, download=True)
train = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
return train, test

train_iter, test_iter = get_data()

from d2l import torch as d2l

lr = 0.03
epoch = 50

1.net

def para_init(m):
if type(m) == torch.nn.Linear:
torch.nn.init.normal_(m.weight, mean=0.0, std=0.1)
net = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(28 * 28, 10))
net.apply(para_init)

2 loss

loss = torch.nn.CrossEntropyLoss()

3 optimzer

op = torch.optim.SGD(net.parameters(), lr= lr)

def accuracy(y_pred, y):

print(y_pred)

if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
    y_pred = y_pred.argmax(axis=1)
cmp = (y_pred.type(y.dtype) == y)
return float(cmp.type(y.dtype).sum())

def evaluate_accuracy(net, test_iter):
if isinstance(net, torch.nn.Module):
net.eval()
all_accuracy, all_data = 0, 0
for X, y in test_iter:
all_accuracy += accuracy(net(X), y)
all_data += y.numel()
return all_accuracy / all_data

4 train

def train_epoch(net, train_iter, loss, op):
#/**返回平均loss 和 平均准确率 **/
if isinstance(net, torch.nn.Module):
net.train()

metric = d2l.Accumulator(3)    
for X, y in train_iter:
    y_pred = net(X)
    l = loss(y_pred, y)
    op.zero_grad()
    l.backward()
    op.step()
    metric.add(y.numel(), l * y.numel(), accuracy(y_pred, y))
return metric[1]/metric[0], metric[2]/metric[0]

def train(net, train_iter, test_iter, epoch, loss, op):
animator = d2l.Animator(xlabel='epoch', xlim=[1, epoch], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
for i in range(epoch):
train_metric = train_epoch(net, train_iter, loss, op)
acc = evaluate_accuracy(net, test_iter)
animator.add(i + 1, train_metric + (acc,))

train(net, train_iter, test_iter, epoch, loss, op)

预测

def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)。"""
for X, y in test_iter:
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
print(title)
d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

3.7.6

1.尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。

2.增加迭代周期的数量。为什么测试准确率会在一段时间后降低?我们怎么解决这个问题?
过拟合,可以使用早停法

标签:pred,torch,iter,train,softmax,test,net
来源: https://www.cnblogs.com/pyclq/p/14975037.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有