ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

基于Pytorch1.8.0+Win10+RTX3070的MNIST网络构建与训练

2021-06-28 20:58:55  阅读:189  来源: 互联网

标签:obj network losses loader test train RTX3070 Win10 Pytorch1.8


直接上代码

先上整个的代码

import torch
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

#  参考:https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187

class Mnist_Net(nn.Module):
    '''
    定义网络
    '''
    def __init__(self):
        super(Mnist_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        # 激活函数
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        # 激活函数
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        # 返回结果
        return F.log_softmax(x)

def training_net(epoch,network,train_loader,optimizer,train_losses, train_counter,log_interval):
    '''
    一个种群训练一代
    :param epoch: 用于现实到第几个代了
    :param network: 模型对象
    :param train_loader: 训练数据对象
    :param optimizer: 优化器对象
    :param train_losses:
    :param train_counter:
    :param log_interval:
    :return:
    '''
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        # 将一个图片传入到网络中,得到out结果
        output = network(data)
        # 计算LOSS
        loss = F.nll_loss(output, target)
        # 反向传播LOSS
        loss.backward()
        # 优化器
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append((batch_idx * 64) + ((epoch - 1) * len(train_loader.dataset)))
            # 保存网络模型
            torch.save(network.state_dict(), './model.pth')
            # 保存优化器结果
            torch.save(optimizer.state_dict(), './optimizer.pth')


def testing_net(network, test_loader,test_losses):
    '''
    测试集执行
    :param network:
    :param test_loader:
    :param test_losses:
    :return:
    '''
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            # 首先得到out结果
            output = network(data)
            # 计算LOSS
            test_loss += F.nll_loss(output, target, size_average=False).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))


def view_dataset_figure(test_loader):
    '''
    展示训练和测试的数据图
    :param test_loader:
    :return:
    '''
    # 让我们看看一批测试数据由什么组成。
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    print(example_targets)
    print(example_data.shape)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Ground Truth: {}".format(example_targets[i]))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def show_loss_line_figure(train_counter,train_losses,test_counter,test_losses):
    '''
    展示LOSS曲线
    :param train_counter:
    :param train_losses:
    :param test_counter:
    :param test_losses:
    :return:
    '''
    fig = plt.figure()
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()


def show_predict_result(network,test_loader):
    '''
    展示预测数据的结果,目前是用的test数据集中的数据
    :param network:
    :param test_loader:
    :return:
    '''
    examples = enumerate(test_loader)
    batch_idx, (example_data, example_targets) = next(examples)
    with torch.no_grad():
        output = network(example_data)
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
        plt.title("Prediction: {}".format(
            output.data.max(1, keepdim=True)[1][i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


def execute_through_new():
    '''
    新的执行训练
    :return:
    '''
    n_epochs = 3
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),
        batch_size=batch_size_train, shuffle=True)
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    view_dataset_figure(test_loader_obj)
    network_obj = Mnist_Net()
    optimizer_obj = optim.SGD(network_obj.parameters(), lr=learning_rate,momentum=momentum)
    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    testing_net(network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 训练一代
        training_net(epoch, network_obj, train_loader_obj, optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        # 测试一代
        testing_net(network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(network_obj,test_loader_obj)


def execute_through_checkpoint():
    '''
    基于断点的执行训练
    :return:
    '''
    n_epochs = 30
    batch_size_train = 64
    batch_size_test = 1000
    learning_rate = 0.01
    momentum = 0.5
    log_interval = 10
    random_seed = 1
    torch.manual_seed(random_seed)
    # 加载数据
    train_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=True, download=True,
                                   transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           (0.1307,), (0.3081,))
                                   ])),batch_size=batch_size_train, shuffle=True)
    # 加载数据
    test_loader_obj = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST('./data/', train=False, download=True, transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))])
                                   ), batch_size=batch_size_test, shuffle=True
    )
    # 查看数据
    view_dataset_figure(test_loader_obj)
    # 形成网络对象
    continued_network_obj = Mnist_Net()
    # 形成优化器对象
    continued_optimizer_obj = optim.SGD(continued_network_obj.parameters(), lr=learning_rate,momentum=momentum)
    # 重载断点
    network_state_dict = torch.load('model.pth')
    continued_network_obj.load_state_dict(network_state_dict)
    optimizer_state_dict = torch.load('optimizer.pth')
    continued_optimizer_obj.load_state_dict(optimizer_state_dict)

    train_losses_obj = []
    train_counter_obj = []
    test_losses_obj = []
    test_counter_obj = [i * len(train_loader_obj.dataset) for i in range(n_epochs + 1)]
    # 测试一下测试集 Test set: Avg. loss: 0.0347, Accuracy: 9887/10000 (99%)
    testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    for epoch in range(1, n_epochs + 1):
        # 每个epoch,test一下
        # 训练网络
        training_net(epoch, continued_network_obj, train_loader_obj, continued_optimizer_obj, train_losses_obj, train_counter_obj,log_interval)
        testing_net(continued_network_obj, test_loader_obj, test_losses_obj)
    #画一下训练曲线
    show_loss_line_figure(train_counter_obj,train_losses_obj,test_counter_obj,test_losses_obj)
    #做预测的可视化
    show_predict_result(continued_network_obj,test_loader_obj)

### 主入口
if __name__ == '__main__':
    # 情况一:训练全新的模型;
    # execute_through_new()
    # 情况二:在断点的基础上,接着训练
    execute_through_checkpoint()

算法流程

口号:2【加数据、定模型】+2【训练4、测试2】
在这里插入图片描述
这是主体流程,主要是训练和测试2大步骤,其中训练主要包括了4个环节:网络运行、LOSS计算、反向传播、优化;测试包括了2个环节:网络运行、计算LOSS;

讨论网络模型定义

构建5层,包括两个卷积层,一个Dropout层(降低过拟合),两个线性层,最后返回F.log_softmax(x)。其中,需要去了解Net是集成自nn.Module。
关于nn.Module的详细介绍会在后面的章节展开。

主要参考资料

https://blog.csdn.net/sxf1061700625/article/details/105870851?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162486393316780265489114%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162486393316780265489114&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-105870851.first_rank_v2_pc_rank_v29_1&utm_term=pytorch++mnist&spm=1018.2226.3001.4187

标签:obj,network,losses,loader,test,train,RTX3070,Win10,Pytorch1.8
来源: https://blog.csdn.net/zhouxinxin111/article/details/118311641

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有