package com.digiwin.dap.middle.encrypt.filter;

import cn.hutool.extra.servlet.ServletUtil;
import com.digiwin.dap.middle.encrypt.AnnotationUtil;
import com.digiwin.dap.middle.encrypt.config.EncryptRequestWrapper;
import com.digiwin.dap.middle.encrypt.constant.EncryptConstants;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapEncrypt;
import com.digiwin.dap.middle.encrypt.domain.annotation.DapSign;
import com.digiwin.dap.middle.encrypt.support.DapSecretSupport;
import com.digiwin.dap.middleware.auth.AppAuthContextHolder;
import com.digiwin.dap.middleware.cache.RedisUtils;
import com.digiwin.dap.middleware.commons.crypto.AES;
import com.digiwin.dap.middleware.commons.crypto.SignUtils;
import com.digiwin.dap.middleware.commons.util.StrUtils;
import com.digiwin.dap.middleware.constant.DapHttpHeaders;
import com.digiwin.dap.middleware.domain.CommonErrorCode;
import com.digiwin.dap.middleware.domain.DapEnv;
import com.digiwin.dap.middleware.domain.FilterOrderEnum;
import com.digiwin.dap.middleware.exception.BusinessException;
import com.digiwin.dap.middleware.exception.UnauthorizedException;
import com.digiwin.dap.middleware.util.JsonUtils;
import com.digiwin.dap.middleware.util.UserUtils;
import com.digiwin.dap.middleware.util.VerifyUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationContext;
import org.springframework.core.MethodParameter;
import org.springframework.core.Ordered;
import org.springframework.util.CollectionUtils;
import org.springframework.web.bind.annotation.ModelAttribute;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.context.support.WebApplicationContextUtils;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.beans.Introspector;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
 * 获取请求参数并处理。签名校验，报文解密，参数过滤。
 *
 * @author ChenZhuang
 * @date 2024-1-22 17:14:07
 */
public class EncryptRequestFilter extends OncePerRequestFilter implements Ordered {

    private static final Logger LOGGER = LoggerFactory.getLogger(EncryptRequestFilter.class);
    private static final List<String> bodyMethodList = Arrays.asList(RequestMethod.POST.name(), RequestMethod.PUT.name());

    private final DapEnv dapEnv;
    private final DapSecretSupport dapSecretSupport;

    public EncryptRequestFilter(DapEnv dapEnv, DapSecretSupport dapSecretSupport) {
        this.dapEnv = dapEnv;
        this.dapSecretSupport = dapSecretSupport;
    }

    private static HandlerExecutionChain getHandlerExecutionChain(HttpServletRequest request) {
        ApplicationContext applicationContext = WebApplicationContextUtils.getWebApplicationContext(request.getServletContext());
        String mappingName = Introspector.decapitalize(RequestMappingHandlerMapping.class.getSimpleName());
        RequestMappingHandlerMapping handlerMapping = applicationContext.getBean(mappingName, RequestMappingHandlerMapping.class);

        HandlerExecutionChain handlerChain;
        try {
            handlerChain = handlerMapping.getHandler(request);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        return handlerChain;
    }

    private static String getBodyParam(String bodyString, Method method) {
        String beanParams = null;
        if (StrUtils.isNotEmpty(bodyString)) {
            a:
            for (int i = 0; i < method.getParameterCount(); i++) {
                MethodParameter mp = new MethodParameter(method, i);
                Annotation[] parameterAnnotations = mp.getParameterAnnotations();
                for (Annotation anno : parameterAnnotations) {
                    if (anno instanceof RequestBody || anno instanceof ModelAttribute) {
                        Class<?> type = mp.getParameterType();
                        if (BeanUtils.isSimpleProperty(type)) {
                            beanParams = bodyString;
                        } else {
                            beanParams = SignUtils.sortParam(bodyString);
                        }
                        break a;
                    }
                }
            }
        }
        return beanParams;
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        HandlerExecutionChain handlerChain = getHandlerExecutionChain(request);
        if (handlerChain == null) {
            filterChain.doFilter(request, response);
            return;
        }

        HandlerMethod handlerMethod = (HandlerMethod) handlerChain.getHandler();
        Method method = handlerMethod.getMethod();

        DapSign dapSign = AnnotationUtil.getAnnotation(handlerMethod, DapSign.class);
        DapEncrypt dapEncrypt = AnnotationUtil.getAnnotation(handlerMethod, DapEncrypt.class);

        String signHeader = request.getHeader(DapHttpHeaders.APP_ARGS.getHeader());
        boolean sign = isSign(dapSign, signHeader);

        // 获取请求参数 body
        EncryptRequestWrapper encryptRequestWrapper = new EncryptRequestWrapper(request);
        String bodyString = encryptRequestWrapper.getBodyString();
        boolean encrypt = isEncrypt(request, dapEncrypt, bodyString);
        String appToken = request.getHeader(DapHttpHeaders.APP_TOKEN.getHeader());
        String userToken = request.getHeader(DapHttpHeaders.USER_TOKEN.getHeader());
        if (!sign && dapSign != null) {
            LOGGER.error("未处理加签数据：路由:{},appToken:{},userToken:{}", AppAuthContextHolder.getContext().getRequestInfo().getPath(), appToken, userToken);
        }
        if (!encrypt && dapEncrypt != null) {
            LOGGER.error("未处理加密数据：路由:{},appToken:{},userToken:{}", AppAuthContextHolder.getContext().getRequestInfo().getPath(), appToken, userToken);
        }

        List<String> appSecretList = new ArrayList<>();
        if (dapSign != null || dapEncrypt != null) {
            appSecretList = dapSecretSupport.getEnableAppSecret(UserUtils.getToken(), appToken);
        }
        if(!CollectionUtils.isEmpty(appSecretList)){
            request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecretList.get(0));
        }
        if (!sign && !encrypt) {
            filterChain.doFilter(encryptRequestWrapper, response);
            return;
        }
        if(CollectionUtils.isEmpty(appSecretList)){
            throw new BusinessException(CommonErrorCode.APP_ID_SECRET_NONE, new Object[]{UserUtils.getSysId()});
        }
        if (StrUtils.isEmpty(UserUtils.getSysId())) {
            throw new BusinessException(CommonErrorCode.APP_ID_NONE);
        }
        String appSecret = null;
        // 验签
        if (sign) {
            try {
                com.digiwin.dap.middleware.domain.DapSign signInfo = com.digiwin.dap.middleware.domain.DapSign.get(signHeader);
                VerifyUtils.sign(signInfo, () -> getLock(dapSign.resubmit(), signInfo.getNonce()));
                Map<String, String> queryParams = ServletUtil.getParamMap(request);
                String bodyParams = getBodyParam(bodyString, method);
                boolean verified = false;
                for (String appSecretTemp : appSecretList){
                    if(SignUtils.verify(JsonUtils.objToMap(signInfo), appSecretTemp, queryParams, bodyParams)){
                        verified = true;
                        appSecret = appSecretTemp;
                        request.setAttribute(EncryptConstants.APP_SECRET_KEY, appSecretTemp);
                        break;
                    }
                }
                if (!verified) {
                    throw new UnauthorizedException(CommonErrorCode.SIGN_INCONSISTENT_SIGNATURE);
                }
            } catch (Exception e) {
                LOGGER.error("验签失败：body:{},enable secret list:{}", bodyString, JsonUtils.objToJson(appSecretList));
                throw new UnauthorizedException(CommonErrorCode.SIGN_INCONSISTENT_SIGNATURE, e.getMessage());
            }
        }
        if (StrUtils.isEmpty(appSecret)) {
            throw new BusinessException(CommonErrorCode.APP_ID_SECRET_NONE, new Object[]{UserUtils.getSysId()});
        }
        // 解密
        if (encrypt) {
            com.digiwin.dap.middleware.domain.DapEncrypt body = JsonUtils.readValue(bodyString, com.digiwin.dap.middleware.domain.DapEncrypt.class);
            String data = AES.decryptIvCBC(body.geteData(), appSecret);
            encryptRequestWrapper.setBodyString(data.getBytes(StandardCharsets.UTF_8));
        }
        filterChain.doFilter(encryptRequestWrapper, response);
    }

    private boolean getLock(boolean resubmit, String nonce) {
        if (resubmit) {
            return true;
        }
        String key = String.format(EncryptConstants.REDIS_DWPAY_SIGN_NONCE, dapEnv.getAppName(), nonce);
        return RedisUtils.setIfAbsent(key, 1, Duration.ofMillis(2L * EncryptConstants.EXPIRE_TIME));
    }

    @Override
    public int getOrder() {
        return FilterOrderEnum.API_ENCRYPT.order();
    }

    /**
     * 检查是否需要验签
     * <p>
     * 1. 启动参数dap.middleware.sign=true </br>
     * 2. 注解属性force=ture </br>
     * 3. 注解属性request=true且请求头有值 </br>
     * <p>
     * 兼容配置：sign=false && force=false && request=ture </br>
     * 调用方：传signHeader则验签，不传则不验签
     */
    private boolean isSign(DapSign dapSign, String signHeader) {
        if (dapSign == null) {
            return false;
        }
        if (Boolean.TRUE.equals(dapEnv.getSign()) || dapSign.force()) {
            return true;
        }
        return dapSign.request() && StrUtils.isNotEmpty(signHeader);
    }

    /**
     * 检查是否需要解密 必要条件：POST、PUT请求
     * <p>
     * 1. 启动参数dap.middleware.encrypt=true </br>
     * 2. 注解属性force=ture </br>
     * <p>
     * 兼容配置：encrypt=false && force=false </br>
     * 调用方：传eData则解密，不传则不处理
     */
    private boolean isEncrypt(HttpServletRequest request, DapEncrypt dapEncrypt, String bodyString) {
        // POST、PUT接口增加@DapEncrypt注解时，入参body不能为空
        if (!bodyMethodList.contains(request.getMethod()) || dapEncrypt == null) {
            return false;
        }
        if (Boolean.TRUE.equals(dapEnv.getEncrypt()) || dapEncrypt.force()) {
            return true;
        }
        // 获取请求参数 eData是否有值
        if (StrUtils.isEmpty(bodyString)) {
            throw new BusinessException(CommonErrorCode.ENCRYPT_REQUEST_BODY_EMPTY);
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("请求参数body：{}", bodyString);
        }
        com.digiwin.dap.middleware.domain.DapEncrypt body = JsonUtils.readValue(bodyString, com.digiwin.dap.middleware.domain.DapEncrypt.class);
        return body != null && StrUtils.isNotEmpty(body.geteData());
    }
}
