/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.digiwin.athenai.constant.StreamResponseType;
import com.digiwin.athenai.dashscope.util.StateUtils;
import com.digiwin.athenai.dashscope.util.StreamingChatGeneratorUtil;
import com.digiwin.athenai.model.ExecutionStep;
import com.digiwin.athenai.model.Plan;
import com.digiwin.athenai.nl2sql.constant.SqlPrompts;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.converter.BeanOutputConverter;
import org.springframework.core.ParameterizedTypeReference;
import reactor.core.publisher.Flux;

public class ReportGeneratorNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(ReportGeneratorNode.class);
    private final ChatClient chatClient;
    private final BeanOutputConverter<Plan> converter;

    public ReportGeneratorNode(ChatClient.Builder chatClientBuilder) {
        this.chatClient = chatClientBuilder.build();
        this.converter = new BeanOutputConverter((ParameterizedTypeReference)new ParameterizedTypeReference<Plan>(this){});
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        logger.info("Entering {} node", (Object)this.getClass().getSimpleName());
        String plannerNodeOutput = StateUtils.getStringValue((OverAllState)state, (String)"PLANNER_NODE_OUTPUT");
        String userInput = StateUtils.getStringValue((OverAllState)state, (String)"query");
        Integer currentStep = (Integer)StateUtils.getObjectValue((OverAllState)state, (String)"PLAN_CURRENT_STEP", Integer.class, (Object)1);
        HashMap executionResults = (HashMap)StateUtils.getObjectValue((OverAllState)state, (String)"SQL_EXECUTE_NODE_OUTPUT", HashMap.class, new HashMap());
        logger.info("Planner node output: {}", (Object)plannerNodeOutput);
        Plan plan = (Plan)this.converter.convert(plannerNodeOutput);
        ExecutionStep executionStep = this.getCurrentExecutionStep(plan, currentStep);
        String summaryAndRecommendations = executionStep.getToolParameters().getSummaryAndRecommendations();
        Flux<ChatResponse> reportGenerationFlux = this.generateReport(userInput, plan, executionResults, summaryAndRecommendations);
        AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, (String)"\u5f00\u59cb\u751f\u6210\u62a5\u544a...", (String)"\u62a5\u544a\u751f\u6210\u5b8c\u6210\uff01", reportContent -> {
            logger.info("Generated report content: {}", reportContent);
            HashMap<String, String> result = new HashMap<String, String>();
            result.put("result", (String)reportContent);
            result.put("SQL_EXECUTE_NODE_OUTPUT", null);
            result.put("PLAN_CURRENT_STEP", null);
            result.put("PLANNER_NODE_OUTPUT", null);
            return result;
        }, reportGenerationFlux, (StreamResponseType)StreamResponseType.OUTPUT_REPORT);
        return Map.of("result", generator);
    }

    private ExecutionStep getCurrentExecutionStep(Plan plan, Integer currentStep) {
        List executionPlan = plan.getExecutionPlan();
        if (executionPlan == null || executionPlan.isEmpty()) {
            throw new IllegalStateException("Execution plan is empty");
        }
        int stepIndex = currentStep - 1;
        if (stepIndex < 0 || stepIndex >= executionPlan.size()) {
            throw new IllegalStateException("Current step index out of range: " + stepIndex);
        }
        return (ExecutionStep)executionPlan.get(stepIndex);
    }

    private Flux<ChatResponse> generateReport(String userInput, Plan plan, HashMap<String, String> executionResults, String summaryAndRecommendations) {
        String userRequirementsAndPlan = this.buildUserRequirementsAndPlan(userInput, plan);
        String analysisStepsAndData = this.buildAnalysisStepsAndData(plan, executionResults);
        String reportPrompt = ReportGeneratorNode.buildReportGeneratorPromptWithOptimization(userRequirementsAndPlan, analysisStepsAndData, summaryAndRecommendations);
        return this.chatClient.prompt().user(reportPrompt).stream().chatResponse();
    }

    private String buildUserRequirementsAndPlan(String userInput, Plan plan) {
        StringBuilder sb = new StringBuilder();
        sb.append("## \u7528\u6237\u539f\u59cb\u9700\u6c42\n");
        sb.append(userInput).append("\n\n");
        sb.append("## \u6267\u884c\u8ba1\u5212\u6982\u8ff0\n");
        sb.append("**\u601d\u8003\u8fc7\u7a0b**: ").append(plan.getThoughtProcess()).append("\n\n");
        sb.append("## \u8be6\u7ec6\u6267\u884c\u6b65\u9aa4\n");
        List executionPlan = plan.getExecutionPlan();
        for (int i = 0; i < executionPlan.size(); ++i) {
            ExecutionStep step = (ExecutionStep)executionPlan.get(i);
            sb.append("### \u6b65\u9aa4 ").append(i + 1).append(": \u6b65\u9aa4\u7f16\u53f7 ").append(step.getStep()).append("\n");
            sb.append("**\u5de5\u5177**: ").append(step.getToolToUse()).append("\n");
            if (step.getToolParameters() != null) {
                sb.append("**\u53c2\u6570\u63cf\u8ff0**: ").append(step.getToolParameters().getDescription()).append("\n");
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    private String buildAnalysisStepsAndData(Plan plan, HashMap<String, String> executionResults) {
        StringBuilder sb = new StringBuilder();
        sb.append("## \u6570\u636e\u6267\u884c\u7ed3\u679c\n");
        if (executionResults.isEmpty()) {
            sb.append("\u6682\u65e0\u6267\u884c\u7ed3\u679c\u6570\u636e\n");
        } else {
            List executionPlan = plan.getExecutionPlan();
            for (Map.Entry<String, String> entry : executionResults.entrySet()) {
                String stepKey = entry.getKey();
                String stepResult = entry.getValue();
                sb.append("### ").append(stepKey).append("\n");
                try {
                    int stepIndex = Integer.parseInt(stepKey.replace("step_", "")) - 1;
                    if (stepIndex >= 0 && stepIndex < executionPlan.size()) {
                        ExecutionStep step = (ExecutionStep)executionPlan.get(stepIndex);
                        sb.append("**\u6b65\u9aa4\u7f16\u53f7**: ").append(step.getStep()).append("\n");
                        sb.append("**\u4f7f\u7528\u5de5\u5177**: ").append(step.getToolToUse()).append("\n");
                        if (step.getToolParameters() != null) {
                            sb.append("**\u53c2\u6570\u63cf\u8ff0**: ").append(step.getToolParameters().getDescription()).append("\n");
                            if (step.getToolParameters().getSqlQuery() != null) {
                                sb.append("**\u6267\u884cSQL**: \n```sql\n").append(step.getToolParameters().getSqlQuery()).append("\n```\n");
                            }
                        }
                    }
                }
                catch (NumberFormatException numberFormatException) {
                    // empty catch block
                }
                sb.append("**\u6267\u884c\u7ed3\u679c**: \n```json\n").append(stepResult).append("\n```\n\n");
            }
        }
        return sb.toString();
    }

    public static String buildReportGeneratorPromptWithOptimization(String userRequirementsAndPlan, String analysisStepsAndData, String summaryAndRecommendations) {
        HashMap<String, String> params = new HashMap<String, String>();
        params.put("user_requirements_and_plan", userRequirementsAndPlan);
        params.put("analysis_steps_and_data", analysisStepsAndData);
        params.put("summary_and_recommendations", summaryAndRecommendations);
        return SqlPrompts.getReportGeneratorPromptTemplate().render(params);
    }
}

