package com.digiwin.dap.middleware.aspect;

import com.digiwin.dap.middleware.entity.BaseEntity;
import com.digiwin.dap.middleware.service.CascadeDeleteEntityService;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.persistence.Table;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.List;

/**
 * 删除数据时，做级联删除
 *
 * @author fobgochod
 * @date 2020/4/27
 */
@Aspect
@Component
public class CascadeDeleteEntityRepository {

    @Autowired
    private CascadeDeleteEntityService cascadeDeleteEntityService;

    /**
     * 定义删除的切点
     */
    @Pointcut(value = "execution(public void org.springframework.data.repository.Repository+.deleteById(..))")
    public void deleteByIdPoint() {
    }

    /**
     * 定义删除的切点
     */
    @Pointcut(value = "execution(public void org.springframework.data.repository.Repository+.delete(..))")
    public void deletePoint() {
    }

    /**
     * 定义删除的切点
     */
    @Pointcut(value = "execution(public void org.springframework.data.repository.Repository+.deleteAll(..))")
    public void deleteAllPoint() {
    }

    /**
     * 定义删除的切点
     */
    @Pointcut(value = "execution(public void org.springframework.data.repository.Repository+.deleteInBatch(..))")
    public void deleteInBatchPoint() {
    }

    /**
     * 在删除实体之前删除与其相关联的实体
     */
    @Before("deleteInBatchPoint()")
    public void cascadeDeleteByBatch(JoinPoint pjp) {

        deleteInBatch(pjp);
    }

    /**
     * 在删除实体之前删除与其相关联的实体
     */
    @Before("deleteAllPoint()")
    public void cascadeDeleteByAll(JoinPoint pjp) {

        deleteInBatch(pjp);
    }


    private void deleteInBatch(JoinPoint pjp) {
        if (pjp.getArgs().length > 0) {
            if (pjp.getArgs()[0] instanceof Iterable<?>) {
                String name = getTableName(pjp);
                Iterable<BaseEntity> entities = (Iterable<BaseEntity>) pjp.getArgs()[0];
                List<Long> sids = new ArrayList<>();
                for (com.digiwin.dap.middleware.entity.BaseEntity BaseEntity : entities) {
                    sids.add(BaseEntity.getSid());
                }
                cascadeDeleteEntityService.delete(name, sids);
            }
        }
    }

    /**
     * 在删除实体之前删除与其相关联的实体
     */
    @Before("deleteByIdPoint()")
    public void cascadeDeleteById(JoinPoint pjp) {
        if (pjp.getArgs()[0] instanceof Long) {
            String name = getTableName(pjp);
            Long id = (Long) pjp.getArgs()[0];
            List<Long> sids = new ArrayList<>();
            sids.add(id);
            cascadeDeleteEntityService.delete(name, sids);
        }
    }

    /**
     * 在删除实体之前删除与其相关联的实体
     */
    @Before("deletePoint()")
    public void cascadeDelete(JoinPoint pjp) {
        if (pjp.getArgs()[0] instanceof BaseEntity) {
            String name = getTableName(pjp);
            BaseEntity entity = (BaseEntity) pjp.getArgs()[0];

            List<Long> sids = new ArrayList<>();
            sids.add(entity.getSid());
            cascadeDeleteEntityService.delete(name, sids);
        }
    }


    private String getTableName(JoinPoint pjp) {
        Type genType = pjp.getTarget().getClass().getInterfaces()[0].getGenericInterfaces()[0];
        Type[] params = ((ParameterizedType) genType).getActualTypeArguments();
        Table annotation = (Table) ((Class) params[0]).getAnnotation(Table.class);
        String name;
        if (annotation == null) {
            name = ((Class) params[0]).getSimpleName().toLowerCase();
        } else {
            name = annotation.name();
        }
        return name;
    }
}
