目录

java-spark-朴素贝叶斯算法(naive-bayes)

# 配置

配置请看我的其他文章 点击跳转 (opens new window)

# spark官方文档

点击跳转官方文档 (opens new window)

# 其它文章

推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转 (opens new window)

# 数据

# 训练数据

在这里插入图片描述

# 预测数据


# 实体类

用了swagger和lombok 不需要的可以删掉

import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;

/**
 * 分类算法传参
 * 比如决策树 随机森林 朴素贝叶斯
 *
 * @author teler
 * @date 2020-10-09
 */

@Data
@ApiModel("朴素贝叶斯")
public class ClassifyEntity {
	/**
	 * 测试数据
	 */
	@ApiModelProperty("测试数据")
	@NotEmpty(message = "参数不能为空")
	String trainFilePath;
	/**
	 * 数据路径
	 */
	@ApiModelProperty("数据路径")
	@NotEmpty(message = "参数不能为空")
	String dataFilePath;

	/**
	 * 测试模型的数据比例
	 */
	@ApiModelProperty("测试模型的数据比例")
	@Min(value = 0, message = "最小比例为0")
	@Max(value = 1, message = "最大比例为1")
	double testDataPct;

	/**
	 * 测试模型准确率阈值
	 */
	@ApiModelProperty("测试模型准确率阈值")
	@Min(value = 0, message = "最小比例为0")
	@Max(value = 1, message = "最大比例为1")
	double testThreshold;

}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50

# 算法实现

里面有些方法是为了保留小数 不需要的自己改


@Resource
private transient SparkSession sparkSession;

@Override
public Map<String, Object> naiveBayes(ClassifyEntity record) {
	System.out.println("========== 朴素贝叶斯算法开始 ==========");
	Map<String, Object> map = new HashMap<>();
	Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());
	//训练数据
	map.put("training", toList(source));
	//根据比例从数据源中随机抽取数据 训练数据和测试数据比例 建议设为0.8
	Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},
			1234L);
	//训练和测试数据
	Dataset<Row> trainingData = splits[0].cache();
	Dataset<Row> testData = splits[1].cache();

	NaiveBayes nb = new NaiveBayes();
	NaiveBayesModel model = nb.fit(trainingData);

	//准确率
	Dataset<Row> testResult = model.transform(testData);
	Dataset<Row> predictionAndLabels = testResult.select("prediction", "label");

	MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
			                                              .setLabelCol("label")
			                                              .setPredictionCol("prediction")
			                                              .setMetricName("accuracy");
	//准确率
	
	map.put("accuracy", NumberUtil.roundDown(evaluator.evaluate(predictionAndLabels), 3));


	//计算结果
	Dataset<Row> dataSource = getDataSetByHdfs(record.getDataFilePath());
	Dataset<Row> datasetRow = model.transform(dataSource).select("features", "label", "prediction");
	//平滑指数
	map.put("smoothing", NumberUtil.roundDown(model.getSmoothing(), 3));
	//数据
	map.put("data", toList(datasetRow));
	return map;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

# getDataSetByHdfs方法

这个方法我与上面的方法放在一个类中,所以sparkSession没重复写

/**
	 * 从hdfs中取数据
	 *
	 * @param dataFilePath 数据路径
	 * @return 数据集合
	 */
	private Dataset<Row> getDataSetByHdfs(String dataFilePath) {

		//屏蔽日志
		Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
		Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);

		Dataset<Row> dataset;
		try {
			//我这里的数据是libsvm格式的 如果是其他格式请自行更改
			dataset = sparkSession.read().format("libsvm").load(dataFilePath);
			log.info("获取数据结束 ");
		} catch (Exception e) {
			log.info("读取失败:{} ", e.getMessage());
		}
		return dataset;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# toList

/**
	 * dataset数据转化为list数据
	 *
	 * @param record 数据
	 * @return 数据集合
	 */
	private List<Map<String, String>> toList(Dataset<Row> record) {
		log.info("格式化结果数据集===============================");
		List<Map<String, String>> list = new ArrayList<>();
		String[] columns = record.columns();
		List<Row> rows = record.collectAsList();
		for (Row row : rows) {
			Map<String, String> obj = new HashMap<>(16);
			for (int j = 0; j < columns.length; j++) {
				String col = columns[j];
				Object rowAs = row.getAs(col);
				String val = "";
				//如果是数组 
				//这一段不需要的可以只留下else的内容
				if (rowAs instanceof DenseVector) {
					if (((DenseVector) rowAs).values() instanceof double[]) {
						val = ArrayUtil.join(
								Arrays.stream(((DenseVector) rowAs).values())
										.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
								, ",")
						;
					} else {
						val = rowAs.toString();
					}
				} else {
					val = rowAs.toString();
				}

				obj.put(col, val);
				log.info("列:{},名:{},值:{}", j, col, val);
			}
			list.add(obj);
		}
		return list;
	}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
上次更新: 2024-01-03, 13:22:13
最近更新
01
2023年度总结
01-03
02
MongoDB的简单的常用语法
12-11
03
cetnos7通过nfs共享磁盘文件
11-24
更多文章>