java-spark-主成分分析算法-pca
# 配置
配置请看我的其他文章 点击跳转 (opens new window)
# spark官方文档
# 数据
# 训练数据
# 代码
PCA算法的应用场景不是太明确,没做太多验证
# 实体类
用了swagger和lombok 不需要的可以删掉
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
/**
* 主成分分析算法
*
* @author teler
*/
@Data
@ApiModel("主成分分析算法")
public class PcaEntity {
/**
* 数据集路径
*/
@ApiModelProperty("数据源路径")
@NotEmpty(message = "数据源路径不能为空")
private String dataFilePath;
/**
* 迭代次数
*/
@ApiModelProperty("维度")
@NotNull(message = "k不能为空")
@Min(value = 1, message = "k最小值为1")
private Integer k;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# 算法实现
里面有些方法是为了保留小数 不需要的自己改
@Resource
private transient SparkSession sparkSession;
@Override
public Map<String, Object> pca(PcaEntity record) {
System.out.println("================== 主成分分析算法开始===========");
Map<String, Object> map = new HashMap<>(ConstantIntEnum.HASH_MAP_INITIAL_CAPACITY.getCode());
Dataset<Row> source = getDataSetByHdfs(record.getDataFilePath());
//训练数据
;
map.put("training", toList(source));
//设置算法参数
PCA pca = new PCA()
.setInputCol("features")
.setOutputCol("pcaFeatures")
.setK(record.getK());
//训练模型
PCAModel pcaModel = pca.fit(source);
pcaModel.pc();
//转化数据
Dataset<Row> predictions = pcaModel.transform(source).select("pcaFeatures");
predictions.show();
map.put("data", toList(predictions));
return map;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# getDataSetByHdfs方法
这个方法我与上面的方法放在一个类中,所以sparkSession没重复写
/**
* 从hdfs中取数据
*
* @param dataFilePath 数据路径
* @return 数据集合
*/
private Dataset<Row> getDataSetByHdfs(String dataFilePath) {
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
Dataset<Row> dataset;
try {
//我这里的数据是libsvm格式的 如果是其他格式请自行更改
dataset = sparkSession.read().format("libsvm").load(dataFilePath);
log.info("获取数据结束 ");
} catch (Exception e) {
log.info("读取失败:{} ", e.getMessage());
}
return dataset;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# toList
/**
* dataset数据转化为list数据
*
* @param record 数据
* @return 数据集合
*/
private List<Map<String, String>> toList(Dataset<Row> record) {
log.info("格式化结果数据集===============================");
List<Map<String, String>> list = new ArrayList<>();
String[] columns = record.columns();
List<Row> rows = record.collectAsList();
for (Row row : rows) {
Map<String, String> obj = new HashMap<>(16);
for (int j = 0; j < columns.length; j++) {
String col = columns[j];
Object rowAs = row.getAs(col);
String val = "";
//如果是数组
//这一段不需要的可以只留下else的内容
if (rowAs instanceof DenseVector) {
if (((DenseVector) rowAs).values() instanceof double[]) {
val = ArrayUtil.join(
Arrays.stream(((DenseVector) rowAs).values())
.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
, ",")
;
} else {
val = rowAs.toString();
}
} else {
val = rowAs.toString();
}
obj.put(col, val);
log.info("列:{},名:{},值:{}", j, col, val);
}
list.add(obj);
}
return list;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
编辑 (opens new window)
上次更新: 2024-11-06, 19:27:10