标签:... nn 模型 DataParallel pytorch device 显卡 model
torch.nn.DataParallel是一种能够将数据分散到多张显卡上从而加快模型训练的方法。
它的原理是首先在指定的每张显卡上拷贝一份模型,然后将输入的数据分散到各张显卡上,计算梯度,回传到第一张显卡上,然后再对模型进行参数优化。
所以,第一张显卡的负载往往更高,但由于该方法集成度高,书写简便,使用仍十分广泛。
示例:
import torch import torch.nn as nn ... gpu_num = x # 可用的gpu数量 model = Model() if gpu_num == 1: # 单卡 model = model.cuda(0) else: # 多卡 device_ids = list(range(gpu_num)) model = nn.DataParallel(model, device_ids=device_ids).cuda(device=device_ids[0]) ... # 所有数据都需要先放到指定的第一张显卡上才能进行多卡训练 data = data.cuda(0) ... # train ...
***注意使用nn.DataParellel时,模型后会自动添加一个.module的属性,在save的时候会将其保存下来,所以在load该模型时需要去掉字典key中的'.module'字符串
***在使用nn.DataParellel时,由于自动添加了module模型,因此需要分块训练模型的时候,也需要将模型块名更改。
例如:
# 原optimizer定义 optimizer = optim.Adam(params=model.part.parameters(), lr=0.00001) # 使用多卡训练后 optimizer = optim.Adam(params=model.module.part.parameters(), lr=0.00001)
标签:...,nn,模型,DataParallel,pytorch,device,显卡,model 来源: https://www.cnblogs.com/s-tyou/p/16558996.html
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。