package com.digiwin.dap.middle.sql.safe.api;

import com.digiwin.dap.middle.sql.safe.domain.SqlRequest;
import com.digiwin.dap.middle.sql.safe.domain.SqlRule;
import com.digiwin.dap.middle.sql.safe.service.CountSqlBuilder;
import com.digiwin.dap.middle.sql.safe.service.SqlSafeValidator;
import com.digiwin.dap.middle.sql.safe.service.SqlWhitelistLoader;
import com.digiwin.dap.middle.kms.constants.KeyConstant;
import com.digiwin.dap.middleware.commons.crypto.AES;
import com.digiwin.dap.middleware.commons.crypto.SignUtils;
import com.digiwin.dap.middleware.commons.util.EncryptUtils;
import com.digiwin.dap.middleware.commons.util.StrUtils;
import com.digiwin.dap.middleware.constant.GlobalConstants;
import com.digiwin.dap.middleware.domain.*;
import com.digiwin.dap.middleware.exception.BusinessException;
import com.digiwin.dap.middleware.exception.ThirdCallException;
import com.digiwin.dap.middleware.support.EnvSupport;
import com.digiwin.dap.middleware.util.JsonUtils;
import com.digiwin.dap.middleware.util.UserUtils;
import com.digiwin.dap.middleware.util.VerifyUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestTemplate;

import java.util.Collections;
import java.util.List;
import java.util.Map;

@RestController
@RequestMapping("/api/${spring.application.name}/v2/admin/sql")
public class AdminSQLController {

    @Autowired
    private DapEnv dapEnv;
    @Autowired
    private EnvSupport envSupport;
    @Autowired
    private JdbcTemplate jdbcTemplate;
    @Autowired
    private RestTemplate restTemplate;

    /**
     * BOSS提供统一查询入口，通过入参app切换查不同应用的数据库
     *
     * <ol>
     *     <li>基于配置白名单，校验SQL</li>
     *     <li>IAM、BOSS直接查询DB，其他组件通过APP远程调用对应组件查询</li>
     * </ol>
     *
     * @param body sql
     * @return data
     */
    @PostMapping("/query")
    public StdData<?> query(@RequestBody SqlRequest body) {
        String querySql = AES.decrypt(body.getSql(), KeyConstant.WECHAT_UNION_ID);
        if (StrUtils.isEmpty(querySql)) {
            return StdData.ok(Collections.emptyMap());
        }

        if (CommonCode.BOSS.name().equals(dapEnv.getAppName().toUpperCase())) {
            // boss是入口，直接查询
            SqlRule sqlRule = SqlWhitelistLoader.load(jdbcTemplate);
            SqlSafeValidator sqlSafeValidator = SqlSafeValidator.build(body.getApp(), sqlRule);
            // 先进行 SQL 白名单校验（防止恶意结构或使用 *）
            sqlSafeValidator.validate(querySql);

            if (!CommonCode.IAM.name().equalsIgnoreCase(body.getApp())
                    && !CommonCode.BOSS.name().equalsIgnoreCase(body.getApp())) {
                // 内部加签调用
                String signParam = EncryptUtils.sign(KeyConstant.OTHER, Collections.singletonMap("sql", querySql));
                body.setSign(signParam);
                // 发起远程调用
                return this.querySql(body);
            }
        } else {
            // 其他中间件通过boss调用
            SqlRequest dapSign = SqlRequest.get(body.getSign(), SqlRequest.class);
            dapSign.setSql(querySql);
            VerifyUtils.sign(dapSign, () -> true);
            SignUtils.verify(JsonUtils.objToMap(dapSign), KeyConstant.OTHER);
        }

        // 分页查询
        String countSql = CountSqlBuilder.buildCountSql(querySql);
        Object[] args = body.getParams() == null ? new Object[0] : body.getParams().toArray();
        if (countSql != null) {
            Long count = jdbcTemplate.queryForObject(countSql, Long.class, args);
            if (StrUtils.isEmpty(count)) {
                return StdData.ok(PageData.zero());
            }
            // 然后使用预编译方式执行 SQL（防止注入）
            List<Map<String, Object>> list = jdbcTemplate.queryForList(querySql, args);
            return StdData.ok(PageData.data(count, list));
        } else {
            List<Map<String, Object>> list = jdbcTemplate.queryForList(querySql, args);
            return StdData.ok(list);
        }
    }

    private StdData<?> querySql(SqlRequest body) {
        CommonCode app = CommonCode.ofName(body.getApp());
        String domain = DeployAreaEnum.isDev(dapEnv.getDeployArea()) ? envSupport.getUri(app) : envSupport.getLocalUri(app);
        String path = "/api/" + app.getPath() + "/v2/admin/sql/query";
        String uri = domain + path;
        try {
            HttpHeaders headers = new HttpHeaders();
            headers.setContentType(MediaType.APPLICATION_JSON);
            headers.add(GlobalConstants.HTTP_HEADER_USER_TOKEN_KEY, UserUtils.getToken());
            HttpEntity<SqlRequest> httpEntity = new HttpEntity<>(body, headers);
            return restTemplate.postForObject(uri, httpEntity, StdData.class);
        } catch (HttpStatusCodeException e) {
            throw new ThirdCallException(CommonErrorCode.REMOTE_UNEXPECTED, uri, e);
        } catch (Exception e) {
            throw new BusinessException(CommonErrorCode.BUSINESS, e.getMessage());
        }
    }
}
