java-spark一元(多元)线性回归
# 配置
配置请看我的其他文章 点击跳转 (opens new window)
# spark官方文档
# 其它文章
推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转 (opens new window)
# 数据
# 训练数据
# 预测数据
# 实体类
用了swagger和lombok 不需要的可以删掉
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
/**
* 线性回归参数
*
* @author teler
* @date 2020-09-21
*/
@Data
public class LinearRegressionEntity {
/**
* 训练数据集路径
*/
@ApiModelProperty("训练数据集路径")
@NotEmpty(message = "必须有样本集")
private String trainFilePath;
/**
* 预测数据集路径
*/
@ApiModelProperty("预测数据集路径")
@NotEmpty(message = "必须有预测集")
private String dataFilePath;
/**
* 用于测试模型的数据比例,范围[0,1]
*/
@ApiModelProperty("用于测试模型的数据比例,范围[0,1]")
@Max(value = 1L, message = "数据比例最大值为1.0")
@Min(value = 0L, message = "数据比例最小值为0.0")
private double testDataPct;
/**
* 迭代次数
*/
@ApiModelProperty("迭代次数")
@NotNull(message = "迭代次数必填")
@Min(value = 0, message = "迭代次数最小值为0")
private Integer iter;
/**
* 正则化参数,范围[0,1]
*/
@NotNull(message = "正则化参数必填")
@Max(value = 1L, message = "正则化参数最大值为1.0")
@Min(value = 0L, message = "正则化参数最小值为0.0")
@ApiModelProperty("正则化参数,范围[0,1]")
private double regParam;
/**
* 弹性网络混合参数,范围[0,1]
*/
@ApiModelProperty("弹性网络混合参数,范围[0,1]")
@Max(value = 1L, message = "弹性网络混合参数最大值为1.0")
@Min(value = 0L, message = "弹性网络混合参数最小值为0.0")
private double elasticNetParam;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# 算法实现
里面有些方法是为了保留小数 不需要的自己改
@Resource
private transient SparkSession sparkSession;
@Override
public Map<String, Object> linearRegression(LinearRegressionEntity record) {
log.info("========== 线性回归计算开始 ==========");
Map<String, Object> map = new HashMap<>(16);
Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());
List<Map<String, String>> sourceList = toList(source);
//训练数据
map.put("training", sourceList);
//根据比例从数据源中随机抽取数据 /训练数据和测试数据比例 建议设为0.8
Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},
1234L);
//训练数据
Dataset<Row> trainingData = splits[0].cache();
// 10 / 0.3 / 0.8
LinearRegression lr = new LinearRegression()
.setMaxIter(record.getIter())
.setRegParam(record.getRegParam())
.setElasticNetParam(record.getElasticNetParam());
LinearRegressionModel lrModel = lr.fit(trainingData);
//系数
map.put("coefficients", Arrays.stream(lrModel.coefficients().toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
//截距
map.put("intercept", NumberUtil.roundDown(lrModel.intercept(), 3));
//训练数据结果集
LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
//迭代次数
map.put("numIterations", trainingSummary.totalIterations());
//损失率,一般会逐渐减小
map.put("objectiveHistory", Arrays.stream(trainingSummary.objectiveHistory()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()));
//均方根误差
map.put("rmse", NumberUtil.roundDown(trainingSummary.rootMeanSquaredError(), 3));
//真实误差
map.put("mae", NumberUtil.roundDown(trainingSummary.meanAbsoluteError(), 3));
//r平方 越接近1说明效果越好
map.put("r2", NumberUtil.roundDown(trainingSummary.r2(), 3));
//预测数据
Dataset<Row> predictionData = getDataSetByHdfs(record.getDataFilePath());
Dataset<Row> predictionResult = lrModel.transform(predictionData).selectExpr("label", "features", "round(prediction,3) as prediction");
predictionResult.show();
List<Object> predictionFeaturesVal = dataSetToString(predictionResult.select("features"));
map.put("data", toList(predictionResult));
log.info("========== 线性回归计算结束 ==========");
return map;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# getDataSetByHdfs方法
这个方法我与上面的方法放在一个类中,所以sparkSession没重复写
/**
* 从hdfs中取数据
*
* @param dataFilePath 数据路径
* @return 数据集合
*/
private Dataset<Row> getDataSetByHdfs(String dataFilePath) {
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
Dataset<Row> dataset;
try {
//我这里的数据是libsvm格式的 如果是其他格式请自行更改
dataset = sparkSession.read().format("libsvm").load(dataFilePath);
log.info("获取数据结束 ");
} catch (Exception e) {
log.info("读取失败:{} ", e.getMessage());
}
return dataset;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# toList
/**
* dataset数据转化为list数据
*
* @param record 数据
* @return 数据集合
*/
private List<Map<String, String>> toList(Dataset<Row> record) {
log.info("格式化结果数据集===============================");
List<Map<String, String>> list = new ArrayList<>();
String[] columns = record.columns();
List<Row> rows = record.collectAsList();
for (Row row : rows) {
Map<String, String> obj = new HashMap<>(16);
for (int j = 0; j < columns.length; j++) {
String col = columns[j];
Object rowAs = row.getAs(col);
String val = "";
//如果是数组
//这一段不需要的可以只留下else的内容
if (rowAs instanceof DenseVector) {
if (((DenseVector) rowAs).values() instanceof double[]) {
val = ArrayUtil.join(
Arrays.stream(((DenseVector) rowAs).values())
.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
, ",")
;
} else {
val = rowAs.toString();
}
} else {
val = rowAs.toString();
}
obj.put(col, val);
log.info("列:{},名:{},值:{}", j, col, val);
}
list.add(obj);
}
return list;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
编辑 (opens new window)
上次更新: 2024-12-06, 10:03:39