struts2文件上传兼容http servlet3.1

虽然struts2已经被淘汰了, 但是一些老项目仍在使用. servlet 3.1已经支持multipart了, 而struts2不支持servlet 3.1
解决办法是: 重新为struts2写一个multipart解析器

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE struts PUBLIC
    "-//Apache Software Foundation//DTD Struts Configuration 2.3//EN"
    "http://struts.apache.org/dtds/struts-2.3.dtd">
<struts>
	<constant name="struts.action.extension" value="," />
	<constant name="struts.devMode" value="true" />
	<constant name="struts.i18n.encoding" value="UTF-8" />
	<constant name="struts.multipart.maxSize" value="52428800" />
    <constant name="struts.multipart.saveDir" value="/temp" />

	<!-- 配置struts2的multipart解析器 -->
    <constant name="struts.multipart.parser" value="com.x.x.x.JakartaMultiPartRequest" />
    <!--  -->
</struts>

JakartaMultiPartRequest的实现

import org.apache.commons.fileupload.FileItem;
import org.apache.commons.fileupload.FileItemHeaders;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.struts2.dispatcher.LocalizedMessage;
import org.apache.struts2.dispatcher.multipart.AbstractMultiPartRequest;
import org.apache.struts2.dispatcher.multipart.StrutsUploadedFile;
import org.apache.struts2.dispatcher.multipart.UploadedFile;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.Part;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;

import static java.lang.String.format;


public class JakartaMultiPartRequest extends AbstractMultiPartRequest {
    static final Logger LOG = LogManager.getLogger(JakartaMultiPartRequest.class);
    protected Map<String, List<FileItem>> files = new HashMap<>();
    protected Map<String, List<String>> params = new HashMap<>();

    public static class Servlet31FileItem implements FileItem{
        private final Part part;
        private final File tempFile;
        public Servlet31FileItem(Part part, File tempFile) throws IOException{
            this.part = part;
            this.tempFile = tempFile;
            FileUtils.copyInputStreamToFile(part.getInputStream(), tempFile);
        }

        @Override
        public InputStream getInputStream() throws IOException {
            return new FileInputStream(this.tempFile);
        }

        @Override
        public String getContentType() {
            return part.getContentType();
        }

        @Override
        public String getName() {
            return part.getName();
        }

        @Override
        public boolean isInMemory() {
            return false;
        }

        @Override
        public long getSize() {
            return part.getSize();
        }

        @Override
        public byte[] get() {
            try {
                return FileUtils.readFileToByteArray(this.tempFile);
            } catch (IOException e) {
                LOG.error("readFileToByteArray error", e);
                return null;
            }
        }

        @Override
        public String getString(String encoding) throws UnsupportedEncodingException {
            return null;
        }

        @Override
        public String getString() {
            return null;
        }

        @Override
        public void write(File file) throws Exception {
            FileUtils.copyFile(this.tempFile, file);
        }

        @Override
        public void delete() {
            try {
                tempFile.delete();
            } catch (Exception e) { }
            try{
                part.delete();
            }catch(Exception e){ }
        }
        @Override
        public String getFieldName() {
            return part.getName();
        }
        @Override
        public void setFieldName(String name) { }

        @Override
        public boolean isFormField() {
            return false;
        }
        @Override
        public void setFormField(boolean state) { }
        @Override
        public OutputStream getOutputStream() throws IOException { return null; }
        @Override
        public FileItemHeaders getHeaders() {
            return null;
        }
        @Override
        public void setHeaders(FileItemHeaders headers) { }
    }

    private static final AtomicInteger COUNTER = new AtomicInteger(0);
    private static final String UID = UUID.randomUUID().toString().replace('-', '_');
    private static String getUniqueId() {
        final int limit = 100000000;
        int current = COUNTER.getAndIncrement();
        String id = Integer.toString(current);

        // If you manage to get more than 100 million of ids, you'll
        // start getting ids longer than 8 characters.
        if (current < limit) {
            id = ("00000000" + id).substring(id.length());
        }
        return id;
    }

    protected File getTempFile(File repository) {
            File tempDir = repository;
            if (tempDir == null) {
                tempDir = new File(System.getProperty("java.io.tmpdir"));
            }
            String tempFileName = format("upload_%s_%s.tmp", UID, getUniqueId());
            return new File(tempDir, tempFileName);
    }

    public void parse(HttpServletRequest request, String saveDir) throws IOException {
        LOG.info("parse multi request");
        try {
            for(Map.Entry<String, String[]> entry : request.getParameterMap().entrySet()){
                LOG.info("entry: {}, {}", entry.getKey(), entry.getValue());
                String[] values = entry.getValue();
                if(values.length>0) {
                    params.put(entry.getKey(), Arrays.asList(entry.getValue()));
                }
            }
            for(Part part : request.getParts()){
                String name = part.getName();
                LOG.info("part: {}, {}, {}", name, part.getSize(), part.getContentType());
                if(part.getContentType() == null) {
                    continue;
                }else{
                    Servlet31FileItem item = new Servlet31FileItem(part, getTempFile(new File(saveDir)));
                    List<FileItem> items = files.get(name);
                    if (items == null) {
                        items = new LinkedList<>();
                        files.put(name, items);
                    }
                    items.add(item);
                }
            }
        } catch (Exception e) {
            LOG.warn("Unable to parse request", e);
            LocalizedMessage errorMessage = buildErrorMessage(e, new Object[]{});
            if (!errors.contains(errorMessage)) {
                errors.add(errorMessage);
            }
        }
    }

    public Enumeration<String> getFileParameterNames() {
        return Collections.enumeration(files.keySet());
    }

    public String[] getContentType(String fieldName) {
        List<FileItem> items = files.get(fieldName);
        if (items == null) {
            return null;
        }
        List<String> contentTypes = new ArrayList<>(items.size());
        for (FileItem fileItem : items) {
            contentTypes.add(fileItem.getContentType());
        }
        return contentTypes.toArray(new String[contentTypes.size()]);
    }

    public UploadedFile[] getFile(String fieldName) {
        List<FileItem> items = files.get(fieldName);
        if (items == null) {
            return null;
        }
        List<UploadedFile> fileList = new ArrayList<>(items.size());
        for (FileItem fileItem : items) {
            File storeLocation = ((Servlet31FileItem) fileItem).tempFile;
            if (fileItem.isInMemory() && storeLocation != null && !storeLocation.exists()) {
                try {
                    storeLocation.createNewFile();
                } catch (IOException e) {
                    LOG.error("Cannot write uploaded empty file to disk: {}", storeLocation.getAbsolutePath(), e);
                }
            }
            fileList.add(new StrutsUploadedFile(storeLocation));
        }
        return fileList.toArray(new UploadedFile[fileList.size()]);
    }

    public String[] getFileNames(String fieldName) {
        List<FileItem> items = files.get(fieldName);
        if (items == null) {
            return null;
        }
        List<String> fileNames = new ArrayList<>(items.size());
        for (FileItem fileItem : items) {
            fileNames.add(getCanonicalName(fileItem.getName()));
        }
        return fileNames.toArray(new String[fileNames.size()]);
    }

    public String[] getFilesystemName(String fieldName) {
        return null;
    }

    public String getParameter(String name) {
        List<String> v = params.get(name);
        if (v != null && v.size() > 0) {
            return v.get(0);
        }

        return null;
    }

    public Enumeration<String> getParameterNames() {
        return Collections.enumeration(params.keySet());
    }

    public String[] getParameterValues(String name) {
        List<String> v = params.get(name);
        if (v != null && v.size() > 0) {
            return v.toArray(new String[v.size()]);
        }

        return null;
    }

    public void cleanUp() {
        Set<String> names = files.keySet();
        for (String name : names) {
            List<FileItem> items = files.get(name);
            for (FileItem item : items) {
                LOG.debug("Removing file {} {}", name, item );
                if (!item.isInMemory()) {
                    item.delete();
                }
            }
        }
    }
}

猜你喜欢

转载自blog.csdn.net/wzj_whut/article/details/88219274