/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql2.service;

import ch.qos.logback.core.util.StringUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.component.AiEmbeddingService;
import com.digiwin.athenai.domain.AiSearchRequest;
import com.digiwin.athenai.entity.AgentScenario;
import com.digiwin.athenai.nl2sql.connector.config.DbConfig;
import com.digiwin.athenai.nl2sql.dto.SqlResultDo;
import com.digiwin.athenai.nl2sql.entity.TableInfo;
import com.digiwin.athenai.nl2sql.service.Nl2SqlService;
import com.digiwin.athenai.nl2sql.service.SqlAsistantService;
import com.digiwin.athenai.nl2sql.service.SqlExecutorService;
import com.digiwin.athenai.nl2sql.service.TableInfoService;
import com.digiwin.athenai.nl2sql.utils.Nl2SqlUtils;
import com.digiwin.athenai.service.AgentScenarioService;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.template.TemplateRenderer;
import org.springframework.ai.template.st.StTemplateRenderer;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Service;
import org.springframework.util.Assert;

@Service
public class ScenarioParseService {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(ScenarioParseService.class);
    static String system = "         \u8bf7\u6839\u636e\u7528\u6237\u7684\u610f\u56fe\u4e0e\u573a\u666f\u5b9a\u4e49\u8fdb\u884c\u5339\u914d,\u5982\u679c\u547d\u4e2d\u5219\u4ee5json\u7ed3\u6784\u8fd4\u56de\u573a\u666fcode\u548cname,\u5982\u679c\u672a\u547d\u4e2d\u5219\u7b80\u8981\u8bf4\u660e\u539f\u56e0.\n# \u6837\u4f8b\n**\u7528\u6237\u610f\u56fe**\n\u67e5\u770b\u672a\u5ba1\u6838\u5de5\u5355\n\n**\u573a\u666f\u5b9a\u4e49**\n         \u573a\u666f1\n         code:e1030202approve\n         name:\u5de5\u5355\u5ba1\u6838\n         \u63cf\u8ff0:\u5de5\u5355\u5ba1\u6838\n\n         \u573a\u666f2\n         code:e103001\n         name:\u5de5\u5355\u67e5\u8be2\n         \u63cf\u8ff0:\u5de5\u5355\u67e5\u8be2\n\n         \u573a\u666f3\n         code:e1030202que_zhiling\n         name:\u5236\u4ee4\u5de5\u5355\u67e5\u8be2\n         \u63cf\u8ff0:\u5236\u4ee4\u5de5\u5355\u67e5\u8be2\n\n         \u573a\u666f4\n         code:e1030202create\n         name:\u5de5\u5355\u521b\u5efa\n         \u63cf\u8ff0:\u5de5\u5355\u521b\u5efa\n\n         **\u7ed3\u679c**\n         {\"code\":\"e103001\",\"name\":\"\u5de5\u5355\u67e5\u8be2\"}\n\n# \u9a8c\u8bc1\n**\u7528\u6237\u610f\u56fe**\n<query>\n\n**\u573a\u666f\u5b9a\u4e49**\n<scenario>\n\n **\u7ed3\u679c**\n\n";
    PromptTemplate promptTemplate = PromptTemplate.builder().renderer((TemplateRenderer)StTemplateRenderer.builder().startDelimiterToken('<').endDelimiterToken('>').build()).template(system).build();
    @Autowired
    AiChatService chatService;
    @Autowired
    AiEmbeddingService embeddingService;
    @Autowired
    Nl2SqlService nl2SqlService;
    @Autowired
    SqlExecutorService sqlExecutorService;
    @Autowired
    SqlAsistantService sqlAsistantService;
    @Autowired
    AiChatService aiChatService;
    @Autowired
    AiEmbeddingService aiEmbeddingService;
    @Autowired
    @Lazy
    TableInfoService tableInfoService;
    @Autowired
    AgentScenarioService agentScenarioService;

    public SqlResultDo scenarioQuery(String agent, String query) throws Exception {
        AgentScenario scenario;
        String code = this.findScenario(agent, query);
        if (null != code && null != (scenario = (AgentScenario)this.agentScenarioService.findByAgentAndCode(agent, code))) {
            return this.scenarioSqlQuery(scenario, query);
        }
        return this.nl2SqlService.query1(agent, query);
    }

    public String findScenario(String agent, String query) {
        String agent_Scenario = agent + "_scenario";
        AiSearchRequest searchRequest = new AiSearchRequest();
        searchRequest.setAgent(agent_Scenario);
        searchRequest.setTopK(4);
        searchRequest.setSimilarityThreshold(0.4);
        searchRequest.setQuery(query);
        List docs = this.embeddingService.searchDocuments(searchRequest);
        log.info("match scenario docs:" + docs.size());
        log.info(String.valueOf(docs));
        String scenarioResult = null;
        if (docs.size() > 0) {
            StringBuilder sb = new StringBuilder();
            int i = 1;
            for (Document doc : docs) {
                Object code = doc.getMetadata().get("code");
                Object name = doc.getMetadata().get("name");
                String text = doc.getText();
                if (null == code) continue;
                sb.append("\u573a\u666f").append(i).append(System.lineSeparator()).append("code:").append(code).append(System.lineSeparator()).append("name:").append(name).append(System.lineSeparator()).append("\u63cf\u8ff0:").append(text).append(System.lineSeparator());
                sb.append(System.lineSeparator());
                ++i;
            }
            String prompt = this.promptTemplate.render(Map.of("query", query, "scenario", sb.toString()));
            String content = this.chatService.simpleCall(prompt);
            try {
                AgentScenario knowledgeBase = (AgentScenario)com.alibaba.fastjson2.JSON.parseObject((String)content, AgentScenario.class);
                scenarioResult = knowledgeBase.getCode();
                if (io.micrometer.common.util.StringUtils.isNotEmpty((String)scenarioResult)) {
                    log.info("\u547d\u4e2d\u573a\u666f," + com.alibaba.fastjson2.JSON.toJSONString((Object)knowledgeBase));
                } else {
                    log.info("\u672a\u547d\u4e2d\u573a\u666f1:" + content);
                }
            }
            catch (Exception e) {
                log.info("\u672a\u547d\u4e2d\u573a\u666f2:" + content);
            }
        }
        return scenarioResult;
    }

    public SqlResultDo scenarioSqlQuery(AgentScenario scenario, String query) {
        String agent = scenario.getAgent();
        DbConfig dbConfig = this.sqlAsistantService.getDbConfig(agent);
        Assert.notNull((Object)dbConfig, (String)"dbConfig is null");
        HashSet<String> tables = new HashSet<String>();
        if (null != scenario.getCoreTableStr()) {
            try {
                JSONArray jsonArray = JSON.parseArray((String)scenario.getCoreTableStr());
                jsonArray.forEach(x -> tables.add((String)x));
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        AiSearchRequest request = new AiSearchRequest();
        request.setQuery(query);
        request.setAgent(agent);
        request.setFilterExpression("vectorType == 'table'");
        request.setTopK(20);
        List schemas = this.aiEmbeddingService.searchDocuments(request);
        for (Document document : schemas) {
            String table = (String)document.getMetadata().get("name");
            if (!StringUtils.isNoneBlank((CharSequence[])new CharSequence[]{table})) continue;
            tables.add(table);
        }
        List<TableInfo> tableInfos = this.tableInfoService.findByCodes(tables.stream().toList(), agent);
        String schemaStr = Nl2SqlUtils.formatTablePrompt(tableInfos);
        HashMap<String, Object> params = new HashMap<String, Object>();
        params.put("dialect", dbConfig.getDialectType());
        params.put("question", query);
        params.put("evidence", "");
        params.put("schema_info", schemaStr);
        params.put("eg", StringUtil.nullStringToEmpty((String)scenario.getSqlDemoStr()));
        SqlResultDo result = this.nl2SqlService.generateAndExecute(agent, params);
        result.setSchemas(schemaStr);
        return result;
    }
}

