/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.digiwin.athenai.component.AiChatService;
import com.digiwin.athenai.constant.StreamResponseType;
import com.digiwin.athenai.dashscope.util.StateUtils;
import com.digiwin.athenai.dashscope.util.StreamingChatGeneratorUtil;
import com.digiwin.athenai.nl2sql.constant.PromptConstant;
import com.digiwin.athenai.nl2sql.utils.NodeUtil;
import com.digiwin.athenai.utils.ChatResponseUtil;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import reactor.core.publisher.Flux;

public class PlannerNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(PlannerNode.class);
    AiChatService aiChatService;

    public PlannerNode(AiChatService aiChatService) {
        this.aiChatService = aiChatService;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        String userPrompt;
        String agent = NodeUtil.agentFromState(state);
        String query = NodeUtil.queryFromState(state);
        String evidencesStr = (String)state.value("evidences", (Object)"");
        String tablesStr = (String)state.value("tables", (Object)"");
        String validationError = (String)state.value("PLAN_VALIDATION_ERROR", (Object)"");
        if (StringUtils.hasText((String)validationError)) {
            logger.warn("This is a plan repair attempt. Previous error: {}", (Object)validationError);
            String previousPlan = StateUtils.getStringValue((OverAllState)state, (String)"PLANNER_NODE_OUTPUT", (String)"");
            userPrompt = String.format("The previous plan you generated failed validation with the following error: %s\n\nHere is the faulty plan:\n%s\n\nPlease correct the plan and provide a new, valid one to answer the original question: %s", validationError, previousPlan, query);
        } else {
            userPrompt = query;
        }
        Map<String, String> params = Map.of("user_question", userPrompt, "schema", tablesStr, "business_knowledge", evidencesStr);
        String plannerPrompt = PromptConstant.sql_planner().render(params);
        Flux chatResponseFlux = this.aiChatService.chatClient().prompt().user(plannerPrompt).stream().chatResponse();
        AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, v -> Map.of("PLANNER_NODE_OUTPUT", v, "PLAN_CURRENT_STEP", 1), (Flux)chatResponseFlux, (StreamResponseType)StreamResponseType.PLAN_GENERATION);
        Flux displayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u5f00\u59cb\u6267\u884c\u65b9\u6848......"));
            emitter.complete();
        });
        AsyncGenerator generator2 = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, v -> Map.of(), (Flux)displayFlux, (StreamResponseType)StreamResponseType.REWRITE);
        return Map.of("PLANNER_NODE_OUTPUT", generator, "generator2", generator2);
    }
}

