ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

nn.TransformerDecoderLayer

2021-12-20 15:33:43  阅读:342  来源: 互联网

标签:layer nn torch decoder nhead model TransformerDecoderLayer


import torch
import torch.nn as nn

decode_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)  # d_model is the input feature, nhead is the number of head in the multiheadattention
memory = torch.ones(10,32,512)  # the sequence from the last layer of the encoder ; 可以类比为: batch_size * seqence_length * hidden_size
tgt = torch.zeros(20,20,512)  # the sequence to the decoder layer
out = decode_layer(tgt,memory)
print(out.shape)# 20*20*512

Details: TransformerDecoderLayer — PyTorch 1.10.0 documentation

如下面一个网络: 选用了Roberta 作为 encoder and the decoder is 6-layers Transformer.

encoder = model_class.from_pretrained(args.model_name_or_path,config=config)  # RobertaModel 当作一个 encoder, 加载的model为: roberta
decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads) # d_model = 768, nhead= 12---the number of heads in the multiheadattention models
decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)

Details for TRANSFORMERDECODER: TransformerDecoder — PyTorch 1.10.0 documentation

标签:layer,nn,torch,decoder,nhead,model,TransformerDecoderLayer
来源: https://blog.csdn.net/weixin_44219178/article/details/122042431

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有