网上有很多MyBatis物理分页插件,基本都是围绕拦截指定分页对象来处理的;网上分页逻辑为,先查出总的列表数,使用分页插件查出分页后列表集合,封装总列表数和分页列表集合到Page结果对象;用起来还要自己封装结果类,相对比较麻烦;我想省去这些复杂的步骤,直接传递分页对象,返回结果对象。
需求:
这里我想像JPA一样传递Pagination分页对象进去,返回Page结果对象。
逻辑:
MyBatis预处理拦截器拦截指定的Pagination,并重写当前的SQL为分页SQL
MyBatis结果集拦截器拦截指定的Pagination,根据分页SQL得到查询总量的SQL执行得到总的列表数,得到原分页后的列表集合,封装为Page对象并返回
代码类:
PaginationStatementInterceptor类为SQL预处理拦截类,拦截指定的分页类,这里我使用的是org.springframework.data.domain.Pageable分页类(Spring-data-commons包下的,如果用Spring-data应该会附带出来这个包的);如果MyBatis中分页方法有这个分页类,则拦截并修改原来的执行SQL为想要的分页SQL;以MySQL为例比如:之前的SQL为 select * from user 修改后的SQL为 select * from user limit 100, 10利用反射替换原来SQL值的位置delegate.boundSql.sql
下面举例用的方言为MySQL对应的类MySql5Dialect继承了Dialect类,该类的getLimitString得到分页SQL,如果是别的数据库继承Dialect类,重写getLimitString即可
import com.mm.persist.expands.mybatis.dialect.Dialect; import com.mm.persist.expands.mybatis.dialect.MySql5Dialect; import com.mm.persist.expands.mybatis.dialect.OracleDialect; import com.mm.persist.expands.mybatis.dialect.SQLServer2005Dialect; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.ibatis.binding.MapperMethod; import org.apache.ibatis.executor.parameter.ParameterHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.plugin.*; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.factory.DefaultObjectFactory; import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory; import org.apache.ibatis.session.Configuration; import org.springframework.data.domain.Pageable; import java.sql.Connection; import java.util.Properties; @Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class})}) public class PaginationStatementInterceptor implements Interceptor { private final static Log log = LogFactory .getLog(PaginationStatementInterceptor.class); @Override public Object intercept(Invocation invocation) throws Throwable { StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); ParameterHandler parameterHandler = statementHandler.getParameterHandler(); Object parameterObject = parameterHandler.getParameterObject(); Pageable pagination = null; if(parameterObject instanceof MapperMethod.ParamMap){ MapperMethod.ParamMap paramMapObject = (MapperMethod.ParamMap)parameterObject ; if(paramMapObject != null){ for(Object key : paramMapObject.keySet()){ if(paramMapObject.get(key) instanceof Pageable){ pagination = (Pageable) paramMapObject.get(key); break; } } } } if (pagination != null) { MetaObject metaStatementHandler = MetaObject.forObject(statementHandler, new DefaultObjectFactory(), new DefaultObjectWrapperFactory()); Configuration configuration = (Configuration) metaStatementHandler.getValue("delegate.configuration"); Dialect.Type databaseType = null; try { databaseType = Dialect.Type.valueOf(configuration.getVariables().getProperty("dialect").toUpperCase()); } catch (Exception e) { throw new Exception("Generate SQL: Obtain DatabaseType Failed!"); } Dialect dialect = null; switch (databaseType) { case MYSQL: dialect = new MySql5Dialect(); break; case ORACLE: dialect = new OracleDialect(); break; case SQLSERVER: dialect = new SQLServer2005Dialect(); break; } String originalSql = (String) metaStatementHandler.getValue("delegate.boundSql.sql"); metaStatementHandler.setValue("delegate.boundSql.sql", dialect.getLimitString(originalSql, pagination.getPageNumber() * pagination.getPageSize(), pagination.getPageSize())); if (log.isDebugEnabled()) { BoundSql boundSql = statementHandler.getBoundSql(); log.debug("Generate SQL : " + boundSql.getSql()); } return invocation.proceed(); } return invocation.proceed(); } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { } }PaginationResultSetInterceptor类为结果拦截类,拦截指定的分页类,这里我使用的是org.springframework.data.domain.Pageable分页类(Spring-data-commons包下的,如果用Spring-data应该会附带出来这个包的);如果MyBatis中分页方法有这个分页类,则拦截结果集进行重写得到最终想要的org.springframework.data.domain.Page结果类(Spring-data-commons包下的,如果用Spring-data应该会附带出来这个包的);Page中会包含分页信息和分页后的列表信息。 下面举例用的方言为MySQL对应的类MySql5Dialect继承了Dialect类,该类的getCountString方法是根据分页SQL解析获取到对应的查询总记录数的SQL,如果是别的数据库继承Dialect类,重写getCountString即可 invocation.proceed()为原分页结果集,这里根据上面的查询总记录数的SQL执行获取总记录数结果,封装Page对象设置原结果集为分页结果列表,设置总记录数,并返回
import com.mm.persist.expands.mybatis.dialect.Dialect; import com.mm.persist.expands.mybatis.dialect.MySql5Dialect; import com.mm.persist.expands.mybatis.dialect.OracleDialect; import com.mm.persist.expands.mybatis.dialect.SQLServer2005Dialect; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.ibatis.binding.MapperMethod; import org.apache.ibatis.executor.parameter.ParameterHandler; import org.apache.ibatis.executor.resultset.DefaultResultSetHandler; import org.apache.ibatis.executor.resultset.ResultSetHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.plugin.*; import org.apache.ibatis.reflection.MetaObject; import org.apache.ibatis.reflection.factory.DefaultObjectFactory; import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory; import org.apache.ibatis.session.Configuration; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageImpl; import org.springframework.data.domain.Pageable; import org.springframework.jdbc.support.JdbcUtils; import java.sql.*; import java.util.ArrayList; import java.util.List; import java.util.Properties; @Intercepts({@Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})}) public class PaginationResultSetInterceptor implements Interceptor { private final static Log log = LogFactory.getLog(PaginationResultSetInterceptor.class); @Override public Object intercept(Invocation invocation) throws Throwable { DefaultResultSetHandler resultSetHandler = (DefaultResultSetHandler) invocation.getTarget(); MetaObject metaResultSetHandler = MetaObject.forObject(resultSetHandler, new DefaultObjectFactory(), new DefaultObjectWrapperFactory()); try { ParameterHandler parameterHandler = (ParameterHandler) metaResultSetHandler.getValue("parameterHandler"); Object parameterObject = parameterHandler.getParameterObject(); Pageable pagination = null; if(parameterObject instanceof MapperMethod.ParamMap){ MapperMethod.ParamMap paramMapObject = (MapperMethod.ParamMap)parameterObject ; if(paramMapObject != null){ for(Object key : paramMapObject.keySet()){ if(paramMapObject.get(key) instanceof Pageable){ pagination = (Pageable) paramMapObject.get(key); break; } } } } if (pagination != null) { BoundSql boundSql = (BoundSql) metaResultSetHandler.getValue("parameterHandler.boundSql"); Configuration configuration = (Configuration) metaResultSetHandler.getValue("configuration"); Connection connection = (Connection) metaResultSetHandler.getValue("executor.delegate.transaction.connection"); String originalSql = boundSql.getSql(); Dialect.Type databaseType = Dialect.Type.valueOf(configuration.getVariables().getProperty("dialect").toUpperCase()); Dialect dialect = null; switch (databaseType) { case MYSQL: dialect = new MySql5Dialect(); break; case ORACLE: dialect = new OracleDialect(); break; case SQLSERVER: dialect = new SQLServer2005Dialect(); break; } // 修改sql,用于返回总记录数 String sql = dialect.getCountString(originalSql); Long totalRecord = getTotalRecord(connection, sql, parameterHandler); Object result = invocation.proceed(); Page page = new PageImpl((List)result, pagination, totalRecord); // // 设置返回对象类型 // metaResultSetHandler.setValue("mappedStatement.resultMaps[0].type.name", Page.class.getName()); // 设置返回值 List<Page> pageList = new ArrayList<Page>(); pageList.add(page); return pageList; } } catch (Exception e) { throw new Exception("Overwrite SQL : Fail!"); } return invocation.proceed(); } /** * 执行 count 操作 * @param connection 数据库连接 * @param sql sql * @param parameterHandler 参数设置处理器 * @return */ private Long getTotalRecord(Connection connection,String sql,ParameterHandler parameterHandler){ PreparedStatement preparedStatement = null; ResultSet resultSet = null; try { preparedStatement = connection.prepareStatement(sql); parameterHandler.setParameters(preparedStatement); resultSet = preparedStatement.executeQuery(); resultSet.next(); return (Long) JdbcUtils.getResultSetValue(resultSet, 1, Long.class); } catch (SQLException e) { e.printStackTrace(); }finally { JdbcUtils.closeResultSet(resultSet); JdbcUtils.closeStatement(preparedStatement); } return 0l; } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { } }这里我仅以MySQL为例,其他需要同学们自己写
public class MySql5Dialect extends Dialect { public String getLimitString(String querySqlString, int offset, int limit) { return querySqlString + " limit " + offset + " ," + limit; } @Override public String getCountString(String querySqlString) { int limitIndex = querySqlString.lastIndexOf("limit"); if(limitIndex != -1){ querySqlString = querySqlString.substring(0, limitIndex != -1 ? limitIndex : querySqlString.length() - 1); } // 用的过程中会发现这里对原有sql进行包装一层select count会有SQL效率低的问题 // 等待优化 return "SELECT COUNT(*) FROM (" + querySqlString + ") tem"; } public boolean supportsLimit() { return true; } }
public abstract class Dialect { public static enum Type { MYSQL, ORACLE, SQLSERVER } public abstract String getLimitString(String querySqlString, int offset, int limit); public abstract String getCountString(String querySqlString); }
以上代码测试并在工作项目中运用,代码可能不全,有时间做个demo并上传
参考文档: http://mybatis.github.io/mybatis-3/zh/configuration.html#plugins