/*
 * Decompiled with CFR 0.152.
 */
package com.digiwin.athenai.component;

import com.digiwin.athenai.advisor.TokenCounterAdvisor;
import com.digiwin.athenai.component.AiEmbeddingService;
import com.digiwin.athenai.domain.AiRequest;
import com.digiwin.athenai.prop.AiComponentProperties;
import com.digiwin.athenai.utils.AiBeanConverter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.client.advisor.api.Advisor;
import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.converter.ListOutputConverter;
import org.springframework.ai.converter.MapOutputConverter;
import org.springframework.ai.converter.StructuredOutputConverter;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.mcp.SyncMcpToolCallbackProvider;
import org.springframework.ai.rag.advisor.RetrievalAugmentationAdvisor;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.preretrieval.query.expansion.MultiQueryExpander;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.CompressionQueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;

@Component
public class AiChatService
implements InitializingBean {
    private static Logger logger = LoggerFactory.getLogger(AiChatService.class);
    private static Map<String, ChatClient> agentChatClients = new ConcurrentHashMap<String, ChatClient>();
    private MapOutputConverter mapOutputConverter = new MapOutputConverter();
    private ListOutputConverter listOutputConverter = new ListOutputConverter();
    @Autowired
    private ChatModel chatModel;
    @Autowired
    private ImageModel imageModel;
    @Autowired
    private ChatClient.Builder chatClientBuilder;
    @Autowired
    private ChatClient defaultChatClient;
    @Autowired
    private SyncMcpToolCallbackProvider toolCallbackProvider;
    @Autowired
    private ChatMemoryRepository chatMemoryRepository;
    @Autowired
    AiComponentProperties aiComponentProperties;
    @Autowired
    private AiEmbeddingService aiEmbeddingService;
    @Autowired
    SimpleLoggerAdvisor simpleLoggerAdvisor;
    @Autowired
    MessageChatMemoryAdvisor messageChatMemoryAdvisor;
    ExecutorService executor = Executors.newFixedThreadPool(3);

    public ChatClient chatClient() {
        return this.createChatClient(null);
    }

    public ChatClient createChatClient(AiRequest request) {
        if (null == request) {
            return this.getChatClientBuilder().build();
        }
        String agent = request.getAgent();
        ChatClient.Builder builder = this.getChatClientBuilder();
        if (null != request.getChatOptions()) {
            builder.defaultOptions(AiBeanConverter.convertChatOptions(request.getChatOptions()));
        }
        if (StringUtils.isNotEmpty((CharSequence)request.getSystem())) {
            builder.defaultSystem(AiBeanConverter.renderText(request.getSystem(), request.getParams()));
        }
        if (!CollectionUtils.isEmpty(request.getToolNames())) {
            builder.defaultToolNames(request.getToolNames().toArray(new String[0]));
        }
        if (!CollectionUtils.isEmpty(request.getParams())) {
            builder.defaultToolContext(request.getParams());
        }
        if (request.isUseChatMemory()) {
            builder.defaultAdvisors(new Advisor[]{this.messageChatMemoryAdvisor});
        }
        if (request.isUseVectorStore()) {
            VectorStore vectorStore = this.aiEmbeddingService.getOrCreateVectorStore(agent, null);
            if (request.isExpendVectorStoreQuery()) {
                ArrayList<Object> queryTransformers = new ArrayList<Object>();
                queryTransformers.add(new RewriteQueryTransformer(this.chatClientBuilder, null, null));
                if (!CollectionUtils.isEmpty(request.getMessages()) && request.getMessages().size() > 20) {
                    queryTransformers.add(new CompressionQueryTransformer(this.chatClientBuilder, null));
                }
                RetrievalAugmentationAdvisor retrievalAugmentationAdvisor = RetrievalAugmentationAdvisor.builder().queryTransformers(queryTransformers).queryExpander((QueryExpander)MultiQueryExpander.builder().chatClientBuilder(this.chatClientBuilder).build()).documentRetriever((DocumentRetriever)VectorStoreDocumentRetriever.builder().vectorStore(vectorStore).build()).documentJoiner((DocumentJoiner)new ConcatenationDocumentJoiner()).queryAugmenter((QueryAugmenter)ContextualQueryAugmenter.builder().build()).build();
                builder.defaultAdvisors(new Advisor[]{retrievalAugmentationAdvisor});
            } else {
                builder.defaultAdvisors(new Advisor[]{new QuestionAnswerAdvisor(vectorStore)});
            }
        }
        if (this.aiComponentProperties.isLogPromptEnable() && request.isLogPromptEnable()) {
            builder.defaultAdvisors(new Advisor[]{this.simpleLoggerAdvisor});
        }
        return builder.build();
    }

    public ChatClient getOrCreateChatClient(AiRequest request) {
        if (null == request || StringUtils.isBlank((CharSequence)request.getAgent())) {
            return this.createChatClient(request);
        }
        return agentChatClients.computeIfAbsent(request.getAgent(), k -> this.createChatClient(request));
    }

    public ChatResponse call(AiRequest request) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        if (!CollectionUtils.isEmpty(request.getParams())) {
            spec.toolContext(request.getParams());
            spec.advisors(advisorSpec -> advisorSpec.params(request.getParams()));
        }
        return spec.call().chatResponse();
    }

    public String callContent(AiRequest request) {
        return AiBeanConverter.getContentFromChatResponse(this.call(request));
    }

    public Map<String, Object> callContent2Map(AiRequest request) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        return (Map)spec.call().entity((StructuredOutputConverter)this.mapOutputConverter);
    }

    public List<String> callContent2List(AiRequest request) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        return (List)spec.call().entity((StructuredOutputConverter)this.listOutputConverter);
    }

    public <T> T callContent2Object(AiRequest request, Class<T> clazz) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        return (T)spec.call().entity(clazz);
    }

    public <T> T callContent2ObjectByType(AiRequest request, ParameterizedTypeReference<T> type) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        return (T)spec.call().entity(type);
    }

    public Flux<ChatResponse> callStream(AiRequest request) {
        ChatClient.ChatClientRequestSpec spec = this.chatClientRequestSpec(request);
        return spec.stream().chatResponse();
    }

    public Flux<String> callStreamContent(AiRequest request) {
        return AiBeanConverter.getContentFromChatResponse(this.callStream(request));
    }

    private ChatClient.ChatClientRequestSpec chatClientRequestSpec(AiRequest request) {
        ChatClient client = this.getOrCreateChatClient(request);
        Prompt prompt = AiBeanConverter.convertPrompt(request);
        ChatClient.ChatClientRequestSpec spec = client.prompt(prompt);
        if (null != request.getConversationId()) {
            spec.advisors(a -> a.param("chat_memory_conversation_id", (Object)request.getConversationId()));
        }
        return spec;
    }

    public String simpleCall(String query) {
        return this.chatClientBuilder.build().prompt(query).call().content();
    }

    public String simpleCall(String system, String query) {
        return this.chatClientBuilder.build().prompt().system(system).user(query).call().content();
    }

    public ChatClient.Builder getChatClientBuilder() {
        return this.chatClientBuilder;
    }

    public void afterPropertiesSet() throws Exception {
        this.chatClientBuilder.defaultAdvisors(new Advisor[]{new TokenCounterAdvisor()});
    }
}

