feat: 新增copilot所需接口

This commit is contained in:
junjun 2024-07-09 11:48:44 +08:00
parent aa205e9a09
commit 1020d2375b
22 changed files with 1379 additions and 0 deletions

View File

@ -0,0 +1,104 @@
package io.dataease.copilot.api;
import io.dataease.api.copilot.dto.ReceiveDTO;
import io.dataease.api.copilot.dto.SendDTO;
import io.dataease.copilot.dao.auto.entity.CoreCopilotConfig;
import io.dataease.copilot.dao.auto.mapper.CoreCopilotConfigMapper;
import io.dataease.exception.DEException;
import io.dataease.utils.HttpClientConfig;
import io.dataease.utils.HttpClientUtil;
import io.dataease.utils.JsonUtil;
import jakarta.annotation.Resource;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.Header;
import org.apache.http.HttpResponse;
import org.json.simple.JSONObject;
import org.springframework.stereotype.Component;
import java.util.Base64;
import java.util.Map;
/**
* @Author Junjun
*/
@Component
public class CopilotAPI {
public static final String TOKEN = "/auth/token/license";
public static final String FREE_TOKEN = "/auth/token/free";
public static final String API = "/copilot/v1";
public static final String CHART = "/generate-chart";
public static final String RATE_LIMIT = "/rate-limit";
@Resource
private CoreCopilotConfigMapper coreCopilotConfigMapper;
public String basicAuth(String userName, String password) {
String auth = userName + ":" + password;
String encodedAuth = Base64.getEncoder().encodeToString(auth.getBytes());
return "Basic " + encodedAuth;
}
public String bearerAuth(String token) {
return "Bearer " + token;
}
public CoreCopilotConfig getConfig() {
CoreCopilotConfig coreCopilotConfig = coreCopilotConfigMapper.selectById(1);
coreCopilotConfig.setPwd(new String(Base64.getDecoder().decode(coreCopilotConfig.getPwd())));
return coreCopilotConfig;
}
public String getToken(String license) throws Exception {
String url = getConfig().getCopilotUrl() + TOKEN;
JSONObject json = new JSONObject();
json.put("licenseText", license);
HttpClientConfig httpClientConfig = new HttpClientConfig();
httpClientConfig.addHeader("Authorization", basicAuth(getConfig().getUsername(), getConfig().getPwd()));
String tokenJson = HttpClientUtil.post(url, json.toString(), httpClientConfig);
return (String) JsonUtil.parse(tokenJson, Map.class).get("accessToken");
}
public String getFreeToken() throws Exception {
String url = getConfig().getCopilotUrl() + FREE_TOKEN;
HttpClientConfig httpClientConfig = new HttpClientConfig();
httpClientConfig.addHeader("Authorization", basicAuth(getConfig().getUsername(), getConfig().getPwd()));
String tokenJson = HttpClientUtil.post(url, "", httpClientConfig);
return (String) JsonUtil.parse(tokenJson, Map.class).get("accessToken");
}
public ReceiveDTO generateChart(String token, SendDTO sendDTO) {
String url = getConfig().getCopilotUrl() + API + CHART;
String request = (String) JsonUtil.toJSONString(sendDTO);
HttpClientConfig httpClientConfig = new HttpClientConfig();
httpClientConfig.addHeader("Authorization", bearerAuth(token));
String result = HttpClientUtil.post(url, request, httpClientConfig);
return JsonUtil.parseObject(result, ReceiveDTO.class);
}
public void checkRateLimit(String token) {
String url = getConfig().getCopilotUrl() + API + RATE_LIMIT;
HttpClientConfig httpClientConfig = new HttpClientConfig();
httpClientConfig.addHeader("Authorization", bearerAuth(token));
HttpResponse httpResponse = HttpClientUtil.postWithHeaders(url, null, httpClientConfig);
Header[] allHeaders = httpResponse.getAllHeaders();
String limit = "";
String seconds = "";
for (Header header : allHeaders) {
if (StringUtils.equalsIgnoreCase(header.getName(), "x-rate-limit-remaining")) {
limit = header.getValue();
}
if (StringUtils.equalsIgnoreCase(header.getName(), "x-rate-limit-retry-after-seconds")) {
seconds = header.getValue();
}
}
if (Long.parseLong(limit) <= 0) {
DEException.throwException(String.format("当前请求频率已达上限,请在%s秒后重试", seconds));
}
}
}

View File

@ -0,0 +1,71 @@
package io.dataease.copilot.dao.auto.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import java.io.Serializable;
/**
* <p>
*
* </p>
*
* @author fit2cloud
* @since 2024-07-08
*/
@TableName("core_copilot_config")
public class CoreCopilotConfig implements Serializable {
private static final long serialVersionUID = 1L;
/**
* ID
*/
private Long id;
private String copilotUrl;
private String username;
private String pwd;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getCopilotUrl() {
return copilotUrl;
}
public void setCopilotUrl(String copilotUrl) {
this.copilotUrl = copilotUrl;
}
public String getUsername() {
return username;
}
public void setUsername(String username) {
this.username = username;
}
public String getPwd() {
return pwd;
}
public void setPwd(String pwd) {
this.pwd = pwd;
}
@Override
public String toString() {
return "CoreCopilotConfig{" +
"id = " + id +
", copilotUrl = " + copilotUrl +
", username = " + username +
", pwd = " + pwd +
"}";
}
}

View File

@ -0,0 +1,276 @@
package io.dataease.copilot.dao.auto.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import java.io.Serializable;
/**
* <p>
*
* </p>
*
* @author fit2cloud
* @since 2024-07-04
*/
@TableName("core_copilot_msg")
public class CoreCopilotMsg implements Serializable {
private static final long serialVersionUID = 1L;
/**
* ID
*/
private Long id;
/**
* 用户ID
*/
private Long userId;
/**
* 数据集ID
*/
private Long datasetGroupId;
/**
* user or api
*/
private String msgType;
/**
* mysql oracle ...
*/
private String engineType;
/**
* create sql
*/
private String schemaSql;
/**
* 用户提问
*/
private String question;
/**
* 历史信息
*/
private String history;
/**
* copilot 返回 sql
*/
private String copilotSql;
/**
* copilot 返回信息
*/
private String apiMsg;
/**
* sql 状态
*/
private Integer sqlOk;
/**
* chart 状态
*/
private Integer chartOk;
/**
* chart 内容
*/
private String chart;
/**
* 视图数据
*/
private String chartData;
/**
* 执行请求的SQL
*/
private String execSql;
/**
* msg状态0失败 1成功
*/
private Integer msgStatus;
/**
* de错误信息
*/
private String errMsg;
/**
* 创建时间
*/
private Long createTime;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public Long getUserId() {
return userId;
}
public void setUserId(Long userId) {
this.userId = userId;
}
public Long getDatasetGroupId() {
return datasetGroupId;
}
public void setDatasetGroupId(Long datasetGroupId) {
this.datasetGroupId = datasetGroupId;
}
public String getMsgType() {
return msgType;
}
public void setMsgType(String msgType) {
this.msgType = msgType;
}
public String getEngineType() {
return engineType;
}
public void setEngineType(String engineType) {
this.engineType = engineType;
}
public String getSchemaSql() {
return schemaSql;
}
public void setSchemaSql(String schemaSql) {
this.schemaSql = schemaSql;
}
public String getQuestion() {
return question;
}
public void setQuestion(String question) {
this.question = question;
}
public String getHistory() {
return history;
}
public void setHistory(String history) {
this.history = history;
}
public String getCopilotSql() {
return copilotSql;
}
public void setCopilotSql(String copilotSql) {
this.copilotSql = copilotSql;
}
public String getApiMsg() {
return apiMsg;
}
public void setApiMsg(String apiMsg) {
this.apiMsg = apiMsg;
}
public Integer getSqlOk() {
return sqlOk;
}
public void setSqlOk(Integer sqlOk) {
this.sqlOk = sqlOk;
}
public Integer getChartOk() {
return chartOk;
}
public void setChartOk(Integer chartOk) {
this.chartOk = chartOk;
}
public String getChart() {
return chart;
}
public void setChart(String chart) {
this.chart = chart;
}
public String getChartData() {
return chartData;
}
public void setChartData(String chartData) {
this.chartData = chartData;
}
public String getExecSql() {
return execSql;
}
public void setExecSql(String execSql) {
this.execSql = execSql;
}
public Integer getMsgStatus() {
return msgStatus;
}
public void setMsgStatus(Integer msgStatus) {
this.msgStatus = msgStatus;
}
public String getErrMsg() {
return errMsg;
}
public void setErrMsg(String errMsg) {
this.errMsg = errMsg;
}
public Long getCreateTime() {
return createTime;
}
public void setCreateTime(Long createTime) {
this.createTime = createTime;
}
@Override
public String toString() {
return "CoreCopilotMsg{" +
"id = " + id +
", userId = " + userId +
", datasetGroupId = " + datasetGroupId +
", msgType = " + msgType +
", engineType = " + engineType +
", schemaSql = " + schemaSql +
", question = " + question +
", history = " + history +
", copilotSql = " + copilotSql +
", apiMsg = " + apiMsg +
", sqlOk = " + sqlOk +
", chartOk = " + chartOk +
", chart = " + chart +
", chartData = " + chartData +
", execSql = " + execSql +
", msgStatus = " + msgStatus +
", errMsg = " + errMsg +
", createTime = " + createTime +
"}";
}
}

View File

@ -0,0 +1,74 @@
package io.dataease.copilot.dao.auto.entity;
import com.baomidou.mybatisplus.annotation.TableName;
import java.io.Serializable;
/**
* <p>
*
* </p>
*
* @author fit2cloud
* @since 2024-07-08
*/
@TableName("core_copilot_token")
public class CoreCopilotToken implements Serializable {
private static final long serialVersionUID = 1L;
/**
* ID
*/
private Long id;
/**
* free or license
*/
private String type;
private String token;
private Long updateTime;
public Long getId() {
return id;
}
public void setId(Long id) {
this.id = id;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getToken() {
return token;
}
public void setToken(String token) {
this.token = token;
}
public Long getUpdateTime() {
return updateTime;
}
public void setUpdateTime(Long updateTime) {
this.updateTime = updateTime;
}
@Override
public String toString() {
return "CoreCopilotToken{" +
"id = " + id +
", type = " + type +
", token = " + token +
", updateTime = " + updateTime +
"}";
}
}

View File

@ -0,0 +1,18 @@
package io.dataease.copilot.dao.auto.mapper;
import io.dataease.copilot.dao.auto.entity.CoreCopilotConfig;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
/**
* <p>
* Mapper 接口
* </p>
*
* @author fit2cloud
* @since 2024-07-08
*/
@Mapper
public interface CoreCopilotConfigMapper extends BaseMapper<CoreCopilotConfig> {
}

View File

@ -0,0 +1,18 @@
package io.dataease.copilot.dao.auto.mapper;
import io.dataease.copilot.dao.auto.entity.CoreCopilotMsg;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
/**
* <p>
* Mapper 接口
* </p>
*
* @author fit2cloud
* @since 2024-07-04
*/
@Mapper
public interface CoreCopilotMsgMapper extends BaseMapper<CoreCopilotMsg> {
}

View File

@ -0,0 +1,18 @@
package io.dataease.copilot.dao.auto.mapper;
import io.dataease.copilot.dao.auto.entity.CoreCopilotToken;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Mapper;
/**
* <p>
* Mapper 接口
* </p>
*
* @author fit2cloud
* @since 2024-07-08
*/
@Mapper
public interface CoreCopilotTokenMapper extends BaseMapper<CoreCopilotToken> {
}

View File

@ -0,0 +1,377 @@
package io.dataease.copilot.manage;
import com.fasterxml.jackson.core.type.TypeReference;
import io.dataease.api.copilot.dto.DESendDTO;
import io.dataease.api.copilot.dto.MsgDTO;
import io.dataease.api.copilot.dto.ReceiveDTO;
import io.dataease.api.copilot.dto.TokenDTO;
import io.dataease.api.dataset.union.DatasetGroupInfoDTO;
import io.dataease.api.dataset.union.UnionDTO;
import io.dataease.chart.utils.ChartDataBuild;
import io.dataease.copilot.api.CopilotAPI;
import io.dataease.dataset.dao.auto.entity.CoreDatasetGroup;
import io.dataease.dataset.dao.auto.mapper.CoreDatasetGroupMapper;
import io.dataease.dataset.manage.DatasetDataManage;
import io.dataease.dataset.manage.DatasetSQLManage;
import io.dataease.dataset.manage.DatasetTableFieldManage;
import io.dataease.dataset.manage.PermissionManage;
import io.dataease.engine.constant.DeTypeConstants;
import io.dataease.engine.utils.Utils;
import io.dataease.exception.DEException;
import io.dataease.extensions.datasource.constant.SqlPlaceholderConstants;
import io.dataease.extensions.datasource.dto.DatasetTableFieldDTO;
import io.dataease.extensions.datasource.dto.DatasourceRequest;
import io.dataease.extensions.datasource.dto.DatasourceSchemaDTO;
import io.dataease.extensions.datasource.dto.TableField;
import io.dataease.extensions.datasource.factory.ProviderFactory;
import io.dataease.extensions.datasource.provider.Provider;
import io.dataease.extensions.view.dto.ColumnPermissionItem;
import io.dataease.i18n.Translator;
import io.dataease.license.dao.po.LicensePO;
import io.dataease.license.manage.F2CLicManage;
import io.dataease.license.utils.LicenseUtil;
import io.dataease.utils.AuthUtils;
import io.dataease.utils.BeanUtils;
import io.dataease.utils.JsonUtil;
import jakarta.annotation.Resource;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import java.math.BigDecimal;
import java.util.*;
import java.util.stream.Collectors;
/**
* @Author Junjun
*/
@Component
public class CopilotManage {
@Resource
private DatasetSQLManage datasetSQLManage;
@Resource
private CoreDatasetGroupMapper coreDatasetGroupMapper;
@Resource
private DatasetTableFieldManage datasetTableFieldManage;
@Resource
private DatasetDataManage datasetDataManage;
@Resource
private PermissionManage permissionManage;
@Resource
private MsgManage msgManage;
private static Logger logger = LoggerFactory.getLogger(CopilotManage.class);
@Resource
private TokenManage tokenManage;
@Resource
private CopilotAPI copilotAPI;
@Resource
private F2CLicManage f2CLicManage;
public MsgDTO chat(MsgDTO msgDTO) throws Exception {
CoreDatasetGroup coreDatasetGroup = coreDatasetGroupMapper.selectById(msgDTO.getDatasetGroupId());
if (coreDatasetGroup == null) {
return null;
}
DatasetGroupInfoDTO dto = new DatasetGroupInfoDTO();
BeanUtils.copyBean(dto, coreDatasetGroup);
dto.setUnionSql(null);
List<UnionDTO> unionDTOList = JsonUtil.parseList(coreDatasetGroup.getInfo(), new TypeReference<>() {
});
dto.setUnion(unionDTOList);
// 获取field
List<DatasetTableFieldDTO> dsFields = datasetTableFieldManage.selectByDatasetGroupId(msgDTO.getDatasetGroupId());
List<DatasetTableFieldDTO> allFields = dsFields.stream().filter(ele -> ele.getExtField() == 0)
.map(ele -> {
DatasetTableFieldDTO datasetTableFieldDTO = new DatasetTableFieldDTO();
BeanUtils.copyBean(datasetTableFieldDTO, ele);
datasetTableFieldDTO.setFieldShortName(ele.getDataeaseName());
return datasetTableFieldDTO;
}).collect(Collectors.toList());
Map<String, Object> sqlMap = datasetSQLManage.getUnionSQLForEdit(dto, null);
String sql = (String) sqlMap.get("sql");// 数据集原始SQL
Map<Long, DatasourceSchemaDTO> dsMap = (Map<Long, DatasourceSchemaDTO>) sqlMap.get("dsMap");
boolean crossDs = Utils.isCrossDs(dsMap);
if (crossDs) {
DEException.throwException("跨源数据集不支持该功能");
}
// 调用copilot service 获取SQL和chart struct将返回SQL中表名替换成数据集SQL
// deSendDTO 构建schema和engine
if (ObjectUtils.isEmpty(dsMap)) {
DEException.throwException("No datasource");
}
DatasourceSchemaDTO ds = dsMap.entrySet().iterator().next().getValue();
String type = ds.getType();// 数据库类型如mysqloracle等可能需要映射成copilot需要的类型
datasetDataManage.buildFieldName(sqlMap, allFields);
List<String> strings = transCreateTableFields(allFields);
String createSql = "CREATE TABLE de_tmp_table (" + StringUtils.join(strings, ",") + ")";
logger.info("Copilot Schema SQL: " + createSql);
// PerMsgDTO perMsgDTO = new PerMsgDTO();
msgDTO.setDatasetGroupId(dto.getId());
msgDTO.setMsgType("user");
msgDTO.setEngineType(type);
msgDTO.setSchemaSql(createSql);
msgDTO.setHistory(msgDTO.getHistory());
msgDTO.setMsgStatus(1);
msgManage.save(msgDTO);// 到这里为止提问所需参数构建完毕往数据库插入一条提问记录
DESendDTO deSendDTO = new DESendDTO();
deSendDTO.setDatasetGroupId(dto.getId());
deSendDTO.setQuestion(msgDTO.getQuestion());
deSendDTO.setHistory(msgDTO.getHistory());
deSendDTO.setEngine(type);
deSendDTO.setSchema(createSql);
// do copilot chat
ReceiveDTO receiveDTO = copilotChat(deSendDTO);
// copilot 请求结束继续de获取数据
// 获取数据集相关行列权限最终再套一层SQL
Map<String, ColumnPermissionItem> desensitizationList = new HashMap<>();
allFields = permissionManage.filterColumnPermissions(allFields, desensitizationList, dto.getId(), null);
if (ObjectUtils.isEmpty(allFields)) {
DEException.throwException(Translator.get("i18n_no_column_permission"));
}
List<String> dsList = new ArrayList<>();
for (Map.Entry<Long, DatasourceSchemaDTO> next : dsMap.entrySet()) {
dsList.add(next.getValue().getType());
}
boolean needOrder = Utils.isNeedOrder(dsList);
if (!crossDs) {
sql = Utils.replaceSchemaAlias(sql, dsMap);
}
Provider provider;
if (crossDs) {
provider = ProviderFactory.getDefaultProvider();
} else {
provider = ProviderFactory.getProvider(dsList.getFirst());
}
// List<DataSetRowPermissionsTreeDTO> rowPermissionsTree = new ArrayList<>();
// TokenUserBO user = AuthUtils.getUser();
// if (user != null) {
// rowPermissionsTree = permissionManage.getRowPermissionsTree(dto.getId(), user.getUserId());
// }
// build query sql
// SQLMeta sqlMeta = new SQLMeta();
// Table2SQLObj.table2sqlobj(sqlMeta, null, "(" + sql + ")", crossDs);
// Field2SQLObj.field2sqlObj(sqlMeta, allFields, allFields, crossDs, dsMap);
// WhereTree2Str.transFilterTrees(sqlMeta, rowPermissionsTree, allFields, crossDs, dsMap);
// Order2SQLObj.getOrders(sqlMeta, dto.getSortFields(), allFields, crossDs, dsMap);
// String querySQL = SQLProvider.createQuerySQL(sqlMeta, false, false, needOrder);
// querySQL = provider.rebuildSQL(querySQL, sqlMeta, crossDs, dsMap);
// logger.info("preview sql: " + querySQL);
// todo test
String querySQL = sql;
String copilotSQL = receiveDTO.getSql();
// 通过数据源请求数据
// 调用数据源的calcite获得data
DatasourceRequest datasourceRequest = new DatasourceRequest();
datasourceRequest.setDsList(dsMap);
String s = "";
Map<String, Object> data;
try {
s = copilotSQL
.replaceAll(SqlPlaceholderConstants.KEYWORD_PREFIX_REGEX + "de_tmp_table" + SqlPlaceholderConstants.KEYWORD_SUFFIX_REGEX, "(" + querySQL + ")")
.replaceAll(SqlPlaceholderConstants.KEYWORD_PREFIX_REGEX + "DE_TMP_TABLE" + SqlPlaceholderConstants.KEYWORD_SUFFIX_REGEX, "(" + querySQL + ")");
logger.info("copilot sql: " + s);
datasourceRequest.setQuery(s);
data = provider.fetchResultField(datasourceRequest);
} catch (Exception e) {
try {
s = copilotSQL
.replaceAll(SqlPlaceholderConstants.KEYWORD_PREFIX_REGEX + "de_tmp_table" + SqlPlaceholderConstants.KEYWORD_SUFFIX_REGEX, "(" + querySQL + ") tmp")
.replaceAll(SqlPlaceholderConstants.KEYWORD_PREFIX_REGEX + "DE_TMP_TABLE" + SqlPlaceholderConstants.KEYWORD_SUFFIX_REGEX, "(" + querySQL + ") tmp");
logger.info("copilot sql: " + s);
datasourceRequest.setQuery(s);
data = provider.fetchResultField(datasourceRequest);
} catch (Exception e1) {
// 如果异常则获取最后一条成功的history
MsgDTO lastSuccessMsg = msgManage.getLastSuccessMsg(AuthUtils.getUser().getUserId(), dto.getId());
// 请求数据出错记录错误信息和copilot返回的信息
MsgDTO result = new MsgDTO();
result.setDatasetGroupId(dto.getId());
result.setMsgType("api");
result.setHistory(lastSuccessMsg == null ? new ArrayList<>() : lastSuccessMsg.getHistory());
result.setCopilotSql(receiveDTO.getSql());
result.setApiMsg(receiveDTO.getApiMessage());
result.setSqlOk(receiveDTO.getSqlOk() ? 1 : 0);
result.setChartOk(receiveDTO.getChartOk() ? 1 : 0);
result.setChart(receiveDTO.getChart());
result.setExecSql(s);
result.setMsgStatus(0);
result.setErrMsg(e1.getMessage());
msgManage.save(result);
return result;
}
}
List<TableField> fields = (List<TableField>) data.get("fields");
Map<String, Object> map = new LinkedHashMap<>();
// 重新构造data
Map<String, Object> previewData = buildPreviewData(data, fields, desensitizationList);
map.put("data", previewData);
// map.put("allFields", allFields);// map.data 中包含了fields和data
// if (ObjectUtils.isEmpty(dto.getId())) {
// map.put("allFields", allFields);
// } else {
// List<DatasetTableFieldDTO> fieldList = datasetTableFieldManage.selectByDatasetGroupId(dto.getId());
// map.put("allFields", fieldList);
// }
map.put("sql", Base64.getEncoder().encodeToString(s.getBytes()));
MsgDTO result = new MsgDTO();
result.setDatasetGroupId(dto.getId());
result.setMsgType("api");
result.setHistory(receiveDTO.getHistory());
result.setCopilotSql(receiveDTO.getSql());
result.setApiMsg(receiveDTO.getApiMessage());
result.setSqlOk(receiveDTO.getSqlOk() ? 1 : 0);
result.setChartOk(receiveDTO.getChartOk() ? 1 : 0);
result.setChart(receiveDTO.getChart());
result.setChartData(map);
result.setExecSql(s);
result.setMsgStatus(1);
msgManage.save(result);
return result;
}
public ReceiveDTO copilotChat(DESendDTO deSendDTO) throws Exception {
boolean valid = LicenseUtil.licenseValid();
// call copilot service
TokenDTO tokenDTO = tokenManage.getToken(valid);
ReceiveDTO receiveDTO;
if (StringUtils.isEmpty(tokenDTO.getToken())) {
// get token
String token;
if (valid) {
LicensePO read = f2CLicManage.read();
token = copilotAPI.getToken(read.getLicense());
} else {
token = copilotAPI.getFreeToken();
}
tokenManage.updateToken(token, valid);
receiveDTO = copilotAPI.generateChart(token, deSendDTO);
} else {
try {
receiveDTO = copilotAPI.generateChart(tokenDTO.getToken(), deSendDTO);
} catch (Exception e) {
// error, get token again
String token;
if (valid) {
LicensePO read = f2CLicManage.read();
token = copilotAPI.getToken(read.getLicense());
} else {
token = copilotAPI.getFreeToken();
}
tokenManage.updateToken(token, valid);
receiveDTO = copilotAPI.generateChart(token, deSendDTO);
}
}
if (!receiveDTO.getSqlOk() || !receiveDTO.getChartOk()) {
DEException.throwException((String) JsonUtil.toJSONString(receiveDTO));
}
logger.info("Copilot Service SQL: " + receiveDTO.getSql());
logger.info("Copilot Service Chart: " + JsonUtil.toJSONString(receiveDTO.getChart()));
return receiveDTO;
}
public List<MsgDTO> getList(Long userId) {
MsgDTO lastMsg = msgManage.getLastMsg(userId);
if (lastMsg == null) {
return null;
}
List<MsgDTO> msg = msgManage.getMsg(lastMsg);
msgManage.deleteMsg(lastMsg);
return msg;
}
public void clearAll(Long userId) {
msgManage.clearAllUserMsg(userId);
}
public MsgDTO errorMsg(MsgDTO msgDTO, String errMsg) throws Exception {
// 如果异常则获取最后一条成功的history
MsgDTO lastSuccessMsg = msgManage.getLastSuccessMsg(AuthUtils.getUser().getUserId(), msgDTO.getDatasetGroupId());
MsgDTO dto = new MsgDTO();
dto.setDatasetGroupId(msgDTO.getDatasetGroupId());
dto.setHistory(lastSuccessMsg == null ? new ArrayList<>() : lastSuccessMsg.getHistory());
dto.setMsgStatus(0);
dto.setMsgType("api");
dto.setErrMsg(errMsg);
msgManage.save(dto);
return dto;
}
public Map<String, Object> buildPreviewData(Map<String, Object> data, List<TableField> fields, Map<String, ColumnPermissionItem> desensitizationList) {
Map<String, Object> map = new LinkedHashMap<>();
List<String[]> dataList = (List<String[]>) data.get("data");
List<LinkedHashMap<String, Object>> dataObjectList = new ArrayList<>();
if (ObjectUtils.isNotEmpty(dataList)) {
for (int i = 0; i < dataList.size(); i++) {
String[] row = dataList.get(i);
LinkedHashMap<String, Object> obj = new LinkedHashMap<>();
if (row.length > 0) {
for (int j = 0; j < fields.size(); j++) {
TableField tableField = fields.get(j);
if (desensitizationList.containsKey(tableField.getOriginName())) {
obj.put(tableField.getOriginName(), ChartDataBuild.desensitizationValue(desensitizationList.get(tableField.getOriginName()), String.valueOf(row[j])));
} else {
if (tableField.getDeExtractType() == DeTypeConstants.DE_INT
|| tableField.getDeExtractType() == DeTypeConstants.DE_FLOAT
|| tableField.getDeExtractType() == DeTypeConstants.DE_BOOL) {
try {
obj.put(tableField.getOriginName(), new BigDecimal(row[j]));
} catch (Exception e) {
obj.put(tableField.getOriginName(), row[j]);
}
} else {
obj.put(tableField.getOriginName(), row[j]);
}
}
}
}
dataObjectList.add(obj);
}
}
map.put("fields", fields);
map.put("data", dataObjectList);
return map;
}
public List<String> transCreateTableFields(List<DatasetTableFieldDTO> allFields) {
List<String> list = new ArrayList<>();
for (DatasetTableFieldDTO dto : allFields) {
list.add(" " + dto.getDataeaseName() + " " + transNum2Type(dto.getDeExtractType()) +
" COMMENT '" + dto.getName() + "'");
}
return list;
}
public String transNum2Type(Integer num) {
return switch (num) {
case 0, 1, 5 -> "VARCHAR(50)";
case 2, 3, 4 -> "INT(10)";
default -> "VARCHAR(50)";
};
}
}

View File

@ -0,0 +1,100 @@
package io.dataease.copilot.manage;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.dataease.api.copilot.dto.ChartDTO;
import io.dataease.api.copilot.dto.HistoryDTO;
import io.dataease.api.copilot.dto.MsgDTO;
import io.dataease.copilot.dao.auto.entity.CoreCopilotMsg;
import io.dataease.copilot.dao.auto.mapper.CoreCopilotMsgMapper;
import io.dataease.utils.AuthUtils;
import io.dataease.utils.BeanUtils;
import io.dataease.utils.IDUtils;
import io.dataease.utils.JsonUtil;
import jakarta.annotation.Resource;
import org.apache.commons.lang3.ObjectUtils;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* @Author Junjun
*/
@Component
public class MsgManage {
@Resource
private CoreCopilotMsgMapper coreCopilotMsgMapper;
private ObjectMapper objectMapper = new ObjectMapper();
public void save(MsgDTO msgDTO) throws Exception {
msgDTO.setId(IDUtils.snowID());
msgDTO.setCreateTime(System.currentTimeMillis());
msgDTO.setUserId(AuthUtils.getUser().getUserId());
coreCopilotMsgMapper.insert(transDTO(msgDTO));
}
public List<MsgDTO> getMsg(MsgDTO msgDTO) {
QueryWrapper<CoreCopilotMsg> wrapper = new QueryWrapper<>();
wrapper.eq("user_id", msgDTO.getUserId());
wrapper.eq("dataset_group_id", msgDTO.getDatasetGroupId());
wrapper.orderByAsc("create_time");
List<CoreCopilotMsg> perCopilotMsgs = coreCopilotMsgMapper.selectList(wrapper);
return perCopilotMsgs.stream().map(this::transRecord).toList();
}
public void deleteMsg(MsgDTO msgDTO) {
QueryWrapper<CoreCopilotMsg> wrapper = new QueryWrapper<>();
wrapper.eq("user_id", msgDTO.getUserId());
wrapper.ne("dataset_group_id", msgDTO.getDatasetGroupId());
coreCopilotMsgMapper.delete(wrapper);
}
public void clearAllUserMsg(Long userId) {
QueryWrapper<CoreCopilotMsg> wrapper = new QueryWrapper<>();
wrapper.eq("user_id", userId);
coreCopilotMsgMapper.delete(wrapper);
}
public MsgDTO getLastMsg(Long userId) {
QueryWrapper<CoreCopilotMsg> wrapper = new QueryWrapper<>();
wrapper.eq("user_id", userId);
wrapper.orderByDesc("create_time");
List<CoreCopilotMsg> perCopilotMsgs = coreCopilotMsgMapper.selectList(wrapper);
return ObjectUtils.isEmpty(perCopilotMsgs) ? null : transRecord(perCopilotMsgs.getFirst());
}
public MsgDTO getLastSuccessMsg(Long userId, Long datasetGroupId) {
QueryWrapper<CoreCopilotMsg> wrapper = new QueryWrapper<>();
wrapper.eq("user_id", userId);
wrapper.eq("dataset_group_id", datasetGroupId);
wrapper.eq("msg_status", 1);
wrapper.eq("msg_type", "api");
wrapper.orderByDesc("create_time");
List<CoreCopilotMsg> perCopilotMsgs = coreCopilotMsgMapper.selectList(wrapper);
return ObjectUtils.isEmpty(perCopilotMsgs) ? null : transRecord(perCopilotMsgs.getFirst());
}
private CoreCopilotMsg transDTO(MsgDTO dto) throws Exception {
CoreCopilotMsg record = new CoreCopilotMsg();
BeanUtils.copyBean(record, dto);
record.setHistory(dto.getHistory() == null ? null : objectMapper.writeValueAsString(dto.getHistory()));
record.setChart(dto.getChart() == null ? null : objectMapper.writeValueAsString(dto.getChart()));
record.setChartData(dto.getChartData() == null ? null : objectMapper.writeValueAsString(dto.getChartData()));
return record;
}
private MsgDTO transRecord(CoreCopilotMsg record) {
MsgDTO dto = new MsgDTO();
BeanUtils.copyBean(dto, record);
TypeReference<List<HistoryDTO>> tokenType = new TypeReference<>() {
};
dto.setHistory(record.getHistory() == null ? new ArrayList<>() : JsonUtil.parseList(record.getHistory(), tokenType));
dto.setChart(record.getChart() == null ? null : JsonUtil.parseObject(record.getChart(), ChartDTO.class));
dto.setChartData(record.getChartData() == null ? null : JsonUtil.parse(record.getChartData(), Map.class));
return dto;
}
}

View File

@ -0,0 +1,36 @@
package io.dataease.copilot.manage;
import io.dataease.api.copilot.dto.TokenDTO;
import io.dataease.copilot.dao.auto.entity.CoreCopilotToken;
import io.dataease.copilot.dao.auto.mapper.CoreCopilotTokenMapper;
import io.dataease.utils.BeanUtils;
import jakarta.annotation.Resource;
import org.springframework.stereotype.Component;
/**
* @Author Junjun
*/
@Component
public class TokenManage {
@Resource
private CoreCopilotTokenMapper coreCopilotTokenMapper;
public TokenDTO getToken(boolean valid) {
CoreCopilotToken perCopilotToken = coreCopilotTokenMapper.selectById(valid ? 2 : 1);
return transRecord(perCopilotToken);
}
public void updateToken(String token, boolean valid) {
CoreCopilotToken record = new CoreCopilotToken();
record.setId(valid ? 2L : 1L);
record.setToken(token);
record.setUpdateTime(System.currentTimeMillis());
coreCopilotTokenMapper.updateById(record);
}
private TokenDTO transRecord(CoreCopilotToken perCopilotToken) {
TokenDTO dto = new TokenDTO();
BeanUtils.copyBean(dto, perCopilotToken);
return dto;
}
}

View File

@ -0,0 +1,40 @@
package io.dataease.copilot.service;
import io.dataease.api.copilot.CopilotApi;
import io.dataease.api.copilot.dto.MsgDTO;
import io.dataease.copilot.manage.CopilotManage;
import io.dataease.utils.AuthUtils;
import jakarta.annotation.Resource;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
/**
* @Author Junjun
*/
@RestController
@RequestMapping("copilot")
public class CopilotService implements CopilotApi {
@Resource
private CopilotManage copilotManage;
@Override
public MsgDTO chat(MsgDTO msgDTO) throws Exception {
try {
return copilotManage.chat(msgDTO);
} catch (Exception e) {
return copilotManage.errorMsg(msgDTO, e.getMessage());
}
}
@Override
public List<MsgDTO> getList() throws Exception {
return copilotManage.getList(AuthUtils.getUser().getUserId());
}
@Override
public void clearAll() throws Exception {
copilotManage.clearAll(AuthUtils.getUser().getUserId());
}
}

View File

@ -0,0 +1,48 @@
DROP TABLE IF EXISTS `core_copilot_msg`;
CREATE TABLE `core_copilot_msg` (
`id` bigint NOT NULL COMMENT 'ID',
`user_id` bigint DEFAULT NULL COMMENT '用户ID',
`dataset_group_id` bigint DEFAULT NULL COMMENT '数据集ID',
`msg_type` varchar(255) DEFAULT NULL COMMENT 'user or api',
`engine_type` varchar(255) DEFAULT NULL COMMENT 'mysql oracle ...',
`schema_sql` longtext COMMENT 'create sql',
`question` longtext COMMENT '用户提问',
`history` longtext COMMENT '历史信息',
`copilot_sql` longtext COMMENT 'copilot 返回 sql',
`api_msg` longtext COMMENT 'copilot 返回信息',
`sql_ok` int DEFAULT NULL COMMENT 'sql 状态',
`chart_ok` int DEFAULT NULL COMMENT 'chart 状态',
`chart` longtext COMMENT 'chart 内容',
`chart_data` longtext COMMENT '视图数据',
`exec_sql` longtext COMMENT '执行请求的SQL',
`msg_status` int DEFAULT NULL COMMENT 'msg状态0失败 1成功',
`err_msg` longtext COMMENT 'de错误信息',
`create_time` bigint DEFAULT NULL COMMENT '创建时间',
PRIMARY KEY (`id`)
);
DROP TABLE IF EXISTS `core_copilot_token`;
CREATE TABLE `core_copilot_token` (
`id` bigint NOT NULL COMMENT 'ID',
`type` varchar(255) DEFAULT NULL COMMENT 'free or license',
`token` longtext,
`update_time` bigint DEFAULT NULL,
PRIMARY KEY (`id`)
);
INSERT INTO `core_copilot_token` VALUES (1, 'free', null, null);
INSERT INTO `core_copilot_token` VALUES (2, 'license', null, null);
DROP TABLE IF EXISTS `core_copilot_config`;
CREATE TABLE `core_copilot_config` (
`id` bigint NOT NULL COMMENT 'ID',
`copilot_url` varchar(255) DEFAULT NULL,
`username` varchar(255) DEFAULT NULL,
`pwd` varchar(255) DEFAULT NULL,
PRIMARY KEY (`id`)
);
INSERT INTO `core_copilot_config` VALUES (1, 'https://copilot-demo.test.fit2cloud.dev:5000', 'xlab', 'Q2Fsb25nQDIwMTU=');

View File

@ -0,0 +1,30 @@
package io.dataease.api.copilot;
import com.github.xiaoymin.knife4j.annotations.ApiSupport;
import io.dataease.api.copilot.dto.MsgDTO;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import java.util.List;
/**
* @Author Junjun
*/
@Tag(name = "Copilot")
@ApiSupport(order = 999)
public interface CopilotApi {
@Operation(summary = "发起一次对话")
@PostMapping("chat")
MsgDTO chat(@RequestBody MsgDTO msgDTO) throws Exception;
@Operation(summary = "获取对话记录")
@PostMapping("getList")
List<MsgDTO> getList() throws Exception;
@Operation(summary = "清空对话")
@PostMapping("clearAll")
void clearAll() throws Exception;
}

View File

@ -0,0 +1,12 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
/**
* @Author Junjun
*/
@Data
public class AxisDTO {
private String x;
private String y;
}

View File

@ -0,0 +1,13 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
/**
* @Author Junjun
*/
@Data
public class ChartDTO {
private String type;
private AxisDTO axis;
private String title;
}

View File

@ -0,0 +1,13 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
import java.util.Map;
/**
* @Author Junjun
*/
@Data
public class DEReceiveDTO extends ReceiveDTO {
private Map<String, Object> chartData;
}

View File

@ -0,0 +1,16 @@
package io.dataease.api.copilot.dto;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import lombok.Data;
/**
* @Author Junjun
*/
@Data
public class DESendDTO extends SendDTO {
@JsonSerialize(using = ToStringSerializer.class)
private Long id;
@JsonSerialize(using = ToStringSerializer.class)
private Long datasetGroupId;
}

View File

@ -0,0 +1,12 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
/**
* @Author Junjun
*/
@Data
public class HistoryDTO {
private String type;
private String message;
}

View File

@ -0,0 +1,53 @@
package io.dataease.api.copilot.dto;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import lombok.Data;
import java.util.List;
import java.util.Map;
/**
* @Author Junjun
*/
@Data
public class MsgDTO {
@JsonSerialize(using = ToStringSerializer.class)
private Long id;
@JsonSerialize(using = ToStringSerializer.class)
private Long userId;
@JsonSerialize(using = ToStringSerializer.class)
private Long datasetGroupId;
private String msgType;
private String engineType;
private String schemaSql;
private String question;
private List<HistoryDTO> history;
private String copilotSql;
private String apiMsg;
private Integer sqlOk;
private Integer chartOk;
private ChartDTO chart;
private Long createTime;
private Map<String, Object> chartData;
private String execSql;
private Integer msgStatus;
private String errMsg;
}

View File

@ -0,0 +1,18 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
import java.util.List;
/**
* @Author Junjun
*/
@Data
public class ReceiveDTO {
private String sql;
private List<HistoryDTO> history;
private String apiMessage;
private Boolean sqlOk;
private Boolean chartOk;
private ChartDTO chart;
}

View File

@ -0,0 +1,16 @@
package io.dataease.api.copilot.dto;
import lombok.Data;
import java.util.List;
/**
* @Author Junjun
*/
@Data
public class SendDTO {
private String engine;
private String schema;
private String question;
private List<HistoryDTO> history;
}

View File

@ -0,0 +1,16 @@
package io.dataease.api.copilot.dto;
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
import com.fasterxml.jackson.databind.ser.std.ToStringSerializer;
import lombok.Data;
/**
* @Author Junjun
*/
@Data
public class TokenDTO {
@JsonSerialize(using = ToStringSerializer.class)
private Long id;
private String token;
private Long updateTime;
}