Mybatis: multi-tenant sql-interceptor
系统租户隔离实现有多种实现方式:
- 完全隔离(不同数据库): 没啥好讲的, 看作是多个系统就成, 此方式毫无疑问, 成本最高 玩不起 玩不起…
- 共享隔离(共享同一个数据库), 又分为以下两种:
- 多个Schema, 表完全隔离:一般通过中间件, 根据会话标识路由到指定schema即可
- 同一个Schema, 表上添加租户标识:比较底层, 必须通过拦截方式实现SQL重构方可实现
下面介绍的是同Schema,表上添加租户标识的具体实现代码
1. 添加依赖
<dependency> <!-- 需要借助MyBatis拦截器插件, implements org.apache.ibatis.plugin.Interceptor -->
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
</dependency>
<dependency> <!-- 需要一个SQL识别的插件, Github上刚好有: https://github.com/JSQLParser/JSqlParser, Star 2.5k -->
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>1.4</version>
</dependency>
2. 代码片段及解析
class imports
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.util.TablesNamesFinder;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
/**
* 前言阐述知识点:
* Mybatis仅拦截的四大金刚: ParameterHandler、ResultSetHandler、StatementHandler、Executor
* 注册插件分两种方式:
* a. 配置文件
* <Configuration>
* <plugins>
<plugin />
* </plugins>
* </configuration>
* b. 代码方式
* @Resource
* SqlSessionFactory sqlSessionFactory;
* sqlSessionFactory.getConfiguration().addInterceptor(interceptor);
*
* 拦截器真正执行时机在plugin()方法通过代理方式注册拦截器责任链后.
**/
@Log4j2
// Anno: @Intercepts, class头声明该拦截器需要拦截MyBatis中哪个类(type)且类中的哪些方法(method)<Signature>, 下面我们抓的是声明类的预处理方法
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class})})
// 亦可@Intercepts({
// @Signature(type = Executor.class, method = "update", args = {
// MappedStatement.class, Object.class }),
// @Signature(type = Executor.class, method = "query", args = {
// MappedStatement.class, Object.class, RowBounds.class,
// ResultHandler.class }) }), 此处不作展示
public class TenantInterceptor implements Interceptor {
private static final String SQL_TENANT_ID = "tenant_id"; // 名称自定
private boolean onFilter(String statementId) {
// todo 根据个人需求实现, 主要用途是过滤掉一些不需要租户过滤的脚本, statementId = Mapper中定义的属性ID, 譬如<select id="selectXXX" ../>
...
}
// method: intercept, 此处实现拦截逻辑
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 拦截的StatementHandler, 所以获取的对象应该也是它
StatementHandler handler = (StatementHandler) invocation.getTarget();
// 取绑定的SQL脚本并打印
BoundSql boundSql = handler.getBoundSql();
String sql = boundSql.getSql();
log.debug("Intercept SQL: {}", sql);
String delegateSql = sql;
MetaObject statementHandler = SystemMetaObject.forObject(handler);
// 取Mapper文件定义
MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
if (this.onFilter(mappedStatement.getId())) {
// 核心注入自定义脚本的方法, 具体往下看
delegateSql = this.delegate(sql);
statementHandler.setValue("delegate.boundSql.sql", delegateSql);
}
log.debug("Delegate SQL: {}", delegateSql);
return invocation.proceed();
}
private String delegate(String originSql) throws Exception {
return this.rewrite(originSql);
}
// SQL重写路由
private String rewrite(String originSql) throws Exception {
Statement statement = CCJSqlParserUtil.parse(originSql);
if (statement instanceof Insert) {
return this.rewriteInsertSql(statement);
} else if (statement instanceof Delete) {
return this.rewriteDeleteSql(statement);
} else if (statement instanceof Update) {
return this.rewriteUpdateSql(statement);
} else if (statement instanceof Select) {
return this.rewriteSelectSql(statement);
} else {
// 自行实现异常
throw new SQLNotSupportedException();
}
}
private String rewriteInsertSql(Statement statement) {
Insert insert = (Insert) statement;
insert.getColumns().add(new Column(SQL_TENANT_ID));
// insert into A (a, b, tenantId) values ('a', 'b', ''), ('a', 'b', ''), ('a', 'b', '') ...
if (insert.getItemsList() instanceof MultiExpressionList){
for (ExpressionList expression : ((MultiExpressionList) insert.getItemsList()).getExprList()) {
expression.getExpressions().add(new StringValue(this.tenantSupport.getTenantId()));
}
} else {
((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(this.tenantSupport.getTenantId()));
}
return insert.toString();
}
private String rewriteDeleteSql(Statement statement) {
Delete deleteStatement = (Delete) statement;
Expression whereExpression = deleteStatement.getWhere();
if(whereExpression == null) {
throw new SQLInterceptException("Delete-Statement must be set conditions");
}
// 含左右表达式
if (whereExpression instanceof BinaryExpression) {
AndExpression andExpression = new AndExpression(this.newEqualTo(), new Parenthesis(whereExpression));
deleteStatement.setWhere(andExpression);
}
return deleteStatement.toString();
}
private String rewriteUpdateSql(Statement statement) {
Update updateStatement = (Update) statement;
if (updateStatement.getWhere() == null) {
throw new SQLInterceptException("Update-Statement must be set conditions");
}
TablesNamesFinder tableNameFinder = new TablesNamesFinder();
List<String> tableNames = tableNameFinder.getTableList(statement);
// select 1
if (tableNames.size() == 0) {
return updateStatement.toString();
}
// update A set name='' where tenantId = ''
for (String tableName : tableNames) {
updateStatement.setWhere(this.newAndExpression(statement, tableName, updateStatement.getWhere()));
}
return updateStatement.toString();
}
private String rewriteSelectSql(Statement statement) {
Select selectStatement = (Select) statement;
TablesNamesFinder tablesNameFinder = new TablesNamesFinder();
List<String> tableNames = tablesNameFinder.getTableList(selectStatement);
// select 1 OR select now()
if (tableNames.size() == 0) {
return selectStatement.toString();
}
// 复杂查询, 譬如JOIN, 普通连表等, 当前仅仅处理主表条件
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
String mainTableName = ((Table) plainSelect.getFromItem()).getName();
if(plainSelect.getWhere() == null) {
plainSelect.setWhere(this.newEqualTo(statement, mainTableName));
} else {
plainSelect.setWhere(this.newAndExpression(statement, mainTableName, plainSelect.getWhere()));
}
return selectStatement.toString();
}
private AndExpression newAndExpression(Statement statement, String tableName, Expression whereExpression) {
EqualsTo equalsTo = this.newEqualTo(statement, tableName);
// rewrite parent where expression
return new AndExpression(equalsTo, new Parenthesis(whereExpression));
}
private EqualsTo newEqualTo() {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(new Column(SQL_TENANT_ID));
equalsTo.setRightExpression(new StringValue(tenantSupport.getTenantId()));
return equalsTo;
}
private EqualsTo newEqualTo(Statement statement, String tableName) {
EqualsTo equalsTo = new EqualsTo();
String aliasName = this.getTableAlias(statement, tableName);
equalsTo.setLeftExpression(new Column((aliasName == null ? "" : aliasName + '.') + SQL_TENANT_ID));
equalsTo.setRightExpression(new StringValue(tenantSupport.getTenantId()));
return equalsTo;
}
private String getTableAlias(Statement stmt, String tableName) {
String aliasName = null;
if (stmt instanceof Insert) {
return tableName;
} else if (stmt instanceof Delete) {
Delete deleteStatement = (Delete) stmt;
if ((deleteStatement.getTable()).getName().equalsIgnoreCase(tableName)) {
Alias alias = deleteStatement.getTable().getAlias();
aliasName = alias != null ? alias.getName() : tableName;
}
} else if (stmt instanceof Update) {
Update updateStatement = (Update) stmt;
if ((updateStatement.getTables().get(0)).getName().equalsIgnoreCase(tableName)) {
Alias alias = updateStatement.getTables().get(0).getAlias();
aliasName = alias != null ? alias.getName() : tableName;
}
} else if (stmt instanceof Select) {
Select select = (Select) stmt;
PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
if (((Table) plainSelect.getFromItem()).getName().equalsIgnoreCase(tableName)) {
Alias alias = plainSelect.getFromItem().getAlias();
aliasName = alias != null ? alias.getName() : tableName;
}
}
return aliasName;
}
// 生成包装代理类
@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}
@Override
public void setProperties(Properties properties) {
}
...
}
小提示
<!-- 最好隐性操作租户标识,可在Mybatis生成基础文件时屏蔽 -->
<table ...>
<ignoreColumn column="tenant_id"/>
</table>