package com.digiwin.athena.knowledgegraph.utils;

/**
 * @title: SQLIndexSuggestion
 * @author: linc
 * @date 2024/4/17 9:26
 * @version: 1.0
 */
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.operators.relational.Between;
import net.sf.jsqlparser.expression.operators.relational.LikeExpression;
import net.sf.jsqlparser.parser.SimpleNode;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.*;
import org.apache.commons.collections.CollectionUtils;

import java.util.*;

public class SQLIndexSuggestionUtils {
    private static void extractTablesWithAlias(SelectItem selectItem, Map<String, String> tableAliasMap,Set<String> joinColumnsSet) {
        if (selectItem instanceof SelectExpressionItem) {
            SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
            extractTablesWithAliasFromItem(expressionItem.getExpression(), tableAliasMap,joinColumnsSet);
        }
    }

    private static void extractTablesWithAliasFromItem(Expression expression, Map<String, String> tableAliasMap,Set<String> joinColumnsSet) {
        if (expression instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) expression;
            extractTablesWithAlias(subSelect.getSelectBody(), tableAliasMap,joinColumnsSet);
        } else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            extractTablesWithAliasFromItem(binaryExpression.getLeftExpression(), tableAliasMap,joinColumnsSet);
            extractTablesWithAliasFromItem(binaryExpression.getRightExpression(), tableAliasMap,joinColumnsSet);
        } else if (expression instanceof Function) {
            Function function = (Function) expression;
            for (Expression param : function.getParameters().getExpressions()) {
                extractTablesWithAliasFromItem(param, tableAliasMap,joinColumnsSet);
            }
        } else if (expression instanceof Column) {
            Column column = (Column) expression;
            String tableName = column.getTable() != null ? column.getTable().getName() : "";
            String tableAlias = column.getTable() != null && column.getTable().getAlias() != null ? column.getTable().getAlias().getName() : "";
            if (!tableName.isEmpty() && !tableAlias.isEmpty()) {
                tableAliasMap.put(tableAlias,tableName);
            }
        }
    }

    private static void extractTablesWithAlias(FromItem fromItem, Map<String, String> tableAliasMap,Set<String> joinColumnsSet) {
        if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            extractTablesWithAlias(subSelect.getSelectBody(), tableAliasMap,joinColumnsSet);
        } else if (fromItem instanceof Join) {
            Join join = (Join) fromItem;
            extractTablesWithAliasFromItem(join.getOnExpression(), tableAliasMap,joinColumnsSet);
        } else if (fromItem instanceof Table) {
            Table table = (Table) fromItem;
            tableAliasMap.put(table.getAlias() != null ? table.getAlias().getName() : "",table.getName());
        }
    }

    public static void extractTablesWithAlias(SelectBody selectBody, Map<String, String> tableAliasMap, Set<String> joinColumnsSet) {
        if (selectBody instanceof PlainSelect) {
            PlainSelect plainSelect = (PlainSelect) selectBody;
            if (plainSelect.getFromItem() != null) {
                extractTablesWithAlias(plainSelect.getFromItem(), tableAliasMap,joinColumnsSet);
                extractFromTableAndColumn(plainSelect,joinColumnsSet);
            }
            if (plainSelect.getJoins() != null) {
                for (Join join : plainSelect.getJoins()) {
                    extractTablesWithAlias(join.getRightItem(), tableAliasMap,joinColumnsSet);
                    extractJoinTableAndColumn(join,tableAliasMap,joinColumnsSet);
                }
            }
            if (plainSelect.getSelectItems() != null) {
                for (SelectItem selectItem : plainSelect.getSelectItems()) {
                    extractTablesWithAlias(selectItem, tableAliasMap,joinColumnsSet);

                }
            }

            // 获取WHERE条件
            extractWhereTableAndColumn(plainSelect, tableAliasMap,joinColumnsSet);

        } else if (selectBody instanceof SetOperationList) {
            SetOperationList setOperationList = (SetOperationList) selectBody;
            for (SelectBody select : setOperationList.getSelects()) {
                extractTablesWithAlias(select, tableAliasMap,joinColumnsSet);
            }
        }
    }



    // 获取主表和列信息
    private static void extractFromTableAndColumn(PlainSelect plainSelect,Set<String> joinColumnsSet) {
        FromItem fromItem = plainSelect.getFromItem();
        String mainTableAlias = fromItem.getAlias() != null ? fromItem.getAlias().getName() : "";
        //String tableName1 = ((Table) fromItem).getName();
        if (fromItem instanceof Table) {
            mainTableAlias = ((Table) fromItem).getName();
        }
//        joinColumnsList.add(mainTableAlias + "." +mainTableAlias);
        List<SelectItem> selectItems = plainSelect.getSelectItems();
        for (SelectItem selectItem : selectItems) {
            if (selectItem instanceof SelectExpressionItem) {
                SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
                if (expressionItem.getExpression() instanceof Column) {
                    Column column = (Column) expressionItem.getExpression();
                    String columnName = column.getColumnName();
                    String tableName = Objects.isNull(column.getTable()) ? "" : column.getTable().getName();
//                    joinColumnsList.add(mainTableAlias + tableName + "." +columnName);
                }
            }
        }
    }


    // 获取JOIN表和列信息
    private static void extractJoinTableAndColumn(Join join,Map<String, String> tableAliasMap,Set<String> joinColumnsSet) {
        FromItem joinItem = join.getRightItem();
        String joinTableAlias = joinItem.getAlias() != null ? joinItem.getAlias().getName() : "";
        if (joinItem instanceof Table) {
            joinTableAlias = ((Table) joinItem).getName();
        }
        List<Column> joinColumns = join.getUsingColumns(); // 或者 join.getOnExpression() 来获取用于连接的列

        for (Column joinColumn : joinColumns) {
            String joinColumnName = joinColumn.getColumnName();
            joinColumnsSet.add(joinTableAlias + "." +joinColumnName);
        }

        // 获取连接条件
        Collection<Expression> onExpressions = join.getOnExpressions();
        if (CollectionUtils.isNotEmpty(onExpressions)) {
            // 解析连接条件中的表和列
            for (Expression onExpression : onExpressions) {
                processJoinExpression(onExpression,tableAliasMap,joinColumnsSet);
            }
        }

    }

    // 获取where条件的表名和栏位
    private static void extractWhereTableAndColumn(PlainSelect plainSelect, Map<String, String> tableAliasMap,Set<String> joinColumnsSet) {
        Expression whereExpression = plainSelect.getWhere();
        if (whereExpression != null) {
            processWhereExpression(whereExpression,tableAliasMap,joinColumnsSet);
        }
    }

    private static void processJoinExpression(Expression expression, Map<String, String> tableAliasMap,Set<String> tableColumnsList) {
        if (expression instanceof BinaryExpression) {
            processBinaryExpression((BinaryExpression) expression, tableAliasMap,tableColumnsList);
        } else {
            Expression expressionValue = ((Parenthesis) expression).getExpression();
            processJoinExpression(expressionValue, tableAliasMap,tableColumnsList);
        }
    }

    private static void processBinaryExpression(BinaryExpression binaryExpression, Map<String, String> tableAliasMap,Set<String> tableColumnsList) {
        Expression leftExpression = binaryExpression.getLeftExpression();
        Expression rightExpression = binaryExpression.getRightExpression();

        // 当时LIKE且以%开头匹配的字段，则不加入索引
        SimpleNode type = binaryExpression.getASTNode();
        if (null != type && Objects.equals(type.toString(),"LikeExpression")) {
            String rightValue = ((StringValue) rightExpression).getValue();
            if (rightValue.startsWith("%")) {
                return;
            }
        }


        if (leftExpression instanceof BinaryExpression) {
            processWhereExpression(leftExpression, tableAliasMap,tableColumnsList);
        } else {
            processSingleExpression(leftExpression, tableAliasMap,tableColumnsList);
        }

        if (rightExpression instanceof BinaryExpression) {
            processWhereExpression(rightExpression, tableAliasMap,tableColumnsList);
        } else {
            processSingleExpression(rightExpression, tableAliasMap,tableColumnsList);
        }
    }

    private static void processBetweenExpression(Between binaryExpression, Map<String, String> tableAliasMap,Set<String> tableColumnsList) {
        Expression leftExpression = binaryExpression.getLeftExpression();
        if (leftExpression instanceof BinaryExpression) {
            processWhereExpression(leftExpression,tableAliasMap, tableColumnsList);
        } else {
            processSingleExpression(leftExpression, tableAliasMap,tableColumnsList);
        }
    }

    private static void processWhereExpression(Expression expression, Map<String, String> tableAliasMap,Set<String> tableColumnsList) {
        if (expression instanceof BinaryExpression) {
            processBinaryExpression((BinaryExpression) expression, tableAliasMap,tableColumnsList);
        } else if (expression instanceof Parenthesis) {
            Expression expressionValue = ((Parenthesis) expression).getExpression();
            processWhereExpression(expressionValue,tableAliasMap, tableColumnsList);
        }else if (expression instanceof Between) {
            Between betweenExpression = (Between) expression;
            // 将Between表达式转换为两个比较表达式
            processBetweenExpression(betweenExpression, tableAliasMap,tableColumnsList);
        } else {
            // 处理其他类型的表达式，可能是其他操作符或条件
        }
    }


    private static void processSingleExpression(Expression expression,Map<String, String> tableAliasMap, Set<String> tableColumnsList) {
        if (expression instanceof Column) {
            Column column = (Column) expression;
            Table tableName = column.getTable();
            String tableAlias = tableName != null ? tableName.getName() : "";
            String columnName = column.getColumnName();

            String tableColumn = tableAlias.isEmpty() ? columnName : tableAlias + "%" + columnName;
            tableColumnsList.add(tableColumn);
        }else if (expression instanceof Function) {
            // where 中的函数，不加入索引   (`oga_file`.OGA02) BETWEEN '2024-01-01'AND '2024-03-31' 这样写会进入这里好像不对，怎么改造
            return;
            /*Function function = (Function) expression;
            for (Expression param : function.getParameters().getExpressions()) {
                processSingleExpression(param,tableAliasMap,tableColumnsList);
            }*/
        }else if (expression instanceof Between) {
            Between betweenExpression = (Between) expression;
            // 将Between表达式转换为两个比较表达式
            processBetweenExpression(betweenExpression, tableAliasMap,tableColumnsList);
        }
        // 处理expression表达式用括号包住栏位的情况
        else if (expression instanceof Parenthesis) {
            Expression expressionValue = ((Parenthesis) expression).getExpression();
            processSingleExpression(expressionValue,tableAliasMap,tableColumnsList);
        }
    }

    // 分析获取优化建议
    public static List<String> getOptimizationSuggestions(List<String> tablesList, Map<String, String> tableAliasMap, List<String> joinColumnsList) {
        List<String> suggestions = new ArrayList<>();
        if (CollectionUtils.isNotEmpty(joinColumnsList) && CollectionUtils.isNotEmpty(tablesList)) {
            for (String table :tablesList) {
                StringBuilder sb = new StringBuilder();
                sb.append(table).append(".");
                for (String suggestion : joinColumnsList) {
                    String[] data = suggestion.split("%");
                    if (data.length > 1) {
                        String tableName = tableAliasMap.getOrDefault(data[0],data[0]);
                        String column = data[1];
                        if (table.contains(tableName)) {
                            sb.append(column).append(",");
                        }
                    }

                    if (data.length == 1 && tablesList.size() == 1) {
                        sb.append(data[0]).append(",");
                    }
                }
                String inputString = sb.toString();
                if (inputString.endsWith(",")) {
                    inputString = inputString.substring(0, inputString.length() - 1);
                    suggestions.add(inputString);
                }
            }
        }
        return suggestions;
    }
}
