package com.digiwin.dap.middle.encrypt.filter;

import cn.hutool.extra.servlet.ServletUtil;
import com.digiwin.dap.middle.encrypt.config.EncryptRequestWrapper;
import com.digiwin.dap.middle.encrypt.constant.EncryptConstants;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapEncrypt;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapSign;
import com.digiwin.dap.middle.encrypt.support.DapSecretSupport;
import com.digiwin.dap.middleware.cache.RedisUtils;
import com.digiwin.dap.middleware.commons.crypto.AES;
import com.digiwin.dap.middleware.commons.crypto.SignUtils;
import com.digiwin.dap.middleware.constant.DapHttpHeaders;
import com.digiwin.dap.middleware.domain.CommonErrorCode;
import com.digiwin.dap.middleware.domain.DapEnv;
import com.digiwin.dap.middleware.domain.FilterOrderEnum;
import com.digiwin.dap.middleware.exception.BusinessException;
import com.digiwin.dap.middleware.exception.UnauthorizedException;
import com.digiwin.dap.middleware.util.JsonUtils;
import com.digiwin.dap.middleware.util.VerifyUtils;
import com.digiwin.dap.middleware.util.UserUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.beans.Introspector;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * 获取请求参数并处理。签名校验，报文解密，参数过滤。
 *
 * @author ChenZhuang
 * @date 2024-1-22 17:14:07
 */
public class EncryptRequestFilter extends OncePerRequestFilter implements Ordered {

    private static final Logger LOGGER = LoggerFactory.getLogger(EncryptRequestFilter.class);
    private static final List<String> bodyMethodList = Arrays.asList(RequestMethod.POST.name(), RequestMethod.PUT.name());

    private final DapEnv dapEnv;
    private final DapSecretSupport dapSecretSupport;

    public EncryptRequestFilter(DapEnv dapEnv, DapSecretSupport dapSecretSupport) {
        this.dapEnv = dapEnv;
        this.dapSecretSupport = dapSecretSupport;
    }

    private static DapEncrypt getDapEncrypt(HandlerMethod handlerMethod) {
        DapEncrypt dapEncrypt = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), DapEncrypt.class);
        if (dapEncrypt == null) {
            dapEncrypt = handlerMethod.getMethodAnnotation(DapEncrypt.class);
        }
        return dapEncrypt;
    }

    private static DapSign getDapSign(HandlerMethod handlerMethod) {
        DapSign dapSign = AnnotationUtils.findAnnotation(handlerMethod.getBeanType(), DapSign.class);
        if (dapSign == null) {
            dapSign = handlerMethod.getMethodAnnotation(DapSign.class);
        }
        return dapSign;
    }

    private static HandlerExecutionChain getHandlerExecutionChain(HttpServletRequest request) {
        ApplicationContext applicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
        String mappingName = Introspector.decapitalize(RequestMappingHandlerMapping.class.getSimpleName());
        RequestMappingHandlerMapping handlerMapping = applicationContext.getBean(mappingName, RequestMappingHandlerMapping.class);

        HandlerExecutionChain handlerChain;
        try {
            handlerChain = handlerMapping.getHandler(request);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return handlerChain;
    }

    private static String getBodyParam(String bodyString, Method method) {
        String beanParams = null;
        if (!ObjectUtils.isEmpty(bodyString)) {
            a:
            for (int i = 0; i < method.getParameterCount(); i++) {
                MethodParameter mp = new MethodParameter(method, i);
                Annotation[] parameterAnnotations = mp.getParameterAnnotations();
                for (Annotation anno : parameterAnnotations) {
                    if (anno instanceof RequestBody || anno instanceof ModelAttribute) {
                        Class<?> type = mp.getParameterType();
                        if (BeanUtils.isSimpleProperty(type)) {
                            beanParams = bodyString;
                        } else {
                            beanParams = SignUtils.sortParam(bodyString);
                        }
                        break a;
                    }
                }
            }
        }
        return beanParams;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        HandlerExecutionChain handlerChain = getHandlerExecutionChain(request);
        if (handlerChain == null) {
            filterChain.doFilter(request, response);
            return;
        }

        HandlerMethod handlerMethod = (HandlerMethod) handlerChain.getHandler();
        Method method = handlerMethod.getMethod();

        DapSign dapSign = getDapSign(handlerMethod);
        DapEncrypt dapEncrypt = getDapEncrypt(handlerMethod);

        String signHeader = request.getHeader(DapHttpHeaders.APP_ARGS.getHeader());
        boolean sign = isSign(dapSign, signHeader);

        // 获取请求参数 body
        EncryptRequestWrapper encryptRequestWrapper = new EncryptRequestWrapper(request);
        String bodyString = encryptRequestWrapper.getBodyString();
        boolean encrypt = isEncrypt(request, dapEncrypt, bodyString);

        if (!sign && !encrypt) {
            filterChain.doFilter(encryptRequestWrapper, response);
            return;
        }

        if (ObjectUtils.isEmpty(UserUtils.getSysId())) {
            throw new BusinessException(CommonErrorCode.APP_ID_NONE);
        }

        String appToken = request.getHeader(DapHttpHeaders.APP_TOKEN.getHeader());
        String appSecret = dapSecretSupport.getAppSecret(UserUtils.getToken(), appToken);
        if (ObjectUtils.isEmpty(appSecret)) {
            throw new BusinessException(CommonErrorCode.APP_ID_SECRET_NONE, new Object[]{UserUtils.getSysId()});
        }

        // 验签
        if (sign) {
            try {
                com.digiwin.dap.middleware.domain.DapSign signInfo = com.digiwin.dap.middleware.domain.DapSign.get(signHeader);
                VerifyUtils.sign(signInfo, () -> getLock(dapSign.resubmit(), signInfo.getNonce()));
                Map<String, String> queryParams = ServletUtil.getParamMap(request);
                String bodyParams = getBodyParam(bodyString, method);
                boolean verified = SignUtils.verify(JsonUtils.objToMap(signInfo), appSecret, queryParams, bodyParams);
                if (!verified) {
                    throw new UnauthorizedException(CommonErrorCode.SIGN_INCONSISTENT_SIGNATURE);
                }
                request.setAttribute(EncryptConstants.SIGN_STATUS_KEY, true);
                request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecret);
            } catch (Exception e) {
                throw new UnauthorizedException(CommonErrorCode.SIGN_INCONSISTENT_SIGNATURE, e.getMessage());
            }
        }

        // 解密
        if (encrypt) {
            com.digiwin.dap.middleware.domain.DapEncrypt body = JsonUtils.readValue(bodyString, com.digiwin.dap.middleware.domain.DapEncrypt.class);
            String data = AES.decryptIvCBC(body.geteData(), appSecret);
            request.setAttribute(EncryptConstants.ENCRYPT_STATUS_KEY, true);
            request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecret);
            encryptRequestWrapper.setBodyString(data.getBytes(StandardCharsets.UTF_8));
        }
        filterChain.doFilter(encryptRequestWrapper, response);
    }

    private boolean getLock(boolean resubmit, String nonce) {
        if (resubmit) {
            return true;
        }
        String key = String.format(EncryptConstants.REDIS_DWPAY_SIGN_NONCE, dapEnv.getAppName(), nonce);
        return RedisUtils.setIfAbsent(key, 1, Duration.ofMillis(2L * EncryptConstants.EXPIRE_TIME));
    }

    @Override
    public int getOrder() {
        return FilterOrderEnum.API_ENCRYPT.order();
    }

    private boolean isSign(DapSign dapSign, String signHeader) {
        if (dapSign == null) {
            return false;
        }
        // 兼容不验签场景 force=false; sign=null,sign=false; signHeader=empty -> sign = false
        return !(!dapSign.force() && !Boolean.TRUE.equals(dapEnv.getSign()) && ObjectUtils.isEmpty(signHeader));
    }

    private boolean isEncrypt(HttpServletRequest request, DapEncrypt dapEncrypt, String bodyString) {
        // POST、PUT接口增加@DapEncrypt注解时，入参body不能为空
        if (!bodyMethodList.contains(request.getMethod()) || dapEncrypt == null) {
            return false;
        }
        // 兼容解密场景 force=false; encrypt=null,encrypt=false; eDate !=empty -> encrypt = true
        // 兼容不解密场景 force=false; encrypt=null,encrypt=false; eDate ==empty -> encrypt = false
        if (!dapEncrypt.force() && !Boolean.TRUE.equals(dapEnv.getEncrypt())) {
            // 获取请求参数 eData是否有值
            if (ObjectUtils.isEmpty(bodyString)) {
                throw new BusinessException(CommonErrorCode.ENCRYPT_REQUEST_BODY_EMPTY);
            }
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("请求参数body：{}", bodyString);
            }

            com.digiwin.dap.middleware.domain.DapEncrypt dto = JsonUtils.readValue(bodyString, com.digiwin.dap.middleware.domain.DapEncrypt.class);
            return dto != null && !ObjectUtils.isEmpty(dto.geteData());
        }
        return true;
    }
}