package com.digiwin.athena.framework.multitx;

import com.jugg.agile.framework.core.util.concurrent.JaThreadLocal;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.stereotype.Component;
import org.springframework.transaction.PlatformTransactionManager;
import org.springframework.transaction.TransactionDefinition;
import org.springframework.transaction.TransactionStatus;
import org.springframework.transaction.support.DefaultTransactionDefinition;
import org.springframework.util.CollectionUtils;

import java.util.Map;
import java.util.Vector;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

/**
 * 多线程事务处理 AOP
 */
@Slf4j
@Aspect
@ConditionalOnProperty(prefix = "athena.tx", name = "enable", havingValue = "true")
public class TransactionAop {

    private static final Map<String, MultiTransactionContext> contextMap = new ConcurrentHashMap<>();

    private static final JaThreadLocal<TransactionThreadContext> contextHolder = new JaThreadLocal<>();

    @Autowired
    private PlatformTransactionManager transactionManager;

    @Around("@annotation(mainTransaction)")
    public void mainIntercept(ProceedingJoinPoint joinPoint, MainTransaction mainTransaction) throws Throwable {

        MultiTransactionContext context = preInitContext(mainTransaction);
        try {
            // 执行目标方法
            joinPoint.proceed();
        } catch (Throwable e) {
            context.getExceptionVector().add(0, e);
            // 标记回滚
            context.getRollBackFlag().set(true);
        }
        // 等待所有子线程完成
        context.getSubDownLatch().await();

        // 统一决定提交或回滚子事务
        if (context.getRollBackFlag().get()) {
            // 如果需要回滚，通知所有子线程回滚
            context.getTransactionStatuses().forEach(transactionManager::rollback);
        } else {
            // 如果不需要回滚，提交所有子事务
            context.getTransactionStatuses().forEach(transactionManager::commit);
        }

        if (!CollectionUtils.isEmpty(context.getExceptionVector())) {
            contextMap.remove(Thread.currentThread().getName());
            contextHolder.remove();
            // 抛出异常
            throw context.getExceptionVector().get(0);
        }
    }

    @Around("@annotation(subTransaction)")
    public void sonIntercept(ProceedingJoinPoint joinPoint, SubTransaction subTransaction) throws Throwable {
        TransactionThreadContext transactionThreadContext = contextHolder.get();
        String threadName = transactionThreadContext == null ? "" : transactionThreadContext.getThreadName();
        MultiTransactionContext context = contextMap.get(threadName);

        //如果没有主事务或者没有设置上下文，则多线程事务失效
        if (context == null) {
            joinPoint.proceed();
            return;
        }

        CountDownLatch sonDownLatch = context.getSubDownLatch();
        AtomicBoolean rollBackFlag = context.getRollBackFlag();
        Vector<Throwable> exceptionVector = context.getExceptionVector();

        // 创建新的事务
        DefaultTransactionDefinition def = new DefaultTransactionDefinition();
        // 设置事务传播行为
        def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRES_NEW);
        TransactionStatus status = transactionManager.getTransaction(def);

        try {
            // 执行子事务
            joinPoint.proceed();
            // 将子事务的事务状态添加到上下文中
            context.getTransactionStatuses().add(status);
        } catch (Throwable e) {
            // 记录异常并标记回滚
            exceptionVector.add(0, e);
            rollBackFlag.set(true);
        } finally {
            // 子线程完成
            sonDownLatch.countDown();
        }
    }

    /**
     * 初始化上下文
     *
     * @param mainTransaction
     * @return
     */
    private MultiTransactionContext preInitContext(MainTransaction mainTransaction) {
        String threadName = Thread.currentThread().getName();
        MultiTransactionContext context = new MultiTransactionContext(mainTransaction.value());
        // 存储当前线程的事务上下文
        contextMap.put(threadName, context);
        TransactionThreadContext transactionThreadContext = new TransactionThreadContext();
        transactionThreadContext.setThreadName(threadName);
        contextHolder.set(transactionThreadContext);
        return context;
    }

    static class TransactionThreadContext {

        private String threadName;

        public String getThreadName() {
            return threadName;
        }

        public void setThreadName(String threadName) {
            this.threadName = threadName;
        }
    }
}


