/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.constant.StreamResponseType;
import com.digiwin.athenai.dashscope.util.StateUtils;
import com.digiwin.athenai.dashscope.util.StreamingChatGeneratorUtil;
import com.digiwin.athenai.model.ExecutionStep;
import com.digiwin.athenai.nl2sql.connector.bo.ResultSetBO;
import com.digiwin.athenai.nl2sql.connector.config.DbConfig;
import com.digiwin.athenai.nl2sql.constant.SqlPrompts;
import com.digiwin.athenai.nl2sql.node.AbstractPlanBasedNode;
import com.digiwin.athenai.nl2sql.service.SqlAsistantService;
import com.digiwin.athenai.nl2sql.service.SqlExecutorService;
import com.digiwin.athenai.nl2sql.utils.NodeUtil;
import com.digiwin.athenai.utils.ChatResponseUtil;
import com.digiwin.athenai.utils.MarkdownParser;
import com.digiwin.athenai.utils.StepResultUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;

public class SqlExecuteNode
extends AbstractPlanBasedNode {
    private static final Logger logger = LoggerFactory.getLogger(SqlExecuteNode.class);
    SqlExecutorService sqlExecutorService;
    SqlAsistantService sqlAsistantService;
    AiChatService aiChatService;
    int sqlRetryCount = 3;

    public SqlExecuteNode(SqlExecutorService sqlExecutorService, SqlAsistantService sqlAsistantService, AiChatService aiChatService) {
        this.sqlExecutorService = sqlExecutorService;
        this.sqlAsistantService = sqlAsistantService;
        this.aiChatService = aiChatService;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        this.logNodeEntry();
        String agent = NodeUtil.agentFromState(state);
        String query = NodeUtil.queryFromState(state);
        String tablesStr = (String)state.value("tables", (Object)"");
        DbConfig dbConfig = this.sqlAsistantService.getDbConfig(agent);
        Assert.notNull((Object)dbConfig, (String)"dbConfig is null");
        ExecutionStep executionStep = this.getCurrentExecutionStep(state);
        Integer currentStep = this.getCurrentStepNumber(state);
        ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
        String sqlQuery = toolParameters.getSqlQuery();
        logger.info("Executing SQL query: {}", (Object)sqlQuery);
        logger.info("Step  description: {}", (Object)currentStep, (Object)toolParameters.getDescription());
        AtomicInteger atomicInteger = new AtomicInteger(this.sqlRetryCount);
        return this.executeSqlQuery(dbConfig, state, currentStep, sqlQuery, tablesStr, query, atomicInteger, agent, null, new HashMap<String, Object>());
    }

    private Map<String, Object> executeSqlQuery(DbConfig dbConfig, OverAllState state, Integer currentStep, String sqlQuery, String schema, String userQuery, AtomicInteger count, String agent, String error, Map<String, Object> resultLast) {
        int c = count.getAndDecrement();
        if (c < 0) {
            logger.warn("Sql retry count exceeds limit {}", (Object)this.sqlRetryCount);
            return resultLast;
        }
        HashMap<String, String> params = new HashMap<String, String>();
        params.put("dialect", dbConfig.getDialectType());
        params.put("question", userQuery);
        params.put("schema_info", schema);
        params.put("originSql", sqlQuery);
        Object prompt = SqlPrompts.sql_generator_rewrite().render(params);
        if (StringUtils.isNotEmpty((CharSequence)error)) {
            prompt = (String)prompt + "\n\u3010\u539f\u59cbsql\u6267\u884c\u62a5\u9519\u3011\n\n" + error;
        }
        String newSql = this.aiChatService.simpleCall((String)prompt);
        newSql = MarkdownParser.extractRawText((String)newSql);
        logger.info("rewrite sql: {} => {} ", (Object)sqlQuery, (Object)newSql);
        try {
            ResultSetBO resultSetBO = this.sqlExecutorService.executeSqlAndReturnObject(agent, newSql);
            String jsonStr = resultSetBO.toJsonStr();
            Map existingResults = (Map)StateUtils.getObjectValue((OverAllState)state, (String)"SQL_EXECUTE_NODE_OUTPUT", Map.class, new HashMap());
            Map updatedResults = StepResultUtils.addStepResult((Map)existingResults, (Integer)currentStep, (String)jsonStr);
            logger.info("SQL execution successful, result count: {}", (Object)(resultSetBO.getData() != null ? resultSetBO.getData().size() : 0));
            Map<String, Integer> result = Map.of("SQL_EXECUTE_NODE_OUTPUT", updatedResults, "result", updatedResults, "SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", "", "SQL_RESULT_LIST_MEMORY", resultSetBO.getData(), "PLAN_CURRENT_STEP", currentStep + 1);
            Flux displayFlux = Flux.create(emitter -> {
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\n\u5f00\u59cb\u6267\u884cSQL..."));
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u6267\u884cSQL\u67e5\u8be2"));
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)("``` " + sqlQuery + "```")));
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u6267\u884cSQL\u5b8c\u6210"));
                emitter.complete();
            });
            AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, v -> result, (Flux)displayFlux, (StreamResponseType)StreamResponseType.EXECUTE_SQL);
            return Map.of("SQL_EXECUTE_NODE_OUTPUT", generator);
        }
        catch (Exception e) {
            String errorMessage = e.getMessage();
            logger.error("SQL execution failed - SQL: [{}] ", (Object)sqlQuery, (Object)e);
            Map<String, String> errorResult = Map.of("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", errorMessage, "result", errorMessage);
            Flux errorDisplayFlux = Flux.create(emitter -> {
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\n\u5f00\u59cb\u6267\u884cSQL..."));
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u6267\u884cSQL\u67e5\u8be2"));
                emitter.next((Object)ChatResponseUtil.createStatusResponse((String)("SQL\u6267\u884c\u5931\u8d25: " + errorMessage)));
                emitter.complete();
            });
            AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, v -> errorResult, (Flux)errorDisplayFlux, (StreamResponseType)StreamResponseType.EXECUTE_SQL);
            return this.executeSqlQuery(dbConfig, state, currentStep, sqlQuery, schema, userQuery, count, agent, errorMessage, Map.of("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", generator));
        }
    }
}

