思路:
如何切换?
1、利用覆盖实现 AbstractRoutingDataSource.determineCurrentLookupKey 返回当前所需的数据源名称,已达到动态切换数据源。
为了线程安全,将determineCurrentLookupKey 返回值定义为ThreadLocal。
什么时候切换?
1、利用spring aop 根据当前方法名切换数据源。
AbstractRoutingDataSource实现:
/**
*
* @Description: TODO( spring 通过determineCurrentLookupKey获取数据源名称,
* 所以动态修改 determineCurrentLookupKey返回值,即可达到切换数据源)
* @author: chenwm
* @date: 2017年6月8日 上午9:41:39
*
*/
public class ChooseDataSource extends AbstractRoutingDataSource {
// 获取数据源名称
protected Object determineCurrentLookupKey() {
return HandleDataSource.getDataSource();
}
}
扫描二维码关注公众号,回复:
2826163 查看本文章
HandleDataSource实现:
/**
*
* @Description: TODO( 保证数据源线程安全 )
* @author: chenwm
* @date: 2017年6月8日 上午9:33:57
*
*/
public class HandleDataSource {
// 数据源名称线程池
private static final ThreadLocal<String> holder = new ThreadLocal<String>();
public static void putDataSource(String datasource) {
holder.set(datasource);
}
public static String getDataSource() {
return holder.get();
}
public static void clear() {
holder.remove();
}
}
定义切入点:
/**
*
* @Description: TODO( aop切换数据源 )
* @author: chenwm
* @date: 2017年6月8日 上午9:33:33
*
*/
@Aspect
@Order(-10)//保证该AOP在@Transactional之前执行
@Component
public class DataSourceAspect {
public static String slaveDSName = "slaveDataSource";
public static String masterDSName = "masterDataSource";
public static String[] slaveReg = new String[]{"get","select","count","list","query","find"};
public static String[] masterReg = new String[]{"add","insert","create","update","delete","remove"};
public boolean isMaster(String methodName){
return StringUtils.startsWithAny(methodName,masterReg);
}
public boolean isSlave(String methodName){
return StringUtils.startsWithAny(methodName,slaveReg);
}
@Pointcut("execution(* org.quickjee.service..*.*(..))")
public void aspect() {
}
/**
*
* @Description: TODO( 配置前置通知,使用在方法aspect()上注册的切入点 )
* @param: @param point
* @return: void
* @throws
*/
@Before("aspect()")
public void before(JoinPoint point) {
String className = point.getTarget().getClass().getName();
String method = point.getSignature().getName();
System.out.println(className + "." + method + "(" + StringUtils.join(point.getArgs(), ",") + ")");
try {
if(isMaster(method)){//主库
HandleDataSource.putDataSource(DataSourceAspect.masterDSName);
}else if(isSlave(method)){//从库
HandleDataSource.putDataSource(DataSourceAspect.slaveDSName);
}else{
HandleDataSource.putDataSource(DataSourceAspect.masterDSName);
}
} catch (Exception e) {
HandleDataSource.putDataSource(DataSourceAspect.masterDSName);
}
}
/**
*
* @Description: TODO( 清除当前数据源 )
* @param: @param point
* @return: void
* @throws
*/
@After("aspect()")
public void after(JoinPoint point) {
HandleDataSource.clear();
}
}
定义数据源:
public class DynamicDataSourceRegister implements
ImportBeanDefinitionRegistrar, EnvironmentAware {
private DataSource masterDataSource;//主库
private DataSource slaveDataSource;//从库
private Map<String, DataSource> targetDataSources = new HashMap<String, DataSource>();
private void initTargetDataSources(DBProperties dbParams) {
slaveDataSource = buildDataSources(dbParams,true);//初始化
masterDataSource = buildDataSources(dbParams,false);
targetDataSources.put(DataSourceAspect.slaveDSName, slaveDataSource);
targetDataSources.put(DataSourceAspect.masterDSName, masterDataSource);
}
private DataSource buildDataSources(DBProperties dbProperties,boolean readOnly) {
DruidDataSource dataSource = new DruidDataSource();
List<Filter> filters = new ArrayList<Filter>();
dataSource.setDriverClassName(dbProperties.getDriverClassName());
dataSource.setUrl(dbProperties.getUrl());
dataSource.setUsername(dbProperties.getUsername());
dataSource.setPassword(dbProperties.getPassword());
dataSource.setInitialSize(dbProperties.getInitialSize());
dataSource.setMaxActive(dbProperties.getMaxActive());
dataSource.setMinIdle(dbProperties.getMaxIdle());
dataSource.setDefaultReadOnly(readOnly);//只读
//dataSource.setProxyFilters(filters);
dataSource.setTestWhileIdle(true);
dataSource.setTestOnBorrow(false);
dataSource.setTestOnReturn(false);
dataSource.setValidationQuery("SELECT 'x'");
dataSource.setTimeBetweenLogStatsMillis(dbProperties.getTimeBetweenLogStatsMillis());
dataSource.setTimeBetweenEvictionRunsMillis(dbProperties.getTimeBetweenEvictionRunsMillis());
dataSource.setMinEvictableIdleTimeMillis(dbProperties.getMinEvictableIdleTimeMillis());
return dataSource;
}
@Override
public void setEnvironment(Environment env) {
DBProperties dbParams = getDBParams(new RelaxedPropertyResolver(
env, "db."));
initTargetDataSources(dbParams);
}
private DBProperties getDBParams(RelaxedPropertyResolver resolver) {
DBProperties db = new DBProperties();
db.setDriverClassName(resolver.getProperty("driverClassName"));
db.setInitialSize(Integer.parseInt(resolver.getProperty("initialSize")));
db.setMaxActive(Integer.parseInt(resolver.getProperty("maxActive")));
db.setMaxIdle(Integer.parseInt(resolver.getProperty("maxIdle")));
db.setMaxPoolPreparedStatementPerConnectionSize(Integer.parseInt(resolver
.getProperty("maxPoolPreparedStatementPerConnectionSize")));
db.setUrl(resolver.getProperty("url"));
db.setUsername(resolver.getProperty("username"));
db.setPassword(resolver.getProperty("password"));
db.setMinIdle(Integer.parseInt(resolver.getProperty("minIdle")));
db.setMaxWait(Integer.parseInt(resolver.getProperty("maxWait")));
db.setTimeBetweenLogStatsMillis(Integer.parseInt(resolver
.getProperty("timeBetweenLogStatsMillis")));
db.setTimeBetweenEvictionRunsMillis(Integer.parseInt(resolver
.getProperty("timeBetweenEvictionRunsMillis")));
db.setMinEvictableIdleTimeMillis(Integer.parseInt(resolver
.getProperty("minEvictableIdleTimeMillis")));
return db;
}
/**
* 往spring注册bean
* <p>Title: registerBeanDefinitions</p>
* <p>Description: </p>
* @param importingClassMetadata
* @param registry
* @see org.springframework.context.annotation.ImportBeanDefinitionRegistrar#registerBeanDefinitions(org.springframework.core.type.AnnotationMetadata, org.springframework.beans.factory.support.BeanDefinitionRegistry)
*/
@Override
public void registerBeanDefinitions(
AnnotationMetadata importingClassMetadata,
BeanDefinitionRegistry registry) {
// 创建ChooseDataSource
GenericBeanDefinition beanDefinition = new GenericBeanDefinition();
beanDefinition.setBeanClass(ChooseDataSource.class);
beanDefinition.setSynthetic(true);
MutablePropertyValues mpv = beanDefinition.getPropertyValues();
//添加属性:AbstractRoutingDataSource.defaultTargetDataSource
mpv.addPropertyValue("defaultTargetDataSource", masterDataSource);
mpv.addPropertyValue("targetDataSources", targetDataSources);
registry.registerBeanDefinition("dataSource", beanDefinition);
}
}
本切入点org.quickjee.service..*.*(..) ,即service层,
测试demo:
@Service
public class UserService {
@Resource
private UserRepository userRepository;
public SysUser findById(long id) {
return userRepository.findOne(id);
}
@Transactional
public SysUser update(long id) {
SysUser user = userRepository.findOne(id);
user.setPhone("13900139000");
userRepository.save(user);
return user;
}
}
Controller:
@RestController
public class TestController {
@Autowired
private UserService userService;
@RequestMapping("/index")
public String index(){
SysUser loaded = userService.findById(1L);
SysUser cached = userService.update(1L);
//loaded = userService.findById(2);
return"ok";
}
}