From 5329734949ec1ca6f16dd5ca90e3148504d471a1 Mon Sep 17 00:00:00 2001 From: wisonic-s Date: Fri, 17 May 2024 17:21:47 +0800 Subject: [PATCH] =?UTF-8?q?feat(=E8=A7=86=E5=9B=BE):=20=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=A2=84=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../io/dataease/commons/utils/DateUtils.java | 112 +++++++++++++++--- .../io/dataease/commons/utils/MathUtils.java | 64 +++++++++- .../dto/chart/ChartSeniorForecastDTO.java | 35 ++++++ .../dataease/dto/chart/ForecastDataDTO.java | 14 +++ .../io/dataease/dto/chart/ForecastDataVO.java | 11 ++ .../service/chart/ChartViewService.java | 52 +++++++- .../chart/util/dataForecast/ForecastAlgo.java | 21 ++++ .../dataForecast/ForecastAlgoManager.java | 17 +++ .../impl/LinearRegressionAlgo.java | 59 +++++++++ .../impl/PolynomialRegressionAlgo.java | 71 +++++++++++ 10 files changed, 433 insertions(+), 23 deletions(-) create mode 100644 core/backend/src/main/java/io/dataease/dto/chart/ChartSeniorForecastDTO.java create mode 100644 core/backend/src/main/java/io/dataease/dto/chart/ForecastDataDTO.java create mode 100644 core/backend/src/main/java/io/dataease/dto/chart/ForecastDataVO.java create mode 100644 core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgo.java create mode 100644 core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgoManager.java create mode 100644 core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/LinearRegressionAlgo.java create mode 100644 core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/PolynomialRegressionAlgo.java diff --git a/core/backend/src/main/java/io/dataease/commons/utils/DateUtils.java b/core/backend/src/main/java/io/dataease/commons/utils/DateUtils.java index d95b7a16c1..57da621224 100644 --- a/core/backend/src/main/java/io/dataease/commons/utils/DateUtils.java +++ b/core/backend/src/main/java/io/dataease/commons/utils/DateUtils.java @@ -1,33 +1,34 @@ package io.dataease.commons.utils; +import org.apache.commons.lang3.StringUtils; + +import java.text.ParseException; import java.text.SimpleDateFormat; -import java.util.Calendar; -import java.util.Date; -import java.util.HashMap; -import java.util.Map; +import java.util.*; public class DateUtils { - public static final String DATE_PATTERM = "yyyy-MM-dd"; + public static final String DATE_PATTERN = "yyyy-MM-dd"; public static final String TIME_PATTERN = "yyyy-MM-dd HH:mm:ss"; public static Date getDate(String dateString) throws Exception { - SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN); return dateFormat.parse(dateString); } + public static Date getTime(String timeString) throws Exception { SimpleDateFormat dateFormat = new SimpleDateFormat(TIME_PATTERN); return dateFormat.parse(timeString); } public static String getDateString(Date date) throws Exception { - SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN); return dateFormat.format(date); } public static String getDateString(long timeStamp) throws Exception { - SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERM); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_PATTERN); return dateFormat.format(timeStamp); } @@ -47,10 +48,10 @@ public class DateUtils { } - public static Date dateSum (Date date,int countDays){ + public static Date dateSum(Date date, int countDays) { Calendar calendar = Calendar.getInstance(); calendar.setTime(date); - calendar.add(Calendar.DAY_OF_MONTH,countDays); + calendar.add(Calendar.DAY_OF_MONTH, countDays); return calendar.getTime(); } @@ -70,7 +71,7 @@ public class DateUtils { try { calendar.setTime(date); calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMinimum(Calendar.DAY_OF_WEEK)); - calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd); + calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd); //第一天的时分秒是 00:00:00 这里直接取日期,默认就是零点零分 Date thisWeekFirstTime = getDate(getDateString(calendar.getTime())); @@ -78,12 +79,12 @@ public class DateUtils { calendar.clear(); calendar.setTime(date); calendar.set(Calendar.DAY_OF_WEEK, calendar.getActualMaximum(Calendar.DAY_OF_WEEK)); - calendar.add(Calendar.DAY_OF_MONTH,weekDayAdd); + calendar.add(Calendar.DAY_OF_MONTH, weekDayAdd); //最后一天的时分秒应当是23:59:59。 处理方式是增加一天计算日期再-1 - calendar.add(Calendar.DAY_OF_MONTH,1); + calendar.add(Calendar.DAY_OF_MONTH, 1); Date nextWeekFirstDay = getDate(getDateString(calendar.getTime())); - Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime()-1)); + Date thisWeekLastTime = getTime(getTimeString(nextWeekFirstDay.getTime() - 1)); returnMap.put("firstTime", thisWeekFirstTime); returnMap.put("lastTime", thisWeekLastTime); @@ -95,14 +96,91 @@ public class DateUtils { } - /** * 获取当天的起始时间Date - * @param time 指定日期 例: 2020-12-13 06:12:42 - * @return 当天起始时间 例: 2020-12-13 00:00:00 + * + * @param time 指定日期 例: 2020-12-13 06:12:42 + * @return 当天起始时间 例: 2020-12-13 00:00:00 * @throws Exception */ public static Date getDayStartTime(Date time) throws Exception { return getDate(getDateString(time)); } + + public static List getForecastPeriod(String baseTime, int period, String dateStyle, String pattern) throws ParseException { + String split = "-"; + if (StringUtils.equalsIgnoreCase(pattern, "date_split")) { + split = "/"; + } + + List result = new ArrayList<>(period); + switch (dateStyle) { + case "y": + int baseYear = Integer.parseInt(baseTime); + for (int i = 1; i <= period; i++) { + result.add(baseYear + i + ""); + } + break; + case "y_Q": + String[] yQ = baseTime.split(split); + int year = Integer.parseInt(yQ[0]); + int quarter = Integer.parseInt(yQ[1].split("Q")[1]); + for (int i = 0; i < period; i++) { + quarter = quarter % 4 + 1; + if (quarter == 1) { + year += 1; + } + result.add(year + split + "Q" + quarter); + } + break; + case "y_M": + String[] yM = baseTime.split(split); + int y = Integer.parseInt(yM[0]); + int month = Integer.parseInt(yM[1]); + for (int i = 0; i < period; i++) { + month = month % 12 + 1; + if (month == 1) { + y += 1; + } + String padMonth = month < 10 ? "0" + month : "" + month; + result.add(y + split + padMonth); + } + break; + case "y_W": + String[] yW = baseTime.split(split); + int yy = Integer.parseInt(yW[0]); + int w = Integer.parseInt(yW[1].split("W")[1]); + for (int i = 0; i < period; i++) { + Calendar calendar = Calendar.getInstance(); + calendar.setMinimalDaysInFirstWeek(7); + calendar.setFirstDayOfWeek(Calendar.MONDAY); + calendar.set(Calendar.YEAR, yy); + calendar.set(Calendar.MONTH, Calendar.DECEMBER); + calendar.set(Calendar.DAY_OF_MONTH, 31); + int lastWeek = calendar.get(Calendar.WEEK_OF_YEAR); + w += 1; + if (w > lastWeek) { + yy += 1; + w = 1; + } + result.add(yy + split + "W" + w); + } + break; + case "y_M_d": + SimpleDateFormat sdf = new SimpleDateFormat("yyyy" + split + "MM" + split + "dd"); + Calendar calendar = Calendar.getInstance(); + Date baseDate = sdf.parse(baseTime); + calendar.setTime(baseDate); + for (int i = 0; i < period; i++) { + calendar.add(Calendar.DAY_OF_MONTH, 1); + Date curDate = calendar.getTime(); + String date = sdf.format(curDate); + result.add(date); + } + break; + default: + break; + } + return result; + } } diff --git a/core/backend/src/main/java/io/dataease/commons/utils/MathUtils.java b/core/backend/src/main/java/io/dataease/commons/utils/MathUtils.java index a48a203155..39f8c74b0c 100644 --- a/core/backend/src/main/java/io/dataease/commons/utils/MathUtils.java +++ b/core/backend/src/main/java/io/dataease/commons/utils/MathUtils.java @@ -1,8 +1,17 @@ package io.dataease.commons.utils; +import groovy.lang.Tuple2; +import org.apache.commons.math4.legacy.stat.regression.SimpleRegression; +import org.apache.commons.statistics.distribution.NormalDistribution; +import org.apache.commons.statistics.distribution.TDistribution; + import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.List; public class MathUtils { + private static final NormalDistribution NORMAL_DISTRIBUTION = NormalDistribution.of(0.0, 1.0); /** * 获取百分比 @@ -12,8 +21,57 @@ public class MathUtils { * @return */ public static double getPercentWithDecimal(double value) { - return new BigDecimal(value * 100) - .setScale(1, BigDecimal.ROUND_HALF_UP) - .doubleValue(); + return new BigDecimal(value * 100).setScale(1, RoundingMode.HALF_UP).doubleValue(); } + + /** + * 获取预测数据的置信区间,这边计算的是预测值的置信区间,还有一个是预测值的预测区间,公式不一样,注意区分. + * 参考资料 知乎, + * real-statistics + * @param data 原始数据 + * @param forecastValue 预测得到的数据 + * @param forecastData 将原数据 x 代入回归方程得到的拟合数据 + * @param alpha 置信水平 + * @param degreeOfFreedom 自由度,t分布使用 + * @return 预测值的置信区间数组 + */ + public static double[][] getConfidenceInterval(double[][] data, double[] forecastValue, double[][] forecastData, double alpha, int degreeOfFreedom) { + // y 平均方差 + double totalPow = 0; + double xTotal = 0; + for (int i = 0; i < data.length; i++) { + double xVal = data[i][0]; + xTotal += xVal; + double realVal = data[i][1]; + double predictVal = forecastValue[i]; + totalPow += Math.pow((realVal - predictVal), 2); + } + double xAvg = xTotal / data.length; + double yMseSqrt = Math.sqrt(totalPow / (forecastValue.length - 2)); + // x 均值方差 + double xSubPow = 0; + for (int i = 0; i < data.length; i++) { + double xVal = data[i][0]; + xSubPow += Math.pow(xVal - xAvg, 2); + } + // t/z 值, 样本数 < 30 选 t 分布, > 30 选 z 分布, + double tzFactor; + if (data.length <= 30) { + tzFactor = TDistribution.of(degreeOfFreedom).inverseCumulativeProbability(1 - (1 - alpha) / 2); + } else { + tzFactor = NORMAL_DISTRIBUTION.inverseCumulativeProbability(1 - (1 - alpha) / 2); + } + double[][] result = new double[forecastData.length][2]; + for (int i = 0; i < forecastData.length; i++) { + double xVal = forecastData[i][0]; + double curSubPow = Math.pow(xVal - xAvg, 2); + double sqrt = Math.sqrt(1.0 / data.length + curSubPow / xSubPow); + double lower = forecastData[i][1] - tzFactor * yMseSqrt * sqrt; + double upper = forecastData[i][1] + tzFactor * yMseSqrt * sqrt; + result[i][0] = lower; + result[i][1] = upper; + } + return result; + } + } diff --git a/core/backend/src/main/java/io/dataease/dto/chart/ChartSeniorForecastDTO.java b/core/backend/src/main/java/io/dataease/dto/chart/ChartSeniorForecastDTO.java new file mode 100644 index 0000000000..e4a44dccfd --- /dev/null +++ b/core/backend/src/main/java/io/dataease/dto/chart/ChartSeniorForecastDTO.java @@ -0,0 +1,35 @@ +package io.dataease.dto.chart; + +import lombok.Data; + +@Data +public class ChartSeniorForecastDTO { + /** + * 是否开启预测 + */ + private boolean enable; + /** + * 预测周期 + */ + private int period; + /** + * 是否使用所有数据进行预测 + */ + private boolean allPeriod; + /** + * 用于预测的数据量 + */ + private int trainingPeriod; + /** + * 置信区间 + */ + private float confidenceInterval; + /** + * 预测用的算法/模型 + */ + private String algorithm; + /** + * 多项式阶数 + */ + private int degree; +} diff --git a/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataDTO.java b/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataDTO.java new file mode 100644 index 0000000000..90e889f6a1 --- /dev/null +++ b/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataDTO.java @@ -0,0 +1,14 @@ +package io.dataease.dto.chart; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import lombok.Data; + +@Data +public class ForecastDataDTO { + @JsonIgnore + private double xVal; + @JsonIgnore + private double yVal; + private double lower; + private double upper; +} diff --git a/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataVO.java b/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataVO.java new file mode 100644 index 0000000000..752ce0fc0e --- /dev/null +++ b/core/backend/src/main/java/io/dataease/dto/chart/ForecastDataVO.java @@ -0,0 +1,11 @@ +package io.dataease.dto.chart; + +import lombok.Data; +import lombok.EqualsAndHashCode; + +@Data +@EqualsAndHashCode(callSuper = true) +public class ForecastDataVO extends ForecastDataDTO { + private D dimension; + private Q quota; +} diff --git a/core/backend/src/main/java/io/dataease/service/chart/ChartViewService.java b/core/backend/src/main/java/io/dataease/service/chart/ChartViewService.java index fe48d5db5c..44528a8f51 100644 --- a/core/backend/src/main/java/io/dataease/service/chart/ChartViewService.java +++ b/core/backend/src/main/java/io/dataease/service/chart/ChartViewService.java @@ -11,6 +11,7 @@ import io.dataease.commons.constants.JdbcConstants; import io.dataease.commons.model.PluginViewSetImpl; import io.dataease.commons.utils.AuthUtils; import io.dataease.commons.utils.BeanUtils; +import io.dataease.commons.utils.DateUtils; import io.dataease.commons.utils.LogUtil; import io.dataease.controller.request.chart.*; import io.dataease.controller.response.ChartDetail; @@ -52,6 +53,8 @@ import io.dataease.plugins.view.service.ViewPluginService; import io.dataease.plugins.xpack.auth.dto.request.ColumnPermissionItem; import io.dataease.provider.query.SQLUtils; import io.dataease.service.chart.util.ChartDataBuild; +import io.dataease.service.chart.util.dataForecast.ForecastAlgo; +import io.dataease.service.chart.util.dataForecast.ForecastAlgoManager; import io.dataease.service.dataset.*; import io.dataease.service.datasource.DatasourceService; import io.dataease.service.engine.EngineService; @@ -71,6 +74,7 @@ import javax.annotation.Resource; import java.lang.reflect.Type; import java.math.BigDecimal; import java.math.RoundingMode; +import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.atomic.AtomicReference; @@ -1090,7 +1094,7 @@ public class ChartViewService { logger.info("plugin_sql:" + sql); Map mapTableNormal = ChartDataBuild.transTableNormal(fieldMap, view, data, desensitizationList); - return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData); + return uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, Collections.emptyList()); // 如果是插件到此结束 } @@ -1361,6 +1365,16 @@ public class ChartViewService { } tempYAxis.addAll(yAxis); + // forecast + List> forecastData = Collections.emptyList(); + JSONObject senior = JSONObject.parseObject(view.getSenior()); + JSONObject forecastObj = senior.getJSONObject("forecast"); + if (forecastObj != null) { + ChartSeniorForecastDTO forecastCfg = forecastObj.toJavaObject(ChartSeniorForecastDTO.class); + if (forecastCfg.isEnable()) { + forecastData = forecastData(forecastCfg, data, xAxis, yAxis, view); + } + } for (int i = 0; i < tempYAxis.size(); i++) { ChartViewFieldDTO chartViewFieldDTO = tempYAxis.get(i); ChartFieldCompareDTO compareCalc = chartViewFieldDTO.getCompareCalc(); @@ -1622,12 +1636,43 @@ public class ChartViewService { mapTableNormal = ChartDataBuild.transTableNormal(xAxis, yAxis, view, data, extStack, desensitizationList); } - chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData); + chartViewDTO = uniteViewResult(datasourceRequest.getQuery(), mapChart, mapTableNormal, view, isDrill, drillFilters, dynamicAssistFields, assistData, forecastData); chartViewDTO.setTotalPage(totalPage); chartViewDTO.setTotalItems(totalItems); return chartViewDTO; } + private List> forecastData(ChartSeniorForecastDTO forecastCfg, List data, List xAxis, List yAxis, ChartViewDTO view) throws ParseException { + List trainingData = data; + if (!forecastCfg.isAllPeriod() && data.size() > forecastCfg.getTrainingPeriod()) { + trainingData = data.subList(data.size() - forecastCfg.getTrainingPeriod(), data.size() - 1); + } + if (xAxis.size() == 1 && xAxis.get(0).getDeType() == 1) { + // 先处理时间类型, 默认数据是有序递增的 + String lastTime = data.get(data.size() - 1)[0]; + ChartViewFieldDTO timeAxis = xAxis.get(0); + List forecastPeriod = DateUtils.getForecastPeriod(lastTime, forecastCfg.getPeriod(), timeAxis.getDateStyle(), timeAxis.getDatePattern()); + if(!forecastPeriod.isEmpty()){ + ForecastAlgo algo = ForecastAlgoManager.getAlgo(forecastCfg.getAlgorithm()); + List forecastData = algo.forecast(forecastCfg, trainingData, view); + if (forecastPeriod.size() == forecastData.size()) { + List> result = new ArrayList<>(); + for (int i = 0; i < forecastPeriod.size(); i++) { + String period = forecastPeriod.get(i); + ForecastDataDTO forecastDataItem = forecastData.get(i); + ForecastDataVO tmp = new ForecastDataVO<>(); + BeanUtils.copyBean(tmp, forecastDataItem); + tmp.setDimension(period); + tmp.setQuota(forecastDataItem.getYVal()); + result.add(tmp); + } + return result; + } + } + } + return List.of(); + } + // 对结果排序 public List resultCustomSort(List xAxis, List data) { List res = new ArrayList<>(data); @@ -1684,11 +1729,12 @@ public class ChartViewService { return "SELECT " + stringBuilder + " FROM (" + sql + ") tmp"; } - public ChartViewDTO uniteViewResult(String sql, Map chartData, Map tableData, ChartViewDTO view, Boolean isDrill, List drillFilters, List dynamicAssistFields, List assistData) { + public ChartViewDTO uniteViewResult(String sql, Map chartData, Map tableData, ChartViewDTO view, Boolean isDrill, List drillFilters, List dynamicAssistFields, List assistData, List> forecastData) { Map map = new HashMap<>(); map.putAll(chartData); map.putAll(tableData); + map.put("forecastData", forecastData); List sourceFields = dataSetTableFieldsService.getFieldsByTableId(view.getTableId()); map.put("sourceFields", sourceFields); diff --git a/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgo.java b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgo.java new file mode 100644 index 0000000000..a9328e639c --- /dev/null +++ b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgo.java @@ -0,0 +1,21 @@ +package io.dataease.service.chart.util.dataForecast; + +import io.dataease.dto.chart.ChartSeniorForecastDTO; +import io.dataease.dto.chart.ChartViewDTO; +import io.dataease.dto.chart.ForecastDataDTO; +import lombok.Getter; + +import java.util.List; + +@Getter +public abstract class ForecastAlgo { + private final String algoType; + + public ForecastAlgo() { + this.algoType = this.getAlgoType(); + ForecastAlgoManager.register(this); + } + + public abstract List forecast(ChartSeniorForecastDTO forecastCfg, List data, ChartViewDTO view); + +} diff --git a/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgoManager.java b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgoManager.java new file mode 100644 index 0000000000..d759c1f210 --- /dev/null +++ b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/ForecastAlgoManager.java @@ -0,0 +1,17 @@ +package io.dataease.service.chart.util.dataForecast; + +import io.dataease.service.chart.util.dataForecast.impl.LinearRegressionAlgo; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class ForecastAlgoManager { + private static final Map FORECAST_ALGO_MAP = new ConcurrentHashMap<>(); + public static ForecastAlgo getAlgo(String algoType) { + return FORECAST_ALGO_MAP.get(algoType); + } + + public static void register(ForecastAlgo forecastAlgo) { + FORECAST_ALGO_MAP.put(forecastAlgo.getAlgoType(), forecastAlgo); + } +} diff --git a/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/LinearRegressionAlgo.java b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/LinearRegressionAlgo.java new file mode 100644 index 0000000000..a72fe920f9 --- /dev/null +++ b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/LinearRegressionAlgo.java @@ -0,0 +1,59 @@ +package io.dataease.service.chart.util.dataForecast.impl; + +import com.alibaba.fastjson.JSONObject; +import io.dataease.commons.utils.MathUtils; +import io.dataease.dto.chart.ChartSeniorForecastDTO; +import io.dataease.dto.chart.ChartViewDTO; +import io.dataease.dto.chart.ForecastDataDTO; +import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO; +import io.dataease.service.chart.util.dataForecast.ForecastAlgo; +import lombok.Getter; +import org.apache.commons.math4.legacy.stat.regression.SimpleRegression; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@Getter +public class LinearRegressionAlgo extends ForecastAlgo { + + private final String algoType = "linear-regression"; + + @Override + public List forecast(ChartSeniorForecastDTO forecastCfg, List data, ChartViewDTO view) { + List xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class); + final List forecastData = new ArrayList<>(data.size()); + // 先按连续的数据处理 + for (int i = 0; i < data.size(); i++) { + String val = data.get(i)[xAxis.size()]; + double value = Double.parseDouble(val); + forecastData.add(new double[]{i, value}); + } + double[][] matrix = forecastData.toArray(double[][]::new); + SimpleRegression regression = new SimpleRegression(); + regression.addData(matrix); + double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2]; + for (int i = 0; i < forecastCfg.getPeriod(); i++) { + double xVal = data.size() + i; + double predictVal = regression.predict(xVal); + forecastMatrix[i] = new double[]{xVal, predictVal}; + } + final double[] forecastValue = new double[forecastData.size()]; + for (int i = 0; i < forecastData.size(); i++) { + double xVal = forecastData.get(i)[0]; + forecastValue[i] = regression.predict(xVal); + } + double[][] confidenceInterval = MathUtils.getConfidenceInterval(forecastData.toArray(new double[0][0]), forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), forecastData.size() - 2); + final List result = new ArrayList<>(forecastCfg.getPeriod()); + for (int i = 0; i < forecastMatrix.length; i++) { + ForecastDataDTO tmp = new ForecastDataDTO(); + tmp.setXVal(forecastMatrix[i][0]); + tmp.setYVal(forecastMatrix[i][1]); + tmp.setLower(confidenceInterval[i][0]); + tmp.setUpper(confidenceInterval[i][1]); + result.add(tmp); + } + return result; + } +} diff --git a/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/PolynomialRegressionAlgo.java b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/PolynomialRegressionAlgo.java new file mode 100644 index 0000000000..f73639e92e --- /dev/null +++ b/core/backend/src/main/java/io/dataease/service/chart/util/dataForecast/impl/PolynomialRegressionAlgo.java @@ -0,0 +1,71 @@ +package io.dataease.service.chart.util.dataForecast.impl; + +import com.alibaba.fastjson.JSONObject; +import io.dataease.commons.utils.MathUtils; +import io.dataease.dto.chart.ChartSeniorForecastDTO; +import io.dataease.dto.chart.ChartViewDTO; +import io.dataease.dto.chart.ForecastDataDTO; +import io.dataease.plugins.common.dto.chart.ChartViewFieldDTO; +import io.dataease.service.chart.util.dataForecast.ForecastAlgo; +import lombok.Getter; +import org.apache.commons.math4.legacy.analysis.function.Logistic; +import org.apache.commons.math4.legacy.fitting.PolynomialCurveFitter; +import org.apache.commons.math4.legacy.fitting.WeightedObservedPoints; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.List; + +@Component +@Getter +public class PolynomialRegressionAlgo extends ForecastAlgo { + private final String algoType = "polynomial-regression"; + + @Override + public List forecast(ChartSeniorForecastDTO forecastCfg, List data, ChartViewDTO view) { + List xAxis = JSONObject.parseArray(view.getXAxis(), ChartViewFieldDTO.class); + WeightedObservedPoints points = new WeightedObservedPoints(); + double[][] originData = new double[data.size()][2]; + // 先按连续的数据处理 + for (int i = 0; i < data.size(); i++) { + String val = data.get(i)[xAxis.size()]; + double value = Double.parseDouble(val); + points.add(i, value); + originData[i] = new double[]{i, value}; + } + PolynomialCurveFitter filter = PolynomialCurveFitter.create(forecastCfg.getDegree()); + // 返回的是多次项系数, y = 3 + 2x + x*2 则为 [3,2,1] + double[] coefficients = filter.fit(points.toList()); + double[][] forecastMatrix = new double[forecastCfg.getPeriod()][2]; + for (int i = 0; i < forecastCfg.getPeriod(); i++) { + double xVal = data.size() + i; + double predictVal = getPolynomialValue(xVal, coefficients); + forecastMatrix[i] = new double[]{xVal, predictVal}; + } + final double[] forecastValue = new double[data.size()]; + for (int i = 0; i < data.size(); i++) { + double xVal = originData[i][0]; + forecastValue[i] = getPolynomialValue(xVal, coefficients); + } + int df = data.size() - forecastCfg.getDegree() - 1; + double[][] confidenceInterval = MathUtils.getConfidenceInterval(originData, forecastValue, forecastMatrix, forecastCfg.getConfidenceInterval(), df); + final List result = new ArrayList<>(forecastCfg.getPeriod()); + for (int i = 0; i < forecastMatrix.length; i++) { + ForecastDataDTO tmp = new ForecastDataDTO(); + tmp.setXVal(forecastMatrix[i][0]); + tmp.setYVal(forecastMatrix[i][1]); + tmp.setLower(confidenceInterval[i][0]); + tmp.setUpper(confidenceInterval[i][1]); + result.add(tmp); + } + return result; + } + + private double getPolynomialValue(double x, double[] coefficients) { + double result = 0.0; + for (int i = 0; i < coefficients.length; i++) { + result += coefficients[i] * Math.pow(x, i); + } + return result; + } +}