package com.digiwin.athena.knowledgegraph.service.impl;

import com.digiwin.athena.kmservice.locale.Lang;
import com.digiwin.athena.knowledgegraph.domain.inference.InferenceEntity;
import com.digiwin.athena.knowledgegraph.domain.inference.InferenceParamVO;
import com.digiwin.athena.knowledgegraph.domain.inference.InferenceReqVO;
import com.digiwin.athena.repository.neo4j.InferenceRepository;
import com.digiwin.athena.knowledgegraph.service.IPocInferenceService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.neo4j.ogm.model.Result;
import org.neo4j.ogm.session.SessionFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author tang jie cheng
 * @date 2023-08-31 10:33
 */
@Lang
@Service
@Slf4j
public class InferenceService implements IPocInferenceService {
    @Autowired
    InferenceRepository inferenceRepository;
    @Autowired
    SessionFactory neo4jSessionFactory;

    @Override
    public Object postQueryInference(InferenceReqVO inferenceParams) throws Exception {
        List<InferenceParamVO> inferenceParamsList = inferenceParams.getInferenceParams();
        ConcurrentHashMap<String, List<InferenceEntity>> map = new ConcurrentHashMap<>();
        //todo 线程查询后再执行合并策略
        for (InferenceParamVO iPVO : inferenceParamsList) {
            StringBuffer sql = new StringBuffer();
            sql.append("MATCH (source:")
                    .append(iPVO.getSource().getLabel())
                    .append("{code: '")
                    .append(iPVO.getSource().getCode())
                    .append("'})-[:")
                    .append(iPVO.getRelationship());
            if (iPVO.getDepth() == -1) {
                sql.append("*]->(target:");
            } else {
                sql.append("*").append(iPVO.getDepth()).append("]->(target:");
            }
            //MATCH (source:{sourceLabel} {code: $1})-[:$2*{depth}]->(target:{targetLabel} {code: $5}) RETURN target
            sql.append(iPVO.getTarget().getLabel());
            if (StringUtils.isEmpty(iPVO.getTarget().getCode())) {
                sql.append(")");
            } else {
                sql.append("{code: '")
                        .append(iPVO.getTarget().getCode())
                        .append("'})");
            }
            sql.append("RETURN id(target) as id,target.code as code,target.name as name,target.value as value,target" +
                    ".language as language,target.version as version");
            String querySql = sql.toString();
            //if (iPVO.getDepth() != -1) {
            //
            //    //inferenceEntity = inferenceRepository.queryInference(iPVO.getSource().getLabel(),
            //    //        iPVO.getSource().getCode(),
            //    //        iPVO.getRelationship(), iPVO.getDepth(), iPVO.getTarget().getLabel(),
            //    //        iPVO.getTarget().getCode());
            //    //
            //    //if (!ObjectUtils.isEmpty(inferenceEntity)) {
            //    //    inferenceEntity.forEach(v -> v.setLabel(iPVO.getTarget().getLabel()));
            //    //    map.put(iPVO.toString(), inferenceEntity);
            //    //}
            //} else {
            //
            //}
            List<InferenceEntity> inferenceEntity = new ArrayList<>();
            Result query = neo4jSessionFactory.openSession().query(querySql, new HashMap<>());

            Iterator<Map<String, Object>> iterator = query.iterator();
            while (iterator.hasNext()) {
                Map<String, Object> next = iterator.next();
                InferenceEntity inference = new InferenceEntity();
                inference.setId(Long.parseLong(String.valueOf(next.get("id"))));
                inference.setCode(String.valueOf(next.get("code")));
                inference.setName(String.valueOf(next.get("name")));
                inference.setLanguage((Map<String, Map<String, String>>) next.get("language"));
                inference.setValue(String.valueOf(next.get("value")));
                inference.setLabel(iPVO.getTarget().getLabel());
                inferenceEntity.add(inference);
            }
            map.put(iPVO.toString(), inferenceEntity);
        }
        switch (inferenceParams.getMergeStrategy()) {
            case intersection:
                if (inferenceParamsList.size() > 1) {
                    return findDuplicateEntities(map);
                } else {
                    //只有一个入参直接返回了 不用对比
                    List<InferenceEntity> resultList = new ArrayList<>();
                    for (List<InferenceEntity> value : map.values()) {
                        resultList.addAll(value);
                    }
                    return resultList;
                }
            default:
                throw new RuntimeException("postQueryInference.mergeStrategy not exist");
        }
    }


    public static List<InferenceEntity> findDuplicateEntities(Map<String, List<InferenceEntity>> map) {
        List<InferenceEntity> resultList = new ArrayList<>();
        Map<String, Integer> idCountMap = new HashMap<>();
        for (List<InferenceEntity> entities : map.values()) {
            for (InferenceEntity entity : entities) {
                // 记录每个 ID 的出现次数
                idCountMap.put(entity.getLabel() + "-" + entity.getCode(),
                        idCountMap.getOrDefault(entity.getLabel() + "-" + entity.getCode(), 0) + 1);
            }
        }
        // 再次遍历 Map 中的值列表 筛选出出现次数大于 1 的元素
        for (List<InferenceEntity> entities : map.values()) {
            for (InferenceEntity entity : entities) {
                int count = idCountMap.get(entity.getLabel() + "-" + entity.getCode());
                if (count == map.size() && !resultList.contains(entity)) {
                    resultList.add(entity);
                }
            }
        }
        return resultList;
    }
}