java-spark-朴素贝叶斯算法(naive-bayes)
# 配置
配置请看我的其他文章 点击跳转 (opens new window)
# spark官方文档
# 其它文章
推荐一个在蚂蚁做算法的人写的文章,不过他的文章偏专业化,有很多数学学公式。我是看的比较懵。点击跳转 (opens new window)
# 数据
# 训练数据
# 预测数据
# 实体类
用了swagger和lombok 不需要的可以删掉
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.Max;
import javax.validation.constraints.Min;
import javax.validation.constraints.NotEmpty;
/**
* 分类算法传参
* 比如决策树 随机森林 朴素贝叶斯
*
* @author teler
* @date 2020-10-09
*/
@Data
@ApiModel("朴素贝叶斯")
public class ClassifyEntity {
/**
* 测试数据
*/
@ApiModelProperty("测试数据")
@NotEmpty(message = "参数不能为空")
String trainFilePath;
/**
* 数据路径
*/
@ApiModelProperty("数据路径")
@NotEmpty(message = "参数不能为空")
String dataFilePath;
/**
* 测试模型的数据比例
*/
@ApiModelProperty("测试模型的数据比例")
@Min(value = 0, message = "最小比例为0")
@Max(value = 1, message = "最大比例为1")
double testDataPct;
/**
* 测试模型准确率阈值
*/
@ApiModelProperty("测试模型准确率阈值")
@Min(value = 0, message = "最小比例为0")
@Max(value = 1, message = "最大比例为1")
double testThreshold;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# 算法实现
里面有些方法是为了保留小数 不需要的自己改
@Resource
private transient SparkSession sparkSession;
@Override
public Map<String, Object> naiveBayes(ClassifyEntity record) {
System.out.println("========== 朴素贝叶斯算法开始 ==========");
Map<String, Object> map = new HashMap<>();
Dataset<Row> source = getDataSetByHdfs(record.getTrainFilePath());
//训练数据
map.put("training", toList(source));
//根据比例从数据源中随机抽取数据 训练数据和测试数据比例 建议设为0.8
Dataset<Row>[] splits = source.randomSplit(new double[]{record.getTestDataPct(), 1 - record.getTestDataPct()},
1234L);
//训练和测试数据
Dataset<Row> trainingData = splits[0].cache();
Dataset<Row> testData = splits[1].cache();
NaiveBayes nb = new NaiveBayes();
NaiveBayesModel model = nb.fit(trainingData);
//准确率
Dataset<Row> testResult = model.transform(testData);
Dataset<Row> predictionAndLabels = testResult.select("prediction", "label");
MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy");
//准确率
map.put("accuracy", NumberUtil.roundDown(evaluator.evaluate(predictionAndLabels), 3));
//计算结果
Dataset<Row> dataSource = getDataSetByHdfs(record.getDataFilePath());
Dataset<Row> datasetRow = model.transform(dataSource).select("features", "label", "prediction");
//平滑指数
map.put("smoothing", NumberUtil.roundDown(model.getSmoothing(), 3));
//数据
map.put("data", toList(datasetRow));
return map;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# getDataSetByHdfs方法
这个方法我与上面的方法放在一个类中,所以sparkSession没重复写
/**
* 从hdfs中取数据
*
* @param dataFilePath 数据路径
* @return 数据集合
*/
private Dataset<Row> getDataSetByHdfs(String dataFilePath) {
//屏蔽日志
Logger.getLogger("org.apache.spark").setLevel(Level.WARN);
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);
Dataset<Row> dataset;
try {
//我这里的数据是libsvm格式的 如果是其他格式请自行更改
dataset = sparkSession.read().format("libsvm").load(dataFilePath);
log.info("获取数据结束 ");
} catch (Exception e) {
log.info("读取失败:{} ", e.getMessage());
}
return dataset;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# toList
/**
* dataset数据转化为list数据
*
* @param record 数据
* @return 数据集合
*/
private List<Map<String, String>> toList(Dataset<Row> record) {
log.info("格式化结果数据集===============================");
List<Map<String, String>> list = new ArrayList<>();
String[] columns = record.columns();
List<Row> rows = record.collectAsList();
for (Row row : rows) {
Map<String, String> obj = new HashMap<>(16);
for (int j = 0; j < columns.length; j++) {
String col = columns[j];
Object rowAs = row.getAs(col);
String val = "";
//如果是数组
//这一段不需要的可以只留下else的内容
if (rowAs instanceof DenseVector) {
if (((DenseVector) rowAs).values() instanceof double[]) {
val = ArrayUtil.join(
Arrays.stream(((DenseVector) rowAs).values())
.map(rowVal -> NumberUtil.roundDown(rowVal, 3).doubleValue()).toArray()
, ",")
;
} else {
val = rowAs.toString();
}
} else {
val = rowAs.toString();
}
obj.put(col, val);
log.info("列:{},名:{},值:{}", j, col, val);
}
list.add(obj);
}
return list;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
编辑 (opens new window)
上次更新: 2024-11-06, 19:27:10