/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.node;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import com.alibaba.cloud.ai.graph.async.AsyncGenerator;
import com.digiwin.athenai.constant.StreamResponseType;
import com.digiwin.athenai.dashscope.util.StreamingChatGeneratorUtil;
import com.digiwin.athenai.nl2sql.service.SqlAsistantService;
import com.digiwin.athenai.nl2sql.utils.NodeUtil;
import com.digiwin.athenai.utils.ChatResponseUtil;
import com.digiwin.athenai.utils.PromptUtil;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import reactor.core.publisher.Flux;

public class DocumentRetrieveNode
implements NodeAction {
    private static final Logger logger = LoggerFactory.getLogger(DocumentRetrieveNode.class);
    private SqlAsistantService sqlAsistantService;

    public DocumentRetrieveNode(SqlAsistantService sqlAsistantService) {
        this.sqlAsistantService = sqlAsistantService;
    }

    public Map<String, Object> apply(OverAllState state) throws Exception {
        String agent = NodeUtil.agentFromState(state);
        String query = NodeUtil.queryFromState(state);
        List<Document> evidences = this.sqlAsistantService.evidences(agent, query);
        logger.info("retrieve evidences count : {}", (Object)evidences.size());
        List<Document> tables = this.sqlAsistantService.tables(agent, query, evidences);
        logger.info("retrieve tables count : {}", (Object)tables.size());
        String evidencesStr = PromptUtil.formatDocuments(evidences);
        String tablesStr = PromptUtil.formatDocuments(tables);
        Map<String, String> result = Map.of("evidences", evidencesStr, "tables", tablesStr);
        Flux errorDisplayFlux = Flux.create(emitter -> {
            emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u5f00\u59cb\u6536\u96c6\u4e1a\u52a1\u76f8\u5173\u6570\u636e..."));
            emitter.next((Object)ChatResponseUtil.createStatusResponse((String)("\u7ed3\u675f\u6536\u96c6\u4e1a\u52a1\u76f8\u5173\u6570\u636e" + evidences.size() + "\u6761")));
            emitter.next((Object)ChatResponseUtil.createStatusResponse((String)"\u5f00\u59cb\u6536\u96c6\u8868\u7ed3\u6784\u76f8\u5173\u6570\u636e... "));
            emitter.next((Object)ChatResponseUtil.createStatusResponse((String)("\u7ed3\u675f\u83b7\u53d6\u8868\u7ed3\u6784\u76f8\u5173\u6570\u636e" + tables.size() + "\u6761")));
            emitter.complete();
        });
        AsyncGenerator generator = StreamingChatGeneratorUtil.createStreamingGeneratorWithMessages(this.getClass(), (OverAllState)state, v -> result, (Flux)errorDisplayFlux, (StreamResponseType)StreamResponseType.REWRITE);
        return Map.of("QUERY_REWRITE_NODE_OUTPUT", generator);
    }
}

