package com.digiwin.dap.middle.database.encrypt.utils;

import com.digiwin.dap.middle.database.encrypt.enums.DatabaseEncryptExceptionEnum;
import com.digiwin.dap.middle.database.encrypt.exception.DatabaseEncryptException;
import com.digiwin.dap.middle.database.encrypt.model.ObjectRelationalMapping;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.CaseExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

/**
 * @author michael
 */
public class SqlParserUtils {

    private final static Logger LOGGER = LoggerFactory.getLogger(SqlParserUtils.class);

    private final static String MYSQL_DIALECT = "`";

    /**
     * 解析mybatis中select语句且返回为resultType的sql
     * jsqlparser 4.5版本include会被误判为关键字,导致sql解析失败,升级高版本可以解决问题,但是pagehelper当前版本引入的是4.5
     **/
    public static List<ObjectRelationalMapping> parseQuerySql(String sql) {
        List<ObjectRelationalMapping> objectRelationalMappings = new ArrayList<>();
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                LOGGER.debug("====>sql解析,暂不支持解析非select语句");
                return objectRelationalMappings;
            }
            // 1.获取sql中出现的表名以及其别名的映射关系
            List<TableInfo> tableInfoList = parseTableFromStatement(statement);
            // 2.select语句获取结果集表名、表别名、字段名、字段别名的映射关系注意
            objectRelationalMappings = parseResultSetFromStatement((Select) statement, tableInfoList);
        } catch (Exception e) {
            if (e instanceof JSQLParserException) {
                LOGGER.error("====>解析sql异常:【{}】", sql);
            } else {
                LOGGER.error("====>解析sql获取结果列异常,异常sql:【{}】", sql);
                throw new DatabaseEncryptException(DatabaseEncryptExceptionEnum.PARSE_SQL, sql);
            }

        }
        return objectRelationalMappings;
    }

    private static List<ObjectRelationalMapping> parseResultSetFromStatement(Select select, List<TableInfo> tableInfoList) {
        List<ObjectRelationalMapping> objectRelationalMappings = new ArrayList<>();
        SelectBody selectBody = select.getSelectBody();
        selectBody.accept(new SelectVisitorAdapter() {
            @Override
            public void visit(PlainSelect plainSelect) {
                List<SelectItem> selectItems = plainSelect.getSelectItems();
                for (SelectItem selectItem : selectItems) {
                    if (selectItem instanceof SelectExpressionItem) {
                        SelectExpressionItem expressionItem = (SelectExpressionItem) selectItem;
                        Expression expression = expressionItem.getExpression();
                        if (expression instanceof Function) {
                            // 函数列不支持自动解密
                            LOGGER.debug("====>函数结果列【{}】暂不支持解析", expression);
                        } else if (expression instanceof CaseExpression) {
                            // case when 不支持自动加密密,会出现多个数据库字段映射到对象中同一属性无法判断是否需要解密
                            LOGGER.debug("====>case结果列【{}】暂不支持解析", expression);
                        } else if (expression instanceof Column) {
                            String expressionString = expression.toString();
                            String aliasTableName = expressionString.contains(".") ? expressionString.substring(0, expressionString.indexOf(".")) : "";
                            String columnName = expressionString.contains(".") ? expressionString.substring(expressionString.indexOf(".") + 1) : expressionString;
                            String aliasColumnName = expressionItem.getAlias() != null ? expressionItem.getAlias().getName() : columnName;
                            String tableName = StringUtils.hasLength(aliasTableName) ? tableInfoList.stream().filter(x -> Objects.equals(x.getAliasTableName(), aliasTableName)).findFirst().get().getTableName().toLowerCase() :
                                    tableInfoList.get(0).getTableName().toLowerCase();
                            // 去掉列名上的``
                            columnName = columnName.replaceAll(MYSQL_DIALECT, "");
                            aliasColumnName = aliasColumnName.replaceAll(MYSQL_DIALECT, "");
                            objectRelationalMappings.add(new ObjectRelationalMapping(tableName, columnName, aliasColumnName));
                            LOGGER.debug("====>数据库【{}】表,字段【{}】,映射到对象中属性名为【{}】", tableName, columnName, aliasColumnName);
                        } else {
                            LOGGER.debug("====>未知结果列【{}】暂不支持解析", expression);
                        }
                    }
                }
            }
        });
        return objectRelationalMappings;
    }


    private static List<TableInfo> parseTableFromStatement(Statement statement) {
        List<TableInfo> tableInfoList = new ArrayList<>();
        if (statement instanceof Select) {
            Select selectStatement = (Select) statement;
            SelectBody selectBody = selectStatement.getSelectBody();
            selectBody.accept(new SelectVisitorAdapter() {
                @Override
                public void visit(PlainSelect plainSelect) {
                    FromItem fromItem = plainSelect.getFromItem();
                    if (fromItem instanceof Table) {
                        Table table = (Table) fromItem;
                        String tableName = table.getName();

                        String[] parts = tableName.split("\\s+");
                        tableName = parts[0];
                        // 去掉表名上的``
                        tableName = tableName.replaceAll(MYSQL_DIALECT, "");
                        String aliasTableName = table.getAlias() != null ? table.getAlias().getName() : tableName;
                        aliasTableName = aliasTableName.replaceAll(MYSQL_DIALECT, "");
                        tableInfoList.add(new TableInfo(tableName, aliasTableName));
                    }

                    if (plainSelect.getJoins() != null) {
                        for (Join join : plainSelect.getJoins()) {
                            FromItem joinItem = join.getRightItem();
                            if (joinItem instanceof Table) {
                                Table table = (Table) joinItem;
                                String tableName = table.getName();
                                String[] parts = tableName.split("\\s+");
                                tableName = parts[0];
                                // 去掉表名上的``
                                tableName = tableName.replaceAll(MYSQL_DIALECT, "");
                                String aliasTableName = table.getAlias() != null ? table.getAlias().getName() : tableName;
                                aliasTableName = aliasTableName.replaceAll(MYSQL_DIALECT, "");
                                tableInfoList.add(new TableInfo(tableName, aliasTableName));
                            }
                        }
                    }
                }
            });
        } else if (statement instanceof Update) {
            Update updateStatement = (Update) statement;
            Table table = updateStatement.getTable();
            tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
            if (updateStatement.getJoins() != null) {
                for (Join join : updateStatement.getJoins()) {
                    FromItem fromItem = join.getRightItem();
                    if (fromItem instanceof Table) {
                        Table table1 = (Table) fromItem;
                        tableInfoList.add(new TableInfo(table1.getName(), table1.getAlias().getName()));
                    }
                }
            }
        } else if (statement instanceof Delete) {
            Delete deleteStatement = (Delete) statement;
            List<Table> tableList = deleteStatement.getTables();
            for (Table table : tableList) {
                tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
            }
            if (deleteStatement.getJoins() != null) {
                for (Join join : deleteStatement.getJoins()) {
                    FromItem fromItem = join.getRightItem();
                    if (fromItem instanceof Table) {
                        Table table1 = (Table) fromItem;
                        tableInfoList.add(new TableInfo(table1.getName(), table1.getAlias().getName()));
                    }
                }
            }
        } else if (statement instanceof Insert) {
            Insert insertStatement = (Insert) statement;
            Table table = insertStatement.getTable();
            tableInfoList.add(new TableInfo(table.getName(), table.getAlias().getName()));
        }
        return tableInfoList;
    }

    /**
     * SQL中表名与表别名的映射关系
     **/
    static class TableInfo {
        private final String tableName;

        private final String aliasTableName;

        public TableInfo(String tableName, String aliasTableName) {
            this.tableName = tableName;
            this.aliasTableName = aliasTableName;
        }

        public String getTableName() {
            return tableName;
        }

        public String getAliasTableName() {
            return aliasTableName;
        }
    }

}
