SpringBoot 自定义 Filter 实现请求与响应加解密

1:基于SpringBoot基础使用的Filter

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.boot.web.servlet.ServletComponentScan;


@ServletComponentScan
@SpringBootApplication
public class RiskEtlService {
    public static void main(String[] args) {
        SpringApplication.run(RiskEtlService.class, args);
    }
}

2:创建filter包,添加一下三个类


import com.hsgd.risk.etl.utlis.DecryptUtil;
import com.hsgd.risk.etl.utlis.EncryptUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.binary.Base64;
import org.springframework.core.annotation.Order;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;

/**
 * @author liyapeng
 */
@Slf4j
@Order(1)
@WebFilter(filterName = "RiskFilter", urlPatterns = "/*")
public class RiskFilter implements Filter {

    @Override
    public void init(FilterConfig filterConfig) {

    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) {
        try {


            String requestBody = getRequestBody((HttpServletRequest) request);
            byte[] decoded = Base64.decodeBase64(requestBody);
            byte[] decipherBytes = DecryptUtil.decipher(decoded);
            String decRequestBody = new String(decipherBytes, StandardCharsets.UTF_8);

            WrapperRequest wrapRequest = new WrapperRequest((HttpServletRequest) request, decRequestBody);
            WrapperResponse wrapResponse = new WrapperResponse((HttpServletResponse) response);
            filterChain.doFilter(wrapRequest, wrapResponse);

            byte[] data = wrapResponse.getResponseData();
            String encResponseBody = new String(EncryptUtil.encrypt(data), StandardCharsets.UTF_8);

            writeResponse(response, encResponseBody);
        } catch (Exception e) {
            log.error("RiskFilter filed,filed info:{}", e);
        }

    }

    private void writeResponse(ServletResponse response, String responseString)
            throws IOException {
        PrintWriter out = response.getWriter();
        out.print(responseString);
        out.flush();
        out.close();
    }

    private String getRequestBody(HttpServletRequest req) {
        try {
            BufferedReader reader = req.getReader();
            StringBuffer sb = new StringBuffer();
            String line = null;
            while ((line = reader.readLine()) != null) {
                sb.append(line);
            }
            String json = sb.toString();
            return json;
        } catch (IOException e) {
            log.error("验签时请求体读取失败", e);
        }
        return "";
    }

    @Override
    public void destroy() {

    }


}

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;

/**
 * @author lilili
 * @version 0.0.1
 * @date 2020/10/30 9:47 上午
 */
public class WrapperRequest extends HttpServletRequestWrapper {
    private String requestBody = null;
    HttpServletRequest req = null;
    public WrapperRequest(HttpServletRequest request) {
        super(request);
        this.req = request;
    }

    public WrapperRequest(HttpServletRequest request, String requestBody) {
        super(request);
        this.requestBody = requestBody;
        this.req = request;
    }


    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new StringReader(requestBody));
    }


    @Override
    public ServletInputStream getInputStream() throws IOException {
        return new ServletInputStream() {
            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            private InputStream in = new ByteArrayInputStream(
                    requestBody.getBytes(req.getCharacterEncoding()));

            @Override
            public int read() throws IOException {
                return in.read();
            }
        };
    }

}

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.*;

/**
 * @author lilili
 * @version 0.0.1
 * @date 2020/10/30 9:22 上午
 */
public class WrapperResponse extends HttpServletResponseWrapper {
    private ByteArrayOutputStream buffer = null;
    private ServletOutputStream out = null;
    private PrintWriter writer = null;

    public WrapperResponse(HttpServletResponse resp) throws IOException {
        super(resp);
        buffer = new ByteArrayOutputStream();
        out = new WapperedOutputStream(buffer);
        writer = new PrintWriter(new OutputStreamWriter(buffer,
                this.getCharacterEncoding()));
    }

    /** 重载父类获取outputstream的方法 */
    @Override
    public ServletOutputStream getOutputStream() throws IOException {
        return out;
    }

    /** 重载父类获取writer的方法 */
    @Override
    public PrintWriter getWriter() throws UnsupportedEncodingException {
        return writer;
    }

    /** 重载父类获取flushBuffer的方法 */
    @Override
    public void flushBuffer() throws IOException {
        if (out != null) {
            out.flush();
        }
        if (writer != null) {
            writer.flush();
        }
    }

    @Override
    public void reset() {
        buffer.reset();
    }

    public byte[] getResponseData() throws IOException {
        flushBuffer();
        return buffer.toByteArray();
    }

    /** 内部类,对ServletOutputStream进行包装 */
    private class WapperedOutputStream extends ServletOutputStream {
        private ByteArrayOutputStream bos = null;

        public WapperedOutputStream(ByteArrayOutputStream stream)
                throws IOException {
            bos = stream;
        }

        @Override
        public void write(int b) throws IOException {
            bos.write(b);
        }

        @Override
        public void write(byte[] b) throws IOException {
            bos.write(b, 0, b.length);
        }

        @Override
        public boolean isReady() {
            return false;
        }

        @Override
        public void setWriteListener(WriteListener writeListener) {

        }
    }

}
 

猜你喜欢

转载自blog.csdn.net/lucklilili/article/details/109384713