package com.digiwin.dap.middle.sql.safe.service;

import com.digiwin.dap.middle.sql.safe.domain.SqlRule;
import com.digiwin.dap.middle.sql.safe.domain.SqlTable;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SqlSafeValidator {

    private static final String SUBSELECT = "<SUBSELECT>";
    // 应用和数据库不一致的数据
    private static final Map<String, String> appDbMap = new HashMap<>();

    static {
        appDbMap.put("gmc", "omc");
        appDbMap.put("boss", "iam");
    }

    private final String db;
    private final SqlRule sqlRule;

    private SqlSafeValidator(String appName, SqlRule sqlRule) {
        this.db = appDbMap.getOrDefault(appName, appName);
        this.sqlRule = sqlRule;
    }

    public static SqlSafeValidator build(String appName, SqlRule sqlRule) {
        return new SqlSafeValidator(appName, sqlRule);
    }

    /**
     * 根据配置白名单信息，校验给定的sql包含的表、字段、函数等信息都在白名单范围内 <br/>
     * 目前仅支持查询语句 <br/>
     *
     * @param sql 查询SQL
     */
    public void validate(String sql) {
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                throw new IllegalArgumentException("仅支持 SELECT 语句");
            }

            Select select = (Select) statement;

            // 支持 WITH 子句(CTE)
            if (select.getWithItemsList() != null) {
                for (WithItem<?> withItem : select.getWithItemsList()) {
                    sqlRule.with(withItem.getAliasName().toLowerCase());
                    Select withBody = withItem.getSelect();
                    processFromItem(withBody, "WITH 子句 [" + withItem.getAliasName() + "]", 0);
                }
            }

            processFromItem(select, "SELECT", 0);
        } catch (JSQLParserException e) {
            throw new IllegalArgumentException("SQL 解析失败: " + e.getMessage(), e);
        }
    }

    private void processFromItem(FromItem fromItem, String context, int level) {
        if (fromItem instanceof LateralSubSelect) {
            validateSelectBody(((LateralSubSelect) fromItem).getSelect(), context, level + 1);
        } else if (fromItem instanceof ParenthesedSelect) {
            validateSelectBody(((ParenthesedSelect) fromItem).getSelect(), context, level + 1);
        } else if (fromItem instanceof Select) {
            validateSelectBody((Select) fromItem, context, level + 1);
        } else if (fromItem instanceof ParenthesedFromItem) {
            processFromItem(((ParenthesedFromItem) fromItem).getFromItem(), context, level + 1);
        }
    }

    protected void validateSelectBody(Select select, String context, int level) {
        if (select instanceof PlainSelect) {
            validatePlainSelect((PlainSelect) select, context);
        } else if (select.getWithItemsList() != null && !select.getWithItemsList().isEmpty()) {
            select.getWithItemsList().forEach(withItem -> {
                if (withItem.getSelect() != null) {
                    validateSelectBody(withItem.getSelect(), context, level + 1);
                }
            });
        } else {
            // 处理 UNION / INTERSECT 等组合查询
            SetOperationList operationList = (SetOperationList) select;
            if (operationList.getSelects() != null && !operationList.getSelects().isEmpty()) {
                List<Select> plainSelects = operationList.getSelects();
                for (Select plainSelect : plainSelects) {
                    validateSelectBody(plainSelect, context, level + 1);
                }
            }
        }
    }


    private void validatePlainSelect(PlainSelect select, String context) {
        Map<String, String> aliasTableMap = extractAliasTableMapping(select);
        List<String> fieldAliases = extractFiledAliases(select);

        for (SelectItem<?> item : select.getSelectItems()) {
            if (item.getExpression() instanceof AllColumns) {
                throw new IllegalArgumentException("禁止使用 SELECT *(位置：" + context + ")");
            } else {
                validateExpression(item.getExpression(), aliasTableMap, fieldAliases, context + " 字段");
            }
        }

        if (select.getWhere() != null) {
            validateExpression(select.getWhere(), aliasTableMap, fieldAliases, context + " WHERE");
        }

        if (select.getGroupBy() != null) {
            if (select.getGroupBy().getGroupingSets() != null) {
                for (Expression expr : select.getGroupBy().getGroupingSets()) {
                    validateExpression(expr, aliasTableMap, fieldAliases, context + " GROUP BY");
                }
            }
        }

        if (select.getHaving() != null) {
            validateExpression(select.getHaving(), aliasTableMap, fieldAliases, context + " HAVING");
        }

        if (select.getOrderByElements() != null) {
            for (OrderByElement expr : select.getOrderByElements()) {
                validateExpression(expr.getExpression(), aliasTableMap, fieldAliases, context + " ORDER BY");
            }
        }

        processFromItem(select.getFromItem(), context + " FROM 子查询", 0);

        if (select.getJoins() != null) {
            for (Join join : select.getJoins()) {
                FromItem item = join.getRightItem();
                if (item instanceof Select) {
                    processFromItem(item, context + " JOIN 子查询", 0);
                }
                for (Expression expr : join.getOnExpressions()) {
                    validateExpression(expr, aliasTableMap, fieldAliases, context + " JOIN ON");
                }
            }
        }
    }

    private void validateExpression(Expression expr, Map<String, String> aliasTableMap, List<String> fieldAliases, String context) {
        expr.accept(new ExpressionVisitorAdapter<Void>() {
            @Override
            public <S> Void visit(Column column, S ctx) {
                String columnName = column.getColumnName();
                if ("FALSE".equalsIgnoreCase(columnName)
                        || "TRUE".equalsIgnoreCase(columnName)
                        || fieldAliases.contains(columnName)) {
                    // where表达式右边是布尔值、是字段别名
                    return null;
                }

                String tableAlias = column.getTable() != null ? column.getTable().getName() : null;
                String actualTable;
                if (tableAlias != null) {
                    actualTable = aliasTableMap.get(tableAlias);
                } else {
                    if (aliasTableMap.size() == 1) {
                        actualTable = aliasTableMap.values().iterator().next();
                    } else {
                        throw new IllegalArgumentException("多表查询时，字段 [" + column + "] 必须指定表别名(位置：" + context + ")");
                    }
                }

                // 跳过子查询别名关联字段
                if (SUBSELECT.equals(aliasTableMap.get(tableAlias)) || SUBSELECT.equals(actualTable)) {
                    return null;
                }

                if (actualTable == null) {
                    throw new IllegalArgumentException("字段 [" + column + "] 所属表未知(位置：" + context + ")");
                }

                // 跳过 CTE 表名(临时表)的白名单校验
                if (sqlRule.containsWith(actualTable)) {
                    return null;
                }

                if (!sqlRule.containsTable(db, actualTable)) {
                    throw new IllegalArgumentException("字段 [" + column + "] 所属表 [" + new SqlTable(db, actualTable) + "] 不在白名单(位置：" + context + ")");
                }

                if (!sqlRule.containsColumn(db, actualTable, columnName)) {
                    throw new IllegalArgumentException("字段 [" + column + "] 不在表 [" + new SqlTable(db, actualTable) + "] 的白名单中(位置：" + context + ")");
                }

                return null;
            }

            @Override
            public <S> Void visit(Function function, S ctx) {
                String funcName = function.getName();
                if (funcName == null) return null;

                if (!sqlRule.containsFunction(funcName)) {
                    throw new IllegalArgumentException("函数 [" + funcName + "] 不在全局白名单中(位置：" + context + ")");
                }

                if (function.getParameters() != null) {
                    for (Expression expr : function.getParameters()) {
                        if (expr instanceof AllColumns) {
                            expr.accept(this, funcName);
                        } else {
                            expr.accept(this, ctx);
                        }
                    }
                }

                return null;
            }

            @Override
            public <S> Void visit(AllColumns function, S funcName) {
                if (funcName instanceof String) {
                    if (!"COUNT".equalsIgnoreCase((String) funcName)) {
                        throw new IllegalArgumentException("函数 [" + funcName + "] 不允许使用 * 参数(位置：" + context + ")");
                    }
                }
                return null;
            }

            @Override
            public void visit(AllColumns allColumns) {
                // TODO 未处理
                System.out.println("allColumns = " + allColumns);
            }
        });
    }

    /**
     * 抽取 SELECT 子句中所有字段的别名（即 SELECT xxx AS alias 中的 alias），
     * <p>
     * 从而在 ORDER BY、GROUP BY 等子句中识别这些别名，避免误判。
     */
    private List<String> extractFiledAliases(PlainSelect select) {
        List<String> aliases = new ArrayList<>();
        for (SelectItem<?> item : select.getSelectItems()) {
            Alias alias = item.getAlias();
            if (alias != null) {
                aliases.add(alias.getName().replaceAll("'", ""));
            }
        }
        return aliases;
    }

    /**
     * 抽取别名和真实表名的映射关系，方便定位到实际表的白名单
     * <p>
     * alias->table_name
     */
    private Map<String, String> extractAliasTableMapping(PlainSelect select) {
        Map<String, String> aliasMap = new HashMap<>();
        processFromItem(select.getFromItem(), aliasMap);
        if (select.getJoins() != null) {
            for (Join join : select.getJoins()) {
                processFromItem(join.getRightItem(), aliasMap);
            }
        }
        return aliasMap;
    }

    private void processFromItem(FromItem item, Map<String, String> aliasMap) {
        if (item instanceof Table) {
            Table table = (Table) item;
            String normalizedName = normalizeTableName(table.getName());
            String alias = table.getAlias() != null ? table.getAlias().getName() : normalizedName;
            aliasMap.put(alias, normalizedName);
        } else if (item instanceof ParenthesedSelect) {
            ParenthesedSelect subSelect = (ParenthesedSelect) item;
            String alias = subSelect.getAlias() != null ? subSelect.getAlias().getName() : null;
            if (alias != null) {
                aliasMap.put(alias, SUBSELECT);
            }
        }
    }

    /**
     * 移除关键字表上面的单引号
     *
     * @param name `order`
     * @return order
     */
    private String normalizeTableName(String name) {
        return name != null ? name.replaceAll("`", "") : null;
    }
}
