tensorflow实现2D小波变化dwt和小波逆变换idwt,梯度可以反向传播

使用tensorflow实现小波变化和小波逆变换,并且梯度可以反向传播。因此可以方便的将小波变化嵌入到网络结构中去。
本代码参考pytorch实现的小波变化移植至tensorflow。pytorch实现链接:https://github.com/fbcotter/pytorch_wavelets。

实现中存在的一个的问题是tensorflow不能实现分组卷积,因此这里只能采用循环一个2D卷积来实现,所以会增加时间复杂度。关于分组卷积,在tensorflow的issue中有讨论,链接:https://github.com/tensorflow/tensorflow/issues/3332。
但是目前本人在tensorflow上还没有找到很好的解决方法,即使后来实现了用3D卷积来实现,但是过多的tf.reshape、tf.slice和tf.concat操作,所以依然没有解决问题。希望有更好的解决分组卷积的小伙伴们教教我。

下面的代码,包括二维的小标变换和小波逆变换以及测试代码。注意的是,这里的函数间传递的都是4-Dtensor。这里必须安装pywt才能使用

# -*- coding: utf-8 -*-
# @Author   : Chen Meiya
# @time     : 2018/12/9 21:46
# @File     : tf_dwt_release.py
# @Software : PyCharm

import numpy as np
import tensorflow as tf
from PIL import Image
import pywt
import time
import matplotlib.pyplot as plt


# C is channel # just suit for J=1
def tf_dwt(yl,  in_size, wave='db3'):
    w = pywt.Wavelet(wave)
    ll = np.outer(w.dec_lo, w.dec_lo)
    lh = np.outer(w.dec_hi, w.dec_lo)
    hl = np.outer(w.dec_lo, w.dec_hi)
    hh = np.outer(w.dec_hi, w.dec_hi)
    d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
    d_temp[::-1, ::-1, 0, 0] = ll
    d_temp[::-1, ::-1, 0, 1] = lh
    d_temp[::-1, ::-1, 0, 2] = hl
    d_temp[::-1, ::-1, 0, 3] = hh

    filts = d_temp.astype('float32')
    filts = np.copy(filts)
    filter = tf.convert_to_tensor(filts)
    sz = 2 * (len(w.dec_lo) // 2 - 1)

    with tf.variable_scope('DWT'):

        # Pad odd length images
        if in_size[0] % 2 == 1 and tf.shape(yl)[1] % 2 == 1:
            yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz + 1], [0, 0]]), mode='reflect')
        elif in_size[0] % 2 == 1:
            yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz + 1], [sz, sz], [0, 0]]), mode='reflect')
        elif in_size[1] % 2 == 1:
            yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz + 1], [0, 0]]), mode='reflect')
        else:
            yl = tf.pad(yl, tf.constant([[0, 0], [sz, sz], [sz, sz], [0, 0]]), mode='reflect')

        # group convolution
        outputs = tf.nn.conv2d(yl[:, :, :, 0:1], filter, padding='VALID', strides=[1, 2, 2, 1])
        for channel in range(1, int(yl.shape.dims[3])):
            temp = tf.nn.conv2d(yl[:, :, :, channel:channel+1], filter, padding='VALID', strides=[1, 2, 2, 1])
            outputs = tf.concat([outputs, temp], axis=3)

    return outputs


def tf_idwt(y,  wave='db3'):
    w = pywt.Wavelet(wave)
    ll = np.outer(w.rec_lo, w.rec_lo)
    lh = np.outer(w.rec_hi, w.rec_lo)
    hl = np.outer(w.rec_lo, w.rec_hi)
    hh = np.outer(w.rec_hi, w.rec_hi)
    d_temp = np.zeros((np.shape(ll)[0], np.shape(ll)[1], 1, 4))
    d_temp[:, :, 0, 0] = ll
    d_temp[:, :, 0, 1] = lh
    d_temp[:, :, 0, 2] = hl
    d_temp[:, :, 0, 3] = hh
    filts = d_temp.astype('float32')
    filter = tf.convert_to_tensor(filts)
    s = 2 * (len(w.dec_lo) // 2 - 1)

    with tf.variable_scope('IWT'):
        out_size = tf.shape(y)[1]
        in_t = tf.slice(y, (0, 0, 0, 0),
                           (tf.shape(y)[0], out_size, out_size, 4))

        outputs = tf.nn.conv2d_transpose(in_t, filter, output_shape=[tf.shape(y)[0], 2*(out_size-1)+np.shape(ll)[0],
                                                                     2*(tf.shape(y)[1]-1)+np.shape(ll)[0], 1],
                                         padding='VALID', strides=[1, 2, 2, 1])
        for channels in range(4, int(y.shape.dims[-1]), 4):
            y_batch = tf.slice(y, (0, 0, 0, channels), (tf.shape(y)[0], out_size, out_size, 4))
            out_t = tf.nn.conv2d_transpose(y_batch, filter, output_shape=[tf.shape(y)[0], 2*(out_size-1)+np.shape(ll)[0],
                                                                     2*(out_size-1)+np.shape(ll)[0], 1],
                                           padding='VALID', strides=[1, 2, 2, 1])
            outputs = tf.concat((outputs, out_t), axis=3)
        outputs = outputs[:, s: 2*(out_size-1)+np.shape(ll)[0]-s, s: 2*(out_size-1)+np.shape(ll)[0]-s, :]
    return outputs


if __name__ == '__main__':
    # load images
    a = Image.open('22090.jpg')  # change the image path
    X_n = np.array(a).astype('float32')
    X_n = X_n / 255
    X_n = X_n[0:256, 0:256, :]
    X_t = np.zeros((1, 256, 256, 3), dtype='float32')
    X_t[0, :, :, :] = X_n[:, :, :]

    # test code
    sess = tf.Session()
    inputs = tf.placeholder(tf.float32, [None, None, None, 3], name='inputs')
    outputs_in = tf.placeholder(tf.float32, [None, None, None, 12], name='outputs')
    outputs = tf_dwt(inputs, in_size=[256, 256])
    outputs_mex = tf_idwt(outputs_in)
    sess.run(tf.global_variables_initializer())
    time_start = time.time()
    outputs_dwt = sess.run(outputs, feed_dict={inputs: X_t})
    outputs_mex = sess.run(outputs_mex, feed_dict={outputs_in: outputs_dwt})
    time_end = time.time()
    print('totally cost', time_end - time_start)

    # show the decomposition images
    plt.figure()
    plt.imshow(outputs_dwt[0, :, :, 0], cmap='gray')

    # pywt is the python library to dwt. If you are not install pywt, please annotate the code
    cA, (cH, cV, cD) = pywt.dwt2(X_n[:, :, 0], 'db3')

    # compare to the groundtruth
    plt.figure()
    plt.imshow(np.abs(cA-outputs_dwt[0, :, :, 0]), cmap='gray')

    plt.show()






  • 可参见后面的第二篇博客查看优化后的版本

猜你喜欢

转载自blog.csdn.net/zseqsc_asd/article/details/84932855