package com.digiwin.athena.framework.rw.utils;

import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.expr.SQLIdentifierExpr;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SQLASTVisitorAdapter;
import com.alibaba.druid.stat.TableStat;
import com.digiwin.athena.framework.rw.ShardPlugin;
import com.digiwin.athena.framework.rw.strategy.ShardStrategyContext;
import org.springframework.util.Assert;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class Test {

    public static void main(String[] args) {
        List<SQLStatement> stmtList = getSqlStatement();  // 假设你从原始 SQL 拿到了 AST 列表
        MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
        stmtList.get(0).accept(visitor);

        // 获取所有的表名
        Set<TableStat.Name> tableNames = visitor.getTables().keySet();

        // 假设你有逻辑根据表名和读/写类型决定真实的分片表名，这里用一个 map 表示
        Map<String, String> tableNameReplaceMap = new HashMap<>();
        for (TableStat.Name name : tableNames) {
            String originalTableName = name.getName();
            // 假设 getShardTableName 是你自己的逻辑，根据读/写模式进行替换
            String shardTableName = getShardTableName(originalTableName);
            tableNameReplaceMap.put(originalTableName, shardTableName);
        }

        // 替换表名：对 AST 中的 SQLExprTableSource 做替换
        for (SQLStatement stmt : stmtList) {

            stmt.accept(new SQLASTVisitorAdapter() {
                @Override
                public boolean visit(SQLExprTableSource x) {
                    String oldName = x.getExpr().toString();
                    System.out.println("sss:" + ShardStrategyContext.handleName(oldName));
                    if (tableNameReplaceMap.containsKey(oldName)) {
                        x.setExpr(new SQLIdentifierExpr(tableNameReplaceMap.get(oldName)));
                    }
                    return true;
                }
            });
        }

        // 返回修改后的 SQL
        String ssss = SQLUtils.toMySqlString(stmtList.get(0));
        System.out.println("====>" + ssss);

        String originalSql = "SELECT a.name, b.age FROM user a JOIN user_detail b ON a.id = b.user_id WHERE b.age > 18";
        List<SQLStatement> stmtList1111 = SQLUtils.parseStatements(originalSql, ShardPlugin.DB_TYPE);

        System.out.println("------->" + stmtList1111.get(0).toString() /*+ "|" + stmtList1111.get(1).toString()*/);

    }

    private static List<SQLStatement> getSqlStatement() {
        String originalSql = "SELECT a.name, b.age FROM `user` a JOIN user_detail b ON a.id = b.user_id WHERE b.age > 18";
        List<SQLStatement> stmtList = SQLUtils.parseStatements(originalSql, ShardPlugin.DB_TYPE);
        Assert.notEmpty(stmtList, "stmtList is empty, sql: " + originalSql);
        return stmtList;
    }

    private static String getShardTableName(String tableName) {
        if ("user".equals(tableName)) {
            return "user_new111";
        }
        if ("user_detail".equals(tableName)) {
            return "user_detail111";
        }
        return null;
    }
}

