目录

java-spark-kMeans算法

# 配置

配置请看我的其他文章 点击跳转 (opens new window)

# spark官方文档

点击跳转官方文档 (opens new window)

# 数据

# 训练数据

在这里插入图片描述


# 实体类

用了swagger和lombok 不需要的可以删掉


import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;

/**
 * k means聚类
 *
 * @author teler
 */
@Data
public class KMeansEntity {
	/**
	 * 数据集路径
	 */
	@ApiModelProperty("数据源路径")
	@NotEmpty(message = "数据源路径不能为空")
	private String dataFilePath;
	/**
	 * 簇数
	 */
	@ApiModelProperty("簇数")
	@NotNull(message = "簇数必填")
	@Min(value = 1, message = "簇数最小为1")
	private int k;
	/**
	 * 迭代次数
	 */
	@ApiModelProperty("迭代次数")
	@NotNull(message = "迭代次数必填")
	@Min(value = 0, message = "迭代次数最小为0")
	private int iter;

}


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

# 算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;

@Override
	public Map<String, Object> kmeans(KMeansEntity record) {
		log.info("========== k均值聚类算法开始 ==========");
		Dataset<Row> dataset = getDataSetByHdfs(record.getDataFilePath());

		//设置聚类算法参数
		KMeans kmeans = new KMeans().setK(record.getK()).setSeed(1L).setMaxIter(record.getIter());
		KMeansModel model = kmeans.fit(dataset);

		Dataset<Row> predictions = model.transform(dataset);
		ClusteringEvaluator evaluator = new ClusteringEvaluator();
		Map<String, Object> map = new HashMap<>();
		List<Map<String, String>> centers = new ArrayList<>();
		Vector[] clusterArray = model.clusterCenters();
		for (Vector center : clusterArray) {
			log.info("中心值:{}", center.toString());
			Map<String, String> centerMap = new HashMap<>();
			centerMap.put("center", ArrayUtil.join(Arrays.stream(center.toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()).toArray(), ","));
			centers.add(centerMap);
		}
		// 簇中心点
		map.put("clusterCenters", centers);
		// 误差 Silhouette with squared euclidean distance
		map.put("sse", NumberUtil.roundDown(evaluator.evaluate(predictions), 3));

		// 区分样本数据所属簇
		List<Row> rows = predictions.collectAsList();
		List<Map<String, Object>> data = new ArrayList<>();
		for (int i = 0; i < rows.size(); i++) {
			Map<String, Object> dataMap = new HashMap<>();
			dataMap.put("index", i);

			Vector vector = rows.get(i).getAs("features");
			dataMap.put("features", vector.toArray());

			int belong = model.predict(vector);
			dataMap.put("belong", belong);

			dataMap.put("center", ArrayUtil.join(Arrays.stream(clusterArray[belong].toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()).toArray(), ","));

			data.add(dataMap);
		}
		map.put("data", data);
		log.info("========== k均值聚类算法结束 ==========");
		return map;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

# getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/**
	 * 从hdfs中取数据
	 *
	 * @param dataFilePath 数据路径
	 * @return 数据集合
	 */
	private Dataset<Row> getDataSetByHdfs(String dataFilePath) {

		//屏蔽日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);

		Dataset<Row> dataset;
		try {
			//我这里的数据是libsvm格式的 如果是其他格式请自行更改
			dataset = sparkSession.read().format("libsvm").load(dataFilePath);
			log.info("获取数据结束 ");
		} catch (Exception e) {
			log.info("读取失败:{} ", e.getMessage());
		}
		return dataset;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# toList

/**
	 * dataset数据转化为list数据
	 *
	 * @param record 数据
	 * @return 数据集合
	 */
	private List<Map<String, String>> toList(Dataset<Row> record) {
		log.info("格式化结果数据集===============================");
		List<Map<String, String>> list = new ArrayList<>();
		String[] columns = record.columns();
		List<Row> rows = record.collectAsList();
		for (Row row : rows) {
			Map<String, String> obj = new HashMap<>(16);
			for (int j = 0; j < columns.length; j++) {
				String col = columns[j];
				Object rowAs = row.getAs(col);
				String val = "";
				//如果是数组 
				//这一段不需要的可以只留下else的内容
				if (rowAs instanceof DenseVector) {
					if (((DenseVector) rowAs).values() instanceof double[]) {
						val = ArrayUtil.join(
								Arrays.stream(((DenseVector) rowAs).values())
										.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
								, ",")
						;
					} else {
						val = rowAs.toString();
					}
				} else {
					val = rowAs.toString();
				}

				obj.put(col, val);
				log.info("列:{},名:{},值:{}", j, col, val);
			}
			list.add(obj);
		}
		return list;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
上次更新: 2024-01-03, 13:22:13
最近更新
01
2023年度总结
01-03
02
MongoDB的简单的常用语法
12-11
03
cetnos7通过nfs共享磁盘文件
11-24
更多文章>