package com.digiwin.athena.framework.dispatch;

import lombok.extern.slf4j.Slf4j;
import org.springframework.http.*;
import org.springframework.http.client.*;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerMapping;
import javax.servlet.http.HttpServletRequest;
import java.io.*;
import java.net.URLDecoder;
import java.util.*;

@Slf4j
public class LocalDispatcherInterceptor implements ClientHttpRequestInterceptor {

    private final DispatcherServlet dispatcherServlet;
    private final LocalControllerMatcher matcher;
    private final WebApplicationContext context;
    private final RestLocalProperties properties;
    private final AntPathMatcher pathMatcher = new AntPathMatcher();
    private static final String LOCAL_CALL_HEADER = "digi-local-call";

    public LocalDispatcherInterceptor(WebApplicationContext context, RequestMappingHandlerMapping mapping, DispatcherServlet dispatcherServlet, RestLocalProperties properties) {
        this.context = context;
        this.dispatcherServlet = dispatcherServlet;
        this.matcher = new LocalControllerMatcher(mapping);
        this.properties = properties;
    }

    @Override
    public ClientHttpResponse intercept(HttpRequest request, byte[] body, ClientHttpRequestExecution execution) throws IOException {
        String path = request.getURI().getPath();
        HttpMethod method = request.getMethod();
        if (!properties.isEnabled() || request.getHeaders().containsKey(LOCAL_CALL_HEADER) || isExcludedPath(path)) {
            return execution.execute(request, body);
        }
        if (matcher.matches(path, method)) {
            log.info("LocalDispatcherInterceptor url:{},走本地", path);
            return forwardLocally(request, body);
        }
        log.info("LocalDispatcherInterceptor url:{},走远程", path);
        return execution.execute(request, body);
    }

    private boolean isExcludedPath(String path) {
        if (properties.getExcludes() == null) {
            return false;
        }
        return properties.getExcludes().stream().anyMatch(pattern -> pathMatcher.match(pattern, path));
    }

    private ClientHttpResponse forwardLocally(HttpRequest request, byte[] body) throws IOException {
        MockHttpServletRequest servletRequest = new MockHttpServletRequest(context.getServletContext());
        servletRequest.setMethod(request.getMethod().name());
        servletRequest.setRequestURI(request.getURI().getPath());
        servletRequest.setContent(body);
        servletRequest.addHeader(LOCAL_CALL_HEADER, "true");

        String query = request.getURI().getRawQuery();
        if (query != null) {
            for (String pair : query.split("&")) {
                String[] kv = pair.split("=");
                if (kv.length == 2) {
                    servletRequest.addParameter(URLDecoder.decode(kv[0], "UTF-8"), URLDecoder.decode(kv[1], "UTF-8"));
                }
            }
        }

        request.getHeaders().forEach((key, values) -> values.forEach(value -> servletRequest.addHeader(key, value)));
//        servletRequest.setAttribute(GlobalConstant.AUTH_USER, AppAuthContextHolder.getContext().getAuthoredUser());
//        servletRequest.setAttribute(GlobalConstant.PROXY_AUTH_USER, AppAuthContextHolder.getContext().getProxyAuthoredUser());
        this.copyAttribute(servletRequest);


        MockHttpServletResponse servletResponse = new MockHttpServletResponse();
        try {
            dispatcherServlet.service(servletRequest, servletResponse);
        } catch (Exception e) {
            throw new IOException("Local dispatch failed: " + request.getURI(), e);
        }

        return new ClientHttpResponse() {
            @Override
            public HttpStatus getStatusCode() {
                return HttpStatus.valueOf(servletResponse.getStatus());
            }

            @Override
            public int getRawStatusCode() {
                return servletResponse.getStatus();
            }

            @Override
            public String getStatusText() {
                return Optional.ofNullable(servletResponse.getErrorMessage()).orElse("");
            }

            @Override
            public void close() {
            }

            @Override
            public InputStream getBody() {
                return new ByteArrayInputStream(servletResponse.getContentAsByteArray());
            }

            @Override
            public HttpHeaders getHeaders() {
                HttpHeaders headers = new HttpHeaders();
                for (String name : servletResponse.getHeaderNames()) {
                    headers.put(name, new ArrayList<>(servletResponse.getHeaders(name)));
                }
                return headers;
            }
        };
    }

    public void copyAttribute(MockHttpServletRequest servletRequest) {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (attributes != null) {
            HttpServletRequest currentRequest = attributes.getRequest();
//            Enumeration<String> headerNames = currentRequest.getHeaderNames();
//            while (headerNames.hasMoreElements()) {
//                String name = headerNames.nextElement();
//                Enumeration<String> values = currentRequest.getHeaders(name);
//                while (values.hasMoreElements()) {
//                    servletRequest.addHeader(name, values.nextElement());
//                }
//            }
            Enumeration<String> attrNames = currentRequest.getAttributeNames();
            while (attrNames.hasMoreElements()) {
                String name = attrNames.nextElement();
                Object value = currentRequest.getAttribute(name);
                servletRequest.setAttribute(name, value);
            }
        }
    }

}