ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

Pytorch学习

2021-11-30 10:59:34  阅读:154  来源: 互联网

标签:dim nn state mid 学习 Pytorch action net


Pytorch学习 -- 深度学习

一级目录

二级目录

三级目录

Pytorch必须在init初始化网络结构 forward中做feed forward网络的前馈
创建网络结构代码
待更新知识点学习:

  1. 张量tensor的各种操作.argmax() add() 等等 link
  2. nn.module 父类
  3. nn.Sequential()
  4. nn.module中的各层 nn.Linear()
  5. 激活函数 nn.ReLU() nn.LeackyReLU() nn.ELU() 等等
  6. 损失函数
  7. 优化器
import torch
import torch.nn as nn
from collections import OrderedDict

class Net(nn.Module):  # nn.Module is standard PyTorch Network
    def __init__(self,state_dim, mid_dim,  action_dim):
        '''
        相当于create model
        :param mid_dim: 中间层神经元数
        :param state_dim: 状态层神经元数 输入
        :param action_dim: 动作层神经元数 输出
        '''
        super().__init__()  # 第一句话,调用父类的构造函数
        self.net = nn.Sequential(
            nn.Linear(state_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, mid_dim),
            nn.ELU(),
            nn.Linear(mid_dim, action_dim)
        )

    def forward(self, state):
        return self.net(state)  # 计算Q-value  直接返回action-dim的张量

主函数【注:该代码无法运行,目前用于学习流程】

if __name__ == '__main__':
    '''
    听说 pytorch的训练需要自己写?
    好吧 学习ing
    '''
    # 初始化 模型的类
    net = QNet(13, 7, 8)  # 输入13维向量 隐藏层7维 输出8维向量
    # 选择 损失函数和优化器
    criterion = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(net.parameters(),lr=1e-4)

    # 可以开始训练了
    state = torch.Tensor([20,15,14,12,20.00,50,7,4.98,0.4,0.8,0.1,0,5])
    # y是目标值 
    for t in range(500):
        y_pred = net(state)  # 必须传入tensor 
        loss = criterion(y_pred,y)   # 计算损失函数 但是强化学习没有y呀???疑惑??
        optimizer.zero_grad() # 梯度置零
        loss.backward()
        optimizer.step()

总结:使用pytorch训练强化学习算法DQN模型
主训练过程

max_episodes
max_steps

for episode in max_episodes:
	state = get_state() # 获得初始状态
	for t in max_steps:
		action = choose_action(state) #根据当前状态选择动作
		_,reward,done,_ = env.step(action)  # 执行动作 获得奖励
		next_state = get_state() # 观察获得新状态
		memory.push(state,action,reward,next_state)  # 将transition存入经验缓冲池
		optimize_model()  # 优化模型
	if episode%target_update == 0: # 如果到了target_net更新的轮次,更新target_net
		target_net.update()  

其中optimize_model() 是本次记录的重点 pytorch构建的模型是如何优化更新的呢
经验缓冲池功能:存储经验,随机采样,
agent功能:choose_action() target_net.update()
环境功能:

def optimize_model():
	'''
	step1:首先从经验缓冲池中进行随机采样,将其拼接??
	'''
	if len(memory)<batch_size:
		return 
	transitions = memeory.sample()
	# 拼接操作  torch.cat()
	# 计算Q(st,a)  得到采取行动的列
 	state_action_values = policy_net(state_batch).gather(1,action_batch)
	# 计算下一个状态
	next_state_values = target_net(state).max(1)[0].detach()
	# 计算期望q值
	excepted_state_action_value = (next_state_values * gamma) + reward_batch
	
	# 计算loss
	loss = torch.nn.functional.smooth_l1_loss(state_action_values ,excepted_state_action_value )
	
	# 优化模型
	optimizer.zero_grad()
	loss.backward()
	
	

标签:dim,nn,state,mid,学习,Pytorch,action,net
来源: https://blog.csdn.net/qq_36930921/article/details/121269779

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有