package com.digiwin.athena.framework.rw;

import com.digiwin.athena.framework.rw.context.RWContextHolder;
import com.digiwin.athena.framework.rw.dto.ReadWriterDto;
import com.digiwin.athena.framework.rw.router.DataSourRouter;
import com.digiwin.athena.framework.rw.router.DbSwitchConfig;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.util.CollectionUtils;

import java.util.Arrays;
import java.util.List;


@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class RWTypeInterceptor implements Interceptor {

    private DbSwitchConfig dbSwitchConfig;

    public RWTypeInterceptor(DataSourRouter dataSourRouter, DbSwitchConfig dbSwitchConfig) {
        this.dbSwitchConfig = dbSwitchConfig;
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] objects = invocation.getArgs();
        MappedStatement ms = (MappedStatement) objects[0];
        String mapperClassName = getMapperClassName(ms.getId());
        boolean shouldIntercept = shouldIntercept(mapperClassName);
        if (!shouldIntercept) {
            return invocation.proceed();
        }
        SqlCommandType sqlCommandType = ms.getSqlCommandType();
        if (sqlCommandType == SqlCommandType.SELECT) {
            RWContextHolder.setDataSourceType(new ReadWriterDto(true, dbSwitchConfig.getReadMode()));
        } else if (Arrays.asList(SqlCommandType.INSERT, SqlCommandType.UPDATE, SqlCommandType.DELETE).contains(sqlCommandType)) {
            RWContextHolder.setDataSourceType(new ReadWriterDto(false, dbSwitchConfig.getWriteMode()));
        }
        return invocation.proceed();
    }


    private boolean shouldIntercept(String className) {
        List<String> basePackages = dbSwitchConfig.getMapperBasePackages();
        if (!CollectionUtils.isEmpty(basePackages)) {
            return basePackages.stream().anyMatch(className::startsWith);
        }
        return true;
    }

    private String getMapperClassName(String mappedStatementId) {
        if (mappedStatementId == null) {
            return "";
        }
        int lastDotIndex = mappedStatementId.lastIndexOf('.');
        return lastDotIndex != -1 ? mappedStatementId.substring(0, lastDotIndex) : mappedStatementId;
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }
}
