package com.digiwin.dap.middle.console.serice;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Alias;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;

import java.util.Collections;

public class CountSqlBuilder {

    /**
     * 根据查询语句生成对应的count语句，用于分页查询 <br/>
     * 当查询语句内不含分页信息时，返回null <br/>
     *
     * @param sql 查询SQL
     * @return 计数SQL
     */
    public static String buildCountSql(String sql) {
        try {
            Statement statement = CCJSqlParserUtil.parse(sql);
            if (!(statement instanceof Select)) {
                throw new IllegalArgumentException("仅支持 SELECT 语句");
            }

            Select select = (Select) statement;
            SelectBody body = select.getSelectBody();

            // 判断是否含分页参数
            boolean hasPaging = hasPaging(body);
            if (!hasPaging) {
                return null;
            }

            // 移除分页、排序
            removePagingAndOrdering(body);

            // 处理 SetOperationList（如 UNION）
            if (body instanceof SetOperationList) {
                return wrapCountSubSelect(body);
            }

            // 处理普通 SELECT
            if (body instanceof PlainSelect) {
                PlainSelect ps = (PlainSelect) body;
                if (ps.getGroupBy() != null) {
                    return wrapCountSubSelect(ps);
                } else {
                    return buildSimpleCountSelect(ps);
                }
            }

            return null;
        } catch (JSQLParserException e) {
            throw new IllegalArgumentException("SQL 解析失败: " + e.getMessage(), e);
        }
    }

    /**
     * 检测不到分页 SELECT name FROM student UNION SELECT name FROM teacher LIMIT 10
     * <p>
     * 可以检测到分页 SELECT id FROM a UNION ALL SELECT id FROM b ORDER BY id DESC LIMIT 20
     *
     * @param body sql
     * @return 是否包含分页
     */
    private static boolean hasPaging(SelectBody body) {
        if (body instanceof PlainSelect) {
            PlainSelect ps = (PlainSelect) body;
            return ps.getLimit() != null || ps.getOffset() != null || ps.getFetch() != null;
        } else if (body instanceof SetOperationList) {
            SetOperationList set = (SetOperationList) body;
            return set.getLimit() != null || set.getOffset() != null || set.getFetch() != null;
        }
        return false;
    }

    private static void removePagingAndOrdering(SelectBody body) {
        if (body instanceof PlainSelect) {
            PlainSelect ps = (PlainSelect) body;
            ps.setLimit(null);
            ps.setOffset(null);
            ps.setFetch(null);
            ps.setOrderByElements(null);
        } else if (body instanceof SetOperationList) {
            SetOperationList set = (SetOperationList) body;
            set.setLimit(null);
            set.setOffset(null);
            set.setFetch(null);
            set.setOrderByElements(null);
        }
    }

    private static String wrapCountSubSelect(SelectBody body) {
        SubSelect subSelect = new SubSelect();
        subSelect.setSelectBody(body);
        subSelect.setAlias(new Alias("tmp_count"));

        PlainSelect countSelect = new PlainSelect();
        countSelect.setFromItem(subSelect);
        countSelect.setSelectItems(Collections.singletonList(new SelectExpressionItem(new Column("COUNT(*)"))));

        return new Select().withSelectBody(countSelect).toString();
    }

    private static String buildSimpleCountSelect(PlainSelect ps) {
        Function countFunction = new Function();
        countFunction.setName("COUNT");
        ExpressionList list = new ExpressionList();
        list.setExpressions(Collections.singletonList(new AllColumns()));
        countFunction.setParameters(list);

        SelectExpressionItem countItem = new SelectExpressionItem(countFunction);

        PlainSelect countSelect = new PlainSelect();
        countSelect.setSelectItems(Collections.singletonList(countItem));
        countSelect.setFromItem(ps.getFromItem());
        countSelect.setJoins(ps.getJoins());
        countSelect.setWhere(ps.getWhere());

        return new Select().withSelectBody(countSelect).toString();
    }

}
