package com.digiwin.athena.adt.app.config.filter;

import ch.qos.logback.classic.util.LogbackMDCAdapter;
import com.jugg.agile.framework.core.dapper.log.JaMDC;
import com.jugg.agile.framework.meta.adapter.JaCoreAdapter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.slf4j.MDC;
import org.slf4j.spi.MDCAdapter;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.*;

/**
 * Safe MDC filter that does not extend Spring's GenericFilterBean/OncePerRequestFilter
 * to avoid initialization issues. It creates fresh ja-trace-id/ja-req-id per request,
 * masks upstream trace headers, records upstream id only for diagnostics,
 * and clears only keys it sets. Does not depend on Spring Filter base classes.
 */
@Slf4j
@Component
public class SafeMdcTraceFilter implements Filter {

    private static final String JA_TRACE_ID = "ja-trace-id";
    private static final String JA_REQ_ID = "ja-req-id";
    private static final String TRACE_ID = "traceId";
    private static final String PTX_ID = "PtxId";
    private static final String UPSTREAM_TRACE_ID = "upstream-trace-id";
    private static final String APPLIED_ATTR = SafeMdcTraceFilter.class.getName() + ".APPLIED";
    private static final Set<String> BLOCKED_HEADERS = new HashSet<>(Arrays.asList(
            // Pinpoint remote trace headers
            "x-transactionid", "x-spanid", "x-parentspanid", "x-sampled", "x-flags",
            // W3C
            "traceparent", "tracestate",
            // B3/Zipkin
            "x-b3-traceid", "x-b3-spanid", "x-b3-parentspanid", "x-b3-sampled", "x-b3-flags"
    ));

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) {
            chain.doFilter(request, response);
            return;
        }
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse resp = (HttpServletResponse) response;

        // guard against re-entry
        if (req.getAttribute(APPLIED_ATTR) != null) {
            chain.doFilter(request, response);
            return;
        }
        req.setAttribute(APPLIED_ATTR, Boolean.TRUE);
        log.info("这是mdc的链路ID"+MDC.get("PtxId"));
        log.info("这是JaMDC的链路ID"+JaMDC.get("PtxId"));
        log.info("这是JaCoreAdapter的链路ID"+JaCoreAdapter.getTraceId());
        MDCAdapter mdc = MDC.getMDCAdapter();
        if (mdc instanceof LogbackMDCAdapter) {
            Map<String,String> logback = ((LogbackMDCAdapter)mdc).getPropertyMap();
            log.info("这是LogbackMDCAdapter的链路ID"+logback.get("PtxId"));
        } else {
            Map<String,String> contextMap = mdc.getCopyOfContextMap();
            log.info("这是contextMap的链路ID"+contextMap.get("PtxId"));
        }
        String upstream = firstNonBlank(header(req, JA_TRACE_ID), header(req, "x-b3-traceid"), header(req, "traceparent"));
        String traceId = mdc.get("PtxId");
        if(StringUtils.isBlank(traceId)) {
            traceId = "adt-" + safeCreateTraceId();
        }
        String reqId = firstNonBlank(header(req, JA_REQ_ID), header(req, "x-request-id"));
        if (isBlank(reqId)) reqId = genId();
        try {
            MDC.put(JA_TRACE_ID, traceId);
            MDC.put(PTX_ID, traceId);
            MDC.put(JA_REQ_ID, reqId);
            JaMDC.put(traceId);
            mdc.put(PTX_ID, traceId);
            if (!isBlank(upstream)) {
                MDC.put(UPSTREAM_TRACE_ID, upstream);
            }
            resp.setHeader(JA_TRACE_ID, traceId);
            resp.setHeader(JA_REQ_ID, reqId);

            // wrap request to mask remote-trace headers from downstream/agent
            HttpServletRequest wrapped = new HeaderMaskingRequest(req, BLOCKED_HEADERS);
            chain.doFilter(wrapped, response);
        } finally {
            MDC.remove(JA_TRACE_ID);
            MDC.remove(JA_REQ_ID);
            MDC.remove(TRACE_ID);
            MDC.remove(PTX_ID);
            MDC.remove(UPSTREAM_TRACE_ID);
            MDC.remove(PTX_ID);
            JaMDC.remove();
            req.removeAttribute(APPLIED_ATTR);
        }
    }

    private static class HeaderMaskingRequest extends HttpServletRequestWrapper {
        private final Set<String> blockedLower;

        HeaderMaskingRequest(HttpServletRequest request, Set<String> blocked) {
            super(request);
            Set<String> s = new HashSet<>();
            for (String h : blocked) s.add(h.toLowerCase(Locale.ROOT));
            this.blockedLower = Collections.unmodifiableSet(s);
        }

        private boolean blocked(String name) {
            return name != null && blockedLower.contains(name.toLowerCase(Locale.ROOT));
        }

        @Override
        public String getHeader(String name) {
            if (blocked(name)) return null;
            return super.getHeader(name);
        }

        @Override
        public Enumeration<String> getHeaders(String name) {
            if (blocked(name)) return Collections.emptyEnumeration();
            return super.getHeaders(name);
        }

        @Override
        public Enumeration<String> getHeaderNames() {
            List<String> names = Collections.list(super.getHeaderNames());
            names.removeIf(this::blocked);
            return Collections.enumeration(names);
        }
    }

    private static String header(HttpServletRequest req, String name) {
        String v = req.getHeader(name);
        return v == null ? null : v.trim();
    }

    private static String parseTraceParent(String traceparent) {
        if (isBlank(traceparent)) return null;
        String[] parts = traceparent.split("-");
        if (parts.length >= 4) {
            return parts[1];
        }
        return null;
    }

    private static boolean isBlank(String v) {
        return v == null || v.trim().isEmpty();
    }

    private static String firstNonBlank(String... arr) {
        if (arr == null) return null;
        for (String s : arr) {
            if (!isBlank(s)) return s.trim();
        }
        return null;
    }

    private static String genId() {
        return UUID.randomUUID().toString().replace("-", "");
    }

    private static String safeCreateTraceId() {
        try {
            return JaMDC.createTraceId();
        } catch (Throwable ignore) {
            return genId();
        }
    }
}
