/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.service;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.fastjson.JSON;
import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.domain.KnowledgeBase;
import com.digiwin.athenai.nl2sql.connector.config.DbConfig;
import com.digiwin.athenai.nl2sql.constant.PromptConstant;
import com.digiwin.athenai.nl2sql.dto.SqlResultDo;
import com.digiwin.athenai.nl2sql.entity.TableInfo;
import com.digiwin.athenai.nl2sql.service.SqlAsistantService;
import com.digiwin.athenai.nl2sql.service.SqlExecutorService;
import com.digiwin.athenai.nl2sql.utils.Nl2SqlUtils;
import com.digiwin.athenai.tool.TimeTool;
import com.digiwin.athenai.utils.MarkdownParser;
import com.digiwin.athenai.utils.PromptUtil;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

@Service
public class Nl2SqlService {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(Nl2SqlService.class);
    private final CompiledGraph nl2sqlGraph;
    @Autowired
    SqlExecutorService sqlExecutorService;
    @Autowired
    AiChatService aiChatService;
    @Autowired
    SqlAsistantService sqlAsistantService;
    static String defaultSqlDemo = "[\u95ee\u9898]\n\u67e5\u8be2\u6700\u8fd1\u5b8c\u5de5\u7684\u5de5\u5355\n[\u8f93\u51fa]\n{\"sql\":\"SELECT t0.MO_ID AS MO_ID, t0.ITEM_DESCRIPTION AS ITEM_DESCRIPTION, t0.ITEM_SPECIFICATION AS ITEM_SPECIFICATION, t0.BOM_VERSION_TIMES AS BOM_VERSION_TIMES, t0.BOM_DATE AS BOM_DATE, t0.URGENT AS URGENT, t0.STATUS AS STATUS, t0.LOT_MO_FLAG AS LOT_MO_FLAG, t0.PLAN_QTY AS PLAN_QTY, t0.REQ_QTY AS REQ_QTY, t0.COMPLETED_QTY AS COMPLETED_QTY, t0.SCRAP_QTY AS SCRAP_QTY, t0.DESTROYED_QTY AS DESTROYED_QTY, t0.PLAN_START_DATE AS PLAN_START_DATE, t0.PLAN_COMPLETE_DATE AS PLAN_COMPLETE_DATE, t0.ACTUAL_START_DATE AS ACTUAL_START_DATE, t0.ACTUAL_COMPLETE_DATE AS ACTUAL_COMPLETE_DATE FROM MO AS t0 WHERE t0.STATUS = 'Y' and t0.ACTUAL_COMPLETE_DATE >= DATEADD(day, -7, GETDATE()) ORDER BY t0.ACTUAL_COMPLETE_DATE DESC\",\"tables\":[\"MO\"]}\n\n[\u95ee\u9898]\n\u67e5\u770b\u5df2\u5ba1\u6838\u7684\u5de5\u5355\n[\u8f93\u51fa]\n{\"sql\":\"SELECT t0.MO_ID AS MO_ID, t0.ITEM_DESCRIPTION AS ITEM_DESCRIPTION, t0.ITEM_SPECIFICATION AS ITEM_SPECIFICATION, t0.BOM_VERSION_TIMES AS BOM_VERSION_TIMES, t0.BOM_DATE AS BOM_DATE, t0.URGENT AS URGENT, t0.STATUS AS STATUS, t0.LOT_MO_FLAG AS LOT_MO_FLAG, t0.PLAN_QTY AS PLAN_QTY, t0.REQ_QTY AS REQ_QTY, t0.COMPLETED_QTY AS COMPLETED_QTY, t0.SCRAP_QTY AS SCRAP_QTY, t0.DESTROYED_QTY AS DESTROYED_QTY, t0.ApproveStatus AS ApproveStatus, t0.ApproveDate AS ApproveDate, t1.USER_NAME as USER_NAME FROM [MO] AS t0 left join [USER] t1 on t0.ApproveBy=t1.USER_ID WHERE t0.ApproveStatus = 'Y' ORDER BY t0.ApproveStatus DESC\",\"tables\":[\"MO\",\"USER\"]}\n\n[\u95ee\u9898]\n\u67e5\u8be223\u5e7410\u6708\u672a\u5ba1\u6838\u7684\u5de5\u5355\uff0c\u663e\u793a\u5de5\u5382\u3001\u7269\u6599\u3001\u5de5\u827a\u3001\u9879\u76ee\u3001\u6570\u91cf\u3001\u72b6\u6001\u3001\u6765\u6e90\u3001\u90e8\u95e8\u7b49\uff0c\u5e76\u6309\u5de5\u5355\u53f7\u548c\u65e5\u671f\u6392\u5e8f\uff0c\u4ee5\u4fbf\u5ba1\u6279\u548c\u751f\u4ea7\u8ba1\u5212\u7ba1\u7406\n[\u8f93\u51fa]\n{\"sql\":\"SELECT t0.MO_ID, t1.PLANT_NAME AS Factory, t0.ITEM_DESCRIPTION AS Material, t9.ROUTING_DES AS Process, t15.PROJECT_NAME AS Project, t0.PLAN_QTY AS Quantity, t0.STATUS AS Status, CASE t0.SOURCE_ID_RTK WHEN 'SUPPLIER' THEN t6.SUPPLIER_NAME WHEN 'WORK_CENTER' THEN t7.WORK_CENTER_NAME END AS Source, t13.ADMIN_UNIT_NAME AS Department, t19.USER_NAME AS Approver FROM MO AS t0 LEFT JOIN PLANT AS t1 ON t0.Owner_Org_ROid = t1.PLANT_ID LEFT JOIN ITEM_ROUTING AS t9 ON t0.ITEM_ROUTING_ID = t9.ITEM_ROUTING_ID LEFT JOIN PROJECT AS t15 ON t0.PROJECT_ID = t15.PROJECT_ID LEFT JOIN SUPPLIER AS t6 ON t0.SOURCE_ID_ROid = t6.SUPPLIER_BUSINESS_ID LEFT JOIN WORK_CENTER AS t7 ON t0.SOURCE_ID_ROid = t7.WORK_CENTER_ID LEFT JOIN ADMIN_UNIT AS t13 ON t0.Owner_Dept = t13.ADMIN_UNIT_ID LEFT JOIN [USER] AS t19 ON t0.ApproveBy = t19.USER_ID WHERE t0.DOC_DATE >= '2023-10-01' AND t0.DOC_DATE <= '2023-10-31' AND t0.ApproveStatus = 'N' ORDER BY t0.DOC_NO ASC, t0.DOC_DATE ASC;\",\"tables\":[\"MO\",\"PLANT\",\"ITEM_ROUTING\",\"PROJECT\",\"SUPPLIER\",\"WORK_CENTER\",\"ADMIN_UNIT\",\"USER\"]}\n\n";

    public Nl2SqlService(@Qualifier(value="nl2sqlGraph") StateGraph stateGraph) throws GraphStateException {
        this.nl2sqlGraph = stateGraph.compile();
        this.nl2sqlGraph.setMaxIterations(100);
    }

    public String nl2sql(String agent, String query) throws Exception {
        return this.query1(agent, query).getData();
    }

    private SqlResultDo generateSql(SqlResultDo dto, Map<String, Object> params, String originSql, String errorMsg) {
        if (null != errorMsg && null != originSql) {
            params.put("error_sql", originSql);
            params.put("error_message", errorMsg);
            String user = PromptConstant.getSqlErrorFixerPromptTemplate().render(params);
            String content = this.aiChatService.getOrCreateChatClient(null).prompt().user(user).tools(new Object[]{new TimeTool()}).call().content();
            String sql = MarkdownParser.extractRawText((String)content);
            dto.setSql(sql);
            log.info("nl2sqlFix: {} =>{} ", params.get("question"), (Object)sql);
        } else {
            String system = PromptConstant.sql_generator_system().render(params);
            String user = PromptConstant.sql_generator_user().render(params);
            String content = this.aiChatService.getOrCreateChatClient(null).prompt().system(system).user(user).tools(new Object[]{new TimeTool()}).call().content();
            log.info("nl2sql: {} =>{} ", params.get("question"), (Object)content);
            SqlResultDo dto2 = (SqlResultDo)JSON.parseObject((String)content, SqlResultDo.class);
            dto.setSql(dto2.getSql());
            dto.setTables(dto2.getTables());
        }
        return dto;
    }

    public String executeSql(String agent, String sql) throws Exception {
        String result = this.sqlExecutorService.executeSql(agent, sql);
        return result;
    }

    public SqlResultDo generateAndExecute(String agent, Map<String, Object> params) {
        SqlResultDo result = new SqlResultDo();
        this.innerGenerateAndExecute(result, agent, params, new AtomicInteger(3), null, null);
        return result;
    }

    public SqlResultDo query1(String agent, String query) throws Exception {
        Assert.notNull((Object)agent, (String)"agent is null");
        Assert.notNull((Object)query, (String)"query is null");
        DbConfig dbConfig = this.sqlAsistantService.getDbConfig(agent);
        Assert.notNull((Object)dbConfig, (String)"dbConfig is null");
        List<Document> evidences = this.sqlAsistantService.evidences(agent, query);
        log.info("retrieve evidences:" + evidences.size());
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        List<TableInfo> schemas = this.sqlAsistantService.tables(agent, query, evidences);
        Assert.notEmpty(schemas, (String)("no table info found by query: " + query));
        String tableNames = String.join((CharSequence)",", schemas.stream().map(KnowledgeBase::getCode).toList());
        log.info("retrieve tables:" + schemas.size() + " :" + tableNames);
        String schemaStr = Nl2SqlUtils.formatTablePrompt(schemas);
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("dialect", dbConfig.getDialectType());
        params.put("question", query);
        params.put("schema_info", schemaStr);
        params.put("evidence", evidencesStr);
        params.put("eg", defaultSqlDemo);
        SqlResultDo result = this.generateAndExecute(agent, params);
        return result;
    }

    public SqlResultDo query2(String agent, String query) throws Exception {
        Assert.notNull((Object)agent, (String)"agent is null");
        Assert.notNull((Object)query, (String)"query is null");
        DbConfig dbConfig = this.sqlAsistantService.getDbConfig(agent);
        Assert.notNull((Object)dbConfig, (String)"dbConfig is null");
        List<Document> evidences = this.sqlAsistantService.evidences(agent, query);
        log.info("retrieve evidences:" + evidences.size());
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        List<TableInfo> schemas = this.sqlAsistantService.tables2(agent, query, evidences);
        Assert.notEmpty(schemas, (String)("no table info found by query: " + query));
        String tableNames = String.join((CharSequence)",", schemas.stream().map(KnowledgeBase::getCode).toList());
        log.info("retrieve tables:" + schemas.size() + " :" + tableNames);
        String schemaStr = Nl2SqlUtils.formatTablePrompt(schemas);
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("dialect", dbConfig.getDialectType());
        params.put("question", query);
        params.put("schema_info", schemaStr);
        params.put("evidence", evidencesStr);
        params.put("eg", defaultSqlDemo);
        SqlResultDo result = this.generateAndExecute(agent, params);
        return result;
    }

    private SqlResultDo innerGenerateAndExecute(SqlResultDo dto, String agent, Map<String, Object> params, AtomicInteger counter, String originSql, String errorMsg) {
        if (counter.getAndIncrement() <= 0) {
            throw new RuntimeException("\u8d85\u8fc7\u6700\u5927\u6267\u884c\u6b21\u6570,\u751f\u6210sql\u5e76\u6267\u884c\u5931\u8d25:" + errorMsg);
        }
        this.generateSql(dto, params, originSql, errorMsg);
        String sql = dto.getSql();
        try {
            String result = this.sqlExecutorService.executeSql(agent, sql);
            dto.setData(result);
            dto.setSchemas((String)params.get("schema_info"));
        }
        catch (Exception e) {
            this.innerGenerateAndExecute(dto, agent, params, counter, sql, e.getMessage());
        }
        return dto;
    }

    public Map<String, Object> invoke(Map<String, Object> params) throws GraphRunnerException {
        System.out.println(Thread.currentThread().getName() + " running graph " + String.valueOf(params));
        Optional result = this.nl2sqlGraph.invoke(params);
        return ((OverAllState)result.get()).data();
    }

    public CompiledGraph getNl2sqlGraph() {
        return this.nl2sqlGraph;
    }

    public static void main(String[] args) {
        SqlResultDo do1 = new SqlResultDo();
        do1.setSql("aaaa");
        System.out.println(do1.getSql());
    }
}

