标签:loss torch nn 分类 神经网络 train test 鸢尾花 net
import torch from torch import nn from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import numpy as np import matplotlib.pyplot as plt X = torch.tensor(load_iris().data, dtype=torch.float32) y = torch.tensor(load_iris().target, dtype=torch.long) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
导入鸢尾花数据集,这里注意数据和标签类型的设置:dtype=torch.float32,dtype=torch.long,否则会报错
net = nn.Sequential(nn.Linear(4, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 3)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weights, std=0.01) loss = nn.CrossEntropyLoss(reduction="none") trainer = torch.optim.Adam(net.parameters(), lr=0.05) train_loss = [] test_loss = [] train_l = sum(loss(net(X_train), y_train)).detach().numpy() test_l = sum(loss(net(X_test), y_test)).detach().numpy() train_loss.append(train_l) test_loss.append(test_l) epochs = 1000 for i in range(epochs): trainer.zero_grad() l = sum(loss(net(X_train), y_train)) l.backward() trainer.step() l = sum(loss(net(X), y)) train_l = sum(loss(net(X_train), y_train)).detach().numpy() test_l = sum(loss(net(X_test), y_test)).detach().numpy() train_loss.append(train_l) test_loss.append(test_l) epoch_index = range(epochs + 1) plt.plot(epoch_index, train_loss, 'green', epoch_index, test_loss, 'blue') plt.show()
使用交叉熵损失函数时, 定义神经网络架构的时候不需要用Softmax ! (我一开始在神经网络最后一层加了nn.Softmax有报错)
关于交叉熵损失函数,nn.CrossEntropyLoss(),有一些需要注意的点
贴篇网上介绍的博客,后面看自己有没有时间总结下。https://blog.csdn.net/geter_CS/article/details/84857220
有些场合(例如用matplotlib绘图)需要用numpy的数组,使用能求梯度的tensor是会报错的!
这里用.detach().numpy()来完成,例子可以见上面的代码
实验结果:
标签:loss,torch,nn,分类,神经网络,train,test,鸢尾花,net 来源: https://www.cnblogs.com/kyfishing/p/15918401.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。