├── src ├── test │ ├── resources │ │ └── test-vocabularies │ │ │ ├── merges.txt │ │ │ ├── vocabulary.json │ │ │ └── base_vocabulary.json │ └── java │ │ └── com │ │ └── genesys │ │ └── roberta │ │ └── tokenizer │ │ ├── utils │ │ └── CommonTestUtils.java │ │ ├── RobertaTokenizerResourcesTest.java │ │ ├── BytePairEncoderTest.java │ │ └── RobertaTokenizerTest.java └── main │ ├── java │ └── com │ │ └── genesys │ │ └── roberta │ │ └── tokenizer │ │ ├── Tokenizer.java │ │ ├── BiGram.java │ │ ├── RobertaTokenizer.java │ │ ├── BytePairEncoder.java │ │ └── RobertaTokenizerResources.java │ └── resources │ └── checkstyle.xml ├── .gitignore ├── LICENSE ├── README.md └── pom.xml /src/test/resources/test-vocabularies/merges.txt: -------------------------------------------------------------------------------- 1 | \u0120 l 2 | \u0120l o 3 | \u0120lo w 4 | e r 5 | -------------------------------------------------------------------------------- /src/test/resources/test-vocabularies/vocabulary.json: -------------------------------------------------------------------------------- 1 | { 2 | "": 0, 3 | "": 1, 4 | "": 2, 5 | "": 3, 6 | "l": 4, 7 | "o": 5, 8 | "w": 6, 9 | "e": 7, 10 | "r": 8, 11 | "s": 9, 12 | "t": 10, 13 | "i": 11, 14 | "d": 12, 15 | "n": 13, 16 | "\u0120": 114, 17 | "\u0120l": 15, 18 | "\u0120n": 16, 19 | "\u0120lo": 17, 20 | "\u0120low": 18, 21 | "er": 19, 22 | "\u0120lowest": 20, 23 | "\u0120newer": 21, 24 | "\u0120wider": 22 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/com/genesys/roberta/tokenizer/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | /** 4 | * The use of this interface will have the ability to tokenize given String inputs 5 | */ 6 | interface Tokenizer { 7 | 8 | /** 9 | * Converts given input sentence to an array of long tokens 10 | * 11 | * @param sentence One or more words delimited by space 12 | * @return list of input IDs with the appropriate tokens 13 | */ 14 | long[] tokenize(String sentence); 15 | } 16 | -------------------------------------------------------------------------------- /src/test/java/com/genesys/roberta/tokenizer/utils/CommonTestUtils.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer.utils; 2 | 3 | import com.genesys.roberta.tokenizer.RobertaTokenizer; 4 | 5 | import java.io.File; 6 | import java.util.Objects; 7 | 8 | public class CommonTestUtils { 9 | 10 | public static String getResourceAbsPath() { 11 | String resourceRelPath = "test-vocabularies"; 12 | return new File(Objects.requireNonNull(RobertaTokenizer.class.getClassLoader().getResource(resourceRelPath)) 13 | .getFile()).getAbsolutePath(); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Node artifact files 2 | node_modules/ 3 | dist/ 4 | 5 | # Compiled Java class files 6 | *.class 7 | 8 | # Compiled Python bytecode 9 | *.py[cod] 10 | 11 | # Log files 12 | *.log 13 | 14 | # Package files 15 | *.jar 16 | 17 | # Maven 18 | target/ 19 | dist/ 20 | 21 | # JetBrains IDE 22 | OLD.idea/ 23 | .idea/ 24 | *.iml 25 | 26 | # Unit test reports 27 | TEST*.xml 28 | 29 | # Generated by MacOS 30 | .DS_Store 31 | /target/ 32 | __pycache__/ 33 | 34 | # Generated by Windows 35 | Thumbs.db 36 | 37 | # Applications 38 | *.app 39 | *.exe 40 | *.war 41 | 42 | # Large media files 43 | *.mp4 44 | *.tiff 45 | *.avi 46 | *.flv 47 | *.mov 48 | *.wmv 49 | 50 | **/test-output 51 | **/target 52 | **/resources/static/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Genesys Cloud Services, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /src/main/java/com/genesys/roberta/tokenizer/BiGram.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import lombok.EqualsAndHashCode; 4 | import lombok.NonNull; 5 | 6 | import static com.google.common.base.Preconditions.checkState; 7 | import static java.lang.String.format; 8 | 9 | /** 10 | * A sequence of two adjacent elements from a string which differs by their position - left or right 11 | */ 12 | @EqualsAndHashCode 13 | class BiGram { 14 | private static final int PAIR_SIZE = 2; 15 | private final String left; 16 | private final String right; 17 | 18 | private BiGram(@NonNull final String[] pairArray) { 19 | checkState(pairArray.length == PAIR_SIZE, 20 | format("Expecting BiGram pair to be of size: [%d] but it's of size: [%d]", PAIR_SIZE, pairArray.length)); 21 | this.left = pairArray[0]; 22 | this.right = pairArray[1]; 23 | } 24 | 25 | private BiGram(@NonNull final String left, @NonNull final String right) { 26 | this.left = left; 27 | this.right = right; 28 | } 29 | 30 | /** 31 | * Creates a new BiGram object from an array of Strings. 32 | * Expecting the array to be of size 2. 33 | * @param pairArray array of Strings 34 | * @return new BiGram where the String pairArray[0] will be left and pairArray[1] right. 35 | */ 36 | public static BiGram of(@NonNull final String[] pairArray) { 37 | return new BiGram(pairArray); 38 | } 39 | 40 | /** 41 | * Creates an object with given parameters. 42 | * @param left will be the left String of the BiGRam 43 | * @param right will be the right String of the BiGRam 44 | * @return new BiGram object 45 | */ 46 | public static BiGram of(@NonNull final String left, @NonNull final String right) { 47 | return new BiGram(left, right); 48 | } 49 | 50 | public String getLeft() { 51 | return this.left; 52 | } 53 | 54 | public String getRight() { 55 | return this.right; 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/com/genesys/roberta/tokenizer/RobertaTokenizerResourcesTest.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import lombok.val; 4 | import org.testng.Assert; 5 | import org.testng.annotations.BeforeClass; 6 | import org.testng.annotations.Test; 7 | 8 | import static com.genesys.roberta.tokenizer.utils.CommonTestUtils.getResourceAbsPath; 9 | 10 | public class RobertaTokenizerResourcesTest { 11 | 12 | private static final String VOCABULARY_BASE_DIR_PATH = getResourceAbsPath(); 13 | private static final long UNKNOWN_TOKEN = RobertaTokenizer.DEFAULT_UNK_TOKEN; 14 | private RobertaTokenizerResources robertaTokenizerResources; 15 | 16 | @BeforeClass 17 | public void initDataMembersBeforeClass() { 18 | robertaTokenizerResources = new RobertaTokenizerResources(VOCABULARY_BASE_DIR_PATH); 19 | } 20 | 21 | @Test(expectedExceptions = NullPointerException.class) 22 | public void nullBaseDirPath() { 23 | new RobertaTokenizerResources(null); 24 | } 25 | 26 | @Test(expectedExceptions = IllegalStateException.class) 27 | public void vocabularyBaseDirPathNotExist() { 28 | new RobertaTokenizerResources("dummy/base/dir/path"); 29 | } 30 | 31 | @Test 32 | public void minByteValue() { 33 | byte key = -128; 34 | val encodedChar = robertaTokenizerResources.encodeByte(key); 35 | Assert.assertNotNull(encodedChar); 36 | } 37 | 38 | @Test 39 | public void maxByteValue() { 40 | byte key = 127; 41 | val encodedChar = robertaTokenizerResources.encodeByte(key); 42 | Assert.assertNotNull(encodedChar); 43 | } 44 | 45 | @Test 46 | public void wordDoesNotExist() { 47 | String word = "Funnel"; 48 | Long actualToken = robertaTokenizerResources.encodeWord(word, UNKNOWN_TOKEN); 49 | Assert.assertEquals(actualToken.longValue(), UNKNOWN_TOKEN); 50 | } 51 | 52 | @Test 53 | public void wordExists() { 54 | String word = "er"; 55 | long expectedToken = 19; 56 | Long actualToken = robertaTokenizerResources.encodeWord(word, UNKNOWN_TOKEN); 57 | Assert.assertEquals(actualToken.longValue(), expectedToken); 58 | } 59 | 60 | @Test 61 | public void pairExists() { 62 | BiGram bigram = BiGram.of("e", "r"); 63 | int actualRank = robertaTokenizerResources.getRankOrDefault(bigram, Integer.MAX_VALUE); 64 | int expectedRank = 3; 65 | Assert.assertEquals(actualRank, expectedRank); 66 | } 67 | 68 | @Test 69 | public void pairDoesNotExist() { 70 | BiGram bigram = BiGram.of("Zilpa", "Funnel"); 71 | int actualRank = robertaTokenizerResources.getRankOrDefault(bigram, Integer.MAX_VALUE); 72 | Assert.assertEquals(actualRank, Integer.MAX_VALUE); 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /src/test/java/com/genesys/roberta/tokenizer/BytePairEncoderTest.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import org.mockito.Mock; 4 | import org.mockito.MockitoAnnotations; 5 | import org.testng.Assert; 6 | import org.testng.annotations.BeforeClass; 7 | import org.testng.annotations.Test; 8 | 9 | import java.util.Arrays; 10 | import java.util.HashMap; 11 | import java.util.List; 12 | import java.util.Map; 13 | 14 | import static org.mockito.ArgumentMatchers.any; 15 | import static org.mockito.ArgumentMatchers.anyInt; 16 | import static org.mockito.Mockito.when; 17 | 18 | 19 | public class BytePairEncoderTest { 20 | 21 | private BytePairEncoder bytePairEncoder; 22 | private Map ranks; 23 | 24 | @Mock 25 | private RobertaTokenizerResources robertaTokenizerResources; 26 | 27 | @BeforeClass 28 | public void setupBeforeClass() { 29 | MockitoAnnotations.openMocks(this); 30 | ranks = new HashMap<>() {{ 31 | put(BiGram.of("Ġ", "l"), 0); 32 | put(BiGram.of("Ġl", "o"), 1); 33 | put(BiGram.of("Ġlo", "w"), 2); 34 | put(BiGram.of("e", "r"), 3); 35 | }}; 36 | bytePairEncoder = new BytePairEncoder(); 37 | when(robertaTokenizerResources.getRankOrDefault(any(BiGram.class), anyInt())) 38 | .thenAnswer(input -> ranks.getOrDefault(input.getArgument(0), Integer.MAX_VALUE)); 39 | } 40 | 41 | @Test(expectedExceptions = NullPointerException.class) 42 | public void nullWordTest() { 43 | bytePairEncoder.encode(null, robertaTokenizerResources); 44 | } 45 | 46 | @Test 47 | public void correctSplitTest() { 48 | List actualSplit = bytePairEncoder.encode("lowerĠnewer", robertaTokenizerResources); 49 | // The vocabulary rules and characters were taken from here: 50 | // https://github.com/huggingface/transformers/blob/v4.20.1/tests/models/roberta/test_tokenization_roberta.py#L86 51 | List expectedSplit = Arrays.asList("l", "o", "w", "er", "Ġ", "n", "e", "w", "er"); 52 | Assert.assertEquals(actualSplit, expectedSplit); 53 | } 54 | 55 | @Test 56 | public void emptySplitTest() { 57 | List actualSplit = bytePairEncoder.encode("", robertaTokenizerResources); 58 | Assert.assertTrue(actualSplit.isEmpty()); 59 | } 60 | 61 | @Test 62 | public void noMergeRulesForWordTest() { 63 | List actualSplit = bytePairEncoder.encode("qpyt", robertaTokenizerResources); 64 | // Since all these characters do not appear at all at the ranks map i.e., no merge rule for any of them 65 | // we would expect each one to be encoded alone 66 | List expectedSplit = Arrays.asList("q", "p", "y", "t"); 67 | Assert.assertEquals(actualSplit, expectedSplit); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RoBERTa Java Tokenizer # 2 | 3 | 4 | ## About 5 | 6 | --- 7 | This repo contains a Java tokenizer used by RoBERTa model. The implementation is mainly according to HuggingFace Python 8 | RoBERTa Tokenizer, but also we took references from other implementations as mentioned in the code and below: 9 | 10 | * https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer 11 | 12 | * https://github.com/huggingface/tflite-android-transformers/blob/master/gpt2/src/main/java/co/huggingface/android_transformers/gpt2/tokenization/GPT2Tokenizer.kt 13 | 14 | * https://github.com/hyunwoongko/gpt2-tokenizer-java/blob/master/src/main/java/ai/tunib/tokenizer/GPT2Tokenizer.java 15 | 16 | The algorithm used is a byte-level Byte Pair Encoding. 17 | 18 | https://huggingface.co/docs/transformers/tokenizer_summary#bytelevel-bpe 19 | ## How do I get set up? ### 20 | 21 | --- 22 | 23 | * Clone the repo for explicit usage. 24 | * Add the Maven dependency to your `pom.xml` for usage in your project: 25 | 26 | ``` 27 | 28 | cloud.genesys 29 | roberta-tokenizer 30 | 1.0.7 31 | 32 | 33 | 34 | 35 | ossrh 36 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 37 | 38 | ... 39 | 40 | ``` 41 | 42 | 43 | ### Tests ### 44 | 45 | --- 46 | 47 | * Unit tests - Run on local machine. 48 | 49 | ### File Dependencies ### 50 | 51 | --- 52 | 53 | Since we want efficiency when initializing the tokenizer, we use a factory to create the relevant resources 54 | files and create it "lazily". 55 | 56 | For this tokenizer we need 3 data files: 57 | 58 | * `base_vocabulary.json` - map of numbers ([0,255]) to symbols (UniCode Characters). Only those symbols will be known by the 59 | algorithm. e.g., given _s_ as input it iterates over the bytes of the String _s_ and replaces each given byte with the mapped symbol. 60 | This way we assure what characters are passed. 61 | 62 | * `vocabulary.json` - Is a file that holds all the words(sub-words) and their token according to training. 63 | 64 | * `merges.txt` - describes the merge rules of words. The algorithm splits the given word into two subwords, afterwards 65 | it decides the best split according to the rank of the sub words. The higher those words are, the higher the rank. 66 | 67 | __Please note__: 68 | 69 | 1. All three files must be under the same directory. 70 | 71 | 2. They must be named like mentioned above. 72 | 73 | 3. The result of the tokenization depends on the vocabulary and merges files. 74 | 75 | ### Example ### 76 | 77 | --- 78 | 79 | ``` 80 | 81 | String baseDirPath = "base/dir/path"; 82 | RobertaTokenizerResources robertaResources = new RobertaTokenizerResources(baseDirPath); 83 | Tokenizer robertaTokenizer = new RobertaTokenizer(robertaResources); 84 | ... 85 | String sentence = "this must be the place"; 86 | long[] tokenizedSentence = robertaTokenizer.tokenize(sentence); 87 | System.out.println(tokenizedSentence); 88 | 89 | ``` 90 | 91 | An example output would be: `[0, 9226, 531, 28, 5, 317, 2]` - Depends on the given vocabulary and merges files. 92 | 93 | ### Contribution guidelines 94 | 95 | --- 96 | 97 | * Use temporary branches for every issue/task. 98 | -------------------------------------------------------------------------------- /src/main/java/com/genesys/roberta/tokenizer/RobertaTokenizer.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import lombok.NonNull; 4 | import lombok.val; 5 | 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.ArrayList; 8 | import java.util.List; 9 | import java.util.regex.Matcher; 10 | import java.util.regex.Pattern; 11 | import java.util.stream.LongStream; 12 | 13 | import static java.util.stream.LongStream.concat; 14 | import static java.util.stream.LongStream.of; 15 | 16 | /** 17 | * Tokenizer used for the RoBERTa model. 18 | * Encode sentences to integer tokens. 19 | * 20 | * This tokenizer is implemented according to the following: 21 | * - https://huggingface.co/docs/transformers/model_doc/roberta#transformers.RobertaTokenizer 22 | * - https://github.com/hyunwoongko/gpt2-tokenizer-java/blob/master/src/main/java/ai/tunib/tokenizer/GPT2Tokenizer.java 23 | * - https://github.com/huggingface/tflite-android-transformers/blob/master/gpt2/src/main/java/co/huggingface/android_transformers/\ 24 | * gpt2/tokenization/GPT2Tokenizer.kt 25 | */ 26 | public class RobertaTokenizer implements Tokenizer { 27 | 28 | public static final long DEFAULT_CLS_TOKEN = 0; 29 | public static final long DEFAULT_SEP_TOKEN = 2; 30 | public static final long DEFAULT_UNK_TOKEN = 3; 31 | 32 | //splits a given sentence by space in to words or sub-words 33 | private static final Pattern PATTERN = Pattern 34 | .compile("'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)|\\s+"); 35 | 36 | // Special tokens 37 | private final long clsToken; // Also BOS (beginning of sequence) token 38 | private final long sepToken; // Also EOS (end of sequence) token 39 | private final long unkToken; // Unknown Token. 40 | 41 | private final RobertaTokenizerResources robertaResources; 42 | private final BytePairEncoder bytePairEncoder; 43 | 44 | /** 45 | * Constructs a RoBERTa tokenizer, using byte-level Byte-Pair-Encoding. 46 | * 47 | * @param robertaTokenizerResources - responsible for providing roberta vocabularies and merges files. 48 | * 49 | * Note that this constructor will use HuggingFace's default special tokens: 50 | * [CLS_TOKEN = 0, SEP_TOKEN = 2, UNK_TOKEN = 3] 51 | */ 52 | public RobertaTokenizer(@NonNull final RobertaTokenizerResources robertaTokenizerResources) { 53 | this(robertaTokenizerResources, DEFAULT_CLS_TOKEN, DEFAULT_SEP_TOKEN, DEFAULT_UNK_TOKEN); 54 | } 55 | 56 | /** 57 | * Constructs a RoBERTa tokenizer, using byte-level Byte-Pair-Encoding. 58 | * 59 | * @param robertaTokenizerResources - responsible for providing roberta vocabularies and merges files. 60 | * @param clsToken Classification token 61 | * @param sepToken Separator token 62 | * @param unkToken Unknown token 63 | */ 64 | public RobertaTokenizer(@NonNull final RobertaTokenizerResources robertaTokenizerResources, final long clsToken, 65 | final long sepToken, final long unkToken) { 66 | this.robertaResources = robertaTokenizerResources; 67 | this.bytePairEncoder = new BytePairEncoder(); 68 | this.clsToken = clsToken; 69 | this.sepToken = sepToken; 70 | this.unkToken = unkToken; 71 | } 72 | 73 | /** 74 | * Encodes the given word into a list of tokens (long numbers) using Byte Level Byte-Pair-Encoding. 75 | * 76 | * @param sentence a word or more divided by space 77 | * @return an array of tokens (long) values 78 | */ 79 | @Override 80 | public long[] tokenize(@NonNull final String sentence) { 81 | List encodedStrings = new ArrayList<>(); 82 | 83 | Matcher matcher = PATTERN.matcher(sentence); 84 | while (matcher.find()) { 85 | String matchedSequence = matcher.group(); 86 | val matchedSequenceEncoded = new StringBuilder(); 87 | 88 | for (byte b : matchedSequence.getBytes(StandardCharsets.UTF_8)) { 89 | String encodedByte = this.robertaResources.encodeByte(b); 90 | matchedSequenceEncoded.append(encodedByte); 91 | } 92 | 93 | encodedStrings.add(matchedSequenceEncoded.toString()); 94 | } 95 | 96 | LongStream outputTokens = encodedStrings.stream() 97 | // returns list of strings ready for vocabulary mapping 98 | .map(encodedStr -> bytePairEncoder.encode(encodedStr, robertaResources)) 99 | // mapping each word in the given lists to a Long token from the vocabulary 100 | .flatMapToLong(encodedStrList -> encodedStrList.stream() 101 | .mapToLong(word -> this.robertaResources.encodeWord(word, unkToken))); 102 | 103 | outputTokens = concat(of(clsToken), outputTokens); // adding BOS 104 | return concat(outputTokens, of(sepToken)).toArray(); // adding EOS 105 | } 106 | 107 | public long getClsToken() { 108 | return clsToken; 109 | } 110 | 111 | public long getSepToken() { 112 | return sepToken; 113 | } 114 | 115 | public long getUnkToken() { 116 | return unkToken; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/main/java/com/genesys/roberta/tokenizer/BytePairEncoder.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import lombok.NonNull; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | import java.util.Set; 8 | import java.util.stream.Collectors; 9 | import java.util.stream.IntStream; 10 | 11 | /** 12 | * Byte-Pair-Encoding 13 | * Relies on a pre-tokenizer that splits the training data into words, in our case space. 14 | * 15 | * This greedy algorithm looks for the best way to divide given input word. 16 | * It does that by dividing the word into characters, then assembles sub strings of the given word trying to find the best 17 | * partition of the word according to the ranks of the merges file. 18 | */ 19 | class BytePairEncoder { 20 | 21 | /** 22 | * Applies the byte level BPE algorithm on the given word 23 | * 24 | * @param word one word from an input sentence 25 | * @param robertaTokenizerRobertaResources holds the vocabulary resources 26 | * @return a list of strings optimally partitioned and ready for tokenization 27 | */ 28 | public List encode(@NonNull final String word, @NonNull RobertaTokenizerResources robertaTokenizerRobertaResources) { 29 | List wordCharactersStrList = word.chars() 30 | .mapToObj(Character::toString) 31 | .collect(Collectors.toList()); 32 | 33 | Set biGramsSet = getBiGrams(wordCharactersStrList); 34 | 35 | while (true) { 36 | long minScore = Integer.MAX_VALUE; 37 | BiGram lowestScoreBiGram = null; 38 | 39 | for (BiGram biGram : biGramsSet) { 40 | long score = robertaTokenizerRobertaResources.getRankOrDefault(biGram, Integer.MAX_VALUE); 41 | 42 | // Note that we turn the most frequent bi-gram from a max problem to minimum 43 | // The lower the score the higher the frequency 44 | if (score < minScore) { 45 | minScore = score; 46 | lowestScoreBiGram = biGram; 47 | } 48 | } 49 | 50 | // Reaching here means that only BiGrams that aren’t in the vocabulary (got rank Integer.MAX_VALUE) are left in 51 | // wordCharactersStrList, so no more merges should be done and the final tokenized word is the current wordCharactersStrList. 52 | if (lowestScoreBiGram == null) { 53 | break; 54 | } 55 | 56 | String first = lowestScoreBiGram.getLeft(); 57 | String second = lowestScoreBiGram.getRight(); 58 | List newWordList = new ArrayList<>(); 59 | int currIdx = 0; 60 | 61 | while (currIdx < wordCharactersStrList.size()) { 62 | int biGramStartIndex = getIndexWithStartPosition(wordCharactersStrList, first, currIdx); 63 | 64 | if (biGramStartIndex != -1) { 65 | newWordList.addAll(wordCharactersStrList.subList(currIdx, biGramStartIndex)); 66 | currIdx = biGramStartIndex; 67 | } else { 68 | newWordList.addAll(wordCharactersStrList.subList(currIdx, wordCharactersStrList.size())); 69 | break; 70 | } 71 | 72 | if (wordCharactersStrList.get(currIdx).equals(first) && currIdx < wordCharactersStrList.size() - 1 && 73 | wordCharactersStrList.get(currIdx + 1).equals(second)) { 74 | newWordList.add(first + second); 75 | currIdx += 2; 76 | } else { 77 | newWordList.add(wordCharactersStrList.get(currIdx)); 78 | currIdx += 1; 79 | } 80 | } 81 | 82 | wordCharactersStrList = newWordList; 83 | if (wordCharactersStrList.size() == 1) { 84 | break; 85 | } else { 86 | biGramsSet = getBiGrams(wordCharactersStrList); 87 | } 88 | } 89 | 90 | return wordCharactersStrList; 91 | } 92 | 93 | /** 94 | * 95 | * @param wordStrChars all characters of the word represented each by a String 96 | * @return list of all adjacent biGrams 97 | * e.g., "hello" will be given as input: ["h", "e", "l", "l", "o"] and will return {"he", "el","ll", "lo"} 98 | */ 99 | private Set getBiGrams(@NonNull final List wordStrChars) { 100 | return IntStream.range(0, wordStrChars.size() - 1) 101 | .mapToObj(i -> BiGram.of(wordStrChars.get(i), wordStrChars.get(i + 1))) 102 | .collect(Collectors.toSet()); 103 | } 104 | 105 | /** 106 | * Looking for given word in wordCharsList and returns the index if found 107 | * 108 | * @param wordCharsList list of characters represented as Strings 109 | * @param word given word to search for 110 | * @param startPosition an index to start the search from 111 | * @return the index found o.w. -1 112 | */ 113 | private int getIndexWithStartPosition(@NonNull final List wordCharsList, @NonNull final String word, 114 | final int startPosition) { 115 | return IntStream.range(startPosition, wordCharsList.size()) 116 | .filter(idx -> wordCharsList.get(idx).equals(word)) 117 | .findFirst() 118 | .orElse(-1); 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/test/resources/test-vocabularies/base_vocabulary.json: -------------------------------------------------------------------------------- 1 | { 2 | "33": "!", 3 | "34": "\"", 4 | "35": "#", 5 | "36": "$", 6 | "37": "%", 7 | "38": "&", 8 | "39": "'", 9 | "40": "(", 10 | "41": ")", 11 | "42": "*", 12 | "43": "+", 13 | "44": ",", 14 | "45": "-", 15 | "46": ".", 16 | "47": "/", 17 | "48": "0", 18 | "49": "1", 19 | "50": "2", 20 | "51": "3", 21 | "52": "4", 22 | "53": "5", 23 | "54": "6", 24 | "55": "7", 25 | "56": "8", 26 | "57": "9", 27 | "58": ":", 28 | "59": ";", 29 | "60": "<", 30 | "61": "=", 31 | "62": ">", 32 | "63": "?", 33 | "64": "@", 34 | "65": "A", 35 | "66": "B", 36 | "67": "C", 37 | "68": "D", 38 | "69": "E", 39 | "70": "F", 40 | "71": "G", 41 | "72": "H", 42 | "73": "I", 43 | "74": "J", 44 | "75": "K", 45 | "76": "L", 46 | "77": "M", 47 | "78": "N", 48 | "79": "O", 49 | "80": "P", 50 | "81": "Q", 51 | "82": "R", 52 | "83": "S", 53 | "84": "T", 54 | "85": "U", 55 | "86": "V", 56 | "87": "W", 57 | "88": "X", 58 | "89": "Y", 59 | "90": "Z", 60 | "91": "[", 61 | "92": "\\", 62 | "93": "]", 63 | "94": "^", 64 | "95": "_", 65 | "96": "`", 66 | "97": "a", 67 | "98": "b", 68 | "99": "c", 69 | "100": "d", 70 | "101": "e", 71 | "102": "f", 72 | "103": "g", 73 | "104": "h", 74 | "105": "i", 75 | "106": "j", 76 | "107": "k", 77 | "108": "l", 78 | "109": "m", 79 | "110": "n", 80 | "111": "o", 81 | "112": "p", 82 | "113": "q", 83 | "114": "r", 84 | "115": "s", 85 | "116": "t", 86 | "117": "u", 87 | "118": "v", 88 | "119": "w", 89 | "120": "x", 90 | "121": "y", 91 | "122": "z", 92 | "123": "{", 93 | "124": "|", 94 | "125": "}", 95 | "126": "~", 96 | "161": "\u00a1", 97 | "162": "\u00a2", 98 | "163": "\u00a3", 99 | "164": "\u00a4", 100 | "165": "\u00a5", 101 | "166": "\u00a6", 102 | "167": "\u00a7", 103 | "168": "\u00a8", 104 | "169": "\u00a9", 105 | "170": "\u00aa", 106 | "171": "\u00ab", 107 | "172": "\u00ac", 108 | "174": "\u00ae", 109 | "175": "\u00af", 110 | "176": "\u00b0", 111 | "177": "\u00b1", 112 | "178": "\u00b2", 113 | "179": "\u00b3", 114 | "180": "\u00b4", 115 | "181": "\u00b5", 116 | "182": "\u00b6", 117 | "183": "\u00b7", 118 | "184": "\u00b8", 119 | "185": "\u00b9", 120 | "186": "\u00ba", 121 | "187": "\u00bb", 122 | "188": "\u00bc", 123 | "189": "\u00bd", 124 | "190": "\u00be", 125 | "191": "\u00bf", 126 | "192": "\u00c0", 127 | "193": "\u00c1", 128 | "194": "\u00c2", 129 | "195": "\u00c3", 130 | "196": "\u00c4", 131 | "197": "\u00c5", 132 | "198": "\u00c6", 133 | "199": "\u00c7", 134 | "200": "\u00c8", 135 | "201": "\u00c9", 136 | "202": "\u00ca", 137 | "203": "\u00cb", 138 | "204": "\u00cc", 139 | "205": "\u00cd", 140 | "206": "\u00ce", 141 | "207": "\u00cf", 142 | "208": "\u00d0", 143 | "209": "\u00d1", 144 | "210": "\u00d2", 145 | "211": "\u00d3", 146 | "212": "\u00d4", 147 | "213": "\u00d5", 148 | "214": "\u00d6", 149 | "215": "\u00d7", 150 | "216": "\u00d8", 151 | "217": "\u00d9", 152 | "218": "\u00da", 153 | "219": "\u00db", 154 | "220": "\u00dc", 155 | "221": "\u00dd", 156 | "222": "\u00de", 157 | "223": "\u00df", 158 | "224": "\u00e0", 159 | "225": "\u00e1", 160 | "226": "\u00e2", 161 | "227": "\u00e3", 162 | "228": "\u00e4", 163 | "229": "\u00e5", 164 | "230": "\u00e6", 165 | "231": "\u00e7", 166 | "232": "\u00e8", 167 | "233": "\u00e9", 168 | "234": "\u00ea", 169 | "235": "\u00eb", 170 | "236": "\u00ec", 171 | "237": "\u00ed", 172 | "238": "\u00ee", 173 | "239": "\u00ef", 174 | "240": "\u00f0", 175 | "241": "\u00f1", 176 | "242": "\u00f2", 177 | "243": "\u00f3", 178 | "244": "\u00f4", 179 | "245": "\u00f5", 180 | "246": "\u00f6", 181 | "247": "\u00f7", 182 | "248": "\u00f8", 183 | "249": "\u00f9", 184 | "250": "\u00fa", 185 | "251": "\u00fb", 186 | "252": "\u00fc", 187 | "253": "\u00fd", 188 | "254": "\u00fe", 189 | "255": "\u00ff", 190 | "0": "\u0100", 191 | "1": "\u0101", 192 | "2": "\u0102", 193 | "3": "\u0103", 194 | "4": "\u0104", 195 | "5": "\u0105", 196 | "6": "\u0106", 197 | "7": "\u0107", 198 | "8": "\u0108", 199 | "9": "\u0109", 200 | "10": "\u010a", 201 | "11": "\u010b", 202 | "12": "\u010c", 203 | "13": "\u010d", 204 | "14": "\u010e", 205 | "15": "\u010f", 206 | "16": "\u0110", 207 | "17": "\u0111", 208 | "18": "\u0112", 209 | "19": "\u0113", 210 | "20": "\u0114", 211 | "21": "\u0115", 212 | "22": "\u0116", 213 | "23": "\u0117", 214 | "24": "\u0118", 215 | "25": "\u0119", 216 | "26": "\u011a", 217 | "27": "\u011b", 218 | "28": "\u011c", 219 | "29": "\u011d", 220 | "30": "\u011e", 221 | "31": "\u011f", 222 | "32": "\u0120", 223 | "127": "\u0121", 224 | "128": "\u0122", 225 | "129": "\u0123", 226 | "130": "\u0124", 227 | "131": "\u0125", 228 | "132": "\u0126", 229 | "133": "\u0127", 230 | "134": "\u0128", 231 | "135": "\u0129", 232 | "136": "\u012a", 233 | "137": "\u012b", 234 | "138": "\u012c", 235 | "139": "\u012d", 236 | "140": "\u012e", 237 | "141": "\u012f", 238 | "142": "\u0130", 239 | "143": "\u0131", 240 | "144": "\u0132", 241 | "145": "\u0133", 242 | "146": "\u0134", 243 | "147": "\u0135", 244 | "148": "\u0136", 245 | "149": "\u0137", 246 | "150": "\u0138", 247 | "151": "\u0139", 248 | "152": "\u013a", 249 | "153": "\u013b", 250 | "154": "\u013c", 251 | "155": "\u013d", 252 | "156": "\u013e", 253 | "157": "\u013f", 254 | "158": "\u0140", 255 | "159": "\u0141", 256 | "160": "\u0142", 257 | "173": "\u0143" 258 | } 259 | -------------------------------------------------------------------------------- /src/test/java/com/genesys/roberta/tokenizer/RobertaTokenizerTest.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import org.mockito.Mock; 4 | import org.mockito.MockitoAnnotations; 5 | import org.testng.Assert; 6 | import org.testng.annotations.BeforeClass; 7 | import org.testng.annotations.Test; 8 | 9 | import java.util.Arrays; 10 | 11 | import static com.genesys.roberta.tokenizer.utils.CommonTestUtils.getResourceAbsPath; 12 | 13 | public class RobertaTokenizerTest { 14 | private static final String VOCABULARY_BASE_DIR_PATH = getResourceAbsPath(); 15 | private long clsToken; 16 | private long sepToken; 17 | 18 | @Mock 19 | private RobertaTokenizer robertaTokenizer; 20 | 21 | @BeforeClass 22 | public void initDataMembersBeforeClass() { 23 | MockitoAnnotations.openMocks(this); 24 | RobertaTokenizerResources robertaResources = new RobertaTokenizerResources(VOCABULARY_BASE_DIR_PATH); 25 | robertaTokenizer = new RobertaTokenizer(robertaResources); 26 | clsToken = robertaTokenizer.getClsToken(); 27 | sepToken = robertaTokenizer.getSepToken(); 28 | } 29 | 30 | @Test(expectedExceptions = NullPointerException.class) 31 | public void nullResourcesFactory() { 32 | new RobertaTokenizer(null); 33 | } 34 | 35 | @Test 36 | public void longSentenceWithTruncating() { 37 | // er token is 19, this sentence holds 24 occurrences of "er" 38 | String sentence = "erererererererererererererererererererererererer"; 39 | long expectedToken = 19; 40 | long[] actualEncoding = robertaTokenizer.tokenize(sentence); 41 | Assert.assertEquals(actualEncoding[0], clsToken); 42 | Assert.assertTrue(Arrays.stream(actualEncoding).skip(1).takeWhile(token -> token != sepToken) 43 | .allMatch(token -> token == expectedToken)); 44 | Assert.assertEquals(actualEncoding[actualEncoding.length - 1], sepToken); 45 | } 46 | 47 | @Test 48 | public void addingBeginningAndEndTokensToSentence() { 49 | String sentence = "er"; 50 | long expectedToken = 19; 51 | long[] actualTokens = robertaTokenizer.tokenize(sentence); 52 | Assert.assertEquals(actualTokens[0], clsToken); 53 | Assert.assertEquals(actualTokens[1], expectedToken); 54 | Assert.assertEquals(actualTokens[2], sepToken); 55 | } 56 | 57 | /** 58 | * Since this sentence is well-defined according to the vocabulary, we know what tokens to expect. 59 | * Taken from here: 60 | * https://github.com/huggingface/transformers/blob/v4.20.1/tests/models/roberta/test_tokenization_roberta.py#L94 61 | */ 62 | @Test 63 | public void tokenizeCorrectly() { 64 | String sentence = "lower newer"; 65 | long[] expectedTokens = { 66 | clsToken, 67 | 4, 5, 6, 19, // lower 68 | 114, 13, 7, 6, 19, // newer 69 | sepToken}; 70 | long[] actualTokens = robertaTokenizer.tokenize(sentence); 71 | Assert.assertEquals(actualTokens, expectedTokens); 72 | } 73 | 74 | @Test 75 | public void emptySentence() { 76 | long[] actualTokens = robertaTokenizer.tokenize(""); 77 | Assert.assertEquals(actualTokens[0], clsToken); 78 | Assert.assertEquals(actualTokens[1], sepToken); 79 | } 80 | 81 | @Test 82 | public void veryLongWord() { 83 | String originalText = 84 | "https://www.google.com/search?as_q=you+have+to+write+a+really+really+long+search+to+get+to+2000+" + 85 | "characters.+like+seriously%2C+you+have+no+idea+how+long+it+has+to+be&as_epq=2000+characters+" + 86 | "is+absolutely+freaking+enormous.+You+can+fit+sooooooooooooooooooooooooooooooooo+much+data+" + 87 | "into+2000+characters.+My+hands+are+getting+tired+typing+this+many+characters.+I+didn%27t+" + 88 | "even+realise+how+long+it+was+going+to+take+to+type+them+all.&as_oq=Argh!+So+many+" + 89 | "characters.+I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.+I%27m+bored+now%2C+so+I%27ll+" + 90 | "just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+" + 91 | "so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" + 92 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+" + 93 | "bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" + 94 | "paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+" + 95 | "copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+" + 96 | "I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" + 97 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+" + 98 | "bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" + 99 | "paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+" + 100 | "copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+" + 101 | "I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+" + 102 | "now%2C+so+I%27ll+just+copy+and+paste.I%27m+bored+now%2C+so+I%27ll+just+copy+and+" + 103 | "paste.&as_eq=It+has+to+be+freaking+enormously+freaking+enormous&as_nlo=123&as_nhi=456&lr=" + 104 | "lang_hu&cr=countryAD&as_qdr=m&as_sitesearch=stackoverflow.com&as_occt=title&safe=active&tbs=" + 105 | "rl%3A1%2Crls%3A0&as_filetype=xls&as_rights=(cc_publicdomain%7Ccc_attribute%7Ccc_sharealike%" + 106 | "7Ccc_nonderived).-(cc_noncommercial)&gws_rd=ssl"; 107 | robertaTokenizer.tokenize(originalText); 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | 5 | cloud.genesys 6 | roberta-tokenizer 7 | 1.0.7 8 | jar 9 | Tokenizer for RoBERTa model 10 | roberta-tokenizer 11 | https://github.com/purecloudlabs/roberta-tokenizer 12 | 13 | 14 | 15 | MIT License 16 | http://www.opensource.org/licenses/mit-license.php 17 | 18 | 19 | 20 | 21 | 22 | Raviv Trichter 23 | raviv.trichter@genesys.com 24 | Genesys 25 | https://github.com/purecloudlabs 26 | 27 | 28 | 29 | 30 | https://github.com/purecloudlabs/roberta-tokenizer.git 31 | https://github.com/purecloudlabs/roberta-tokenizer.git 32 | https://github.com/purecloudlabs/roberta-tokenizer.git 33 | 34 | 35 | 36 | UTF-8 37 | 17 38 | 17 39 | 17 40 | 41 | 42 | 1.18.24 43 | 31.1-jre 44 | 2.9.0 45 | 4.3.1 46 | 7.0.0 47 | 3.10.1 48 | 3.0.0-M7 49 | 3.4.1 50 | 3.2.1 51 | 3.0.1 52 | 1.6.13 53 | 54 | 55 | 56 | 57 | ossrh 58 | https://s01.oss.sonatype.org/content/repositories/snapshots 59 | 60 | 61 | ossrh 62 | https://s01.oss.sonatype.org/service/local/staging/deploy/maven2/ 63 | 64 | 65 | 66 | 67 | 68 | 69 | org.sonatype.plugins 70 | nexus-staging-maven-plugin 71 | ${nexus.staging.maven.plugin.version} 72 | true 73 | 74 | ossrh 75 | https://s01.oss.sonatype.org/ 76 | true 77 | 78 | 79 | 80 | org.apache.maven.plugins 81 | maven-compiler-plugin 82 | ${maven.compiler.plugin.version} 83 | 84 | ${java.version} 85 | ${java.version} 86 | 87 | 88 | org.projectlombok 89 | lombok 90 | ${lombok.version} 91 | 92 | 93 | 94 | 95 | 96 | org.apache.maven.plugins 97 | maven-surefire-plugin 98 | ${maven.surefire.plugin.version} 99 | 100 | false 101 | 102 | 103 | 104 | org.apache.maven.plugins 105 | maven-gpg-plugin 106 | ${maven.gpg.plugin} 107 | 108 | 109 | sign-artifacts 110 | verify 111 | 112 | sign 113 | 114 | 115 | 116 | 117 | 118 | org.apache.maven.plugins 119 | maven-source-plugin 120 | ${maven.source.plugin.version} 121 | 122 | 123 | attach-sources 124 | 125 | jar-no-fork 126 | 127 | 128 | 129 | 130 | 131 | org.apache.maven.plugins 132 | maven-javadoc-plugin 133 | ${maven.javadoc.plugin.version} 134 | 135 | 136 | attach-javadocs 137 | 138 | jar 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | org.mockito 149 | mockito-core 150 | ${mockito.core.version} 151 | test 152 | 153 | 154 | org.projectlombok 155 | lombok 156 | ${lombok.version} 157 | 158 | 159 | com.google.guava 160 | guava 161 | ${com.google.guava.version} 162 | 163 | 164 | com.google.code.gson 165 | gson 166 | ${com.google.gson.version} 167 | 168 | 169 | org.testng 170 | testng 171 | ${testng.version} 172 | test 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /src/main/java/com/genesys/roberta/tokenizer/RobertaTokenizerResources.java: -------------------------------------------------------------------------------- 1 | package com.genesys.roberta.tokenizer; 2 | 3 | import com.google.gson.Gson; 4 | import com.google.gson.reflect.TypeToken; 5 | import lombok.NonNull; 6 | 7 | import java.io.FileNotFoundException; 8 | import java.io.IOException; 9 | import java.nio.charset.StandardCharsets; 10 | import java.nio.file.Files; 11 | import java.nio.file.Path; 12 | import java.nio.file.Paths; 13 | import java.util.Collections; 14 | import java.util.HashMap; 15 | import java.util.List; 16 | import java.util.Map; 17 | import java.util.function.Function; 18 | import java.util.stream.Collectors; 19 | import java.util.stream.IntStream; 20 | 21 | import static com.google.common.base.Preconditions.checkState; 22 | 23 | /** 24 | * Holds the vocabularies and the merges file used to encode and tokenize the inputs. 25 | */ 26 | public class RobertaTokenizerResources { 27 | 28 | private static final String BASE_VOCABULARY_FILE_NAME = "base_vocabulary.json"; 29 | private static final String VOCABULARY_FILE_NAME = "vocabulary.json"; 30 | private static final String MERGES_FILE_NAME = "merges.txt"; 31 | 32 | private final Map baseVocabularyMap; 33 | private final Map vocabularyMap; 34 | private final Map bpeRanks; 35 | 36 | /** 37 | * @param resourcesPath expecting this path to hold (with their names): 38 | * Base Vocabulary - base_vocabulary.txt 39 | * Vocabulary - vocabulary.json 40 | * Merges - merges.txt 41 | */ 42 | public RobertaTokenizerResources(@NonNull final String resourcesPath) { 43 | this.baseVocabularyMap = loadBaseVocabulary(resourcesPath); 44 | this.vocabularyMap = loadVocabulary(resourcesPath); 45 | this.bpeRanks = loadMergesFile(resourcesPath); 46 | } 47 | 48 | private Map loadBaseVocabulary(@NonNull final String resourcesPath) { 49 | final Path baseVocabPath = Paths.get(resourcesPath, BASE_VOCABULARY_FILE_NAME); 50 | try { 51 | checkPathExists(baseVocabPath, 52 | String.format("base vocabulary file path for Roberta: [ %s ] was not found", baseVocabPath)); 53 | final Map baseVocabMap = new Gson() 54 | .fromJson(Files.readString(baseVocabPath), new TypeToken>(){}.getType()); 55 | return Collections.unmodifiableMap(baseVocabMap); 56 | } catch (IOException e) { 57 | throw new IllegalStateException(String.format( 58 | "Failed to load base vocabulary map for Roberta from [ %s ]", baseVocabPath), e); 59 | } 60 | } 61 | 62 | private Map loadVocabulary(@NonNull final String resourcesPath) { 63 | final Path vocabPath = Paths.get(resourcesPath, VOCABULARY_FILE_NAME); 64 | try { 65 | checkPathExists(vocabPath, 66 | String.format("vocabulary file path for Roberta: [%s] was not found", vocabPath)); 67 | final Map vocabMap = new Gson() 68 | .fromJson(Files.readString(vocabPath), new TypeToken>(){}.getType()); 69 | return Collections.unmodifiableMap(vocabMap); 70 | } catch (IOException e) { 71 | throw new IllegalStateException(String.format( 72 | "Failed to load vocabulary for Roberta from file path [ %s ]", vocabPath), e); 73 | } 74 | } 75 | 76 | /** 77 | * This method allows merges file to be with or without the header. 78 | * Other than that, it will accept in every line one BiGram ONLY, split by one space. 79 | * 80 | * @param resourcesPath resources dir path 81 | * @return the merges map 82 | */ 83 | private Map loadMergesFile(@NonNull final String resourcesPath) { 84 | final Path mergesPath = Paths.get(resourcesPath, MERGES_FILE_NAME); 85 | try { 86 | checkPathExists(mergesPath, 87 | String.format("%s merges file path: [%s] was not found", RobertaTokenizerResources.class.getSimpleName(), 88 | mergesPath)); 89 | 90 | final List lines = Files.readAllLines(mergesPath, StandardCharsets.UTF_8); 91 | final int startIndex = isMergesFileWithHeader(lines) ? 1 : 0; 92 | 93 | return IntStream.range(startIndex, lines.size()).boxed() 94 | .collect(Collectors.toUnmodifiableMap(idx -> BiGram.of(lines.get(idx).split(" ")), Function.identity())); 95 | } catch (IOException e) { 96 | throw new IllegalStateException(String.format( 97 | "Failed to load merges file for Roberta from file path [ %s ]", mergesPath), e); 98 | } 99 | } 100 | 101 | /** 102 | * Encoding the given key to a mapped String which represents a character from the base vocabulary. 103 | * Since the input is of type byte values we except only values [-127, 128]. 104 | * Shifting the range with the unsigned int operation to [0, 255] \ 105 | * the exact size of our base vocab map - what assures us valid input. 106 | * 107 | * @param key - byte to encode 108 | * @return associated String according to the base vocabulary json 109 | */ 110 | public String encodeByte(final byte key) { 111 | // In case the byte is negative we add to it 256 by a Bitwise AND so it will be in range [0, 255] 112 | // This solution was taken from the below StackOverflow thread 113 | // https://stackoverflow.com/questions/22575308/getbytes-returns-negative-number/22575346#22575346 114 | return baseVocabularyMap.get(Byte.toUnsignedInt(key)); 115 | } 116 | 117 | /** 118 | * Converts a word into an integer (long) according to the word vocabulary file 119 | * @param word (or subword) after bpe was applied on it 120 | * @param defaultValue positive integer 121 | * @return mapped token according to the vocabulary or default value if it didn't exist 122 | */ 123 | public Long encodeWord(@NonNull final String word, final long defaultValue) { 124 | return vocabularyMap.getOrDefault(word, defaultValue); 125 | } 126 | 127 | /** 128 | * Returns the rank for the given BiGram according to the rank file 129 | * @param biGram a pair of Strings 130 | * @param defaultValue positive integer 131 | * @return the rank of that pair or default value if it doesn't exist 132 | */ 133 | public Integer getRankOrDefault(@NonNull final BiGram biGram, final int defaultValue) { 134 | return bpeRanks.getOrDefault(biGram, defaultValue); 135 | } 136 | 137 | /** 138 | * Since we use HuggingFace tokenizers, the merges file output might have a comment in the head of the file like: 139 | * "#version: 0.2 - Trained by `huggingface/tokenizers`" 140 | * 141 | * @param lines - all lines of the merges file 142 | * @return true if merges file starts with a comment and false o.w. 143 | */ 144 | private boolean isMergesFileWithHeader(@NonNull final List lines) { 145 | checkState(!lines.isEmpty(), "provided empty merges file"); 146 | final String header = lines.get(0); 147 | return header.split(" ").length != 2; 148 | } 149 | 150 | private static void checkPathExists(final Path path, final String errorMsg) throws FileNotFoundException { 151 | if (!Files.exists(path)) { 152 | throw new FileNotFoundException(errorMsg); 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /src/main/resources/checkstyle.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 47 | 48 | 49 | 50 | 51 | 53 | 54 | 55 | 56 | 57 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 67 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 105 | 106 | 107 | 108 | 110 | 111 | 112 | 113 | 115 | 116 | 117 | 118 | 120 | 121 | 122 | 123 | 125 | 126 | 127 | 128 | 130 | 131 | 132 | 133 | 134 | 136 | 137 | 138 | 139 | 141 | 142 | 143 | 144 | 146 | 147 | 148 | 149 | 151 | 152 | 153 | 154 | 156 | 157 | 158 | 159 | 161 | 162 | 163 | 164 | 166 | 168 | 170 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | --------------------------------------------------------------------------------