目录

java-spark-主成分分析算法-pca

# 配置

配置请看我的其他文章 点击跳转 (opens new window)

# spark官方文档

点击跳转官方文档 (opens new window)

# 数据

# 训练数据

在这里插入图片描述


# 代码

PCA算法的应用场景不是太明确,没做太多验证

# 实体类

用了swagger和lombok 不需要的可以删掉


import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;

/**
 * 主成分分析算法
 *
 * @author teler
 */

@Data
@ApiModel("主成分分析算法")
public class PcaEntity {
	/**
	 * 数据集路径
	 */
	@ApiModelProperty("数据源路径")
	@NotEmpty(message = "数据源路径不能为空")
	private String dataFilePath;

	/**
	 * 迭代次数
	 */
	@ApiModelProperty("维度")
	@NotNull(message = "k不能为空")
	@Min(value = 1, message = "k最小值为1")
	private Integer k;
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

# 算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;
@Override
	public Map<String, Object> pca(PcaEntity record) {
		System.out.println("================== 主成分分析算法开始===========");
		Map<String, Object> map = new HashMap<>(ConstantIntEnum.HASH_MAP_INITIAL_CAPACITY.getCode());

		Dataset<Row> source = getDataSetByHdfs(record.getDataFilePath());
		//训练数据
		;
		map.put("training", toList(source));
		//设置算法参数
		PCA pca = new PCA()
				          .setInputCol("features")
				          .setOutputCol("pcaFeatures")
				          .setK(record.getK());
		//训练模型
		PCAModel pcaModel = pca.fit(source);

		pcaModel.pc();
		//转化数据
		Dataset<Row> predictions = pcaModel.transform(source).select("pcaFeatures");
		predictions.show();

		map.put("data", toList(predictions));
		return map;
	}


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/**
	 * 从hdfs中取数据
	 *
	 * @param dataFilePath 数据路径
	 * @return 数据集合
	 */
	private Dataset<Row> getDataSetByHdfs(String dataFilePath) {

		//屏蔽日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);

		Dataset<Row> dataset;
		try {
			//我这里的数据是libsvm格式的 如果是其他格式请自行更改
			dataset = sparkSession.read().format("libsvm").load(dataFilePath);
			log.info("获取数据结束 ");
		} catch (Exception e) {
			log.info("读取失败:{} ", e.getMessage());
		}
		return dataset;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# toList

/**
	 * dataset数据转化为list数据
	 *
	 * @param record 数据
	 * @return 数据集合
	 */
	private List<Map<String, String>> toList(Dataset<Row> record) {
		log.info("格式化结果数据集===============================");
		List<Map<String, String>> list = new ArrayList<>();
		String[] columns = record.columns();
		List<Row> rows = record.collectAsList();
		for (Row row : rows) {
			Map<String, String> obj = new HashMap<>(16);
			for (int j = 0; j < columns.length; j++) {
				String col = columns[j];
				Object rowAs = row.getAs(col);
				String val = "";
				//如果是数组 
				//这一段不需要的可以只留下else的内容
				if (rowAs instanceof DenseVector) {
					if (((DenseVector) rowAs).values() instanceof double[]) {
						val = ArrayUtil.join(
								Arrays.stream(((DenseVector) rowAs).values())
										.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
								, ",")
						;
					} else {
						val = rowAs.toString();
					}
				} else {
					val = rowAs.toString();
				}

				obj.put(col, val);
				log.info("列:{},名:{},值:{}", j, col, val);
			}
			list.add(obj);
		}
		return list;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
上次更新: 2024-11-06, 19:27:10
最近更新
01
java playwright爬虫
11-06
02
连接chrome调试
07-23
03
连接chrome调试
07-23
更多文章>