标签:10 Embedding nn 单词 pytorch 0.1371 embedding input
记得在代码的开始引入
import torch
import torch.nn as nn
举个常用的例子
#以下代码为pytorch的python代码
embedding = nn.Embedding(10, 3)
print(embedding.weight)
input = torch.LongTensor([[0, 2, 0, 5]])
print(input)
print(embedding(input))
embedding的参数为
input的内容为
输出的结果为
对以上的代码和输出的解释
embedding相当于创建一个能翻译10个单词的工具,其中这10个单词为0~9,每个单词对应一个长度为3的向量
input就是一个单词,它由0, 2, 0, 5四个单词组成的
当代码做词嵌入的时候,就相当于把0, 2, 0, 5作为下标,到embedding里面找对应下标的向量。
比如在embedding里面,下标为0的3维度向量为[ 1.5013, -0.1371, 0.4321]
,所以最后的输出会把0替换成[ 1.5013, -0.1371, 0.4321]
依次类推
2会替换成[ 0.6691, 0.9784, -0.1510]
5会替换成[-0.8694, 0.8183, 1.8619]
最后的结果就是
[
[ 1.5013, -0.1371, 0.4321],
[ 0.6691, 0.9784, -0.1510],
[ 1.5013, -0.1371, 0.4321],
[-0.8694, 0.8183, 1.8619]
]
padding_idx的用法(mask)
padding的意思是“填充”
写法
embed = nn.Embedding(10,3,padding_idx=0)
意思就是说当单词为0的时候,进行词嵌入的时候的输出为[0,0,0]
embed = nn.Embedding(10,3,padding_idx=3)
意思就是说当单词为3的时候,进行词嵌入的时候的输出为[0,0,0]
标签:10,Embedding,nn,单词,pytorch,0.1371,embedding,input 来源: https://www.cnblogs.com/lanhongfu/p/16492453.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。