基于Spring Boot的ALBERT词向量服务(3)

Spring Boot工程开发

前面我们已经做好了模型的准备,下面就可以进行Spring Boot工程的开发了。这里我使用的是IntelliJ IDEA,借助Spring initializer创建一个Spring Boot Maven工程,添加Spring Boot Web starter和Lombok,确保IDEA已经安装了Lombok插件,这个是使用IDEA创建Spring Boot工程的基本操作,就不多说了。然后正式开始撸代码。

添加Maven依赖项和资源文件

这里主要添加两个依赖项,分别是TensorFlow Java版(这里我用的是1.12.0)和commons-io,分别用于模型推理和模型载入:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow</artifactId>
  <version>1.12.0</version>
</dependency>
<dependency>
  <groupId>commons-io</groupId>
  <artifactId>commons-io</artifactId>
  <version>2.6</version>
</dependency>

为了后期能直接把模型文件打到jar包之中而无需额外指定,这里将模型文件和词汇表引入资源文件。这么做还比较合适的重要原因也是我们这里用到的ALBERT本身的模型文件比较小,只有不到16M。如果是原生的BERT,模型文件动辄几百MB甚至更大,那直接放入资源文件可能就不是很合适了。这里单独创建一个文件夹albert-model存放:

resources.PNG

模型与词汇表的启动自加载(com.aiwiscal.albert.model)

这里借助Sprint Boot的_@_PostConstruct 注解,使Spring Boot工程构建完成后就自动把模型和词汇表加载到内存中,然后对应的类依赖注入,就不用每次推理时重新加载模型和创建相关对象,直接拿来用就可以,大大提高了效率。另外借助lombok,可自动生成Getter/Setter/Constructor使代码更加精简:

@Getter // lombok生成getter方法,方便取数据
@Component("loadALBERT")
public class LoadALBERT {
    private Logger logger = LoggerFactory.getLogger(getClass());

    private Session session;  // TensorFlow Session 对象,可完成推理

    private final int vectorDim = 312;  // albert 向量维数

    private final int maxSupportLen = 510; // albert 支持的最大长度,原始为512,去掉首尾的[CLS]和[SEP],即为510

    private final String modelPath = "albert-model/albert_tiny_zh_google.pb"; // 模型资源文件路径



    @PostConstruct
    private void init(){
        loadGraph();  // 调用加载方法
    }

    private void loadGraph(){
        Graph graph = new Graph();
        try{
            // 获取资源文件中的模型文件输入流
            InputStream inputStream = this.getClass().getClassLoader()
                .getResourceAsStream(modelPath);
            // 使用commons-io中的IOUtils将模型文件输入流转化为byte数组
            byte[] graphPb = IOUtils.toByteArray(inputStream);
            //初始化TensorFlow graph
            graph.importGraphDef(graphPb);
            // 把graph装入一个新的Session,可运行推理
            this.session = new Session(graph);
            logger.info("ALBERT checkpoint loaded @ {}, vector dimension - {}, maxSupportLen - {}",
                    modelPath, vectorDim, maxSupportLen);
        } catch (Exception e){
            logger.error("Failed to load ALBERT checkpoint @ {} ! - {}", modelPath, e.toString());
        }
    }
}

其余说明参见代码中的注释,另外使用slf4j打印日志。同理,词汇表也需要预先自加载,并使用HashMap存储字符-Token ID映射:

@Getter
@Component("loadVocab")
public class LoadVocab {
    private Logger logger = LoggerFactory.getLogger(getClass());

    private Map<String, Integer> vocabTable = new HashMap<>();

    @PostConstruct
    private void init(){
        loadVocab();
    }

    private void loadVocab(){
        try{
            InputStreamReader inputReader = new InputStreamReader(
                    this.getClass().getClassLoader().getResourceAsStream("albert-model/vocab.txt"));
            BufferedReader bf = new BufferedReader(inputReader);
            String str;
            int n = 0;
            // 按行读,分配id
            while ((str = bf.readLine()) != null) {
                String lineWord = str.trim();
                this.vocabTable.put(lineWord, n);
                n++;
            }
            bf.close();
            inputReader.close();
            logger.info("ALBERT vocab loaded. total number - {} ", n);
        }catch (Exception e){
            logger.error("failed to load ALBERT vocab! - {}", e.toString());
        }
    }
}

这样以上两个类在Spring Boot启动时会自动调用加载方法,完成必要数据的加载,以供其他依赖对象使用。具体包路径是com.aiwiscal.albert.model。

输入输出接口定义(com.aiwiscal.albert.param)

定义了几个输入输出,以及中间结果类,比较零碎,可参照开源代码与注释理解:

类名 说明
com.aiwiscal.albert.param.InputText 原始输入类
com.aiwiscal.albert.param.InputTextValid 有效输入类,对原始输入去空字符或补充[PAD]
com.aiwiscal.albert.param.OutputToken 模型推理输入类,包含tokenId和segmentId
com.aiwiscal.albert.param.OutputVector 最终请求返回类

简易分词器(com.aiwiscal.albert.service.Tokenizer)

这里实现了一个简易分词器,把原始输入文本简单清洗后,切分为单字符查表获得TokenId。这里就需要把之前加载好的词汇表映射注入进来(@Autowired):

@Component("tokenizer")
public class Tokenizer {
    private Logger logger = LoggerFactory.getLogger(getClass());

    @Autowired
    private LoadVocab loadVocab; // 注入词汇表映射

    @Autowired
    private LoadALBERT loadALBERT; // 注入albert模型数据

    // 原始输入清洗,切分字符并查表得到tokenId和segmentId
    // 返回OutputToken类对象
    public OutputToken tokenize(InputText inputText){
        OutputToken outputToken = new OutputToken();
        try{
            // 简单清洗文本,生成有效输入类InputTextValid对象
            InputTextValid inputTextValid = validateText(inputText);
            
            // 获得tokenId
            float[] tokenId = getTokenId(inputTextValid);
            
            // 获得segmentId
            float[] segmentId = getSegmentId(inputTextValid);
            
            outputToken.setTokenId(tokenId);
            outputToken.setSegmentId(segmentId);
            outputToken.setSuccess(true);
            outputToken.setInputTextValid(inputTextValid);
            logger.debug("Text tokenized ...");
        }catch (Exception e){
            logger.error("Failed to tokenize the text - {}", e.toString());
            outputToken.setSuccess(false);
        }
        return outputToken;
    }
  // ...... 以下省略
}

这里重点提一下TokenId的获取,通过注入的loadVocab对象获得其中的HashMap映射表,可以方便的取到对应字符的Token Id:

// 查表获得tokenId
    private float[] getTokenId(InputTextValid inputTextValid){
        String[] textTokenList = inputTextValid.getTextTokenList();
        Map<String, Integer> vocabTable = loadVocab.getVocabTable(); // 获得注入loadVacab对象中的字符映射表
        float[] tokenId = new float[textTokenList.length + 2];
        tokenId[0] = 101; // 头部添加[CLS]标记,token id为101
        tokenId[tokenId.length - 1] = 102; // 尾部添加[SEP]标记,token id为102
        for (int i = 0; i < textTokenList.length; i++) {
            String currentCharStr = textTokenList[i];
            if(!vocabTable.containsKey(currentCharStr)){
                currentCharStr = "[UNK]"; // 不在词汇表中的,设定为[UNK]
            }
            tokenId[i+1] = vocabTable.get(currentCharStr); // 查表
        }
        return tokenId;
    }

而对于segment id,这里只是输入单文本,所以直接生成全0数组表示:

private float[] getSegmentId(InputTextValid inputTextValid){
        return new float[inputTextValid.getTextTokenList().length + 2];
    }

模型推理(com.aiwiscal.albert.service.InferALBERT)

这是向量生成的核心,通过运行模型来获取结果,作为一个Service。需要注入模型数据,调用TensorFlow组件完成推理,参照以下代码中的注释:

@Service
public class InferALBERT {
    private Logger logger = LoggerFactory.getLogger(getClass());

    @Autowired
    private LoadALBERT loadALBERT;  // 注入模型数据

    @Autowired
    private Tokenizer tokenizer;    // 注入分词器

    private float[] inferArr(float[] inputToken, float[] inputSegment){
        // 将1维数组扩展为2维以满足输入需要
        float[][] inputToken2D = new float[1][inputToken.length];
        float[][] inputSegment2D = new float[1][inputSegment.length];
        System.arraycopy(inputToken, 0, inputToken2D[0], 0, inputToken.length);
        System.arraycopy(inputSegment, 0, inputSegment2D[0], 0, inputSegment.length);

        // 调用TensorFlow会话(Session)中的runner,实现模型推理
        // 注入数据使用feed,取结果使用fetch,根据输入输出tensor的名称操作
        Tensor result = loadALBERT.getSession().runner()
                .feed("Input-Token", Tensor.create(inputToken2D))
                .feed("Input-Segment", Tensor.create(inputSegment2D))
                .fetch("output_1")
                .run().get(0);
        float[] ret = new float[loadALBERT.getVectorDim()];
        // 将结果的Tensor对象内部数据拷贝至原生数组
        result.copyTo(ret);
        return ret;
    }
    public OutputVector infer(InputText inputText){
        // 调用了上述inferArr
        // 总流程,null检查,原始输入处理,分词,模型推理并打印相关日志,最终返回
        // 具体代码省略 ......
        // ......
    }
}

这里用上了前面在Python环境中生成pb文件时得到的输入输出Tensor名称,以实现正确的数据注入,推理和结果获取。最终再把结果拷贝到Java原生数组方便后续的处理。当然在外层还封装了一个总流程方法,完成一系列操作并且返回最终的OutputVector对象。

Http请求处理(com.aiwiscal.albert.controller.AlbertVecController)

基于@RestController注解完成对外部请求的处理,调用InferALBERT(Service)推理后返回给请求端:

@RestController
public class AlbertVecController {
    private Logger logger = LoggerFactory.getLogger(getClass());

    @Autowired
    private InferALBERT inferALBERT; // 注入核心推理类Service

    @Value("${server.port}")
    private int port;

    @RequestMapping("/")
    public String runStatus(){
        return String.format("======== ALBERT Vector Service is running @ port %d =======", this.port);
    }

    @PostMapping(path="/vector") //处理post向量生成请求
    public OutputVector getVector(@RequestBody InputText inputText){
        return inferALBERT.infer(inputText);
    }
}

在配置文件(application.properties)里设定运行端口为7777(server.port=7777)。另外实现Spring Boot CommandLineRunner接口,使Spring Boot启动时对向量生成进行简单自检,位于com.aiwiscal.albert.starter.ServicePass类,若成功启动则输出日志**“ALBERT Vector Service is ready to listen …”。**

工程启动(com.aiwiscal.albert.AlbertVecApplication)

运行Spring Boot Application主应用,这里是com.aiwiscal.albert.AlbertVecApplication,查看日志输入:

 .   ____          _            __ _ _
 /\\ / ___'_ __ _ _(_)_ __  __ _ \ \ \ \
( ( )\___ | '_ | '_| | '_ \/ _` | \ \ \ \
 \\/  ___)| |_)| | | | | || (_| |  ) ) ) )
  '  |____| .__|_| |_|_| |_\__, | / / / /
 =========|_|==============|___/=/_/_/_/
 :: Spring Boot ::        (v2.2.5.RELEASE)

2020-03-28 10:17:18.167  INFO 10964 --- [           main] c.aiwiscal.albert.AlbertVecApplication   : Starting AlbertVecApplication on LAPTOP-MVOM84AD with PID 10964 (E:\IdeaProjects\albert-vec\target\classes started by Wenhan in E:\IdeaProjects\albert-vec)
2020-03-28 10:17:18.171  INFO 10964 --- [           main] c.aiwiscal.albert.AlbertVecApplication   : No active profile set, falling back to default profiles: default
2020-03-28 10:17:18.939  INFO 10964 --- [           main] o.s.b.w.embedded.tomcat.TomcatWebServer  : Tomcat initialized with port(s): 7777 (http)
2020-03-28 10:17:18.946  INFO 10964 --- [           main] o.apache.catalina.core.StandardService   : Starting service [Tomcat]
2020-03-28 10:17:18.946  INFO 10964 --- [           main] org.apache.catalina.core.StandardEngine  : Starting Servlet engine: [Apache Tomcat/9.0.31]
2020-03-28 10:17:19.008  INFO 10964 --- [           main] o.a.c.c.C.[Tomcat].[localhost].[/]       : Initializing Spring embedded WebApplicationContext
2020-03-28 10:17:19.009  INFO 10964 --- [           main] o.s.web.context.ContextLoader            : Root WebApplicationContext: initialization completed in 782 ms
2020-03-28 10:17:20.094609: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2020-03-28 10:17:20.103  INFO 10964 --- [           main] com.aiwiscal.albert.model.LoadALBERT     : ALBERT checkpoint loaded @ albert-model/albert_tiny_zh_google.pb, vector dimension - 312, maxSupportLen - 510
2020-03-28 10:17:20.112  INFO 10964 --- [           main] com.aiwiscal.albert.model.LoadVocab      : ALBERT vocab loaded. total number - 21128 
2020-03-28 10:17:20.207  INFO 10964 --- [           main] o.s.s.concurrent.ThreadPoolTaskExecutor  : Initializing ExecutorService 'applicationTaskExecutor'
2020-03-28 10:17:20.333  INFO 10964 --- [           main] o.s.b.w.embedded.tomcat.TomcatWebServer  : Tomcat started on port(s): 7777 (http) with context path ''
2020-03-28 10:17:20.336  INFO 10964 --- [           main] c.aiwiscal.albert.AlbertVecApplication   : Started AlbertVecApplication in 2.562 seconds (JVM running for 3.697)
2020-03-28 10:17:20.339  INFO 10964 --- [           main] com.aiwiscal.albert.service.InferALBERT  : Raw Input: Text - "你好 世 界, 世界你好!", ValidLength - 5 
2020-03-28 10:17:20.339  INFO 10964 --- [           main] com.aiwiscal.albert.service.InferALBERT  : Validated Input: Text - "你好世界,", ValidLength - 5 
2020-03-28 10:17:21.756  INFO 10964 --- [           main] com.aiwiscal.albert.service.InferALBERT  : ALBERT vector generation finished - time cost: 1417 ms. 
2020-03-28 10:17:21.757  INFO 10964 --- [           main] com.aiwiscal.albert.starter.ServicePass  : ALBERT Vector Service is ready to listen ...

看到最后的"2020-03-28 10:17:21.757 INFO 10964 — [ main] com.aiwiscal.albert.starter.ServicePass : ALBERT Vector Service is ready to listen …",说明已经成功启动了,在浏览器里输入127.0.0.1:7777回车:

get.PNG

小结

以上大致说明了Spring Boot工程的开发思路和流程,总体还是比较简单的,下一步会在Python中对我们启动的上述向量服务进行请求,进行应用示例。

Python支持工程开源代码:https://github.com/Aiwiscal/albert-vec-support
Java主工程开源代码:https://github.com/Aiwiscal/albert-vec

喜欢请给star哦~

发布了30 篇原创文章 · 获赞 205 · 访问量 9万+

猜你喜欢

转载自blog.csdn.net/qq_15746879/article/details/105166300