├── src
├── test
│ ├── resources
│ │ └── test-vocabularies
│ │ │ ├── merges.txt
│ │ │ ├── vocabulary.json
│ │ │ └── base_vocabulary.json
│ └── java
│ │ └── com
│ │ └── genesys
│ │ └── roberta
│ │ └── tokenizer
│ │ ├── utils
│ │ └── CommonTestUtils.java
│ │ ├── RobertaTokenizerResourcesTest.java
│ │ ├── BytePairEncoderTest.java
│ │ └── RobertaTokenizerTest.java
└── main
│ ├── java
│ └── com
│ │ └── genesys
│ │ └── roberta
│ │ └── tokenizer
│ │ ├── Tokenizer.java
│ │ ├── BiGram.java
│ │ ├── RobertaTokenizer.java
│ │ ├── BytePairEncoder.java
│ │ └── RobertaTokenizerResources.java
│ └── resources
│ └── checkstyle.xml
├── .gitignore
├── LICENSE
├── README.md
└── pom.xml
/src/test/resources/test-vocabularies/merges.txt:
--------------------------------------------------------------------------------
1 | \u0120 l
2 | \u0120l o
3 | \u0120lo w
4 | e r
5 |
--------------------------------------------------------------------------------
/src/test/resources/test-vocabularies/vocabulary.json:
--------------------------------------------------------------------------------
1 | {
2 | "": 0,
3 | "": 1,
4 | "": 2,
5 | "": 3,
6 | "l": 4,
7 | "o": 5,
8 | "w": 6,
9 | "e": 7,
10 | "r": 8,
11 | "s": 9,
12 | "t": 10,
13 | "i": 11,
14 | "d": 12,
15 | "n": 13,
16 | "\u0120": 114,
17 | "\u0120l": 15,
18 | "\u0120n": 16,
19 | "\u0120lo": 17,
20 | "\u0120low": 18,
21 | "er": 19,
22 | "\u0120lowest": 20,
23 | "\u0120newer": 21,
24 | "\u0120wider": 22
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/java/com/genesys/roberta/tokenizer/Tokenizer.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | /**
4 | * The use of this interface will have the ability to tokenize given String inputs
5 | */
6 | interface Tokenizer {
7 |
8 | /**
9 | * Converts given input sentence to an array of long tokens
10 | *
11 | * @param sentence One or more words delimited by space
12 | * @return list of input IDs with the appropriate tokens
13 | */
14 | long[] tokenize(String sentence);
15 | }
16 |
--------------------------------------------------------------------------------
/src/test/java/com/genesys/roberta/tokenizer/utils/CommonTestUtils.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer.utils;
2 |
3 | import com.genesys.roberta.tokenizer.RobertaTokenizer;
4 |
5 | import java.io.File;
6 | import java.util.Objects;
7 |
8 | public class CommonTestUtils {
9 |
10 | public static String getResourceAbsPath() {
11 | String resourceRelPath = "test-vocabularies";
12 | return new File(Objects.requireNonNull(RobertaTokenizer.class.getClassLoader().getResource(resourceRelPath))
13 | .getFile()).getAbsolutePath();
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Node artifact files
2 | node_modules/
3 | dist/
4 |
5 | # Compiled Java class files
6 | *.class
7 |
8 | # Compiled Python bytecode
9 | *.py[cod]
10 |
11 | # Log files
12 | *.log
13 |
14 | # Package files
15 | *.jar
16 |
17 | # Maven
18 | target/
19 | dist/
20 |
21 | # JetBrains IDE
22 | OLD.idea/
23 | .idea/
24 | *.iml
25 |
26 | # Unit test reports
27 | TEST*.xml
28 |
29 | # Generated by MacOS
30 | .DS_Store
31 | /target/
32 | __pycache__/
33 |
34 | # Generated by Windows
35 | Thumbs.db
36 |
37 | # Applications
38 | *.app
39 | *.exe
40 | *.war
41 |
42 | # Large media files
43 | *.mp4
44 | *.tiff
45 | *.avi
46 | *.flv
47 | *.mov
48 | *.wmv
49 |
50 | **/test-output
51 | **/target
52 | **/resources/static/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Genesys Cloud Services, Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/src/main/java/com/genesys/roberta/tokenizer/BiGram.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import lombok.EqualsAndHashCode;
4 | import lombok.NonNull;
5 |
6 | import static com.google.common.base.Preconditions.checkState;
7 | import static java.lang.String.format;
8 |
9 | /**
10 | * A sequence of two adjacent elements from a string which differs by their position - left or right
11 | */
12 | @EqualsAndHashCode
13 | class BiGram {
14 | private static final int PAIR_SIZE = 2;
15 | private final String left;
16 | private final String right;
17 |
18 | private BiGram(@NonNull final String[] pairArray) {
19 | checkState(pairArray.length == PAIR_SIZE,
20 | format("Expecting BiGram pair to be of size: [%d] but it's of size: [%d]", PAIR_SIZE, pairArray.length));
21 | this.left = pairArray[0];
22 | this.right = pairArray[1];
23 | }
24 |
25 | private BiGram(@NonNull final String left, @NonNull final String right) {
26 | this.left = left;
27 | this.right = right;
28 | }
29 |
30 | /**
31 | * Creates a new BiGram object from an array of Strings.
32 | * Expecting the array to be of size 2.
33 | * @param pairArray array of Strings
34 | * @return new BiGram where the String pairArray[0] will be left and pairArray[1] right.
35 | */
36 | public static BiGram of(@NonNull final String[] pairArray) {
37 | return new BiGram(pairArray);
38 | }
39 |
40 | /**
41 | * Creates an object with given parameters.
42 | * @param left will be the left String of the BiGRam
43 | * @param right will be the right String of the BiGRam
44 | * @return new BiGram object
45 | */
46 | public static BiGram of(@NonNull final String left, @NonNull final String right) {
47 | return new BiGram(left, right);
48 | }
49 |
50 | public String getLeft() {
51 | return this.left;
52 | }
53 |
54 | public String getRight() {
55 | return this.right;
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/src/test/java/com/genesys/roberta/tokenizer/RobertaTokenizerResourcesTest.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import lombok.val;
4 | import org.testng.Assert;
5 | import org.testng.annotations.BeforeClass;
6 | import org.testng.annotations.Test;
7 |
8 | import static com.genesys.roberta.tokenizer.utils.CommonTestUtils.getResourceAbsPath;
9 |
10 | public class RobertaTokenizerResourcesTest {
11 |
12 | private static final String VOCABULARY_BASE_DIR_PATH = getResourceAbsPath();
13 | private static final long UNKNOWN_TOKEN = RobertaTokenizer.DEFAULT_UNK_TOKEN;
14 | private RobertaTokenizerResources robertaTokenizerResources;
15 |
16 | @BeforeClass
17 | public void initDataMembersBeforeClass() {
18 | robertaTokenizerResources = new RobertaTokenizerResources(VOCABULARY_BASE_DIR_PATH);
19 | }
20 |
21 | @Test(expectedExceptions = NullPointerException.class)
22 | public void nullBaseDirPath() {
23 | new RobertaTokenizerResources(null);
24 | }
25 |
26 | @Test(expectedExceptions = IllegalStateException.class)
27 | public void vocabularyBaseDirPathNotExist() {
28 | new RobertaTokenizerResources("dummy/base/dir/path");
29 | }
30 |
31 | @Test
32 | public void minByteValue() {
33 | byte key = -128;
34 | val encodedChar = robertaTokenizerResources.encodeByte(key);
35 | Assert.assertNotNull(encodedChar);
36 | }
37 |
38 | @Test
39 | public void maxByteValue() {
40 | byte key = 127;
41 | val encodedChar = robertaTokenizerResources.encodeByte(key);
42 | Assert.assertNotNull(encodedChar);
43 | }
44 |
45 | @Test
46 | public void wordDoesNotExist() {
47 | String word = "Funnel";
48 | Long actualToken = robertaTokenizerResources.encodeWord(word, UNKNOWN_TOKEN);
49 | Assert.assertEquals(actualToken.longValue(), UNKNOWN_TOKEN);
50 | }
51 |
52 | @Test
53 | public void wordExists() {
54 | String word = "er";
55 | long expectedToken = 19;
56 | Long actualToken = robertaTokenizerResources.encodeWord(word, UNKNOWN_TOKEN);
57 | Assert.assertEquals(actualToken.longValue(), expectedToken);
58 | }
59 |
60 | @Test
61 | public void pairExists() {
62 | BiGram bigram = BiGram.of("e", "r");
63 | int actualRank = robertaTokenizerResources.getRankOrDefault(bigram, Integer.MAX_VALUE);
64 | int expectedRank = 3;
65 | Assert.assertEquals(actualRank, expectedRank);
66 | }
67 |
68 | @Test
69 | public void pairDoesNotExist() {
70 | BiGram bigram = BiGram.of("Zilpa", "Funnel");
71 | int actualRank = robertaTokenizerResources.getRankOrDefault(bigram, Integer.MAX_VALUE);
72 | Assert.assertEquals(actualRank, Integer.MAX_VALUE);
73 | }
74 | }
75 |
--------------------------------------------------------------------------------
/src/test/java/com/genesys/roberta/tokenizer/BytePairEncoderTest.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import org.mockito.Mock;
4 | import org.mockito.MockitoAnnotations;
5 | import org.testng.Assert;
6 | import org.testng.annotations.BeforeClass;
7 | import org.testng.annotations.Test;
8 |
9 | import java.util.Arrays;
10 | import java.util.HashMap;
11 | import java.util.List;
12 | import java.util.Map;
13 |
14 | import static org.mockito.ArgumentMatchers.any;
15 | import static org.mockito.ArgumentMatchers.anyInt;
16 | import static org.mockito.Mockito.when;
17 |
18 |
19 | public class BytePairEncoderTest {
20 |
21 | private BytePairEncoder bytePairEncoder;
22 | private Map ranks;
23 |
24 | @Mock
25 | private RobertaTokenizerResources robertaTokenizerResources;
26 |
27 | @BeforeClass
28 | public void setupBeforeClass() {
29 | MockitoAnnotations.openMocks(this);
30 | ranks = new HashMap<>() {{
31 | put(BiGram.of("Ġ", "l"), 0);
32 | put(BiGram.of("Ġl", "o"), 1);
33 | put(BiGram.of("Ġlo", "w"), 2);
34 | put(BiGram.of("e", "r"), 3);
35 | }};
36 | bytePairEncoder = new BytePairEncoder();
37 | when(robertaTokenizerResources.getRankOrDefault(any(BiGram.class), anyInt()))
38 | .thenAnswer(input -> ranks.getOrDefault(input.getArgument(0), Integer.MAX_VALUE));
39 | }
40 |
41 | @Test(expectedExceptions = NullPointerException.class)
42 | public void nullWordTest() {
43 | bytePairEncoder.encode(null, robertaTokenizerResources);
44 | }
45 |
46 | @Test
47 | public void correctSplitTest() {
48 | List actualSplit = bytePairEncoder.encode("lowerĠnewer", robertaTokenizerResources);
49 | // The vocabulary rules and characters were taken from here:
50 | // https://github.com/huggingface/transformers/blob/v4.20.1/tests/models/roberta/test_tokenization_roberta.py#L86
51 | List expectedSplit = Arrays.asList("l", "o", "w", "er", "Ġ", "n", "e", "w", "er");
52 | Assert.assertEquals(actualSplit, expectedSplit);
53 | }
54 |
55 | @Test
56 | public void emptySplitTest() {
57 | List actualSplit = bytePairEncoder.encode("", robertaTokenizerResources);
58 | Assert.assertTrue(actualSplit.isEmpty());
59 | }
60 |
61 | @Test
62 | public void noMergeRulesForWordTest() {
63 | List actualSplit = bytePairEncoder.encode("qpyt", robertaTokenizerResources);
64 | // Since all these characters do not appear at all at the ranks map i.e., no merge rule for any of them
65 | // we would expect each one to be encoded alone
66 | List expectedSplit = Arrays.asList("q", "p", "y", "t");
67 | Assert.assertEquals(actualSplit, expectedSplit);
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RoBERTa Java Tokenizer #
2 |
3 |
4 | ## About
5 |
6 | ---
7 | This repo contains a Java tokenizer used by RoBERTa model. The implementation is mainly according to HuggingFace Python
8 | RoBERTa Tokenizer, but also we took references from other implementations as mentioned in the code and below:
9 |
10 | * https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer
11 |
12 | * https://github.com/huggingface/tflite-android-transformers/blob/master/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt
13 |
14 | * https://github.com/hyunwoongko/gpt2-tokenizer-java/blob/master/src/main/java/ai/tunib/tokenizer/GPT2Tokenizer.java
15 |
16 | The algorithm used is a byte-level Byte Pair Encoding.
17 |
18 | https://huggingface.co/docs/transformers/tokenizer_summary#bytelevel-bpe
19 | ## How do I get set up? ###
20 |
21 | ---
22 |
23 | * Clone the repo for explicit usage.
24 | * Add the Maven dependency to your `pom.xml` for usage in your project:
25 |
26 | ```
27 |
28 | cloud.genesys
29 | roberta-tokenizer
30 | 1.0.7
31 |
32 |
33 |
34 |
35 | ossrh
36 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
37 |
38 | ...
39 |
40 | ```
41 |
42 |
43 | ### Tests ###
44 |
45 | ---
46 |
47 | * Unit tests - Run on local machine.
48 |
49 | ### File Dependencies ###
50 |
51 | ---
52 |
53 | Since we want efficiency when initializing the tokenizer, we use a factory to create the relevant resources
54 | files and create it "lazily".
55 |
56 | For this tokenizer we need 3 data files:
57 |
58 | * `base_vocabulary.json` - map of numbers ([0,255]) to symbols (UniCode Characters). Only those symbols will be known by the
59 | algorithm. e.g., given _s_ as input it iterates over the bytes of the String _s_ and replaces each given byte with the mapped symbol.
60 | This way we assure what characters are passed.
61 |
62 | * `vocabulary.json` - Is a file that holds all the words(sub-words) and their token according to training.
63 |
64 | * `merges.txt` - describes the merge rules of words. The algorithm splits the given word into two subwords, afterwards
65 | it decides the best split according to the rank of the sub words. The higher those words are, the higher the rank.
66 |
67 | __Please note__:
68 |
69 | 1. All three files must be under the same directory.
70 |
71 | 2. They must be named like mentioned above.
72 |
73 | 3. The result of the tokenization depends on the vocabulary and merges files.
74 |
75 | ### Example ###
76 |
77 | ---
78 |
79 | ```
80 |
81 | String baseDirPath = "base/dir/path";
82 | RobertaTokenizerResources robertaResources = new RobertaTokenizerResources(baseDirPath);
83 | Tokenizer robertaTokenizer = new RobertaTokenizer(robertaResources);
84 | ...
85 | String sentence = "this must be the place";
86 | long[] tokenizedSentence = robertaTokenizer.tokenize(sentence);
87 | System.out.println(tokenizedSentence);
88 |
89 | ```
90 |
91 | An example output would be: `[0, 9226, 531, 28, 5, 317, 2]` - Depends on the given vocabulary and merges files.
92 |
93 | ### Contribution guidelines
94 |
95 | ---
96 |
97 | * Use temporary branches for every issue/task.
98 |
--------------------------------------------------------------------------------
/src/main/java/com/genesys/roberta/tokenizer/RobertaTokenizer.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import lombok.NonNull;
4 | import lombok.val;
5 |
6 | import java.nio.charset.StandardCharsets;
7 | import java.util.ArrayList;
8 | import java.util.List;
9 | import java.util.regex.Matcher;
10 | import java.util.regex.Pattern;
11 | import java.util.stream.LongStream;
12 |
13 | import static java.util.stream.LongStream.concat;
14 | import static java.util.stream.LongStream.of;
15 |
16 | /**
17 | * Tokenizer used for the RoBERTa model.
18 | * Encode sentences to integer tokens.
19 | *
20 | * This tokenizer is implemented according to the following:
21 | * - https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer
22 | * - https://github.com/hyunwoongko/gpt2-tokenizer-java/blob/master/src/main/java/ai/tunib/tokenizer/GPT2Tokenizer.java
23 | * - https://github.com/huggingface/tflite-android-transformers/blob/master/gpt2/src/main/java/co/huggingface/android_transformers/\
24 | * gpt2/tokenization/GPT2Tokenizer.kt
25 | */
26 | public class RobertaTokenizer implements Tokenizer {
27 |
28 | public static final long DEFAULT_CLS_TOKEN = 0;
29 | public static final long DEFAULT_SEP_TOKEN = 2;
30 | public static final long DEFAULT_UNK_TOKEN = 3;
31 |
32 | //splits a given sentence by space in to words or sub-words
33 | private static final Pattern PATTERN = Pattern
34 | .compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+");
35 |
36 | // Special tokens
37 | private final long clsToken; // Also BOS (beginning of sequence) token
38 | private final long sepToken; // Also EOS (end of sequence) token
39 | private final long unkToken; // Unknown Token.
40 |
41 | private final RobertaTokenizerResources robertaResources;
42 | private final BytePairEncoder bytePairEncoder;
43 |
44 | /**
45 | * Constructs a RoBERTa tokenizer, using byte-level Byte-Pair-Encoding.
46 | *
47 | * @param robertaTokenizerResources - responsible for providing roberta vocabularies and merges files.
48 | *
49 | * Note that this constructor will use HuggingFace's default special tokens:
50 | * [CLS_TOKEN = 0, SEP_TOKEN = 2, UNK_TOKEN = 3]
51 | */
52 | public RobertaTokenizer(@NonNull final RobertaTokenizerResources robertaTokenizerResources) {
53 | this(robertaTokenizerResources, DEFAULT_CLS_TOKEN, DEFAULT_SEP_TOKEN, DEFAULT_UNK_TOKEN);
54 | }
55 |
56 | /**
57 | * Constructs a RoBERTa tokenizer, using byte-level Byte-Pair-Encoding.
58 | *
59 | * @param robertaTokenizerResources - responsible for providing roberta vocabularies and merges files.
60 | * @param clsToken Classification token
61 | * @param sepToken Separator token
62 | * @param unkToken Unknown token
63 | */
64 | public RobertaTokenizer(@NonNull final RobertaTokenizerResources robertaTokenizerResources, final long clsToken,
65 | final long sepToken, final long unkToken) {
66 | this.robertaResources = robertaTokenizerResources;
67 | this.bytePairEncoder = new BytePairEncoder();
68 | this.clsToken = clsToken;
69 | this.sepToken = sepToken;
70 | this.unkToken = unkToken;
71 | }
72 |
73 | /**
74 | * Encodes the given word into a list of tokens (long numbers) using Byte Level Byte-Pair-Encoding.
75 | *
76 | * @param sentence a word or more divided by space
77 | * @return an array of tokens (long) values
78 | */
79 | @Override
80 | public long[] tokenize(@NonNull final String sentence) {
81 | List encodedStrings = new ArrayList<>();
82 |
83 | Matcher matcher = PATTERN.matcher(sentence);
84 | while (matcher.find()) {
85 | String matchedSequence = matcher.group();
86 | val matchedSequenceEncoded = new StringBuilder();
87 |
88 | for (byte b : matchedSequence.getBytes(StandardCharsets.UTF_8)) {
89 | String encodedByte = this.robertaResources.encodeByte(b);
90 | matchedSequenceEncoded.append(encodedByte);
91 | }
92 |
93 | encodedStrings.add(matchedSequenceEncoded.toString());
94 | }
95 |
96 | LongStream outputTokens = encodedStrings.stream()
97 | // returns list of strings ready for vocabulary mapping
98 | .map(encodedStr -> bytePairEncoder.encode(encodedStr, robertaResources))
99 | // mapping each word in the given lists to a Long token from the vocabulary
100 | .flatMapToLong(encodedStrList -> encodedStrList.stream()
101 | .mapToLong(word -> this.robertaResources.encodeWord(word, unkToken)));
102 |
103 | outputTokens = concat(of(clsToken), outputTokens); // adding BOS
104 | return concat(outputTokens, of(sepToken)).toArray(); // adding EOS
105 | }
106 |
107 | public long getClsToken() {
108 | return clsToken;
109 | }
110 |
111 | public long getSepToken() {
112 | return sepToken;
113 | }
114 |
115 | public long getUnkToken() {
116 | return unkToken;
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/src/main/java/com/genesys/roberta/tokenizer/BytePairEncoder.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import lombok.NonNull;
4 |
5 | import java.util.ArrayList;
6 | import java.util.List;
7 | import java.util.Set;
8 | import java.util.stream.Collectors;
9 | import java.util.stream.IntStream;
10 |
11 | /**
12 | * Byte-Pair-Encoding
13 | * Relies on a pre-tokenizer that splits the training data into words, in our case space.
14 | *
15 | * This greedy algorithm looks for the best way to divide given input word.
16 | * It does that by dividing the word into characters, then assembles sub strings of the given word trying to find the best
17 | * partition of the word according to the ranks of the merges file.
18 | */
19 | class BytePairEncoder {
20 |
21 | /**
22 | * Applies the byte level BPE algorithm on the given word
23 | *
24 | * @param word one word from an input sentence
25 | * @param robertaTokenizerRobertaResources holds the vocabulary resources
26 | * @return a list of strings optimally partitioned and ready for tokenization
27 | */
28 | public List encode(@NonNull final String word, @NonNull RobertaTokenizerResources robertaTokenizerRobertaResources) {
29 | List wordCharactersStrList = word.chars()
30 | .mapToObj(Character::toString)
31 | .collect(Collectors.toList());
32 |
33 | Set biGramsSet = getBiGrams(wordCharactersStrList);
34 |
35 | while (true) {
36 | long minScore = Integer.MAX_VALUE;
37 | BiGram lowestScoreBiGram = null;
38 |
39 | for (BiGram biGram : biGramsSet) {
40 | long score = robertaTokenizerRobertaResources.getRankOrDefault(biGram, Integer.MAX_VALUE);
41 |
42 | // Note that we turn the most frequent bi-gram from a max problem to minimum
43 | // The lower the score the higher the frequency
44 | if (score < minScore) {
45 | minScore = score;
46 | lowestScoreBiGram = biGram;
47 | }
48 | }
49 |
50 | // Reaching here means that only BiGrams that aren’t in the vocabulary (got rank Integer.MAX_VALUE) are left in
51 | // wordCharactersStrList, so no more merges should be done and the final tokenized word is the current wordCharactersStrList.
52 | if (lowestScoreBiGram == null) {
53 | break;
54 | }
55 |
56 | String first = lowestScoreBiGram.getLeft();
57 | String second = lowestScoreBiGram.getRight();
58 | List newWordList = new ArrayList<>();
59 | int currIdx = 0;
60 |
61 | while (currIdx < wordCharactersStrList.size()) {
62 | int biGramStartIndex = getIndexWithStartPosition(wordCharactersStrList, first, currIdx);
63 |
64 | if (biGramStartIndex != -1) {
65 | newWordList.addAll(wordCharactersStrList.subList(currIdx, biGramStartIndex));
66 | currIdx = biGramStartIndex;
67 | } else {
68 | newWordList.addAll(wordCharactersStrList.subList(currIdx, wordCharactersStrList.size()));
69 | break;
70 | }
71 |
72 | if (wordCharactersStrList.get(currIdx).equals(first) && currIdx < wordCharactersStrList.size() - 1 &&
73 | wordCharactersStrList.get(currIdx + 1).equals(second)) {
74 | newWordList.add(first + second);
75 | currIdx += 2;
76 | } else {
77 | newWordList.add(wordCharactersStrList.get(currIdx));
78 | currIdx += 1;
79 | }
80 | }
81 |
82 | wordCharactersStrList = newWordList;
83 | if (wordCharactersStrList.size() == 1) {
84 | break;
85 | } else {
86 | biGramsSet = getBiGrams(wordCharactersStrList);
87 | }
88 | }
89 |
90 | return wordCharactersStrList;
91 | }
92 |
93 | /**
94 | *
95 | * @param wordStrChars all characters of the word represented each by a String
96 | * @return list of all adjacent biGrams
97 | * e.g., "hello" will be given as input: ["h", "e", "l", "l", "o"] and will return {"he", "el","ll", "lo"}
98 | */
99 | private Set getBiGrams(@NonNull final List wordStrChars) {
100 | return IntStream.range(0, wordStrChars.size() - 1)
101 | .mapToObj(i -> BiGram.of(wordStrChars.get(i), wordStrChars.get(i + 1)))
102 | .collect(Collectors.toSet());
103 | }
104 |
105 | /**
106 | * Looking for given word in wordCharsList and returns the index if found
107 | *
108 | * @param wordCharsList list of characters represented as Strings
109 | * @param word given word to search for
110 | * @param startPosition an index to start the search from
111 | * @return the index found o.w. -1
112 | */
113 | private int getIndexWithStartPosition(@NonNull final List wordCharsList, @NonNull final String word,
114 | final int startPosition) {
115 | return IntStream.range(startPosition, wordCharsList.size())
116 | .filter(idx -> wordCharsList.get(idx).equals(word))
117 | .findFirst()
118 | .orElse(-1);
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/src/test/resources/test-vocabularies/base_vocabulary.json:
--------------------------------------------------------------------------------
1 | {
2 | "33": "!",
3 | "34": "\"",
4 | "35": "#",
5 | "36": "$",
6 | "37": "%",
7 | "38": "&",
8 | "39": "'",
9 | "40": "(",
10 | "41": ")",
11 | "42": "*",
12 | "43": "+",
13 | "44": ",",
14 | "45": "-",
15 | "46": ".",
16 | "47": "/",
17 | "48": "0",
18 | "49": "1",
19 | "50": "2",
20 | "51": "3",
21 | "52": "4",
22 | "53": "5",
23 | "54": "6",
24 | "55": "7",
25 | "56": "8",
26 | "57": "9",
27 | "58": ":",
28 | "59": ";",
29 | "60": "<",
30 | "61": "=",
31 | "62": ">",
32 | "63": "?",
33 | "64": "@",
34 | "65": "A",
35 | "66": "B",
36 | "67": "C",
37 | "68": "D",
38 | "69": "E",
39 | "70": "F",
40 | "71": "G",
41 | "72": "H",
42 | "73": "I",
43 | "74": "J",
44 | "75": "K",
45 | "76": "L",
46 | "77": "M",
47 | "78": "N",
48 | "79": "O",
49 | "80": "P",
50 | "81": "Q",
51 | "82": "R",
52 | "83": "S",
53 | "84": "T",
54 | "85": "U",
55 | "86": "V",
56 | "87": "W",
57 | "88": "X",
58 | "89": "Y",
59 | "90": "Z",
60 | "91": "[",
61 | "92": "\\",
62 | "93": "]",
63 | "94": "^",
64 | "95": "_",
65 | "96": "`",
66 | "97": "a",
67 | "98": "b",
68 | "99": "c",
69 | "100": "d",
70 | "101": "e",
71 | "102": "f",
72 | "103": "g",
73 | "104": "h",
74 | "105": "i",
75 | "106": "j",
76 | "107": "k",
77 | "108": "l",
78 | "109": "m",
79 | "110": "n",
80 | "111": "o",
81 | "112": "p",
82 | "113": "q",
83 | "114": "r",
84 | "115": "s",
85 | "116": "t",
86 | "117": "u",
87 | "118": "v",
88 | "119": "w",
89 | "120": "x",
90 | "121": "y",
91 | "122": "z",
92 | "123": "{",
93 | "124": "|",
94 | "125": "}",
95 | "126": "~",
96 | "161": "\u00a1",
97 | "162": "\u00a2",
98 | "163": "\u00a3",
99 | "164": "\u00a4",
100 | "165": "\u00a5",
101 | "166": "\u00a6",
102 | "167": "\u00a7",
103 | "168": "\u00a8",
104 | "169": "\u00a9",
105 | "170": "\u00aa",
106 | "171": "\u00ab",
107 | "172": "\u00ac",
108 | "174": "\u00ae",
109 | "175": "\u00af",
110 | "176": "\u00b0",
111 | "177": "\u00b1",
112 | "178": "\u00b2",
113 | "179": "\u00b3",
114 | "180": "\u00b4",
115 | "181": "\u00b5",
116 | "182": "\u00b6",
117 | "183": "\u00b7",
118 | "184": "\u00b8",
119 | "185": "\u00b9",
120 | "186": "\u00ba",
121 | "187": "\u00bb",
122 | "188": "\u00bc",
123 | "189": "\u00bd",
124 | "190": "\u00be",
125 | "191": "\u00bf",
126 | "192": "\u00c0",
127 | "193": "\u00c1",
128 | "194": "\u00c2",
129 | "195": "\u00c3",
130 | "196": "\u00c4",
131 | "197": "\u00c5",
132 | "198": "\u00c6",
133 | "199": "\u00c7",
134 | "200": "\u00c8",
135 | "201": "\u00c9",
136 | "202": "\u00ca",
137 | "203": "\u00cb",
138 | "204": "\u00cc",
139 | "205": "\u00cd",
140 | "206": "\u00ce",
141 | "207": "\u00cf",
142 | "208": "\u00d0",
143 | "209": "\u00d1",
144 | "210": "\u00d2",
145 | "211": "\u00d3",
146 | "212": "\u00d4",
147 | "213": "\u00d5",
148 | "214": "\u00d6",
149 | "215": "\u00d7",
150 | "216": "\u00d8",
151 | "217": "\u00d9",
152 | "218": "\u00da",
153 | "219": "\u00db",
154 | "220": "\u00dc",
155 | "221": "\u00dd",
156 | "222": "\u00de",
157 | "223": "\u00df",
158 | "224": "\u00e0",
159 | "225": "\u00e1",
160 | "226": "\u00e2",
161 | "227": "\u00e3",
162 | "228": "\u00e4",
163 | "229": "\u00e5",
164 | "230": "\u00e6",
165 | "231": "\u00e7",
166 | "232": "\u00e8",
167 | "233": "\u00e9",
168 | "234": "\u00ea",
169 | "235": "\u00eb",
170 | "236": "\u00ec",
171 | "237": "\u00ed",
172 | "238": "\u00ee",
173 | "239": "\u00ef",
174 | "240": "\u00f0",
175 | "241": "\u00f1",
176 | "242": "\u00f2",
177 | "243": "\u00f3",
178 | "244": "\u00f4",
179 | "245": "\u00f5",
180 | "246": "\u00f6",
181 | "247": "\u00f7",
182 | "248": "\u00f8",
183 | "249": "\u00f9",
184 | "250": "\u00fa",
185 | "251": "\u00fb",
186 | "252": "\u00fc",
187 | "253": "\u00fd",
188 | "254": "\u00fe",
189 | "255": "\u00ff",
190 | "0": "\u0100",
191 | "1": "\u0101",
192 | "2": "\u0102",
193 | "3": "\u0103",
194 | "4": "\u0104",
195 | "5": "\u0105",
196 | "6": "\u0106",
197 | "7": "\u0107",
198 | "8": "\u0108",
199 | "9": "\u0109",
200 | "10": "\u010a",
201 | "11": "\u010b",
202 | "12": "\u010c",
203 | "13": "\u010d",
204 | "14": "\u010e",
205 | "15": "\u010f",
206 | "16": "\u0110",
207 | "17": "\u0111",
208 | "18": "\u0112",
209 | "19": "\u0113",
210 | "20": "\u0114",
211 | "21": "\u0115",
212 | "22": "\u0116",
213 | "23": "\u0117",
214 | "24": "\u0118",
215 | "25": "\u0119",
216 | "26": "\u011a",
217 | "27": "\u011b",
218 | "28": "\u011c",
219 | "29": "\u011d",
220 | "30": "\u011e",
221 | "31": "\u011f",
222 | "32": "\u0120",
223 | "127": "\u0121",
224 | "128": "\u0122",
225 | "129": "\u0123",
226 | "130": "\u0124",
227 | "131": "\u0125",
228 | "132": "\u0126",
229 | "133": "\u0127",
230 | "134": "\u0128",
231 | "135": "\u0129",
232 | "136": "\u012a",
233 | "137": "\u012b",
234 | "138": "\u012c",
235 | "139": "\u012d",
236 | "140": "\u012e",
237 | "141": "\u012f",
238 | "142": "\u0130",
239 | "143": "\u0131",
240 | "144": "\u0132",
241 | "145": "\u0133",
242 | "146": "\u0134",
243 | "147": "\u0135",
244 | "148": "\u0136",
245 | "149": "\u0137",
246 | "150": "\u0138",
247 | "151": "\u0139",
248 | "152": "\u013a",
249 | "153": "\u013b",
250 | "154": "\u013c",
251 | "155": "\u013d",
252 | "156": "\u013e",
253 | "157": "\u013f",
254 | "158": "\u0140",
255 | "159": "\u0141",
256 | "160": "\u0142",
257 | "173": "\u0143"
258 | }
259 |
--------------------------------------------------------------------------------
/src/test/java/com/genesys/roberta/tokenizer/RobertaTokenizerTest.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import org.mockito.Mock;
4 | import org.mockito.MockitoAnnotations;
5 | import org.testng.Assert;
6 | import org.testng.annotations.BeforeClass;
7 | import org.testng.annotations.Test;
8 |
9 | import java.util.Arrays;
10 |
11 | import static com.genesys.roberta.tokenizer.utils.CommonTestUtils.getResourceAbsPath;
12 |
13 | public class RobertaTokenizerTest {
14 | private static final String VOCABULARY_BASE_DIR_PATH = getResourceAbsPath();
15 | private long clsToken;
16 | private long sepToken;
17 |
18 | @Mock
19 | private RobertaTokenizer robertaTokenizer;
20 |
21 | @BeforeClass
22 | public void initDataMembersBeforeClass() {
23 | MockitoAnnotations.openMocks(this);
24 | RobertaTokenizerResources robertaResources = new RobertaTokenizerResources(VOCABULARY_BASE_DIR_PATH);
25 | robertaTokenizer = new RobertaTokenizer(robertaResources);
26 | clsToken = robertaTokenizer.getClsToken();
27 | sepToken = robertaTokenizer.getSepToken();
28 | }
29 |
30 | @Test(expectedExceptions = NullPointerException.class)
31 | public void nullResourcesFactory() {
32 | new RobertaTokenizer(null);
33 | }
34 |
35 | @Test
36 | public void longSentenceWithTruncating() {
37 | // er token is 19, this sentence holds 24 occurrences of "er"
38 | String sentence = "erererererererererererererererererererererererer";
39 | long expectedToken = 19;
40 | long[] actualEncoding = robertaTokenizer.tokenize(sentence);
41 | Assert.assertEquals(actualEncoding[0], clsToken);
42 | Assert.assertTrue(Arrays.stream(actualEncoding).skip(1).takeWhile(token -> token != sepToken)
43 | .allMatch(token -> token == expectedToken));
44 | Assert.assertEquals(actualEncoding[actualEncoding.length - 1], sepToken);
45 | }
46 |
47 | @Test
48 | public void addingBeginningAndEndTokensToSentence() {
49 | String sentence = "er";
50 | long expectedToken = 19;
51 | long[] actualTokens = robertaTokenizer.tokenize(sentence);
52 | Assert.assertEquals(actualTokens[0], clsToken);
53 | Assert.assertEquals(actualTokens[1], expectedToken);
54 | Assert.assertEquals(actualTokens[2], sepToken);
55 | }
56 |
57 | /**
58 | * Since this sentence is well-defined according to the vocabulary, we know what tokens to expect.
59 | * Taken from here:
60 | * https://github.com/huggingface/transformers/blob/v4.20.1/tests/models/roberta/test_tokenization_roberta.py#L94
61 | */
62 | @Test
63 | public void tokenizeCorrectly() {
64 | String sentence = "lower newer";
65 | long[] expectedTokens = {
66 | clsToken,
67 | 4, 5, 6, 19, // lower
68 | 114, 13, 7, 6, 19, // newer
69 | sepToken};
70 | long[] actualTokens = robertaTokenizer.tokenize(sentence);
71 | Assert.assertEquals(actualTokens, expectedTokens);
72 | }
73 |
74 | @Test
75 | public void emptySentence() {
76 | long[] actualTokens = robertaTokenizer.tokenize("");
77 | Assert.assertEquals(actualTokens[0], clsToken);
78 | Assert.assertEquals(actualTokens[1], sepToken);
79 | }
80 |
81 | @Test
82 | public void veryLongWord() {
83 | String originalText =
84 | "https://www.google.com/search?as_q=you+have+to+write+a+really+really+long+search+to+get+to+2000+" +
85 | "characters.+like+seriously%2C+you+have+no+idea+how+long+it+has+to+be&as_epq=2000+characters+" +
86 | "is+absolutely+freaking+enormous.+You+can+fit+sooooooooooooooooooooooooooooooooo+much+data+" +
87 | "into+2000+characters.+My+hands+are+getting+tired+typing+this+many+characters.+I+didn%27t+" +
88 | "even+realise+how+long+it+was+going+to+take+to+type+them+all.&as_oq=Argh!+So+many+" +
89 | "characters.+I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.+I%27m+bored+now%2C+so+I%27ll+" +
90 | "just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+" +
91 | "so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" +
92 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+" +
93 | "bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" +
94 | "paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+" +
95 | "copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+" +
96 | "I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" +
97 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+" +
98 | "bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" +
99 | "paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+" +
100 | "copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+" +
101 | "I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" +
102 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" +
103 | "paste.&as_eq=It+has+to+be+freaking+enormously+freaking+enormous&as_nlo=123&as_nhi=456&lr=" +
104 | "lang_hu&cr=countryAD&as_qdr=m&as_sitesearch=stackoverflow.com&as_occt=title&safe=active&tbs=" +
105 | "rl%3A1%2Crls%3A0&as_filetype=xls&as_rights=(cc_publicdomain%7Ccc_attribute%7Ccc_sharealike%" +
106 | "7Ccc_nonderived).-(cc_noncommercial)&gws_rd=ssl";
107 | robertaTokenizer.tokenize(originalText);
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 |
5 | cloud.genesys
6 | roberta-tokenizer
7 | 1.0.7
8 | jar
9 | Tokenizer for RoBERTa model
10 | roberta-tokenizer
11 | https://github.com/purecloudlabs/roberta-tokenizer
12 |
13 |
14 |
15 | MIT License
16 | http://www.opensource.org/licenses/mit-license.php
17 |
18 |
19 |
20 |
21 |
22 | Raviv Trichter
23 | raviv.trichter@genesys.com
24 | Genesys
25 | https://github.com/purecloudlabs
26 |
27 |
28 |
29 |
30 | https://github.com/purecloudlabs/roberta-tokenizer.git
31 | https://github.com/purecloudlabs/roberta-tokenizer.git
32 | https://github.com/purecloudlabs/roberta-tokenizer.git
33 |
34 |
35 |
36 | UTF-8
37 | 17
38 | 17
39 | 17
40 |
41 |
42 | 1.18.24
43 | 31.1-jre
44 | 2.9.0
45 | 4.3.1
46 | 7.0.0
47 | 3.10.1
48 | 3.0.0-M7
49 | 3.4.1
50 | 3.2.1
51 | 3.0.1
52 | 1.6.13
53 |
54 |
55 |
56 |
57 | ossrh
58 | https://s01.oss.sonatype.org/content/repositories/snapshots
59 |
60 |
61 | ossrh
62 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/
63 |
64 |
65 |
66 |
67 |
68 |
69 | org.sonatype.plugins
70 | nexus-staging-maven-plugin
71 | ${nexus.staging.maven.plugin.version}
72 | true
73 |
74 | ossrh
75 | https://s01.oss.sonatype.org/
76 | true
77 |
78 |
79 |
80 | org.apache.maven.plugins
81 | maven-compiler-plugin
82 | ${maven.compiler.plugin.version}
83 |
84 | ${java.version}
85 | ${java.version}
86 |
87 |
88 | org.projectlombok
89 | lombok
90 | ${lombok.version}
91 |
92 |
93 |
94 |
95 |
96 | org.apache.maven.plugins
97 | maven-surefire-plugin
98 | ${maven.surefire.plugin.version}
99 |
100 | false
101 |
102 |
103 |
104 | org.apache.maven.plugins
105 | maven-gpg-plugin
106 | ${maven.gpg.plugin}
107 |
108 |
109 | sign-artifacts
110 | verify
111 |
112 | sign
113 |
114 |
115 |
116 |
117 |
118 | org.apache.maven.plugins
119 | maven-source-plugin
120 | ${maven.source.plugin.version}
121 |
122 |
123 | attach-sources
124 |
125 | jar-no-fork
126 |
127 |
128 |
129 |
130 |
131 | org.apache.maven.plugins
132 | maven-javadoc-plugin
133 | ${maven.javadoc.plugin.version}
134 |
135 |
136 | attach-javadocs
137 |
138 | jar
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 | org.mockito
149 | mockito-core
150 | ${mockito.core.version}
151 | test
152 |
153 |
154 | org.projectlombok
155 | lombok
156 | ${lombok.version}
157 |
158 |
159 | com.google.guava
160 | guava
161 | ${com.google.guava.version}
162 |
163 |
164 | com.google.code.gson
165 | gson
166 | ${com.google.gson.version}
167 |
168 |
169 | org.testng
170 | testng
171 | ${testng.version}
172 | test
173 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/src/main/java/com/genesys/roberta/tokenizer/RobertaTokenizerResources.java:
--------------------------------------------------------------------------------
1 | package com.genesys.roberta.tokenizer;
2 |
3 | import com.google.gson.Gson;
4 | import com.google.gson.reflect.TypeToken;
5 | import lombok.NonNull;
6 |
7 | import java.io.FileNotFoundException;
8 | import java.io.IOException;
9 | import java.nio.charset.StandardCharsets;
10 | import java.nio.file.Files;
11 | import java.nio.file.Path;
12 | import java.nio.file.Paths;
13 | import java.util.Collections;
14 | import java.util.HashMap;
15 | import java.util.List;
16 | import java.util.Map;
17 | import java.util.function.Function;
18 | import java.util.stream.Collectors;
19 | import java.util.stream.IntStream;
20 |
21 | import static com.google.common.base.Preconditions.checkState;
22 |
23 | /**
24 | * Holds the vocabularies and the merges file used to encode and tokenize the inputs.
25 | */
26 | public class RobertaTokenizerResources {
27 |
28 | private static final String BASE_VOCABULARY_FILE_NAME = "base_vocabulary.json";
29 | private static final String VOCABULARY_FILE_NAME = "vocabulary.json";
30 | private static final String MERGES_FILE_NAME = "merges.txt";
31 |
32 | private final Map baseVocabularyMap;
33 | private final Map vocabularyMap;
34 | private final Map bpeRanks;
35 |
36 | /**
37 | * @param resourcesPath expecting this path to hold (with their names):
38 | * Base Vocabulary - base_vocabulary.txt
39 | * Vocabulary - vocabulary.json
40 | * Merges - merges.txt
41 | */
42 | public RobertaTokenizerResources(@NonNull final String resourcesPath) {
43 | this.baseVocabularyMap = loadBaseVocabulary(resourcesPath);
44 | this.vocabularyMap = loadVocabulary(resourcesPath);
45 | this.bpeRanks = loadMergesFile(resourcesPath);
46 | }
47 |
48 | private Map loadBaseVocabulary(@NonNull final String resourcesPath) {
49 | final Path baseVocabPath = Paths.get(resourcesPath, BASE_VOCABULARY_FILE_NAME);
50 | try {
51 | checkPathExists(baseVocabPath,
52 | String.format("base vocabulary file path for Roberta: [ %s ] was not found", baseVocabPath));
53 | final Map baseVocabMap = new Gson()
54 | .fromJson(Files.readString(baseVocabPath), new TypeToken>(){}.getType());
55 | return Collections.unmodifiableMap(baseVocabMap);
56 | } catch (IOException e) {
57 | throw new IllegalStateException(String.format(
58 | "Failed to load base vocabulary map for Roberta from [ %s ]", baseVocabPath), e);
59 | }
60 | }
61 |
62 | private Map loadVocabulary(@NonNull final String resourcesPath) {
63 | final Path vocabPath = Paths.get(resourcesPath, VOCABULARY_FILE_NAME);
64 | try {
65 | checkPathExists(vocabPath,
66 | String.format("vocabulary file path for Roberta: [%s] was not found", vocabPath));
67 | final Map vocabMap = new Gson()
68 | .fromJson(Files.readString(vocabPath), new TypeToken>(){}.getType());
69 | return Collections.unmodifiableMap(vocabMap);
70 | } catch (IOException e) {
71 | throw new IllegalStateException(String.format(
72 | "Failed to load vocabulary for Roberta from file path [ %s ]", vocabPath), e);
73 | }
74 | }
75 |
76 | /**
77 | * This method allows merges file to be with or without the header.
78 | * Other than that, it will accept in every line one BiGram ONLY, split by one space.
79 | *
80 | * @param resourcesPath resources dir path
81 | * @return the merges map
82 | */
83 | private Map loadMergesFile(@NonNull final String resourcesPath) {
84 | final Path mergesPath = Paths.get(resourcesPath, MERGES_FILE_NAME);
85 | try {
86 | checkPathExists(mergesPath,
87 | String.format("%s merges file path: [%s] was not found", RobertaTokenizerResources.class.getSimpleName(),
88 | mergesPath));
89 |
90 | final List lines = Files.readAllLines(mergesPath, StandardCharsets.UTF_8);
91 | final int startIndex = isMergesFileWithHeader(lines) ? 1 : 0;
92 |
93 | return IntStream.range(startIndex, lines.size()).boxed()
94 | .collect(Collectors.toUnmodifiableMap(idx -> BiGram.of(lines.get(idx).split(" ")), Function.identity()));
95 | } catch (IOException e) {
96 | throw new IllegalStateException(String.format(
97 | "Failed to load merges file for Roberta from file path [ %s ]", mergesPath), e);
98 | }
99 | }
100 |
101 | /**
102 | * Encoding the given key to a mapped String which represents a character from the base vocabulary.
103 | * Since the input is of type byte values we except only values [-127, 128].
104 | * Shifting the range with the unsigned int operation to [0, 255] \
105 | * the exact size of our base vocab map - what assures us valid input.
106 | *
107 | * @param key - byte to encode
108 | * @return associated String according to the base vocabulary json
109 | */
110 | public String encodeByte(final byte key) {
111 | // In case the byte is negative we add to it 256 by a Bitwise AND so it will be in range [0, 255]
112 | // This solution was taken from the below StackOverflow thread
113 | // https://stackoverflow.com/questions/22575308/getbytes-returns-negative-number/22575346#22575346
114 | return baseVocabularyMap.get(Byte.toUnsignedInt(key));
115 | }
116 |
117 | /**
118 | * Converts a word into an integer (long) according to the word vocabulary file
119 | * @param word (or subword) after bpe was applied on it
120 | * @param defaultValue positive integer
121 | * @return mapped token according to the vocabulary or default value if it didn't exist
122 | */
123 | public Long encodeWord(@NonNull final String word, final long defaultValue) {
124 | return vocabularyMap.getOrDefault(word, defaultValue);
125 | }
126 |
127 | /**
128 | * Returns the rank for the given BiGram according to the rank file
129 | * @param biGram a pair of Strings
130 | * @param defaultValue positive integer
131 | * @return the rank of that pair or default value if it doesn't exist
132 | */
133 | public Integer getRankOrDefault(@NonNull final BiGram biGram, final int defaultValue) {
134 | return bpeRanks.getOrDefault(biGram, defaultValue);
135 | }
136 |
137 | /**
138 | * Since we use HuggingFace tokenizers, the merges file output might have a comment in the head of the file like:
139 | * "#version: 0.2 - Trained by `huggingface/tokenizers`"
140 | *
141 | * @param lines - all lines of the merges file
142 | * @return true if merges file starts with a comment and false o.w.
143 | */
144 | private boolean isMergesFileWithHeader(@NonNull final List lines) {
145 | checkState(!lines.isEmpty(), "provided empty merges file");
146 | final String header = lines.get(0);
147 | return header.split(" ").length != 2;
148 | }
149 |
150 | private static void checkPathExists(final Path path, final String errorMsg) throws FileNotFoundException {
151 | if (!Files.exists(path)) {
152 | throw new FileNotFoundException(errorMsg);
153 | }
154 | }
155 | }
156 |
--------------------------------------------------------------------------------
/src/main/resources/checkstyle.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
47 |
48 |
49 |
50 |
51 |
53 |
54 |
55 |
56 |
57 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
67 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
105 |
106 |
107 |
108 |
110 |
111 |
112 |
113 |
115 |
116 |
117 |
118 |
120 |
121 |
122 |
123 |
125 |
126 |
127 |
128 |
130 |
131 |
132 |
133 |
134 |
136 |
137 |
138 |
139 |
141 |
142 |
143 |
144 |
146 |
147 |
148 |
149 |
151 |
152 |
153 |
154 |
156 |
157 |
158 |
159 |
161 |
162 |
163 |
164 |
166 |
168 |
170 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
--------------------------------------------------------------------------------