package com.digiwin.athena.framework.rw.strategy;

import cn.hutool.core.util.ReflectUtil;
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import com.alibaba.fastjson.JSON;
import com.digiwin.athena.framework.rw.dto.ReadWriterDto;
import com.digiwin.athena.framework.rw.router.DbSwitchConfig;
import com.digiwin.athena.framework.rw.exception.MyBatisShardException;
import lombok.Getter;
import lombok.NonNull;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.reflection.MetaObject;
import org.springframework.util.StringUtils;

import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;

@Getter
public abstract class AbstractShardStrategy {

    private final String originalTableName;

    public AbstractShardStrategy(String originalTableName) {
        this.originalTableName = originalTableName;
    }

    protected String getOriginalTableName() {
        return originalTableName;
    }

    public boolean replaceTableName(@NonNull SQLExprTableSource sqlExprTableSource, @NonNull BoundSql boundSql, DbSwitchConfig dbSwitchConfig) {
        sqlExprTableSource.setExpr(getShardTableName(dbSwitchConfig));
        return false;
    }

    public String getShardTableName(DbSwitchConfig dbSwitchConfig) {
        String shardTableName = dbSwitchConfig.getTableMapping().get(ShardStrategyContext.handleName(this.getOriginalTableName()));
        shardTableName = StringUtils.isEmpty(shardTableName) ? this.getOriginalTableName() : shardTableName;
        return shardTableName;
    }

    public abstract void processParams(@NonNull MetaObject metaObject, @NonNull BoundSql boundSql, @NonNull SchemaStatVisitor schemaStatVisitor);

    public void register() {
        ShardStrategyContext.registerStrategy(getOriginalTableName(), this);
    }

    public abstract void processBefore(ReadWriterDto readWriterDto);

    @SuppressWarnings({"rawtypes", "unchecked"})
    protected void setColumnValue(BoundSql boundSql, String columnName, Collection<TableStat.Column> columns, Supplier<Object> supplier) {

        boolean findColumn = false;
        for (TableStat.Column column : columns) {
            if (Objects.equals(column.getName(), columnName)) {
                findColumn = true;
                break;
            }
        }
        if (!findColumn) {
            return;
        }

        Object parameterObject = boundSql.getParameterObject();
        if (parameterObject instanceof Map) {
            if (((Map) parameterObject).containsKey(columnName)) {
                ((Map) parameterObject).computeIfAbsent(columnName, k -> supplier.get());
                return;
            }
            Set<Map.Entry> entrySet = ((Map) parameterObject).entrySet();
            for (Map.Entry entry : entrySet) {
                Object value = entry.getValue();
                if (value instanceof Collection) {
                    for (Object o : (Collection) value) {
                        setNx(o, columnName, supplier);
                    }
                    return;
                }
            }
            return;
        }
        setNx(parameterObject, columnName, supplier);
    }

    private void setNx(@NonNull Object o, @NonNull String key, Supplier<Object> supplier) {
        if (Objects.nonNull(ReflectUtil.getFieldValue(o, key))) {
            return;
        }
        ReflectUtil.setFieldValue(o, key, supplier.get());
    }

    @SuppressWarnings({"rawtypes", "unchecked"})
    protected Object getValue(@NonNull Object parameterObject, @NonNull String key) {
        if (parameterObject instanceof Map) {
            if (((Map) parameterObject).containsKey(key)) {
                return ((Map) parameterObject).get(key);
            }
            Set<Map.Entry> entrySet = ((Map) parameterObject).entrySet();
            for (Map.Entry entry : entrySet) {
                Object value = entry.getValue();
                if (ReflectUtil.hasField(value.getClass(), key)) {
                    return ReflectUtil.getFieldValue(value, key);
                }
                if (value instanceof Collection) {
                    for (Object o : (Collection) value) {
                        return ReflectUtil.getFieldValue(o, key);
                    }
                }
            }
            throwEx(parameterObject, key);
        }

        return ReflectUtil.getFieldValue(parameterObject, key);
    }

    private void throwEx(@NonNull Object parameterObject, @NonNull String key) {
        throw new MyBatisShardException(String.format("Error: Method AbstractShardStrategy.getValue, parameterObject(%s): %s, key: %s", parameterObject.getClass().getSimpleName(), JSON.toJSONString(parameterObject), key));
    }


}
