java-spark-kMeans算法
# 配置
配置请看我的其他文章 点击跳转 (opens new window)
# spark官方文档
# 数据
# 训练数据
# 实体类
用了swagger和lombok 不需要的可以删掉
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
import javax.validation.constraints.NotNull;
/**
* k means聚类
*
* @author teler
*/
@Data
public class KMeansEntity {
/**
* 数据集路径
*/
@ApiModelProperty("数据源路径")
@NotEmpty(message = "数据源路径不能为空")
private String dataFilePath;
/**
* 簇数
*/
@ApiModelProperty("簇数")
@NotNull(message = "簇数必填")
@Min(value = 1, message = "簇数最小为1")
private int k;
/**
* 迭代次数
*/
@ApiModelProperty("迭代次数")
@NotNull(message = "迭代次数必填")
@Min(value = 0, message = "迭代次数最小为0")
private int iter;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# 算法实现
里面有些方法是为了保留小数 不需要的自己改
@Resource
private transient SparkSession sparkSession;
@Override
public Map<String, Object> kmeans(KMeansEntity record) {
log.info("========== k均值聚类算法开始 ==========");
Dataset<Row> dataset = getDataSetByHdfs(record.getDataFilePath());
//设置聚类算法参数
KMeans kmeans = new KMeans().setK(record.getK()).setSeed(1L).setMaxIter(record.getIter());
KMeansModel model = kmeans.fit(dataset);
Dataset<Row> predictions = model.transform(dataset);
ClusteringEvaluator evaluator = new ClusteringEvaluator();
Map<String, Object> map = new HashMap<>();
List<Map<String, String>> centers = new ArrayList<>();
Vector[] clusterArray = model.clusterCenters();
for (Vector center : clusterArray) {
log.info("中心值:{}", center.toString());
Map<String, String> centerMap = new HashMap<>();
centerMap.put("center", ArrayUtil.join(Arrays.stream(center.toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()).toArray(), ","));
centers.add(centerMap);
}
// 簇中心点
map.put("clusterCenters", centers);
// 误差 Silhouette with squared euclidean distance
map.put("sse", NumberUtil.roundDown(evaluator.evaluate(predictions), 3));
// 区分样本数据所属簇
List<Row> rows = predictions.collectAsList();
List<Map<String, Object>> data = new ArrayList<>();
for (int i = 0; i < rows.size(); i++) {
Map<String, Object> dataMap = new HashMap<>();
dataMap.put("index", i);
Vector vector = rows.get(i).getAs("features");
dataMap.put("features", vector.toArray());
int belong = model.predict(vector);
dataMap.put("belong", belong);
dataMap.put("center", ArrayUtil.join(Arrays.stream(clusterArray[belong].toArray()).map(val -> NumberUtil.roundDown(val, 3).doubleValue()).toArray(), ","));
data.add(dataMap);
}
map.put("data", data);
log.info("========== k均值聚类算法结束 ==========");
return map;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# getDataSetByHdfs方法
这个方法我与上面的方法放在一个类中,所以sparkSession没重复写
/**
* 从hdfs中取数据
*
* @param dataFilePath 数据路径
* @return 数据集合
*/
private Dataset<Row> getDataSetByHdfs(String dataFilePath) {
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
Dataset<Row> dataset;
try {
//我这里的数据是libsvm格式的 如果是其他格式请自行更改
dataset = sparkSession.read().format("libsvm").load(dataFilePath);
log.info("获取数据结束 ");
} catch (Exception e) {
log.info("读取失败:{} ", e.getMessage());
}
return dataset;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# toList
/**
* dataset数据转化为list数据
*
* @param record 数据
* @return 数据集合
*/
private List<Map<String, String>> toList(Dataset<Row> record) {
log.info("格式化结果数据集===============================");
List<Map<String, String>> list = new ArrayList<>();
String[] columns = record.columns();
List<Row> rows = record.collectAsList();
for (Row row : rows) {
Map<String, String> obj = new HashMap<>(16);
for (int j = 0; j < columns.length; j++) {
String col = columns[j];
Object rowAs = row.getAs(col);
String val = "";
//如果是数组
//这一段不需要的可以只留下else的内容
if (rowAs instanceof DenseVector) {
if (((DenseVector) rowAs).values() instanceof double[]) {
val = ArrayUtil.join(
Arrays.stream(((DenseVector) rowAs).values())
.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
, ",")
;
} else {
val = rowAs.toString();
}
} else {
val = rowAs.toString();
}
obj.put(col, val);
log.info("列:{},名:{},值:{}", j, col, val);
}
list.add(obj);
}
return list;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
编辑 (opens new window)
上次更新: 2024-12-06, 10:03:39