├── easybert ├── VERSION.txt ├── __init__.py ├── __main__.py └── bert.py ├── setup.cfg ├── MANIFEST.in ├── requirements.txt ├── requirements-gpu.txt ├── LICENSE.txt ├── setup.py ├── .gitignore ├── docker ├── cpu │ └── Dockerfile └── gpu │ └── Dockerfile ├── src └── main │ └── java │ └── com │ └── robrua │ └── nlp │ └── bert │ ├── Tokenizer.java │ ├── WordpieceTokenizer.java │ ├── FullTokenizer.java │ ├── BasicTokenizer.java │ └── Bert.java ├── README.md └── pom.xml /easybert/VERSION.txt: -------------------------------------------------------------------------------- 1 | 1.0.4 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore=E501 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include easybert/VERSION.txt 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | numpy 3 | tensorflow==1.15.2 4 | tensorflow-hub==0.4.0 5 | bert-tensorflow==1.0.1 6 | click 7 | -------------------------------------------------------------------------------- /requirements-gpu.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | numpy 3 | tensorflow-gpu==1.15.2 4 | tensorflow-hub==0.4.0 5 | bert-tensorflow==1.0.1 6 | click 7 | -------------------------------------------------------------------------------- /easybert/__init__.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | 3 | from .bert import Bert 4 | 5 | 6 | __version__ = pkg_resources.resource_string("easybert", "VERSION.txt").decode("UTF-8").strip() 7 | 8 | 9 | __all__ = [__version__, Bert] 10 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Rob Rua 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from pathlib import Path 3 | import os 4 | 5 | from setuptools import setup, find_packages 6 | 7 | 8 | install_requires = [ 9 | "numpy", 10 | "tensorflow-hub==0.4.0", 11 | "bert-tensorflow==1.0.1", 12 | "click" 13 | ] 14 | 15 | # Hacky check for whether CUDA is installed 16 | has_cuda = any("CUDA" in name.split("_") for name in os.environ.keys()) 17 | install_requires.append("tensorflow-gpu==1.13.1" if has_cuda else "tensorflow==1.13.1") 18 | 19 | version_file = Path(__file__).parent.joinpath("easybert", "VERSION.txt") 20 | version = version_file.read_text(encoding="UTF-8").strip() 21 | 22 | setup( 23 | name="easybert", 24 | version=version, 25 | url="https://github.com/robrua/easy-bert", 26 | author="Rob Rua", 27 | author_email="robertrua@gmail.com", 28 | description="A Dead Simple BERT API (https://github.com/google-research/bert)", 29 | keywords=["BERT", "Natural Language Processing", "NLP", "Language Model", "Language Models", "Machine Learning", "ML", "TensorFlow", "Embeddings", "Word Embeddings", "Sentence Embeddings"], 30 | classifiers=[ 31 | "Development Status :: 4 - Beta", 32 | "Intended Audience :: Developers", 33 | "License :: OSI Approved :: MIT License", 34 | "Operating System :: OS Independent", 35 | "Programming Language :: Python :: 3" 36 | ], 37 | license="MIT", 38 | packages=find_packages(), 39 | entry_points={"console_scripts": ["bert=easybert.__main__:_main"]}, 40 | zip_safe=True, 41 | install_requires=install_requires, 42 | include_package_data=True 43 | ) 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Windows git error dumps 37 | *sh.exe.stackdump 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | docs/ 58 | 59 | # PyBuilder 60 | target/ 61 | 62 | # PyCharm 63 | .idea/ 64 | 65 | # Compiled class file 66 | *.class 67 | 68 | # Log file 69 | *.log 70 | 71 | # BlueJ files 72 | *.ctxt 73 | 74 | # Mobile Tools for Java (J2ME) 75 | .mtj.tmp/ 76 | 77 | # Package Files # 78 | *.jar 79 | *.war 80 | *.ear 81 | *.zip 82 | *.tar.gz 83 | *.rar 84 | 85 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 86 | hs_err_pid* 87 | 88 | # Eclipse & build artifacts 89 | *.classpath 90 | *.project 91 | *.settings 92 | *target/ 93 | /.metadata/ 94 | *dependency-reduced-pom.xml 95 | 96 | # Dockerfile copied during build 97 | /Dockerfile 98 | -------------------------------------------------------------------------------- /docker/cpu/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 as build 2 | 3 | 4 | MAINTAINER Rob Rua 5 | 6 | 7 | ENV DEBIAN_FRONTEND noninteractive 8 | ENV LC_ALL C.UTF-8 9 | ENV LANG C.UTF-8 10 | 11 | 12 | # Install utilities 13 | RUN apt-get update --fix-missing && \ 14 | apt-get install -y wget bzip2 ca-certificates curl git && \ 15 | apt-get clean && \ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | 19 | # Anaconda home setup 20 | ENV CONDA_HOME /opt/conda 21 | ENV PATH $CONDA_HOME/bin:$PATH 22 | 23 | 24 | # Python version 25 | ARG python_version=3.6 26 | 27 | 28 | # Install Miniconda python 29 | RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/conda.sh && \ 30 | /bin/bash ~/conda.sh -b -p $CONDA_HOME && \ 31 | rm ~/conda.sh && \ 32 | conda install -y python=${python_version} && \ 33 | $CONDA_HOME/bin/conda clean -tipsy 34 | 35 | 36 | # Install easy-bert 37 | ADD requirements.txt /opt/easy-bert/requirements.txt 38 | RUN pip install -r /opt/easy-bert/requirements.txt 39 | 40 | ADD easybert /opt/easy-bert/easybert 41 | ADD MANIFEST.in /opt/easy-bert/MANIFEST.in 42 | ADD setup.py /opt/easy-bert/setup.py 43 | 44 | WORKDIR /opt/easy-bert 45 | RUN python setup.py install 46 | 47 | 48 | # Use multi-stage build to minimize image size 49 | FROM ubuntu:18.04 50 | 51 | 52 | MAINTAINER Rob Rua 53 | 54 | 55 | ENV DEBIAN_FRONTEND noninteractive 56 | ENV LC_ALL C.UTF-8 57 | ENV LANG C.UTF-8 58 | 59 | 60 | # Anaconda home setup 61 | ENV CONDA_HOME /opt/conda 62 | ENV PATH $CONDA_HOME/bin:$PATH 63 | 64 | 65 | COPY --from=build $CONDA_HOME $CONDA_HOME 66 | 67 | 68 | ENTRYPOINT ["bert"] 69 | -------------------------------------------------------------------------------- /docker/gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 as build 2 | 3 | 4 | MAINTAINER Rob Rua 5 | 6 | 7 | ENV DEBIAN_FRONTEND noninteractive 8 | ENV LC_ALL C.UTF-8 9 | ENV LANG C.UTF-8 10 | 11 | 12 | # Install utilities 13 | RUN apt-get update --fix-missing && \ 14 | apt-get install -y wget bzip2 ca-certificates curl git && \ 15 | apt-get clean && \ 16 | rm -rf /var/lib/apt/lists/* 17 | 18 | 19 | # Anaconda home setup 20 | ENV CONDA_HOME /opt/conda 21 | ENV PATH $CONDA_HOME/bin:$PATH 22 | 23 | 24 | # Python version 25 | ARG python_version=3.6 26 | 27 | 28 | # Install Miniconda python 29 | RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/conda.sh && \ 30 | /bin/bash ~/conda.sh -b -p $CONDA_HOME && \ 31 | rm ~/conda.sh && \ 32 | conda install -y python=${python_version} && \ 33 | $CONDA_HOME/bin/conda clean -tipsy 34 | 35 | 36 | # Install easy-bert 37 | ADD requirements-gpu.txt /opt/easy-bert/requirements.txt 38 | RUN pip install -r /opt/easy-bert/requirements.txt 39 | 40 | ADD easybert /opt/easy-bert/easybert 41 | ADD MANIFEST.in /opt/easy-bert/MANIFEST.in 42 | ADD setup.py /opt/easy-bert/setup.py 43 | 44 | WORKDIR /opt/easy-bert 45 | RUN python setup.py install 46 | 47 | 48 | # Use multi-stage build to minimize image size 49 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 50 | 51 | 52 | MAINTAINER Rob Rua 53 | 54 | 55 | ENV DEBIAN_FRONTEND noninteractive 56 | ENV LC_ALL C.UTF-8 57 | ENV LANG C.UTF-8 58 | 59 | 60 | # Anaconda home setup 61 | ENV CONDA_HOME /opt/conda 62 | ENV PATH $CONDA_HOME/bin:$PATH 63 | 64 | 65 | COPY --from=build $CONDA_HOME $CONDA_HOME 66 | 67 | 68 | ENTRYPOINT ["bert"] 69 | -------------------------------------------------------------------------------- /src/main/java/com/robrua/nlp/bert/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package com.robrua.nlp.bert; 2 | 3 | import java.util.Arrays; 4 | import java.util.Iterator; 5 | import java.util.List; 6 | import java.util.stream.Stream; 7 | 8 | import com.google.common.collect.Lists; 9 | 10 | /** 11 | * A tokenizer that converts text sequences into tokens or sub-tokens for BERT to use 12 | * 13 | * @author Rob Rua (https://github.com/robrua) 14 | * @version 1.0.3 15 | * @since 1.0.3 16 | */ 17 | public abstract class Tokenizer { 18 | /** 19 | * Splits a sequence into tokens based on whitespace 20 | * 21 | * @param sequence 22 | * the sequence to split 23 | * @return a stream of the tokens from the stream that were separated by whitespace 24 | * @since 1.0.3 25 | */ 26 | protected static Stream whitespaceTokenize(final String sequence) { 27 | return Arrays.stream(sequence.trim().split("\\s+")); 28 | } 29 | 30 | /** 31 | * Tokenizes a multiple sequences 32 | * 33 | * @param sequences 34 | * the sequences to tokenize 35 | * @return the tokens in the sequences, in the order the {@link java.lang.Iterable} provided them 36 | * @since 1.0.3 37 | */ 38 | public String[][] tokenize(final Iterable sequences) { 39 | final List list = Lists.newArrayList(sequences); 40 | return tokenize(list.toArray(new String[list.size()])); 41 | } 42 | 43 | /** 44 | * Tokenizes a multiple sequences 45 | * 46 | * @param sequences 47 | * the sequences to tokenize 48 | * @return the tokens in the sequences, in the order the {@link java.util.Iterator} provided them 49 | * @since 1.0.3 50 | */ 51 | public String[][] tokenize(final Iterator sequences) { 52 | final List list = Lists.newArrayList(sequences); 53 | return tokenize(list.toArray(new String[list.size()])); 54 | } 55 | 56 | /** 57 | * Tokenizes a single sequence 58 | * 59 | * @param sequence 60 | * the sequence to tokenize 61 | * @return the tokens in the sequence 62 | * @since 1.0.3 63 | */ 64 | public abstract String[] tokenize(String sequence); 65 | 66 | /** 67 | * Tokenizes a multiple sequences 68 | * 69 | * @param sequences 70 | * the sequences to tokenize 71 | * @return the tokens in the sequences, in the order they were provided 72 | * @since 1.0.3 73 | */ 74 | public abstract String[][] tokenize(String... sequences); 75 | } 76 | -------------------------------------------------------------------------------- /src/main/java/com/robrua/nlp/bert/WordpieceTokenizer.java: -------------------------------------------------------------------------------- 1 | package com.robrua.nlp.bert; 2 | 3 | import java.util.Arrays; 4 | import java.util.Map; 5 | import java.util.stream.Stream; 6 | 7 | /** 8 | * A port of the BERT WordpieceTokenizer in the BERT GitHub Repository. 9 | * 10 | * The WordpieceTokenizer processes tokens from the {@link com.robrua.nlp.bert.BasicTokenizer} into sub-tokens - parts of words that compose BERT's vocabulary. 11 | * These tokens can then be converted into the inputIds that the BERT model accepts. 12 | * 13 | * @author Rob Rua (https://github.com/robrua) 14 | * @version 1.0.3 15 | * @since 1.0.3 16 | * 17 | * @see The Python tokenization code this is ported from 18 | */ 19 | public class WordpieceTokenizer extends Tokenizer { 20 | private static final int DEFAULT_MAX_CHARACTERS_PER_WORD = 200; 21 | private static final String DEFAULT_UNKNOWN_TOKEN = "[UNK]"; 22 | 23 | private final int maxCharactersPerWord; 24 | private final String unknownToken; 25 | private final Map vocabulary; 26 | 27 | /** 28 | * Creates a BERT {@link com.robrua.nlp.bert.WordpieceTokenizer} 29 | * 30 | * @param vocabulary 31 | * a mapping from sub-tokens in the BERT vocabulary to their inputIds 32 | * @since 1.0.3 33 | */ 34 | public WordpieceTokenizer(final Map vocabulary) { 35 | this.vocabulary = vocabulary; 36 | unknownToken = DEFAULT_UNKNOWN_TOKEN; 37 | maxCharactersPerWord = DEFAULT_MAX_CHARACTERS_PER_WORD; 38 | } 39 | 40 | /** 41 | * Creates a BERT {@link com.robrua.nlp.bert.WordpieceTokenizer} 42 | * 43 | * @param vocabulary 44 | * a mapping from sub-tokens in the BERT vocabulary to their inputIds 45 | * @param unknownToken 46 | * the sub-token to use when an unrecognized or too-long token is encountered 47 | * @param maxCharactersPerToken 48 | * the maximum number of characters allowed in a token to be sub-tokenized 49 | * @since 1.0.3 50 | */ 51 | public WordpieceTokenizer(final Map vocabulary, final String unknownToken, final int maxCharactersPerToken) { 52 | this.vocabulary = vocabulary; 53 | this.unknownToken = unknownToken; 54 | maxCharactersPerWord = maxCharactersPerToken; 55 | } 56 | 57 | private Stream splitToken(final String token) { 58 | final char[] characters = token.toCharArray(); 59 | if(characters.length > maxCharactersPerWord) { 60 | return Stream.of(unknownToken); 61 | } 62 | 63 | final Stream.Builder subtokens = Stream.builder(); 64 | int start = 0; 65 | int end; 66 | while(start < characters.length) { 67 | end = characters.length; 68 | boolean found = false; 69 | while(start < end) { 70 | final String substring = (start > 0 ? "##" : "") + String.valueOf(characters, start, end - start); 71 | if(vocabulary.containsKey(substring)) { 72 | subtokens.accept(substring); 73 | start = end; 74 | found = true; 75 | break; 76 | } 77 | end--; 78 | } 79 | if(!found) { 80 | subtokens.accept(unknownToken); 81 | break; 82 | } 83 | start = end; 84 | } 85 | return subtokens.build(); 86 | } 87 | 88 | @Override 89 | public String[] tokenize(final String sequence) { 90 | return whitespaceTokenize(sequence) 91 | .flatMap(this::splitToken) 92 | .toArray(String[]::new); 93 | } 94 | 95 | @Override 96 | public String[][] tokenize(final String... sequences) { 97 | return Arrays.stream(sequences) 98 | .map((final String sequence) -> whitespaceTokenize(sequence).toArray(String[]::new)) 99 | .map((final String[] tokens) -> Arrays.stream(tokens) 100 | .flatMap(this::splitToken) 101 | .toArray(String[]::new)) 102 | .toArray(String[][]::new); 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/com/robrua/nlp/bert/FullTokenizer.java: -------------------------------------------------------------------------------- 1 | package com.robrua.nlp.bert; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.File; 5 | import java.io.IOException; 6 | import java.net.URISyntaxException; 7 | import java.nio.charset.Charset; 8 | import java.nio.file.Files; 9 | import java.nio.file.Path; 10 | import java.nio.file.Paths; 11 | import java.util.Arrays; 12 | import java.util.HashMap; 13 | import java.util.Map; 14 | import java.util.stream.Stream; 15 | 16 | import com.google.common.io.Resources; 17 | 18 | /** 19 | * A port of the BERT FullTokenizer in the BERT GitHub Repository. 20 | * 21 | * It's used to segment input sequences into the BERT tokens that exist in the model's vocabulary. These tokens are later converted into inputIds for the model. 22 | * 23 | * It basically just feeds sequences to the {@link com.robrua.nlp.bert.BasicTokenizer} then passes those results to the 24 | * {@link com.robrua.nlp.bert.WordpieceTokenizer} 25 | * 26 | * @author Rob Rua (https://github.com/robrua) 27 | * @version 1.0.3 28 | * @since 1.0.3 29 | * 30 | * @see The Python tokenization code this is ported from 31 | */ 32 | public class FullTokenizer extends Tokenizer { 33 | private static final boolean DEFAULT_DO_LOWER_CASE = false; 34 | 35 | private static Map loadVocabulary(final Path file) { 36 | final Map vocabulary = new HashMap<>(); 37 | try(BufferedReader reader = Files.newBufferedReader(file, Charset.forName("UTF-8"))) { 38 | int index = 0; 39 | String line; 40 | while((line = reader.readLine()) != null) { 41 | vocabulary.put(line.trim(), index++); 42 | } 43 | } catch(final IOException e) { 44 | throw new RuntimeException(e); 45 | } 46 | return vocabulary; 47 | } 48 | 49 | private static Path toPath(final String resource) { 50 | try { 51 | return Paths.get(Resources.getResource(resource).toURI()); 52 | } catch(final URISyntaxException e) { 53 | throw new RuntimeException(e); 54 | } 55 | } 56 | 57 | private final BasicTokenizer basic; 58 | private final Map vocabulary; 59 | private final WordpieceTokenizer wordpiece; 60 | 61 | /** 62 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 63 | * 64 | * @param vocabulary 65 | * the BERT vocabulary file to use for tokenization 66 | * @since 1.0.3 67 | */ 68 | public FullTokenizer(final File vocabulary) { 69 | this(Paths.get(vocabulary.toURI()), DEFAULT_DO_LOWER_CASE); 70 | } 71 | 72 | /** 73 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 74 | * 75 | * @param vocabulary 76 | * the BERT vocabulary file to use for tokenization 77 | * @param doLowerCase 78 | * whether to convert sequences to lower case during tokenization 79 | * @since 1.0.3 80 | */ 81 | public FullTokenizer(final File vocabulary, final boolean doLowerCase) { 82 | this(Paths.get(vocabulary.toURI()), doLowerCase); 83 | } 84 | 85 | /** 86 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 87 | * 88 | * @param vocabularyPath 89 | * the path to the BERT vocabulary file to use for tokenization 90 | * @since 1.0.3 91 | */ 92 | public FullTokenizer(final Path vocabularyPath) { 93 | this(vocabularyPath, DEFAULT_DO_LOWER_CASE); 94 | } 95 | 96 | /** 97 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 98 | * 99 | * @param vocabularyPath 100 | * the path to the BERT vocabulary file to use for tokenization 101 | * @param doLowerCase 102 | * whether to convert sequences to lower case during tokenization 103 | * @since 1.0.3 104 | */ 105 | public FullTokenizer(final Path vocabularyPath, final boolean doLowerCase) { 106 | vocabulary = loadVocabulary(vocabularyPath); 107 | basic = new BasicTokenizer(doLowerCase); 108 | wordpiece = new WordpieceTokenizer(vocabulary); 109 | } 110 | 111 | /** 112 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 113 | * 114 | * @param vocabularyResource 115 | * the resource path to the BERT vocabulary file to use for tokenization 116 | * @since 1.0.3 117 | */ 118 | public FullTokenizer(final String vocabularyResource) { 119 | this(toPath(vocabularyResource), DEFAULT_DO_LOWER_CASE); 120 | } 121 | 122 | /** 123 | * Creates a BERT {@link com.robrua.nlp.bert.FullTokenizer} 124 | * 125 | * @param vocabularyResource 126 | * the resource path to the BERT vocabulary file to use for tokenization 127 | * @param doLowerCase 128 | * whether to convert sequences to lower case during tokenization 129 | * @since 1.0.3 130 | */ 131 | public FullTokenizer(final String vocabularyResource, final boolean doLowerCase) { 132 | this(toPath(vocabularyResource), doLowerCase); 133 | } 134 | 135 | /** 136 | * Converts BERT sub-tokens into their inputIds 137 | * 138 | * @param tokens 139 | * the tokens to convert 140 | * @return the inputIds for the tokens 141 | * @since 1.0.3 142 | */ 143 | public int[] convert(final String[] tokens) { 144 | return Arrays.stream(tokens).mapToInt(vocabulary::get).toArray(); 145 | } 146 | 147 | @Override 148 | public String[] tokenize(final String sequence) { 149 | return Arrays.stream(wordpiece.tokenize(basic.tokenize(sequence))) 150 | .flatMap(Stream::of) 151 | .toArray(String[]::new); 152 | } 153 | 154 | @Override 155 | public String[][] tokenize(final String... sequences) { 156 | return Arrays.stream(basic.tokenize(sequences)) 157 | .map((final String[] tokens) -> Arrays.stream(wordpiece.tokenize(tokens)) 158 | .flatMap(Stream::of) 159 | .toArray(String[]::new)) 160 | .toArray(String[][]::new); 161 | } 162 | } 163 | -------------------------------------------------------------------------------- /easybert/__main__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Runs a BERT model from the command line 3 | """ 4 | from typing import ContextManager 5 | from contextlib import contextmanager 6 | from pathlib import Path 7 | import os 8 | 9 | import numpy as np 10 | import click 11 | 12 | from .bert import Bert, _DEFAULT_MAX_SEQUENCE_LENGTH, _DEFAULT_PER_TOKEN_EMBEDDING 13 | from . import __version__ 14 | 15 | 16 | @contextmanager 17 | def _gpu(gpu: bool) -> ContextManager: 18 | """ 19 | A contextmanager for controlling the visibility of CUDA devices. 20 | Allows for running on CPU or GPU on devices which support both 21 | 22 | Args: 23 | gpu (bool): whether to use the GPU 24 | """ 25 | if gpu: 26 | yield 27 | else: 28 | try: 29 | visible_devices = os.environ["CUDA_VISIBLE_DEVICES"] 30 | except KeyError: 31 | visible_devices = None 32 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 33 | yield 34 | if visible_devices: 35 | os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices 36 | else: 37 | del os.environ["CUDA_VISIBLE_DEVICES"] 38 | 39 | 40 | @contextmanager 41 | def _errors_only(activate: bool) -> ContextManager: 42 | """ 43 | A contextmanager for stopping TensorFlow from spamming the console with non-errors 44 | 45 | Args: 46 | activate (bool): whether to restrict TensorFlow logging to errors 47 | """ 48 | if not activate: 49 | yield 50 | else: 51 | try: 52 | log_level = os.environ["TF_CPP_MIN_LOG_LEVEL"] 53 | except KeyError: 54 | log_level = None 55 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" 56 | yield 57 | if log_level: 58 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = log_level 59 | else: 60 | del os.environ["TF_CPP_MIN_LOG_LEVEL"] 61 | 62 | 63 | @click.group(help="Run a pretrained BERT model") 64 | @click.version_option(version=__version__) 65 | def _main() -> None: 66 | pass 67 | 68 | 69 | _DEFAULT_ENCODING = "UTF-8" 70 | _DEFAULT_HUB_MODEL = "https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1" 71 | _DEFAULT_GPU = False 72 | _DEFAULT_VERBOSE = False 73 | 74 | 75 | @_main.command(name="embed", help="Gets BERT embeddings of provided data") 76 | @click.option("-s", "--sequence", default=None, type=str, help="the sequence to embed") 77 | @click.option("-i", "--input", default=None, type=str, help="the path to a .txt file containing sequences to embed, one per line") 78 | @click.option("-e", "--encoding", default=_DEFAULT_ENCODING, help="the text encoding of the input file provided by (-i/--input)", show_default=True) 79 | @click.option("-o", "--output", default=None, type=str, help="the path to put the resuling [default: print the embeddings to console]", show_default=False) 80 | @click.option("-m", "--model", default=None, type=str, help="the path to the TensorFlow saved model to use [default: use a model from TensorFlow Hub (-h/--hub-model)]", show_default=False) 81 | @click.option("-t/-p", "--tokens/--pooled", default=_DEFAULT_PER_TOKEN_EMBEDDING, help="whether to return per-token embeddings or pooled embeddings for the full sequences [default: {}]".format("tokens" if _DEFAULT_PER_TOKEN_EMBEDDING else "pooled"), show_default=False) 82 | @click.option("-h", "--hub-model", default=_DEFAULT_HUB_MODEL, help="the url to the TensorFlow Hub BERT model to use", show_default=True) 83 | @click.option("-l", "--max-sequence-length", default=_DEFAULT_MAX_SEQUENCE_LENGTH, help="the max sequence length that a model initialized from TensorFlow Hub should allow, if one is being used", show_default=True) 84 | @click.option("-g/-c", "--gpu/--cpu", default=_DEFAULT_GPU, help="whether to use the gpu [default: {}]".format("gpu" if _DEFAULT_GPU else "cpu"), show_default=False) 85 | @click.option("-v/-q", "--verbose/--quiet", default=_DEFAULT_VERBOSE, help="whether to log verbose TensorFlow output [default: {}]".format("verbose" if _DEFAULT_VERBOSE else "quiet"), show_default=False) 86 | def _embed(sequence: str = None, 87 | input: str = None, 88 | encoding: str = _DEFAULT_ENCODING, 89 | output: str = None, 90 | model: str = None, 91 | tokens: bool = _DEFAULT_PER_TOKEN_EMBEDDING, 92 | hub_model: str = _DEFAULT_HUB_MODEL, 93 | max_sequence_length: int = _DEFAULT_MAX_SEQUENCE_LENGTH, 94 | gpu: bool = _DEFAULT_GPU, 95 | verbose: bool = _DEFAULT_VERBOSE) -> None: 96 | # Check inputs 97 | if sequence is None and input is None: 98 | print("Error: Missing option \"-s\" / \"--sequence\" or \"-i\" / \"--input\". Please include a sequence or input file of sequences to embed.") 99 | exit(1) 100 | if sequence is not None and input is not None: 101 | print("Error: Redundant options \"-s\" / \"--sequence\" and \"-i\" / \"--input\". Only one of these options should be provided.") 102 | exit(1) 103 | 104 | # Get sequences to embed 105 | if sequence is not None: 106 | sequences = sequence 107 | elif input is not None: 108 | input = Path(input) 109 | with input.open("r", encoding=encoding) as in_file: 110 | sequences = [sequence.strip() for sequence in in_file] 111 | 112 | with _errors_only(not verbose), _gpu(gpu): 113 | # Load model 114 | if model is not None: 115 | bert = Bert.load(path=model) 116 | else: 117 | bert = Bert(tf_hub_url=hub_model, max_sequence_length=max_sequence_length) 118 | 119 | # Embed 120 | embeddings = bert.embed(sequences=sequences, per_token=tokens) 121 | 122 | # Output embeddings 123 | if output is not None: 124 | np.save(output, embeddings, allow_pickle=False) 125 | else: 126 | print(embeddings) 127 | 128 | 129 | @_main.command(name="download", help="Downloads a TensorFlow Hub BERT model and converts it into a TensorFlow saved model") 130 | @click.option("-m", "--model", required=True, type=str, help="the path to save the BERT model to") 131 | @click.option("-h", "--hub-model", default=_DEFAULT_HUB_MODEL, help="the url to the TensorFlow Hub BERT model to use", show_default=True) 132 | @click.option("-l", "--max-sequence-length", default=_DEFAULT_MAX_SEQUENCE_LENGTH, help="the max sequence length that the model should allow", show_default=True) 133 | @click.option("-o/-s", "--overwrite/--safe", default=False, help="whether to overwrite the model directory if there's already a file or directory ther [default: safe]", show_default=False) 134 | @click.option("-g/-c", "--gpu/--cpu", default=_DEFAULT_GPU, help="whether to use the gpu [default: {}]".format("gpu" if _DEFAULT_GPU else "cpu"), show_default=False) 135 | @click.option("-v/-q", "--verbose/--quiet", default=_DEFAULT_VERBOSE, help="whether to log verbose TensorFlow output [default: {}]".format("verbose" if _DEFAULT_VERBOSE else "quiet"), show_default=False) 136 | def _download(model: str, 137 | hub_model: str = _DEFAULT_HUB_MODEL, 138 | max_sequence_length: int = _DEFAULT_MAX_SEQUENCE_LENGTH, 139 | overwrite: bool = False, 140 | gpu: bool = _DEFAULT_GPU, 141 | verbose: bool = _DEFAULT_VERBOSE) -> None: 142 | with _errors_only(not verbose), _gpu(gpu): 143 | bert = Bert(tf_hub_url=hub_model, max_sequence_length=max_sequence_length) 144 | bert.save(path=model, overwrite=overwrite) 145 | 146 | 147 | if __name__ == "__main__": 148 | _main(prog_name="bert") 149 | -------------------------------------------------------------------------------- /src/main/java/com/robrua/nlp/bert/BasicTokenizer.java: -------------------------------------------------------------------------------- 1 | package com.robrua.nlp.bert; 2 | 3 | import java.text.Normalizer; 4 | import java.util.Arrays; 5 | import java.util.Set; 6 | import java.util.stream.Stream; 7 | 8 | import com.google.common.collect.ImmutableSet; 9 | 10 | /** 11 | * A port of the BERT BasicTokenizer in the BERT GitHub Repository. 12 | * 13 | * The BasicTokenizer is used to segment input sequences into linguistic tokens, which in most cases are words. These tokens can be fed to the 14 | * {@link com.robrua.nlp.bert.WordpieceTokenizer} to further segment them into the BERT tokens that are used for input into the model. 15 | * 16 | * @author Rob Rua (https://github.com/robrua) 17 | * @version 1.0.3 18 | * @since 1.0.3 19 | * 20 | * @see The Python tokenization code this is ported from 21 | */ 22 | public class BasicTokenizer extends Tokenizer { 23 | private static final Set CONTROL_CATEGORIES = ImmutableSet.of((int)Character.CONTROL, 24 | (int)Character.FORMAT, 25 | (int)Character.PRIVATE_USE, 26 | (int)Character.SURROGATE, 27 | (int)Character.UNASSIGNED); // In bert-tensorflow this is any category where the Unicode specification starts with "C" 28 | 29 | private static final Set PUNCTUATION_CATEGORIES = ImmutableSet.of((int)Character.CONNECTOR_PUNCTUATION, 30 | (int)Character.DASH_PUNCTUATION, 31 | (int)Character.END_PUNCTUATION, 32 | (int)Character.FINAL_QUOTE_PUNCTUATION, 33 | (int)Character.INITIAL_QUOTE_PUNCTUATION, 34 | (int)Character.OTHER_PUNCTUATION, 35 | (int)Character.START_PUNCTUATION); // In bert-tensorflow this is any category where the Unicode specification starts with "P" 36 | 37 | private static final Set SAFE_CONTROL_CHARACTERS = ImmutableSet.of((int)'\t', (int)'\n', (int)'\r'); 38 | private static final Set STRIP_CHARACTERS = ImmutableSet.of(0, 0xFFFD); 39 | private static final Set WHITESPACE_CHARACTERS = ImmutableSet.of((int)' ', (int)'\t', (int)'\n', (int)'\r'); 40 | 41 | private static String cleanText(final String sequence) { 42 | final StringBuilder builder = new StringBuilder(); 43 | sequence.codePoints().filter((final int codePoint) -> !STRIP_CHARACTERS.contains(codePoint) && !isControl(codePoint)) 44 | .map((final int codePoint) -> isWhitespace(codePoint) ? ' ' : codePoint) 45 | .forEachOrdered((final int codePoint) -> builder.append(Character.toChars(codePoint))); 46 | return builder.toString(); 47 | } 48 | 49 | private static boolean isChineseCharacter(final int codePoint) { 50 | return codePoint >= 0x4E00 && codePoint <= 0x9FFF || 51 | codePoint >= 0x3400 && codePoint <= 0x4DBF || 52 | codePoint >= 0x20000 && codePoint <= 0x2A6DF || 53 | codePoint >= 0x2A700 && codePoint <= 0x2B73F || 54 | codePoint >= 0x2B740 && codePoint <= 0x2B81F || 55 | codePoint >= 0x2B820 && codePoint <= 0x2CEAF || 56 | codePoint >= 0xF900 && codePoint <= 0xFAFF || 57 | codePoint >= 0x2F800 && codePoint <= 0x2FA1F; 58 | } 59 | 60 | private static boolean isControl(final int codePoint) { 61 | return !SAFE_CONTROL_CHARACTERS.contains(codePoint) && CONTROL_CATEGORIES.contains(Character.getType(codePoint)); 62 | } 63 | 64 | private static boolean isPunctuation(final int codePoint) { 65 | return codePoint >= 33 && codePoint <= 47 || 66 | codePoint >= 58 && codePoint <= 64 || 67 | codePoint >= 91 && codePoint <= 96 || 68 | codePoint >= 123 && codePoint <= 126 || 69 | PUNCTUATION_CATEGORIES.contains(Character.getType(codePoint)); 70 | } 71 | 72 | private static boolean isWhitespace(final int codePoint) { 73 | return WHITESPACE_CHARACTERS.contains(codePoint) || Character.SPACE_SEPARATOR == Character.getType(codePoint); 74 | } 75 | 76 | private static Stream splitOnPunctuation(final String token) { 77 | final Stream.Builder stream = Stream.builder(); 78 | 79 | final StringBuilder builder = new StringBuilder(); 80 | token.codePoints().forEachOrdered((final int codePoint) -> { 81 | if(isPunctuation(codePoint)) { 82 | stream.accept(builder.toString()); 83 | builder.setLength(0); 84 | stream.accept(String.valueOf(Character.toChars(codePoint))); 85 | } else { 86 | builder.append(Character.toChars(codePoint)); 87 | } 88 | }); 89 | if(builder.length() > 0) { 90 | stream.accept(builder.toString()); 91 | } 92 | 93 | return stream.build(); 94 | } 95 | 96 | private static String stripAccents(final String token) { 97 | final StringBuilder builder = new StringBuilder(); 98 | Normalizer.normalize(token, Normalizer.Form.NFD).codePoints() 99 | .filter((final int codePoint) -> Character.NON_SPACING_MARK != Character.getType(codePoint)) 100 | .forEachOrdered((final int codePoint) -> builder.append(Character.toChars(codePoint))); 101 | return builder.toString(); 102 | } 103 | 104 | private static String tokenizeChineseCharacters(final String sequence) { 105 | final StringBuilder builder = new StringBuilder(); 106 | sequence.codePoints().forEachOrdered((final int codePoint) -> { 107 | if(isChineseCharacter(codePoint)) { 108 | builder.append(' '); 109 | builder.append(Character.toChars(codePoint)); 110 | builder.append(' '); 111 | } else { 112 | builder.append(Character.toChars(codePoint)); 113 | } 114 | }); 115 | return builder.toString(); 116 | } 117 | 118 | private final boolean doLowerCase; 119 | 120 | /** 121 | * Creates a BERT {@link com.robrua.nlp.bert.BasicTokenizer} 122 | * 123 | * @param doLowerCase 124 | * whether to convert sequences to lower case during tokenization 125 | * @since 1.0.3 126 | */ 127 | public BasicTokenizer(final boolean doLowerCase) { 128 | this.doLowerCase = doLowerCase; 129 | } 130 | 131 | private String stripAndSplit(String token) { 132 | if(doLowerCase) { 133 | token = stripAccents(token.toLowerCase()); 134 | } 135 | return String.join(" ", splitOnPunctuation(token).toArray(String[]::new)); 136 | } 137 | 138 | @Override 139 | public String[][] tokenize(final String... sequences) { 140 | return Arrays.stream(sequences) 141 | .map(BasicTokenizer::cleanText) 142 | .map(BasicTokenizer::tokenizeChineseCharacters) 143 | .map((final String sequence) -> whitespaceTokenize(sequence).toArray(String[]::new)) 144 | .map((final String[] tokens) -> Arrays.stream(tokens) 145 | .map(this::stripAndSplit) 146 | .flatMap(BasicTokenizer::whitespaceTokenize) 147 | .toArray(String[]::new)) 148 | .toArray(String[][]::new); 149 | } 150 | 151 | @Override 152 | public String[] tokenize(final String sequence) { 153 | return whitespaceTokenize(tokenizeChineseCharacters(cleanText(sequence))) 154 | .map(this::stripAndSplit) 155 | .flatMap(BasicTokenizer::whitespaceTokenize) 156 | .toArray(String[]::new); 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![MIT Licensed](https://img.shields.io/badge/license-MIT-green.svg)](https://github.com/robrua/easy-bert/blob/master/LICENSE.txt) 2 | [![PyPI](https://img.shields.io/pypi/v/easybert.svg)](https://pypi.org/project/easybert/) 3 | [![Maven Central](https://img.shields.io/maven-central/v/com.robrua.nlp/easy-bert.svg)](https://search.maven.org/search?q=g:com.robrua.nlp%20a:easy-bert) 4 | [![JavaDocs](https://javadoc.io/badge/com.robrua.nlp/easy-bert.svg)](https://javadoc.io/doc/com.robrua.nlp/easy-bert) 5 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.2651822.svg)](https://doi.org/10.5281/zenodo.2651822) 6 | 7 | # easy-bert 8 | easy-bert is a dead simple API for using Google's high quality [BERT](https://github.com/google-research/bert) language model in Python and Java. 9 | 10 | Currently, easy-bert is focused on getting embeddings from pre-trained BERT models in both Python and Java. Support for fine-tuning and pre-training in Python will be added in the future, as well as support for using easy-bert for other tasks besides getting embeddings. 11 | 12 | ## Python 13 | 14 | ### How To Get It 15 | easy-bert is available on [PyPI](https://pypi.org/project/easybert/). You can install with `pip install easybert` or `pip install git+https://github.com/robrua/easy-bert.git` if you want the very latest. 16 | 17 | ### Usage 18 | You can use easy-bert with pre-trained BERT models from TensorFlow Hub or from local models in the TensorFlow saved model format. 19 | 20 | To create a BERT embedder from a TensowFlow Hub model, simply instantiate a Bert object with the target tf-hub URL: 21 | 22 | ```python 23 | from easybert import Bert 24 | bert = Bert("https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1") 25 | ``` 26 | 27 | You can also load a local model in TensorFlow's saved model format using `Bert.load`: 28 | 29 | ```python 30 | from easybert import Bert 31 | bert = Bert.load("/path/to/your/model/") 32 | ``` 33 | 34 | Once you have a BERT model loaded, you can get sequence embeddings using `bert.embed`: 35 | 36 | ```python 37 | x = bert.embed("A sequence") 38 | y = bert.embed(["Multiple", "Sequences"]) 39 | ``` 40 | 41 | If you want per-token embeddings, you can set `per_token=True`: 42 | 43 | ```python 44 | x = bert.embed("A sequence", per_token=True) 45 | y = bert.embed(["Multiple", "Sequences"], per_token=True) 46 | ``` 47 | 48 | easy-bert returns BERT embeddings as numpy arrays 49 | 50 | 51 | Every time you call `bert.embed`, a new TensorFlow session is created and used for the computation. If you're calling `bert.embed` a lot sequentially, you can speed up your code by sharing a TensorFlow session among those calls using a `with` statement: 52 | 53 | ```python 54 | with bert: 55 | x = bert.embed("A sequence", per_token=True) 56 | y = bert.embed(["Multiple", "Sequences"], per_token=True) 57 | ``` 58 | 59 | You can save a BERT model using `bert.save`, then reload it later using `Bert.load`: 60 | 61 | ```python 62 | bert.save("/path/to/your/model/") 63 | bert = Bert.load("/path/to/your/model/") 64 | ``` 65 | 66 | ### CLI 67 | easy-bert also provides a CLI tool to conveniently do one-off embeddings of sequences with BERT. It can also convert a TensorFlow Hub model to a saved model. 68 | 69 | Run `bert --help`, `bert embed --help` or `bert download --help` to get details about the CLI tool. 70 | 71 | ### Docker 72 | easy-bert comes with a [docker build](https://hub.docker.com/r/robrua/easy-bert) that can be used as a base image for applications that rely on bert embeddings or to just run the CLI tool without needing to install an environment. 73 | 74 | ## Java 75 | 76 | ### How To Get It 77 | easy-bert is available on [Maven Central](https://search.maven.org/search?q=g:com.robrua.nlp%20a:easy-bert). It is also distributed through the [releases page](https://github.com/robrua/easy-bert/releases). 78 | 79 | To add the latest easy-bert release version to your maven project, add the dependency to your `pom.xml` dependencies section: 80 | ```xml 81 | 82 | 83 | com.robrua.nlp 84 | easy-bert 85 | 1.0.3 86 | 87 | 88 | ``` 89 | Or, if you want to get the latest development version, add the [Sonaype Snapshot Repository](https://oss.sonatype.org/content/repositories/snapshots/) to your `pom.xml` as well: 90 | ```xml 91 | 92 | 93 | com.robrua.nlp 94 | easy-bert 95 | 1.0.4-SNAPSHOT 96 | 97 | 98 | 99 | 100 | 101 | snapshots-repo 102 | https://oss.sonatype.org/content/repositories/snapshots 103 | 104 | false 105 | 106 | 107 | true 108 | 109 | 110 | 111 | ``` 112 | 113 | ### Usage 114 | You can use easy-bert with pre-trained BERT models generated with easy-bert's Python tools. You can also used pre-generated models on Maven Central. 115 | 116 | To load a model from your local filesystem, you can use: 117 | 118 | ```java 119 | try(Bert bert = Bert.load(new File("/path/to/your/model/"))) { 120 | // Embed some sequences 121 | } 122 | ``` 123 | 124 | If the model is in your classpath (e.g. if you're pulling it in via Maven), you can use: 125 | 126 | ```java 127 | try(Bert bert = Bert.load("/resource/path/to/your/model")) { 128 | // Embed some sequences 129 | } 130 | ``` 131 | 132 | Once you have a BERT model loaded, you can get sequence embeddings using `bert.embedSequence` or `bert.embedSequences`: 133 | 134 | ```java 135 | float[] embedding = bert.embedSequence("A sequence"); 136 | float[][] embeddings = bert.embedSequences("Multiple", "Sequences"); 137 | ``` 138 | 139 | If you want per-token embeddings, you can use `bert.embedTokens`: 140 | 141 | ```java 142 | float[][] embedding = bert.embedTokens("A sequence"); 143 | float[][][] embeddings = bert.embedTokens("Multiple", "Sequences"); 144 | ``` 145 | 146 | ### Pre-Generated Maven Central Models 147 | Various TensorFlow Hub BERT models are available in easy-bert format on [Maven Central](https://search.maven.org/search?q=g:com.robrua.nlp.models). To use one in your project, add the following to your `pom.xml`, substituting one of the Artifact IDs listed below in place of `ARTIFACT-ID` in the `artifactId`: 148 | 149 | ```xml 150 | 151 | 152 | com.robrua.nlp.models 153 | ARTIFACT-ID 154 | 1.0.0 155 | 156 | 157 | ``` 158 | 159 | Once you've pulled in the dependency, you can load the model using this code. Substitute the appropriate Resource Path from the list below in place of `RESOURCE-PATH` based on the model you added as a dependency: 160 | 161 | ```java 162 | try(Bert bert = Bert.load("RESOURCE-PATH")) { 163 | // Embed some sequences 164 | } 165 | ``` 166 | 167 | #### Available Models 168 | | Model | Languages | Layers | Embedding Size | Heads | Parameters | Artifact ID | Resource Path | 169 | | --- | --- | --- | --- | --- | --- | --- | --- | 170 | | [BERT-Base, Uncased](https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1) | English | 12 | 768 | 12 | 110M | easy-bert-uncased-L-12-H-768-A-12 [![Maven Central](https://img.shields.io/maven-central/v/com.robrua.nlp.models/easy-bert-uncased-L-12-H-768-A-12.svg)](https://search.maven.org/search?q=g:com.robrua.nlp.models%20a:easy-bert-uncased-L-12-H-768-A-12) | com/robrua/nlp/easy-bert/bert-uncased-L-12-H-768-A-12 | 171 | | [BERT-Base, Cased](https://tfhub.dev/google/bert_cased_L-12_H-768_A-12/1) | English | 12 | 768 | 12 | 110M | easy-bert-cased-L-12-H-768-A-12 [![Maven Central](https://img.shields.io/maven-central/v/com.robrua.nlp.models/easy-bert-cased-L-12-H-768-A-12.svg)](https://search.maven.org/search?q=g:com.robrua.nlp.models%20a:easy-bert-cased-L-12-H-768-A-12) | com/robrua/nlp/easy-bert/bert-cased-L-12-H-768-A-12 | 172 | | [BERT-Base, Multilingual Cased](https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1) | 104 Languages | 12 | 768 | 12 | 110M | easy-bert-multi-cased-L-12-H-768-A-12 [![Maven Central](https://img.shields.io/maven-central/v/com.robrua.nlp.models/easy-bert-multi-cased-L-12-H-768-A-12.svg)](https://search.maven.org/search?q=g:com.robrua.nlp.models%20a:easy-bert-multi-cased-L-12-H-768-A-12) | com/robrua/nlp/easy-bert/bert-multi-cased-L-12-H-768-A-12 | 173 | | [BERT-Base, Chinese](https://tfhub.dev/google/bert_chinese_L-12_H-768_A-12/1) | Chinese Simplified and Traditional | 12 | 768 | 12 | 110M | easy-bert-chinese-L-12-H-768-A-12 [![Maven Central](https://img.shields.io/maven-central/v/com.robrua.nlp.models/easy-bert-chinese-L-12-H-768-A-12.svg)](https://search.maven.org/search?q=g:com.robrua.nlp.models%20a:easy-bert-chinese-L-12-H-768-A-12) | com/robrua/nlp/easy-bert/bert-chinese-L-12-H-768-A-12 | 174 | 175 | ### Creating Your Own Models 176 | For now, easy-bert can only use pre-trained TensorFlow Hub BERT models that have been converted using the Python tools. We will be adding support for fine-tuning and pre-training new models easily, but there are no plans to support these on the Java side. You'll need to train in Python, save the model, then load it in Java. 177 | 178 | ## Bugs 179 | If you find bugs please let us know via a pull request or issue. 180 | 181 | ## Citing easy-bert 182 | If you used easy-bert for your research, please [cite the project](https://doi.org/10.5281/zenodo.2651822). 183 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.robrua.nlp 8 | easy-bert 9 | 1.0.4-SNAPSHOT 10 | 11 | easy-bert 12 | A Dead Simple BERT API (https://github.com/google-research/bert) 13 | https://github.com/robrua/easy-bert 14 | 15 | 16 | 17 | robrua 18 | Rob Rua 19 | robertrua@gmail.com 20 | 21 | 22 | 23 | 24 | 25 | MIT License 26 | https://opensource.org/licenses/MIT 27 | repo 28 | 29 | 30 | 31 | 32 | scm:git:git@github.com:robrua/easy-bert.git 33 | scm:git:git@github.com:robrua/easy-bert.git 34 | git@github.com:robrua/easy-bert.git 35 | 36 | 37 | 38 | 39 | snapshots-repo 40 | https://oss.sonatype.org/content/repositories/snapshots 41 | 42 | false 43 | 44 | 45 | true 46 | 47 | 48 | 49 | 50 | 51 | UTF-8 52 | 3.8.0 53 | 3.0.1 54 | 3.1.0 55 | 3.1.0 56 | 2.5.2 57 | 2.8.2 58 | 1.6.0 59 | 1.4.10 60 | 1.6 61 | 1.6.8 62 | 1.20 63 | 3.2.1 64 | 65 | 2.10.0.pr3 66 | 27.1-jre 67 | 1.13.1 68 | 69 | robrua/easy-bert 70 | cpu 71 | docker/cpu 72 | 73 | 74 | 75 | 76 | com.fasterxml.jackson.core 77 | jackson-databind 78 | ${jackson.version} 79 | 80 | 81 | com.google.guava 82 | guava 83 | ${guava.version} 84 | 85 | 86 | org.tensorflow 87 | tensorflow 88 | ${tensorflow.version} 89 | 90 | 91 | 92 | 93 | 94 | 95 | maven-compiler-plugin 96 | ${maven-compiler.version} 97 | 98 | ${project.build.sourceEncoding} 99 | 1.8 100 | 1.8 101 | 102 | 103 | 104 | maven-source-plugin 105 | ${maven-source.version} 106 | 107 | 108 | attach-sources 109 | verify 110 | 111 | jar-no-fork 112 | 113 | 114 | 115 | 116 | 117 | maven-javadoc-plugin 118 | ${maven-javadoc.version} 119 | 120 | 121 | attach-javadocs 122 | verify 123 | 124 | jar 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | release 135 | 136 | 137 | ossrh 138 | https://oss.sonatype.org/content/repositories/snapshots 139 | 140 | 141 | ossrh 142 | https://oss.sonatype.org/service/local/staging/deploy/maven2/ 143 | 144 | 145 | 146 | 147 | 148 | 149 | maven-gpg-plugin 150 | ${maven-gpg.version} 151 | 152 | 153 | sign-artifacts 154 | verify 155 | 156 | sign 157 | 158 | 159 | 160 | 161 | 162 | org.sonatype.plugins 163 | nexus-staging-maven-plugin 164 | ${nexus-maven.version} 165 | true 166 | 167 | ossrh 168 | https://oss.sonatype.org/ 169 | true 170 | 171 | 172 | 173 | 174 | 175 | 176 | shaded 177 | 178 | 179 | 180 | org.codehaus.mojo 181 | license-maven-plugin 182 | ${license-maven.version} 183 | 184 | system,test 185 | 186 | 187 | 188 | add-third-party 189 | 190 | add-third-party 191 | 192 | 193 | 194 | 195 | 196 | org.apache.maven.plugins 197 | maven-shade-plugin 198 | ${maven-shade.version} 199 | 200 | 201 | package 202 | 203 | shade 204 | 205 | 206 | ${project.build.finalName}-jar-with-dependencies 207 | true 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | python 217 | 218 | 219 | 220 | org.apache.maven.plugins 221 | maven-install-plugin 222 | ${maven-install.version} 223 | 224 | true 225 | 226 | 227 | 228 | org.apache.maven.plugins 229 | maven-deploy-plugin 230 | ${maven-deploy.version} 231 | 232 | true 233 | 234 | 235 | 236 | org.codehaus.mojo 237 | exec-maven-plugin 238 | ${exec-maven.version} 239 | 240 | 241 | install 242 | 243 | exec 244 | 245 | install 246 | 247 | python 248 | 249 | setup.py 250 | install 251 | 252 | 253 | 254 | 255 | deploy 256 | 257 | exec 258 | 259 | deploy 260 | 261 | python 262 | 263 | setup.py 264 | sdist 265 | upload 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | docker 276 | 277 | 278 | 279 | org.apache.maven.plugins 280 | maven-resources-plugin 281 | ${maven-resources.version} 282 | 283 | 284 | package 285 | 286 | copy-resources 287 | 288 | 289 | 290 | 291 | ${basedir} 292 | true 293 | 294 | 295 | ${docker.directory} 296 | 297 | Dockerfile 298 | 299 | 300 | 301 | 302 | 303 | 304 | com.spotify 305 | dockerfile-maven-plugin 306 | ${dockerfile-maven.version} 307 | 308 | 309 | build 310 | package 311 | 312 | build 313 | 314 | 315 | 316 | push 317 | deploy 318 | 319 | push 320 | 321 | 322 | 323 | 324 | ${docker.image} 325 | ${docker.tag} 326 | false 327 | true 328 | 329 | 330 | 331 | 332 | 333 | 334 | gpu 335 | 336 | gpu 337 | docker/gpu 338 | 339 | 340 | 341 | org.tensorflow 342 | libtensorflow 343 | ${tensorflow.version} 344 | 345 | 346 | org.tensorflow 347 | libtensorflow_jni_gpu 348 | ${tensorflow.version} 349 | 350 | 351 | 352 | 353 | 354 | -------------------------------------------------------------------------------- /easybert/bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | easy-bert is a dead simple API for using Google's high quality BERT language model (https://github.com/google-research/bert). 3 | 4 | Currently, easy-bert is focused on getting embeddings from pre-trained BERT models. Support for fine-tuning and pre-training will be added in the future, 5 | as well as support for using easy-bert for other tasks besides getting embeddings. 6 | 7 | You can use easy-bert with pretrained BERT models from TensorFlow Hub or from local models in the TensorFlow saved model format. 8 | 9 | 10 | To create a BERT embedder from a TensowFlow Hub model, simply instantiate a Bert object with the target tf-hub URL: 11 | 12 | from easybert import Bert 13 | bert = Bert("https://tfhub.dev/google/bert_multi_cased_L-12_H-768_A-12/1") 14 | 15 | You can also load a local model in TensorFlow's saved model format using Bert.load: 16 | 17 | from easybert import Bert 18 | bert = Bert.load("/path/to/your/model/") 19 | 20 | Once you have a BERT model loaded, you can get sequence embeddings using bert.embed: 21 | 22 | x = bert.embed("A sequence") 23 | y = bert.embed(["Multiple", "Sequences"]) 24 | 25 | If you want per-token embeddings, you can set per_token=True: 26 | 27 | x = bert.embed("A sequence", per_token=True) 28 | y = bert.embed(["Multiple", "Sequences"], per_token=True) 29 | 30 | easy-bert returns BERT embeddings as numpy arrays 31 | 32 | 33 | Every time you call bert.embed, a new TensorFlow session is created and used for the computation. If you're calling bert.embed a lot 34 | sequentially, you can speed up your code by sharing a TensorFlow session among those calls using a with statement: 35 | 36 | with bert: 37 | x = bert.embed("A sequence", per_token=True) 38 | y = bert.embed(["Multiple", "Sequences"], per_token=True) 39 | 40 | 41 | You can save a BERT model using bert.save, then reload it later using Bert.load: 42 | 43 | bert.save("/path/to/your/model/") 44 | bert = Bert.load("/path/to/your/model/") 45 | """ 46 | from typing import Union, Iterable 47 | from types import TracebackType 48 | from pathlib import Path 49 | import json 50 | 51 | from bert.tokenization import FullTokenizer 52 | from bert import run_classifier 53 | import tensorflow_hub as hub 54 | import tensorflow as tf 55 | import numpy as np 56 | 57 | 58 | _DEFAULT_MAX_SEQUENCE_LENGTH = 128 # Max number of BERT tokens in a sequence 59 | _DEFAULT_PER_TOKEN_EMBEDDING = False # Whether to return per-token embeddings or pooled embeddings for the full sequences 60 | _MODEL_DETAILS = "model.json" # The asset file in the saved model used to store parameters for serving in other languages 61 | _VOCAB_FILE = "vocab.txt" # The asset file in the saved model that stores the vocabulary for the tokenizer 62 | 63 | 64 | class Bert(object): 65 | """ 66 | A BERT model that can be used for generating high quality sentence and word embeddings easily (https://github.com/google-research/bert) 67 | 68 | Args: 69 | tf_hub_url (str): the URL to the TensorFlow Hub model to load 70 | max_sequence_length (int): the maximum number of BERT tokens allowed in an input sequence 71 | """ 72 | def __init__(self, tf_hub_url: str, max_sequence_length: int = _DEFAULT_MAX_SEQUENCE_LENGTH) -> None: 73 | self._graph = tf.Graph() 74 | self._session = None 75 | 76 | # Initialize the BERT model 77 | with tf.Session(graph=self._graph) as session: 78 | # Download module from tf-hub 79 | bert_module = hub.Module(tf_hub_url) 80 | 81 | # Get the tokenizer from the module 82 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True) 83 | self._vocab_file, self._do_lower_case = session.run([tokenization_info["vocab_file"], tokenization_info["do_lower_case"]]) 84 | self._vocab_file = self._vocab_file.decode("UTF-8") 85 | self._do_lower_case = bool(self._do_lower_case) 86 | self._tokenizer = FullTokenizer(vocab_file=self._vocab_file, do_lower_case=self._do_lower_case) 87 | 88 | # Create symbolic input tensors as inputs to the model 89 | self._input_ids = tf.placeholder(name="input_ids", shape=(None, max_sequence_length), dtype=tf.int32) 90 | self._input_mask = tf.placeholder(name="input_mask", shape=(None, max_sequence_length), dtype=tf.int32) 91 | self._segment_ids = tf.placeholder(name="segment_ids", shape=(None, max_sequence_length), dtype=tf.int32) 92 | 93 | # Get the symbolic output tensors 94 | self._outputs = bert_module({ 95 | "input_ids": self._input_ids, 96 | "input_mask": self._input_mask, 97 | "segment_ids": self._segment_ids 98 | }, signature="tokens", as_dict=True) 99 | 100 | def __enter__(self) -> None: 101 | # Start a session 102 | if self._session is None: 103 | self._session = tf.Session(graph=self._graph) 104 | self._session.__enter__() 105 | self._session.run(tf.global_variables_initializer()) 106 | 107 | def __exit__(self, exc_type: type = None, exc_value: Exception = None, traceback: TracebackType = None) -> None: 108 | # Close an open session 109 | if self._session is not None: 110 | self._session.__exit__(exc_type, exc_value, traceback) 111 | self._session = None 112 | 113 | def embed(self, sequences: Union[str, Iterable[str]], per_token: bool = _DEFAULT_PER_TOKEN_EMBEDDING) -> np.ndarray: 114 | """ 115 | Embeds a sequence or multiple sequences using the BERT model 116 | 117 | Args: 118 | sequences (Union[str, Iterable[str]]): the sequence(s) to embed 119 | per_token (bool): whether to produce an embedding per token or a pooled embedding for the whole sequence 120 | 121 | Returns: 122 | a numpy array with the embedding(s) of the sequence(s) 123 | """ 124 | single_input = isinstance(sequences, str) 125 | if single_input: 126 | sequences = [sequences] 127 | 128 | # Convert sequnces into BERT input format 129 | input_examples = [run_classifier.InputExample(guid=None, text_a=sequence, text_b=None, label=0) for sequence in sequences] 130 | input_features = run_classifier.convert_examples_to_features(input_examples, [0], self._input_ids.shape[1], self._tokenizer) 131 | 132 | # Execute the computation graph on the inputs 133 | if self._session is not None: 134 | output = self._session.run(self._outputs["sequence_output" if per_token else "pooled_output"], feed_dict={ 135 | self._input_ids: [sequence.input_ids for sequence in input_features], 136 | self._input_mask: [sequence.input_mask for sequence in input_features], 137 | self._segment_ids: [sequence.segment_ids for sequence in input_features] 138 | }) 139 | else: 140 | with tf.Session(graph=self._graph) as session: 141 | session.run(tf.global_variables_initializer()) 142 | 143 | output = session.run(self._outputs["sequence_output" if per_token else "pooled_output"], feed_dict={ 144 | self._input_ids: [sequence.input_ids for sequence in input_features], 145 | self._input_mask: [sequence.input_mask for sequence in input_features], 146 | self._segment_ids: [sequence.segment_ids for sequence in input_features] 147 | }) 148 | 149 | if single_input: 150 | output = output.reshape(output.shape[1:]) 151 | return output 152 | 153 | def save(self, path: Union[str, Path], overwrite: bool = False) -> None: 154 | """ 155 | Saves the BERT model to a directory as a TensorFlow saved model 156 | 157 | Args: 158 | path (Union[str, Path]): the directory to save the model to 159 | overwrite (bool): whether to automatically overwrite the directory if it already exists 160 | """ 161 | if isinstance(path, str): 162 | path = Path(path) 163 | 164 | if path.exists(): 165 | if not overwrite: 166 | raise ValueError("Model path already exists and overwrite was set to False") 167 | _delete(path) 168 | 169 | if self._session is not None: 170 | tf.saved_model.simple_save(self._session, str(path), inputs={ 171 | "input_ids": self._input_ids, 172 | "input_mask": self._input_mask, 173 | "segment_ids": self._segment_ids 174 | }, outputs=self._outputs) 175 | else: 176 | with tf.Session(graph=self._graph) as session: 177 | session.run(tf.global_variables_initializer()) 178 | 179 | tf.saved_model.simple_save(session, str(path), inputs={ 180 | "input_ids": self._input_ids, 181 | "input_mask": self._input_mask, 182 | "segment_ids": self._segment_ids 183 | }, outputs=self._outputs) 184 | 185 | # Save needed information to get the tokenizer and load models in other languages 186 | with path.joinpath("assets", _MODEL_DETAILS).open("w", encoding="UTF-8") as out_file: 187 | json.dump({ 188 | "doLowerCase": self._do_lower_case, 189 | "inputIds": self._input_ids.name, 190 | "inputMask": self._input_mask.name, 191 | "segmentIds": self._segment_ids.name, 192 | "pooledOutput": self._outputs["pooled_output"].name, 193 | "sequenceOutput": self._outputs["sequence_output"].name, 194 | "maxSequenceLength": int(self._input_ids.shape[1]) 195 | }, out_file) 196 | 197 | @classmethod 198 | def load(cls, path: Union[str, Path]) -> "Bert": 199 | """ 200 | Loads a BERT model that has been saved to a directory as a TensorFlow saved model 201 | 202 | Args: 203 | path (Union[str, Path]): the directory that contains the model 204 | 205 | Returns: 206 | the saved BERT model 207 | """ 208 | if isinstance(path, str): 209 | path = Path(path) 210 | 211 | bert = cls.__new__(cls) 212 | bert._graph = tf.Graph() 213 | bert._session = None 214 | 215 | # Load graph from disk 216 | with tf.Session(graph=bert._graph) as session: 217 | bundle = tf.saved_model.load(session, ["serve"], str(path)) 218 | 219 | # Get tokenizer parameters 220 | with path.joinpath("assets", _MODEL_DETAILS).open("r", encoding="UTF-8") as in_file: 221 | details = json.load(in_file) 222 | 223 | bert._vocab_file = str(path.joinpath("assets", _VOCAB_FILE)) 224 | bert._do_lower_case = details["doLowerCase"] 225 | bert._tokenizer = FullTokenizer(vocab_file=bert._vocab_file, do_lower_case=bert._do_lower_case) 226 | 227 | # Initialize inputs/outputs for use in bert.embed 228 | bert._input_ids = bert._graph.get_tensor_by_name(bundle.signature_def["serving_default"].inputs["input_ids"].name) 229 | bert._input_mask = bert._graph.get_tensor_by_name(bundle.signature_def["serving_default"].inputs["input_mask"].name) 230 | bert._segment_ids = bert._graph.get_tensor_by_name(bundle.signature_def["serving_default"].inputs["segment_ids"].name) 231 | bert._outputs = { 232 | "pooled_output": bert._graph.get_tensor_by_name(bundle.signature_def["serving_default"].outputs["pooled_output"].name), 233 | "sequence_output": bert._graph.get_tensor_by_name(bundle.signature_def["serving_default"].outputs["sequence_output"].name) 234 | } 235 | 236 | return bert 237 | 238 | 239 | def _delete(path: Path) -> None: 240 | """ 241 | Recursively deletes a Path regardless of whether it's a file, empty directory, or non-empty directory 242 | 243 | Args: 244 | path (Path): the path to delete 245 | """ 246 | if not path.is_dir(): 247 | path.unlink() 248 | else: 249 | for subpath in path.iterdir(): 250 | _delete(subpath) 251 | path.rmdir() 252 | -------------------------------------------------------------------------------- /src/main/java/com/robrua/nlp/bert/Bert.java: -------------------------------------------------------------------------------- 1 | package com.robrua.nlp.bert; 2 | 3 | import java.io.File; 4 | import java.io.IOException; 5 | import java.io.OutputStream; 6 | import java.net.URL; 7 | import java.nio.IntBuffer; 8 | import java.nio.file.Files; 9 | import java.nio.file.Path; 10 | import java.nio.file.Paths; 11 | import java.util.Comparator; 12 | import java.util.Iterator; 13 | import java.util.List; 14 | import java.util.zip.ZipEntry; 15 | import java.util.zip.ZipInputStream; 16 | 17 | import org.tensorflow.SavedModelBundle; 18 | import org.tensorflow.Tensor; 19 | 20 | import com.fasterxml.jackson.databind.ObjectMapper; 21 | import com.google.common.collect.Lists; 22 | import com.google.common.io.Resources; 23 | 24 | /** 25 | *

26 | * easy-bert is a dead simple API for using Google's high quality BERT language model. 27 | * 28 | * The easy-bert Java bindings allow you to run pre-trained BERT models generated with easy-bert's Python tools. You can also used pre-generated models on Maven 29 | * Central. 30 | *
31 | *
32 | *

33 | * To load a model from your local filesystem, you can use: 34 | * 35 | *

36 | *
 37 |  * {@code
 38 |  * try(Bert bert = Bert.load(new File("/path/to/your/model/"))) {
 39 |  *     // Embed some sequences
 40 |  * }
 41 |  * }
 42 |  * 
43 | *
44 | * 45 | * If the model is on your classpath (e.g. if you're pulling it in via Maven), you can use: 46 | * 47 | *
48 | *
 49 |  * {@code
 50 |  * try(Bert bert = Bert.load("/resource/path/to/your/model/")) {
 51 |  *     // Embed some sequences
 52 |  * }
 53 |  * }
 54 |  * 
55 | *
56 | * 57 | * See the easy-bert GitHub Repository for information about model available via Maven Central. 58 | *
59 | *
60 | *

61 | * Once you have a BERT model loaded, you can get sequence embeddings using {@link com.robrua.nlp.bert.Bert#embedSequence(String)}, 62 | * {@link com.robrua.nlp.bert.Bert#embedSequences(String...)}, {@link com.robrua.nlp.bert.Bert#embedSequences(Iterable)}, or 63 | * {@link com.robrua.nlp.bert.Bert#embedSequences(Iterator)}: 64 | * 65 | *

66 | *
 67 |  * {@code
 68 |  * float[] embedding = bert.embedSequence("A sequence");
 69 |  * float[][] embeddings = bert.embedSequence("Multiple", "Sequences");
 70 |  * }
 71 |  * 
72 | *
73 | * 74 | * If you want per-token embeddings, you can use {@link com.robrua.nlp.bert.Bert#embedTokens(String)}, {@link com.robrua.nlp.bert.Bert#embedTokens(String...)}, 75 | * {@link com.robrua.nlp.bert.Bert#embedTokens(Iterable)}, or {@link com.robrua.nlp.bert.Bert#embedTokens(Iterator)}: 76 | * 77 | *
78 | *
 79 |  * {@code
 80 |  * float[][] embedding = bert.embedTokens("A sequence");
 81 |  * float[][][] embeddings = bert.embedTokens("Multiple", "Sequences");
 82 |  * }
 83 |  * 
84 | *
85 | * 86 | * @author Rob Rua (https://github.com/robrua) 87 | * @version 1.0.3 88 | * @since 1.0.3 89 | * 90 | * @see The easy-bert GitHub Repository 91 | * @see Google's BERT GitHub Repository 92 | */ 93 | public class Bert implements AutoCloseable { 94 | private class Inputs implements AutoCloseable { 95 | private final Tensor inputIds, inputMask, segmentIds; 96 | 97 | public Inputs(final IntBuffer inputIds, final IntBuffer inputMask, final IntBuffer segmentIds, final int count) { 98 | this.inputIds = Tensor.create(new long[] {count, model.maxSequenceLength}, inputIds); 99 | this.inputMask = Tensor.create(new long[] {count, model.maxSequenceLength}, inputMask); 100 | this.segmentIds = Tensor.create(new long[] {count, model.maxSequenceLength}, segmentIds); 101 | } 102 | 103 | @Override 104 | public void close() { 105 | inputIds.close(); 106 | inputMask.close(); 107 | segmentIds.close(); 108 | } 109 | } 110 | 111 | private static class ModelDetails { 112 | public boolean doLowerCase; 113 | public String inputIds, inputMask, segmentIds, pooledOutput, sequenceOutput; 114 | public int maxSequenceLength; 115 | } 116 | 117 | private static final int FILE_COPY_BUFFER_BYTES = 1024 * 1024; 118 | private static final String MODEL_DETAILS = "model.json"; 119 | private static final String SEPARATOR_TOKEN = "[SEP]"; 120 | private static final String START_TOKEN = "[CLS]"; 121 | private static final String VOCAB_FILE = "vocab.txt"; 122 | 123 | /** 124 | * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities 125 | * 126 | * @param model 127 | * the model to load 128 | * @return a ready-to-use BERT model 129 | * @since 1.0.3 130 | */ 131 | public static Bert load(final File model) { 132 | return load(Paths.get(model.toURI())); 133 | } 134 | 135 | /** 136 | * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities 137 | * 138 | * @param path 139 | * the path to load the model from 140 | * @return a ready-to-use BERT model 141 | * @since 1.0.3 142 | */ 143 | public static Bert load(Path path) { 144 | path = path.toAbsolutePath(); 145 | ModelDetails model; 146 | try { 147 | model = new ObjectMapper().readValue(path.resolve("assets").resolve(MODEL_DETAILS).toFile(), ModelDetails.class); 148 | } catch(final IOException e) { 149 | throw new RuntimeException(e); 150 | } 151 | 152 | return new Bert(SavedModelBundle.load(path.toString(), "serve"), model, path.resolve("assets").resolve(VOCAB_FILE)); 153 | } 154 | 155 | /** 156 | * Loads a pre-trained BERT model from a TensorFlow saved model saved by the easy-bert Python utilities. The target resource should be in .zip format. 157 | * 158 | * @param resource 159 | * the resource path to load the model from - should be in .zip format 160 | * @return a ready-to-use BERT model 161 | * @since 1.0.3 162 | */ 163 | public static Bert load(final String resource) { 164 | Path directory = null; 165 | try { 166 | // Create a temp directory to unpack the zip into 167 | final URL model = Resources.getResource(resource); 168 | directory = Files.createTempDirectory("easy-bert-"); 169 | 170 | try(ZipInputStream zip = new ZipInputStream(Resources.asByteSource(model).openBufferedStream())) { 171 | ZipEntry entry; 172 | // Copy each zip entry into the temp directory 173 | while((entry = zip.getNextEntry()) != null) { 174 | final Path path = directory.resolve(entry.getName()); 175 | if(entry.getName().endsWith("/")) { 176 | Files.createDirectories(path); 177 | } else { 178 | Files.createFile(path); 179 | 180 | try(OutputStream output = Files.newOutputStream(path)) { 181 | final byte[] buffer = new byte[FILE_COPY_BUFFER_BYTES]; 182 | int bytes; 183 | while((bytes = zip.read(buffer)) > 0) { 184 | output.write(buffer, 0, bytes); 185 | } 186 | } 187 | } 188 | zip.closeEntry(); 189 | } 190 | } 191 | 192 | // Load a BERT model from the temp directory 193 | return Bert.load(directory); 194 | } catch(final IOException e) { 195 | throw new RuntimeException(e); 196 | } finally { 197 | // Clean up the temp directory 198 | if(directory != null && Files.exists(directory)) { 199 | try { 200 | Files.walk(directory) 201 | .sorted(Comparator.reverseOrder()) 202 | .forEach((final Path file) -> { 203 | try { 204 | Files.delete(file); 205 | } catch(final IOException e) { 206 | throw new RuntimeException(e); 207 | } 208 | }); 209 | } catch(final IOException e) { 210 | throw new RuntimeException(e); 211 | } 212 | } 213 | } 214 | } 215 | 216 | private final SavedModelBundle bundle; 217 | private final ModelDetails model; 218 | private final int separatorTokenId; 219 | private final int startTokenId; 220 | private final FullTokenizer tokenizer; 221 | 222 | private Bert(final SavedModelBundle bundle, final ModelDetails model, final Path vocabulary) { 223 | tokenizer = new FullTokenizer(vocabulary, model.doLowerCase); 224 | this.bundle = bundle; 225 | this.model = model; 226 | 227 | final int[] ids = tokenizer.convert(new String[] {START_TOKEN, SEPARATOR_TOKEN}); 228 | startTokenId = ids[0]; 229 | separatorTokenId = ids[1]; 230 | } 231 | 232 | @Override 233 | public void close() { 234 | bundle.close(); 235 | } 236 | 237 | /** 238 | * Gets a pooled BERT embedding for a single sequence. Sequences are usually individual sentences, but don't have to be. 239 | * 240 | * @param sequence 241 | * the sequence to embed 242 | * @return the pooled embedding for the sequence 243 | * @since 1.0.3 244 | */ 245 | public float[] embedSequence(final String sequence) { 246 | try(Inputs inputs = getInputs(sequence)) { 247 | final List> output = bundle.session().runner() 248 | .feed(model.inputIds, inputs.inputIds) 249 | .feed(model.inputMask, inputs.inputMask) 250 | .feed(model.segmentIds, inputs.segmentIds) 251 | .fetch(model.pooledOutput) 252 | .run(); 253 | 254 | try(Tensor embedding = output.get(0)) { 255 | final float[][] converted = new float[1][(int)embedding.shape()[1]]; 256 | embedding.copyTo(converted); 257 | return converted[0]; 258 | } 259 | } 260 | } 261 | 262 | /** 263 | * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. 264 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 265 | * 266 | * @param sequences 267 | * the sequences to embed 268 | * @return the pooled embeddings for the sequences, in the order the input {@link java.lang.Iterable} provided them 269 | * @since 1.0.3 270 | */ 271 | public float[][] embedSequences(final Iterable sequences) { 272 | final List list = Lists.newArrayList(sequences); 273 | return embedSequences(list.toArray(new String[list.size()])); 274 | } 275 | 276 | /** 277 | * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. 278 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 279 | * 280 | * @param sequences 281 | * the sequences to embed 282 | * @return the pooled embeddings for the sequences, in the order the input {@link java.util.Iterator} provided them 283 | * @since 1.0.3 284 | */ 285 | public float[][] embedSequences(final Iterator sequences) { 286 | final List list = Lists.newArrayList(sequences); 287 | return embedSequences(list.toArray(new String[list.size()])); 288 | } 289 | 290 | /** 291 | * Gets pooled BERT embeddings for multiple sequences. Sequences are usually individual sentences, but don't have to be. 292 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 293 | * 294 | * @param sequences 295 | * the sequences to embed 296 | * @return the pooled embeddings for the sequences, in the order they were provided 297 | * @since 1.0.3 298 | */ 299 | public float[][] embedSequences(final String... sequences) { 300 | try(Inputs inputs = getInputs(sequences)) { 301 | final List> output = bundle.session().runner() 302 | .feed(model.inputIds, inputs.inputIds) 303 | .feed(model.inputMask, inputs.inputMask) 304 | .feed(model.segmentIds, inputs.segmentIds) 305 | .fetch(model.pooledOutput) 306 | .run(); 307 | 308 | try(Tensor embedding = output.get(0)) { 309 | final float[][] converted = new float[sequences.length][(int)embedding.shape()[1]]; 310 | embedding.copyTo(converted); 311 | return converted; 312 | } 313 | } 314 | } 315 | 316 | /** 317 | * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. 318 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 319 | * 320 | * @param sequences 321 | * the sequences to embed 322 | * @return the token embeddings for the sequences, in the order the input {@link java.lang.Iterable} provided them 323 | * @since 1.0.3 324 | */ 325 | public float[][][] embedTokens(final Iterable sequences) { 326 | final List list = Lists.newArrayList(sequences); 327 | return embedTokens(list.toArray(new String[list.size()])); 328 | } 329 | 330 | /** 331 | * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. 332 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 333 | * 334 | * @param sequences 335 | * the sequences to embed 336 | * @return the token embeddings for the sequences, in the order the input {@link java.util.Iterator} provided them 337 | * @since 1.0.3 338 | */ 339 | public float[][][] embedTokens(final Iterator sequences) { 340 | final List list = Lists.newArrayList(sequences); 341 | return embedTokens(list.toArray(new String[list.size()])); 342 | } 343 | 344 | /** 345 | * Gets BERT embeddings for each of the tokens in single sequence. Sequences are usually individual sentences, but don't have to be. 346 | * 347 | * @param sequence 348 | * the sequence to embed 349 | * @return the token embeddings for the sequence 350 | * @since 1.0.3 351 | */ 352 | public float[][] embedTokens(final String sequence) { 353 | try(Inputs inputs = getInputs(sequence)) { 354 | final List> output = bundle.session().runner() 355 | .feed(model.inputIds, inputs.inputIds) 356 | .feed(model.inputMask, inputs.inputMask) 357 | .feed(model.segmentIds, inputs.segmentIds) 358 | .fetch(model.sequenceOutput) 359 | .run(); 360 | 361 | try(Tensor embedding = output.get(0)) { 362 | final float[][][] converted = new float[1][(int)embedding.shape()[1]][(int)embedding.shape()[2]]; 363 | embedding.copyTo(converted); 364 | return converted[0]; 365 | } 366 | } 367 | } 368 | 369 | /** 370 | * Gets BERT embeddings for each of the tokens in multiple sequences. Sequences are usually individual sentences, but don't have to be. 371 | * The sequences will be processed in parallel as a single batch input to the TensorFlow model. 372 | * 373 | * @param sequences 374 | * the sequences to embed 375 | * @return the token embeddings for the sequences, in the order they were provided 376 | * @since 1.0.3 377 | */ 378 | public float[][][] embedTokens(final String... sequences) { 379 | try(Inputs inputs = getInputs(sequences)) { 380 | final List> output = bundle.session().runner() 381 | .feed(model.inputIds, inputs.inputIds) 382 | .feed(model.inputMask, inputs.inputMask) 383 | .feed(model.segmentIds, inputs.segmentIds) 384 | .fetch(model.sequenceOutput) 385 | .run(); 386 | 387 | try(Tensor embedding = output.get(0)) { 388 | final float[][][] converted = new float[sequences.length][(int)embedding.shape()[1]][(int)embedding.shape()[2]]; 389 | embedding.copyTo(converted); 390 | return converted; 391 | } 392 | } 393 | } 394 | 395 | private Inputs getInputs(final String sequence) { 396 | final String[] tokens = tokenizer.tokenize(sequence); 397 | 398 | final IntBuffer inputIds = IntBuffer.allocate(model.maxSequenceLength); 399 | final IntBuffer inputMask = IntBuffer.allocate(model.maxSequenceLength); 400 | final IntBuffer segmentIds = IntBuffer.allocate(model.maxSequenceLength); 401 | 402 | /* 403 | * In BERT: 404 | * inputIds are the indexes in the vocabulary for each token in the sequence 405 | * inputMask is a binary mask that shows which inputIds have valid data in them 406 | * segmentIds are meant to distinguish paired sequences during training tasks. Here they're always 0 since we're only doing inference. 407 | */ 408 | final int[] ids = tokenizer.convert(tokens); 409 | inputIds.put(startTokenId); 410 | inputMask.put(1); 411 | segmentIds.put(0); 412 | for(int i = 0; i < ids.length && i < model.maxSequenceLength - 2; i++) { 413 | inputIds.put(ids[i]); 414 | inputMask.put(1); 415 | segmentIds.put(0); 416 | } 417 | inputIds.put(separatorTokenId); 418 | inputMask.put(1); 419 | segmentIds.put(0); 420 | 421 | while(inputIds.position() < model.maxSequenceLength) { 422 | inputIds.put(0); 423 | inputMask.put(0); 424 | segmentIds.put(0); 425 | } 426 | 427 | inputIds.rewind(); 428 | inputMask.rewind(); 429 | segmentIds.rewind(); 430 | 431 | return new Inputs(inputIds, inputMask, segmentIds, 1); 432 | } 433 | 434 | private Inputs getInputs(final String[] sequences) { 435 | final String[][] tokens = tokenizer.tokenize(sequences); 436 | 437 | final IntBuffer inputIds = IntBuffer.allocate(sequences.length * model.maxSequenceLength); 438 | final IntBuffer inputMask = IntBuffer.allocate(sequences.length * model.maxSequenceLength); 439 | final IntBuffer segmentIds = IntBuffer.allocate(sequences.length * model.maxSequenceLength); 440 | 441 | /* 442 | * In BERT: 443 | * inputIds are the indexes in the vocabulary for each token in the sequence 444 | * inputMask is a binary mask that shows which inputIds have valid data in them 445 | * segmentIds are meant to distinguish paired sequences during training tasks. Here they're always 0 since we're only doing inference. 446 | */ 447 | int instance = 1; 448 | for(final String[] token : tokens) { 449 | final int[] ids = tokenizer.convert(token); 450 | inputIds.put(startTokenId); 451 | inputMask.put(1); 452 | segmentIds.put(0); 453 | for(int i = 0; i < ids.length && i < model.maxSequenceLength - 2; i++) { 454 | inputIds.put(ids[i]); 455 | inputMask.put(1); 456 | segmentIds.put(0); 457 | } 458 | inputIds.put(separatorTokenId); 459 | inputMask.put(1); 460 | segmentIds.put(0); 461 | 462 | while(inputIds.position() < model.maxSequenceLength * instance) { 463 | inputIds.put(0); 464 | inputMask.put(0); 465 | segmentIds.put(0); 466 | } 467 | instance++; 468 | } 469 | 470 | inputIds.rewind(); 471 | inputMask.rewind(); 472 | segmentIds.rewind(); 473 | 474 | return new Inputs(inputIds, inputMask, segmentIds, sequences.length); 475 | } 476 | } 477 | --------------------------------------------------------------------------------