package com.digiwin.athena.framework.rw.router;

import com.digiwin.athena.framework.rw.RWTypeInterceptor;
import com.digiwin.athena.framework.rw.ShardPlugin;
import org.apache.ibatis.mapping.Environment;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.boot.autoconfigure.ConfigurationCustomizer;
import org.springframework.beans.BeansException;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import javax.annotation.PostConstruct;
import javax.sql.DataSource;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class MybatisInterceptorInit implements ApplicationContextAware {

    private List<SqlSessionFactory> sqlSessionFactoryList;
    private ShardPlugin shardPlugin;
    private RWTypeInterceptor rwTypeInterceptor;
    private ApplicationContext applicationContext;

    MybatisInterceptorInit(List<SqlSessionFactory> sqlSessionFactoryList, ShardPlugin shardPlugin, RWTypeInterceptor rwTypeInterceptor) {
        this.sqlSessionFactoryList = sqlSessionFactoryList;
        this.shardPlugin = shardPlugin;
        this.rwTypeInterceptor = rwTypeInterceptor;
    }

    @PostConstruct
    public void addMybatisInterceptor() {
        Map<String, DataSource> dataSourceMap = applicationContext.getBeansOfType(DataSource.class);

        for (SqlSessionFactory sqlSessionFactory : sqlSessionFactoryList) {
            org.apache.ibatis.session.Configuration configuration = sqlSessionFactory.getConfiguration();
            Environment environment = configuration.getEnvironment();
            DataSource currentDataSource = environment.getDataSource();

            // 找到名称为 slave 的 DataSource
            for (Map.Entry<String, DataSource> entry : dataSourceMap.entrySet()) {
                String beanName = entry.getKey();
                DataSource dataSource = entry.getValue();

                if ("ptmPrimaryDataSource".equals(beanName) || ("ptmPrimaryDataSource".equals(beanName)) && dataSource == currentDataSource) {
                    configuration.addInterceptor(shardPlugin);
                    configuration.addInterceptor(rwTypeInterceptor);
                    break;
                }
            }
        }
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }
}
