目录

java-spark一元(多元)线性回归

# 配置

配置请看我的其他文章 点击跳转 (opens new window)

# spark官方文档

点击跳转官方文档 (opens new window)

# 其它文章

推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转 (opens new window)

# 数据

# 训练数据

在这里插入图片描述

# 预测数据


# 实体类

用了swagger和lombok 不需要的可以删掉


import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;

/**
 * 线性回归参数
 *
 * @author teler
 * @date 2020-09-21
 */
@Data
public class LinearRegressionEntity {
	/**
	 * 训练数据集路径
	 */
	@ApiModelProperty("训练数据集路径")
	@NotEmpty(message = "必须有样本集")
	private String trainFilePath;
	/**
	 * 预测数据集路径
	 */
	@ApiModelProperty("预测数据集路径")
	@NotEmpty(message = "必须有预测集")
	private String dataFilePath;
	/**
	 * 用于测试模型的数据比例,范围[0,1]
	 */
	@ApiModelProperty("用于测试模型的数据比例,范围[0,1]")
	@Max(value = 1L, message = "数据比例最大值为1.0")
	@Min(value = 0L, message = "数据比例最小值为0.0")
	private double testDataPct;
	/**
	 * 迭代次数
	 */
	@ApiModelProperty("迭代次数")
	@NotNull(message = "迭代次数必填")
	@Min(value = 0, message = "迭代次数最小值为0")
	private Integer iter;
	/**
	 * 正则化参数,范围[0,1]
	 */
	@NotNull(message = "正则化参数必填")
	@Max(value = 1L, message = "正则化参数最大值为1.0")
	@Min(value = 0L, message = "正则化参数最小值为0.0")
	@ApiModelProperty("正则化参数,范围[0,1]")
	private double regParam;
	/**
	 * 弹性网络混合参数,范围[0,1]
	 */
	@ApiModelProperty("弹性网络混合参数,范围[0,1]")
	@Max(value = 1L, message = "弹性网络混合参数最大值为1.0")
	@Min(value = 0L, message = "弹性网络混合参数最小值为0.0")
	private double elasticNetParam;
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

# 算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;

@Override
	public Map<String, Object> linearRegression(LinearRegressionEntity record) {
		log.info("========== 线性回归计算开始 ==========");
		Map<String, Object> map = new HashMap<>(16);
		Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());

		List<Map<String, String>> sourceList = toList(source);
		//训练数据
		map.put("training", sourceList);

		//根据比例从数据源中随机抽取数据 /训练数据和测试数据比例 建议设为0.8
		Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},
				1234L);
		//训练数据
		Dataset<Row> trainingData = splits[0].cache();

		// 10 / 0.3 / 0.8
		LinearRegression lr = new LinearRegression()
				                      .setMaxIter(record.getIter())
				                      .setRegParam(record.getRegParam())
				                      .setElasticNetParam(record.getElasticNetParam());

		LinearRegressionModel lrModel = lr.fit(trainingData);

		//系数
		map.put("coefficients", Arrays.stream(lrModel.coefficients().toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
		//截距
		map.put("intercept", NumberUtil.roundDown(lrModel.intercept(), 3));

		//训练数据结果集
		LinearRegressionTrainingSummary trainingSummary = lrModel.summary();

		//迭代次数
		map.put("numIterations", trainingSummary.totalIterations());
		//损失率,一般会逐渐减小
		map.put("objectiveHistory", Arrays.stream(trainingSummary.objectiveHistory()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
		//均方根误差
		map.put("rmse", NumberUtil.roundDown(trainingSummary.rootMeanSquaredError(), 3));
		//真实误差
		map.put("mae", NumberUtil.roundDown(trainingSummary.meanAbsoluteError(), 3));
		//r平方 越接近1说明效果越好
		map.put("r2", NumberUtil.roundDown(trainingSummary.r2(), 3));

		//预测数据
		Dataset<Row> predictionData = getDataSetByHdfs(record.getDataFilePath());
		Dataset<Row> predictionResult = lrModel.transform(predictionData).selectExpr("label", "features", "round(prediction,3) as prediction");

		predictionResult.show();
		List<Object> predictionFeaturesVal = dataSetToString(predictionResult.select("features"));
		
		map.put("data", toList(predictionResult));
		log.info("========== 线性回归计算结束 ==========");
		return map;
	}



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61

# getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/**
	 * 从hdfs中取数据
	 *
	 * @param dataFilePath 数据路径
	 * @return 数据集合
	 */
	private Dataset<Row> getDataSetByHdfs(String dataFilePath) {

		//屏蔽日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);

		Dataset<Row> dataset;
		try {
			//我这里的数据是libsvm格式的 如果是其他格式请自行更改
			dataset = sparkSession.read().format("libsvm").load(dataFilePath);
			log.info("获取数据结束 ");
		} catch (Exception e) {
			log.info("读取失败:{} ", e.getMessage());
		}
		return dataset;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# toList

/**
	 * dataset数据转化为list数据
	 *
	 * @param record 数据
	 * @return 数据集合
	 */
	private List<Map<String, String>> toList(Dataset<Row> record) {
		log.info("格式化结果数据集===============================");
		List<Map<String, String>> list = new ArrayList<>();
		String[] columns = record.columns();
		List<Row> rows = record.collectAsList();
		for (Row row : rows) {
			Map<String, String> obj = new HashMap<>(16);
			for (int j = 0; j < columns.length; j++) {
				String col = columns[j];
				Object rowAs = row.getAs(col);
				String val = "";
				//如果是数组 
				//这一段不需要的可以只留下else的内容
				if (rowAs instanceof DenseVector) {
					if (((DenseVector) rowAs).values() instanceof double[]) {
						val = ArrayUtil.join(
								Arrays.stream(((DenseVector) rowAs).values())
										.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
								, ",")
						;
					} else {
						val = rowAs.toString();
					}
				} else {
					val = rowAs.toString();
				}

				obj.put(col, val);
				log.info("列:{},名:{},值:{}", j, col, val);
			}
			list.add(obj);
		}
		return list;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
上次更新: 2024-11-06, 19:27:10
最近更新
01
java playwright爬虫
11-06
02
连接chrome调试
07-23
03
连接chrome调试
07-23
更多文章>