package com.digiwin.dap.middle.cache.limiter;

import com.digiwin.dap.middle.cache.config.LimiterConfig;
import com.digiwin.dap.middleware.domain.CommonErrorCode;
import com.digiwin.dap.middleware.exception.RequestNotPermittedException;
import com.digiwin.dap.middleware.util.UserUtils;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

@Aspect
@Component
public class RateLimiterInterceptor {

    private static final Logger logger = LoggerFactory.getLogger(RateLimiterInterceptor.class);

    private final RedisTemplate<String, Object> redisTemplate;
    private final RedisScript<Boolean> script;

    @Value("${spring.application.name:dap}")
    private String appName;

    public RateLimiterInterceptor(RedisTemplate<String, Object> redisTemplate, @Qualifier(LimiterConfig.REDIS_SCRIPT_NAME) RedisScript<Boolean> script) {
        this.redisTemplate = redisTemplate;
        this.script = script;
    }

    @Before("(@annotation(com.digiwin.dap.middle.cache.limiter.RateLimiter))")
    public void execute(JoinPoint joinPoint) {
        Boolean allowed = true;
        try {
            RateLimiter rateLimiter = ((MethodSignature) joinPoint.getSignature()).getMethod().getAnnotation(RateLimiter.class);
            List<String> keys = getKeys(appName, rateLimiter);

            // How many requests per second do you want a user to be allowed to do?
            int replenishRate = rateLimiter.replenishRate();

            // How much bursting do you want to allow?
            int burstCapacity = rateLimiter.burstCapacity();

            // How many tokens are requested per request?
            int requestedTokens = rateLimiter.requestedTokens();

            // allowed, tokens_left = redis.eval(SCRIPT, keys, args)
            allowed = redisTemplate.execute(this.script, keys, replenishRate, burstCapacity, requestedTokens);

        } catch (Exception e) {
            /*
             * We don't want a hard dependency on Redis to allow traffic. Make sure to set
             * an alert so you know if this is happening too much. Stripe's observed
             * failure rate is 0.01%.
             */
            logger.error("Error determining if user allowed from redis", e);
        }
        if (allowed == null || !allowed) {
            throw new RequestNotPermittedException(CommonErrorCode.TOO_MANY_REQUESTS);
        }
    }

    static List<String> getKeys(String appName, RateLimiter rateLimiter) {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = Objects.requireNonNull(attributes).getRequest();

        // Make a unique key.
        StringBuilder prefix = new StringBuilder(appName + ":limiter:{" + request.getServletPath() + "}");
        if (rateLimiter.user() && UserUtils.getUserId() != null) {
            prefix.append(":").append(UserUtils.getUserId());
        }
        if (rateLimiter.tenant() && UserUtils.getTenantId() != null) {
            prefix.append(":").append(UserUtils.getTenantId());
        }
        if (rateLimiter.sys() && UserUtils.getSysId() != null) {
            prefix.append(":").append(UserUtils.getSysId());
        }
        // You need two Redis keys for Token Bucket.
        String tokenKey = prefix + ".tokens";
        String timestampKey = prefix + ".timestamp";
        return Arrays.asList(tokenKey, timestampKey);
    }
}
