• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

迁移pytorch工程至matlab

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

好久没写博客了,险些以为自己找不到密码了。

最近抽空参与了个小项目,很惭愧,只做了三件小事


1. 基于PyTorch训练了一系列单图像超分辨神经网络

 

基于PyTorch训练了一系列单图像超分辨神经网络,超分辨系数从2-10。
该部分的实现参考了pytorch官方repo中的SR例程,训练程序包含于`./train`文件夹。该项目
基于高效子像素卷积层[1]进行空间分辨率提升操作,训练速度极快。

[1] ["Shi W, Caballero J, Huszar F, et al. Real-Time Single Image and
    Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
    Neural Network[J]. 2016:1874-1883.](https://arxiv.org/abs/1609.05158)

 

2. 把训练好的模型权值转存为MATLAB文件。


简单粗暴,异常直接,只要把对应卷积层的权值全部提取出来就可以了。

提取的时候注意一点,要把pytorch中的Variable格式转换为Tensor,再转换为CPU模式,最终转换为numpy数组。

这一系列过程合并起来就是:

Var.data.cpu().numpy()

具体实现如下:

 1 from __future__ import print_function
 2 
 3 import torch
 4 import numpy as np
 5 import scipy.io as sio
 6 
 7 for i in [2, 3, 4, 5, 6, 7, 8, 9, 10]:
 8 
 9     model_name = 'model_upscale_{}_epoch_101.pth'.format(i)
10     model = torch.load(model_name)
11     print(model._modules)
12 
13     weight = dict()
14     weight['conv1_w'] = model._modules['conv1']._parameters['weight'].data.cpu().numpy()
15     weight['conv2_w'] = model._modules['conv2']._parameters['weight'].data.cpu().numpy()
16     weight['conv3_w'] = model._modules['conv3']._parameters['weight'].data.cpu().numpy()
17     weight['conv4_w'] = model._modules['conv4']._parameters['weight'].data.cpu().numpy()
18 
19     weight['conv1_b'] = model._modules['conv1']._parameters['bias'].data.cpu().numpy()
20     weight['conv2_b'] = model._modules['conv2']._parameters['bias'].data.cpu().numpy()
21     weight['conv3_b'] = model._modules['conv3']._parameters['bias'].data.cpu().numpy()
22     weight['conv4_b'] = model._modules['conv4']._parameters['bias'].data.cpu().numpy()
23 
24     sio.savemat('model_upscale_{}.mat'.format(i), mdict=weight)

 

3. 把网络的test过程移植到了MATLAB平台,并撰写了测试代码。

把卷积层和pixelshuffle层用matlab重写了一下。

复现pixelshuffle层的时候遇到了一些麻烦,又回头看了下pytorch里的测试代码

`https://github.com/pytorch/pytorch/blob/master/test/test_nn.py `

# https://github.com/pytorch/pytorch/blob/master/test/test_nn.py
def _verify_pixel_shuffle(self, input, output, upscale_factor):
    for c in range(output.size(1)):
        for h in range(output.size(2)):
            for w in range(output.size(3)):
                height_idx = h // upscale_factor
                weight_idx = w // upscale_factor
                channel_idx = (upscale_factor * (h % upscale_factor)) + (w % upscale_factor) + \
                              (c * upscale_factor ** 2)
    self.assertEqual(output[:, c, h, w], input[:, channel_idx, height_idx, weight_idx])

理了理思路,改写成MATLAB代码:

 1 function [ outputs ] = PixelShuffle( inputs, upscale_factor )
 2 %    PixelShuffle :
 3 %
 4 %   input : N, upscale_factor ** 2, H, W
 5 %   output : N, 1, H*upscale_factor, W*upscale_factor
 6 
 7 [N, ~, H, W] = size(inputs);
 8 H_out = H*upscale_factor;
 9 W_out = W*upscale_factor;
10 outputs = zeros([N, 1, H_out, W_out]);
11 for i = 1:N
12     for h = 1: H_out
13         for w = 1:W_out
14             height_idx = floor(h / upscale_factor+0.5);
15             weight_idx = floor(w / upscale_factor+0.5);
16             channel_idx = (upscale_factor * mod(h-1, upscale_factor)) + mod(w-1, upscale_factor)+1;
17             outputs(i, 1, h, w) = inputs(i, channel_idx, height_idx, weight_idx);
18         end
19     end
20 end
21 end

4. 完整工程github链接。

https://github.com/JiJingYu/super-resolution-by-subpixel-convolution

模型权值已保存为matlab权值,直接在matlab中运行`demo.m`文件即可验证

 


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
matlab复数计算发布时间:2022-07-22
下一篇:
Matlab数理统计工具箱应用简介发布时间:2022-07-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap