feat: copilot优化

This commit is contained in:
junjun 2024-07-24 18:20:05 +08:00
parent 7866211c58
commit ef8555cef4

View File

@ -32,6 +32,10 @@ import io.dataease.utils.AuthUtils;
import io.dataease.utils.BeanUtils;
import io.dataease.utils.JsonUtil;
import jakarta.annotation.Resource;
import org.apache.calcite.config.Lex;
import org.apache.calcite.sql.SqlDialect;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.commons.lang3.ObjectUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
@ -148,7 +152,6 @@ public class CopilotManage {
for (Map.Entry<Long, DatasourceSchemaDTO> next : dsMap.entrySet()) {
dsList.add(next.getValue().getType());
}
boolean needOrder = Utils.isNeedOrder(dsList);
if (!crossDs) {
sql = Utils.replaceSchemaAlias(sql, dsMap);
@ -181,6 +184,9 @@ public class CopilotManage {
String querySQL = sql;
String copilotSQL = receiveDTO.getSql();
// 用calcite尝试将SQL转方言如果失败了就按照原SQL执行
// copilotSQL = transSql(type, copilotSQL, provider, receiveDTO);
// 通过数据源请求数据
// 调用数据源的calcite获得data
DatasourceRequest datasourceRequest = new DatasourceRequest();
@ -280,7 +286,7 @@ public class CopilotManage {
token = copilotAPI.getFreeToken();
}
tokenManage.updateToken(token, valid);
receiveDTO = copilotAPI.generateChart(token, deSendDTO);
throw new Exception(e.getMessage());
}
}
@ -428,4 +434,38 @@ public class CopilotManage {
default -> "VARCHAR(50)";
};
}
public Lex getLex(String type) {
switch (type) {
case "oracle":
return Lex.ORACLE;
case "sqlServer":
case "mssql":
return Lex.SQL_SERVER;
default:
return Lex.JAVA;
}
}
public String transSql(String type, String copilotSQL, Provider provider, ReceiveDTO receiveDTO) {
if (type.equals("oracle") || type.equals("sqlServer")) {
try {
if (copilotSQL.trim().endsWith(";")) {
copilotSQL = copilotSQL.substring(0, copilotSQL.length() - 1);
}
DatasourceSchemaDTO datasourceSchemaDTO = new DatasourceSchemaDTO();
datasourceSchemaDTO.setType(type);
SqlDialect dialect = provider.getDialect(datasourceSchemaDTO);
SqlParser parser = SqlParser.create(copilotSQL, SqlParser.Config.DEFAULT.withLex(getLex(type)));
SqlNode sqlNode = parser.parseStmt();
return sqlNode.toSqlString(dialect).toString().toLowerCase();
} catch (Exception e) {
logger.info("calcite trans copilot SQL error");
return receiveDTO.getSql();
}
} else {
return copilotSQL;
}
}
}