/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.config;

import com.alibaba.cloud.ai.graph.GraphRepresentation;
import com.alibaba.cloud.ai.graph.KeyStrategyFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.AsyncEdgeAction;
import com.alibaba.cloud.ai.graph.action.AsyncNodeAction;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.nl2sql.code.CodeExecutorProperties;
import com.digiwin.athenai.nl2sql.code.CodePoolExecutorService;
import com.digiwin.athenai.nl2sql.dispatcher.PlanExecutorDispatcher;
import com.digiwin.athenai.nl2sql.dispatcher.PythonExecutorDispatcher;
import com.digiwin.athenai.nl2sql.dispatcher.SQLExecutorDispatcher;
import com.digiwin.athenai.nl2sql.node.DocumentRetrieveNode;
import com.digiwin.athenai.nl2sql.node.PlanExecutorNode;
import com.digiwin.athenai.nl2sql.node.PlannerNode;
import com.digiwin.athenai.nl2sql.node.PrintNode;
import com.digiwin.athenai.nl2sql.node.PythonAnalyzeNode;
import com.digiwin.athenai.nl2sql.node.PythonExecuteNode;
import com.digiwin.athenai.nl2sql.node.PythonGenerateNode;
import com.digiwin.athenai.nl2sql.node.ReportGeneratorNode;
import com.digiwin.athenai.nl2sql.node.SqlExecuteNode;
import com.digiwin.athenai.nl2sql.service.SqlAsistantService;
import com.digiwin.athenai.nl2sql.service.SqlExecutorService;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import org.mybatis.spring.annotation.MapperScan;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
@MapperScan(value={"com.digiwin.athenai.nl2sql.mapper"})
public class Nl2SqlConfiguration {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(Nl2SqlConfiguration.class);
    @Autowired
    private SqlAsistantService sqlAsistantService;
    @Autowired
    AiChatService aiChatService;
    @Autowired
    SqlExecutorService sqlExecutorService;
    @Autowired
    private CodeExecutorProperties codeExecutorProperties;
    @Autowired
    private CodePoolExecutorService codePoolExecutor;
    @Autowired
    ChatClient.Builder chatClientBuilder;

    @Bean(value={"nl2sqlGraph"})
    public StateGraph nl2sqlGraph() throws GraphStateException {
        KeyStrategyFactory keyStrategyFactory = () -> {
            HashMap<String, ReplaceStrategy> keyStrategyHashMap = new HashMap<String, ReplaceStrategy>();
            keyStrategyHashMap.put("agent", new ReplaceStrategy());
            keyStrategyHashMap.put("query", new ReplaceStrategy());
            keyStrategyHashMap.put("evidences", new ReplaceStrategy());
            keyStrategyHashMap.put("tables", new ReplaceStrategy());
            keyStrategyHashMap.put("PLANNER_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_CURRENT_STEP", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_NEXT_NODE", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_VALIDATION_STATUS", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_VALIDATION_ERROR", new ReplaceStrategy());
            keyStrategyHashMap.put("PLAN_REPAIR_COUNT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_EXECUTE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_EXECUTE_NODE_EXCEPTION_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("SQL_RESULT_LIST_MEMORY", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_IS_SUCCESS", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_TRIES_COUNT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_EXECUTE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_GENERATE_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("PYTHON_ANALYSIS_NODE_OUTPUT", new ReplaceStrategy());
            keyStrategyHashMap.put("REPORT_GENERATOR_NODE", new ReplaceStrategy());
            keyStrategyHashMap.put("result", new ReplaceStrategy());
            return keyStrategyHashMap;
        };
        StateGraph stateGraph = new StateGraph("nl2sqlGraph", keyStrategyFactory).addNode("DocumentRetrieveNode", AsyncNodeAction.node_async((NodeAction)new DocumentRetrieveNode(this.sqlAsistantService))).addNode("PrintNode", AsyncNodeAction.node_async((NodeAction)new PrintNode())).addNode("PLANNER_NODE", AsyncNodeAction.node_async((NodeAction)new PlannerNode(this.aiChatService))).addNode("PLAN_EXECUTOR_NODE", AsyncNodeAction.node_async((NodeAction)new PlanExecutorNode())).addNode("SQL_EXECUTE_NODE", AsyncNodeAction.node_async((NodeAction)new SqlExecuteNode(this.sqlExecutorService, this.sqlAsistantService, this.aiChatService))).addNode("REPORT_GENERATOR_NODE", AsyncNodeAction.node_async((NodeAction)new ReportGeneratorNode(this.chatClientBuilder))).addNode("PYTHON_GENERATE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonGenerateNode(this.codeExecutorProperties, this.chatClientBuilder))).addNode("PYTHON_EXECUTE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonExecuteNode(this.codePoolExecutor))).addNode("PYTHON_ANALYZE_NODE", AsyncNodeAction.node_async((NodeAction)new PythonAnalyzeNode(this.chatClientBuilder)));
        stateGraph.addEdge("__START__", "DocumentRetrieveNode").addEdge("DocumentRetrieveNode", "PLANNER_NODE").addEdge("PLANNER_NODE", "PLAN_EXECUTOR_NODE").addConditionalEdges("PLAN_EXECUTOR_NODE", AsyncEdgeAction.edge_async((EdgeAction)new PlanExecutorDispatcher()), Map.of("PLANNER_NODE", "PLANNER_NODE", "SQL_EXECUTE_NODE", "SQL_EXECUTE_NODE", "PYTHON_GENERATE_NODE", "PYTHON_GENERATE_NODE", "REPORT_GENERATOR_NODE", "REPORT_GENERATOR_NODE", "__END__", "__END__")).addConditionalEdges("SQL_EXECUTE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new SQLExecutorDispatcher()), Map.of("PLANNER_NODE", "PLANNER_NODE", "PLAN_EXECUTOR_NODE", "PLAN_EXECUTOR_NODE")).addEdge("PYTHON_GENERATE_NODE", "PYTHON_EXECUTE_NODE").addConditionalEdges("PYTHON_EXECUTE_NODE", AsyncEdgeAction.edge_async((EdgeAction)new PythonExecutorDispatcher()), Map.of("PYTHON_ANALYZE_NODE", "PYTHON_ANALYZE_NODE", "__END__", "__END__", "PYTHON_GENERATE_NODE", "PYTHON_GENERATE_NODE")).addEdge("PYTHON_ANALYZE_NODE", "PLAN_EXECUTOR_NODE").addEdge("REPORT_GENERATOR_NODE", "__END__");
        GraphRepresentation graphRepresentation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML, "workflow graph");
        log.info("\n\n");
        log.info(graphRepresentation.content());
        log.info("\n\n");
        return stateGraph;
    }
}

