pytorch 自定义核进行卷积操作

1.介绍

    高斯滤波的用处很多,也有很多现成的包可以被调用,比如opencv里面的cv2.GaussianBlur,一般情况,我们是没必要去造轮子,除非遇到特殊情况,比如我们在使用pytorch的过程中,需要自定义高斯核进行卷积操作,假设,我们要用的高斯核的参数是以下数目:

0.00655965 0.01330373 0.00655965 0.00078633 0.00002292
0.00655965 0.05472157 0.11098164 0.05472157 0.00655965
0.01330373 0.11098164 0.22508352 0.11098164 0.01330373
0.00655965 0.05472157 0.11098164 0.05472157 0.00655965
0.00078633 0.00655965 0.01330373 0.00655965 0.00078633

    在使用pytorch过程中,常用的卷积函数是:

class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
  
  

    感觉是无法自定义卷积权重,那么我们就此放弃吗?肯定不是,当你再仔细看看pytorch的说明书之后,会发现一个好东西:

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
  
  

    里面的weight参数刚好可以用高斯核参数来填充。

2.代码


  
  
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.autograd import Variable
  5. import numpy as np
  6. import cv2
  7. class GaussianBlurConv(nn.Module):
  8. def __init__(self, channels=3):
  9. super(GaussianBlurConv, self).__init__()
  10. self.channels = channels
  11. kernel = [[ 0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
  12. [ 0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
  13. [ 0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
  14. [ 0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
  15. [ 0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
  16. kernel = torch.FloatTensor(kernel).unsqueeze( 0).unsqueeze( 0)
  17. kernel = np.repeat(kernel, self.channels, axis= 0)
  18. self.weight = nn.Parameter(data=kernel, requires_grad= False)
  19. def __call__(self, x):
  20. x = F.conv2d(x.unsqueeze( 0), self.weight, padding= 2, groups=self.channels)
  21. return x
  22. input_x = cv2.imread( "kodim04.png")
  23. cv2.imshow( "input_x", input_x)
  24. input_x = Variable(torch.from_numpy(input_x.astype(np.float32))).permute( 2, 0, 1)
  25. gaussian_conv = GaussianBlurConv()
  26. out_x = gaussian_conv(input_x)
  27. out_x = out_x.squeeze( 0).permute( 1, 2, 0).data.numpy().astype(np.uint8)
  28. cv2.imshow( "out_x", out_x)
  29. cv2.waitKey( 0)

    原图:

    输出图:

3.扩展应用

    我们知道了怎么自定义高斯核,其它的核都可以照搬,这里就不一一讲述了。

发布了381 篇原创文章 · 获赞 67 · 访问量 15万+

猜你喜欢

转载自blog.csdn.net/Arthur_Holmes/article/details/104499107