├── requirements.txt ├── src └── main │ ├── resources │ ├── model │ │ └── dl │ │ │ └── roberta.onnx │ └── properties.properties │ └── java │ ├── sy │ ├── bert │ │ ├── tokenizer │ │ │ └── Tokenizer.java │ │ ├── LoadModel.java │ │ ├── tokenizerimpl │ │ │ ├── BasicTokenizer.java │ │ │ ├── WordpieceTokenizer.java │ │ │ └── BertTokenizer.java │ │ └── utils │ │ │ └── TokenizerUtils.java │ └── BertMask.java │ └── util │ ├── PropertiesReader.java │ └── CollectionUtil.java ├── .gitignore ├── onnx-java.iml ├── .idea ├── codeStyles │ ├── codeStyleConfig.xml │ └── Project.xml ├── .gitignore ├── misc.xml ├── compiler.xml └── jarRepositories.xml ├── pom.xml └── README.MD /requirements.txt: -------------------------------------------------------------------------------- 1 | java11 2 | onnxruntime -------------------------------------------------------------------------------- /src/main/resources/model/dl/roberta.onnx: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project exclude paths 2 | /target/ -------------------------------------------------------------------------------- /onnx-java.iml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/main/resources/properties.properties: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jiangnanboy/onnx-java/HEAD/src/main/resources/properties.properties -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/tokenizer/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package sy.bert.tokenizer; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * @author sy 7 | * @date 2022/5/2 14:03 8 | */ 9 | public interface Tokenizer { 10 | public List tokenize(String text); 11 | } 12 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Zeppelin ignored files 8 | /ZeppelinRemoteNotebooks/ 9 | # Editor-based HTTP Client requests 10 | /httpRequests/ 11 | -------------------------------------------------------------------------------- /.idea/codeStyles/Project.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /src/main/java/util/PropertiesReader.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | import java.io.IOException; 4 | import java.io.InputStreamReader; 5 | import java.util.Properties; 6 | 7 | /** 8 | * @author sy 9 | * @date 2022/2/2 9:07 10 | */ 11 | public class PropertiesReader { 12 | private static Properties properties = new Properties(); 13 | 14 | static { 15 | try { 16 | properties.load(new InputStreamReader(PropertiesReader.class.getClassLoader().getResourceAsStream("properties.properties"), "UTF-8")); 17 | 18 | } catch (IOException e) { 19 | e.printStackTrace(); 20 | } 21 | } 22 | 23 | public static String get(String keyName) { 24 | return properties.getProperty(keyName); 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /.idea/jarRepositories.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 9 | 10 | 14 | 15 | 19 | 20 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.example 8 | onnx-java 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 13 | com.microsoft.onnxruntime 14 | onnxruntime 15 | 1.11.0 16 | 17 | 18 | 19 | org.apache.commons 20 | commons-lang3 21 | 3.7 22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | 这里利用java加载onnx模型,并进行推理。 2 | 3 | #### 步骤 4 | 5 | 1.利用java加载onnx模型,并进行推理预测。这里采用roberta模型的onnx版。 6 | 7 | 2.pytorch模型从[这里](https://huggingface.co/uer/chinese_roberta_L-2_H-512) 下载。 8 | 9 | 3.pytorch模型转为onnx见[这里](https://github.com/jiangnanboy/model2onnx) 。 10 | 11 | #### 使用 12 | 1.sy/BertMask 13 | 14 | ``` 15 | String text = "中国的首都是[MASK]京。"; 16 | Triple, Integer> triple = null; 17 | try { 18 | triple = parseInputText(text); 19 | } catch (Exception e) { 20 | e.printStackTrace(); 21 | } 22 | var maskPredictions = predMask(triple); 23 | System.out.println(maskPredictions); 24 | ``` 25 | 26 | 2.result 27 | ``` 28 | String text = "中国的首都是[MASK]京。"; 29 | 30 | tokens -> [[CLS], 中, 国, 的, 首, 都, 是, [MASK], 京, 。, [SEP]] 31 | [MASK] predictions -> [北, 南, 东, 燕, 望] 32 | 33 | String text = "我家后面有一[MASK]大树。"; 34 | 35 | tokens -> [[CLS], 我, 家, 后, 面, 有, 一, [MASK], 大, 树, 。, [SEP]] 36 | [MASK] predictions -> [棵, 个, 株, 只, 颗] 37 | ``` 38 | 39 | #### 参考 40 | https://github.com/jiangnanboy/model2onnx 41 | 42 | https://huggingface.co/uer/chinese_roberta_L-2_H-512 43 | 44 | https://arxiv.org/pdf/1907.11692.pdf 45 | 46 | https://github.com/ankiteciitkgp/bertTokenizer 47 | 48 | https://arxiv.org/pdf/1810.04805.pdf 49 | -------------------------------------------------------------------------------- /src/main/java/util/CollectionUtil.java: -------------------------------------------------------------------------------- 1 | package util; 2 | 3 | import java.util.*; 4 | 5 | /** 6 | * @author sy 7 | * @date 2022/2/2 9:36 8 | */ 9 | public class CollectionUtil { 10 | 11 | public static List newArrayList() { 12 | return new ArrayList<>(); 13 | } 14 | 15 | public static List newArrayList(List list) { 16 | return new ArrayList<>(list); 17 | } 18 | 19 | public static LinkedList newLinkedList() { 20 | return new LinkedList<>(); 21 | } 22 | 23 | public static List newArrayList(int N) { 24 | return new ArrayList<>(N); 25 | } 26 | 27 | public static List newArrayList(Set entry) { 28 | return new ArrayList<>(entry); 29 | } 30 | 31 | public static Set newHashset() { 32 | return new HashSet<>(); 33 | } 34 | 35 | public static Set newHashset(List entry) { 36 | return new HashSet<>(entry); 37 | } 38 | 39 | public static Map newHashMap() { 40 | return new HashMap<>(); 41 | } 42 | 43 | public static LinkedHashMap newLinkedHashMap() { 44 | return new LinkedHashMap<>(); 45 | } 46 | 47 | public static Map newTreeMap() { 48 | return new TreeMap<>(); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/LoadModel.java: -------------------------------------------------------------------------------- 1 | package sy.bert; 2 | 3 | import ai.onnxruntime.OrtEnvironment; 4 | import ai.onnxruntime.OrtException; 5 | import ai.onnxruntime.OrtSession; 6 | import util.PropertiesReader; 7 | 8 | import java.util.Optional; 9 | 10 | /** 11 | * @author sy 12 | * @date 2022/5/2 15:51 13 | */ 14 | public class LoadModel { 15 | 16 | public static OrtSession session; 17 | public static OrtEnvironment env; 18 | /** 19 | * load onnx model 20 | * @throws OrtException 21 | */ 22 | public static void loadOnnxModel() throws OrtException { 23 | System.out.println("load onnx model..."); 24 | String onnxPath = LoadModel.class.getClassLoader().getResource(PropertiesReader.get("onnx_model_path")).getPath().replaceFirst("/", ""); 25 | env = OrtEnvironment.getEnvironment(); 26 | session = env.createSession(onnxPath, new OrtSession.SessionOptions()); 27 | } 28 | 29 | /** 30 | * close onnx model 31 | */ 32 | public static void closeOnnxModel() { 33 | System.out.println("close onnx model..."); 34 | if (Optional.of(session).isPresent()) { 35 | try { 36 | session.close(); 37 | } catch (OrtException e) { 38 | e.printStackTrace(); 39 | } 40 | } 41 | if(Optional.of(env).isPresent()) { 42 | env.close(); 43 | } 44 | } 45 | 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/tokenizerimpl/BasicTokenizer.java: -------------------------------------------------------------------------------- 1 | package sy.bert.tokenizerimpl; 2 | 3 | import sy.bert.tokenizer.Tokenizer; 4 | import sy.bert.utils.TokenizerUtils; 5 | import util.CollectionUtil; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * @author sy 11 | * @date 2022/5/2 14:03 12 | */ 13 | public class BasicTokenizer implements Tokenizer { 14 | private boolean do_lower_case = true; 15 | private List never_split; 16 | private boolean tokenize_chinese_chars = true; 17 | private List specialTokens; 18 | public BasicTokenizer(boolean do_lower_case, List never_split, boolean tokenize_chinese_chars) { 19 | this.do_lower_case = do_lower_case; 20 | if (never_split == null) { 21 | this.never_split = CollectionUtil.newArrayList(); 22 | } else { 23 | this.never_split = never_split; 24 | } 25 | this.tokenize_chinese_chars = tokenize_chinese_chars; 26 | } 27 | 28 | public BasicTokenizer() { 29 | } 30 | 31 | @Override 32 | public List tokenize(String text) { 33 | text = TokenizerUtils.clean_text(text); 34 | if (tokenize_chinese_chars) { 35 | text = TokenizerUtils.tokenize_chinese_chars(text); 36 | } 37 | List orig_tokens = TokenizerUtils.whitespace_tokenize(text); 38 | List split_tokens = CollectionUtil.newArrayList(); 39 | for (String token : orig_tokens) { 40 | if (do_lower_case && !never_split.contains(token)) { 41 | token = TokenizerUtils.run_strip_accents(token); 42 | split_tokens.addAll(TokenizerUtils.run_split_on_punc(token, never_split)); 43 | } else { 44 | split_tokens.add(token); 45 | } 46 | } 47 | return TokenizerUtils.whitespace_tokenize(String.join(" ", split_tokens)); 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/tokenizerimpl/WordpieceTokenizer.java: -------------------------------------------------------------------------------- 1 | package sy.bert.tokenizerimpl; 2 | 3 | import sy.bert.tokenizer.Tokenizer; 4 | import sy.bert.utils.TokenizerUtils; 5 | import util.CollectionUtil; 6 | 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | /** 11 | * @author sy 12 | * @date 2022/5/2 14:03 13 | */ 14 | public class WordpieceTokenizer implements Tokenizer { 15 | private Map vocab; 16 | private String unk_token; 17 | private int max_input_chars_per_word; 18 | private List specialTokensList; 19 | 20 | public WordpieceTokenizer(Map vocab, String unk_token, int max_input_chars_per_word) { 21 | this.vocab = vocab; 22 | this.unk_token = unk_token; 23 | this.max_input_chars_per_word = max_input_chars_per_word; 24 | } 25 | 26 | public WordpieceTokenizer(Map vocab, String unk_token, List specialTokensList) { 27 | this.vocab = vocab; 28 | this.unk_token = unk_token; 29 | this.specialTokensList = specialTokensList; 30 | this.max_input_chars_per_word = 100; 31 | } 32 | 33 | @Override 34 | public List tokenize(String text) { 35 | List output_tokens = CollectionUtil.newArrayList(); 36 | if(this.specialTokensList.contains(text)) { 37 | output_tokens.add(text); 38 | return output_tokens; 39 | } 40 | for (String token : TokenizerUtils.whitespace_tokenize(text)) { 41 | if (token.length() > max_input_chars_per_word) { 42 | output_tokens.add(unk_token); 43 | continue; 44 | } 45 | boolean is_bad = false; 46 | int start = 0; 47 | 48 | List sub_tokens = CollectionUtil.newArrayList(); 49 | while (start < token.length()) { 50 | int end = token.length(); 51 | String cur_substr = ""; 52 | while (start < end) { 53 | String substr = token.substring(start, end); 54 | if (start > 0) { 55 | substr = "##" + substr; 56 | } 57 | if (vocab.containsKey(substr)) { 58 | cur_substr = substr; 59 | break; 60 | } 61 | end -= 1; 62 | } 63 | if (cur_substr == "") { 64 | is_bad = true; 65 | break; 66 | } 67 | sub_tokens.add(cur_substr); 68 | start = end; 69 | } 70 | if (is_bad) { 71 | output_tokens.add(unk_token); 72 | } else { 73 | output_tokens.addAll(sub_tokens); 74 | } 75 | } 76 | return output_tokens; 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/tokenizerimpl/BertTokenizer.java: -------------------------------------------------------------------------------- 1 | package sy.bert.tokenizerimpl; 2 | 3 | import sy.bert.tokenizer.Tokenizer; 4 | import sy.bert.utils.TokenizerUtils; 5 | import util.CollectionUtil; 6 | import util.PropertiesReader; 7 | 8 | import java.io.IOException; 9 | import java.util.List; 10 | import java.util.Map; 11 | import java.util.stream.Collectors; 12 | 13 | /** 14 | * @author sy 15 | * @date 2022/5/2 14:03 16 | */ 17 | public class BertTokenizer implements Tokenizer { 18 | private String vocab_file = BertTokenizer.class.getClassLoader().getResource(PropertiesReader.get("bert_vocab")).getPath().replaceFirst("/", ""); 19 | private Map token_id_map; 20 | private Map id_token_map; 21 | private boolean do_lower_case = true; 22 | private boolean do_basic_tokenize = true; 23 | private List never_split; 24 | public String unk_token = "[UNK]"; 25 | public String sep_token = "[SEP]"; 26 | public String pad_token = "[PAD]"; 27 | public String cls_token = "[CLS]"; 28 | public String mask_token = "[MASK]"; 29 | private boolean tokenize_chinese_chars = true; 30 | private BasicTokenizer basic_tokenizer; 31 | private WordpieceTokenizer wordpiece_tokenizer; 32 | 33 | private static final int MAX_LEN = 512; 34 | 35 | public BertTokenizer(String vocab_file, boolean do_lower_case, boolean do_basic_tokenize, List never_split, 36 | String unk_token, String sep_token, String pad_token, String cls_token, String mask_token, 37 | boolean tokenize_chinese_chars) { 38 | this.vocab_file = vocab_file; 39 | this.do_lower_case = do_lower_case; 40 | this.do_basic_tokenize = do_basic_tokenize; 41 | this.never_split = never_split; 42 | this.unk_token = unk_token; 43 | this.sep_token = sep_token; 44 | this.pad_token = pad_token; 45 | this.cls_token = cls_token; 46 | this.mask_token = mask_token; 47 | this.tokenize_chinese_chars = tokenize_chinese_chars; 48 | init(); 49 | } 50 | 51 | public BertTokenizer() { 52 | init(); 53 | } 54 | 55 | private void init() { 56 | System.out.println("init bertTokenizer..."); 57 | try { 58 | this.token_id_map = load_vocab(vocab_file); 59 | } catch (IOException e) { 60 | e.printStackTrace(); 61 | } 62 | this.id_token_map = CollectionUtil.newHashMap(); 63 | for (String key : token_id_map.keySet()) { 64 | this.id_token_map.put(token_id_map.get(key), key); 65 | } 66 | never_split = CollectionUtil.newArrayList(); 67 | never_split.add(unk_token); 68 | never_split.add(sep_token); 69 | never_split.add(pad_token); 70 | never_split.add(cls_token); 71 | never_split.add(mask_token); 72 | if (do_basic_tokenize) { 73 | this.basic_tokenizer = new BasicTokenizer(do_lower_case, never_split, tokenize_chinese_chars); 74 | } 75 | this.wordpiece_tokenizer = new WordpieceTokenizer(token_id_map, unk_token, never_split); 76 | } 77 | 78 | private Map load_vocab(String vocab_file_name) throws IOException { 79 | System.out.println("load vocab ..."); 80 | return TokenizerUtils.generateTokenIdMap(vocab_file_name); 81 | } 82 | 83 | @Override 84 | public List tokenize(String text) { 85 | List split_tokens = CollectionUtil.newArrayList(); 86 | if (do_basic_tokenize) { 87 | for (String token : basic_tokenizer.tokenize(text)) { 88 | for (String sub_token : wordpiece_tokenizer.tokenize(token)) { 89 | split_tokens.add(sub_token); 90 | } 91 | } 92 | } else { 93 | split_tokens = wordpiece_tokenizer.tokenize(text); 94 | } 95 | split_tokens.add(0, "[CLS]"); 96 | split_tokens.add("[SEP]"); 97 | return split_tokens; 98 | } 99 | 100 | public List basicTokenize(String text) { 101 | List tokenizeList = basic_tokenizer.tokenize(text); 102 | tokenizeList.add(0, "[CLS]"); 103 | tokenizeList.add("[SEP]"); 104 | return tokenizeList; 105 | } 106 | 107 | public String convert_tokens_to_string(List tokens) { 108 | // Converts a sequence of tokens (string) in a single string. 109 | return tokens.stream().map(s -> s.replace("##", "")).collect(Collectors.joining(" ")); 110 | } 111 | 112 | public List convert_tokens_to_ids(List tokens) { 113 | List output = CollectionUtil.newArrayList(); 114 | for (String s : tokens) { 115 | output.add(token_id_map.get(s.toLowerCase())); 116 | } 117 | return output; 118 | } 119 | 120 | public int convert_tokens_to_ids(String token) { 121 | return token_id_map.get(token.toLowerCase()); 122 | } 123 | 124 | public List convert_ids_to_tokens(List ids) { 125 | List output = CollectionUtil.newArrayList(); 126 | for(int id : ids) { 127 | output.add(id_token_map.get(id)); 128 | } 129 | return output; 130 | } 131 | 132 | public String convert_ids_to_tokens(int id) { 133 | return id_token_map.get(id); 134 | } 135 | 136 | public int vocab_size() { 137 | return token_id_map.size(); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /src/main/java/sy/BertMask.java: -------------------------------------------------------------------------------- 1 | package sy; 2 | 3 | import ai.onnxruntime.*; 4 | import org.apache.commons.lang3.tuple.Triple; 5 | import sy.bert.LoadModel; 6 | import sy.bert.tokenizerimpl.BertTokenizer; 7 | import util.CollectionUtil; 8 | 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * @author sy 14 | * @date 2022/5/2 14:03 15 | */ 16 | public class BertMask { 17 | static BertTokenizer tokenizer; 18 | public static void main(String[] args) { 19 | String text = "中国的首都是北[MASK]。"; 20 | text = "我家后面有一[MASK]大树。"; 21 | Triple, Integer> triple = null; 22 | try { 23 | triple = parseInputText(text); 24 | } catch (Exception e) { 25 | e.printStackTrace(); 26 | } 27 | var maskPredictions = predMask(triple); 28 | System.out.println(maskPredictions); 29 | } 30 | 31 | static { 32 | tokenizer = new BertTokenizer(); 33 | try { 34 | LoadModel.loadOnnxModel(); 35 | } catch (OrtException e) { 36 | e.printStackTrace(); 37 | } 38 | } 39 | 40 | /** 41 | * tokenize text 42 | * @param text 43 | * @return 44 | * @throws Exception 45 | */ 46 | public static Triple, Integer> parseInputText(String text) throws Exception{ 47 | var env = LoadModel.env; 48 | List tokens = tokenizer.tokenize(text); 49 | 50 | System.out.println(tokens); 51 | 52 | List tokenIds = tokenizer.convert_tokens_to_ids(tokens); 53 | int maskId = tokenIds.indexOf(tokenizer.convert_tokens_to_ids("[MASK]")); 54 | long[] inputIds = new long[tokenIds.size()]; 55 | long[] attentionMask = new long[tokenIds.size()]; 56 | long[] tokenTypeIds = new long[tokenIds.size()]; 57 | for(int index=0; index < tokenIds.size(); index ++) { 58 | inputIds[index] = tokenIds.get(index); 59 | attentionMask[index] = 1; 60 | tokenTypeIds[index] = 0; 61 | } 62 | long[] shape = new long[]{1, inputIds.length}; 63 | Object ObjInputIds = OrtUtil.reshape(inputIds, shape); 64 | Object ObjAttentionMask = OrtUtil.reshape(attentionMask, shape); 65 | Object ObjTokenTypeIds = OrtUtil.reshape(tokenTypeIds, shape); 66 | OnnxTensor input_ids = OnnxTensor.createTensor(env, ObjInputIds); 67 | OnnxTensor attention_mask = OnnxTensor.createTensor(env, ObjAttentionMask); 68 | OnnxTensor token_type_ids = OnnxTensor.createTensor(env, ObjTokenTypeIds); 69 | var inputs = Map.of("input_ids", input_ids, "attention_mask", attention_mask, "token_type_ids", token_type_ids); 70 | return Triple.of(tokenizer, inputs, maskId); 71 | } 72 | 73 | /** 74 | * predict mask 75 | * @param triple 76 | * @return 77 | */ 78 | public static List predMask(Triple, Integer> triple) { 79 | return predMask(triple, 5); 80 | } 81 | 82 | public static List predMask(Triple, Integer> triple, int topK) { 83 | var tokenizer = triple.getLeft(); 84 | var inputs =triple.getMiddle(); 85 | var maskId = triple.getRight(); 86 | List maskResults = null; 87 | try{ 88 | var session = LoadModel.session; 89 | try(var results = session.run(inputs)) { 90 | OnnxValue onnxValue = results.get(0); 91 | float[][][] labels = (float[][][]) onnxValue.getValue(); 92 | float[] maskLables = labels[0][maskId]; 93 | int[] index = predSort(maskLables); 94 | maskResults = CollectionUtil.newArrayList(); 95 | for(int idx=0; idx < topK; idx ++) { 96 | maskResults.add(tokenizer.convert_ids_to_tokens(index[idx])); 97 | } 98 | } 99 | } catch (OrtException e) { 100 | e.printStackTrace(); 101 | } 102 | return maskResults; 103 | } 104 | 105 | /** 106 | * 得到最大概率label对应的index 107 | * @param probabilities 108 | * @return 109 | */ 110 | public static int predMax(float[] probabilities) { 111 | float maxVal = Float.NEGATIVE_INFINITY; 112 | int idx = 0; 113 | for (int i = 0; i < probabilities.length; i++) { 114 | if (probabilities[i] > maxVal) { 115 | maxVal = probabilities[i]; 116 | idx = i; 117 | } 118 | } 119 | return idx; 120 | } 121 | 122 | /** 123 | * 对预测的概率进行排序 124 | * @param probabilities 125 | * @return 126 | */ 127 | private static int[] predSort(float[] probabilities) { 128 | int[] indices = new int[probabilities.length]; 129 | for (int i = 0; i < probabilities.length; i++) { 130 | indices[i] = i; 131 | } 132 | predSort(probabilities, 0, probabilities.length-1, indices); 133 | return indices; 134 | } 135 | 136 | private static void predSort(float[] probabilities, int begin, int end, int[] indices) { 137 | if (begin >= 0 && begin < probabilities.length && end >= 0 && end < probabilities.length && begin < end) { 138 | int i = begin, j = end; 139 | float vot = probabilities[i]; 140 | int temp = indices[i]; 141 | while (i != j) { 142 | while(i < j && probabilities[j] <= vot) j--; 143 | if(i < j) { 144 | probabilities[i] = probabilities[j]; 145 | indices[i] = indices[j]; 146 | i++; 147 | } 148 | while(i < j && probabilities[i] >= vot) i++; 149 | if(i < j) { 150 | probabilities[j] = probabilities[i]; 151 | indices[j] = indices[i]; 152 | j--; 153 | } 154 | } 155 | probabilities[i] = vot; 156 | indices[i] = temp; 157 | predSort(probabilities, begin, j-1, indices); 158 | predSort(probabilities, i+1, end, indices); 159 | } 160 | } 161 | 162 | } 163 | 164 | -------------------------------------------------------------------------------- /src/main/java/sy/bert/utils/TokenizerUtils.java: -------------------------------------------------------------------------------- 1 | package sy.bert.utils; 2 | 3 | import util.CollectionUtil; 4 | 5 | import java.io.BufferedReader; 6 | import java.io.IOException; 7 | import java.nio.charset.StandardCharsets; 8 | import java.nio.file.Files; 9 | import java.nio.file.Paths; 10 | import java.text.Normalizer; 11 | import java.text.Normalizer.Form; 12 | import java.util.Arrays; 13 | import java.util.List; 14 | import java.util.Map; 15 | import java.util.Optional; 16 | 17 | /** 18 | * @author sy 19 | * @date 2022/5/2 14:03 20 | */ 21 | public class TokenizerUtils { 22 | 23 | public static String clean_text(String text) { 24 | // Performs invalid character removal and whitespace cleanup on text.""" 25 | 26 | StringBuilder output = new StringBuilder(); 27 | for (int i = 0; i < text.length(); i++) { 28 | Character c = text.charAt(i); 29 | int cp = (int) c; 30 | if (cp == 0 || cp == 0xFFFD || _is_control(c)) { 31 | continue; 32 | } 33 | if (_is_whitespace(c)) { 34 | output.append(" "); 35 | } else { 36 | output.append(c); 37 | } 38 | } 39 | return output.toString(); 40 | } 41 | 42 | public static String tokenize_chinese_chars(String text) { 43 | // Adds whitespace around any CJK character. 44 | StringBuilder output = new StringBuilder(); 45 | for (int i = 0; i < text.length(); i++) { 46 | Character c = text.charAt(i); 47 | int cp = (int) c; 48 | if (_is_chinese_char(cp)) { 49 | output.append(" "); 50 | output.append(c); 51 | output.append(" "); 52 | } else { 53 | output.append(c); 54 | } 55 | } 56 | return output.toString(); 57 | } 58 | 59 | public static List whitespace_tokenize(String text) { 60 | // Runs basic whitespace cleaning and splitting on a piece of text. 61 | text = text.trim(); 62 | if ((text != null) && (text != "")) { 63 | return CollectionUtil.newArrayList(Arrays.asList(text.split("\\s+"))); 64 | } 65 | return CollectionUtil.newArrayList(); 66 | 67 | } 68 | 69 | public static String run_strip_accents(String token) { 70 | token = Normalizer.normalize(token, Form.NFD); 71 | StringBuilder output = new StringBuilder(); 72 | for (int i = 0; i < token.length(); i++) { 73 | Character c = token.charAt(i); 74 | if (Character.NON_SPACING_MARK != Character.getType(c)) { 75 | output.append(c); 76 | } 77 | } 78 | return output.toString(); 79 | } 80 | 81 | public static List run_split_on_punc(String token, List never_split) { 82 | // Splits punctuation on a piece of text. 83 | List output = CollectionUtil.newArrayList(); 84 | if (Optional.of(never_split).isPresent()) { 85 | if(never_split.contains(token)) { 86 | output.add(token); 87 | return output; 88 | } else { 89 | for(String specialToken : never_split) { 90 | if(token.contains(specialToken)) { 91 | int specialTokenIndex = token.indexOf(specialToken); 92 | if(specialTokenIndex == 0) { 93 | String other = token.substring(specialToken.length()); 94 | output.add(specialToken); 95 | output.add(other); 96 | return output; 97 | } else { 98 | String other = token.substring(0, token.indexOf(specialToken)); 99 | output.add(other); 100 | output.add(specialToken); 101 | String another = token.substring(specialTokenIndex + specialToken.length()); 102 | if (another.length() != 0) { 103 | output.add(another); 104 | } 105 | return output; 106 | } 107 | } 108 | } 109 | } 110 | } 111 | 112 | boolean start_new_word = true; 113 | StringBuilder str = new StringBuilder(); 114 | for (int i = 0; i < token.length(); i++) { 115 | Character c = token.charAt(i); 116 | if (_is_punctuation(c)) { 117 | if (str.length() > 0) { 118 | output.add(str.toString()); 119 | str.setLength(0); 120 | } 121 | output.add(c.toString()); 122 | start_new_word = true; 123 | } else { 124 | if (start_new_word && str.length() > 0) { 125 | output.add(str.toString()); 126 | str.setLength(0); 127 | } 128 | start_new_word = false; 129 | str.append(c); 130 | } 131 | } 132 | if (str.length() > 0) { 133 | output.add(str.toString()); 134 | } 135 | return output; 136 | } 137 | 138 | public static Map generateTokenIdMap(String file) throws IOException { 139 | Map token_id_map = CollectionUtil.newHashMap(); 140 | if (file == null) 141 | return token_id_map; 142 | try(BufferedReader br = Files.newBufferedReader(Paths.get(file), StandardCharsets.UTF_8)) { 143 | String line; 144 | int index = 0; 145 | while ((line = br.readLine()) != null) { 146 | token_id_map.put(line.trim().toLowerCase(), index); 147 | index ++; 148 | } 149 | } 150 | return token_id_map; 151 | } 152 | 153 | private static boolean _is_punctuation(char c) { 154 | // Checks whether `chars` is a punctuation character. 155 | int cp = (int) c; 156 | // We treat all non-letter/number ASCII as punctuation. 157 | // Characters such as "^", "$", and "`" are not in the Unicode 158 | // Punctuation class but we treat them as punctuation anyways, for 159 | // consistency. 160 | if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) { 161 | return true; 162 | } 163 | int charType = Character.getType(c); 164 | if (Character.CONNECTOR_PUNCTUATION == charType || Character.DASH_PUNCTUATION == charType 165 | || Character.END_PUNCTUATION == charType || Character.FINAL_QUOTE_PUNCTUATION == charType 166 | || Character.INITIAL_QUOTE_PUNCTUATION == charType || Character.OTHER_PUNCTUATION == charType 167 | || Character.START_PUNCTUATION == charType) { 168 | return true; 169 | } 170 | return false; 171 | } 172 | 173 | private static boolean _is_whitespace(char c) { 174 | // Checks whether `chars` is a whitespace character. 175 | // \t, \n, and \r are technically contorl characters but we treat them 176 | // as whitespace since they are generally considered as such. 177 | if (c == ' ' || c == '\t' || c == '\n' || c == '\r') { 178 | return true; 179 | } 180 | 181 | int charType = Character.getType(c); 182 | if (Character.SPACE_SEPARATOR == charType) { 183 | return true; 184 | } 185 | return false; 186 | } 187 | 188 | private static boolean _is_control(char c) { 189 | // Checks whether `chars` is a control character. 190 | // These are technically control characters but we count them as whitespace 191 | // characters. 192 | if (c == '\t' || c == '\n' || c == '\r') { 193 | return false; 194 | } 195 | 196 | int charType = Character.getType(c); 197 | if (Character.CONTROL == charType || Character.DIRECTIONALITY_COMMON_NUMBER_SEPARATOR == charType 198 | || Character.FORMAT == charType || Character.PRIVATE_USE == charType || Character.SURROGATE == charType 199 | || Character.UNASSIGNED == charType) { 200 | return true; 201 | } 202 | return false; 203 | } 204 | 205 | private static boolean _is_chinese_char(int cp) { 206 | // Checks whether CP is the codepoint of a CJK character.""" 207 | // This defines a "chinese character" as anything in the CJK Unicode block: 208 | // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 209 | // 210 | // Note that the CJK Unicode block is NOT all Japanese and Korean characters, 211 | // despite its name. The modern Korean Hangul alphabet is a different block, 212 | // as is Japanese Hiragana and Katakana. Those alphabets are used to write 213 | // space-separated words, so they are not treated specially and handled 214 | // like the all of the other languages. 215 | if ((cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) || (cp >= 0x20000 && cp <= 0x2A6DF) 216 | || (cp >= 0x2A700 && cp <= 0x2B73F) || (cp >= 0x2B740 && cp <= 0x2B81F) 217 | || (cp >= 0x2B820 && cp <= 0x2CEAF) || (cp >= 0xF900 && cp <= 0xFAFF) 218 | || (cp >= 0x2F800 && cp <= 0x2FA1F)) { 219 | return true; 220 | } 221 | 222 | return false; 223 | } 224 | } 225 | --------------------------------------------------------------------------------