package com.digiwin.dap.middle.cache.limiter;

import com.digiwin.dap.middle.cache.limiter.constants.RateLimiterConstant;
import com.digiwin.dap.middle.cache.limiter.enums.RateLimitingDimensionEnum;
import com.digiwin.dap.middle.cache.limiter.enums.RateLimitingDimensionOperatorEnum;
import com.digiwin.dap.middleware.commons.util.StrUtils;
import com.digiwin.dap.middleware.util.NetUtils;
import com.digiwin.dap.middleware.util.UserUtils;
import org.aspectj.lang.Signature;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * @author michael
 * @since 2.7.19.16
 */
public class RateLimiterKeyGenerators {

    private static final ExpressionParser parser = new SpelExpressionParser();

    public RateLimiterKeyGenerators() {
    }

    public static List<String> generateRateLimiterKey(Object[] args, Object target, Signature signature, String param,
                                                      RateLimitingDimensionEnum[] rateLimitingDimensionEnums,
                                                      RateLimitingDimensionOperatorEnum rateLimitingDimensionOperatorEnum) {
        List<String> dimensionList = new ArrayList<>();
        String defaultPrefix = RateLimiterConstant.HASH_TAG + target.getClass().getName() + "." + signature.getName();
        for (RateLimitingDimensionEnum rateLimitingDimensionEnum : rateLimitingDimensionEnums) {
            switch (rateLimitingDimensionEnum) {
                case TENANT:
                    if (StrUtils.isNotEmpty(UserUtils.getTenantId())) {
                        dimensionList.add(UserUtils.getTenantId());
                    }
                    break;
                case USER:
                    if (StrUtils.isNotEmpty(UserUtils.getUserId())) {
                        dimensionList.add(UserUtils.getUserId());
                    }
                    break;
                case SYS:
                    if (StrUtils.isNotEmpty(UserUtils.getSysId())) {
                        dimensionList.add(UserUtils.getSysId());
                    }
                    break;
                case IP:
                    dimensionList.add(getClientIP());
                    break;
                case CUSTOM_PARAM:
                    StandardEvaluationContext context = new StandardEvaluationContext();
                    context.setRootObject(target);
                    String[] parameterNames = ((org.aspectj.lang.reflect.CodeSignature) signature).getParameterNames();
                    for (int i = 0; i < parameterNames.length; i++) {
                        context.setVariable(parameterNames[i], args[i]);
                    }
                    dimensionList.add(parser.parseExpression(param).getValue(context, String.class));
                    break;
                case UNDEFINED:
                    break;
            }
        }
        if (rateLimitingDimensionOperatorEnum.equals(RateLimitingDimensionOperatorEnum.OR)) {
            return dimensionList.stream()
                    .map(element -> RateLimiterConstant.RATE_LIMITER_PREFIX + defaultPrefix + ":" + element)
                    .collect(Collectors.toList());
        }
        if (rateLimitingDimensionOperatorEnum.equals(RateLimitingDimensionOperatorEnum.AND)) {
            return Collections.singletonList(RateLimiterConstant.RATE_LIMITER_PREFIX + defaultPrefix + ":" + String.join(":", dimensionList));
        }
        return Collections.emptyList();
    }

    private static String getServletPath() {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = Objects.requireNonNull(attributes).getRequest();
        return request.getServletPath();
    }

    private static String getClientIP() {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = Objects.requireNonNull(attributes).getRequest();
        return NetUtils.getClientIP(request);
    }
}
