ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

风格迁移训练实践

2022-02-17 15:01:02  阅读:150  来源: 互联网

标签:style 训练 img self 实践 param content 迁移 model


前一篇文章分享了Pytorch简单风格迁移的代码,本着不跑挂服务器不死心的态度,不停的增加计算步骤,看看图片融合生成的效果,

为了方便一次性执行,把代码简单改造了一下,与前一篇文章大同小异:

  1 import torch
  2 import torch.nn as nn
  3 import torch.nn.functional as F
  4 import torch.optim as optim
  5 
  6 from PIL import Image
  7 import matplotlib.pyplot as plt
  8 
  9 import torchvision.transforms as transforms
 10 import torchvision.models as models
 11 import datetime
 12 
 13 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 14 
 15 
 16 def get_img_size(img_name):
 17     """
 18     获取图像大小
 19     :param img_name:
 20     :return:
 21     """
 22     im = Image.open(img_name).convert('RGB')
 23     return im, im.height, im.width
 24 
 25 
 26 def image_loader(img, im_h, im_w):
 27     """
 28     加载图像
 29     :param img:
 30     :param im_h:
 31     :param im_w:
 32     :return:
 33     """
 34 
 35     # loader = transforms.Compose([transforms.Resize([im_h, im_w]), transforms.ToTensor()])
 36     loader = transforms.Compose([transforms.Resize([1000, 1000]), transforms.ToTensor()])
 37     im_l = loader(img).unsqueeze(0)
 38     return im_l.to(device, torch.float)
 39 
 40 
 41 def im_show(tensor, save_file_path):
 42     """
 43     显示保存图片
 44     :param tensor:
 45     :param save_file_path:
 46     :return:
 47     """
 48     image = tensor.cpu().clone()
 49     image = image.squeeze(0)
 50     image = transforms.ToPILImage()(image)
 51     plt.imshow(image, aspect='equal')
 52     plt.axis('off')
 53     plt.savefig(save_file_path, bbox_inches='tight', pad_inches=0.0)
 54     plt.pause(0.001)
 55 
 56 
 57 class ContentLoss(nn.Module):
 58     """
 59     内容损失
 60     """
 61 
 62     def __init__(self, target,):
 63         super(ContentLoss, self).__init__()
 64         self.target = target.detach()
 65 
 66     def forward(self, cl_input):
 67         self.loss = F.mse_loss(cl_input, self.target)
 68         return cl_input
 69 
 70 
 71 def gram_matrix(gm_input):
 72     """
 73     风格损失矩阵
 74     :param gm_input:
 75     :return:
 76     """
 77     a, b, c, d = gm_input.size()
 78     features = gm_input.view(a * b, c * d)
 79     G = torch.mm(features, features.t())
 80 
 81     return G.div(a * b * c * d)
 82 
 83 
 84 class StyleLoss(nn.Module):
 85     """
 86     风格损失
 87     """
 88 
 89     def __init__(self, target_feature):
 90         super(StyleLoss, self).__init__()
 91         self.target = gram_matrix(target_feature).detach()
 92 
 93     def forward(self, fw_input):
 94         G = gram_matrix(fw_input)
 95         self.loss = F.mse_loss(G, self.target)
 96         return fw_input
 97 
 98 
 99 # 使用19层的VGG神经网络模型
100 cnn = models.vgg19(pretrained=True).features.to(device).eval()
101 
102 
103 cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
104 cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
105 
106 
107 class Normalization(nn.Module):
108     """
109     规范化输入图像
110     """
111     def __init__(self, mean, std):
112         super(Normalization, self).__init__()
113         self.mean = mean.clone().detach().view(-1, 1, 1)
114         self.std = std.clone().detach().view(-1, 1, 1)
115 
116     def forward(self, img):
117         return (img - self.mean) / self.std
118 
119 
120 def get_style_model_and_losses(cn, normalization_mean, normalization_std, style_i, content_i, cld, sld):
121     """
122     获取内容损失和风格损失
123     :param cn:
124     :param normalization_mean:
125     :param normalization_std:
126     :param style_i:
127     :param content_i:
128     :param cld:
129     :param sld:
130     :return:
131     """
132 
133     normalization = Normalization(normalization_mean, normalization_std).to(device)
134     content_losses = []
135     style_losses = []
136 
137     model = nn.Sequential(normalization)
138 
139     i = 0
140     for layer in cn.children():
141         if isinstance(layer, nn.Conv2d):
142             i += 1
143             name = 'conv_{}'.format(i)
144         elif isinstance(layer, nn.ReLU):
145             name = 'relu_{}'.format(i)
146             layer = nn.ReLU(inplace=False)
147         elif isinstance(layer, nn.MaxPool2d):
148             name = 'pool_{}'.format(i)
149         elif isinstance(layer, nn.BatchNorm2d):
150             name = 'bn_{}'.format(i)
151         else:
152             raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
153 
154         model.add_module(name, layer)
155 
156         if name in cld:
157             target = model(content_i).detach()
158             content_loss = ContentLoss(target)
159             model.add_module("content_loss_{}".format(i), content_loss)
160             content_losses.append(content_loss)
161 
162         if name in sld:
163             target_feature = model(style_i).detach()
164             style_loss = StyleLoss(target_feature)
165             model.add_module("style_loss_{}".format(i), style_loss)
166             style_losses.append(style_loss)
167 
168     for i in range(len(model) - 1, -1, -1):
169         if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
170             break
171 
172     model = model[:(i + 1)]
173 
174     return model, style_losses, content_losses
175 
176 
177 def get_input_optimizer(input_i):
178     """
179     使用 L-BFGS 算法
180     最小化风格、内容的损失
181     :param input_i:
182     :return:
183     """
184     optimizer = optim.LBFGS([input_i])
185     return optimizer
186 
187 
188 def run_style_transfer(cn, norma_mean, normalization_std, ct_img, sl_img, in_img, steps, style_weight, content_weight):
189     """
190     样式转换,建立风格迁移模型
191     :param cn:
192     :param norma_mean:
193     :param normalization_std:
194     :param ct_img:
195     :param sl_img:
196     :param in_img:
197     :param steps:
198     :param style_weight:
199     :param content_weight:
200     :return:
201     """
202     content_layers = ['conv_4']
203     style_layers = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
204     model, style_losses, content_losses = get_style_model_and_losses(cn, norma_mean, normalization_std, sl_img, ct_img, content_layers, style_layers)
205     in_img.requires_grad_(True)
206     model.requires_grad_(False)
207 
208     optimizer = get_input_optimizer(in_img)
209     print('Optimizing..')
210     run = [0]
211     while run[0] <= steps:
212 
213         def closure():
214             with torch.no_grad():
215                 in_img.clamp_(0, 1)
216 
217             optimizer.zero_grad()
218             model(in_img)
219             style_score = 0
220             content_score = 0
221 
222             for sl in style_losses:
223                 style_score += sl.loss
224             for cl in content_losses:
225                 content_score += cl.loss
226 
227             style_score *= style_weight
228             content_score *= content_weight
229 
230             loss = style_score + content_score
231             loss.backward()
232 
233             run[0] += 1
234             if run[0] % 50 == 0:
235                 print("run {}:".format(run))
236                 print('Style Loss : {:4f} Content Loss: {:4f}'.format(style_score.item(), content_score.item()))
237             return style_score + content_score
238 
239         optimizer.step(closure)
240     with torch.no_grad():
241         in_img.clamp_(0, 1)
242     return in_img
243 
244 
245 def style_transfer(content_image_path, style_image_path, image_save_path, run_steps):
246     """
247     风格迁移主入口
248     :param content_image_path: 内容图片
249     :param style_image_path: 风格图片
250     :param image_save_path: 存储图片地址
251     :param run_steps: 执行计算次数
252     :return:
253     """
254     c_image, c_im_h, c_im_w = get_img_size(content_image_path)
255     s_image, s_im_h, s_im_w = get_img_size(style_image_path)
256     content_img = image_loader(c_image, c_im_h, c_im_w)
257     style_img = image_loader(s_image, c_im_h, c_im_w)
258     assert style_img.size() == content_img.size()
259     # 输入内容图像
260     input_img = content_img.clone()
261     begin_time = datetime.datetime.now()
262     print("******************开始时间*****************", begin_time)
263     output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img, run_steps, s_weight, c_weight)
264     try:
265         im_show(output, image_save_path)
266     except Exception as e:
267         print(e)
268     print("******************结束时间*****************", datetime.datetime.now())
269     print("******************耗时*****************", datetime.datetime.now() - begin_time)
270 
271 
272 if __name__ == '__main__':
273     s_weight = 1000000
274     c_weight = 1
275     # content_img_path = "data/drew/img/512.png"
276     content_img_path = "/data/drew/img/dancing.jpg"
277     # style_img_path = "data/drew/img/512r.png"
278     style_img_path = "/data/drew/img/picasso.jpg"
279     for steps in range(100, 3200, 200):
280         # save_path = "data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
281         save_path = "/data/drew/img/end_%s_%s.jpg" % (steps, datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
282         style_transfer(content_img_path, style_img_path, save_path, steps)
View Code

 

标签:style,训练,img,self,实践,param,content,迁移,model
来源: https://www.cnblogs.com/drewgg/p/15904654.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有