/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.service;

import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.component.AiEmbeddingService;
import com.digiwin.athenai.domain.AiSearchRequest;
import com.digiwin.athenai.nl2sql.connector.config.DbConfig;
import com.digiwin.athenai.nl2sql.constant.PromptConstant;
import com.digiwin.athenai.nl2sql.entity.SchemaInfo;
import com.digiwin.athenai.nl2sql.entity.TableInfo;
import com.digiwin.athenai.nl2sql.service.SchemaInfoService;
import com.digiwin.athenai.nl2sql.service.TableInfoService;
import com.digiwin.athenai.utils.PromptUtil;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;

@Component
public class SqlAsistantService {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(SqlAsistantService.class);
    @Autowired
    SchemaInfoService schemaInfoService;
    @Autowired
    AiEmbeddingService aiEmbeddingService;
    @Autowired
    @Lazy
    TableInfoService tableInfoService;
    @Autowired
    AiChatService aiChatService;

    public DbConfig getDbConfig(String agent) {
        SchemaInfo dbConfig = this.schemaInfoService.getSchemaInfoByAgent(agent);
        if (null != dbConfig && null == dbConfig.getDialectType()) {
            dbConfig.setDialectType("mysql");
        }
        return dbConfig;
    }

    public List<Document> evidences(String agent, String query) {
        AiSearchRequest request = new AiSearchRequest();
        request.setQuery(query);
        request.setAgent(agent);
        request.setFilterExpression("type == 'evidence'");
        return this.aiEmbeddingService.searchDocuments(request);
    }

    public List<TableInfo> tables(String agent, String query, List<Document> evidences) {
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        String content = evidencesStr + System.lineSeparator() + query;
        AiSearchRequest request = new AiSearchRequest();
        request.setQuery(content);
        request.setAgent(agent);
        request.setFilterExpression("vectorType == 'table'");
        request.setTopK(30);
        List schemas = this.aiEmbeddingService.searchDocuments(request);
        ArrayList<String> tables = new ArrayList<String>();
        for (Document document : schemas) {
            String table = (String)document.getMetadata().get("name");
            if (!StringUtils.isNoneBlank((CharSequence[])new CharSequence[]{table})) continue;
            tables.add(table);
        }
        if (tables.size() == 0) {
            throw new RuntimeException("\u6ca1\u6709\u627e\u5230\u8868");
        }
        List<TableInfo> tableInfos = this.tableInfoService.findByCodes(tables, agent);
        return tableInfos;
    }

    public List<TableInfo> tables2(String agent, String query, List<Document> evidences) {
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        String content = evidencesStr + System.lineSeparator() + query;
        AiSearchRequest request = new AiSearchRequest();
        request.setQuery(content);
        request.setAgent(agent);
        request.setFilterExpression("vectorType == 'tableSummary'");
        request.setTopK(100);
        List schemas = this.aiEmbeddingService.searchDocuments(request);
        log.info("tableSummary size:" + schemas.size());
        if (schemas.size() == 0) {
            throw new RuntimeException("\u6ca1\u6709\u627e\u5230\u8868");
        }
        String schemaStr = SqlAsistantService.formatTablePrompt(schemas);
        String prompt = PromptConstant.sql_retrive_table_from_summary().render(Map.of("query", query, "schema", schemaStr));
        System.out.println(prompt);
        String tablesCodes = this.aiChatService.simpleCall(prompt);
        List tables = (List)new Gson().fromJson(tablesCodes, new TypeToken<List<String>>(this){}.getType());
        log.info("tableSummary get tables:" + String.valueOf(tables));
        if (tables.size() == 0) {
            throw new RuntimeException("\u6ca1\u6709\u627e\u5230\u8868");
        }
        List<TableInfo> tableInfos = this.tableInfoService.findByCodes(tables, agent);
        return tableInfos;
    }

    public static String formatTablePrompt(List<Document> schemas) {
        StringBuilder stringBuilder = new StringBuilder();
        for (Document tableInfo : schemas) {
            String name = (String)tableInfo.getMetadata().get("name");
            String desc = (String)tableInfo.getMetadata().get("description");
            String foreignKey = (String)tableInfo.getMetadata().get("foreignKey");
            String text = tableInfo.getText();
            stringBuilder.append(System.lineSeparator()).append("\u8868:").append(name).append("").append(System.lineSeparator());
            stringBuilder.append("\u540d\u79f0:").append(desc).append(System.lineSeparator());
            if (null != foreignKey) {
                stringBuilder.append("\u5916\u952e:").append(foreignKey).append(System.lineSeparator());
            }
            stringBuilder.append("\u63cf\u8ff0:").append(System.lineSeparator());
            stringBuilder.append(text).append(System.lineSeparator());
        }
        stringBuilder.append(System.lineSeparator()).append(System.lineSeparator());
        return stringBuilder.toString();
    }
}

