/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.digiwin.athenai.constant.StreamResponseType;
import com.digiwin.athenai.dashscope.util.StateUtils;
import com.digiwin.athenai.dashscope.util.StreamingChatGeneratorUtil;
import com.digiwin.athenai.model.ExecutionStep;
import com.digiwin.athenai.nl2sql.code.CodeExecutorProperties;
import com.digiwin.athenai.nl2sql.constant.SqlPrompts;
import com.digiwin.athenai.nl2sql.node.AbstractPlanBasedNode;
import com.digiwin.athenai.utils.MarkdownParser;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import reactor.core.publisher.Flux;

public class PythonGenerateNode
extends AbstractPlanBasedNode
implements NodeAction {
    private static final Logger log = LoggerFactory.getLogger(PythonGenerateNode.class);
    private static final int SAMPLE_DATA_NUMBER = 5;
    private static final int MAX_TRIES_COUNT = 5;
    private final ObjectMapper objectMapper;
    private final CodeExecutorProperties codeExecutorProperties;
    private final ChatClient chatClient;

    public PythonGenerateNode(CodeExecutorProperties codeExecutorProperties, ChatClient.Builder chatClientBuilder) {
        this.codeExecutorProperties = codeExecutorProperties;
        this.chatClient = chatClientBuilder.build();
        this.objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL);
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        this.logNodeEntry();
        String tablesStr = (String)state.value("tables", (Object)"");
        List sqlResults = StateUtils.getListValue((OverAllState)state, (String)"SQL_RESULT_LIST_MEMORY");
        boolean codeRunSuccess = (Boolean)StateUtils.getObjectValue((OverAllState)state, (String)"PYTHON_IS_SUCCESS", Boolean.class, (Object)true);
        int triesCount = (Integer)StateUtils.getObjectValue((OverAllState)state, (String)"PYTHON_TRIES_COUNT", Integer.class, (Object)5);
        Object userPrompt = StateUtils.getStringValue((OverAllState)state, (String)"query");
        if (!codeRunSuccess) {
            String lastCode = StateUtils.getStringValue((OverAllState)state, (String)"PYTHON_GENERATE_NODE_OUTPUT");
            String lastError = StateUtils.getStringValue((OverAllState)state, (String)"PYTHON_EXECUTE_NODE_OUTPUT");
            userPrompt = (String)userPrompt + String.format("\u4e0a\u6b21\u5c1d\u8bd5\u751f\u6210\u7684Python\u4ee3\u7801\u8fd0\u884c\u5931\u8d25\uff0c\u8bf7\u4f60\u91cd\u65b0\u751f\u6210\u7b26\u5408\u8981\u6c42\u7684Python\u4ee3\u7801\u3002\n\u3010\u4e0a\u6b21\u751f\u6210\u4ee3\u7801\u3011\n```python\n%s\n```\n\u3010\u8fd0\u884c\u9519\u8bef\u4fe1\u606f\u3011\n```\n%s\n```\n", lastCode, lastError);
        }
        ExecutionStep executionStep = this.getCurrentExecutionStep(state);
        ExecutionStep.ToolParameters toolParameters = executionStep.getToolParameters();
        String systemPrompt = SqlPrompts.getPythonGeneratorPromptTemplate().render(Map.of("python_memory", this.codeExecutorProperties.getLimitMemory().toString(), "python_timeout", this.codeExecutorProperties.getCodeTimeout(), "database_schema", tablesStr, "sample_input", this.objectMapper.writeValueAsString(sqlResults.stream().limit(5L).toList()), "plan_description", this.objectMapper.writeValueAsString((Object)toolParameters)));
        Flux pythonGenerateFlux = this.chatClient.prompt().system(systemPrompt).user((String)userPrompt).stream().chatResponse();
        AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, (String)"", (String)"", aiResponse -> {
            aiResponse = MarkdownParser.extractRawText((String)aiResponse);
            log.info("Python Generate Code: {}", aiResponse);
            return Map.of("PYTHON_GENERATE_NODE_OUTPUT", aiResponse, "PYTHON_TRIES_COUNT", triesCount - 1);
        }, (Flux)pythonGenerateFlux, (StreamResponseType)StreamResponseType.PYTHON_GENERATE);
        return Map.of("PYTHON_GENERATE_NODE_OUTPUT", generator);
    }
}

