package com.digiwin.dap.nest.infrastructure.middleware.rdb.mybatis.page;

import lombok.SneakyThrows;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.reflection.SystemMetaObject;

import java.sql.ResultSet;
import java.sql.Statement;
import java.util.Collections;

public class DwMybatisPageProcessor {
    private DwMybatisPageProcessor() {
    }

    @SneakyThrows
    public static Object getPageCount(Invocation invocation) {
        Statement stmt = (Statement) invocation.getArgs()[0];
        ResultSet rs = stmt.getResultSet();
        rs.next();
        DwMybatisPage.PageThreadLocal.get().setCount(rs.getLong(1));
        return Collections.emptyList();
    }

    public static boolean executePageCount() {
        DwMybatisPageContext jaMybatisPageContext = DwMybatisPage.PageThreadLocal.get();
        return null != jaMybatisPageContext && jaMybatisPageContext.getIsCount();
    }

    public static void executePage(BoundSql boundSql) {
        DwMybatisPageContext jaMybatisPageContext = DwMybatisPage.PageThreadLocal.get();
        if (null != jaMybatisPageContext) {
            String sql = boundSql.getSql();
            if (Boolean.TRUE.equals(jaMybatisPageContext.getIsCount())) {
                sql = wrapCountSql(sql);
            } else {
                sql = wrapPageSql(sql, jaMybatisPageContext.getPageSize(), jaMybatisPageContext.getPageIndex());
            }
            SystemMetaObject.forObject(boundSql).setValue("sql", sql);
        }
    }

    private static String wrapPageSql(String sql, Integer pageSize, Integer pageIndex) {
        if (sql.contains("union")) {
            sql = String.format("select * from (%s) t", sql);
        }
        return sql + " limit " + (pageIndex - 1) * pageSize + ", " + pageSize;
    }

    private static String wrapCountSql(String sql) {
        String toLowerCaseSql = sql.toLowerCase();
        if (toLowerCaseSql.contains("union") || toLowerCaseSql.contains("distinct")) {
            return String.format("select count(1) count from (%s) t", sql);
        }
        int from = toLowerCaseSql.indexOf("from");
        if (toLowerCaseSql.lastIndexOf("from") != from) {
            return String.format("select count(1) count from (%s) t", sql);
        }

        // remove orderBy
        int orderBy = toLowerCaseSql.lastIndexOf("order by");
        // remove select xxx; xx, xxx, xxxx这些是费字段


        boolean orderFlag = orderBy > -1 && orderBy > from;

        String fromSql = orderFlag ? sql.substring(from, orderBy) : sql.substring(from);
        String toLowerCaseFromSql = orderFlag ? toLowerCaseSql.substring(from, orderBy) : toLowerCaseSql.substring(from);

        if (toLowerCaseFromSql.contains("group by")) {
            return String.format("select count(1) count from (select 1 %s) t", fromSql);
        }

        return "select count(1) count " + fromSql;
    }


}
