package com.digiwin.athena.km_deployer_service.config.mongodb;

import org.springframework.data.domain.Example;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.repository.query.MongoEntityInformation;
import org.springframework.data.mongodb.repository.support.SimpleMongoRepository;
import org.springframework.data.repository.support.PageableExecutionUtils;
import org.springframework.data.util.Streamable;
import org.springframework.util.Assert;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @ClassName ExtendSimpleMongoRepository
 * @Description TODO
 * @Author zhuangli
 * @Date 2022/5/17 20:21
 * @Version 1.0
 **/
public class ExtendSimpleMongoRepository<T, ID extends Serializable> extends SimpleMongoRepository<T, ID> {
    private final MongoOperations mongoOperations;
    private final MongoEntityInformation<T, ID> entityInformation;


    public ExtendSimpleMongoRepository(MongoEntityInformation<T, ID> metadata, MongoOperations mongoOperations) {
        super(metadata, mongoOperations);
        this.entityInformation = metadata;
        this.mongoOperations = mongoOperations;
    }

    public <S extends T> List<S> saveAll(Iterable<S> entities) {
        Assert.notNull(entities, "The given Iterable of entities not be null!");
        Streamable<S> source = Streamable.of(entities);
        boolean allNew = source.stream().allMatch((it) -> {
            return this.entityInformation.isNew(it);
        });
        if (allNew) {
            List<S> result = (List)source.stream().collect(Collectors.toList());
            return new ArrayList(this.mongoOperations.insert(result, this.entityInformation.getCollectionName()));
        } else {
            return (List)source.stream().map(this::save).collect(Collectors.toList());
        }
    }

    public <S extends T> Page<S> findAll(final Example<S> example, final Query query, Pageable pageable) {
        query.addCriteria((new Criteria()).alike(example)).with(pageable);
        List<S> list = this.mongoOperations.find(query, example.getProbeType(), this.entityInformation.getCollectionName());
        return PageableExecutionUtils.getPage(list, pageable, () -> mongoOperations.count(query, example.getProbeType(),entityInformation.getCollectionName()));
    }

    public <S extends T> Page<T> findAll(Query query, Pageable pageable) {
        query.with(pageable);
        List<T> list = mongoOperations.find(query, entityInformation.getJavaType(), entityInformation.getCollectionName());

        return PageableExecutionUtils.getPage(list, pageable,
                () -> mongoOperations.count(query, entityInformation.getJavaType(), entityInformation.getCollectionName()));
    }

}
