/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.nl2sql.service;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphRunnerException;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.digiwin.athenai.nl2sql.service.SqlExecutorService;
import com.digiwin.athenai.nl2sql.service.SqlGeneratorService;
import java.util.Map;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;

@Service
public class Nl2SqlService {
    private final CompiledGraph nl2sqlGraph;
    @Autowired
    SqlExecutorService sqlExecutorService;
    @Autowired
    SqlGeneratorService sqlGeneratorService;

    public Nl2SqlService(@Qualifier(value="nl2sqlGraph") StateGraph stateGraph) throws GraphStateException {
        this.nl2sqlGraph = stateGraph.compile();
        this.nl2sqlGraph.setMaxIterations(100);
    }

    public String nl2sql(String agent, String query) throws Exception {
        return this.sqlGeneratorService.generateSql(agent, query);
    }

    public String executeSql(String agent, String sql) throws Exception {
        return this.sqlExecutorService.executeSql(agent, sql);
    }

    public Map<String, Object> invoke(Map<String, Object> params) throws GraphRunnerException {
        System.out.println(Thread.currentThread().getName() + " running graph " + String.valueOf(params));
        Optional result = this.nl2sqlGraph.invoke(params);
        return ((OverAllState)result.get()).data();
    }

    public CompiledGraph getNl2sqlGraph() {
        return this.nl2sqlGraph;
    }
}

