package com.digiwin.dap.middle.database.encrypt.utils;

import com.digiwin.dap.middle.database.encrypt.model.ParameterMappingRelation;
import com.digiwin.dap.middle.database.encrypt.model.ResultSetMappingRelation;
import com.digiwin.dap.middle.database.encrypt.model.SqlParseResult;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.update.UpdateSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
 * @author michael
 */
public class SqlParserUtils {

    private final static Logger LOGGER = LoggerFactory.getLogger(SqlParserUtils.class);

    public static void main(String[] args) {
        String[] sqlStatements = {
                "UPDATE table_name AS t SET t.column1 = 'new_value', t.column2 = 'another_value' WHERE t.column3 = 'value'",
                "UPDATE table_name SET column1 = 'new_value' WHERE column2 = 'value2'"
        };

        for (String sqlStatement : sqlStatements) {
            parseSql(sqlStatement);
        }
    }

    /**
     * 解析sql,主要实现功能如下:
     * 1.获取Hibernate和MyBatis中select语句结果集,方便对结果集中涉敏字段进行解密 TODO select * 特殊处理
     * 2.针对Hibernate中insert、update、delete语句涉敏参数加密;MyBatis通过TypeHandler实现涉敏字段参数加密
     * 3.TODO 针对MyBatis中select语句返回类型为ResultMap的额外处理
     *
     * @param sql
     * @return
     * @return java.util.List<com.digiwin.dap.middle.database.encrypt.utils.SqlParserUtils.ColumnInfo>
     * @author michael
     * @date 2024/6/13 09:24
     **/
    public static SqlParseResult parseSql(String sql) {
        SqlParseResult sqlParseResult = new SqlParseResult();
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            // 1.获取sql中出现的表名以及其别名的映射关系
            List<TableInfo> tableInfoList = parseTableFromStatement(statement);
            // 2.获取sql中where条件的字段和表或者表别名的映射关系
            List<ParameterMappingRelation> whereConditionInfos = parseWhereFromStatement(statement, tableInfoList);
            // 3.select语句获取结果集表名、表别名、字段名、字段别名的映射关系注意
            List<ResultSetMappingRelation> resultSetInfos = statement instanceof Select ? parseResultSetFromStatement((Select) statement, tableInfoList) : null;
            // 4.insert和update语句的操作列
            List<ParameterMappingRelation> operationColumnInfos = statement instanceof Update ||
                    statement instanceof Insert ? parseOperationColumnFromStatement(statement, tableInfoList) : null;
            sqlParseResult.setResultSetInfos(resultSetInfos);
            sqlParseResult.setOperationColumnInfos(operationColumnInfos);
            sqlParseResult.setWhereConditionInfos(whereConditionInfos);
        } catch (Exception e) {
            LOGGER.error("SQL解析异常", e);
        }
        return sqlParseResult;
    }

    private static List<ResultSetMappingRelation> parseResultSetFromStatement(Select select, List<TableInfo> tableInfoList) {
        List<ResultSetMappingRelation> resultSetMappingRelations = new ArrayList<>();
        SelectBody selectBody = select.getSelectBody();
        selectBody.accept(new SelectVisitorAdapter() {
            @Override
            public void visit(PlainSelect plainSelect) {
                List<SelectItem> selectItems = plainSelect.getSelectItems();
                for (SelectItem selectItem : selectItems) {
                    if (selectItem instanceof SelectExpressionItem) {
                        SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
                        // TODO 判断列上是否包含函数,暂不支持函数解析
                        if (expressionItem.getExpression() instanceof Function) {
                            continue;
                        }
//                                Function function = (Function) expressionItem.getExpression();
//                                String functionName = function.getName();
//                                LOGGER.info("===>涉及表字段包含函数:{}", functionName);
//
//                                // 如果需要，可以进一步处理函数的参数等信息
//                                for (Expression expression : function.getParameters().getExpressions()) {
//                                    if (expression instanceof Function) {
//                                        Function eestedFunction = (Function) expression;
//                                        List<Expression> parameters = eestedFunction.getParameters().getExpressions();
//                                        for (Expression parameter : parameters) {
//                                            LOGGER.info("===>涉及表字段包含嵌套函数,函数参数:{}", parameter.toString());
//                                        }
//                                    }else {
//                                        LOGGER.info("===>涉及表字段包含函数,函数参数:{}", expression.toString());
//                                    }
//                                }

                        String expressionString = expressionItem.getExpression().toString();
                        String aliasTableName = expressionString.contains(".") ? expressionString.substring(0, expressionString.indexOf(".")) : "";
                        String columnName = expressionString.contains(".") ? expressionString.substring(expressionString.indexOf(".") + 1) : expressionString;
                        String aliasColumnName = expressionItem.getAlias() != null ? expressionItem.getAlias().getName() : columnName;
                        String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                                tableInfoList.get(0).getTableName().toLowerCase();
                        resultSetMappingRelations.add(new ResultSetMappingRelation(tableName, columnName, aliasColumnName));
                        LOGGER.info("===>数据库【{}】表,字段【{}】,映射到对象中属性名为【{}】", tableName, columnName, aliasColumnName);
                    }
                }
            }
        });
        return resultSetMappingRelations;
    }

    private static List<ParameterMappingRelation> parseWhereFromStatement(Statement statement, List<TableInfo> tableInfoList) {
        List<ParameterMappingRelation> whereConditionMappingRelations = new ArrayList<>();
        Expression whereExpression = null;
        if (statement instanceof Select) {
            Select select = (Select) statement;
            SelectBody selectBody = select.getSelectBody();
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                whereExpression = plainSelect.getWhere();
            }
        } else if (statement instanceof Update) {
            Update updateStatement = (Update) statement;

            whereExpression = updateStatement.getWhere();
        } else if (statement instanceof Delete) {
            Delete deleteStatement = (Delete) statement;
            deleteStatement.getTables();
            whereExpression = deleteStatement.getWhere();
        } else if (statement instanceof Insert) {
            Insert insertStatement = (Insert) statement;
            insertStatement.getTable();
        }

        if (whereExpression != null) {
            whereExpression.accept(new ExpressionVisitorAdapter() {
                @Override
                protected void visitBinaryExpression(net.sf.jsqlparser.expression.BinaryExpression expr) {
                    if (expr.getLeftExpression() instanceof Column) {
                        Column column = (Column) expr.getLeftExpression();
                        String columnName = column.getColumnName();

                        String aliasTableName = columnName.contains(".") ? columnName.substring(0, columnName.indexOf(".")) : "";
                        String realColumnName = columnName.contains(".") ? columnName.substring(columnName.indexOf(".") + 1) : columnName;
                        String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                                tableInfoList.get(0).getTableName().toLowerCase();
                        whereConditionMappingRelations.add(new ParameterMappingRelation(tableName, realColumnName));
                    }
                    if (expr.getRightExpression() instanceof Column) {
                        Column column = (Column) expr.getRightExpression();
                        String columnName = column.getColumnName();
                        String aliasTableName = columnName.contains(".") ? columnName.substring(0, columnName.indexOf(".")) : "";
                        String realColumnName = columnName.contains(".") ? columnName.substring(columnName.indexOf(".") + 1) : columnName;
                        String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                                tableInfoList.get(0).getTableName().toLowerCase();
                        whereConditionMappingRelations.add(new ParameterMappingRelation(tableName, realColumnName));
                    }
                    super.visitBinaryExpression(expr);
                }
            });
        }
        return whereConditionMappingRelations;
    }

    private static List<ParameterMappingRelation> parseOperationColumnFromStatement(Statement statement, List<TableInfo> tableInfoList) {
        if (!(statement instanceof Update || statement instanceof Insert)) {
            return Collections.emptyList();
        }
        List<ParameterMappingRelation> operationColumnMappingRelations = new ArrayList<>();

        if (statement instanceof Update) {
            Update update = (Update) statement;
            List<UpdateSet> updateSetList = update.getUpdateSets();
            for (UpdateSet updateSet : updateSetList) {
                for (Column column : updateSet.getColumns()) {
                    String columnName = column.getColumnName();
                    String aliasTableName = columnName.contains(".") ? columnName.substring(0, columnName.indexOf(".")) : "";
                    String realColumnName = columnName.contains(".") ? columnName.substring(columnName.indexOf(".") + 1) : columnName;
                    String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                            tableInfoList.get(0).getTableName().toLowerCase();
                    operationColumnMappingRelations.add(new ParameterMappingRelation(tableName, realColumnName));
                }
            }
        }
        if (statement instanceof Insert) {
            Insert insert = new Insert();
            for (Column column : insert.getColumns()) {
                String columnName = column.getColumnName();
                String aliasTableName = columnName.contains(".") ? columnName.substring(0, columnName.indexOf(".")) : "";
                String realColumnName = columnName.contains(".") ? columnName.substring(columnName.indexOf(".") + 1) : columnName;
                String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                        tableInfoList.get(0).getTableName().toLowerCase();
                operationColumnMappingRelations.add(new ParameterMappingRelation(tableName, realColumnName));
            }
        }
        return operationColumnMappingRelations;
    }

    private static List<TableInfo> parseTableFromStatement(Statement statement) {
        List<TableInfo> tableInfoList = new ArrayList<>();
        if (statement instanceof Select) {
            Select selectStatement = (Select) statement;
            SelectBody selectBody = selectStatement.getSelectBody();
            selectBody.accept(new SelectVisitorAdapter() {
                @Override
                public void visit(PlainSelect plainSelect) {
                    FromItem fromItem = plainSelect.getFromItem();
                    if (fromItem instanceof Table) {
                        Table table = (Table) fromItem;
                        String tableName = table.getName();
                        String[] parts = tableName.split("\\s+");
                        tableName = parts[0];
                        String aliasTableName = table.getAlias() != null ? table.getAlias().getName() : null;
                        tableInfoList.add(new TableInfo(tableName, aliasTableName));
                    }

                    if (plainSelect.getJoins() != null) {
                        for (Join join : plainSelect.getJoins()) {
                            FromItem joinItem = join.getRightItem();
                            if (joinItem instanceof Table) {
                                Table table = (Table) joinItem;
                                String tableName = table.getName();
                                String[] parts = tableName.split("\\s+");
                                tableName = parts[0];
                                String aliasTableName = table.getAlias() != null ? table.getAlias().getName() : null;
                                tableInfoList.add(new TableInfo(tableName, aliasTableName));
                            }
                        }
                    }
                }
            });
        } else if (statement instanceof Update) {
            Update updateStatement = (Update) statement;
            Table table = updateStatement.getTable();
            tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
            if (updateStatement.getJoins() != null) {
                for (Join join : updateStatement.getJoins()) {
                    FromItem fromItem = join.getRightItem();
                    if (fromItem instanceof Table) {
                        Table table1 = (Table) fromItem;
                        tableInfoList.add(new TableInfo(table1.getName(), table1.getAlias().getName()));
                    }
                }
            }
        } else if (statement instanceof Delete) {
            Delete deleteStatement = (Delete) statement;
            List<Table> tableList = deleteStatement.getTables();
            for (Table table : tableList) {
                tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
            }
            if (deleteStatement.getJoins() != null) {
                for (Join join : deleteStatement.getJoins()) {
                    FromItem fromItem = join.getRightItem();
                    if (fromItem instanceof Table) {
                        Table table1 = (Table) fromItem;
                        tableInfoList.add(new TableInfo(table1.getName(), table1.getAlias().getName()));
                    }
                }
            }
        } else if (statement instanceof Insert) {
            Insert insertStatement = (Insert) statement;
            Table table = insertStatement.getTable();
            tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
        }
        return tableInfoList;
    }

    /**
     * SQL中表名与表别名的映射关系
     *
     * @author michael
     * @date 2024/6/6 13:28
     * @see
     **/
    static class TableInfo {
        private String tableName;

        private String aliasTableName;

        public TableInfo(String tableName, String aliasTableName) {
            this.tableName = tableName;
            this.aliasTableName = aliasTableName;
        }

        public String getTableName() {
            return tableName;
        }

        public String getAliasTableName() {
            return aliasTableName;
        }
    }
}
