ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

线性回归实现

2022-08-26 22:31:10  阅读:141  来源: 互联网

标签:定义 实现 梯度 回归 线性 net data 模型


深度学习第一章:最简单的线性回归实现

1. 引言

AI领域的线性回归和其他领域不太一样,包括了名词和实现方式,所以必须先认识重要名词,再把所有步骤熟悉一边,并建立在之前学习线性回归的基础上

2. 学习目的:

  1. 知道线性回归是什么
  2. 知道线性回归在深度学习领域怎么在python上实现
  3. 代码实现,运行结果
  4. 知道每行代码怎么来的
  5. 了解与后面的联系

3. 线性回归是什么

4. python 实现步骤

  1. 数据部分
    1. 数据生成
    2. 数据处理(小批量提取,生成迭代器)
  2. 模型初始化
    1. 模型定义
    2. 模型初始化
    3. 损失函数定义
  3. 更新规则:
    1. 优化函数定义:SGD 梯度下降
  4. 训练(包含求出损失,反向传递,梯度下降,梯度清零)

5. 代码部分+解读:

import torch as tc
import numpy as np
from torch.utils import data
from  LimuAi.Linear_regression import synthetic_data
from torch import nn
#处理数据:
'''
定义实际数据
'''
true_k=tc.tensor([2,-3.4])
true_b=4.2
feature,lable=synthetic_data(true_k,true_b,1000) #数据初始化
'''
定义读取数据的方法
'''
def read_data(sample,batch_size,is_train=True):#python是一个缩进控制组块的语言
    dataset=data.TensorDataset(*sample) #将sample变成元组之后,经过TensorDataset变成dataset对象,方便传入dataloader函数进行小批量的抽取(我猜的,还没求证)
    return data.DataLoader(dataset,batch_size,shuffle=is_train) #返回一个迭代器,小批量的返回样本数据


batch_size=10
item=read_data((feature,lable),batch_size) #生成item作为下方训练用的迭代器,进行小批量随机梯度下降
print(next(iter(item))) # 使用next得出第一个小批次

'''
定义模型
'''



net = nn.Sequential(nn.Linear(2, 1)) #搭建一个单层神经网络,并且神经元使用的是线性结构,且有两个输入,一个输出


'''
初始化模型参数
'''

net[0].weight.data.normal_(0, 0.01)#对net实力初始化模型即使用[0]来定位,weight/bias .data来初始化,
net[0].bias.data.fill_(0)
'''
定义损失函数
'''
loss = nn.MSELoss() #使用nn底下的包即可实现计算MSE

'''
定义优化算法
'''

trainer = tc.optim.SGD(net.parameters(), lr=0.03)
#把优化算法也定义成对象,通过torch的optim包的SGD来实例化,SGD实例化需要模型参数和学习率(也是梯度下降所必须的)
#参数通过net的parameter可直接输入


'''
训练
'''
num_epochs = 3 #把数据集遍历三遍
for epoch in range(num_epochs): #迭代器必须是一个list/元组之类的
    for X, y in item: #取出随机小批次,用来梯度下降
        l=loss(net(X),y)#比较计算出的yhat和真实的y的RMSE
        trainer.zero_grad() #用来清除模型的累计梯度
        l.backward() #反向传递,回调
        trainer.step() #更新模型参数
    l=loss(net(feature),lable)
    print(f'epoch {epoch + 1}, loss {l:f}')

标签:定义,实现,梯度,回归,线性,net,data,模型
来源: https://www.cnblogs.com/yujiesun-818/p/16629449.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有