package com.digiwin.athena.framework.rw;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.fastjson.JSON;
import com.digiwin.athena.framework.rw.contants.ReadType;
import com.digiwin.athena.framework.rw.contants.WriteType;
import com.digiwin.athena.framework.rw.dto.ReadWriterDto;
import com.digiwin.athena.framework.rw.exception.MyBatisShardException;
import com.digiwin.athena.framework.rw.router.DbSwitchConfig;
import com.digiwin.athena.framework.rw.router.MySqlReplaceTableNameVisitor;
import com.digiwin.athena.framework.rw.strategy.AbstractShardStrategy;
import com.digiwin.athena.framework.rw.strategy.ShardStrategyContext;
import com.digiwin.athena.framework.rw.utils.CommonUtils;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.springframework.util.Assert;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.util.List;

@Slf4j
public class ShardProcessor {

    private final MetaObject metaObject;

    private DbSwitchConfig dbSwitchConfig;
    private final BoundSql boundSql;
    private final MappedStatement mappedStatement;
    private final String originalSql;
    private final String shardSql;

    private String tableName;
    private ReadType readType = ReadType.OLD;
    private WriteType writeType = WriteType.OLD;

    public ShardProcessor(@NonNull MetaObject metaObject, DbSwitchConfig dbSwitchConfig) {
        this.dbSwitchConfig = dbSwitchConfig;
        this.metaObject = metaObject;
        this.boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
        this.mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        this.originalSql = boundSql.getSql();
        this.shardSql = calculateShardSql();

    }

    private String calculateShardSql() {
        SQLStatement stmt = getSqlStatement();
        MySqlReplaceTableNameVisitor mySqlReplaceTableNameVisitor = new MySqlReplaceTableNameVisitor(boundSql, dbSwitchConfig);
        stmt.accept(mySqlReplaceTableNameVisitor);
        tableName = mySqlReplaceTableNameVisitor.getTableName();
        this.readType = ReadType.valueOfKey(dbSwitchConfig.getReadMode());
        this.writeType = WriteType.valueOfKey(dbSwitchConfig.getWriteMode());
        return SQLUtils.toMySqlString(stmt);
    }

    public void route() {
//        if (readType == ReadType.OLD) {
//            return;
//        }
//        ReadType readType = ReadType.valueOfKey(dbSwitchConfig.getReadMode());
//        switch (readType) {
//            case OLD:
//                //todo 设置数据源
//                setShardSql();
//                return;
//            case NEW:
//                //todo 设置数据源
//                setShardSql();
//                return;
//        }
        setShardSql();
    }

    private void setShardSql() {
        metaObject.setValue("delegate.boundSql.sql", shardSql);
    }

    public void processParams() {
        SQLStatement stmt = getSqlStatement();
        // 统计SQL中使用的表、字段、过滤条件、排序表达式、分组表达式等
        SchemaStatVisitor schemaStatVisitor = SQLUtils.createSchemaStatVisitor(ShardPlugin.DB_TYPE);
        stmt.accept(schemaStatVisitor);

        AbstractShardStrategy shardStrategy = ShardStrategyContext.getStrategyByTableName(tableName);
        shardStrategy.processParams(metaObject, boundSql, schemaStatVisitor);
    }

    public void processLocal(ReadWriterDto readWriterDto) {
        AbstractShardStrategy shardStrategy = ShardStrategyContext.getStrategyByTableName(tableName);
        shardStrategy.processLocal(readWriterDto);
    }

    private SQLStatement getSqlStatement() {
        List<SQLStatement> stmtList = SQLUtils.parseStatements(originalSql, ShardPlugin.DB_TYPE);
        Assert.notEmpty(stmtList, "stmtList is empty, sql: " + originalSql);
        return stmtList.get(0);
    }

    public void processWrite(@NonNull Connection connection) {
        WriteType writeType = WriteType.valueOfKey(dbSwitchConfig.getWriteMode());
        switch (writeType) {
            case OLD:
                setShardSql();
                return;
            case NEW:
                setShardSql();
                return;
            case BOTH:
                writeShard(connection);
                return;
            default:
                throw new MyBatisShardException("未知WriteType：" + dbSwitchConfig.getWriteMode());
        }
    }

    private void writeShard(Connection connection) {
        Object parameterObject = boundSql.getParameterObject();
        log.info("[rw-plugin] double write sql: {} \n  parameterObject({}): {}", CommonUtils.removeBreakingWhitespace(shardSql), parameterObject.getClass().getSimpleName(), JSON.toJSONString(parameterObject));
        try (PreparedStatement statement = connection.prepareStatement(shardSql)) {
            ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);
            parameterHandler.setParameters(statement);
            statement.executeUpdate();
        } catch (Exception e) {
            throw new MyBatisShardException(String.format("Error: Method ShardPlugin.write execution error of sql : \n %s \n", CommonUtils.removeBreakingWhitespace(shardSql)), e);
        }
    }
}
