package com.digiwin.dap.middle.encrypt.filter;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.collection.ListUtil;
import cn.hutool.core.date.DatePattern;
import cn.hutool.core.date.LocalDateTimeUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.crypto.SecureUtil;
import cn.hutool.extra.servlet.ServletUtil;
import com.digiwin.dap.middle.encrypt.config.EncryptRequestWrapper;
import com.digiwin.dap.middle.encrypt.contstant.EncryptConstants;
import com.digiwin.dap.middle.encrypt.domain.DapEncryptDTO;
import com.digiwin.dap.middle.encrypt.domain.DapSignInfo;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapEncrypt;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapSign;
import com.digiwin.dap.middle.encrypt.support.DapSecretSupport;
import com.digiwin.dap.middleware.auth.AuthoredSys;
import com.digiwin.dap.middleware.cache.RedisUtils;
import com.digiwin.dap.middleware.constant.GlobalConstants;
import com.digiwin.dap.middleware.domain.CommonErrorCode;
import com.digiwin.dap.middleware.domain.DapEnv;
import com.digiwin.dap.middleware.domain.FilterOrderEnum;
import com.digiwin.dap.middleware.exception.BusinessException;
import com.digiwin.dap.middleware.util.JsonUtils;
import com.digiwin.dap.middleware.util.SecureUtils;
import com.digiwin.dap.middleware.util.UserUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.beans.Introspector;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;


/**
 * 获取请求参数并处理。签名校验，报文解密，参数过滤。
 *
 * @author ChenZhuang
 * @date 2024-1-22 17:14:07
 */
public class EncryptRequestFilter extends OncePerRequestFilter implements Ordered {
    private static final Logger LOGGER = LoggerFactory.getLogger(EncryptRequestFilter.class);
    private static final List<String> bodyMethodList = Arrays.asList(RequestMethod.POST.name(), RequestMethod.PUT.name());

    private final DapEnv dapEnv;
    private final DapSecretSupport dapSecretSupport;

    public EncryptRequestFilter(DapEnv dapEnv, DapSecretSupport dapSecretSupport) {
        this.dapEnv = dapEnv;
        this.dapSecretSupport = dapSecretSupport;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        HandlerExecutionChain handlerChain = getHandlerExecutionChain(request);
        if (handlerChain == null) {
            filterChain.doFilter(request, response);
            return;
        }

        HandlerMethod handlerMethod = (HandlerMethod) handlerChain.getHandler();
        Method method = handlerMethod.getMethod();

        DapSign dapSign = getDapSign(handlerMethod);
        DapEncrypt dapEncrypt = getDapEncrypt(handlerMethod);

        String signHeader = request.getHeader(EncryptConstants.HTTP_HEADER_SIGN_ARG_KEY);
        boolean sign = isSign(dapSign, signHeader);

        // 获取请求参数 body
        EncryptRequestWrapper encryptRequestWrapper = new EncryptRequestWrapper(request);
        String bodyString = encryptRequestWrapper.getBodyString();
        boolean encrypt = isEncrypt(request, dapEncrypt, bodyString);

        if (!sign && !encrypt) {
            filterChain.doFilter(encryptRequestWrapper, response);
            return;
        }

        AuthoredSys authoredSys = UserUtils.getAuthoredSys();
        if (authoredSys == null || ObjectUtils.isEmpty(authoredSys.getId())) {
            throw new BusinessException(CommonErrorCode.APP_ID_NONE);
        }

        String appToken = request.getHeader(GlobalConstants.HTTP_HEADER_APP_TOKEN_KEY);
        String appSecret = dapSecretSupport.getAppSecret(UserUtils.getToken(), appToken);
        if (ObjectUtils.isEmpty(appSecret)) {
            throw new BusinessException(CommonErrorCode.APP_ID_SECRET_NONE, new Object[]{UserUtils.getSysId()});
        }

        // 验签
        if (sign) {
            sign(request, appSecret, signHeader, bodyString, method);
        }

        // 解密
        if (encrypt) {
            String data = decrypt(request, bodyString, appSecret);
            encryptRequestWrapper.setBodyString(data.getBytes(StandardCharsets.UTF_8));
        }
        filterChain.doFilter(encryptRequestWrapper, response);
    }

    @Override
    public int getOrder() {
        return FilterOrderEnum.API_ENCRYPT.order();
    }

    private static DapEncrypt getDapEncrypt(HandlerMethod handlerMethod) {
        DapEncrypt dapEncrypt = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), DapEncrypt.class);
        if (dapEncrypt == null) {
            dapEncrypt = handlerMethod.getMethodAnnotation(DapEncrypt.class);
        }
        return dapEncrypt;
    }

    private static DapSign getDapSign(HandlerMethod handlerMethod) {
        DapSign dapSign = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), DapSign.class);
        if (dapSign == null) {
            dapSign = handlerMethod.getMethodAnnotation(DapSign.class);
        }
        return dapSign;
    }

    private boolean isSign(DapSign dapSign, String signHeader) {
        if (dapSign == null) {
            return false;
        }
        //兼容不验签场景 force=false; sign=null,sign=false; signHeader=empty -> sign = false
        return !(!dapSign.force() && !Boolean.TRUE.equals(dapEnv.getSign()) && ObjectUtils.isEmpty(signHeader));
    }

    private boolean isEncrypt(HttpServletRequest request, DapEncrypt dapEncrypt, String bodyString) {
        // POST、PUT接口增加@DapEncrypt注解时，入参body不能为空
        if (!bodyMethodList.contains(request.getMethod()) || dapEncrypt == null) {
            return false;
        }
        // 兼容解密场景 force=false; encrypt=null,encrypt=false; eDate !=empty -> encrypt = true
        // 兼容不解密场景 force=false; encrypt=null,encrypt=false; eDate ==empty -> encrypt = false
        if (!dapEncrypt.force() && !Boolean.TRUE.equals(dapEnv.getEncrypt())) {
            // 获取请求参数 eData是否有值
            if (ObjectUtils.isEmpty(bodyString)) {
                throw new BusinessException(CommonErrorCode.ENCRYPT_REQUEST_BODY_EMPTY);
            }
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("请求参数body：{}", bodyString);
            }

            DapEncryptDTO dto = JsonUtils.readValue(bodyString, DapEncryptDTO.class);
            return dto != null && !ObjectUtils.isEmpty(dto.geteData());
        }
        return true;
    }

    /**
     * 解密
     */
    private static String decrypt(HttpServletRequest request, String bodyString, String appSecret) {
        DapEncryptDTO dto = JsonUtils.readValue(bodyString, DapEncryptDTO.class);
        String data = SecureUtils.decryptBase64(dto.geteData(), appSecret);
        request.setAttribute(EncryptConstants.ENCRYPT_STATUS_KEY, true);
        request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecret);
        return data;
    }

    /**
     * 验签
     */
    private void sign(HttpServletRequest request, String appSecret, String signHeader, String bodyString, Method method) {
        // 3.获取签名信息，参数校验
        String parsed = new String(signHeader.getBytes(StandardCharsets.ISO_8859_1), StandardCharsets.UTF_8);
        DapSignInfo signInfo = DapSignInfo.get(parsed);
        DapSignInfo.verify(signInfo);

        // 4.时间范围
        LocalDateTime dateTime = LocalDateTimeUtil.parse(signInfo.getTimestamp(), DatePattern.PURE_DATETIME_PATTERN);
        if (Math.abs(System.currentTimeMillis() - LocalDateTimeUtil.toEpochMilli(dateTime)) > EncryptConstants.EXPIRE_TIME) {
            throw new BusinessException(CommonErrorCode.SIGN_TIMESTAMP_EXPIRED, new Object[]{signInfo.getTimestamp()});
        }

        // 5.nonce是否已经被使用
        String key = getKey(signInfo.getNonce());
        boolean absent = RedisUtils.setIfAbsent(key, 1, Duration.ofMillis(2L * EncryptConstants.EXPIRE_TIME));
        if (!absent) {
            throw new BusinessException(CommonErrorCode.SIGN_DUPLICATE_REQUEST_ERROR, new Object[]{signInfo.getNonce()});
        }

        // 6.获取signString 按照sign header、Query Param、Body各自排序拼接
        String requestParams = null, beanParams = null;
        Map<String, String> signInfoMap = JsonUtils.objToMap(signInfo);
        signInfoMap.remove("sign");
        String signHeaderParams = MapUtil.sortJoin(signInfoMap, EncryptConstants.AND, EncryptConstants.EQUALS_SIGN, true);

        Map<String, String> paramMap = ServletUtil.getParamMap(request);
        if (!paramMap.isEmpty()) {
            requestParams = MapUtil.sortJoin(paramMap, EncryptConstants.AND, EncryptConstants.EQUALS_SIGN, true);
        }

        if (!ObjectUtils.isEmpty(bodyString)) {
            a: for (int i = 0; i < method.getParameterCount(); i++) {
                MethodParameter mp = new MethodParameter(method, i);
                Annotation[] parameterAnnotations = mp.getParameterAnnotations();
                for (Annotation anno : parameterAnnotations) {
                    if (anno instanceof RequestBody || anno instanceof ModelAttribute) {
                        Class<?> type = mp.getParameterType();
                        if (BeanUtils.isSimpleProperty(type)) {
                            beanParams = bodyString;
                        } else {
                            Map<String, String> beanMap = JsonUtils.jsonToObj(bodyString, Map.class);
                            beanParams = MapUtil.sortJoin(beanMap, EncryptConstants.AND, EncryptConstants.EQUALS_SIGN, true);
                        }
                        break a;
                    }
                }
            }
        }

        List<String> linkedList = ListUtil.toLinkedList(signHeaderParams, requestParams, beanParams);
        CollUtil.removeEmpty(linkedList);
        String signString = CollUtil.join(linkedList, EncryptConstants.AND);

        // 7.验签
        String encryptSign = SecureUtil.hmacSha256(appSecret).digestBase64(signString, CharsetUtil.CHARSET_UTF_8, true);
        if (!Objects.equals(encryptSign, signInfo.getSign())) {
            throw new BusinessException(CommonErrorCode.SIGN_INCONSISTENT_SIGNATURES_ERROR, new Object[]{encryptSign});
        }

        request.setAttribute(EncryptConstants.SIGN_STATUS_KEY, true);
        request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecret);
    }

    private String getKey(String nonce) {
        String appName = ObjectUtils.isEmpty(dapEnv.getAppName()) ? EncryptConstants.MIDDLEWARE : dapEnv.getAppName().toLowerCase();
        return String.format(EncryptConstants.REDIS_DWPAY_SIGN_NONCE, appName, nonce);
    }

    private static HandlerExecutionChain getHandlerExecutionChain(HttpServletRequest request) {
        ApplicationContext applicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
        String mappingName = Introspector.decapitalize(RequestMappingHandlerMapping.class.getSimpleName());
        RequestMappingHandlerMapping handlerMapping = applicationContext.getBean(mappingName, RequestMappingHandlerMapping.class);

        HandlerExecutionChain handlerChain;
        try {
            handlerChain = handlerMapping.getHandler(request);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return handlerChain;
    }
}