feat(视图): 数据预测

This commit is contained in:
wisonic-s 2024-05-17 17:21:47 +08:00
parent 6132566ccf
commit 5329734949
10 changed files with 433 additions and 23 deletions

View File

@ -1,33 +1,34 @@
package io.dataease.commons.utils;
import org.apache.commons.lang3.StringUtils;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.*;
public class DateUtils {
public static final String DATE_PATTERM = "yyyy-MM-dd";
public static final String DATE_PATTERN = "yyyy-MM-dd";
public static final String TIME_PATTERN = "yyyy-MM-dd HH:mm:ss";
public static Date getDate(String dateString) throws Exception {
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
return dateFormat.parse(dateString);
}
public static Date getTime(String timeString) throws Exception {
SimpleDateFormat dateFormat = new SimpleDateFormat(TIME_PATTERN);
return dateFormat.parse(timeString);
}
public static String getDateString(Date date) throws Exception {
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
return dateFormat.format(date);
}
public static String getDateString(long timeStamp) throws Exception {
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
return dateFormat.format(timeStamp);
}
@ -47,10 +48,10 @@ public class DateUtils {
}
public static Date dateSum (Date date,int countDays){
public static Date dateSum(Date date, int countDays) {
Calendar calendar = Calendar.getInstance();
calendar.setTime(date);
calendar.add(Calendar.DAY_OF_MONTH,countDays);
calendar.add(Calendar.DAY_OF_MONTH, countDays);
return calendar.getTime();
}
@ -70,7 +71,7 @@ public class DateUtils {
try {
calendar.setTime(date);
calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMinimum(Calendar.DAY_OF_WEEK));
calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd);
calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd);
//第一天的时分秒是 00:00:00 这里直接取日期默认就是零点零分
Date thisWeekFirstTime = getDate(getDateString(calendar.getTime()));
@ -78,12 +79,12 @@ public class DateUtils {
calendar.clear();
calendar.setTime(date);
calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMaximum(Calendar.DAY_OF_WEEK));
calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd);
calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd);
//最后一天的时分秒应当是23:59:59 处理方式是增加一天计算日期再-1
calendar.add(Calendar.DAY_OF_MONTH,1);
calendar.add(Calendar.DAY_OF_MONTH, 1);
Date nextWeekFirstDay = getDate(getDateString(calendar.getTime()));
Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime()-1));
Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime() - 1));
returnMap.put("firstTime", thisWeekFirstTime);
returnMap.put("lastTime", thisWeekLastTime);
@ -95,14 +96,91 @@ public class DateUtils {
}
/**
* 获取当天的起始时间Date
* @param time 指定日期 2020-12-13 06:12:42
* @return 当天起始时间 2020-12-13 00:00:00
*
* @param time 指定日期 2020-12-13 06:12:42
* @return 当天起始时间 2020-12-13 00:00:00
* @throws Exception
*/
public static Date getDayStartTime(Date time) throws Exception {
return getDate(getDateString(time));
}
public static List<String> getForecastPeriod(String baseTime, int period, String dateStyle, String pattern) throws ParseException {
String split = "-";
if (StringUtils.equalsIgnoreCase(pattern, "date_split")) {
split = "/";
}
List<String> result = new ArrayList<>(period);
switch (dateStyle) {
case "y":
int baseYear = Integer.parseInt(baseTime);
for (int i = 1; i <= period; i++) {
result.add(baseYear + i + "");
}
break;
case "y_Q":
String[] yQ = baseTime.split(split);
int year = Integer.parseInt(yQ[0]);
int quarter = Integer.parseInt(yQ[1].split("Q")[1]);
for (int i = 0; i < period; i++) {
quarter = quarter % 4 + 1;
if (quarter == 1) {
year += 1;
}
result.add(year + split + "Q" + quarter);
}
break;
case "y_M":
String[] yM = baseTime.split(split);
int y = Integer.parseInt(yM[0]);
int month = Integer.parseInt(yM[1]);
for (int i = 0; i < period; i++) {
month = month % 12 + 1;
if (month == 1) {
y += 1;
}
String padMonth = month < 10 ? "0" + month : "" + month;
result.add(y + split + padMonth);
}
break;
case "y_W":
String[] yW = baseTime.split(split);
int yy = Integer.parseInt(yW[0]);
int w = Integer.parseInt(yW[1].split("W")[1]);
for (int i = 0; i < period; i++) {
Calendar calendar = Calendar.getInstance();
calendar.setMinimalDaysInFirstWeek(7);
calendar.setFirstDayOfWeek(Calendar.MONDAY);
calendar.set(Calendar.YEAR, yy);
calendar.set(Calendar.MONTH, Calendar.DECEMBER);
calendar.set(Calendar.DAY_OF_MONTH, 31);
int lastWeek = calendar.get(Calendar.WEEK_OF_YEAR);
w += 1;
if (w > lastWeek) {
yy += 1;
w = 1;
}
result.add(yy + split + "W" + w);
}
break;
case "y_M_d":
SimpleDateFormat sdf = new SimpleDateFormat("yyyy" + split + "MM" + split + "dd");
Calendar calendar = Calendar.getInstance();
Date baseDate = sdf.parse(baseTime);
calendar.setTime(baseDate);
for (int i = 0; i < period; i++) {
calendar.add(Calendar.DAY_OF_MONTH, 1);
Date curDate = calendar.getTime();
String date = sdf.format(curDate);
result.add(date);
}
break;
default:
break;
}
return result;
}
}

View File

@ -1,8 +1,17 @@
package io.dataease.commons.utils;
import groovy.lang.Tuple2;
import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
import org.apache.commons.statistics.distribution.NormalDistribution;
import org.apache.commons.statistics.distribution.TDistribution;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
public class MathUtils {
private static final NormalDistribution NORMAL_DISTRIBUTION = NormalDistribution.of(0.0, 1.0);
/**
* 获取百分比
@ -12,8 +21,57 @@ public class MathUtils {
* @return
*/
public static double getPercentWithDecimal(double value) {
return new BigDecimal(value * 100)
.setScale(1, BigDecimal.ROUND_HALF_UP)
.doubleValue();
return new BigDecimal(value * 100).setScale(1, RoundingMode.HALF_UP).doubleValue();
}
/**
* 获取预测数据的置信区间这边计算的是预测值的置信区间还有一个是预测值的预测区间公式不一样注意区分.
* 参考资料 <a href="https://zhuanlan.zhihu.com/p/366307027">知乎</a>,
* <a href="https://real-statistics.com/regression/confidence-and-prediction-intervals/">real-statistics</a>
* @param data 原始数据
* @param forecastValue 预测得到的数据
* @param forecastData 将原数据 x 代入回归方程得到的拟合数据
* @param alpha 置信水平
* @param degreeOfFreedom 自由度t分布使用
* @return 预测值的置信区间数组
*/
public static double[][] getConfidenceInterval(double[][] data, double[] forecastValue, double[][] forecastData, double alpha, int degreeOfFreedom) {
// y 平均方差
double totalPow = 0;
double xTotal = 0;
for (int i = 0; i < data.length; i++) {
double xVal = data[i][0];
xTotal += xVal;
double realVal = data[i][1];
double predictVal = forecastValue[i];
totalPow += Math.pow((realVal - predictVal), 2);
}
double xAvg = xTotal / data.length;
double yMseSqrt = Math.sqrt(totalPow / (forecastValue.length - 2));
// x 均值方差
double xSubPow = 0;
for (int i = 0; i < data.length; i++) {
double xVal = data[i][0];
xSubPow += Math.pow(xVal - xAvg, 2);
}
// t/z , 样本数 < 30 t 分布 > 30 z 分布,
double tzFactor;
if (data.length <= 30) {
tzFactor = TDistribution.of(degreeOfFreedom).inverseCumulativeProbability(1 - (1 - alpha) / 2);
} else {
tzFactor = NORMAL_DISTRIBUTION.inverseCumulativeProbability(1 - (1 - alpha) / 2);
}
double[][] result = new double[forecastData.length][2];
for (int i = 0; i < forecastData.length; i++) {
double xVal = forecastData[i][0];
double curSubPow = Math.pow(xVal - xAvg, 2);
double sqrt = Math.sqrt(1.0 / data.length + curSubPow / xSubPow);
double lower = forecastData[i][1] - tzFactor * yMseSqrt * sqrt;
double upper = forecastData[i][1] + tzFactor * yMseSqrt * sqrt;
result[i][0] = lower;
result[i][1] = upper;
}
return result;
}
}

View File

@ -0,0 +1,35 @@
package io.dataease.dto.chart;
import lombok.Data;
@Data
public class ChartSeniorForecastDTO {
/**
* 是否开启预测
*/
private boolean enable;
/**
* 预测周期
*/
private int period;
/**
* 是否使用所有数据进行预测
*/
private boolean allPeriod;
/**
* 用于预测的数据量
*/
private int trainingPeriod;
/**
* 置信区间
*/
private float confidenceInterval;
/**
* 预测用的算法/模型
*/
private String algorithm;
/**
* 多项式阶数
*/
private int degree;
}

View File

@ -0,0 +1,14 @@
package io.dataease.dto.chart;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.Data;
@Data
public class ForecastDataDTO {
@JsonIgnore
private double xVal;
@JsonIgnore
private double yVal;
private double lower;
private double upper;
}

View File

@ -0,0 +1,11 @@
package io.dataease.dto.chart;
import lombok.Data;
import lombok.EqualsAndHashCode;
@Data
@EqualsAndHashCode(callSuper = true)
public class ForecastDataVO<D, Q> extends ForecastDataDTO {
private D dimension;
private Q quota;
}

View File

@ -11,6 +11,7 @@ import io.dataease.commons.constants.JdbcConstants;
import io.dataease.commons.model.PluginViewSetImpl;
import io.dataease.commons.utils.AuthUtils;
import io.dataease.commons.utils.BeanUtils;
import io.dataease.commons.utils.DateUtils;
import io.dataease.commons.utils.LogUtil;
import io.dataease.controller.request.chart.*;
import io.dataease.controller.response.ChartDetail;
@ -52,6 +53,8 @@ import io.dataease.plugins.view.service.ViewPluginService;
import io.dataease.plugins.xpack.auth.dto.request.ColumnPermissionItem;
import io.dataease.provider.query.SQLUtils;
import io.dataease.service.chart.util.ChartDataBuild;
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
import io.dataease.service.chart.util.dataForecast.ForecastAlgoManager;
import io.dataease.service.dataset.*;
import io.dataease.service.datasource.DatasourceService;
import io.dataease.service.engine.EngineService;
@ -71,6 +74,7 @@ import javax.annotation.Resource;
import java.lang.reflect.Type;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.concurrent.atomic.AtomicReference;
@ -1090,7 +1094,7 @@ public class ChartViewService {
logger.info("plugin_sql:" + sql);
Map<String, Object> mapTableNormal = ChartDataBuild.transTableNormal(fieldMap, view, data, desensitizationList);
return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData);
return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, Collections.emptyList());
// 如果是插件到此结束
}
@ -1361,6 +1365,16 @@ public class ChartViewService {
}
tempYAxis.addAll(yAxis);
// forecast
List<? extends ForecastDataVO<?, ?>> forecastData = Collections.emptyList();
JSONObject senior = JSONObject.parseObject(view.getSenior());
JSONObject forecastObj = senior.getJSONObject("forecast");
if (forecastObj != null) {
ChartSeniorForecastDTO forecastCfg = forecastObj.toJavaObject(ChartSeniorForecastDTO.class);
if (forecastCfg.isEnable()) {
forecastData = forecastData(forecastCfg, data, xAxis, yAxis, view);
}
}
for (int i = 0; i < tempYAxis.size(); i++) {
ChartViewFieldDTO chartViewFieldDTO = tempYAxis.get(i);
ChartFieldCompareDTO compareCalc = chartViewFieldDTO.getCompareCalc();
@ -1622,12 +1636,43 @@ public class ChartViewService {
mapTableNormal = ChartDataBuild.transTableNormal(xAxis, yAxis, view, data, extStack, desensitizationList);
}
chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData);
chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, forecastData);
chartViewDTO.setTotalPage(totalPage);
chartViewDTO.setTotalItems(totalItems);
return chartViewDTO;
}
private List<? extends ForecastDataVO<?,?>> forecastData(ChartSeniorForecastDTO forecastCfg, List<String[]> data, List<ChartViewFieldDTO> xAxis, List<ChartViewFieldDTO> yAxis, ChartViewDTO view) throws ParseException {
List<String[]> trainingData = data;
if (!forecastCfg.isAllPeriod() && data.size() > forecastCfg.getTrainingPeriod()) {
trainingData = data.subList(data.size() - forecastCfg.getTrainingPeriod(), data.size() - 1);
}
if (xAxis.size() == 1 && xAxis.get(0).getDeType() == 1) {
// 先处理时间类型, 默认数据是有序递增的
String lastTime = data.get(data.size() - 1)[0];
ChartViewFieldDTO timeAxis = xAxis.get(0);
List<String> forecastPeriod = DateUtils.getForecastPeriod(lastTime, forecastCfg.getPeriod(), timeAxis.getDateStyle(), timeAxis.getDatePattern());
if(!forecastPeriod.isEmpty()){
ForecastAlgo algo = ForecastAlgoManager.getAlgo(forecastCfg.getAlgorithm());
List<ForecastDataDTO> forecastData = algo.forecast(forecastCfg, trainingData, view);
if (forecastPeriod.size() == forecastData.size()) {
List<ForecastDataVO<String, Double>> result = new ArrayList<>();
for (int i = 0; i < forecastPeriod.size(); i++) {
String period = forecastPeriod.get(i);
ForecastDataDTO forecastDataItem = forecastData.get(i);
ForecastDataVO<String, Double> tmp = new ForecastDataVO<>();
BeanUtils.copyBean(tmp, forecastDataItem);
tmp.setDimension(period);
tmp.setQuota(forecastDataItem.getYVal());
result.add(tmp);
}
return result;
}
}
}
return List.of();
}
// 对结果排序
public List<String[]> resultCustomSort(List<ChartViewFieldDTO> xAxis, List<String[]> data) {
List<String[]> res = new ArrayList<>(data);
@ -1684,11 +1729,12 @@ public class ChartViewService {
return "SELECT " + stringBuilder + " FROM (" + sql + ") tmp";
}
public ChartViewDTO uniteViewResult(String sql, Map<String, Object> chartData, Map<String, Object> tableData, ChartViewDTO view, Boolean isDrill, List<ChartExtFilterRequest> drillFilters, List<ChartSeniorAssistDTO> dynamicAssistFields, List<String[]> assistData) {
public ChartViewDTO uniteViewResult(String sql, Map<String, Object> chartData, Map<String, Object> tableData, ChartViewDTO view, Boolean isDrill, List<ChartExtFilterRequest> drillFilters, List<ChartSeniorAssistDTO> dynamicAssistFields, List<String[]> assistData, List<? extends ForecastDataVO<?, ?>> forecastData) {
Map<String, Object> map = new HashMap<>();
map.putAll(chartData);
map.putAll(tableData);
map.put("forecastData", forecastData);
List<DatasetTableField> sourceFields = dataSetTableFieldsService.getFieldsByTableId(view.getTableId());
map.put("sourceFields", sourceFields);

View File

@ -0,0 +1,21 @@
package io.dataease.service.chart.util.dataForecast;
import io.dataease.dto.chart.ChartSeniorForecastDTO;
import io.dataease.dto.chart.ChartViewDTO;
import io.dataease.dto.chart.ForecastDataDTO;
import lombok.Getter;
import java.util.List;
@Getter
public abstract class ForecastAlgo {
private final String algoType;
public ForecastAlgo() {
this.algoType = this.getAlgoType();
ForecastAlgoManager.register(this);
}
public abstract List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view);
}

View File

@ -0,0 +1,17 @@
package io.dataease.service.chart.util.dataForecast;
import io.dataease.service.chart.util.dataForecast.impl.LinearRegressionAlgo;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class ForecastAlgoManager {
private static final Map<String, ForecastAlgo> FORECAST_ALGO_MAP = new ConcurrentHashMap<>();
public static ForecastAlgo getAlgo(String algoType) {
return FORECAST_ALGO_MAP.get(algoType);
}
public static void register(ForecastAlgo forecastAlgo) {
FORECAST_ALGO_MAP.put(forecastAlgo.getAlgoType(), forecastAlgo);
}
}

View File

@ -0,0 +1,59 @@
package io.dataease.service.chart.util.dataForecast.impl;
import com.alibaba.fastjson.JSONObject;
import io.dataease.commons.utils.MathUtils;
import io.dataease.dto.chart.ChartSeniorForecastDTO;
import io.dataease.dto.chart.ChartViewDTO;
import io.dataease.dto.chart.ForecastDataDTO;
import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO;
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
import lombok.Getter;
import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
@Getter
public class LinearRegressionAlgo extends ForecastAlgo {
private final String algoType = "linear-regression";
@Override
public List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view) {
List<ChartViewFieldDTO> xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class);
final List<double[]> forecastData = new ArrayList<>(data.size());
// 先按连续的数据处理
for (int i = 0; i < data.size(); i++) {
String val = data.get(i)[xAxis.size()];
double value = Double.parseDouble(val);
forecastData.add(new double[]{i, value});
}
double[][] matrix = forecastData.toArray(double[][]::new);
SimpleRegression regression = new SimpleRegression();
regression.addData(matrix);
double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2];
for (int i = 0; i < forecastCfg.getPeriod(); i++) {
double xVal = data.size() + i;
double predictVal = regression.predict(xVal);
forecastMatrix[i] = new double[]{xVal, predictVal};
}
final double[] forecastValue = new double[forecastData.size()];
for (int i = 0; i < forecastData.size(); i++) {
double xVal = forecastData.get(i)[0];
forecastValue[i] = regression.predict(xVal);
}
double[][] confidenceInterval = MathUtils.getConfidenceInterval(forecastData.toArray(new double[0][0]), forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), forecastData.size() - 2);
final List<ForecastDataDTO> result = new ArrayList<>(forecastCfg.getPeriod());
for (int i = 0; i < forecastMatrix.length; i++) {
ForecastDataDTO tmp = new ForecastDataDTO();
tmp.setXVal(forecastMatrix[i][0]);
tmp.setYVal(forecastMatrix[i][1]);
tmp.setLower(confidenceInterval[i][0]);
tmp.setUpper(confidenceInterval[i][1]);
result.add(tmp);
}
return result;
}
}

View File

@ -0,0 +1,71 @@
package io.dataease.service.chart.util.dataForecast.impl;
import com.alibaba.fastjson.JSONObject;
import io.dataease.commons.utils.MathUtils;
import io.dataease.dto.chart.ChartSeniorForecastDTO;
import io.dataease.dto.chart.ChartViewDTO;
import io.dataease.dto.chart.ForecastDataDTO;
import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO;
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
import lombok.Getter;
import org.apache.commons.math4.legacy.analysis.function.Logistic;
import org.apache.commons.math4.legacy.fitting.PolynomialCurveFitter;
import org.apache.commons.math4.legacy.fitting.WeightedObservedPoints;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
@Component
@Getter
public class PolynomialRegressionAlgo extends ForecastAlgo {
private final String algoType = "polynomial-regression";
@Override
public List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view) {
List<ChartViewFieldDTO> xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class);
WeightedObservedPoints points = new WeightedObservedPoints();
double[][] originData = new double[data.size()][2];
// 先按连续的数据处理
for (int i = 0; i < data.size(); i++) {
String val = data.get(i)[xAxis.size()];
double value = Double.parseDouble(val);
points.add(i, value);
originData[i] = new double[]{i, value};
}
PolynomialCurveFitter filter = PolynomialCurveFitter.create(forecastCfg.getDegree());
// 返回的是多次项系数, y = 3 + 2x + x*2 则为 [3,2,1]
double[] coefficients = filter.fit(points.toList());
double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2];
for (int i = 0; i < forecastCfg.getPeriod(); i++) {
double xVal = data.size() + i;
double predictVal = getPolynomialValue(xVal, coefficients);
forecastMatrix[i] = new double[]{xVal, predictVal};
}
final double[] forecastValue = new double[data.size()];
for (int i = 0; i < data.size(); i++) {
double xVal = originData[i][0];
forecastValue[i] = getPolynomialValue(xVal, coefficients);
}
int df = data.size() - forecastCfg.getDegree() - 1;
double[][] confidenceInterval = MathUtils.getConfidenceInterval(originData, forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), df);
final List<ForecastDataDTO> result = new ArrayList<>(forecastCfg.getPeriod());
for (int i = 0; i < forecastMatrix.length; i++) {
ForecastDataDTO tmp = new ForecastDataDTO();
tmp.setXVal(forecastMatrix[i][0]);
tmp.setYVal(forecastMatrix[i][1]);
tmp.setLower(confidenceInterval[i][0]);
tmp.setUpper(confidenceInterval[i][1]);
result.add(tmp);
}
return result;
}
private double getPolynomialValue(double x, double[] coefficients) {
double result = 0.0;
for (int i = 0; i < coefficients.length; i++) {
result += coefficients[i] * Math.pow(x, i);
}
return result;
}
}