package com.digiwin.dap.middleware.aspect;

import com.digiwin.dap.middle.database.encrypt.annotation.Desensitization;
import com.digiwin.dap.middle.database.encrypt.config.DatabaseEncryptConfig;
import com.digiwin.dap.middle.database.encrypt.desensitization.context.DesensitizationConvertContext;
import com.digiwin.dap.middle.database.encrypt.desensitization.service.DesensitizationConverter;
import com.digiwin.dap.middle.database.encrypt.model.ObjectRelationalMapping;
import com.digiwin.dap.middle.database.encrypt.utils.CamelToSnakeUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.hibernate.proxy.HibernateProxy;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.parser.Part;
import org.springframework.data.repository.query.parser.PartTree;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.persistence.Column;
import javax.persistence.Table;
import java.io.*;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * <p>针对jpa进行数据库加密,需要注意以下几点:
 * <p>1.该切面只对预定义和自定义方法生效,nativeQuery=true不支持</p>
 * <p>2.该切面可以对参数自动加密、结果集自动解密,不支持对敏感字段的模糊查询</p>
 * <p>3.JPQL中不允许出现select * 否则结果集解密会失效</p>
 * </p>
 *
 * @author michael
 */
@Aspect
@Component
@ConditionalOnBean(DatabaseEncryptConfig.class)
public class EncryptionAndDecryptionAspect {

    private final static Logger LOGGER = LoggerFactory.getLogger(EncryptionAndDecryptionAspect.class);

    private static final String QUERY_PATTERN = "find|read|get|query|search|stream";
    private static final String COUNT_PATTERN = "count";
    private static final String EXISTS_PATTERN = "exists";
    private static final String DELETE_PATTERN = "delete|remove";
    private static final Pattern PREFIX_TEMPLATE = Pattern.compile( //
            "^(" + QUERY_PATTERN + "|" + COUNT_PATTERN + "|" + EXISTS_PATTERN + "|" + DELETE_PATTERN + ")((\\p{Lu}.*?))??By");

    @Autowired
    private DesensitizationConverter<Object> desensitizationConverter;

    @Pointcut("execution(* com.digiwin.dap.middleware.*.repository..*.*(..)) ||" +
            "execution(* org.springframework.data.jpa.repository.JpaRepository.*(..)) || " +
            "execution(* com.digiwin.dap.middleware.repository.BaseEntityWithIdRepository.*(..)) || " +
            "execution(* com.digiwin.dap.middleware.repository.BaseEntityRepository.*(..))")
    public void repositoryMethods() {
    }

    @Around("repositoryMethods()")
    public Object beforeRepositoryMethodExecution(ProceedingJoinPoint joinPoint) throws Throwable {
        Object[] args = joinPoint.getArgs();

        // 解析方法名获取当前实体类、接口类以及数据库和对象映射关系
        TargetMethodInfo targetMethodInfo;
        List<ObjectRelationalMapping> mappings;
        try {
            targetMethodInfo = getTargetMethodInfo(joinPoint);
            mappings = buildObjectRelationalMapping(targetMethodInfo.getEntityClass());
        } catch (Exception e) {
            LOGGER.error("===>JPA解析方法异常");
            return joinPoint.proceed();
        }
        // 校验
        if (Objects.isNull(targetMethodInfo.getEntityClass()) || Objects.isNull(targetMethodInfo.getMethod())) {
            return joinPoint.proceed();
        }
        if (!targetMethodInfo.getEntityClass().isAnnotationPresent(Table.class)
                || !targetMethodInfo.getEntityClass().isAnnotationPresent(Desensitization.class)
                || !targetMethodInfo.getEntityClass().getAnnotation(Desensitization.class).enabled()) {
            return joinPoint.proceed();
        }
        if (targetMethodInfo.getMethod().isAnnotationPresent(Query.class)) {
            return joinPoint.proceed();
        }

        String targetMethod = targetMethodInfo.getInterfaceClass().getName() + "." + targetMethodInfo.getMethod().getName();
        // 加密参数
        try {
            List<String> conditionList = parserMethodName(targetMethodInfo.getEntityClass(), joinPoint.getSignature().getName(), args);
            for (int i = 0; i < args.length; i++) {
                String fieldName = conditionList.size() > i ? conditionList.get(i) : "";
                List<ObjectRelationalMapping> paramMappings;
                if (StringUtils.hasLength(fieldName)) {
                    paramMappings = mappings.stream().filter(x -> x.getObjectPropertyName().equals(fieldName)).collect(Collectors.toList());
                }else {
                    paramMappings = mappings;
                }
                DesensitizationConvertContext<Object> parameterContext
                        = new DesensitizationConvertContext<>(targetMethod, paramMappings);
                parameterContext.setContext(args[i]);
                args[i] = desensitizationConverter.desensitize(parameterContext);
            }
        } catch (Exception e) {
            LOGGER.error("===>{}加密参数异常", targetMethod, e);
        }

        // 执行方法
        Object result = joinPoint.proceed(args);
        if (Objects.nonNull(result) && result instanceof HibernateProxy) {
            LOGGER.warn("===>{}启用了Hibernate延迟加载不支持加解密", targetMethod);
            return result;
        }

        // 解密结果集,直接操作返回对象会触发Hibernate中的脏检查机制,两种处理方式:1.清除实体管理器缓存entityManager.clear();2.深拷贝对象处理拷贝结果
        try {
            DesensitizationConvertContext<Object> resultContext
                    = new DesensitizationConvertContext<>(targetMethod, mappings);
            resultContext.setContext(cloneResult(result));
            Object newObject = desensitizationConverter.revert(resultContext);
            return result instanceof Optional ? Optional.ofNullable(newObject) : newObject;
        } catch (Exception e) {
            LOGGER.error("===>{}结果解密失败", targetMethod, e);
        }
        return result;
    }

    private List<ObjectRelationalMapping> buildObjectRelationalMapping(Class<?> entityClass) {
        List<ObjectRelationalMapping> objectRelationalMappings = new ArrayList<>();
        Table table = entityClass.getAnnotation(Table.class);
        for (Field field : entityClass.getDeclaredFields()) {
            String objectPropertyName = field.getName();
            Column column = field.getAnnotation(Column.class);
            String dataBaseColumnName = column != null && StringUtils.hasLength(column.name()) ? column.name() :
                    CamelToSnakeUtils.convertCamelToSnake(objectPropertyName);
            objectRelationalMappings.add(new ObjectRelationalMapping(table.name(), dataBaseColumnName, objectPropertyName));
        }
        return objectRelationalMappings;
    }

    private Object cloneResult(Object result) throws IOException, ClassNotFoundException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        ObjectOutputStream out = new ObjectOutputStream(bos);
        // Optional对象无法直接序列化
        if (result instanceof Optional) {
            out.writeObject(((Optional<?>) result).orElse(null));
        } else {
            out.writeObject(result);
        }
        out.flush();
        out.close();
        ObjectInputStream in = new ObjectInputStream(new ByteArrayInputStream(bos.toByteArray()));
        return in.readObject();
    }

    private List<String> parserMethodName(Class<?> entityClass, String methodName, Object[] args) {
        // 解析jap方法名称获取sql参数,@see https://springdoc.cn/spring-data-jpa/#repository-query-keywords
        List<String> conditionFieldList = new ArrayList<>();
        if (Objects.nonNull(args) && args.length > 0) {
            Matcher matcher = PREFIX_TEMPLATE.matcher(methodName);
            if (matcher.find()) {
                // 方法参数不为空且匹配Spring Data Repository查询推导机制通常支持的subject关键字解析方法名获取参数,此处适配所有查和删的场景
                PartTree tree = new PartTree(methodName, entityClass);
                for (Part part : tree.getParts()) {
                    String property = part.getProperty().toDotPath();
                    conditionFieldList.add(property);
                }
            }
        }
        return conditionFieldList;
    }


    private TargetMethodInfo getTargetMethodInfo(ProceedingJoinPoint joinPoint) {
        String methodName = joinPoint.getSignature().getName();
        Class<?> targetClass = joinPoint.getTarget().getClass();
        // 获取一个类直接实现的接口列表,接口数组的顺序通常与类声明实现接口的顺序一致(接口声明顺序、父接口先于子接口、接口的继承顺序),所以第一个接口默认是当前我们自定义的Interface Repository
        Class<?> interfaceClass = targetClass.getInterfaces()[0];
        Type[] type = interfaceClass.getGenericInterfaces();
        Type genericInterface = type[0];
        Class<?> entityClass = null;
        if (genericInterface instanceof ParameterizedType) {
            Type[] typeArgs = ((ParameterizedType) genericInterface).getActualTypeArguments();
            entityClass = typeArgs.length > 0 ? (Class<?>) typeArgs[0] : null;
        }
        Optional<Method> optionalMethod = Arrays.stream(interfaceClass.getMethods())
                .filter(method -> method.getName().equals(methodName))
                .findFirst();
        return new TargetMethodInfo(interfaceClass, optionalMethod.orElse(null), entityClass);
    }

    static class TargetMethodInfo {
        private final Class<?> interfaceClass;

        private final Method method;

        private final Class<?> entityClass;

        TargetMethodInfo(Class<?> interfaceClass, Method method, Class<?> entityClass) {
            this.interfaceClass = interfaceClass;
            this.method = method;
            this.entityClass = entityClass;
        }

        public Class<?> getInterfaceClass() {
            return interfaceClass;
        }

        public Method getMethod() {
            return method;
        }

        public Class<?> getEntityClass() {
            return entityClass;
        }
    }
}
