├── requirements.txt
├── src
└── main
│ ├── resources
│ ├── model
│ │ └── dl
│ │ │ └── roberta.onnx
│ └── properties.properties
│ └── java
│ ├── sy
│ ├── bert
│ │ ├── tokenizer
│ │ │ └── Tokenizer.java
│ │ ├── LoadModel.java
│ │ ├── tokenizerimpl
│ │ │ ├── BasicTokenizer.java
│ │ │ ├── WordpieceTokenizer.java
│ │ │ └── BertTokenizer.java
│ │ └── utils
│ │ │ └── TokenizerUtils.java
│ └── BertMask.java
│ └── util
│ ├── PropertiesReader.java
│ └── CollectionUtil.java
├── .gitignore
├── onnx-java.iml
├── .idea
├── codeStyles
│ ├── codeStyleConfig.xml
│ └── Project.xml
├── .gitignore
├── misc.xml
├── compiler.xml
└── jarRepositories.xml
├── pom.xml
└── README.MD
/requirements.txt:
--------------------------------------------------------------------------------
1 | java11
2 | onnxruntime
--------------------------------------------------------------------------------
/src/main/resources/model/dl/roberta.onnx:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project exclude paths
2 | /target/
--------------------------------------------------------------------------------
/onnx-java.iml:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/main/resources/properties.properties:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jiangnanboy/onnx-java/HEAD/src/main/resources/properties.properties
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/tokenizer/Tokenizer.java:
--------------------------------------------------------------------------------
1 | package sy.bert.tokenizer;
2 |
3 | import java.util.List;
4 |
5 | /**
6 | * @author sy
7 | * @date 2022/5/2 14:03
8 | */
9 | public interface Tokenizer {
10 | public List tokenize(String text);
11 | }
12 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Zeppelin ignored files
8 | /ZeppelinRemoteNotebooks/
9 | # Editor-based HTTP Client requests
10 | /httpRequests/
11 |
--------------------------------------------------------------------------------
/.idea/codeStyles/Project.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/src/main/java/util/PropertiesReader.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | import java.io.IOException;
4 | import java.io.InputStreamReader;
5 | import java.util.Properties;
6 |
7 | /**
8 | * @author sy
9 | * @date 2022/2/2 9:07
10 | */
11 | public class PropertiesReader {
12 | private static Properties properties = new Properties();
13 |
14 | static {
15 | try {
16 | properties.load(new InputStreamReader(PropertiesReader.class.getClassLoader().getResourceAsStream("properties.properties"), "UTF-8"));
17 |
18 | } catch (IOException e) {
19 | e.printStackTrace();
20 | }
21 | }
22 |
23 | public static String get(String keyName) {
24 | return properties.getProperty(keyName);
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/.idea/jarRepositories.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | org.example
8 | onnx-java
9 | 1.0-SNAPSHOT
10 |
11 |
12 |
13 | com.microsoft.onnxruntime
14 | onnxruntime
15 | 1.11.0
16 |
17 |
18 |
19 | org.apache.commons
20 | commons-lang3
21 | 3.7
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | 这里利用java加载onnx模型,并进行推理。
2 |
3 | #### 步骤
4 |
5 | 1.利用java加载onnx模型,并进行推理预测。这里采用roberta模型的onnx版。
6 |
7 | 2.pytorch模型从[这里](https://huggingface.co/uer/chinese_roberta_L-2_H-512) 下载。
8 |
9 | 3.pytorch模型转为onnx见[这里](https://github.com/jiangnanboy/model2onnx) 。
10 |
11 | #### 使用
12 | 1.sy/BertMask
13 |
14 | ```
15 | String text = "中国的首都是[MASK]京。";
16 | Triple, Integer> triple = null;
17 | try {
18 | triple = parseInputText(text);
19 | } catch (Exception e) {
20 | e.printStackTrace();
21 | }
22 | var maskPredictions = predMask(triple);
23 | System.out.println(maskPredictions);
24 | ```
25 |
26 | 2.result
27 | ```
28 | String text = "中国的首都是[MASK]京。";
29 |
30 | tokens -> [[CLS], 中, 国, 的, 首, 都, 是, [MASK], 京, 。, [SEP]]
31 | [MASK] predictions -> [北, 南, 东, 燕, 望]
32 |
33 | String text = "我家后面有一[MASK]大树。";
34 |
35 | tokens -> [[CLS], 我, 家, 后, 面, 有, 一, [MASK], 大, 树, 。, [SEP]]
36 | [MASK] predictions -> [棵, 个, 株, 只, 颗]
37 | ```
38 |
39 | #### 参考
40 | https://github.com/jiangnanboy/model2onnx
41 |
42 | https://huggingface.co/uer/chinese_roberta_L-2_H-512
43 |
44 | https://arxiv.org/pdf/1907.11692.pdf
45 |
46 | https://github.com/ankiteciitkgp/bertTokenizer
47 |
48 | https://arxiv.org/pdf/1810.04805.pdf
49 |
--------------------------------------------------------------------------------
/src/main/java/util/CollectionUtil.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | import java.util.*;
4 |
5 | /**
6 | * @author sy
7 | * @date 2022/2/2 9:36
8 | */
9 | public class CollectionUtil {
10 |
11 | public static List newArrayList() {
12 | return new ArrayList<>();
13 | }
14 |
15 | public static List newArrayList(List list) {
16 | return new ArrayList<>(list);
17 | }
18 |
19 | public static LinkedList newLinkedList() {
20 | return new LinkedList<>();
21 | }
22 |
23 | public static List newArrayList(int N) {
24 | return new ArrayList<>(N);
25 | }
26 |
27 | public static List newArrayList(Set entry) {
28 | return new ArrayList<>(entry);
29 | }
30 |
31 | public static Set newHashset() {
32 | return new HashSet<>();
33 | }
34 |
35 | public static Set newHashset(List entry) {
36 | return new HashSet<>(entry);
37 | }
38 |
39 | public static Map newHashMap() {
40 | return new HashMap<>();
41 | }
42 |
43 | public static LinkedHashMap newLinkedHashMap() {
44 | return new LinkedHashMap<>();
45 | }
46 |
47 | public static Map newTreeMap() {
48 | return new TreeMap<>();
49 | }
50 |
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/LoadModel.java:
--------------------------------------------------------------------------------
1 | package sy.bert;
2 |
3 | import ai.onnxruntime.OrtEnvironment;
4 | import ai.onnxruntime.OrtException;
5 | import ai.onnxruntime.OrtSession;
6 | import util.PropertiesReader;
7 |
8 | import java.util.Optional;
9 |
10 | /**
11 | * @author sy
12 | * @date 2022/5/2 15:51
13 | */
14 | public class LoadModel {
15 |
16 | public static OrtSession session;
17 | public static OrtEnvironment env;
18 | /**
19 | * load onnx model
20 | * @throws OrtException
21 | */
22 | public static void loadOnnxModel() throws OrtException {
23 | System.out.println("load onnx model...");
24 | String onnxPath = LoadModel.class.getClassLoader().getResource(PropertiesReader.get("onnx_model_path")).getPath().replaceFirst("/", "");
25 | env = OrtEnvironment.getEnvironment();
26 | session = env.createSession(onnxPath, new OrtSession.SessionOptions());
27 | }
28 |
29 | /**
30 | * close onnx model
31 | */
32 | public static void closeOnnxModel() {
33 | System.out.println("close onnx model...");
34 | if (Optional.of(session).isPresent()) {
35 | try {
36 | session.close();
37 | } catch (OrtException e) {
38 | e.printStackTrace();
39 | }
40 | }
41 | if(Optional.of(env).isPresent()) {
42 | env.close();
43 | }
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/tokenizerimpl/BasicTokenizer.java:
--------------------------------------------------------------------------------
1 | package sy.bert.tokenizerimpl;
2 |
3 | import sy.bert.tokenizer.Tokenizer;
4 | import sy.bert.utils.TokenizerUtils;
5 | import util.CollectionUtil;
6 |
7 | import java.util.List;
8 |
9 | /**
10 | * @author sy
11 | * @date 2022/5/2 14:03
12 | */
13 | public class BasicTokenizer implements Tokenizer {
14 | private boolean do_lower_case = true;
15 | private List never_split;
16 | private boolean tokenize_chinese_chars = true;
17 | private List specialTokens;
18 | public BasicTokenizer(boolean do_lower_case, List never_split, boolean tokenize_chinese_chars) {
19 | this.do_lower_case = do_lower_case;
20 | if (never_split == null) {
21 | this.never_split = CollectionUtil.newArrayList();
22 | } else {
23 | this.never_split = never_split;
24 | }
25 | this.tokenize_chinese_chars = tokenize_chinese_chars;
26 | }
27 |
28 | public BasicTokenizer() {
29 | }
30 |
31 | @Override
32 | public List tokenize(String text) {
33 | text = TokenizerUtils.clean_text(text);
34 | if (tokenize_chinese_chars) {
35 | text = TokenizerUtils.tokenize_chinese_chars(text);
36 | }
37 | List orig_tokens = TokenizerUtils.whitespace_tokenize(text);
38 | List split_tokens = CollectionUtil.newArrayList();
39 | for (String token : orig_tokens) {
40 | if (do_lower_case && !never_split.contains(token)) {
41 | token = TokenizerUtils.run_strip_accents(token);
42 | split_tokens.addAll(TokenizerUtils.run_split_on_punc(token, never_split));
43 | } else {
44 | split_tokens.add(token);
45 | }
46 | }
47 | return TokenizerUtils.whitespace_tokenize(String.join(" ", split_tokens));
48 | }
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/tokenizerimpl/WordpieceTokenizer.java:
--------------------------------------------------------------------------------
1 | package sy.bert.tokenizerimpl;
2 |
3 | import sy.bert.tokenizer.Tokenizer;
4 | import sy.bert.utils.TokenizerUtils;
5 | import util.CollectionUtil;
6 |
7 | import java.util.List;
8 | import java.util.Map;
9 |
10 | /**
11 | * @author sy
12 | * @date 2022/5/2 14:03
13 | */
14 | public class WordpieceTokenizer implements Tokenizer {
15 | private Map vocab;
16 | private String unk_token;
17 | private int max_input_chars_per_word;
18 | private List specialTokensList;
19 |
20 | public WordpieceTokenizer(Map vocab, String unk_token, int max_input_chars_per_word) {
21 | this.vocab = vocab;
22 | this.unk_token = unk_token;
23 | this.max_input_chars_per_word = max_input_chars_per_word;
24 | }
25 |
26 | public WordpieceTokenizer(Map vocab, String unk_token, List specialTokensList) {
27 | this.vocab = vocab;
28 | this.unk_token = unk_token;
29 | this.specialTokensList = specialTokensList;
30 | this.max_input_chars_per_word = 100;
31 | }
32 |
33 | @Override
34 | public List tokenize(String text) {
35 | List output_tokens = CollectionUtil.newArrayList();
36 | if(this.specialTokensList.contains(text)) {
37 | output_tokens.add(text);
38 | return output_tokens;
39 | }
40 | for (String token : TokenizerUtils.whitespace_tokenize(text)) {
41 | if (token.length() > max_input_chars_per_word) {
42 | output_tokens.add(unk_token);
43 | continue;
44 | }
45 | boolean is_bad = false;
46 | int start = 0;
47 |
48 | List sub_tokens = CollectionUtil.newArrayList();
49 | while (start < token.length()) {
50 | int end = token.length();
51 | String cur_substr = "";
52 | while (start < end) {
53 | String substr = token.substring(start, end);
54 | if (start > 0) {
55 | substr = "##" + substr;
56 | }
57 | if (vocab.containsKey(substr)) {
58 | cur_substr = substr;
59 | break;
60 | }
61 | end -= 1;
62 | }
63 | if (cur_substr == "") {
64 | is_bad = true;
65 | break;
66 | }
67 | sub_tokens.add(cur_substr);
68 | start = end;
69 | }
70 | if (is_bad) {
71 | output_tokens.add(unk_token);
72 | } else {
73 | output_tokens.addAll(sub_tokens);
74 | }
75 | }
76 | return output_tokens;
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/tokenizerimpl/BertTokenizer.java:
--------------------------------------------------------------------------------
1 | package sy.bert.tokenizerimpl;
2 |
3 | import sy.bert.tokenizer.Tokenizer;
4 | import sy.bert.utils.TokenizerUtils;
5 | import util.CollectionUtil;
6 | import util.PropertiesReader;
7 |
8 | import java.io.IOException;
9 | import java.util.List;
10 | import java.util.Map;
11 | import java.util.stream.Collectors;
12 |
13 | /**
14 | * @author sy
15 | * @date 2022/5/2 14:03
16 | */
17 | public class BertTokenizer implements Tokenizer {
18 | private String vocab_file = BertTokenizer.class.getClassLoader().getResource(PropertiesReader.get("bert_vocab")).getPath().replaceFirst("/", "");
19 | private Map token_id_map;
20 | private Map id_token_map;
21 | private boolean do_lower_case = true;
22 | private boolean do_basic_tokenize = true;
23 | private List never_split;
24 | public String unk_token = "[UNK]";
25 | public String sep_token = "[SEP]";
26 | public String pad_token = "[PAD]";
27 | public String cls_token = "[CLS]";
28 | public String mask_token = "[MASK]";
29 | private boolean tokenize_chinese_chars = true;
30 | private BasicTokenizer basic_tokenizer;
31 | private WordpieceTokenizer wordpiece_tokenizer;
32 |
33 | private static final int MAX_LEN = 512;
34 |
35 | public BertTokenizer(String vocab_file, boolean do_lower_case, boolean do_basic_tokenize, List never_split,
36 | String unk_token, String sep_token, String pad_token, String cls_token, String mask_token,
37 | boolean tokenize_chinese_chars) {
38 | this.vocab_file = vocab_file;
39 | this.do_lower_case = do_lower_case;
40 | this.do_basic_tokenize = do_basic_tokenize;
41 | this.never_split = never_split;
42 | this.unk_token = unk_token;
43 | this.sep_token = sep_token;
44 | this.pad_token = pad_token;
45 | this.cls_token = cls_token;
46 | this.mask_token = mask_token;
47 | this.tokenize_chinese_chars = tokenize_chinese_chars;
48 | init();
49 | }
50 |
51 | public BertTokenizer() {
52 | init();
53 | }
54 |
55 | private void init() {
56 | System.out.println("init bertTokenizer...");
57 | try {
58 | this.token_id_map = load_vocab(vocab_file);
59 | } catch (IOException e) {
60 | e.printStackTrace();
61 | }
62 | this.id_token_map = CollectionUtil.newHashMap();
63 | for (String key : token_id_map.keySet()) {
64 | this.id_token_map.put(token_id_map.get(key), key);
65 | }
66 | never_split = CollectionUtil.newArrayList();
67 | never_split.add(unk_token);
68 | never_split.add(sep_token);
69 | never_split.add(pad_token);
70 | never_split.add(cls_token);
71 | never_split.add(mask_token);
72 | if (do_basic_tokenize) {
73 | this.basic_tokenizer = new BasicTokenizer(do_lower_case, never_split, tokenize_chinese_chars);
74 | }
75 | this.wordpiece_tokenizer = new WordpieceTokenizer(token_id_map, unk_token, never_split);
76 | }
77 |
78 | private Map load_vocab(String vocab_file_name) throws IOException {
79 | System.out.println("load vocab ...");
80 | return TokenizerUtils.generateTokenIdMap(vocab_file_name);
81 | }
82 |
83 | @Override
84 | public List tokenize(String text) {
85 | List split_tokens = CollectionUtil.newArrayList();
86 | if (do_basic_tokenize) {
87 | for (String token : basic_tokenizer.tokenize(text)) {
88 | for (String sub_token : wordpiece_tokenizer.tokenize(token)) {
89 | split_tokens.add(sub_token);
90 | }
91 | }
92 | } else {
93 | split_tokens = wordpiece_tokenizer.tokenize(text);
94 | }
95 | split_tokens.add(0, "[CLS]");
96 | split_tokens.add("[SEP]");
97 | return split_tokens;
98 | }
99 |
100 | public List basicTokenize(String text) {
101 | List tokenizeList = basic_tokenizer.tokenize(text);
102 | tokenizeList.add(0, "[CLS]");
103 | tokenizeList.add("[SEP]");
104 | return tokenizeList;
105 | }
106 |
107 | public String convert_tokens_to_string(List tokens) {
108 | // Converts a sequence of tokens (string) in a single string.
109 | return tokens.stream().map(s -> s.replace("##", "")).collect(Collectors.joining(" "));
110 | }
111 |
112 | public List convert_tokens_to_ids(List tokens) {
113 | List output = CollectionUtil.newArrayList();
114 | for (String s : tokens) {
115 | output.add(token_id_map.get(s.toLowerCase()));
116 | }
117 | return output;
118 | }
119 |
120 | public int convert_tokens_to_ids(String token) {
121 | return token_id_map.get(token.toLowerCase());
122 | }
123 |
124 | public List convert_ids_to_tokens(List ids) {
125 | List output = CollectionUtil.newArrayList();
126 | for(int id : ids) {
127 | output.add(id_token_map.get(id));
128 | }
129 | return output;
130 | }
131 |
132 | public String convert_ids_to_tokens(int id) {
133 | return id_token_map.get(id);
134 | }
135 |
136 | public int vocab_size() {
137 | return token_id_map.size();
138 | }
139 | }
140 |
--------------------------------------------------------------------------------
/src/main/java/sy/BertMask.java:
--------------------------------------------------------------------------------
1 | package sy;
2 |
3 | import ai.onnxruntime.*;
4 | import org.apache.commons.lang3.tuple.Triple;
5 | import sy.bert.LoadModel;
6 | import sy.bert.tokenizerimpl.BertTokenizer;
7 | import util.CollectionUtil;
8 |
9 | import java.util.List;
10 | import java.util.Map;
11 |
12 | /**
13 | * @author sy
14 | * @date 2022/5/2 14:03
15 | */
16 | public class BertMask {
17 | static BertTokenizer tokenizer;
18 | public static void main(String[] args) {
19 | String text = "中国的首都是北[MASK]。";
20 | text = "我家后面有一[MASK]大树。";
21 | Triple, Integer> triple = null;
22 | try {
23 | triple = parseInputText(text);
24 | } catch (Exception e) {
25 | e.printStackTrace();
26 | }
27 | var maskPredictions = predMask(triple);
28 | System.out.println(maskPredictions);
29 | }
30 |
31 | static {
32 | tokenizer = new BertTokenizer();
33 | try {
34 | LoadModel.loadOnnxModel();
35 | } catch (OrtException e) {
36 | e.printStackTrace();
37 | }
38 | }
39 |
40 | /**
41 | * tokenize text
42 | * @param text
43 | * @return
44 | * @throws Exception
45 | */
46 | public static Triple, Integer> parseInputText(String text) throws Exception{
47 | var env = LoadModel.env;
48 | List tokens = tokenizer.tokenize(text);
49 |
50 | System.out.println(tokens);
51 |
52 | List tokenIds = tokenizer.convert_tokens_to_ids(tokens);
53 | int maskId = tokenIds.indexOf(tokenizer.convert_tokens_to_ids("[MASK]"));
54 | long[] inputIds = new long[tokenIds.size()];
55 | long[] attentionMask = new long[tokenIds.size()];
56 | long[] tokenTypeIds = new long[tokenIds.size()];
57 | for(int index=0; index < tokenIds.size(); index ++) {
58 | inputIds[index] = tokenIds.get(index);
59 | attentionMask[index] = 1;
60 | tokenTypeIds[index] = 0;
61 | }
62 | long[] shape = new long[]{1, inputIds.length};
63 | Object ObjInputIds = OrtUtil.reshape(inputIds, shape);
64 | Object ObjAttentionMask = OrtUtil.reshape(attentionMask, shape);
65 | Object ObjTokenTypeIds = OrtUtil.reshape(tokenTypeIds, shape);
66 | OnnxTensor input_ids = OnnxTensor.createTensor(env, ObjInputIds);
67 | OnnxTensor attention_mask = OnnxTensor.createTensor(env, ObjAttentionMask);
68 | OnnxTensor token_type_ids = OnnxTensor.createTensor(env, ObjTokenTypeIds);
69 | var inputs = Map.of("input_ids", input_ids, "attention_mask", attention_mask, "token_type_ids", token_type_ids);
70 | return Triple.of(tokenizer, inputs, maskId);
71 | }
72 |
73 | /**
74 | * predict mask
75 | * @param triple
76 | * @return
77 | */
78 | public static List predMask(Triple, Integer> triple) {
79 | return predMask(triple, 5);
80 | }
81 |
82 | public static List predMask(Triple, Integer> triple, int topK) {
83 | var tokenizer = triple.getLeft();
84 | var inputs =triple.getMiddle();
85 | var maskId = triple.getRight();
86 | List maskResults = null;
87 | try{
88 | var session = LoadModel.session;
89 | try(var results = session.run(inputs)) {
90 | OnnxValue onnxValue = results.get(0);
91 | float[][][] labels = (float[][][]) onnxValue.getValue();
92 | float[] maskLables = labels[0][maskId];
93 | int[] index = predSort(maskLables);
94 | maskResults = CollectionUtil.newArrayList();
95 | for(int idx=0; idx < topK; idx ++) {
96 | maskResults.add(tokenizer.convert_ids_to_tokens(index[idx]));
97 | }
98 | }
99 | } catch (OrtException e) {
100 | e.printStackTrace();
101 | }
102 | return maskResults;
103 | }
104 |
105 | /**
106 | * 得到最大概率label对应的index
107 | * @param probabilities
108 | * @return
109 | */
110 | public static int predMax(float[] probabilities) {
111 | float maxVal = Float.NEGATIVE_INFINITY;
112 | int idx = 0;
113 | for (int i = 0; i < probabilities.length; i++) {
114 | if (probabilities[i] > maxVal) {
115 | maxVal = probabilities[i];
116 | idx = i;
117 | }
118 | }
119 | return idx;
120 | }
121 |
122 | /**
123 | * 对预测的概率进行排序
124 | * @param probabilities
125 | * @return
126 | */
127 | private static int[] predSort(float[] probabilities) {
128 | int[] indices = new int[probabilities.length];
129 | for (int i = 0; i < probabilities.length; i++) {
130 | indices[i] = i;
131 | }
132 | predSort(probabilities, 0, probabilities.length-1, indices);
133 | return indices;
134 | }
135 |
136 | private static void predSort(float[] probabilities, int begin, int end, int[] indices) {
137 | if (begin >= 0 && begin < probabilities.length && end >= 0 && end < probabilities.length && begin < end) {
138 | int i = begin, j = end;
139 | float vot = probabilities[i];
140 | int temp = indices[i];
141 | while (i != j) {
142 | while(i < j && probabilities[j] <= vot) j--;
143 | if(i < j) {
144 | probabilities[i] = probabilities[j];
145 | indices[i] = indices[j];
146 | i++;
147 | }
148 | while(i < j && probabilities[i] >= vot) i++;
149 | if(i < j) {
150 | probabilities[j] = probabilities[i];
151 | indices[j] = indices[i];
152 | j--;
153 | }
154 | }
155 | probabilities[i] = vot;
156 | indices[i] = temp;
157 | predSort(probabilities, begin, j-1, indices);
158 | predSort(probabilities, i+1, end, indices);
159 | }
160 | }
161 |
162 | }
163 |
164 |
--------------------------------------------------------------------------------
/src/main/java/sy/bert/utils/TokenizerUtils.java:
--------------------------------------------------------------------------------
1 | package sy.bert.utils;
2 |
3 | import util.CollectionUtil;
4 |
5 | import java.io.BufferedReader;
6 | import java.io.IOException;
7 | import java.nio.charset.StandardCharsets;
8 | import java.nio.file.Files;
9 | import java.nio.file.Paths;
10 | import java.text.Normalizer;
11 | import java.text.Normalizer.Form;
12 | import java.util.Arrays;
13 | import java.util.List;
14 | import java.util.Map;
15 | import java.util.Optional;
16 |
17 | /**
18 | * @author sy
19 | * @date 2022/5/2 14:03
20 | */
21 | public class TokenizerUtils {
22 |
23 | public static String clean_text(String text) {
24 | // Performs invalid character removal and whitespace cleanup on text."""
25 |
26 | StringBuilder output = new StringBuilder();
27 | for (int i = 0; i < text.length(); i++) {
28 | Character c = text.charAt(i);
29 | int cp = (int) c;
30 | if (cp == 0 || cp == 0xFFFD || _is_control(c)) {
31 | continue;
32 | }
33 | if (_is_whitespace(c)) {
34 | output.append(" ");
35 | } else {
36 | output.append(c);
37 | }
38 | }
39 | return output.toString();
40 | }
41 |
42 | public static String tokenize_chinese_chars(String text) {
43 | // Adds whitespace around any CJK character.
44 | StringBuilder output = new StringBuilder();
45 | for (int i = 0; i < text.length(); i++) {
46 | Character c = text.charAt(i);
47 | int cp = (int) c;
48 | if (_is_chinese_char(cp)) {
49 | output.append(" ");
50 | output.append(c);
51 | output.append(" ");
52 | } else {
53 | output.append(c);
54 | }
55 | }
56 | return output.toString();
57 | }
58 |
59 | public static List whitespace_tokenize(String text) {
60 | // Runs basic whitespace cleaning and splitting on a piece of text.
61 | text = text.trim();
62 | if ((text != null) && (text != "")) {
63 | return CollectionUtil.newArrayList(Arrays.asList(text.split("\\s+")));
64 | }
65 | return CollectionUtil.newArrayList();
66 |
67 | }
68 |
69 | public static String run_strip_accents(String token) {
70 | token = Normalizer.normalize(token, Form.NFD);
71 | StringBuilder output = new StringBuilder();
72 | for (int i = 0; i < token.length(); i++) {
73 | Character c = token.charAt(i);
74 | if (Character.NON_SPACING_MARK != Character.getType(c)) {
75 | output.append(c);
76 | }
77 | }
78 | return output.toString();
79 | }
80 |
81 | public static List run_split_on_punc(String token, List never_split) {
82 | // Splits punctuation on a piece of text.
83 | List output = CollectionUtil.newArrayList();
84 | if (Optional.of(never_split).isPresent()) {
85 | if(never_split.contains(token)) {
86 | output.add(token);
87 | return output;
88 | } else {
89 | for(String specialToken : never_split) {
90 | if(token.contains(specialToken)) {
91 | int specialTokenIndex = token.indexOf(specialToken);
92 | if(specialTokenIndex == 0) {
93 | String other = token.substring(specialToken.length());
94 | output.add(specialToken);
95 | output.add(other);
96 | return output;
97 | } else {
98 | String other = token.substring(0, token.indexOf(specialToken));
99 | output.add(other);
100 | output.add(specialToken);
101 | String another = token.substring(specialTokenIndex + specialToken.length());
102 | if (another.length() != 0) {
103 | output.add(another);
104 | }
105 | return output;
106 | }
107 | }
108 | }
109 | }
110 | }
111 |
112 | boolean start_new_word = true;
113 | StringBuilder str = new StringBuilder();
114 | for (int i = 0; i < token.length(); i++) {
115 | Character c = token.charAt(i);
116 | if (_is_punctuation(c)) {
117 | if (str.length() > 0) {
118 | output.add(str.toString());
119 | str.setLength(0);
120 | }
121 | output.add(c.toString());
122 | start_new_word = true;
123 | } else {
124 | if (start_new_word && str.length() > 0) {
125 | output.add(str.toString());
126 | str.setLength(0);
127 | }
128 | start_new_word = false;
129 | str.append(c);
130 | }
131 | }
132 | if (str.length() > 0) {
133 | output.add(str.toString());
134 | }
135 | return output;
136 | }
137 |
138 | public static Map generateTokenIdMap(String file) throws IOException {
139 | Map token_id_map = CollectionUtil.newHashMap();
140 | if (file == null)
141 | return token_id_map;
142 | try(BufferedReader br = Files.newBufferedReader(Paths.get(file), StandardCharsets.UTF_8)) {
143 | String line;
144 | int index = 0;
145 | while ((line = br.readLine()) != null) {
146 | token_id_map.put(line.trim().toLowerCase(), index);
147 | index ++;
148 | }
149 | }
150 | return token_id_map;
151 | }
152 |
153 | private static boolean _is_punctuation(char c) {
154 | // Checks whether `chars` is a punctuation character.
155 | int cp = (int) c;
156 | // We treat all non-letter/number ASCII as punctuation.
157 | // Characters such as "^", "$", and "`" are not in the Unicode
158 | // Punctuation class but we treat them as punctuation anyways, for
159 | // consistency.
160 | if ((cp >= 33 && cp <= 47) || (cp >= 58 && cp <= 64) || (cp >= 91 && cp <= 96) || (cp >= 123 && cp <= 126)) {
161 | return true;
162 | }
163 | int charType = Character.getType(c);
164 | if (Character.CONNECTOR_PUNCTUATION == charType || Character.DASH_PUNCTUATION == charType
165 | || Character.END_PUNCTUATION == charType || Character.FINAL_QUOTE_PUNCTUATION == charType
166 | || Character.INITIAL_QUOTE_PUNCTUATION == charType || Character.OTHER_PUNCTUATION == charType
167 | || Character.START_PUNCTUATION == charType) {
168 | return true;
169 | }
170 | return false;
171 | }
172 |
173 | private static boolean _is_whitespace(char c) {
174 | // Checks whether `chars` is a whitespace character.
175 | // \t, \n, and \r are technically contorl characters but we treat them
176 | // as whitespace since they are generally considered as such.
177 | if (c == ' ' || c == '\t' || c == '\n' || c == '\r') {
178 | return true;
179 | }
180 |
181 | int charType = Character.getType(c);
182 | if (Character.SPACE_SEPARATOR == charType) {
183 | return true;
184 | }
185 | return false;
186 | }
187 |
188 | private static boolean _is_control(char c) {
189 | // Checks whether `chars` is a control character.
190 | // These are technically control characters but we count them as whitespace
191 | // characters.
192 | if (c == '\t' || c == '\n' || c == '\r') {
193 | return false;
194 | }
195 |
196 | int charType = Character.getType(c);
197 | if (Character.CONTROL == charType || Character.DIRECTIONALITY_COMMON_NUMBER_SEPARATOR == charType
198 | || Character.FORMAT == charType || Character.PRIVATE_USE == charType || Character.SURROGATE == charType
199 | || Character.UNASSIGNED == charType) {
200 | return true;
201 | }
202 | return false;
203 | }
204 |
205 | private static boolean _is_chinese_char(int cp) {
206 | // Checks whether CP is the codepoint of a CJK character."""
207 | // This defines a "chinese character" as anything in the CJK Unicode block:
208 | // https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
209 | //
210 | // Note that the CJK Unicode block is NOT all Japanese and Korean characters,
211 | // despite its name. The modern Korean Hangul alphabet is a different block,
212 | // as is Japanese Hiragana and Katakana. Those alphabets are used to write
213 | // space-separated words, so they are not treated specially and handled
214 | // like the all of the other languages.
215 | if ((cp >= 0x4E00 && cp <= 0x9FFF) || (cp >= 0x3400 && cp <= 0x4DBF) || (cp >= 0x20000 && cp <= 0x2A6DF)
216 | || (cp >= 0x2A700 && cp <= 0x2B73F) || (cp >= 0x2B740 && cp <= 0x2B81F)
217 | || (cp >= 0x2B820 && cp <= 0x2CEAF) || (cp >= 0xF900 && cp <= 0xFAFF)
218 | || (cp >= 0x2F800 && cp <= 0x2FA1F)) {
219 | return true;
220 | }
221 |
222 | return false;
223 | }
224 | }
225 |
--------------------------------------------------------------------------------