• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

MatConvNet的CNN卷积网络目标函数定义,优化和反向传播及其Matlab代码实现 ...

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

MatConvNet的CNN卷积网络目标函数定义,优化和反向传播及其Matlab代码实现

 

在CNN卷积网络中,从输入前向传播到误差反向传播,如下图所示:

 

其中,在反向传播过程中,还需要把网络的输出经过一个目标函数(Objective function),通常也称为损失函数(Loss function),把网络的输出映射为一个实数,反向传播就是去优化这个损失函数,具体如下图所示:

在上图中,f 表示网络的计算模块,y 为网络输出,g 表示损失函数,网络的输出 y 经 g 映射为一个实数 z 。反向传播需要将误差 z 反向传播,那么就需要计算 z 关于 x 和 w 的偏导和,但是根据链式求导法则需要先求得。

下面设 f 为卷积模块, p 用正态分布随机数填充,则前向传播与反向传播求导如下代码:

[plain] view plain copy
 
  1. <span style="font-size:14px;">% Read an example image  
  2. x = im2single(imread(\'peppers.png\')) ;  
  3.   
  4. % Create a bank of linear filters and apply them to the image  
  5. w = randn(5,5,3,10,\'single\') ;  
  6. y = vl_nnconv(x, w, []) ;  
  7.   
  8. % Create the derivative dz/dy  
  9. dzdy = randn(size(y), \'single\') ;  
  10.   
  11. % Back-propagation  
  12. [dzdx, dzdw] = vl_nnconv(x, w, [], dzdy) ;  
  13. </span>  

通过上述代码,已经熟悉了CNN卷积网络的前向传导与反向传播求导的基本操作,下面自定义一个目标函数,对其求导,并反向传播。

目标函数定义如下:

在上式中,f(x;w,b) 为网络的输出,也即上文中的 y 。因此,在对 w 与 b 求导前需要对 f(x;w,b) 先求导。下面分析这个目标函数以便进行求导。由上式可得:

现将网络输出  f(x;w,b) 设为res.x3(这里定义了一个三层的网络,且为了与MatConvNet定义的网络输出保持一致),E(3,1)即为上式中的目标函数, 也表示输出 z ,P—>pos, N—>neg, Matlab代码实现如下:

[plain] view plain copy
 
  1. <span style="font-size:14px;">E(1,1) = ...  
  2.     mean(max(0, 1 - res.x3(pos))) + ...  
  3.     mean(max(0, res.x3(neg))) ;  
  4.   E(2,1) = 0.5 * shrinkRate * sum(w(:).^2) ;  
  5.   E(3,1) = E(1,1) + E(2,1) ;  
  6.   
  7.   dzdx3 = ...  
  8.     - single(res.x3 < 1 & pos) / sum(pos(:)) + ...  
  9.     + single(res.x3 > 0 & neg) / sum(neg(:)) ;</span>  


完整代码如下:

[plain] view plain copy
 
    1. <span style="font-size:14px;"> % Forward pass  
    2.   res = tinycnn(im, w, b) ;  
    3.   
    4.   % Loss  
    5.   
    6.   E(1,1) = ...  
    7.     mean(max(0, 1 - res.x3(pos))) + ...  
    8.     mean(max(0, res.x3(neg))) ;  
    9.   E(2,1) = 0.5 * shrinkRate * sum(w(:).^2) ;  
    10.   E(3,1) = E(1,1) + E(2,1) ;  
    11.   
    12.   dzdx3 = ...  
    13.     - single(res.x3 < 1 & pos) / sum(pos(:)) + ...  
    14.     + single(res.x3 > 0 & neg) / sum(neg(:)) ;  
    15.   
    16.   % Backward pass  
    17.   res = tinycnn(im, w, b, dzdx3) ;</span>  

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Delphi与DirectX之DelphiX(53):TDIB.DoSplitBlur();发布时间:2022-07-18
下一篇:
Delphi日期时间格式错误解决办法发布时间:2022-07-18
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap