package com.digiwin.chatbi.service;

import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.digiwin.chatbi.beans.ResultBean;
import com.digiwin.chatbi.beans.pojos.gpt.FieldModelAI;
import com.digiwin.chatbi.beans.pojos.gpt.ModelAI;
import com.digiwin.chatbi.beans.pojos.vector.SearchDocsQo;
import com.digiwin.chatbi.beans.pojos.vector.SearchDocsResult;
import com.digiwin.chatbi.common.annotations.LogRecord;
import com.digiwin.chatbi.common.constant.Constants;
import com.digiwin.chatbi.common.enums.KnowledgeBaseRequest;
import com.google.common.reflect.TypeToken;
import com.google.gson.Gson;
import io.micrometer.core.instrument.binder.BaseUnits;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.apache.http.HttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.core.io.FileSystemResource;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.web.client.RestTemplate;

@Service("VectorDataSourceService")
/* loaded from: input_file:WEB-INF/classes/com/digiwin/chatbi/service/VectorDataSourceService.class */
public class VectorDataSourceService {

    @Autowired
    private RestTemplate restTemplate;

    @Value("${vector.prefix}")
    private String vectorPrefix;
    private static final Logger log = LoggerFactory.getLogger((Class<?>) VectorDataSourceService.class);
    private static final Pattern SAFE_FILENAME_PATTERN = Pattern.compile("^[a-zA-Z0-9_.-]+$");

    private HttpHeaders getHeaders() {
        HttpHeaders httpHeaders = new HttpHeaders();
        httpHeaders.setContentType(MediaType.parseMediaType("application/json;charset=UTF-8"));
        httpHeaders.add("Accept", MediaType.APPLICATION_JSON.toString());
        httpHeaders.add("Accept-Charset", "UTF-8");
        return httpHeaders;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @LogRecord(isGpt = "1")
    public List<SearchDocsResult> searchDocs(SearchDocsQo searchDocsQo, String str) throws Exception {
        HttpHeaders headers = getHeaders();
        HashMap hashMap = new HashMap();
        hashMap.put("query", searchDocsQo.getQuery());
        hashMap.put("knowledge_base_name", searchDocsQo.getKnowledge_base_name());
        hashMap.put("top_k", searchDocsQo.getTop_k());
        hashMap.put("score_threshold", searchDocsQo.getScore_threshold());
        HttpEntity<?> httpEntity = new HttpEntity<>(hashMap, headers);
        log.info("搜索知识库开始...");
        long currentTimeMillis = System.currentTimeMillis();
        ResponseEntity exchange = this.restTemplate.exchange(KnowledgeBaseRequest.SEARCH_DOCS.getUrl(), HttpMethod.POST, httpEntity, String.class, new Object[0]);
        log.info("搜索知识库耗时：" + (System.currentTimeMillis() - currentTimeMillis) + BaseUnits.MILLISECONDS);
        return (List) new Gson().fromJson((String) exchange.getBody(), new TypeToken<List<SearchDocsResult>>() { // from class: com.digiwin.chatbi.service.VectorDataSourceService.1
        }.getType());
    }

    public String queryDataSourceVoList(List<ModelAI> list) {
        ArrayList arrayList = new ArrayList();
        for (ModelAI modelAI : list) {
            Set<String> set = (Set) modelAI.getSchema().stream().map((v0) -> {
                return v0.getTitle();
            }).collect(Collectors.toSet());
            ArrayList arrayList2 = new ArrayList();
            for (String str : set) {
                FieldModelAI fieldModelAI = new FieldModelAI();
                fieldModelAI.setTitle(str);
                arrayList2.add(fieldModelAI);
            }
            modelAI.setSchema(arrayList2);
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < list.size(); i++) {
            ModelAI modelAI2 = list.get(i);
            if (Objects.nonNull(list.get(i))) {
                StringBuilder sb2 = new StringBuilder();
                sb2.append("{");
                sb2.append("name:").append(modelAI2.getTableName()).append(",");
                sb2.append("datasourceId:").append(modelAI2.getDatasourceId()).append(",");
                sb2.append("schema:[{");
                if (CollectionUtils.isNotEmpty(modelAI2.getSchema())) {
                    for (int i2 = 0; i2 < modelAI2.getSchema().size(); i2++) {
                        sb2.append(modelAI2.getSchema().get(i2).getTitle());
                        if (i2 != modelAI2.getSchema().size() - 1) {
                            sb2.append("~~");
                        }
                    }
                    sb2.append("}");
                    sb2.append("]");
                    sb2.append("}");
                    arrayList.add(sb2.toString());
                    sb.append(sb2.toString());
                    if (i != list.size() - 1) {
                        sb.append("//");
                    }
                }
            }
        }
        return sb.toString();
    }

    public ResultBean uploadDoc(String str, String str2) {
        HttpHeaders headers = getHeaders();
        headers.setContentType(MediaType.MULTIPART_FORM_DATA);
        FileSystemResource fileSystemResource = new FileSystemResource(str);
        LinkedMultiValueMap linkedMultiValueMap = new LinkedMultiValueMap();
        linkedMultiValueMap.add("file", fileSystemResource);
        linkedMultiValueMap.add("knowledge_base_name", String.valueOf(str2));
        linkedMultiValueMap.add("override", true);
        linkedMultiValueMap.add("not_refresh_vs_cache", false);
        HttpEntity httpEntity = new HttpEntity(linkedMultiValueMap, headers);
        log.info("上传文件到知识库开始...");
        long currentTimeMillis = System.currentTimeMillis();
        String str3 = (String) this.restTemplate.postForObject(KnowledgeBaseRequest.UPLOAD_DOC.getUrl(), httpEntity, String.class, new Object[0]);
        log.info("上传文件到知识库返回：" + str3);
        log.info("上传文件到知识库耗时：" + (System.currentTimeMillis() - currentTimeMillis) + BaseUnits.MILLISECONDS);
        JSONObject parseObject = JSONObject.parseObject(str3);
        return (Objects.nonNull(parseObject) && "200".equals(parseObject.getString(Constants.CODE))) ? ResultBean.ok(parseObject.getString(Constants.MSG)) : ResultBean.fail("文件上传失败:" + str3);
    }

    public ResultBean vectorSync(List<ModelAI> list) {
        return create(list);
    }

    @Async
    public ResultBean create(List<ModelAI> list) {
        String str = this.vectorPrefix + "_model_center_test";
        List<String> list2 = null;
        try {
            list2 = queryKnowledgeBases();
        } catch (Exception e) {
            log.error("获取知识库列表失败：" + e.getMessage());
        }
        if ((CollectionUtils.isNotEmpty(list2) && !list2.contains(str)) || CollectionUtils.isEmpty(list2)) {
            try {
                String createKnowledgeBase = createKnowledgeBase(str);
                if (Objects.nonNull(createKnowledgeBase) && "200".equals(createKnowledgeBase)) {
                    log.info("创建知识库成功");
                }
            } catch (Exception e2) {
                log.error("创建知识库失败：" + e2.getMessage());
            }
        }
        String queryDataSourceVoList = queryDataSourceVoList(list);
        String validateAndGetSafeFileName = validateAndGetSafeFileName("file");
        if (validateAndGetSafeFileName == null) {
            return null;
        }
        File file = new File(FilenameUtils.getName(new File(FilenameUtils.getName(".")).getAbsolutePath()), FilenameUtils.getName(validateAndGetSafeFileName));
        if (!file.exists()) {
            file.mkdirs();
        }
        String str2 = file + File.separator + str + ".tt";
        BufferedWriter bufferedWriter = null;
        try {
            try {
                bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(FileUtils.getFile(str2)), "UTF-8"));
                bufferedWriter.write(queryDataSourceVoList);
                log.info("文件生成成功。");
                if (bufferedWriter != null) {
                    try {
                        bufferedWriter.close();
                    } catch (IOException e3) {
                        e3.printStackTrace();
                    }
                }
            } catch (IOException e4) {
                log.error("文件生成失败：" + e4.getMessage());
                if (bufferedWriter != null) {
                    try {
                        bufferedWriter.close();
                    } catch (IOException e5) {
                        e5.printStackTrace();
                    }
                }
            }
            ResultBean resultBean = null;
            try {
                resultBean = uploadDoc(str2, str);
            } catch (Exception e6) {
                e6.printStackTrace();
            }
            if (Objects.nonNull(resultBean) && 200 == resultBean.getCode().intValue()) {
                log.info("向量库文件生成成功");
                return ResultBean.ok("成功");
            }
            log.error("向量库文件生成失败");
            return ResultBean.fail("失败");
        } catch (Throwable th) {
            if (bufferedWriter != null) {
                try {
                    bufferedWriter.close();
                } catch (IOException e7) {
                    e7.printStackTrace();
                    throw th;
                }
            }
            throw th;
        }
    }

    private static String validateAndGetSafeFileName(String str) {
        String name = FilenameUtils.getName(str);
        if (SAFE_FILENAME_PATTERN.matcher(name).matches()) {
            return name;
        }
        return null;
    }

    @LogRecord(isGpt = "1")
    public List<String> queryKnowledgeBases() throws Exception {
        new HttpEntity(new HashMap(), getHeaders());
        log.info("获取知识库列表开始...");
        long currentTimeMillis = System.currentTimeMillis();
        String entityUtils = EntityUtils.toString(HttpClients.createDefault().execute((HttpUriRequest) new HttpGet(new URIBuilder(KnowledgeBaseRequest.LIST_KNOWLEDGE_BASES.getUrl()).build().toString())).getEntity());
        log.info("获取知识库列表耗时：" + (System.currentTimeMillis() - currentTimeMillis) + BaseUnits.MILLISECONDS);
        JSONObject parseObject = JSONObject.parseObject(entityUtils);
        if (!Objects.nonNull(parseObject) || !"200".equals(parseObject.getString(Constants.CODE))) {
            return null;
        }
        JSONArray jSONArray = parseObject.getJSONArray("data");
        if (Objects.nonNull(jSONArray)) {
            return JSONArray.parseArray(jSONArray.toJSONString(), String.class);
        }
        return null;
    }

    public String createKnowledgeBase(String str) throws Exception {
        CloseableHttpClient createDefault = HttpClients.createDefault();
        HttpPost httpPost = new HttpPost(KnowledgeBaseRequest.CREATE_KNOWLEDGE_BASE.getUrl());
        httpPost.setHeader("Content-Type", ContentType.APPLICATION_JSON.toString());
        HashMap hashMap = new HashMap();
        hashMap.put("knowledge_base_name", String.valueOf(str));
        hashMap.put("vector_store_type", "faiss");
        hashMap.put("embed_model", "text-embedding-ada-002");
        httpPost.setEntity(new StringEntity(JSONObject.toJSONString(hashMap), ContentType.APPLICATION_JSON));
        log.info("创建知识库开始...");
        long currentTimeMillis = System.currentTimeMillis();
        HttpResponse execute = createDefault.execute((HttpUriRequest) httpPost);
        log.info("创建知识库耗时：" + (System.currentTimeMillis() - currentTimeMillis) + BaseUnits.MILLISECONDS);
        JSONObject parseObject = JSONObject.parseObject(EntityUtils.toString(execute.getEntity()));
        if (Objects.nonNull(parseObject) && "200".equals(parseObject.getString(Constants.CODE))) {
            return parseObject.getString(Constants.CODE);
        }
        return null;
    }
}
