软件编程
位置:首页>> 软件编程>> java编程>> 使用注解解决ShardingJdbc不支持复杂SQL方法

使用注解解决ShardingJdbc不支持复杂SQL方法

作者:女友在高考  发布时间:2022-08-03 02:15:12 

标签:ShardingJdbc,SQL

背景介绍

公司最近做分库分表业务,接入了 Sharding JDBC,接入完成后,回归测试时发现好几个 SQL 执行报错,关键这几个表都还不是分片表。

报错如下:

使用注解解决ShardingJdbc不支持复杂SQL方法

这下糟了嘛。熟悉 Sharding JDBC 的同学应该知道,有很多 SQL 它是不支持的。官方截图如下:

使用注解解决ShardingJdbc不支持复杂SQL方法

如果要去修改这些复杂 SQL 的话,可能要花费很多时间。那怎么办呢?只能从 Sharding JDBC 这里找突破口了,两天的研究,出来了下面这个只需要加一个注解轻松解决 Sharding Jdbc 不支持复杂 SQL 的方案。

问题复现

我本地写了一个复杂 SQL 进行测试:

public List<Map<String, Object>> queryOrder(){
       List<Map<String, Object>> orders = borderRepository.findOrders();
       return orders;
   }
public interface BOrderRepository extends JpaRepository<BOrder,Long> {
   @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
   List<Map<String, Object>> findOrders();
}

写了个测试 controller 来调用,调用后果然报错了。

使用注解解决ShardingJdbc不支持复杂SQL方法

解决思路

因为查询的复杂 SQL 的表不是分片表,那能不能指定这几个复杂查询的时候不用 Sharding JDBC 的数据源呢?

  • 在注入 Sharding JDBC 数据源的地方做处理,注入一个我们自定义的数据源

  • 这样我们获取连接的时候就能返回原生数据源了

  • 另外我们声明一个注解,对标识了注解的就返回原生数据源,否则还是返回 Sharding 数据源

具体实现

编写autoConfig 类

  • 编写一个 autoConfig 类,来替换 ShardingSphereAutoConfiguration 类

/**
* 动态数据源核心自动配置类
*
*
*/
@Configuration
@ComponentScan("org.apache.shardingsphere.spring.boot.converter")
@EnableConfigurationProperties(SpringBootPropertiesConfiguration.class)
@ConditionalOnProperty(prefix = "spring.shardingsphere", name = "enabled", havingValue = "true", matchIfMissing = true)
@AutoConfigureBefore(DataSourceAutoConfiguration.class)
public class DynamicDataSourceAutoConfiguration implements EnvironmentAware {
   private String databaseName;
   private final SpringBootPropertiesConfiguration props;
   private final Map<String, DataSource> dataSourceMap = new LinkedHashMap<>();
   public DynamicDataSourceAutoConfiguration(SpringBootPropertiesConfiguration props) {
       this.props = props;
   }
   /**
    * Get mode configuration.
    *
    * @return mode configuration
    */
   @Bean
   public ModeConfiguration modeConfiguration() {
       return null == props.getMode() ? null : new ModeConfigurationYamlSwapper().swapToObject(props.getMode());
   }
   /**
    * Get ShardingSphere data source bean.
    *
    * @param rules rules configuration
    * @param modeConfig mode configuration
    * @return data source bean
    * @throws SQLException SQL exception
    */
   @Bean
   @Conditional(LocalRulesCondition.class)
   @Autowired(required = false)
   public DataSource shardingSphereDataSource(final ObjectProvider<List<RuleConfiguration>> rules, final ObjectProvider<ModeConfiguration> modeConfig) throws SQLException {
       Collection<RuleConfiguration> ruleConfigs = Optional.ofNullable(rules.getIfAvailable()).orElseGet(Collections::emptyList);
       DataSource dataSource = ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig.getIfAvailable(), dataSourceMap, ruleConfigs, props.getProps());
       return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
   }
   /**
    * Get data source bean from registry center.
    *
    * @param modeConfig mode configuration
    * @return data source bean
    * @throws SQLException SQL exception
    */
   @Bean
   @ConditionalOnMissingBean(DataSource.class)
   public DataSource dataSource(final ModeConfiguration modeConfig) throws SQLException {
       DataSource dataSource = !dataSourceMap.isEmpty() ? ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig, dataSourceMap, Collections.emptyList(), props.getProps())
               : ShardingSphereDataSourceFactory.createDataSource(databaseName, modeConfig);
       return new WrapShardingDataSource((ShardingSphereDataSource) dataSource,dataSourceMap);
   }
   /**
    * Create transaction type scanner.
    *
    * @return transaction type scanner
    */
   @Bean
   public TransactionTypeScanner transactionTypeScanner() {
       return new TransactionTypeScanner();
   }
   @Override
   public final void setEnvironment(final Environment environment) {
       dataSourceMap.putAll(DataSourceMapSetter.getDataSourceMap(environment));
       databaseName = DatabaseNameSetter.getDatabaseName(environment);
   }
   @Role(BeanDefinition.ROLE_INFRASTRUCTURE)
   @Bean
   @ConditionalOnProperty(prefix = "spring.datasource.dynamic.aop", name = "enabled", havingValue = "true", matchIfMissing = true)
   public Advisor dynamicDatasourceAnnotationAdvisor() {
       DynamicDataSourceAnnotationInterceptor interceptor = new DynamicDataSourceAnnotationInterceptor(true);
       DynamicDataSourceAnnotationAdvisor advisor = new DynamicDataSourceAnnotationAdvisor(interceptor, DS.class);
       return advisor;
   }
}

自定义数据源

public class WrapShardingDataSource extends AbstractDataSourceAdapter implements AutoCloseable{
   private ShardingSphereDataSource dataSource;
   private Map&lt;String, DataSource&gt; dataSourceMap;
   public WrapShardingDataSource(ShardingSphereDataSource dataSource, Map&lt;String, DataSource&gt; dataSourceMap) {
       this.dataSource = dataSource;
       this.dataSourceMap = dataSourceMap;
   }
   public DataSource getTargetDataSource(){
       String peek = DynamicDataSourceContextHolder.peek();
       if(StringUtils.isEmpty(peek)){
           return dataSource;
       }
       return dataSourceMap.get(peek);
   }
   @Override
   public Connection getConnection() throws SQLException {
       return getTargetDataSource().getConnection();
   }
   @Override
   public Connection getConnection(final String username, final String password) throws SQLException {
       return getConnection();
   }
   @Override
   public void close() throws Exception {
       DataSource targetDataSource = getTargetDataSource();
       if (targetDataSource instanceof AutoCloseable) {
           ((AutoCloseable) targetDataSource).close();
       }
   }
   @Override
   public int getLoginTimeout() throws SQLException {
       DataSource targetDataSource = getTargetDataSource();
       return targetDataSource ==null ? 0 : targetDataSource.getLoginTimeout();
   }
   @Override
   public void setLoginTimeout(final int seconds) throws SQLException {
       DataSource targetDataSource = getTargetDataSource();
       targetDataSource.setLoginTimeout(seconds);
   }
}
  • 声明指定数据源注解

@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DS {
   /**
    * 数据源名
    */
   String value();
}
  • 另外使用 AOP 的方式拦截使用了注解的类或方法,并且要将这些用了注解的方法存起来,在获取数据源连接的时候取出来进行判断。这就还要用到 ThreadLocal。

aop * :

public class DynamicDataSourceAnnotationInterceptor implements MethodInterceptor {
   private final DataSourceClassResolver dataSourceClassResolver;
   public DynamicDataSourceAnnotationInterceptor(Boolean allowedPublicOnly) {
       dataSourceClassResolver = new DataSourceClassResolver(allowedPublicOnly);
   }
   @Override
   public Object invoke(MethodInvocation invocation) throws Throwable {
       String dsKey = determineDatasourceKey(invocation);
       DynamicDataSourceContextHolder.push(dsKey);
       try {
           return invocation.proceed();
       } finally {
           DynamicDataSourceContextHolder.poll();
       }
   }
   private String determineDatasourceKey(MethodInvocation invocation) {
       String key = dataSourceClassResolver.findKey(invocation.getMethod(), invocation.getThis());
       return key;
   }
}

aop 切面定义

/**
* aop Advisor
*/
public class DynamicDataSourceAnnotationAdvisor extends AbstractPointcutAdvisor implements BeanFactoryAware {
   private final Advice advice;
   private final Pointcut pointcut;
   private final Class<? extends Annotation> annotation;
   public DynamicDataSourceAnnotationAdvisor(MethodInterceptor advice,
                                              Class<? extends Annotation> annotation) {
       this.advice = advice;
       this.annotation = annotation;
       this.pointcut = buildPointcut();
   }
   @Override
   public Pointcut getPointcut() {
       return this.pointcut;
   }
   @Override
   public Advice getAdvice() {
       return this.advice;
   }
   @Override
   public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
       if (this.advice instanceof BeanFactoryAware) {
           ((BeanFactoryAware) this.advice).setBeanFactory(beanFactory);
       }
   }
   private Pointcut buildPointcut() {
       Pointcut cpc = new AnnotationMatchingPointcut(annotation, true);
       Pointcut mpc = new AnnotationMethodPoint(annotation);
       return new ComposablePointcut(cpc).union(mpc);
   }
   /**
    * In order to be compatible with the spring lower than 5.0
    */
   private static class AnnotationMethodPoint implements Pointcut {
       private final Class<? extends Annotation> annotationType;
       public AnnotationMethodPoint(Class<? extends Annotation> annotationType) {
           Assert.notNull(annotationType, "Annotation type must not be null");
           this.annotationType = annotationType;
       }
       @Override
       public ClassFilter getClassFilter() {
           return ClassFilter.TRUE;
       }
       @Override
       public MethodMatcher getMethodMatcher() {
           return new AnnotationMethodMatcher(annotationType);
       }
       private static class AnnotationMethodMatcher extends StaticMethodMatcher {
           private final Class<? extends Annotation> annotationType;
           public AnnotationMethodMatcher(Class<? extends Annotation> annotationType) {
               this.annotationType = annotationType;
           }
           @Override
           public boolean matches(Method method, Class<?> targetClass) {
               if (matchesMethod(method)) {
                   return true;
               }
               // Proxy classes never have annotations on their redeclared methods.
               if (Proxy.isProxyClass(targetClass)) {
                   return false;
               }
               // The method may be on an interface, so let's check on the target class as well.
               Method specificMethod = AopUtils.getMostSpecificMethod(method, targetClass);
               return (specificMethod != method && matchesMethod(specificMethod));
           }
           private boolean matchesMethod(Method method) {
               return AnnotatedElementUtils.hasAnnotation(method, this.annotationType);
           }
       }
   }
}
/**
* 数据源解析器
*
*/
public class DataSourceClassResolver {
   private static boolean mpEnabled = false;
   private static Field mapperInterfaceField;
   static {
       Class<?> proxyClass = null;
       try {
           proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.MybatisMapperProxy");
       } catch (ClassNotFoundException e1) {
           try {
               proxyClass = Class.forName("com.baomidou.mybatisplus.core.override.PageMapperProxy");
           } catch (ClassNotFoundException e2) {
               try {
                   proxyClass = Class.forName("org.apache.ibatis.binding.MapperProxy");
               } catch (ClassNotFoundException ignored) {
               }
           }
       }
       if (proxyClass != null) {
           try {
               mapperInterfaceField = proxyClass.getDeclaredField("mapperInterface");
               mapperInterfaceField.setAccessible(true);
               mpEnabled = true;
           } catch (NoSuchFieldException e) {
               e.printStackTrace();
           }
       }
   }
   /**
    * 缓存方法对应的数据源
    */
   private final Map<Object, String> dsCache = new ConcurrentHashMap<>();
   private final boolean allowedPublicOnly;
   /**
    * 加入扩展, 给外部一个修改aop条件的机会
    *
    * @param allowedPublicOnly 只允许公共的方法, 默认为true
    */
   public DataSourceClassResolver(boolean allowedPublicOnly) {
       this.allowedPublicOnly = allowedPublicOnly;
   }
   /**
    * 从缓存获取数据
    *
    * @param method       方法
    * @param targetObject 目标对象
    * @return ds
    */
   public String findKey(Method method, Object targetObject) {
       if (method.getDeclaringClass() == Object.class) {
           return "";
       }
       Object cacheKey = new MethodClassKey(method, targetObject.getClass());
       String ds = this.dsCache.get(cacheKey);
       if (ds == null) {
           ds = computeDatasource(method, targetObject);
           if (ds == null) {
               ds = "";
           }
           this.dsCache.put(cacheKey, ds);
       }
       return ds;
   }
   /**
    * 查找注解的顺序
    * 1. 当前方法
    * 2. 桥接方法
    * 3. 当前类开始一直找到Object
    * 4. 支持mybatis-plus, mybatis-spring
    *
    * @param method       方法
    * @param targetObject 目标对象
    * @return ds
    */
   private String computeDatasource(Method method, Object targetObject) {
       if (allowedPublicOnly && !Modifier.isPublic(method.getModifiers())) {
           return null;
       }
       //1. 从当前方法接口中获取
       String dsAttr = findDataSourceAttribute(method);
       if (dsAttr != null) {
           return dsAttr;
       }
       Class<?> targetClass = targetObject.getClass();
       Class<?> userClass = ClassUtils.getUserClass(targetClass);
       // JDK代理时,  获取实现类的方法声明.  method: 接口的方法, specificMethod: 实现类方法
       Method specificMethod = ClassUtils.getMostSpecificMethod(method, userClass);
       specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);
       //2. 从桥接方法查找
       dsAttr = findDataSourceAttribute(specificMethod);
       if (dsAttr != null) {
           return dsAttr;
       }
       // 从当前方法声明的类查找
       dsAttr = findDataSourceAttribute(userClass);
       if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
           return dsAttr;
       }
       //since 3.4.1 从接口查找,只取第一个找到的
       for (Class<?> interfaceClazz : ClassUtils.getAllInterfacesForClassAsSet(userClass)) {
           dsAttr = findDataSourceAttribute(interfaceClazz);
           if (dsAttr != null) {
               return dsAttr;
           }
       }
       // 如果存在桥接方法
       if (specificMethod != method) {
           // 从桥接方法查找
           dsAttr = findDataSourceAttribute(method);
           if (dsAttr != null) {
               return dsAttr;
           }
           // 从桥接方法声明的类查找
           dsAttr = findDataSourceAttribute(method.getDeclaringClass());
           if (dsAttr != null && ClassUtils.isUserLevelMethod(method)) {
               return dsAttr;
           }
       }
       return getDefaultDataSourceAttr(targetObject);
   }
   /**
    * 默认的获取数据源名称方式
    *
    * @param targetObject 目标对象
    * @return ds
    */
   private String getDefaultDataSourceAttr(Object targetObject) {
       Class<?> targetClass = targetObject.getClass();
       // 如果不是代理类, 从当前类开始, 不断的找父类的声明
       if (!Proxy.isProxyClass(targetClass)) {
           Class<?> currentClass = targetClass;
           while (currentClass != Object.class) {
               String datasourceAttr = findDataSourceAttribute(currentClass);
               if (datasourceAttr != null) {
                   return datasourceAttr;
               }
               currentClass = currentClass.getSuperclass();
           }
       }
       // mybatis-plus, mybatis-spring 的获取方式
       if (mpEnabled) {
           final Class<?> clazz = getMapperInterfaceClass(targetObject);
           if (clazz != null) {
               String datasourceAttr = findDataSourceAttribute(clazz);
               if (datasourceAttr != null) {
                   return datasourceAttr;
               }
               // 尝试从其父接口获取
               return findDataSourceAttribute(clazz.getSuperclass());
           }
       }
       return null;
   }
   /**
    * 用于处理嵌套代理
    *
    * @param target JDK 代理类对象
    * @return InvocationHandler 的 Class
    */
   private Class<?> getMapperInterfaceClass(Object target) {
       Object current = target;
       while (Proxy.isProxyClass(current.getClass())) {
           Object currentRefObject = AopProxyUtils.getSingletonTarget(current);
           if (currentRefObject == null) {
               break;
           }
           current = currentRefObject;
       }
       try {
           if (Proxy.isProxyClass(current.getClass())) {
               return (Class<?>) mapperInterfaceField.get(Proxy.getInvocationHandler(current));
           }
       } catch (IllegalAccessException ignore) {
       }
       return null;
   }
   /**
    * 通过 AnnotatedElement 查找标记的注解, 映射为  DatasourceHolder
    *
    * @param ae AnnotatedElement
    * @return 数据源映射持有者
    */
   private String findDataSourceAttribute(AnnotatedElement ae) {
       AnnotationAttributes attributes = AnnotatedElementUtils.getMergedAnnotationAttributes(ae, DS.class);
       if (attributes != null) {
           return attributes.getString("value");
       }
       return null;
   }
}

ThreadLocal

public final class DynamicDataSourceContextHolder {
   /**
    * 为什么要用链表存储(准确的是栈)
    * <pre>
    * 为了支持嵌套切换,如ABC三个service都是不同的数据源
    * 其中A的某个业务要调B的方法,B的方法需要调用C的方法。一级一级调用切换,形成了链。
    * 传统的只设置当前线程的方式不能满足此业务需求,必须使用栈,后进先出。
    * </pre>
    */
   private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new NamedThreadLocal<Deque<String>>("dynamic-datasource") {
       @Override
       protected Deque<String> initialValue() {
           return new ArrayDeque<>();
       }
   };
   private DynamicDataSourceContextHolder() {
   }
   /**
    * 获得当前线程数据源
    *
    * @return 数据源名称
    */
   public static String peek() {
       return LOOKUP_KEY_HOLDER.get().peek();
   }
   /**
    * 设置当前线程数据源
    * <p>
    * 如非必要不要手动调用,调用后确保最终清除
    * </p>
    *
    * @param ds 数据源名称
    */
   public static String push(String ds) {
       String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;
       LOOKUP_KEY_HOLDER.get().push(dataSourceStr);
       return dataSourceStr;
   }
   /**
    * 清空当前线程数据源
    * <p>
    * 如果当前线程是连续切换数据源 只会移除掉当前线程的数据源名称
    * </p>
    */
   public static void poll() {
       Deque<String> deque = LOOKUP_KEY_HOLDER.get();
       deque.poll();
       if (deque.isEmpty()) {
           LOOKUP_KEY_HOLDER.remove();
       }
   }
   /**
    * 强制清空本地线程
    * <p>
    * 防止内存泄漏,如手动调用了push可调用此方法确保清除
    * </p>
    */
   public static void clear() {
       LOOKUP_KEY_HOLDER.remove();
   }
}

启动类配置

引入我们写的自动配置类,排除 ShardingJdbc 的自动配置类。

@SpringBootApplication(exclude = ShardingSphereAutoConfiguration.class)
@Import({DynamicDataSourceAutoConfiguration.class})
public class ShardingRunApplication {
   public static void main(String[] args) {
       SpringApplication.run(ShardingRunApplication.class);
   }
}

最后,我们给之前写的 Repository 加上注解:

public interface BOrderRepository extends JpaRepository<BOrder,Long> {
   @DS("slave0")
   @Query(value = "SELECT * FROM (SELECT id,CASE WHEN company_id =1 THEN '小' WHEN company_id=4 THEN '中' ELSE '大' END AS com,user_id as userId FROM b_order0) t WHERE t.com ='中'",nativeQuery =true)
   List<Map<String, Object>> findOrders();
}

再次调用,查询成功!!!

使用注解解决ShardingJdbc不支持复杂SQL方法

来源:https://juejin.cn/post/7141242019120152612

0
投稿

猜你喜欢

手机版 软件编程 asp之家 www.aspxhome.com