package com.digiwin.dap.middleware.mybatis;

import com.github.pagehelper.JSqlParser;
import com.github.pagehelper.parser.CountSqlParser;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.parser.Token;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;

import java.util.ArrayList;
import java.util.List;

/**
 * @author blockWilling
 * @date 2024/1/2 10:40
 * @mail kangjin@digiwin.com
 */
public class CustomCountSqlParser extends CountSqlParser {
    private static final Alias TABLE_ALIAS;
    private JSqlParser jSqlParser;

    public CustomCountSqlParser(JSqlParser jSqlParser) {
        super(jSqlParser);
        this.jSqlParser = jSqlParser;
    }

    public JSqlParser getjSqlParser() {
        return jSqlParser;
    }

    public void setjSqlParser(JSqlParser jSqlParser) {
        this.jSqlParser = jSqlParser;
    }

    static {
        TABLE_ALIAS = new Alias("table_count");
        TABLE_ALIAS.setUseAs(false);
    }
    @Override
    public void sqlToCount(Select select, String name) {
        SelectBody selectBody = select.getSelectBody();
        // 是否能简化count查询
        List<SelectItem> COUNT_ITEM = new ArrayList<SelectItem>();
        COUNT_ITEM.add(new SelectExpressionItem(new Column("count(" + name +")")));
        if (selectBody instanceof PlainSelect && (CustomPageSettingsHolder.isSimpleCountSql()||isSimpleCount((PlainSelect) selectBody))) {
            PlainSelect plainSelect = (PlainSelect) selectBody;
            plainSelect.setSelectItems(COUNT_ITEM);
            plainSelect.setDistinct(null);
        } else {
            PlainSelect plainSelect = new PlainSelect();
            SubSelect subSelect = new SubSelect();
            subSelect.setSelectBody(selectBody);
            subSelect.setAlias(TABLE_ALIAS);
            plainSelect.setFromItem(subSelect);
            plainSelect.setSelectItems(COUNT_ITEM);
            select.setSelectBody(plainSelect);
        }
    }
    @Override
    public String getSmartCountSql(String sql, String countColumn) {
        Statement stmt = null;
        if (sql.indexOf("/*keep orderby*/") < 0 && !this.keepOrderBy()) {
            try {
                stmt = this.jSqlParser.parse(sql);
            } catch (Throwable var10) {
                return this.getSimpleCountSql(sql);
            }

            Select select = (Select)stmt;
            SelectBody selectBody = select.getSelectBody();

            try {
                this.processSelectBody(selectBody);
            } catch (Exception var9) {
                return this.getSimpleCountSql(sql);
            }

            this.processWithItemsList(select.getWithItemsList());
            this.sqlToCount(select, countColumn);
            String result = select.toString();
            if (selectBody instanceof PlainSelect) {
                Token token = ((PlainSelect)selectBody).getASTNode().jjtGetFirstToken().specialToken;
                if (token != null) {
                    String hints = token.toString().trim();
                    if (hints.startsWith("/*") && hints.endsWith("*/") && !result.startsWith("/*")) {
                        result = hints + result;
                    }
                }
            }

            return result;
        } else {
            return this.getSimpleCountSql(sql, countColumn);
        }
    }
}
