forked from github/dataease
feat(视图): 数据预测
This commit is contained in:
parent
6132566ccf
commit
5329734949
@ -1,33 +1,34 @@
|
||||
package io.dataease.commons.utils;
|
||||
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Calendar;
|
||||
import java.util.Date;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
public class DateUtils {
|
||||
public static final String DATE_PATTERM = "yyyy-MM-dd";
|
||||
public static final String DATE_PATTERN = "yyyy-MM-dd";
|
||||
public static final String TIME_PATTERN = "yyyy-MM-dd HH:mm:ss";
|
||||
|
||||
|
||||
public static Date getDate(String dateString) throws Exception {
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
|
||||
return dateFormat.parse(dateString);
|
||||
}
|
||||
|
||||
public static Date getTime(String timeString) throws Exception {
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(TIME_PATTERN);
|
||||
return dateFormat.parse(timeString);
|
||||
}
|
||||
|
||||
public static String getDateString(Date date) throws Exception {
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
|
||||
return dateFormat.format(date);
|
||||
}
|
||||
|
||||
public static String getDateString(long timeStamp) throws Exception {
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM);
|
||||
SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN);
|
||||
return dateFormat.format(timeStamp);
|
||||
}
|
||||
|
||||
@ -47,10 +48,10 @@ public class DateUtils {
|
||||
}
|
||||
|
||||
|
||||
public static Date dateSum (Date date,int countDays){
|
||||
public static Date dateSum(Date date, int countDays) {
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
calendar.setTime(date);
|
||||
calendar.add(Calendar.DAY_OF_MONTH,countDays);
|
||||
calendar.add(Calendar.DAY_OF_MONTH, countDays);
|
||||
|
||||
return calendar.getTime();
|
||||
}
|
||||
@ -70,7 +71,7 @@ public class DateUtils {
|
||||
try {
|
||||
calendar.setTime(date);
|
||||
calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMinimum(Calendar.DAY_OF_WEEK));
|
||||
calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd);
|
||||
calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd);
|
||||
|
||||
//第一天的时分秒是 00:00:00 这里直接取日期,默认就是零点零分
|
||||
Date thisWeekFirstTime = getDate(getDateString(calendar.getTime()));
|
||||
@ -78,12 +79,12 @@ public class DateUtils {
|
||||
calendar.clear();
|
||||
calendar.setTime(date);
|
||||
calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMaximum(Calendar.DAY_OF_WEEK));
|
||||
calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd);
|
||||
calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd);
|
||||
|
||||
//最后一天的时分秒应当是23:59:59。 处理方式是增加一天计算日期再-1
|
||||
calendar.add(Calendar.DAY_OF_MONTH,1);
|
||||
calendar.add(Calendar.DAY_OF_MONTH, 1);
|
||||
Date nextWeekFirstDay = getDate(getDateString(calendar.getTime()));
|
||||
Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime()-1));
|
||||
Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime() - 1));
|
||||
|
||||
returnMap.put("firstTime", thisWeekFirstTime);
|
||||
returnMap.put("lastTime", thisWeekLastTime);
|
||||
@ -95,14 +96,91 @@ public class DateUtils {
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 获取当天的起始时间Date
|
||||
* @param time 指定日期 例: 2020-12-13 06:12:42
|
||||
* @return 当天起始时间 例: 2020-12-13 00:00:00
|
||||
*
|
||||
* @param time 指定日期 例: 2020-12-13 06:12:42
|
||||
* @return 当天起始时间 例: 2020-12-13 00:00:00
|
||||
* @throws Exception
|
||||
*/
|
||||
public static Date getDayStartTime(Date time) throws Exception {
|
||||
return getDate(getDateString(time));
|
||||
}
|
||||
|
||||
public static List<String> getForecastPeriod(String baseTime, int period, String dateStyle, String pattern) throws ParseException {
|
||||
String split = "-";
|
||||
if (StringUtils.equalsIgnoreCase(pattern, "date_split")) {
|
||||
split = "/";
|
||||
}
|
||||
|
||||
List<String> result = new ArrayList<>(period);
|
||||
switch (dateStyle) {
|
||||
case "y":
|
||||
int baseYear = Integer.parseInt(baseTime);
|
||||
for (int i = 1; i <= period; i++) {
|
||||
result.add(baseYear + i + "");
|
||||
}
|
||||
break;
|
||||
case "y_Q":
|
||||
String[] yQ = baseTime.split(split);
|
||||
int year = Integer.parseInt(yQ[0]);
|
||||
int quarter = Integer.parseInt(yQ[1].split("Q")[1]);
|
||||
for (int i = 0; i < period; i++) {
|
||||
quarter = quarter % 4 + 1;
|
||||
if (quarter == 1) {
|
||||
year += 1;
|
||||
}
|
||||
result.add(year + split + "Q" + quarter);
|
||||
}
|
||||
break;
|
||||
case "y_M":
|
||||
String[] yM = baseTime.split(split);
|
||||
int y = Integer.parseInt(yM[0]);
|
||||
int month = Integer.parseInt(yM[1]);
|
||||
for (int i = 0; i < period; i++) {
|
||||
month = month % 12 + 1;
|
||||
if (month == 1) {
|
||||
y += 1;
|
||||
}
|
||||
String padMonth = month < 10 ? "0" + month : "" + month;
|
||||
result.add(y + split + padMonth);
|
||||
}
|
||||
break;
|
||||
case "y_W":
|
||||
String[] yW = baseTime.split(split);
|
||||
int yy = Integer.parseInt(yW[0]);
|
||||
int w = Integer.parseInt(yW[1].split("W")[1]);
|
||||
for (int i = 0; i < period; i++) {
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
calendar.setMinimalDaysInFirstWeek(7);
|
||||
calendar.setFirstDayOfWeek(Calendar.MONDAY);
|
||||
calendar.set(Calendar.YEAR, yy);
|
||||
calendar.set(Calendar.MONTH, Calendar.DECEMBER);
|
||||
calendar.set(Calendar.DAY_OF_MONTH, 31);
|
||||
int lastWeek = calendar.get(Calendar.WEEK_OF_YEAR);
|
||||
w += 1;
|
||||
if (w > lastWeek) {
|
||||
yy += 1;
|
||||
w = 1;
|
||||
}
|
||||
result.add(yy + split + "W" + w);
|
||||
}
|
||||
break;
|
||||
case "y_M_d":
|
||||
SimpleDateFormat sdf = new SimpleDateFormat("yyyy" + split + "MM" + split + "dd");
|
||||
Calendar calendar = Calendar.getInstance();
|
||||
Date baseDate = sdf.parse(baseTime);
|
||||
calendar.setTime(baseDate);
|
||||
for (int i = 0; i < period; i++) {
|
||||
calendar.add(Calendar.DAY_OF_MONTH, 1);
|
||||
Date curDate = calendar.getTime();
|
||||
String date = sdf.format(curDate);
|
||||
result.add(date);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,17 @@
|
||||
package io.dataease.commons.utils;
|
||||
|
||||
import groovy.lang.Tuple2;
|
||||
import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
|
||||
import org.apache.commons.statistics.distribution.NormalDistribution;
|
||||
import org.apache.commons.statistics.distribution.TDistribution;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class MathUtils {
|
||||
private static final NormalDistribution NORMAL_DISTRIBUTION = NormalDistribution.of(0.0, 1.0);
|
||||
|
||||
/**
|
||||
* 获取百分比
|
||||
@ -12,8 +21,57 @@ public class MathUtils {
|
||||
* @return
|
||||
*/
|
||||
public static double getPercentWithDecimal(double value) {
|
||||
return new BigDecimal(value * 100)
|
||||
.setScale(1, BigDecimal.ROUND_HALF_UP)
|
||||
.doubleValue();
|
||||
return new BigDecimal(value * 100).setScale(1, RoundingMode.HALF_UP).doubleValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取预测数据的置信区间,这边计算的是预测值的置信区间,还有一个是预测值的预测区间,公式不一样,注意区分.
|
||||
* 参考资料 <a href="https://zhuanlan.zhihu.com/p/366307027">知乎</a>,
|
||||
* <a href="https://real-statistics.com/regression/confidence-and-prediction-intervals/">real-statistics</a>
|
||||
* @param data 原始数据
|
||||
* @param forecastValue 预测得到的数据
|
||||
* @param forecastData 将原数据 x 代入回归方程得到的拟合数据
|
||||
* @param alpha 置信水平
|
||||
* @param degreeOfFreedom 自由度,t分布使用
|
||||
* @return 预测值的置信区间数组
|
||||
*/
|
||||
public static double[][] getConfidenceInterval(double[][] data, double[] forecastValue, double[][] forecastData, double alpha, int degreeOfFreedom) {
|
||||
// y 平均方差
|
||||
double totalPow = 0;
|
||||
double xTotal = 0;
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
double xVal = data[i][0];
|
||||
xTotal += xVal;
|
||||
double realVal = data[i][1];
|
||||
double predictVal = forecastValue[i];
|
||||
totalPow += Math.pow((realVal - predictVal), 2);
|
||||
}
|
||||
double xAvg = xTotal / data.length;
|
||||
double yMseSqrt = Math.sqrt(totalPow / (forecastValue.length - 2));
|
||||
// x 均值方差
|
||||
double xSubPow = 0;
|
||||
for (int i = 0; i < data.length; i++) {
|
||||
double xVal = data[i][0];
|
||||
xSubPow += Math.pow(xVal - xAvg, 2);
|
||||
}
|
||||
// t/z 值, 样本数 < 30 选 t 分布, > 30 选 z 分布,
|
||||
double tzFactor;
|
||||
if (data.length <= 30) {
|
||||
tzFactor = TDistribution.of(degreeOfFreedom).inverseCumulativeProbability(1 - (1 - alpha) / 2);
|
||||
} else {
|
||||
tzFactor = NORMAL_DISTRIBUTION.inverseCumulativeProbability(1 - (1 - alpha) / 2);
|
||||
}
|
||||
double[][] result = new double[forecastData.length][2];
|
||||
for (int i = 0; i < forecastData.length; i++) {
|
||||
double xVal = forecastData[i][0];
|
||||
double curSubPow = Math.pow(xVal - xAvg, 2);
|
||||
double sqrt = Math.sqrt(1.0 / data.length + curSubPow / xSubPow);
|
||||
double lower = forecastData[i][1] - tzFactor * yMseSqrt * sqrt;
|
||||
double upper = forecastData[i][1] + tzFactor * yMseSqrt * sqrt;
|
||||
result[i][0] = lower;
|
||||
result[i][1] = upper;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,35 @@
|
||||
package io.dataease.dto.chart;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChartSeniorForecastDTO {
|
||||
/**
|
||||
* 是否开启预测
|
||||
*/
|
||||
private boolean enable;
|
||||
/**
|
||||
* 预测周期
|
||||
*/
|
||||
private int period;
|
||||
/**
|
||||
* 是否使用所有数据进行预测
|
||||
*/
|
||||
private boolean allPeriod;
|
||||
/**
|
||||
* 用于预测的数据量
|
||||
*/
|
||||
private int trainingPeriod;
|
||||
/**
|
||||
* 置信区间
|
||||
*/
|
||||
private float confidenceInterval;
|
||||
/**
|
||||
* 预测用的算法/模型
|
||||
*/
|
||||
private String algorithm;
|
||||
/**
|
||||
* 多项式阶数
|
||||
*/
|
||||
private int degree;
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package io.dataease.dto.chart;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ForecastDataDTO {
|
||||
@JsonIgnore
|
||||
private double xVal;
|
||||
@JsonIgnore
|
||||
private double yVal;
|
||||
private double lower;
|
||||
private double upper;
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
package io.dataease.dto.chart;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.EqualsAndHashCode;
|
||||
|
||||
@Data
|
||||
@EqualsAndHashCode(callSuper = true)
|
||||
public class ForecastDataVO<D, Q> extends ForecastDataDTO {
|
||||
private D dimension;
|
||||
private Q quota;
|
||||
}
|
@ -11,6 +11,7 @@ import io.dataease.commons.constants.JdbcConstants;
|
||||
import io.dataease.commons.model.PluginViewSetImpl;
|
||||
import io.dataease.commons.utils.AuthUtils;
|
||||
import io.dataease.commons.utils.BeanUtils;
|
||||
import io.dataease.commons.utils.DateUtils;
|
||||
import io.dataease.commons.utils.LogUtil;
|
||||
import io.dataease.controller.request.chart.*;
|
||||
import io.dataease.controller.response.ChartDetail;
|
||||
@ -52,6 +53,8 @@ import io.dataease.plugins.view.service.ViewPluginService;
|
||||
import io.dataease.plugins.xpack.auth.dto.request.ColumnPermissionItem;
|
||||
import io.dataease.provider.query.SQLUtils;
|
||||
import io.dataease.service.chart.util.ChartDataBuild;
|
||||
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
|
||||
import io.dataease.service.chart.util.dataForecast.ForecastAlgoManager;
|
||||
import io.dataease.service.dataset.*;
|
||||
import io.dataease.service.datasource.DatasourceService;
|
||||
import io.dataease.service.engine.EngineService;
|
||||
@ -71,6 +74,7 @@ import javax.annotation.Resource;
|
||||
import java.lang.reflect.Type;
|
||||
import java.math.BigDecimal;
|
||||
import java.math.RoundingMode;
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
@ -1090,7 +1094,7 @@ public class ChartViewService {
|
||||
logger.info("plugin_sql:" + sql);
|
||||
Map<String, Object> mapTableNormal = ChartDataBuild.transTableNormal(fieldMap, view, data, desensitizationList);
|
||||
|
||||
return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData);
|
||||
return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, Collections.emptyList());
|
||||
// 如果是插件到此结束
|
||||
}
|
||||
|
||||
@ -1361,6 +1365,16 @@ public class ChartViewService {
|
||||
}
|
||||
tempYAxis.addAll(yAxis);
|
||||
|
||||
// forecast
|
||||
List<? extends ForecastDataVO<?, ?>> forecastData = Collections.emptyList();
|
||||
JSONObject senior = JSONObject.parseObject(view.getSenior());
|
||||
JSONObject forecastObj = senior.getJSONObject("forecast");
|
||||
if (forecastObj != null) {
|
||||
ChartSeniorForecastDTO forecastCfg = forecastObj.toJavaObject(ChartSeniorForecastDTO.class);
|
||||
if (forecastCfg.isEnable()) {
|
||||
forecastData = forecastData(forecastCfg, data, xAxis, yAxis, view);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < tempYAxis.size(); i++) {
|
||||
ChartViewFieldDTO chartViewFieldDTO = tempYAxis.get(i);
|
||||
ChartFieldCompareDTO compareCalc = chartViewFieldDTO.getCompareCalc();
|
||||
@ -1622,12 +1636,43 @@ public class ChartViewService {
|
||||
|
||||
mapTableNormal = ChartDataBuild.transTableNormal(xAxis, yAxis, view, data, extStack, desensitizationList);
|
||||
}
|
||||
chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData);
|
||||
chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, forecastData);
|
||||
chartViewDTO.setTotalPage(totalPage);
|
||||
chartViewDTO.setTotalItems(totalItems);
|
||||
return chartViewDTO;
|
||||
}
|
||||
|
||||
private List<? extends ForecastDataVO<?,?>> forecastData(ChartSeniorForecastDTO forecastCfg, List<String[]> data, List<ChartViewFieldDTO> xAxis, List<ChartViewFieldDTO> yAxis, ChartViewDTO view) throws ParseException {
|
||||
List<String[]> trainingData = data;
|
||||
if (!forecastCfg.isAllPeriod() && data.size() > forecastCfg.getTrainingPeriod()) {
|
||||
trainingData = data.subList(data.size() - forecastCfg.getTrainingPeriod(), data.size() - 1);
|
||||
}
|
||||
if (xAxis.size() == 1 && xAxis.get(0).getDeType() == 1) {
|
||||
// 先处理时间类型, 默认数据是有序递增的
|
||||
String lastTime = data.get(data.size() - 1)[0];
|
||||
ChartViewFieldDTO timeAxis = xAxis.get(0);
|
||||
List<String> forecastPeriod = DateUtils.getForecastPeriod(lastTime, forecastCfg.getPeriod(), timeAxis.getDateStyle(), timeAxis.getDatePattern());
|
||||
if(!forecastPeriod.isEmpty()){
|
||||
ForecastAlgo algo = ForecastAlgoManager.getAlgo(forecastCfg.getAlgorithm());
|
||||
List<ForecastDataDTO> forecastData = algo.forecast(forecastCfg, trainingData, view);
|
||||
if (forecastPeriod.size() == forecastData.size()) {
|
||||
List<ForecastDataVO<String, Double>> result = new ArrayList<>();
|
||||
for (int i = 0; i < forecastPeriod.size(); i++) {
|
||||
String period = forecastPeriod.get(i);
|
||||
ForecastDataDTO forecastDataItem = forecastData.get(i);
|
||||
ForecastDataVO<String, Double> tmp = new ForecastDataVO<>();
|
||||
BeanUtils.copyBean(tmp, forecastDataItem);
|
||||
tmp.setDimension(period);
|
||||
tmp.setQuota(forecastDataItem.getYVal());
|
||||
result.add(tmp);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
|
||||
// 对结果排序
|
||||
public List<String[]> resultCustomSort(List<ChartViewFieldDTO> xAxis, List<String[]> data) {
|
||||
List<String[]> res = new ArrayList<>(data);
|
||||
@ -1684,11 +1729,12 @@ public class ChartViewService {
|
||||
return "SELECT " + stringBuilder + " FROM (" + sql + ") tmp";
|
||||
}
|
||||
|
||||
public ChartViewDTO uniteViewResult(String sql, Map<String, Object> chartData, Map<String, Object> tableData, ChartViewDTO view, Boolean isDrill, List<ChartExtFilterRequest> drillFilters, List<ChartSeniorAssistDTO> dynamicAssistFields, List<String[]> assistData) {
|
||||
public ChartViewDTO uniteViewResult(String sql, Map<String, Object> chartData, Map<String, Object> tableData, ChartViewDTO view, Boolean isDrill, List<ChartExtFilterRequest> drillFilters, List<ChartSeniorAssistDTO> dynamicAssistFields, List<String[]> assistData, List<? extends ForecastDataVO<?, ?>> forecastData) {
|
||||
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.putAll(chartData);
|
||||
map.putAll(tableData);
|
||||
map.put("forecastData", forecastData);
|
||||
|
||||
List<DatasetTableField> sourceFields = dataSetTableFieldsService.getFieldsByTableId(view.getTableId());
|
||||
map.put("sourceFields", sourceFields);
|
||||
|
@ -0,0 +1,21 @@
|
||||
package io.dataease.service.chart.util.dataForecast;
|
||||
|
||||
import io.dataease.dto.chart.ChartSeniorForecastDTO;
|
||||
import io.dataease.dto.chart.ChartViewDTO;
|
||||
import io.dataease.dto.chart.ForecastDataDTO;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
public abstract class ForecastAlgo {
|
||||
private final String algoType;
|
||||
|
||||
public ForecastAlgo() {
|
||||
this.algoType = this.getAlgoType();
|
||||
ForecastAlgoManager.register(this);
|
||||
}
|
||||
|
||||
public abstract List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view);
|
||||
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
package io.dataease.service.chart.util.dataForecast;
|
||||
|
||||
import io.dataease.service.chart.util.dataForecast.impl.LinearRegressionAlgo;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class ForecastAlgoManager {
|
||||
private static final Map<String, ForecastAlgo> FORECAST_ALGO_MAP = new ConcurrentHashMap<>();
|
||||
public static ForecastAlgo getAlgo(String algoType) {
|
||||
return FORECAST_ALGO_MAP.get(algoType);
|
||||
}
|
||||
|
||||
public static void register(ForecastAlgo forecastAlgo) {
|
||||
FORECAST_ALGO_MAP.put(forecastAlgo.getAlgoType(), forecastAlgo);
|
||||
}
|
||||
}
|
@ -0,0 +1,59 @@
|
||||
package io.dataease.service.chart.util.dataForecast.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.dataease.commons.utils.MathUtils;
|
||||
import io.dataease.dto.chart.ChartSeniorForecastDTO;
|
||||
import io.dataease.dto.chart.ChartViewDTO;
|
||||
import io.dataease.dto.chart.ForecastDataDTO;
|
||||
import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO;
|
||||
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.math4.legacy.stat.regression.SimpleRegression;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Getter
|
||||
public class LinearRegressionAlgo extends ForecastAlgo {
|
||||
|
||||
private final String algoType = "linear-regression";
|
||||
|
||||
@Override
|
||||
public List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view) {
|
||||
List<ChartViewFieldDTO> xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class);
|
||||
final List<double[]> forecastData = new ArrayList<>(data.size());
|
||||
// 先按连续的数据处理
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
String val = data.get(i)[xAxis.size()];
|
||||
double value = Double.parseDouble(val);
|
||||
forecastData.add(new double[]{i, value});
|
||||
}
|
||||
double[][] matrix = forecastData.toArray(double[][]::new);
|
||||
SimpleRegression regression = new SimpleRegression();
|
||||
regression.addData(matrix);
|
||||
double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2];
|
||||
for (int i = 0; i < forecastCfg.getPeriod(); i++) {
|
||||
double xVal = data.size() + i;
|
||||
double predictVal = regression.predict(xVal);
|
||||
forecastMatrix[i] = new double[]{xVal, predictVal};
|
||||
}
|
||||
final double[] forecastValue = new double[forecastData.size()];
|
||||
for (int i = 0; i < forecastData.size(); i++) {
|
||||
double xVal = forecastData.get(i)[0];
|
||||
forecastValue[i] = regression.predict(xVal);
|
||||
}
|
||||
double[][] confidenceInterval = MathUtils.getConfidenceInterval(forecastData.toArray(new double[0][0]), forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), forecastData.size() - 2);
|
||||
final List<ForecastDataDTO> result = new ArrayList<>(forecastCfg.getPeriod());
|
||||
for (int i = 0; i < forecastMatrix.length; i++) {
|
||||
ForecastDataDTO tmp = new ForecastDataDTO();
|
||||
tmp.setXVal(forecastMatrix[i][0]);
|
||||
tmp.setYVal(forecastMatrix[i][1]);
|
||||
tmp.setLower(confidenceInterval[i][0]);
|
||||
tmp.setUpper(confidenceInterval[i][1]);
|
||||
result.add(tmp);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
package io.dataease.service.chart.util.dataForecast.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import io.dataease.commons.utils.MathUtils;
|
||||
import io.dataease.dto.chart.ChartSeniorForecastDTO;
|
||||
import io.dataease.dto.chart.ChartViewDTO;
|
||||
import io.dataease.dto.chart.ForecastDataDTO;
|
||||
import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO;
|
||||
import io.dataease.service.chart.util.dataForecast.ForecastAlgo;
|
||||
import lombok.Getter;
|
||||
import org.apache.commons.math4.legacy.analysis.function.Logistic;
|
||||
import org.apache.commons.math4.legacy.fitting.PolynomialCurveFitter;
|
||||
import org.apache.commons.math4.legacy.fitting.WeightedObservedPoints;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Getter
|
||||
public class PolynomialRegressionAlgo extends ForecastAlgo {
|
||||
private final String algoType = "polynomial-regression";
|
||||
|
||||
@Override
|
||||
public List<ForecastDataDTO> forecast(ChartSeniorForecastDTO forecastCfg, List<String[]> data, ChartViewDTO view) {
|
||||
List<ChartViewFieldDTO> xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class);
|
||||
WeightedObservedPoints points = new WeightedObservedPoints();
|
||||
double[][] originData = new double[data.size()][2];
|
||||
// 先按连续的数据处理
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
String val = data.get(i)[xAxis.size()];
|
||||
double value = Double.parseDouble(val);
|
||||
points.add(i, value);
|
||||
originData[i] = new double[]{i, value};
|
||||
}
|
||||
PolynomialCurveFitter filter = PolynomialCurveFitter.create(forecastCfg.getDegree());
|
||||
// 返回的是多次项系数, y = 3 + 2x + x*2 则为 [3,2,1]
|
||||
double[] coefficients = filter.fit(points.toList());
|
||||
double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2];
|
||||
for (int i = 0; i < forecastCfg.getPeriod(); i++) {
|
||||
double xVal = data.size() + i;
|
||||
double predictVal = getPolynomialValue(xVal, coefficients);
|
||||
forecastMatrix[i] = new double[]{xVal, predictVal};
|
||||
}
|
||||
final double[] forecastValue = new double[data.size()];
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
double xVal = originData[i][0];
|
||||
forecastValue[i] = getPolynomialValue(xVal, coefficients);
|
||||
}
|
||||
int df = data.size() - forecastCfg.getDegree() - 1;
|
||||
double[][] confidenceInterval = MathUtils.getConfidenceInterval(originData, forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), df);
|
||||
final List<ForecastDataDTO> result = new ArrayList<>(forecastCfg.getPeriod());
|
||||
for (int i = 0; i < forecastMatrix.length; i++) {
|
||||
ForecastDataDTO tmp = new ForecastDataDTO();
|
||||
tmp.setXVal(forecastMatrix[i][0]);
|
||||
tmp.setYVal(forecastMatrix[i][1]);
|
||||
tmp.setLower(confidenceInterval[i][0]);
|
||||
tmp.setUpper(confidenceInterval[i][1]);
|
||||
result.add(tmp);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private double getPolynomialValue(double x, double[] coefficients) {
|
||||
double result = 0.0;
|
||||
for (int i = 0; i < coefficients.length; i++) {
|
||||
result += coefficients[i] * Math.pow(x, i);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user