package com.digiwin.dap.middle.console.serice;

import com.digiwin.dap.middle.console.domain.admin.SqlRule;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SqlSafeValidator {

    private static final String SUBSELECT = "<SUBSELECT>";
    private final SqlRule sqlRule;

    private SqlSafeValidator(SqlRule sqlRule) {
        this.sqlRule = sqlRule;
    }

    public static SqlSafeValidator build(SqlRule sqlRule) {
        return new SqlSafeValidator(sqlRule);
    }

    /**
     * 根据配置白名单信息，校验给定的sql包含的表、字段、函数等信息都在白名单范围内 <br/>
     * 目前仅支持查询语句 <br/>
     *
     * @param sql 查询SQL
     */
    public void validate(String sql) {
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                throw new IllegalArgumentException("仅支持 SELECT 语句");
            }

            Select select = (Select) statement;

            // 支持 WITH 子句(CTE)
            if (select.getWithItemsList() != null) {
                for (WithItem withItem : select.getWithItemsList()) {
                    sqlRule.with(withItem.getName().toLowerCase());
                    SelectBody withBody = withItem.getSubSelect().getSelectBody();
                    validateSelectBody(withBody, "WITH 子句 [" + withItem.getName() + "]");
                }
            }

            validateSelectBody(select.getSelectBody(), "SELECT");
        } catch (JSQLParserException e) {
            throw new IllegalArgumentException("SQL 解析失败: " + e.getMessage(), e);
        }
    }


    private void validateSelectBody(SelectBody body, String context) {
        if (body instanceof PlainSelect) {
            validatePlainSelect((PlainSelect) body, context);
        } else if (body instanceof SetOperationList) {
            // UNION ALL 查询
            for (SelectBody subBody : ((SetOperationList) body).getSelects()) {
                validateSelectBody(subBody, "子查询(" + context + ")");
            }
        }
    }

    private void validatePlainSelect(PlainSelect select, String context) {
        Map<String, String> aliasTableMap = extractAliasTableMapping(select);
        List<String> fieldAliases = extractFiledAliases(select);

        for (SelectItem item : select.getSelectItems()) {
            if (item instanceof AllColumns) {
                throw new IllegalArgumentException("禁止使用 SELECT *(位置：" + context + ")");
            } else if (item instanceof SelectExpressionItem) {
                SelectExpressionItem selectExpressionItem = (SelectExpressionItem) item;
                validateExpression(selectExpressionItem.getExpression(), aliasTableMap, fieldAliases, context + " 字段");
            }
        }

        if (select.getWhere() != null) {
            validateExpression(select.getWhere(), aliasTableMap, fieldAliases, context + " WHERE");
        }

        if (select.getGroupBy() != null) {
            if (select.getGroupBy().getGroupByExpressionList().getExpressions() != null) {
                for (Expression expr : select.getGroupBy().getGroupByExpressionList().getExpressions()) {
                    validateExpression(expr, aliasTableMap, fieldAliases, context + " GROUP BY");
                }
            }
        }

        if (select.getHaving() != null) {
            validateExpression(select.getHaving(), aliasTableMap, fieldAliases, context + " HAVING");
        }

        if (select.getOrderByElements() != null) {
            for (OrderByElement expr : select.getOrderByElements()) {
                validateExpression(expr.getExpression(), aliasTableMap, fieldAliases, context + " ORDER BY");
            }
        }

        if (select.getFromItem() instanceof SubSelect) {
            validateSelectBody(((SubSelect) select.getFromItem()).getSelectBody(), context + " FROM 子查询");
        }

        if (select.getJoins() != null) {
            for (Join join : select.getJoins()) {
                FromItem item = join.getRightItem();
                if (item instanceof SubSelect) {
                    validateSelectBody(((SubSelect) item).getSelectBody(), context + " JOIN 子查询");
                }
                for (Expression expr : join.getOnExpressions()) {
                    validateExpression(expr, aliasTableMap, fieldAliases, context + " JOIN ON");
                }
            }
        }
    }

    private void validateExpression(Expression expr, Map<String, String> aliasTableMap, List<String> fieldAliases, String context) {
        expr.accept(new ExpressionVisitorAdapter() {
            @Override
            public void visit(Column column) {
                String columnName = column.getColumnName();
                if ("FALSE".equalsIgnoreCase(columnName)
                        || "TRUE".equalsIgnoreCase(columnName)
                        || fieldAliases.contains(columnName)) {
                    // where表达式右边是布尔值、是字段别名
                    return;
                }

                String tableAlias = column.getTable() != null ? column.getTable().getName() : null;
                String actualTable;
                if (tableAlias != null) {
                    actualTable = aliasTableMap.get(tableAlias);
                } else {
                    if (aliasTableMap.size() == 1) {
                        actualTable = aliasTableMap.values().iterator().next();
                    } else {
                        throw new IllegalArgumentException("多表查询时，字段 [" + column + "] 必须指定表别名(位置：" + context + ")");
                    }
                }

                // 跳过子查询别名关联字段
                if (SUBSELECT.equals(aliasTableMap.get(tableAlias)) || SUBSELECT.equals(actualTable)) {
                    return;
                }

                if (actualTable == null) {
                    throw new IllegalArgumentException("字段 [" + column + "] 所属表未知(位置：" + context + ")");
                }

                // 跳过 CTE 表名(临时表)的白名单校验
                if (sqlRule.containsWith(actualTable)) {
                    return;
                }

                if (!sqlRule.containsTable(actualTable)) {
                    throw new IllegalArgumentException("字段 [" + column + "] 所属表 [" + actualTable + "] 不在白名单(位置：" + context + ")");
                }

                if (!sqlRule.containsColumn(actualTable, columnName)) {
                    throw new IllegalArgumentException("字段 [" + column + "] 不在表 [" + actualTable + "] 的白名单中(位置：" + context + ")");
                }
            }

            @Override
            public void visit(Function function) {
                String funcName = function.getName();
                if (funcName == null) return;

                if (!sqlRule.containsFunction(funcName)) {
                    throw new IllegalArgumentException("函数 [" + funcName + "] 不在全局白名单中(位置：" + context + ")");
                }

                if (function.getParameters() != null) {
                    ExpressionList params = function.getParameters();
                    if (params.getExpressions() != null) {
                        for (Expression param : params.getExpressions()) {
                            validateExpression(param, aliasTableMap, fieldAliases, context + " 函数参数");
                        }
                    }
                } else if (function.isAllColumns()) {
                    if (!"COUNT".equalsIgnoreCase(funcName)) {
                        throw new IllegalArgumentException("函数 [" + funcName + "] 不允许使用 * 参数(位置：" + context + ")");
                    }
                }
            }

            @Override
            public void visit(SubSelect subSelect) {
                validateSelectBody(subSelect.getSelectBody(), context + " 子查询");
            }
        });
    }

    /**
     * 抽取 SELECT 子句中所有字段的别名（即 SELECT xxx AS alias 中的 alias），
     * <p>
     * 从而在 ORDER BY、GROUP BY 等子句中识别这些别名，避免误判。
     */
    private List<String> extractFiledAliases(PlainSelect select) {
        List<String> aliases = new ArrayList<>();
        for (SelectItem item : select.getSelectItems()) {
            if (item instanceof SelectExpressionItem) {
                SelectExpressionItem sei = (SelectExpressionItem) item;
                Alias alias = sei.getAlias();
                if (alias != null) {
                    aliases.add(alias.getName().replaceAll("'", ""));
                }
            }
        }
        return aliases;
    }

    /**
     * 抽取别名和真实表名的映射关系，方便定位到实际表的白名单
     * <p>
     * alias->table_name
     */
    private Map<String, String> extractAliasTableMapping(PlainSelect select) {
        Map<String, String> aliasMap = new HashMap<>();
        processFromItem(select.getFromItem(), aliasMap);
        if (select.getJoins() != null) {
            for (Join join : select.getJoins()) {
                processFromItem(join.getRightItem(), aliasMap);
            }
        }
        return aliasMap;
    }

    private void processFromItem(FromItem item, Map<String, String> aliasMap) {
        if (item instanceof Table) {
            Table table = (Table) item;
            String normalizedName = normalizeTableName(table.getName());
            String alias = table.getAlias() != null ? table.getAlias().getName() : normalizedName;
            aliasMap.put(alias, normalizedName);
        } else if (item instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) item;
            String alias = subSelect.getAlias() != null ? subSelect.getAlias().getName() : null;
            if (alias != null) {
                aliasMap.put(alias, SUBSELECT);
            }
        }
    }

    /**
     * 移除关键字表上面的单引号
     *
     * @param name `order`
     * @return order
     */
    private String normalizeTableName(String name) {
        return name != null ? name.replaceAll("`", "") : null;
    }
}
