生成PMML模型
具体见我的上一篇博客Python XGBoost保存模型PMML
Java调用PMML模型
Java基本的运行环境就不说了,大家如果能看到这篇文章,基本上就都掌握了Java运行环境。
首先maven导入需要的jar包
<dependencies>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator</artifactId>
<version>1.4.1</version>
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-evaluator-extension</artifactId>
<version>1.4.1</version>
</dependency>
</dependencies>
导入jar包后,将下面代码复制到代码处
package sso.passport;
/**
* function:java实现调用pmml文件
* datatime:2020-07-10 16:09
**/
import org.dmg.pmml.FieldName;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.*;
import java.io.*;
import java.util.*;
import java.util.List;
public class Classification {
public static void main(String[] args) throws Exception {
//模型路径
String pathxml = System.getProperty("user.dir") + "/model/xgboost.pmml";
//传入模型特征数据
Map<String, Double> map = new HashMap<String, Double>();
map.put("x1", 5.1);
map.put("x2", 3.5);
map.put("x3", 0.4);
map.put("x4", 0.2);
//模型预测
predictLrHeart(map, pathxml);
}
public static void predictLrHeart(Map<String, Double> irismap, String pathxml) throws Exception {
PMML pmml;
File file = new File(pathxml);
InputStream inputStream = new FileInputStream(file);
try (InputStream is = inputStream) {
pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory.newModelEvaluator(pmml);
Evaluator evaluator = (Evaluator) modelEvaluator;
List<InputField> inputFields = evaluator.getInputFields();
Map<FieldName, FieldValue> argements = new LinkedHashMap<>();
for (InputField inputField : inputFields) {
FieldName inputFieldName = inputField.getName();
Object raeValue = irismap.get(inputFieldName.getValue());
FieldValue inputFieldValue = inputField.prepare(raeValue);
argements.put(inputFieldName, inputFieldValue);
}
Map<FieldName, ?> results = evaluator.evaluate(argements);
List<TargetField> targetFields = evaluator.getTargetFields();
for (TargetField targetField : targetFields) {
FieldName targetFieldName = targetField.getName();
Object targetFieldValue = results.get(targetFieldName);
// System.out.println("target: " + targetFieldName.getValue());
System.out.println(targetFieldValue);
}
}
}
}
本代码也是根据鸢尾花数据进行操作的,由于本人对java语言不甚了解,其中详细注释不好多说,但是一看就能明白。
大家运行如果报错的话,请看下一篇文章(可能有你需要的哦)。
(1)、如果您在阅读博客时遇到问题或者不理解的地方,可以联系我,互相交流、互相进步;
(2)、本人业余时间可以承接毕业设计和各种小项目,如系统构建、成立网站、数据挖掘、机器学习、深度学习等。有需要的加QQ:1143948594,备注“csdn项目”。