package com.digiwin.athena.apimgmt.infra.spring.context;

import cn.hutool.core.util.StrUtil;
import com.digiwin.athena.apimgmt.constants.ApimgmtConstant;
import com.digiwin.athena.apimgmt.enums.LocaleEnum;
import com.digiwin.athena.apimgmt.infra.auth.ApiMgmtIdentity;
import com.digiwin.athena.apimgmt.infra.auth.IApiMgmtIdentityService;
import com.digiwin.athena.apimgmt.infra.context.IApiMgmtServiceContext;
import jakarta.annotation.Nullable;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/**
 * Spring服务上下文实现
 * 实现IServiceContext接口，从Spring环境中获取请求和响应头信息
 */
@Slf4j
public class SpringApiMgmtServiceContext implements IApiMgmtServiceContext {
    public static final String USER_ATTR = "__APIMGMT_USER";

    private final IApiMgmtIdentityService identitySupplier;

    public SpringApiMgmtServiceContext(IApiMgmtIdentityService identitySupplier) {
        this.identitySupplier = identitySupplier;
    }

    @Nullable
    @Override
    public String getToken() {
        return StrUtil.toStringOrNull(getRequestHeader().get(ApimgmtConstant.TOKEN));
    }

    @Override
    public Optional<ApiMgmtIdentity> getUserInfo() {
        RequestAttributes attributes = RequestContextHolder.getRequestAttributes();
        if (attributes == null) {
            return Optional.empty();
        }

        //noinspection unchecked
        Optional<ApiMgmtIdentity> identity = (Optional<ApiMgmtIdentity>) attributes.getAttribute(USER_ATTR, RequestAttributes.SCOPE_REQUEST);
        //noinspection OptionalAssignedToNull
        if (identity == null) {
            identity = Optional.ofNullable(identitySupplier.parseToken(getToken()));
            attributes.setAttribute(USER_ATTR, identity, RequestAttributes.SCOPE_REQUEST);
        }

        return identity;
    }

    @Nullable
    @Override
    public String getUserId() {
        return getUserInfo().map(ApiMgmtIdentity::getUserId).orElse(null);
    }

    @Nullable
    @Override
    public String getUserName() {
        return getUserInfo().map(ApiMgmtIdentity::getUserName).orElse(null);
    }

    @Nullable
    @Override
    public String getTenantId() {
        return getUserInfo().map(ApiMgmtIdentity::getTenantId).orElse(null);
    }

    @Nullable
    @Override
    public String getTeamId() {
        return getUserInfo().map(ApiMgmtIdentity::getTeamId).orElse(null);
    }

    @Nullable
    @Override
    public String getTeamType() {
        return getUserInfo().map(ApiMgmtIdentity::getTeamType).orElse(null);
    }

    @Nullable
    @Override
    public String getLocale() {
        return StrUtil.toStringOrNull(getRequestHeader().get(LocaleEnum.LOCALE.getType()));
    }

    @Nullable
    @Override
    public String getRouterKey() {
        return StrUtil.toStringOrNull(getRequestHeader().get(ApimgmtConstant.ROUTER_KEY));
    }

    @Override
    public Map<String, Object> getRequestHeader() {
        Map<String, Object> headers = new HashMap<>();
        HttpServletRequest request = getCurrentRequest();

        if (request == null) {
            log.warn("No current HttpServletRequest found in RequestContext");
            return headers;
        }

        Enumeration<String> headerNames = request.getHeaderNames();
        while (headerNames.hasMoreElements()) {
            String headerName = headerNames.nextElement();
            headers.put(headerName, request.getHeader(headerName));
        }

        return headers;
    }

    @Override
    public Map<String, Object> getResponseHeader() {
        Map<String, Object> headers = new HashMap<>();
        HttpServletResponse response = getCurrentResponse();

        if (response == null) {
            log.warn("No current HttpServletResponse found in RequestContext");
            return headers;
        }

        for (String headerName : response.getHeaderNames()) {
            headers.put(headerName, response.getHeader(headerName));
        }

        return headers;
    }

    private HttpServletRequest getCurrentRequest() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (requestAttributes instanceof ServletRequestAttributes) {
            return ((ServletRequestAttributes) requestAttributes).getRequest();
        }
        return null;
    }

    private HttpServletResponse getCurrentResponse() {
        RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        if (requestAttributes instanceof ServletRequestAttributes) {
            return ((ServletRequestAttributes) requestAttributes).getResponse();
        }
        return null;
    }
}
