/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.service;

import com.digiwin.athenai.component.AiEmbeddingService;
import com.digiwin.athenai.domain.AiSearchRequest;
import com.digiwin.athenai.nl2sql.connector.config.DbConfig;
import com.digiwin.athenai.nl2sql.entity.SchemaInfo;
import com.digiwin.athenai.nl2sql.service.SchemaInfoService;
import com.digiwin.athenai.utils.PromptUtil;
import java.util.List;
import org.springframework.ai.document.Document;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

@Component
public class SqlAsistantService {
    @Autowired
    SchemaInfoService schemaInfoService;
    @Autowired
    AiEmbeddingService aiEmbeddingService;

    public DbConfig getDbConfig(String agent) {
        SchemaInfo dbConfig = this.schemaInfoService.getSchemaInfoByAgent(agent);
        if (null != dbConfig && null == dbConfig.getDialectType()) {
            dbConfig.setDialectType("mysql");
        }
        return dbConfig;
    }

    public List<Document> evidences(String agent, String query) {
        AiSearchRequest request = new AiSearchRequest();
        request.setQuery(query);
        request.setAgent(agent);
        request.setFilterExpression("type == 'business'");
        return this.aiEmbeddingService.expandSearchDocuments(request);
    }

    public List<Document> tables(String agent, String query, List<Document> evidences) {
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        List schemas = null;
        if (CollectionUtils.isEmpty(evidences)) {
            AiSearchRequest request = new AiSearchRequest();
            request.setQuery(query);
            request.setAgent(agent);
            request.setFilterExpression("type == 'table'");
            schemas = this.aiEmbeddingService.expandSearchDocuments(request);
        } else {
            String content = evidencesStr + "\n" + query;
            AiSearchRequest request = new AiSearchRequest();
            request.setQuery(content);
            request.setAgent(agent);
            request.setFilterExpression("type == 'table'");
            schemas = this.aiEmbeddingService.searchDocuments(request);
        }
        return schemas;
    }
}

