├── LICENSE ├── agreement └── Mimix Model Release Agreement.docx ├── app.py ├── conf ├── example_lm_model_conf ├── example_lm_train_conf ├── example_seq2seq_model_conf ├── example_seq2seq_train_conf ├── example_vit_model_conf ├── example_vit_train_conf ├── readme.txt ├── transformer_caption_model_conf └── transformer_caption_train_conf ├── examples ├── example_load_bert.py ├── example_load_mae.py ├── example_load_vit.py ├── example_train_caption_mnist.py ├── example_train_lm.py ├── example_train_lm_deepspeed.py ├── example_train_mnist.py ├── example_train_seq2seq.py └── test.mid ├── gen_midi.py ├── image ├── mimix.png ├── streamlit.png ├── streamlit10.png ├── streamlit2.png ├── streamlit3.png ├── streamlit4.png ├── streamlit5.png ├── streamlit6.png ├── streamlit7.png ├── streamlit8.png ├── streamlit9.png └── wechat.jpg ├── interact.py ├── mimix ├── __init__.py ├── app.py ├── bert_tokenizer.py ├── clustering.py ├── ddp.py ├── decoding.py ├── ds.py ├── evaluate.py ├── interact.py ├── layers.py ├── loss.py ├── models.py ├── optimizer.py ├── predictor.py ├── preprocess.py ├── scheduler.py ├── test.py ├── tokenization.py ├── train.py ├── utils.py └── vis.py ├── model └── readme.txt ├── preprocess.py ├── readme.md ├── requirements.txt └── solve_sudoku.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /agreement/Mimix Model Release Agreement.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/agreement/Mimix Model Release Agreement.docx -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 7 20:30:13 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | 8 | from mimix.app import run_app 9 | 10 | run_app() -------------------------------------------------------------------------------- /conf/example_lm_model_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | trg_max_len = xx 3 | n_heads = xx 4 | d_model = xx 5 | d_ff = xx 6 | n_dec_layers = xx 7 | trg_vocab_size = xx 8 | max_decode_steps = xx 9 | 10 | [float] 11 | dropout = xx 12 | 13 | [str] 14 | trg_vocab = /path/to/trg/vocab 15 | load_model = /path/to/load/for/predict 16 | search_strategy = sample 17 | trg_tokenizer = default 18 | model = transformer 19 | task = lm 20 | 21 | [bool] 22 | use_cuda = False 23 | share_emb_out_proj = True 24 | use_pre_norm = True 25 | norm_after_embedding = False 26 | norm_before_pred = True 27 | use_pos_embedding = True -------------------------------------------------------------------------------- /conf/example_lm_train_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | max_epoch = xx 3 | batch_size = xx 4 | 5 | [float] 6 | lr = xx 7 | 8 | [str] 9 | tmp_dir = path/to/processed/data 10 | train_dir = path/to/raw/train/data 11 | model_dir = path/to/save/model 12 | model_name = name 13 | reload_model = path/to/pretrain/model/for/train 14 | optimizer=adamW 15 | 16 | [bool] 17 | use_cuda = False -------------------------------------------------------------------------------- /conf/example_seq2seq_model_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | src_max_len = xx 3 | trg_max_len = xx 4 | n_heads = xx 5 | d_model = xx 6 | d_ff = xx 7 | n_enc_layers = xx 8 | n_dec_layers = xx 9 | src_vocab_size = xx 10 | trg_vocab_size = xx 11 | beam_size = xx 12 | max_decode_steps = xx 13 | 14 | [float] 15 | dropout = xx 16 | 17 | [str] 18 | src_vocab = /path/to/src/vocab 19 | trg_vocab = /path/to/trg/vocab 20 | load_model = /path/to/load/for/predict 21 | search_strategy = beam_search 22 | src_tokenizer = default 23 | trg_tokenizer = default 24 | model = transformer 25 | task = enc_dec 26 | activation = relu 27 | 28 | [bool] 29 | use_cuda = True 30 | share_src_trg_emb = True 31 | share_emb_out_proj = True 32 | use_pre_norm = True 33 | norm_after_embedding = False 34 | norm_before_pred = True -------------------------------------------------------------------------------- /conf/example_seq2seq_train_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | max_epoch = xx 3 | batch_size = xx 4 | 5 | [float] 6 | lr = xx 7 | 8 | [str] 9 | tmp_dir = path/to/processed/data 10 | train_dir = path/to/raw/train/data 11 | model_dir = path/to/save/model 12 | model_name = name 13 | reload_model = path/to/pretrain/model/for/train 14 | optimizer = adamW 15 | 16 | [bool] 17 | use_cuda = True -------------------------------------------------------------------------------- /conf/example_vit_model_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | src_max_len = 50 3 | n_heads = 8 4 | d_model = 128 5 | d_ff = 512 6 | n_enc_layers = 6 7 | n_class = 10 8 | img_h = 28 9 | img_w = 28 10 | patch_h = 4 11 | patch_w = 4 12 | n_channels = 1 13 | 14 | [float] 15 | dropout = 0.1 16 | 17 | [str] 18 | #load_model = model/vit 19 | model = transformer 20 | task = image_classification 21 | activation = gelu 22 | 23 | [bool] 24 | use_cuda = True 25 | use_pre_norm = True 26 | norm_before_pred = True 27 | use_vit_encoder = True 28 | -------------------------------------------------------------------------------- /conf/example_vit_train_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | max_epoch = 100 3 | batch_size = 1024 4 | 5 | [float] 6 | lr = 0.001 7 | 8 | [str] 9 | model_dir = ../model/vit 10 | model_name = mnist 11 | optimizer = adamW 12 | 13 | [bool] 14 | use_cuda = True 15 | -------------------------------------------------------------------------------- /conf/readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/conf/readme.txt -------------------------------------------------------------------------------- /conf/transformer_caption_model_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | src_max_len = xx 3 | trg_max_len = xx 4 | n_heads = xx 5 | d_model = xx 6 | d_ff = xx 7 | n_enc_layers = xx 8 | n_dec_layers = xx 9 | trg_vocab_size = xx 10 | beam_size = xx 11 | max_decode_steps = xx 12 | img_h = xx 13 | img_w = xx 14 | patch_h = xx 15 | patch_w = xx 16 | n_channels = xx 17 | 18 | [float] 19 | dropout = 0 20 | 21 | [str] 22 | trg_vocab = xxx 23 | load_model = xxx 24 | search_strategy = beam_search 25 | src_tokenizer = default 26 | trg_tokenizer = default 27 | model = transformer 28 | task = image2text 29 | activation = relu 30 | 31 | [bool] 32 | use_cuda = True 33 | share_src_trg_emb = False 34 | share_emb_out_proj = True 35 | use_pre_norm = True 36 | norm_after_embedding = False 37 | norm_before_pred = True 38 | use_vit_encoder = True -------------------------------------------------------------------------------- /conf/transformer_caption_train_conf: -------------------------------------------------------------------------------- 1 | [int] 2 | max_epoch = xxx 3 | batch_size = xxx 4 | save_steps = xxx 5 | 6 | [float] 7 | lr = xxx 8 | 9 | [str] 10 | model_dir = test_data/model/caption 11 | model_name = xxx 12 | optimizer = adamW 13 | 14 | [bool] 15 | use_cuda = True -------------------------------------------------------------------------------- /examples/example_load_bert.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 6 22:27:06 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | import json 11 | import torch 12 | from mimix.tokenization import build_tokenizer 13 | 14 | def load_bert_weights(bert, weight_path): 15 | """ 16 | """ 17 | bert.eval() 18 | state_dict = torch.load(weight_path) 19 | 20 | map_key_dict = {"encoder.word_embedding.W": "bert.embeddings.word_embeddings.weight", 21 | "encoder.pos_embedding.W": "bert.embeddings.position_embeddings.weight", 22 | "encoder.type_embedding.W":"bert.embeddings.token_type_embeddings.weight", 23 | "encoder.norm_emb.alpha": "bert.embeddings.LayerNorm.gamma", 24 | "encoder.norm_emb.bias": "bert.embeddings.LayerNorm.beta", 25 | "W_pool": "bert.pooler.dense.weight", 26 | "b_pool": "bert.pooler.dense.bias", 27 | "W_cls": "cls.seq_relationship.weight", 28 | "b_cls": "cls.seq_relationship.bias", 29 | "W_mlm": "cls.predictions.transform.dense.weight", 30 | "b_mlm": "cls.predictions.transform.dense.bias", 31 | "norm_mlm.alpha": "cls.predictions.transform.LayerNorm.gamma", 32 | "norm_mlm.bias": "cls.predictions.transform.LayerNorm.beta", 33 | "b_out_mlm": "cls.predictions.bias"} 34 | 35 | for i in range(bert.n_layers): 36 | map_key_dict["encoder.layers.%d.self_attention.W_q" % i] = "bert.encoder.layer.%d.attention.self.query.weight" % i 37 | map_key_dict["encoder.layers.%d.self_attention.b_q" % i] = "bert.encoder.layer.%d.attention.self.query.bias" % i 38 | map_key_dict["encoder.layers.%d.self_attention.W_k" % i] = "bert.encoder.layer.%d.attention.self.key.weight" % i 39 | map_key_dict["encoder.layers.%d.self_attention.b_k" % i] = "bert.encoder.layer.%d.attention.self.key.bias" % i 40 | map_key_dict["encoder.layers.%d.self_attention.W_v" % i] = "bert.encoder.layer.%d.attention.self.value.weight" % i 41 | map_key_dict["encoder.layers.%d.self_attention.b_v" % i] = "bert.encoder.layer.%d.attention.self.value.bias" % i 42 | map_key_dict["encoder.layers.%d.self_attention.W_o" % i] = "bert.encoder.layer.%d.attention.output.dense.weight" % i 43 | map_key_dict["encoder.layers.%d.self_attention.b_o" % i] = "bert.encoder.layer.%d.attention.output.dense.bias" % i 44 | map_key_dict["encoder.layers.%d.norm_1.alpha" % i] = "bert.encoder.layer.%d.attention.output.LayerNorm.gamma" % i 45 | map_key_dict["encoder.layers.%d.norm_1.bias" % i] = "bert.encoder.layer.%d.attention.output.LayerNorm.beta" % i 46 | map_key_dict["encoder.layers.%d.ffn.W1" % i] = "bert.encoder.layer.%d.intermediate.dense.weight" % i 47 | map_key_dict["encoder.layers.%d.ffn.b1" % i] = "bert.encoder.layer.%d.intermediate.dense.bias" % i 48 | map_key_dict["encoder.layers.%d.ffn.W2" % i] = "bert.encoder.layer.%d.output.dense.weight" % i 49 | map_key_dict["encoder.layers.%d.ffn.b2" % i] = "bert.encoder.layer.%d.output.dense.bias" % i 50 | map_key_dict["encoder.layers.%d.norm_2.alpha" % i] = "bert.encoder.layer.%d.output.LayerNorm.gamma" % i 51 | map_key_dict["encoder.layers.%d.norm_2.bias" % i] = "bert.encoder.layer.%d.output.LayerNorm.beta" % i 52 | 53 | if bert.share_emb_out_proj == False: 54 | map_key_dict["W_out_mlm"] = "cls.predictions.decoder.weight" 55 | 56 | model_state_dict = {} 57 | for key,param in bert.named_parameters(): 58 | model_state_dict[key] = state_dict[map_key_dict[key]] 59 | #model_state_dict[key] = state_dict[map_key_dict[key].replace("gamma", "weight").replace("beta", "bias")] 60 | if key == "W_out_mlm": 61 | model_state_dict[key] = state_dict[map_key_dict[key]].T 62 | 63 | bert.load_state_dict(model_state_dict, False) 64 | 65 | return bert 66 | 67 | 68 | def load_bert_model(model_path, use_cuda=False): 69 | """ 70 | """ 71 | config = json.load(open(os.path.join(model_path, "config.json"))) 72 | mimix_config = {} 73 | mimix_config["attn_dropout"] = config["attention_probs_dropout_prob"] 74 | mimix_config["activation"] = config["hidden_act"] 75 | mimix_config["dropout"] = config["hidden_dropout_prob"] 76 | mimix_config["d_model"] = config["hidden_size"] 77 | mimix_config["d_ff"] = config["intermediate_size"] 78 | mimix_config["ln_eps"] = config["layer_norm_eps"] 79 | mimix_config["src_max_len"] = config["max_position_embeddings"] 80 | mimix_config["n_heads"] = config["num_attention_heads"] 81 | mimix_config["n_enc_layers"] = config["num_hidden_layers"] 82 | mimix_config["n_types"] = config["type_vocab_size"] 83 | mimix_config["src_vocab_size"] = config["vocab_size"] 84 | mimix_config["use_pre_norm"] = False 85 | mimix_config["norm_after_embedding"] = True 86 | mimix_config["with_mlm"] = True 87 | from mimix.models import TransformerEncoder 88 | mimix_config["symbols"] = {"_pad_": "[PAD]", 89 | "_bos_": "[unused1]", 90 | "_eos_": "[unused2]", 91 | "_unk_": "[UNK]", 92 | "_cls_": "[CLS]", 93 | "_sep_": "[SEP]", 94 | "_mask_": "[MASK]" 95 | } 96 | vocab = {line.strip():i for i,line in enumerate(open(os.path.join(model_path, "vocab.txt"), "r", encoding="utf-8"))} 97 | mimix_config["symbol2id"] = {k:vocab[mimix_config["symbols"][k]] for k in mimix_config["symbols"]} 98 | bert = TransformerEncoder(**mimix_config) 99 | bert = load_bert_weights(bert, os.path.join(model_path, "pytorch_model.bin")) 100 | if use_cuda == True: 101 | bert = bert.cuda() 102 | return bert 103 | 104 | 105 | bert_model_path = "model/pretrain/bert" 106 | #bert_model_path = "model/pretrain/bert_large" 107 | model = load_bert_model(bert_model_path) 108 | model.eval() 109 | tokenizer = build_tokenizer(**{"tokenizer":"bert", "vocab_file":os.path.join(bert_model_path, "vocab.txt")}) 110 | 111 | #Test for Chinese BERT MLM Task: [MASK]国的首都是曼谷 112 | #output: ['泰'] 0.9495071768760681 113 | x = [101,103] + tokenizer.tokenize_to_ids("国的首都是曼谷") + [102] 114 | y = model({"x":torch.tensor([x], dtype=torch.long), "type_ids":torch.zeros([1,len(x)], dtype=torch.long)})["mlm_logits"] 115 | prob = torch.softmax(y, -1) 116 | word_id = y[0][x.index(103)].argmax().item() 117 | prob = prob[0][x.index(103)][word_id].item() 118 | print(tokenizer.convert_ids_to_tokens([word_id]), prob) 119 | 120 | #Test for Chinese BERT MLM Task: 韩国的首都是[MASK]尔 121 | #output: ['首'] 0.9999905824661255 122 | x = [101] + tokenizer.tokenize_to_ids("韩国的首都是") + [103] + tokenizer.tokenize_to_ids("尔") + [102] 123 | y = model({"x":torch.tensor([x], dtype=torch.long), "type_ids":torch.zeros([1,len(x)], dtype=torch.long)})["mlm_logits"] 124 | prob = torch.softmax(y, -1) 125 | word_id = y[0][x.index(103)].argmax().item() 126 | prob = prob[0][x.index(103)][word_id].item() 127 | print(tokenizer.convert_ids_to_tokens([word_id]), prob) 128 | -------------------------------------------------------------------------------- /examples/example_load_mae.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jan 6 18:15:21 2024 4 | 5 | @author: 1 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | import re 11 | import torch 12 | from mimix.models import build_model 13 | from mimix.utils import load_model_config 14 | 15 | model = build_model(load_model_config("conf/mae_base_conf")) 16 | a = torch.load("model/pretrain/vit/mae_visualize_vit_base.pth")["model"] 17 | b = {} 18 | for k in a: 19 | if k == "cls_token": 20 | b["encoder.cls"] = a[k][0,0,:] 21 | elif k == "mask_token": 22 | b["mask"] = a[k][0,0,:] 23 | elif k == "decoder_pred.weight": 24 | b["out_proj.W"] = a[k] 25 | elif k == "decoder_pred.bias": 26 | b["out_proj.b"] = a[k] 27 | elif k == "decoder_embed.weight": 28 | b["dec_embedding.W"] = a[k] 29 | elif k == "decoder_embed.bias": 30 | b["dec_embedding.b"] = a[k] 31 | elif k == "patch_embed.proj.weight": 32 | b["encoder.patch_embedding.weight"] = a[k] 33 | elif k == "patch_embed.proj.bias": 34 | b["encoder.patch_embedding.bias"] = a[k] 35 | elif k == "norm.weight": 36 | b["encoder.norm.alpha"] = a[k] 37 | elif k == "norm.bias": 38 | b["encoder.norm.bias"] = a[k] 39 | elif k == "decoder_norm.weight": 40 | b["decoder.norm.alpha"] = a[k] 41 | elif k == "decoder_norm.bias": 42 | b["decoder.norm.bias"] = a[k] 43 | elif re.search("blocks.[0-9]+.norm[0-9]+.weight", k): 44 | idx = re.findall("[0-9]+", k)[0] 45 | idx2 = re.findall("[0-9]+", k)[1] 46 | k2 = "encoder.layers.%s.norm_%s.alpha" % (idx, idx2) 47 | if "decoder" in k: 48 | k2 = k2.replace("encoder", "decoder") 49 | b[k2] = a[k] 50 | elif re.search("blocks.[0-9]+.norm[0-9]+.bias", k): 51 | idx = re.findall("[0-9]+", k)[0] 52 | idx2 = re.findall("[0-9]+", k)[1] 53 | k2 = "encoder.layers.%s.norm_%s.bias" % (idx, idx2) 54 | if "decoder" in k: 55 | k2 = k2.replace("encoder", "decoder") 56 | b[k2] = a[k] 57 | elif re.search("blocks.[0-9]+.attn.qkv.weight", k): 58 | idx = re.findall("[0-9]+", k)[0] 59 | d = a[k].shape[0] // 3 60 | k2 = "encoder.layers.%s.self_attention.W_q" % idx 61 | if "decoder" in k: 62 | k2 = k2.replace("encoder", "decoder") 63 | b[k2] = a[k][:d] 64 | k2 = "encoder.layers.%s.self_attention.W_k" % idx 65 | if "decoder" in k: 66 | k2 = k2.replace("encoder", "decoder") 67 | b[k2] = a[k][d:2*d] 68 | k2 = "encoder.layers.%s.self_attention.W_v" % idx 69 | if "decoder" in k: 70 | k2 = k2.replace("encoder", "decoder") 71 | b[k2] = a[k][2*d:] 72 | 73 | elif re.search("blocks.[0-9]+.attn.qkv.bias", k): 74 | idx = re.findall("[0-9]+", k)[0] 75 | d = a[k].shape[0] // 3 76 | k2 = "encoder.layers.%s.self_attention.b_q" % idx 77 | if "decoder" in k: 78 | k2 = k2.replace("encoder", "decoder") 79 | b[k2] = a[k][:d] 80 | k2 = "encoder.layers.%s.self_attention.b_k" % idx 81 | if "decoder" in k: 82 | k2 = k2.replace("encoder", "decoder") 83 | b[k2] = a[k][d:2*d] 84 | k2 = "encoder.layers.%s.self_attention.b_v" % idx 85 | if "decoder" in k: 86 | k2 = k2.replace("encoder", "decoder") 87 | b[k2] = a[k][2*d:] 88 | 89 | elif re.search("blocks.[0-9]+.attn.proj.weight", k): 90 | idx = re.findall("[0-9]+", k)[0] 91 | k2 = "encoder.layers.%s.self_attention.W_o" % idx 92 | if "decoder" in k: 93 | k2 = k2.replace("encoder", "decoder") 94 | b[k2] = a[k] 95 | elif re.search("blocks.[0-9]+.attn.proj.bias", k): 96 | idx = re.findall("[0-9]+", k)[0] 97 | k2 = "encoder.layers.%s.self_attention.b_o" % idx 98 | if "decoder" in k: 99 | k2 = k2.replace("encoder", "decoder") 100 | b[k2] = a[k] 101 | elif re.search("blocks.[0-9]+.mlp.fc[0-9]+.weight", k): 102 | idx = re.findall("[0-9]+", k)[0] 103 | idx2 = re.findall("[0-9]+", k)[1] 104 | k2 = "encoder.layers.%s.ffn.W%s" % (idx, idx2) 105 | if "decoder" in k: 106 | k2 = k2.replace("encoder", "decoder") 107 | b[k2] = a[k] 108 | elif re.search("blocks.[0-9]+.mlp.fc[0-9]+.bias", k): 109 | idx = re.findall("[0-9]+", k)[0] 110 | idx2 = re.findall("[0-9]+", k)[1] 111 | k2 = "encoder.layers.%s.ffn.b%s" % (idx, idx2) 112 | if "decoder" in k: 113 | k2 = k2.replace("encoder", "decoder") 114 | b[k2] = a[k] 115 | elif re.search("pos_embed", k): 116 | k2 = "encoder.pos_embedding.W" 117 | if "decoder" in k: 118 | k2 = k2.replace("encoder", "decoder") 119 | b[k2] = a[k][0] 120 | 121 | model.load_state_dict(b) 122 | 123 | import sys 124 | import os 125 | import requests 126 | 127 | import torch 128 | import numpy as np 129 | 130 | import matplotlib.pyplot as plt 131 | from PIL import Image 132 | 133 | def show_image(image, title=''): 134 | # image is [H, W, 3] 135 | assert image.shape[2] == 3 136 | plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()) 137 | plt.title(title, fontsize=16) 138 | plt.axis('off') 139 | return 140 | 141 | def run_one_image(img, model): 142 | x = torch.tensor(img) 143 | 144 | # make it a batch-like 145 | x = x.unsqueeze(dim=0) 146 | x = torch.einsum('nhwc->nchw', x) 147 | 148 | # run MAE 149 | model.eval() 150 | with torch.no_grad(): 151 | 152 | outputs = model({"x":x.float(), "mask_ratio":0.75}) 153 | dec_output = outputs["output"] 154 | reconstruct = outputs["reconstruct"] 155 | mask = outputs["mask"] 156 | patchify_x = outputs["patchify_x"] 157 | 158 | y = torch.einsum('nchw->nhwc', reconstruct).cpu() 159 | 160 | print(mask) 161 | # visualize the mask 162 | mask = mask.detach() 163 | mask = mask.unsqueeze(-1).repeat(1, 1, model.pw*model.ph*model.n_channels) # (N, H*W, p*p*3) 164 | print(mask) 165 | h = model.img_h // model.ph 166 | w = model.img_w // model.pw 167 | mask = mask.reshape(shape=(mask.shape[0], h, w, model.ph, model.pw, model.n_channels)) 168 | print(mask) 169 | mask = torch.einsum('nhwpqc->nchpwq', mask) 170 | mask = mask.reshape(shape=(mask.shape[0], model.n_channels, model.img_h, model.img_w)) 171 | mask = torch.einsum('nchw->nhwc', mask).detach().cpu() 172 | print(mask) 173 | 174 | x = torch.einsum('nchw->nhwc', x) 175 | 176 | # masked image 177 | im_masked = x * (1 - mask) 178 | 179 | # MAE reconstruction pasted with visible patches 180 | im_paste = x * (1 - mask) + y * mask 181 | 182 | # make the plt figure larger 183 | plt.rcParams['figure.figsize'] = [24, 24] 184 | 185 | plt.subplot(1, 4, 1) 186 | show_image(x[0], "original") 187 | 188 | plt.subplot(1, 4, 2) 189 | show_image(im_masked[0], "masked") 190 | 191 | plt.subplot(1, 4, 3) 192 | show_image(y[0], "reconstruction") 193 | 194 | plt.subplot(1, 4, 4) 195 | show_image(im_paste[0], "reconstruction + visible") 196 | 197 | plt.show() 198 | 199 | 200 | torch.save(model.state_dict(), "model/mae/mae.base.model") 201 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 202 | imagenet_std = np.array([0.229, 0.224, 0.225]) 203 | # load an image 204 | img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145 205 | # img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851 206 | img = Image.open(requests.get(img_url, stream=True).raw) 207 | img = img.resize((224, 224)) 208 | img = np.array(img) / 255. 209 | 210 | assert img.shape == (224, 224, 3) 211 | 212 | # normalize by ImageNet mean and std 213 | img = img - imagenet_mean 214 | img = img / imagenet_std 215 | 216 | plt.rcParams['figure.figsize'] = [5, 5] 217 | show_image(torch.tensor(img)) 218 | 219 | torch.manual_seed(2) 220 | print('MAE with pixel reconstruction:') 221 | run_one_image(img, model) -------------------------------------------------------------------------------- /examples/example_load_vit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 14 23:36:34 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | import re 11 | import numpy as np 12 | import torch 13 | from mimix.models import TransformerEncoder 14 | 15 | a = np.load("model/pretrain/vit/imagenet21k+imagenet2012_ViT-B_32.npz") 16 | b = {} 17 | for k in a.files: 18 | k2 = k.replace('Transformer/', 'encoder.') 19 | k2 = re.sub('encoderblock_([0-9]+)/', 'layers.\\1.', k2) 20 | k2 = k2.replace('MultiHeadDotProductAttention_1/', 'self_attention.') 21 | k2 = k2.replace('query/kernel', 'W_q') 22 | k2 = k2.replace('query/bias', 'b_q') 23 | k2 = k2.replace('key/kernel', 'W_k') 24 | k2 = k2.replace('key/bias', 'b_k') 25 | k2 = k2.replace('value/kernel', 'W_v') 26 | k2 = k2.replace('value/bias', 'b_v') 27 | k2 = k2.replace('out/kernel', 'W_o') 28 | k2 = k2.replace('out/bias', 'b_o') 29 | k2 = k2.replace('LayerNorm_0/scale', "norm_1.alpha") 30 | k2 = k2.replace('LayerNorm_0/bias', "norm_1.bias") 31 | k2 = k2.replace('LayerNorm_2/scale', "norm_2.alpha") 32 | k2 = k2.replace('LayerNorm_2/bias', "norm_2.bias") 33 | k2 = k2.replace('MlpBlock_3/Dense_0/kernel', "ffn.W1") 34 | k2 = k2.replace('MlpBlock_3/Dense_0/bias', "ffn.b1") 35 | k2 = k2.replace('MlpBlock_3/Dense_1/kernel', "ffn.W2") 36 | k2 = k2.replace('MlpBlock_3/Dense_1/bias', "ffn.b2") 37 | 38 | k2 = k2.replace('encoder_norm/bias', 'norm.bias') 39 | k2 = k2.replace('encoder_norm/scale', 'norm.alpha') 40 | 41 | w = torch.from_numpy(a[k]) 42 | if k == "cls": 43 | w = w.flatten() 44 | k2 = 'encoder.cls' 45 | if k == 'head/kernel': 46 | k2 = 'W_cls' 47 | if k == 'head/bias': 48 | k2 = 'b_cls' 49 | 50 | if k == 'embedding/kernel': 51 | k2 = 'encoder.patch_embedding.weight' 52 | w = torch.einsum('abcd->dcab', w) 53 | elif 'kernel' in k and "head" not in k: 54 | if "attention" in k2: 55 | w = w.view(a["embedding/bias"].shape[0], a["embedding/bias"].shape[0]) 56 | w = w.transpose(0, 1) 57 | 58 | if "Attention" in k and "bias" in k: 59 | w = w.flatten(0) 60 | 61 | if k == 'embedding/bias': 62 | k2 = 'encoder.patch_embedding.bias' 63 | 64 | if 'pos_embedding' in k: 65 | k2 = 'encoder.pos_embedding.W' 66 | w = w.squeeze(0) 67 | 68 | b[k2] = w 69 | 70 | 71 | vit = TransformerEncoder(use_vit_encoder=True, 72 | d_model=768, 73 | n_heads=12, 74 | img_w=384, 75 | img_h=384, 76 | patch_w=32, 77 | patch_h=32, 78 | n_class=1000, 79 | n_enc_layers=12, 80 | n_channels=3, 81 | activation="gelu", 82 | use_pre_norm=True, 83 | norm_before_pred=True, 84 | ln_eps=1e-6) 85 | vit.load_state_dict(b) 86 | 87 | from PIL import Image 88 | #from imagenet 89 | image = Image.open("n01440764_10026.jpg") 90 | from torchvision import transforms 91 | transform = transforms.Compose([ 92 | transforms.Resize((384, 384)), 93 | transforms.ToTensor(), 94 | transforms.Normalize(0.5, 0.5, 0.5), 95 | ]) 96 | x = transform(image).unsqueeze(0) 97 | 98 | vit.eval() 99 | with torch.no_grad(): 100 | outputs = vit({"x":x}) 101 | 102 | #predict label: 0 103 | print(outputs["cls_logits"].argmax(-1)) 104 | -------------------------------------------------------------------------------- /examples/example_train_caption_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jul 28 22:06:35 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | from argparse import ArgumentParser 11 | from mimix.models import build_model, load_model_weights 12 | from mimix.optimizer import build_optimizer 13 | from mimix.scheduler import build_scheduler 14 | from mimix.loss import seq_cross_entropy 15 | from mimix.utils import real_path, load_config, load_model_config 16 | from mimix.train import train 17 | from mimix.evaluate import eval_acc 18 | import torch 19 | import numpy as np 20 | import random 21 | from torchvision import datasets, transforms 22 | 23 | class MNIST(): 24 | """ 25 | """ 26 | def __init__(self, data, batch_size, device): 27 | """ 28 | """ 29 | self.data = data 30 | self.batch_size = batch_size 31 | self.device = device 32 | self.idx = list(range(0, len(self.data))) 33 | random.shuffle(self.idx) 34 | 35 | 36 | def local_shuffle(self): 37 | """ 38 | """ 39 | random.shuffle(self.idx) 40 | 41 | 42 | def __call__(self, steps=0): 43 | """ 44 | """ 45 | i = steps * self.batch_size 46 | while i < len(self.data): 47 | x = torch.cat([self.data[j][0][0].unsqueeze(0) for j in self.idx[i:i+self.batch_size]]) 48 | x = x.float().unsqueeze(1).to(self.device) 49 | y = torch.tensor([[1, 18, 8 + self.data[j][1]] for j in self.idx[i:i+self.batch_size]]) 50 | y = y.long().to(self.device) 51 | y_target = torch.tensor([[18, 8 + self.data[j][1], 2] for j in self.idx[i:i+self.batch_size]]) 52 | y_target = y_target.long().to(self.device) 53 | yield {"x":x, "y":y}, {"y_target":y_target} 54 | i += self.batch_size 55 | 56 | 57 | def main(model_config, train_config): 58 | """ 59 | """ 60 | model = build_model(model_config) 61 | if train_config.get("reload_model", None) is not None: 62 | model = load_model_weights(model, real_path(train_config["reload_model"])) 63 | 64 | device = "cpu" 65 | if train_config["use_cuda"] == True: 66 | device_id = train_config.get("device_id", "0") 67 | device = 'cuda:%s' % device_id 68 | 69 | model = model.to(device) 70 | eps = train_config.get("eps", 0) 71 | model.loss_fn = lambda x,y:seq_cross_entropy(x["logits"], y["y_target"], eps, model.PAD) 72 | symbol2id = model_config["symbol2id"] 73 | batch_size = train_config["batch_size"] 74 | 75 | transform = transforms.Compose([ 76 | transforms.Resize((model_config["img_w"], model_config["img_h"])), 77 | transforms.ToTensor(), 78 | transforms.Normalize(0.5, 0.5), 79 | ]) 80 | 81 | train_dataset = MNIST( 82 | datasets.MNIST("test_data/mnist-data", train=True, download=True, transform=transform), 83 | #datasets.FashionMNIST("test_data/mnist-data", train=True, download=True, transform=transforms.ToTensor()), 84 | train_config["batch_size"], 85 | device) 86 | val_dataset = None 87 | test_dataset = MNIST( 88 | datasets.MNIST("test_data/mnist-data", train=False, download=True, transform=transform), 89 | #datasets.FashionMNIST("test_data/mnist-data", train=True, download=True, transform=transforms.ToTensor()), 90 | train_config["batch_size"], 91 | device) 92 | 93 | optimizer = build_optimizer(model, train_config) 94 | lr_scheduler = build_scheduler(train_config, optimizer) 95 | eval_fn_list = [] 96 | train(model, 97 | optimizer, 98 | train_config, 99 | train_dataset, 100 | val_dataset, 101 | test_dataset, 102 | eval_fn_list, 103 | lr_scheduler) 104 | 105 | 106 | def run_train(): 107 | """ 108 | """ 109 | parser = ArgumentParser() 110 | 111 | parser.add_argument("--model_conf", type=str) 112 | parser.add_argument("--train_conf", type=str) 113 | 114 | args = parser.parse_args(sys.argv[1:]) 115 | 116 | model_config = load_model_config(real_path(args.model_conf)) 117 | train_config = load_config(real_path(args.train_conf)) 118 | 119 | main(model_config, train_config) 120 | 121 | 122 | if __name__ == "__main__": 123 | run_train() 124 | -------------------------------------------------------------------------------- /examples/example_train_lm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jul 28 22:06:35 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | from argparse import ArgumentParser 11 | import numpy as np 12 | import torch 13 | from mimix.models import build_model 14 | from mimix.optimizer import build_optimizer 15 | from mimix.scheduler import build_scheduler 16 | from mimix.loss import seq_cross_entropy 17 | from mimix.utils import real_path, load_config, load_model_config, SimpleDataset 18 | from mimix.train import train 19 | 20 | 21 | class LMDataset(SimpleDataset): 22 | """ 23 | """ 24 | def __init__(self, 25 | data_dir, 26 | batch_size, 27 | symbol2id, 28 | device="cpu", 29 | rank=0, 30 | world_size=1): 31 | """ 32 | """ 33 | self.data_dir = data_dir 34 | self.batch_size = batch_size 35 | self.PAD = symbol2id["_pad_"] 36 | self.BOS = symbol2id["_bos_"] 37 | self.EOS = symbol2id["_eos_"] 38 | self.UNK = symbol2id["_unk_"] 39 | self.SEP = symbol2id["_sep_"] 40 | self.CLS = symbol2id["_cls_"] 41 | self.MASK = symbol2id["_mask_"] 42 | self.device = device 43 | self.rank = rank 44 | self.world_size = world_size 45 | self.sort_key_fn = None 46 | 47 | 48 | def vectorize(self, batch_data): 49 | """ 50 | """ 51 | batch_size = len(batch_data) 52 | trg_max_len = max(len(x["trg"]) - 1 for x in batch_data) 53 | y = self.PAD + np.zeros((batch_size, trg_max_len), dtype=np.int64) 54 | y_target = self.PAD + np.zeros((batch_size, trg_max_len), 55 | dtype=np.int64) 56 | 57 | for i, d in enumerate(batch_data): 58 | yy = d["trg"] 59 | y[i, :len(yy) - 1] = yy[:-1] 60 | y_target[i, :len(yy) - 1] = yy[1:] 61 | 62 | y = torch.tensor(y, dtype=torch.long) 63 | y_target = torch.tensor(y_target, dtype=torch.long) 64 | 65 | return {"y":y}, {"y_target":y_target} 66 | 67 | 68 | def main(model_config, train_config): 69 | """ 70 | """ 71 | model = build_model(model_config, train_config.get("reload_model", None)) 72 | 73 | device = "cpu" 74 | if train_config["use_cuda"] == True: 75 | device_id = train_config.get("device_id", "0") 76 | device = 'cuda:%s' % device_id 77 | 78 | model = model.to(device) 79 | eps = train_config.get("eps", 0) 80 | model.loss_fn = lambda x,y:seq_cross_entropy(x["logits"], y["y_target"], eps, model.PAD) 81 | symbol2id = model_config["symbol2id"] 82 | train_dir = os.path.join(train_config["tmp_dir"], "train") 83 | batch_size = train_config["batch_size"] 84 | train_dataset = LMDataset(train_dir, batch_size, symbol2id, device) 85 | val_dataset = None 86 | if train_config.get("val_dir", None) is not None: 87 | val_dir = train_config["val_dir"] 88 | test_batch_size = train_config["test_batch_size"] 89 | val_dataset = LMDataset(val_dir, test_batch_size, symbol2id, device) 90 | test_dataset = None 91 | if train_config.get("test_dir", None) is not None: 92 | test_dir = train_config["test_dir"] 93 | test_batch_size = train_config["test_batch_size"] 94 | test_dataset = LMDataset(test_dir, test_batch_size, symbol2id, device) 95 | 96 | optimizer = build_optimizer(model, train_config) 97 | lr_scheduler = build_scheduler(train_config, optimizer) 98 | eval_fn_list = [] 99 | train(model, 100 | optimizer, 101 | train_config, 102 | train_dataset, 103 | val_dataset, 104 | test_dataset, 105 | eval_fn_list, 106 | lr_scheduler) 107 | 108 | 109 | def run_train(): 110 | """ 111 | """ 112 | parser = ArgumentParser() 113 | 114 | parser.add_argument("--model_conf", type=str) 115 | parser.add_argument("--train_conf", type=str) 116 | 117 | args = parser.parse_args(sys.argv[1:]) 118 | 119 | model_config = load_model_config(real_path(args.model_conf)) 120 | train_config = load_config(real_path(args.train_conf)) 121 | 122 | main(model_config, train_config) 123 | 124 | 125 | if __name__ == "__main__": 126 | run_train() 127 | -------------------------------------------------------------------------------- /examples/example_train_lm_deepspeed.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Mar 11 21:41:58 2024 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | from argparse import ArgumentParser 11 | import numpy as np 12 | import torch 13 | from mimix.models import build_model, load_model_weights 14 | from mimix.optimizer import build_optimizer 15 | from mimix.scheduler import build_scheduler 16 | from mimix.loss import seq_cross_entropy 17 | from mimix.utils import real_path, load_config, load_model_config, SimpleDataset 18 | import deepspeed 19 | from mimix.ds import train 20 | import re 21 | import random 22 | import json 23 | 24 | import mimix.tokenization as tokenization 25 | 26 | deepspeed.init_distributed() 27 | local_rank = int(os.environ['LOCAL_RANK']) 28 | world_size = int(os.environ['WORLD_SIZE']) 29 | rank = int(os.environ['RANK']) 30 | 31 | def loop_all_texts(): 32 | pass 33 | 34 | import math 35 | class Scheduler(): 36 | """ 37 | """ 38 | def __init__(self, optimizer, train_config): 39 | """ 40 | """ 41 | self.optimizer = optimizer 42 | self.steps = 0 43 | 44 | 45 | def step(self): 46 | """ 47 | """ 48 | self.steps += 1 49 | 50 | for param_group in self.optimizer.param_groups: 51 | param_group['lr'] = self.lr 52 | 53 | 54 | def main(model_config, train_config, ds_config): 55 | """ 56 | """ 57 | model = build_model(model_config) 58 | if train_config.get("reload_model", None) is not None: 59 | model = load_model_weights(model, real_path(train_config["reload_model"])) 60 | PAD = model.PAD 61 | eps = train_config.get("eps", 0) 62 | model.loss_fn = lambda x,y:seq_cross_entropy(x["logits"], y["y_target"], eps, PAD) 63 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 64 | 65 | model_engine, optimizer, _, __ = deepspeed.initialize( 66 | model=model, 67 | model_parameters=parameters, 68 | config=ds_config, 69 | ) 70 | lr_scheduler = Scheduler(optimizer, train_config) 71 | train(model_engine, optimizer, train_config, loop_all_texts, lr_scheduler) 72 | 73 | def run_train(): 74 | """ 75 | """ 76 | model_conf = "conf/xxx_lm_conf" 77 | train_conf = "conf/xxx_train_conf" 78 | model_config = load_model_config(real_path(model_conf)) 79 | train_config = load_config(real_path(train_conf)) 80 | ds_config = json.loads(open("conf/ds_config_zero2.json", "rb").read()) 81 | 82 | main(model_config, train_config, ds_config) 83 | 84 | 85 | if __name__ == "__main__": 86 | run_train() 87 | 88 | -------------------------------------------------------------------------------- /examples/example_train_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jul 28 22:06:35 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | from argparse import ArgumentParser 11 | from mimix.models import build_model 12 | from mimix.optimizer import build_optimizer 13 | from mimix.scheduler import build_scheduler 14 | from mimix.loss import classify_loss 15 | from mimix.utils import real_path, load_config, load_model_config 16 | from mimix.train import train 17 | from mimix.evaluate import eval_acc 18 | import torch 19 | import numpy as np 20 | import random 21 | from torchvision import datasets, transforms 22 | 23 | class MNIST(): 24 | """ 25 | """ 26 | def __init__(self, data, batch_size, device): 27 | """ 28 | """ 29 | self.data = data 30 | self.batch_size = batch_size 31 | self.device = device 32 | self.idx = list(range(0, len(self.data))) 33 | random.shuffle(self.idx) 34 | 35 | 36 | def local_shuffle(self): 37 | """ 38 | """ 39 | random.shuffle(self.idx) 40 | 41 | 42 | def __call__(self, steps=0): 43 | """ 44 | """ 45 | i = steps * self.batch_size 46 | while i < len(self.data): 47 | x = torch.cat([self.data[j][0][0].unsqueeze(0) for j in self.idx[i:i+self.batch_size]]) 48 | x = x.float().unsqueeze(1).to(self.device) 49 | y = torch.tensor([self.data[j][1] for j in self.idx[i:i+self.batch_size]]) 50 | y = y.long().to(self.device) 51 | yield {"x":x}, {"labels":y} 52 | i += self.batch_size 53 | 54 | 55 | def main(model_config, train_config): 56 | """ 57 | """ 58 | model = build_model(model_config, train_config.get("reload_model", None)) 59 | 60 | device = "cpu" 61 | if train_config["use_cuda"] == True: 62 | device_id = train_config.get("device_id", "0") 63 | device = 'cuda:%s' % device_id 64 | 65 | model = model.to(device) 66 | eps = train_config.get("eps", 0) 67 | model.loss_fn = lambda x,y:classify_loss(x["cls_logits"], y["labels"], eps) 68 | symbol2id = model_config["symbol2id"] 69 | batch_size = train_config["batch_size"] 70 | 71 | train_dataset = MNIST( 72 | datasets.MNIST("test_data/mnist-data", train=True, download=True, transform=transforms.ToTensor()), 73 | #datasets.FashionMNIST("test_data/mnist-data", train=True, download=True, transform=transforms.ToTensor()), 74 | train_config["batch_size"], 75 | device) 76 | val_dataset = None 77 | test_dataset = MNIST( 78 | datasets.MNIST("test_data/mnist-data", train=False, download=True, transform=transforms.ToTensor()), 79 | #datasets.FashionMNIST("test_data/mnist-data", train=True, download=True, transform=transforms.ToTensor()), 80 | train_config["batch_size"], 81 | device) 82 | 83 | optimizer = build_optimizer(model, train_config) 84 | lr_scheduler = build_scheduler(train_config, optimizer) 85 | eval_fn_list = [eval_acc] 86 | train(model, 87 | optimizer, 88 | train_config, 89 | train_dataset, 90 | val_dataset, 91 | test_dataset, 92 | eval_fn_list, 93 | lr_scheduler) 94 | 95 | 96 | def run_train(): 97 | """ 98 | """ 99 | parser = ArgumentParser() 100 | 101 | parser.add_argument("--model_conf", type=str) 102 | parser.add_argument("--train_conf", type=str) 103 | 104 | args = parser.parse_args(sys.argv[1:]) 105 | 106 | model_config = load_model_config(real_path(args.model_conf)) 107 | train_config = load_config(real_path(args.train_conf)) 108 | 109 | main(model_config, train_config) 110 | 111 | 112 | if __name__ == "__main__": 113 | run_train() 114 | 115 | -------------------------------------------------------------------------------- /examples/example_train_seq2seq.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Jul 28 22:06:35 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")) 10 | from argparse import ArgumentParser 11 | import numpy as np 12 | import torch 13 | from mimix.models import build_model 14 | from mimix.optimizer import build_optimizer 15 | from mimix.scheduler import build_scheduler 16 | from mimix.loss import seq_cross_entropy 17 | from mimix.utils import real_path, load_config, load_model_config, SimpleDataset 18 | from mimix.train import train 19 | 20 | 21 | class S2SDataset(SimpleDataset): 22 | """ 23 | """ 24 | def __init__(self, 25 | data_dir, 26 | batch_size, 27 | symbol2id, 28 | device="cpu", 29 | rank=0, 30 | world_size=1): 31 | """ 32 | """ 33 | self.data_dir = data_dir 34 | self.batch_size = batch_size 35 | self.PAD = symbol2id["_pad_"] 36 | self.BOS = symbol2id["_bos_"] 37 | self.EOS = symbol2id["_eos_"] 38 | self.UNK = symbol2id["_unk_"] 39 | self.SEP = symbol2id["_sep_"] 40 | self.CLS = symbol2id["_cls_"] 41 | self.MASK = symbol2id["_mask_"] 42 | self.device = device 43 | self.rank = rank 44 | self.world_size = world_size 45 | self.sort_key_fn = None 46 | 47 | 48 | def vectorize(self, batch_data): 49 | """ 50 | """ 51 | batch_size = len(batch_data) 52 | src_max_len = max(len(x["src"]) for x in batch_data) 53 | trg_max_len = max(len(x["trg"]) - 1 for x in batch_data) 54 | x = self.PAD + np.zeros((batch_size, src_max_len), dtype=np.int64) 55 | y = self.PAD + np.zeros((batch_size, trg_max_len), dtype=np.int64) 56 | y_target = self.PAD + np.zeros((batch_size, trg_max_len), 57 | dtype=np.int64) 58 | 59 | for i, d in enumerate(batch_data): 60 | xx,yy = d["src"],d["trg"] 61 | x[i, :len(xx)] = xx 62 | y[i, :len(yy) - 1] = yy[:-1] 63 | y_target[i, :len(yy) - 1] = yy[1:] 64 | 65 | x = torch.tensor(x, dtype=torch.long) 66 | y = torch.tensor(y, dtype=torch.long) 67 | y_target = torch.tensor(y_target, dtype=torch.long) 68 | 69 | return {"x":x, "y":y}, {"y_target":y_target} 70 | 71 | 72 | def main(model_config, train_config): 73 | """ 74 | """ 75 | model = build_model(model_config, train_config.get("reload_model", None)) 76 | 77 | device = "cpu" 78 | if train_config["use_cuda"] == True: 79 | device_id = train_config.get("device_id", "0") 80 | device = 'cuda:%s' % device_id 81 | 82 | model = model.to(device) 83 | eps = train_config.get("eps", 0) 84 | model.loss_fn = lambda x,y:seq_cross_entropy(x["logits"], y["y_target"], eps, model.PAD) 85 | symbol2id = model_config["symbol2id"] 86 | train_dir = os.path.join(train_config["tmp_dir"], "train") 87 | batch_size = train_config["batch_size"] 88 | train_dataset = S2SDataset(train_dir, batch_size, symbol2id, device) 89 | val_dataset = None 90 | if train_config.get("val_dir", None) is not None: 91 | val_dir = train_config["val_dir"] 92 | test_batch_size = train_config["test_batch_size"] 93 | val_dataset = S2SDataset(val_dir, test_batch_size, symbol2id, device) 94 | test_dataset = None 95 | if train_config.get("test_dir", None) is not None: 96 | test_dir = train_config["test_dir"] 97 | test_batch_size = train_config["test_batch_size"] 98 | test_dataset = S2SDataset(test_dir, test_batch_size, symbol2id, device) 99 | 100 | optimizer = build_optimizer(model, train_config) 101 | lr_scheduler = build_scheduler(train_config, optimizer) 102 | eval_fn_list = [] 103 | train(model, 104 | optimizer, 105 | train_config, 106 | train_dataset, 107 | val_dataset, 108 | test_dataset, 109 | eval_fn_list, 110 | lr_scheduler) 111 | 112 | 113 | def run_train(): 114 | """ 115 | """ 116 | parser = ArgumentParser() 117 | 118 | parser.add_argument("--model_conf", type=str) 119 | parser.add_argument("--train_conf", type=str) 120 | 121 | args = parser.parse_args(sys.argv[1:]) 122 | 123 | model_config = load_model_config(real_path(args.model_conf)) 124 | train_config = load_config(real_path(args.train_conf)) 125 | 126 | main(model_config, train_config) 127 | 128 | 129 | if __name__ == "__main__": 130 | run_train() 131 | -------------------------------------------------------------------------------- /examples/test.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/examples/test.mid -------------------------------------------------------------------------------- /gen_midi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Mar 15 22:06:34 2024 4 | 5 | @author: 1 6 | """ 7 | import sys 8 | import time 9 | import mido 10 | from mimix.predictor import LMGenerator 11 | from mimix.utils import real_path, load_model_config 12 | 13 | 14 | def convert_tokens_to_midi(tokens, output_path): 15 | """ 16 | """ 17 | new_midi_file = mido.MidiFile(ticks_per_beat=384) 18 | new_track = mido.MidiTrack() 19 | new_midi_file.tracks.append(new_track) 20 | new_midi_file.tracks[0].append(mido.MetaMessage('set_tempo', tempo=500000, time=0)) 21 | new_midi_file.tracks[0].append(mido.MetaMessage('time_signature', numerator=4, denominator=4, clocks_per_click=24, notated_32nd_notes_per_beat=8, time=0)) 22 | 23 | def add_event(track, kv): 24 | """ 25 | """ 26 | kv = {k:int(kv[k]) if k != "type" else kv[k] for k in kv} 27 | if len(kv) == 0: 28 | pass 29 | elif kv["type"] == "note_on": 30 | track.append(mido.Message('note_on', note=kv["note"], velocity=kv["velocity"], time=kv["time"])) 31 | elif kv["type"] == "control_change": 32 | track.append(mido.Message('control_change', control=kv["control"], value=kv["value"], time=kv["time"])) 33 | elif kv["type"] == "program_change": 34 | track.append(mido.Message('program_change', program=int(kv["program"]))) 35 | 36 | event = {} 37 | for token in tokens: 38 | 39 | if token.startswith("note_on_note_"): 40 | add_event(new_midi_file.tracks[0], event) 41 | event = {} 42 | event["type"] = "note_on" 43 | k = "note" 44 | event[k] = token.split("_")[-1] 45 | elif token.startswith("note_on_velocity_"): 46 | k = "velocity" 47 | event[k] = token.split("_")[-1] 48 | elif token.startswith("note_on_time_"): 49 | k = "time" 50 | event[k] = token.split("_")[-1] 51 | elif token.startswith("control_change_control_"): 52 | add_event(new_midi_file.tracks[0], event) 53 | event = {} 54 | event["type"] = "control_change" 55 | k = "control" 56 | event[k] = token.split("_")[-1] 57 | elif token.startswith("control_change_time_"): 58 | k = "time" 59 | event[k] = token.split("_")[-1] 60 | elif token.startswith("control_change_value_"): 61 | k = "value" 62 | event[k] = token.split("_")[-1] 63 | elif token.startswith("program_change_program_"): 64 | add_event(new_midi_file.tracks[0], event) 65 | event["type"] = "program_change" 66 | event = {} 67 | k = "program" 68 | event[k] = token.split("_")[-1] 69 | elif token.startswith("#"): 70 | event[k] = event[k] + token[1:] 71 | new_midi_file.tracks[0].append(mido.MetaMessage('end_of_track', time=0)) 72 | new_midi_file.save(output_path) 73 | 74 | 75 | def midi_demo(): 76 | """ 77 | """ 78 | conf_file = "conf/midi_base_conf" 79 | config = load_model_config(real_path(conf_file)) 80 | lm_gen = LMGenerator(config) 81 | 82 | print("Press to start midi generation.") 83 | for line in sys.stdin: 84 | 85 | start = time.time() 86 | search_res = lm_gen.predict(prefix_list=None) 87 | tokens = search_res[0][1][0][0].split() 88 | convert_tokens_to_midi(tokens, "test.mid") 89 | 90 | print("Generate midi and save to test.mid done.") 91 | end = time.time() 92 | cost = end - start 93 | print("-----cost time: %s s-----" % cost) 94 | print("Press to start midi generation.") 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | midi_demo() -------------------------------------------------------------------------------- /image/mimix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/mimix.png -------------------------------------------------------------------------------- /image/streamlit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit.png -------------------------------------------------------------------------------- /image/streamlit10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit10.png -------------------------------------------------------------------------------- /image/streamlit2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit2.png -------------------------------------------------------------------------------- /image/streamlit3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit3.png -------------------------------------------------------------------------------- /image/streamlit4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit4.png -------------------------------------------------------------------------------- /image/streamlit5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit5.png -------------------------------------------------------------------------------- /image/streamlit6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit6.png -------------------------------------------------------------------------------- /image/streamlit7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit7.png -------------------------------------------------------------------------------- /image/streamlit8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit8.png -------------------------------------------------------------------------------- /image/streamlit9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/streamlit9.png -------------------------------------------------------------------------------- /image/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/image/wechat.jpg -------------------------------------------------------------------------------- /interact.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 5 19:04:42 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | from mimix.interact import run_interactive 8 | 9 | run_interactive() -------------------------------------------------------------------------------- /mimix/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 22 11:58:37 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | -------------------------------------------------------------------------------- /mimix/app.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 6 20:54:44 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import time 8 | import sys 9 | from argparse import ArgumentParser 10 | import numpy as np 11 | from PIL import Image 12 | import io 13 | import streamlit as st 14 | from mimix.predictor import EncDecGenerator,LMGenerator,TextEncoder 15 | from mimix.predictor import ImageEncoder, ClipMatcher, MAE 16 | from mimix.utils import real_path, load_model_config 17 | 18 | st.set_page_config(page_title="Mimix Demo", initial_sidebar_state="auto", layout="wide") 19 | 20 | def text_gen_app(model): 21 | """ 22 | """ 23 | st.markdown( 24 | """ 25 | ## Text Generation DEMO 26 | """ 27 | ) 28 | 29 | with st.form(key='my_form'): 30 | text = st.text_area("input text", max_chars=512) 31 | values = ('beam_search', 'sample') 32 | strategy = st.selectbox('strategy', values, index=values.index(model.strategy)) 33 | beam_size = st.number_input("beam_size", min_value=0, max_value=10, value=model.beam_size) 34 | group_size = st.number_input("group_size", min_value=0, max_value=5, value=model.group_size) 35 | max_decode_steps = st.number_input("max_decode_steps", min_value=0, max_value=512, value=model.max_decode_steps, step=1) 36 | repetition_penalty = st.slider("repetition_penalty", min_value=-10.0, max_value=0.0, value=model.repetition_penalty, step=0.1) 37 | temperature = st.slider("temperature", min_value=0., max_value=10.0, value=model.top_p, step=0.01) 38 | top_k = st.number_input("top_k", min_value=0, max_value=100, value=model.top_k, step=1) 39 | top_p = st.slider("top_p", min_value=0., max_value=1.0, value=model.top_p, step=0.01) 40 | 41 | submit = st.form_submit_button("generate") 42 | 43 | if submit: 44 | model.strategy = strategy 45 | model.beam_size = beam_size 46 | model.group_size = group_size 47 | model.max_decode_steps = max_decode_steps 48 | model.repetition_penalty = repetition_penalty 49 | model.top_k = top_k 50 | model.top_p = top_p 51 | start_message = st.empty() 52 | if model.group_size > 0 and model.beam_size % model.group_size != 0: 53 | start_message.write("beam_size must be a multiple of group_size!") 54 | else: 55 | start_message.write("generating...") 56 | start_time = time.time() 57 | res = model.predict([text]) 58 | end_time = time.time() 59 | start_message.write("done, cost{}s".format(end_time - start_time)) 60 | for i, (text, score) in enumerate(res[0][1]): 61 | st.text_area("the {} result".format(i + 1), text) 62 | 63 | 64 | def image_classification_app(model): 65 | """ 66 | """ 67 | uploaded_file = st.file_uploader("Choose a image file", type="jpg") 68 | 69 | if uploaded_file is not None: 70 | image = Image.open(io.BytesIO(uploaded_file.getvalue())) 71 | st.image(image, width=224) 72 | res = model.predict_cls([image]) 73 | for i, (label, score) in enumerate(res[0][1]): 74 | st.text_area("top {} result".format(i + 1), label + " " + str(score)) 75 | 76 | 77 | def image2text_app(model): 78 | """ 79 | """ 80 | st.markdown( 81 | """ 82 | ## Text Generation DEMO 83 | """ 84 | ) 85 | 86 | with st.form(key='my_form'): 87 | uploaded_file = st.file_uploader("Choose a image file", type="jpg") 88 | prefix = st.text_area("input text, can be empty", max_chars=512) 89 | values = ('beam_search', 'sample') 90 | strategy = st.selectbox('strategy', values, index=values.index(model.strategy)) 91 | beam_size = st.number_input("beam_size", min_value=0, max_value=10, value=model.beam_size) 92 | group_size = st.number_input("group_size", min_value=0, max_value=5, value=model.group_size) 93 | max_decode_steps = st.number_input("max_decode_steps", min_value=0, max_value=512, value=model.max_decode_steps, step=1) 94 | repetition_penalty = st.slider("repetition_penalty", min_value=-10.0, max_value=0.0, value=model.repetition_penalty, step=0.1) 95 | temperature = st.slider("temperature", min_value=0., max_value=10.0, value=model.top_p, step=0.01) 96 | top_k = st.number_input("top_k", min_value=0, max_value=100, value=model.top_k, step=1) 97 | top_p = st.slider("top_p", min_value=0., max_value=1.0, value=model.top_p, step=0.01) 98 | 99 | submit = st.form_submit_button("generate") 100 | 101 | if uploaded_file is not None and submit: 102 | image = Image.open(io.BytesIO(uploaded_file.getvalue())) 103 | st.image(image, width=224) 104 | model.strategy = strategy 105 | model.beam_size = beam_size 106 | model.group_size = group_size 107 | model.max_decode_steps = max_decode_steps 108 | model.temperature = temperature 109 | model.repetition_penalty = repetition_penalty 110 | model.top_k = top_k 111 | model.top_p = top_p 112 | start_message = st.empty() 113 | if model.group_size > 0 and model.beam_size % model.group_size != 0: 114 | start_message.write("beam_size must be a multiple of group_size!") 115 | else: 116 | start_message.write("generating...") 117 | start_time = time.time() 118 | 119 | prefix_list = None 120 | if prefix is not None and len(prefix) > 0: 121 | prefix_list = [prefix] 122 | res = model.predict([image], prefix_list) 123 | image.close() 124 | 125 | end_time = time.time() 126 | start_message.write("done, cost{}s".format(end_time - start_time)) 127 | for i, (text, score) in enumerate(res[0][1]): 128 | st.text_area("the {} result".format(i + 1), text) 129 | 130 | 131 | def image_text_match_app(model): 132 | """ 133 | """ 134 | st.markdown( 135 | """ 136 | ## CLIP DEMO 137 | """ 138 | ) 139 | 140 | with st.form(key='my_form'): 141 | uploaded_file = st.file_uploader("Choose a image file", type="jpg", key="1") 142 | uploaded_file_2 = st.file_uploader("Choose a image file", type="jpg", key="2") 143 | texts = st.text_area("input text", max_chars=512) 144 | submit = st.form_submit_button("compute image text match score") 145 | submit_2 = st.form_submit_button("compute image image match score") 146 | #submit_3 = st.form_submit_button("compute text text match score") 147 | if submit: 148 | if uploaded_file is not None and len(texts) > 0: 149 | image = Image.open(io.BytesIO(uploaded_file.getvalue())) 150 | texts = [text.strip() for text in texts.split("\n")] 151 | texts = [text for text in texts if len(text) > 0] 152 | st.image(image, width=224) 153 | res = model.predict_sim([image], texts) 154 | res = [[text,score,prob] for text,score,prob in zip(texts, res[0][0], res[1][0])] 155 | res.sort(key=lambda x:x[1], reverse=True) 156 | for text,score,prob in res: 157 | info = "%s match score: %.2f, match prob: %.2f" % (text, score, prob) 158 | st.text(info) 159 | else: 160 | info = "image and text must not be empty!" 161 | st.text(info) 162 | 163 | if submit_2: 164 | if uploaded_file is not None and uploaded_file_2 is not None: 165 | image = Image.open(io.BytesIO(uploaded_file.getvalue())) 166 | image_2 = Image.open(io.BytesIO(uploaded_file_2.getvalue())) 167 | st.image(image, width=224) 168 | st.image(image_2, width=224) 169 | res = model.predict_images_sim([image, image_2]) 170 | info = "match score: %.2f" % (res[0][0][1]) 171 | st.text(info) 172 | else: 173 | info = "image and image_2 must not be empty!" 174 | st.text(info) 175 | 176 | ''' 177 | if submit_3: 178 | if len(texts) > 0: 179 | res = model.predict_texts_sim(texts) 180 | texts = [text.strip() for text in texts.split("\n")] 181 | texts = [text for text in texts if len(text) > 0] 182 | for i,text in enumerate(texts): 183 | for j,text_2 in enumerate(texts): 184 | if j > i: 185 | info = "%s %s match score: %.2f" % (text, text_2, res[0][i][j]) 186 | st.text(info) 187 | else: 188 | info = "texts must not be empty!" 189 | st.text(info) 190 | ''' 191 | 192 | 193 | def mae_app(model): 194 | """ 195 | """ 196 | uploaded_file = st.file_uploader("Choose a image file", type="jpg") 197 | 198 | if uploaded_file is not None: 199 | image = Image.open(io.BytesIO(uploaded_file.getvalue())) 200 | st.image(image, width=224) 201 | origin, reconstruct, im_masked, im_paste = model.visualize(image) 202 | 203 | origin = Image.fromarray(origin.astype(np.uint8)) 204 | reconstruct = Image.fromarray(reconstruct.astype(np.uint8)) 205 | im_masked = Image.fromarray(im_masked.astype(np.uint8)) 206 | im_paste = Image.fromarray(im_paste.astype(np.uint8)) 207 | 208 | st.text("masked") 209 | st.image(im_masked, width=224) 210 | st.text("reconstruct") 211 | st.image(reconstruct, width=224) 212 | st.text("reconstruction + visible") 213 | st.image(im_paste, width=224) 214 | 215 | def run_app(): 216 | """ 217 | """ 218 | parser = ArgumentParser() 219 | 220 | parser.add_argument("--model_conf", type=str) 221 | 222 | args = parser.parse_args(sys.argv[1:]) 223 | 224 | model_config = load_model_config(real_path(args.model_conf)) 225 | 226 | if model_config["task"] == "enc_dec": 227 | model = EncDecGenerator(model_config) 228 | text_gen_app(model) 229 | elif model_config["task"] == "lm": 230 | model = LMGenerator(model_config) 231 | text_gen_app(model) 232 | elif model_config["task"] == "image_classification": 233 | model = ImageEncoder(model_config) 234 | image_classification_app(model) 235 | elif model_config["task"] == "image2text": 236 | model = EncDecGenerator(model_config) 237 | image2text_app(model) 238 | elif model_config["task"] == "image_text_match": 239 | model = ClipMatcher(model_config) 240 | image_text_match_app(model) 241 | elif model_config["task"] == "masked_auto_encoder": 242 | model = MAE(model_config) 243 | mae_app(model) 244 | 245 | if __name__ == "__main__": 246 | run_app() -------------------------------------------------------------------------------- /mimix/bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import re 23 | import unicodedata 24 | import six 25 | 26 | def convert_to_unicode(text): 27 | """Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" 28 | if six.PY3: 29 | if isinstance(text, str): 30 | return text 31 | elif isinstance(text, bytes): 32 | return text.decode("utf-8", "ignore") 33 | else: 34 | raise ValueError("Unsupported string type: %s" % (type(text))) 35 | elif six.PY2: 36 | if isinstance(text, str): 37 | return text.decode("utf-8", "ignore") 38 | elif isinstance(text, unicode): 39 | return text 40 | else: 41 | raise ValueError("Unsupported string type: %s" % (type(text))) 42 | else: 43 | raise ValueError("Not running on Python2 or Python 3?") 44 | 45 | 46 | def printable_text(text): 47 | """Returns text encoded in a way suitable for print or `tf.logging`.""" 48 | 49 | # These functions want `str` for both Python2 and Python3, but in one case 50 | # it's a Unicode string and in the other it's a byte string. 51 | if six.PY3: 52 | if isinstance(text, str): 53 | return text 54 | elif isinstance(text, bytes): 55 | return text.decode("utf-8", "ignore") 56 | else: 57 | raise ValueError("Unsupported string type: %s" % (type(text))) 58 | elif six.PY2: 59 | if isinstance(text, str): 60 | return text 61 | elif isinstance(text, unicode): 62 | return text.encode("utf-8") 63 | else: 64 | raise ValueError("Unsupported string type: %s" % (type(text))) 65 | else: 66 | raise ValueError("Not running on Python2 or Python 3?") 67 | 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | index = 0 73 | #with tf.gfile.GFile(vocab_file, "r") as reader: 74 | with open(vocab_file, "r", encoding="utf-8") as reader: 75 | while True: 76 | token = convert_to_unicode(reader.readline()) 77 | if not token: 78 | break 79 | token = token.strip() 80 | vocab[token] = index 81 | index += 1 82 | return vocab 83 | 84 | 85 | def convert_by_vocab(vocab, items): 86 | """Converts a sequence of [tokens|ids] using the vocab.""" 87 | output = [] 88 | for item in items: 89 | output.append(vocab[item]) 90 | return output 91 | 92 | 93 | def convert_tokens_to_ids(vocab, tokens): 94 | return convert_by_vocab(vocab, tokens) 95 | 96 | 97 | def convert_ids_to_tokens(inv_vocab, ids): 98 | return convert_by_vocab(inv_vocab, ids) 99 | 100 | 101 | def whitespace_tokenize(text): 102 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 103 | text = text.strip() 104 | if not text: 105 | return [] 106 | tokens = text.split() 107 | return tokens 108 | 109 | 110 | class FullTokenizer(object): 111 | """Runs end-to-end tokenziation.""" 112 | 113 | def __init__(self, vocab_file, do_lower_case=True): 114 | self.vocab = load_vocab(vocab_file) 115 | self.inv_vocab = {v: k for k, v in self.vocab.items()} 116 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) 117 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) 118 | 119 | def tokenize(self, text): 120 | split_tokens = [] 121 | for token in self.basic_tokenizer.tokenize(text): 122 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 123 | split_tokens.append(sub_token) 124 | 125 | return split_tokens 126 | 127 | def convert_tokens_to_ids(self, tokens): 128 | return convert_by_vocab(self.vocab, tokens) 129 | 130 | def convert_ids_to_tokens(self, ids): 131 | return convert_by_vocab(self.inv_vocab, ids) 132 | 133 | 134 | class BasicTokenizer(object): 135 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 136 | 137 | def __init__(self, do_lower_case=True): 138 | """Constructs a BasicTokenizer. 139 | Args: 140 | do_lower_case: Whether to lower case the input. 141 | """ 142 | self.do_lower_case = do_lower_case 143 | 144 | def tokenize(self, text): 145 | """Tokenizes a piece of text.""" 146 | text = convert_to_unicode(text) 147 | text = self._clean_text(text) 148 | 149 | # This was added on November 1st, 2018 for the multilingual and Chinese 150 | # models. This is also applied to the English models now, but it doesn't 151 | # matter since the English models were not trained on any Chinese data 152 | # and generally don't have any Chinese data in them (there are Chinese 153 | # characters in the vocabulary because Wikipedia does have some Chinese 154 | # words in the English Wikipedia.). 155 | text = self._tokenize_chinese_chars(text) 156 | 157 | orig_tokens = whitespace_tokenize(text) 158 | split_tokens = [] 159 | for token in orig_tokens: 160 | if self.do_lower_case: 161 | token = token.lower() 162 | token = self._run_strip_accents(token) 163 | split_tokens.extend(self._run_split_on_punc(token)) 164 | 165 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 166 | return output_tokens 167 | 168 | def _run_strip_accents(self, text): 169 | """Strips accents from a piece of text.""" 170 | text = unicodedata.normalize("NFD", text) 171 | output = [] 172 | for char in text: 173 | cat = unicodedata.category(char) 174 | if cat == "Mn": 175 | continue 176 | output.append(char) 177 | return "".join(output) 178 | 179 | def _run_split_on_punc(self, text): 180 | """Splits punctuation on a piece of text.""" 181 | chars = list(text) 182 | i = 0 183 | start_new_word = True 184 | output = [] 185 | while i < len(chars): 186 | char = chars[i] 187 | if _is_punctuation(char): 188 | output.append([char]) 189 | start_new_word = True 190 | else: 191 | if start_new_word: 192 | output.append([]) 193 | start_new_word = False 194 | output[-1].append(char) 195 | i += 1 196 | 197 | return ["".join(x) for x in output] 198 | 199 | def _tokenize_chinese_chars(self, text): 200 | """Adds whitespace around any CJK character.""" 201 | output = [] 202 | for char in text: 203 | cp = ord(char) 204 | if self._is_chinese_char(cp): 205 | output.append(" ") 206 | output.append(char) 207 | output.append(" ") 208 | else: 209 | output.append(char) 210 | return "".join(output) 211 | 212 | def _is_chinese_char(self, cp): 213 | """Checks whether CP is the codepoint of a CJK character.""" 214 | # This defines a "chinese character" as anything in the CJK Unicode block: 215 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 216 | # 217 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 218 | # despite its name. The modern Korean Hangul alphabet is a different block, 219 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 220 | # space-separated words, so they are not treated specially and handled 221 | # like the all of the other languages. 222 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 223 | (cp >= 0x3400 and cp <= 0x4DBF) or # 224 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 225 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 226 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 227 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 228 | (cp >= 0xF900 and cp <= 0xFAFF) or # 229 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 230 | return True 231 | 232 | return False 233 | 234 | def _clean_text(self, text): 235 | """Performs invalid character removal and whitespace cleanup on text.""" 236 | output = [] 237 | for char in text: 238 | cp = ord(char) 239 | if cp == 0 or cp == 0xfffd or _is_control(char): 240 | continue 241 | if _is_whitespace(char): 242 | output.append(" ") 243 | else: 244 | output.append(char) 245 | return "".join(output) 246 | 247 | 248 | class WordpieceTokenizer(object): 249 | """Runs WordPiece tokenziation.""" 250 | 251 | def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200): 252 | self.vocab = vocab 253 | self.unk_token = unk_token 254 | self.max_input_chars_per_word = max_input_chars_per_word 255 | 256 | def tokenize(self, text): 257 | """Tokenizes a piece of text into its word pieces. 258 | This uses a greedy longest-match-first algorithm to perform tokenization 259 | using the given vocabulary. 260 | For example: 261 | input = "unaffable" 262 | output = ["un", "##aff", "##able"] 263 | Args: 264 | text: A single token or whitespace separated tokens. This should have 265 | already been passed through `BasicTokenizer. 266 | Returns: 267 | A list of wordpiece tokens. 268 | """ 269 | 270 | text = convert_to_unicode(text) 271 | 272 | output_tokens = [] 273 | for token in whitespace_tokenize(text): 274 | chars = list(token) 275 | if len(chars) > self.max_input_chars_per_word: 276 | output_tokens.append(self.unk_token) 277 | continue 278 | 279 | is_bad = False 280 | start = 0 281 | sub_tokens = [] 282 | while start < len(chars): 283 | end = len(chars) 284 | cur_substr = None 285 | while start < end: 286 | substr = "".join(chars[start:end]) 287 | if start > 0: 288 | substr = "##" + substr 289 | if substr in self.vocab: 290 | cur_substr = substr 291 | break 292 | end -= 1 293 | if cur_substr is None: 294 | is_bad = True 295 | break 296 | sub_tokens.append(cur_substr) 297 | start = end 298 | 299 | if is_bad: 300 | output_tokens.append(self.unk_token) 301 | else: 302 | output_tokens.extend(sub_tokens) 303 | return output_tokens 304 | 305 | 306 | def _is_whitespace(char): 307 | """Checks whether `chars` is a whitespace character.""" 308 | # \t, \n, and \r are technically contorl characters but we treat them 309 | # as whitespace since they are generally considered as such. 310 | if char == " " or char == "\t" or char == "\n" or char == "\r": 311 | return True 312 | cat = unicodedata.category(char) 313 | if cat == "Zs": 314 | return True 315 | return False 316 | 317 | 318 | def _is_control(char): 319 | """Checks whether `chars` is a control character.""" 320 | # These are technically control characters but we count them as whitespace 321 | # characters. 322 | if char == "\t" or char == "\n" or char == "\r": 323 | return False 324 | cat = unicodedata.category(char) 325 | if cat in ("Cc", "Cf"): 326 | return True 327 | return False 328 | 329 | 330 | def _is_punctuation(char): 331 | """Checks whether `chars` is a punctuation character.""" 332 | cp = ord(char) 333 | # We treat all non-letter/number ASCII as punctuation. 334 | # Characters such as "^", "$", and "`" are not in the Unicode 335 | # Punctuation class but we treat them as punctuation anyways, for 336 | # consistency. 337 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 338 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 339 | return True 340 | cat = unicodedata.category(char) 341 | if cat.startswith("P"): 342 | return True 343 | return False 344 | 345 | -------------------------------------------------------------------------------- /mimix/clustering.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Aug 25 19:52:32 2022 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | from argparse import ArgumentParser 8 | import sys 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | from sklearn import manifold,datasets 12 | from pylab import * 13 | mpl.rcParams['font.sans-serif'] = ['SimHei'] 14 | mpl.rcParams['axes.unicode_minus'] = False 15 | import json 16 | from collections import Counter 17 | from annoy import AnnoyIndex 18 | from mimix.predictor import TextEncoder 19 | from mimix.utils import real_path, load_model_config 20 | 21 | 22 | def ann_clustering(fi_path, fo_path, dim, search_n, threshold, min_size): 23 | """ 24 | """ 25 | data = [] 26 | t = AnnoyIndex(dim, 'angular') 27 | for i,line in enumerate(open(fi_path, "r", encoding="utf-8")): 28 | if i % 10000 == 0: 29 | print("load data", i) 30 | d = json.loads(line) 31 | t.add_item(i, d["vec"]) 32 | #del d["vec"] 33 | data.append(d) 34 | print("load over") 35 | t.build(100) 36 | 37 | id2clu = {} 38 | clusters = [] 39 | for i,d in enumerate(data): 40 | if i % 10000 == 0: 41 | print("processed %d" % i) 42 | if i in id2clu: 43 | continue 44 | indexs,scores = t.get_nns_by_item(i, search_n, include_distances=True) 45 | count = Counter() 46 | near_but_no_label = [] 47 | for idx,score in zip(indexs, scores): 48 | if i == idx: 49 | continue 50 | if 1 - score**2/2 > threshold: 51 | if idx in id2clu: 52 | count[id2clu[idx]] += 1 53 | else: 54 | count[-1] += 1 55 | near_but_no_label.append(idx) 56 | if len(count) > 0: 57 | max_id = count.most_common(1)[0][0] 58 | if max_id > -1: 59 | id2clu[i] = max_id 60 | clusters[max_id].append(d) 61 | else: 62 | id2clu[i] = len(clusters) 63 | for idx in near_but_no_label: 64 | id2clu[idx] = len(clusters) 65 | clu = [d] 66 | for idx in near_but_no_label: 67 | clu.append(data[idx]) 68 | clusters.append(clu) 69 | else: 70 | id2clu[i] = len(clusters) 71 | clu = [d] 72 | clusters.append(clu) 73 | 74 | clusters.sort(key=lambda x:len(x), reverse=True) 75 | fo = open(fo_path, "w", encoding="utf-8") 76 | for i,clu in enumerate(clusters): 77 | if len(clu) < min_size: 78 | break 79 | if i < 20: 80 | print("clu %d size: %d\n--------\n" % (i, len(clu))) 81 | for d in clu[:5]: 82 | print(" "*8, d["text"]) 83 | fo.write(json.dumps(clu, ensure_ascii=False) + "\n") 84 | fo.close() 85 | 86 | 87 | def text_clustering(model_config, fi_path, fo_path, search_n, threshold, min_size): 88 | """ 89 | """ 90 | model = TextEncoder(model_config) 91 | print("encode text to vector...") 92 | model.dump_encode_text(fi_path, fo_path + ".vec") 93 | 94 | print("text clustering...") 95 | ann_clustering(fo_path + ".vec", 96 | fo_path + ".clu", 97 | model_config["d_model"], 98 | search_n, 99 | threshold, 100 | min_size 101 | ) 102 | 103 | print("vis clustering...") 104 | vis_clusters(fo_path + ".clu", fo_path + ".png") 105 | 106 | 107 | def vis_clusters(fi_path, fig_path): 108 | """ 109 | """ 110 | X = [] 111 | y = [] 112 | labels = [] 113 | n = 0 114 | for i,line in enumerate(open(fi_path, "r", encoding="utf-8")): 115 | if i >= 10: 116 | break 117 | data = json.loads(line) 118 | for d in data: 119 | X.append(d["vec"]) 120 | y.append(i) 121 | labels.append(d["text"]) 122 | n += 1 123 | 124 | tsne = manifold.TSNE(n_components=2, init='pca', random_state=501) 125 | X_tsne = tsne.fit_transform(X) 126 | 127 | x_min, x_max = X_tsne.min(0), X_tsne.max(0) 128 | X_norm = (X_tsne - x_min) / (x_max - x_min) 129 | plt.figure(figsize=(8, 8)) 130 | 131 | X_trans = [[[], [], []] for i in range(n)] 132 | for i,j in enumerate(y): 133 | X_trans[j][0].append(X_norm[i, 0]) 134 | X_trans[j][1].append(X_norm[i, 1]) 135 | X_trans[j][2].append(labels[i]) 136 | 137 | for i, (x, y, label) in enumerate(X_trans): 138 | plt.scatter(x, y, s=50, color=plt.cm.Set3(i), marker="x", label=label[0]) 139 | plt.legend(loc='best') 140 | 141 | plt.savefig(fig_path, dpi=300, bbox_inches = 'tight') 142 | #plt.show() 143 | plt.close() 144 | 145 | 146 | def run_clustering(): 147 | """ 148 | """ 149 | parser = ArgumentParser() 150 | 151 | parser.add_argument("--model_conf", type=str) 152 | parser.add_argument("--fi", type=str) 153 | parser.add_argument("--fo", type=str) 154 | parser.add_argument("--search_n", type=int) 155 | parser.add_argument("--threshold", type=float) 156 | parser.add_argument("--min_size", type=int) 157 | 158 | args = parser.parse_args(sys.argv[1:]) 159 | 160 | model_config = load_model_config(real_path(args.model_conf)) 161 | 162 | text_clustering(model_config, 163 | real_path(args.fi), 164 | real_path(args.fo), 165 | args.n, 166 | args.threshold, 167 | args.min_size) 168 | 169 | 170 | if __name__ == "__main__": 171 | run_clustering() 172 | -------------------------------------------------------------------------------- /mimix/ddp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 30 15:16:38 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | 8 | import os 9 | from datetime import datetime 10 | import logging 11 | import torch 12 | from mimix.utils import real_path 13 | 14 | LOG_DIR = "logger" 15 | 16 | local_rank = int(os.environ['LOCAL_RANK']) 17 | world_size = int(os.environ['WORLD_SIZE']) 18 | rank = int(os.environ['RANK']) 19 | 20 | if local_rank == 0: 21 | if not os.path.exists(real_path(LOG_DIR)): 22 | os.mkdir(real_path(LOG_DIR)) 23 | 24 | 25 | def build_logger(): 26 | """ 27 | """ 28 | format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 29 | filename = datetime.today().strftime('logger/%Y-%m-%d-%H-%M-%S.log') 30 | logging.basicConfig(filename=real_path(filename), 31 | level=logging.INFO, 32 | format=format_str) 33 | console = logging.StreamHandler() 34 | console.setLevel(logging.INFO) 35 | 36 | formatter = logging.Formatter(format_str, "%Y-%m-%d %H:%M:%S") 37 | console.setFormatter(formatter) 38 | logging.getLogger('').addHandler(console) 39 | logger = logging.getLogger(__name__) 40 | 41 | return logger 42 | 43 | if local_rank == 0: 44 | logger = build_logger() 45 | 46 | def save_model(model, optimizer, model_path): 47 | """ 48 | """ 49 | if local_rank == 0: 50 | logger.info("Save model to %s" % model_path) 51 | torch.save(model.state_dict(), 52 | model_path, 53 | _use_new_zipfile_serialization=False) 54 | 55 | logger.info("Save model complete") 56 | 57 | 58 | def print_model_info(model): 59 | """ 60 | """ 61 | if local_rank == 0: 62 | logger.info("%s" % model) 63 | total_params = sum(p.numel() for p in model.parameters()) 64 | if local_rank == 0: 65 | logger.info("Total Model Params:%s" % total_params) 66 | total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad==True) 67 | if local_rank == 0: 68 | logger.info("Trainable Model Params:%s" % total_train_params) 69 | 70 | 71 | def train(model, 72 | optimizer, 73 | train_config, 74 | train_generator, 75 | val_generator=None, 76 | test_generator=None, 77 | eval_fn_list=None, 78 | lr_scheduler=None): 79 | """ 80 | """ 81 | if os.path.exists(real_path(train_config["model_dir"])) == False: 82 | os.mkdir(real_path(train_config["model_dir"])) 83 | 84 | use_amp = train_config.get("use_amp", False) 85 | if use_amp: 86 | scaler = torch.cuda.amp.GradScaler() 87 | 88 | print_model_info(model) 89 | 90 | if local_rank == 0: 91 | logger.info("Train Start!") 92 | 93 | accumulate_steps = train_config.get("accumulate_steps", 1) 94 | print_every_n_steps = train_config.get("print_every_n_steps", 100) 95 | model_path = real_path(os.path.join(real_path(train_config["model_dir"]), "%s." + train_config["model_name"])) 96 | save_steps = train_config.get("save_steps", 100000) 97 | tmp_save_steps = train_config.get("tmp_save_steps", 10000) 98 | grad_clip = train_config.get("grad_clip", None) 99 | 100 | history_loss = [] 101 | 102 | epoch,steps,total_steps = 0, 0, 0 103 | while epoch < train_config["max_epoch"]: 104 | model.train() 105 | 106 | for inputs,targets in train_generator(): 107 | if use_amp == True: 108 | with torch.cuda.amp.autocast(): 109 | outputs = model(inputs, targets=targets, compute_loss=True) 110 | loss = outputs["loss"] 111 | history_loss = history_loss[-999:] + [loss.item()] 112 | loss = loss / accumulate_steps 113 | else: 114 | outputs = model(inputs, targets=targets, compute_loss=True) 115 | loss = outputs["loss"] 116 | history_loss = history_loss[-999:] + [loss.item()] 117 | loss = loss / accumulate_steps 118 | 119 | if total_steps % print_every_n_steps == 0: 120 | ma_loss = sum(history_loss) / len(history_loss) 121 | if local_rank == 0: 122 | logger.info( 123 | "%d epoch %d step total %d steps loss: %.3f" % 124 | (epoch, 125 | steps, 126 | total_steps, 127 | ma_loss) 128 | ) 129 | 130 | if use_amp == True: 131 | scaler.scale(loss).backward() 132 | else: 133 | loss.backward() 134 | 135 | if lr_scheduler is not None: 136 | lr_scheduler.step() 137 | 138 | total_steps += 1 139 | steps += 1 140 | 141 | if total_steps % save_steps == 0: 142 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 143 | 144 | if total_steps % tmp_save_steps == 0: 145 | save_model(model, optimizer, model_path % "tmp") 146 | 147 | if grad_clip is not None: 148 | torch.nn.utils.clip_grad_norm_( 149 | model.parameters(), 150 | grad_clip 151 | ) 152 | 153 | if total_steps % accumulate_steps == 0: 154 | if use_amp == True: 155 | scaler.step(optimizer) 156 | optimizer.zero_grad() 157 | scaler.update() 158 | else: 159 | optimizer.step() 160 | optimizer.zero_grad() 161 | 162 | epoch += 1 163 | steps = 0 164 | 165 | if len(eval_fn_list) > 0: 166 | if val_generator is not None: 167 | if local_rank == 0: 168 | logger.info("Eval val now...") 169 | for eval_fn in eval_fn_list: 170 | eval_res = eval_fn(model, val_generator) 171 | if local_rank == 0: 172 | logger.info("Result: %s" % eval_res) 173 | if test_generator is not None: 174 | if local_rank == 0: 175 | logger.info("Eval test now...") 176 | for eval_fn in eval_fn_list: 177 | eval_res = eval_fn(model, test_generator) 178 | if local_rank == 0: 179 | logger.info("Result: %s" % eval_res) 180 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 181 | if local_rank == 0: 182 | logger.info("Train Completed!") 183 | -------------------------------------------------------------------------------- /mimix/decoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 23 10:59:36 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import torch 8 | 9 | def init_search(model, batch_size, device): 10 | """ 11 | """ 12 | vocab_size = model.trg_vocab_size 13 | y = torch.zeros(batch_size, 1, dtype=torch.long) + model.BOS 14 | 15 | log_probs = torch.zeros(batch_size, 1, dtype=torch.float) 16 | finished = torch.zeros(batch_size, 1, dtype=torch.uint8) 17 | hypothesis = torch.zeros(batch_size, 1, dtype=torch.long) + model.BOS 18 | history_probs = torch.zeros(batch_size, 0, dtype=torch.float) 19 | 20 | mask_finished = torch.tensor([model.MIN_LOGITS] * vocab_size, 21 | dtype=torch.float) 22 | mask_finished[model.PAD] = model.MAX_LOGITS 23 | 24 | states = [y.to(device), 25 | log_probs.to(device), 26 | finished.to(device), 27 | mask_finished.to(device), 28 | hypothesis.to(device), 29 | history_probs.to(device)] 30 | return states 31 | 32 | 33 | def process_logits(model, logits, states, repetition_penalty, repetition_window_size): 34 | """ 35 | """ 36 | mask_unk = torch.zeros_like(logits) 37 | mask_unk[:,model.UNK] = model.MIN_LOGITS 38 | logits = logits + mask_unk 39 | 40 | if repetition_penalty < 0: 41 | y, log_probs, finished, mask_finished, hypothesis, history_probs = states 42 | mask = torch.zeros(hypothesis.shape[0]*hypothesis.shape[1], 43 | model.trg_vocab_size, 44 | device = hypothesis.device) 45 | mask.scatter_(1, hypothesis.view(-1, 1), 1) 46 | mask = mask.view(hypothesis.shape[0], hypothesis.shape[1], model.trg_vocab_size) 47 | 48 | mask = mask[:,-repetition_window_size:,:] 49 | mask = torch.sum(mask, 1) 50 | 51 | logits = logits + mask * repetition_penalty 52 | 53 | return logits 54 | 55 | 56 | def top_k_top_p_sampling(logits, 57 | top_k=-1, 58 | top_p=-1, 59 | temperature=1, 60 | n_samples=1, 61 | replacement=True): 62 | """ 63 | """ 64 | logits /= temperature 65 | probs = torch.softmax(logits, -1) 66 | 67 | if top_k > 0 or top_p > 0: 68 | _logits, _indices = torch.sort(logits, descending=True) 69 | 70 | if top_k > 0: 71 | probs[logits < _logits[:, top_k, None]] = 0 72 | 73 | if top_p > 0: 74 | cumulative_logits = torch.cumsum(torch.softmax(_logits, -1), dim=-1) 75 | need_filter = (cumulative_logits > top_p) 76 | 77 | need_filter[:, 1:] = need_filter[:, :-1].clone() 78 | need_filter[:, 0] = 0 79 | 80 | filter_indice = need_filter.scatter(1, _indices, need_filter) 81 | probs[filter_indice] = 0 82 | 83 | probs /= torch.sum(probs, dim=-1, keepdim=True) 84 | 85 | samples = torch.multinomial(probs, n_samples, replacement=replacement) 86 | probs = torch.gather(probs, 1, samples) 87 | 88 | return samples, probs 89 | 90 | 91 | def search(model, 92 | beam_size, 93 | inputs=None, 94 | device="cpu", 95 | strategy="beam_search", 96 | top_k=-1, 97 | top_p=-1, 98 | temperature=1, 99 | eos=None, 100 | group_size=-1, 101 | repetition_penalty=0, 102 | repetition_window_size=0, 103 | use_mask_unk=False, 104 | max_decode_steps=None): 105 | """ 106 | """ 107 | batch_size = 1 108 | if inputs is not None: 109 | if inputs.get("x", None) is not None: 110 | batch_size = inputs["x"].size(0) 111 | elif inputs.get("y", None) is not None: 112 | batch_size = inputs["y"].size(0) 113 | 114 | states = init_search(model, batch_size, device) 115 | 116 | states, cache = model.init_search(states, inputs) 117 | 118 | steps = 0 119 | last_beam_size = 1 120 | cur_batch_size = batch_size 121 | cur_beam_size = beam_size 122 | if group_size > 0: 123 | cur_beam_size = group_size 124 | 125 | while True: 126 | y, log_probs, finished, mask_finished, hypothesis, history_probs = states 127 | 128 | 129 | logits, cache = model.step(states, cache) 130 | 131 | vocab_size = logits.size(-1) 132 | logits = logits.view(-1, vocab_size) 133 | 134 | #logits: (B x last_beam_size) x V 135 | #probs: (B x last_beam_size) x V 136 | logits = process_logits(model, logits, states, repetition_penalty, repetition_window_size) 137 | probs = torch.softmax(logits, -1) 138 | 139 | if strategy == "beam_search": 140 | 141 | #log_probs: B x last_beam_size 142 | #finished: (B x last_beam_size) x 1 143 | #mask_finished: vocab_size 144 | #cur_log_probs: (B x last_beam_size) x V 145 | masked_logits = logits * (1 - finished.float()) + mask_finished * finished.float() 146 | cur_log_probs = log_probs.view(-1, 1) + torch.log_softmax(masked_logits, -1) 147 | 148 | #topk_log_probs: B x cur_beam_size 149 | #topk_ids: B x cur_beam_size 150 | #y: (B x cur_beam_size) x 1 151 | #probs: (B x cur_beam_size) x 1 152 | topk_log_probs, topk_ids = cur_log_probs.view(cur_batch_size, (last_beam_size * vocab_size)).topk(cur_beam_size) 153 | y = (topk_ids % vocab_size).view(-1, 1) 154 | probs = torch.gather(probs.view(cur_batch_size, (last_beam_size * vocab_size)), 1, topk_ids).view(-1, 1) 155 | 156 | #base_id: B 157 | #beam_id: (B x cur_beam_size) 158 | base_id = torch.arange(0, cur_batch_size, device = y.device) * last_beam_size 159 | beam_id = (base_id.view(-1, 1) + topk_ids // vocab_size).view(-1) 160 | 161 | cur_log_probs = topk_log_probs.view(-1) 162 | 163 | cache = model.gather_cache(cache, beam_id) 164 | else: 165 | replacement = not (group_size > 0 and steps == 0) 166 | logits = logits * (1 - finished.float()) + mask_finished * finished.float() 167 | y,probs = top_k_top_p_sampling(logits, 168 | top_k, 169 | top_p, 170 | temperature, 171 | n_samples=cur_beam_size, 172 | replacement=replacement) 173 | 174 | base_id = torch.arange(0, cur_batch_size, device = y.device) * last_beam_size 175 | beam_id = (base_id.view(-1, 1) + y // vocab_size).view(-1) 176 | 177 | y = y.view(-1, 1) 178 | probs = probs.view(-1, 1) 179 | 180 | cur_log_probs = log_probs[beam_id] + torch.log(probs) 181 | 182 | cache = model.gather_cache(cache, beam_id) 183 | 184 | if strategy == "beam_search" or last_beam_size != cur_beam_size: 185 | finished = finished[beam_id,:] 186 | hypothesis = hypothesis[beam_id,:] 187 | history_probs = history_probs[beam_id, :] 188 | 189 | finished = (finished | y.eq(eos).byte()) 190 | hypothesis = torch.cat([hypothesis, y], 1) 191 | history_probs = torch.cat([history_probs, probs], 1) 192 | 193 | if strategy == "beam_search": 194 | if group_size > 0: 195 | if steps == 0: 196 | cur_batch_size = batch_size * group_size 197 | cur_beam_size = beam_size // group_size 198 | else: 199 | last_beam_size = cur_beam_size 200 | else: 201 | last_beam_size = cur_beam_size 202 | cur_batch_size = batch_size 203 | cur_beam_size = beam_size 204 | elif strategy == "sample": 205 | if group_size > 0: 206 | if steps == 0: 207 | cur_batch_size = batch_size * group_size 208 | cur_beam_size = beam_size // group_size 209 | else: 210 | cur_batch_size = batch_size * beam_size 211 | cur_beam_size = 1 212 | else: 213 | cur_batch_size = batch_size * beam_size 214 | cur_beam_size = 1 215 | 216 | states = [y, cur_log_probs, finished, mask_finished, hypothesis, history_probs] 217 | steps += 1 218 | yield states, cache 219 | 220 | if finished.all() or (max_decode_steps is not None and steps >= max_decode_steps): 221 | break 222 | 223 | 224 | def crf_model_decoding(model, x): 225 | """ 226 | """ 227 | emission = model.get_emission(x) 228 | 229 | mask = x.ne(model.PAD).float() 230 | 231 | crf, emission, mask, pad_tag = model.crf, emission, mask, model.PAD 232 | 233 | batch_size, seq_len, n_labels = emission.size() 234 | 235 | scores = crf.start_trans + emission[:, 0, :] 236 | path_table = torch.zeros(batch_size, seq_len-1, n_labels, dtype=torch.long, device=x.device) 237 | 238 | for i in range(1, seq_len): 239 | all_scores = scores.unsqueeze(2) + emission[:, i, :].unsqueeze(1) + crf.trans.unsqueeze(0) 240 | 241 | next_scores,indices = torch.max(all_scores, 1) 242 | 243 | next_scores = mask[:,i:i+1] * next_scores + (1 - mask[:,i:i+1]) * scores 244 | 245 | path_table[:, i-1, :] = indices 246 | 247 | scores = next_scores 248 | 249 | best_scores,end_tag = torch.max(scores, 1) 250 | end_tag = end_tag.unsqueeze(-1) 251 | 252 | indice = end_tag 253 | best_path = indice 254 | for i in range(seq_len-2, -1, -1): 255 | indice = torch.gather(path_table[:,i,:], -1, indice) 256 | best_path = torch.cat([indice, best_path], 1) 257 | 258 | if mask is not None: 259 | best_path[mask<1] = pad_tag 260 | 261 | return best_path 262 | -------------------------------------------------------------------------------- /mimix/ds.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 30 15:16:38 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | 8 | import os 9 | from datetime import datetime 10 | import logging 11 | import torch 12 | from mimix.utils import real_path 13 | 14 | LOG_DIR = "logger" 15 | 16 | local_rank = int(os.environ['LOCAL_RANK']) 17 | world_size = int(os.environ['WORLD_SIZE']) 18 | rank = int(os.environ['RANK']) 19 | 20 | if rank == 0 and local_rank == 0: 21 | if not os.path.exists(real_path(LOG_DIR)): 22 | os.mkdir(real_path(LOG_DIR)) 23 | 24 | 25 | def build_logger(): 26 | """ 27 | """ 28 | format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 29 | filename = datetime.today().strftime('logger/%Y-%m-%d-%H-%M-%S-rank' + str(rank) + '.log') 30 | logging.basicConfig(filename=real_path(filename), 31 | level=logging.INFO, 32 | format=format_str) 33 | console = logging.StreamHandler() 34 | console.setLevel(logging.INFO) 35 | 36 | formatter = logging.Formatter(format_str, "%Y-%m-%d %H:%M:%S") 37 | console.setFormatter(formatter) 38 | logging.getLogger('').addHandler(console) 39 | logger = logging.getLogger(__name__) 40 | 41 | return logger 42 | 43 | if local_rank == 0: 44 | logger = build_logger() 45 | 46 | def save_model(model, optimizer, model_path): 47 | """ 48 | """ 49 | if rank == 0 and local_rank == 0: 50 | logger.info("Save model to %s" % model_path) 51 | torch.save(model.state_dict(), 52 | model_path, 53 | _use_new_zipfile_serialization=False) 54 | 55 | logger.info("Save model complete") 56 | 57 | 58 | def print_model_info(model): 59 | """ 60 | """ 61 | if local_rank == 0: 62 | logger.info("%s" % model) 63 | total_params = sum(p.numel() for p in model.parameters()) 64 | if local_rank == 0: 65 | logger.info("Total Model Params:%s" % total_params) 66 | total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad==True) 67 | if local_rank == 0: 68 | logger.info("Trainable Model Params:%s" % total_train_params) 69 | 70 | 71 | def train(model, 72 | optimizer, 73 | train_config, 74 | train_generator, 75 | lr_scheduler=None): 76 | """ 77 | """ 78 | if os.path.exists(real_path(train_config["model_dir"])) == False: 79 | os.mkdir(real_path(train_config["model_dir"])) 80 | 81 | print_model_info(model) 82 | 83 | if local_rank == 0: 84 | logger.info("Train Start!") 85 | 86 | print_every_n_steps = train_config.get("print_every_n_steps", 100) 87 | model_path = real_path(os.path.join(real_path(train_config["model_dir"]), "%s." + train_config["model_name"])) 88 | save_steps = train_config.get("save_steps", 100000) 89 | tmp_save_steps = train_config.get("tmp_save_steps", 10000) 90 | history_loss = [] 91 | epoch,steps,total_steps = 0, 0, 0 92 | while epoch < train_config["max_epoch"]: 93 | for inputs,targets in train_generator(): 94 | inputs = [inputs.to(model.device)] 95 | targets = [targets.to(model.device)] 96 | outputs = model(inputs, targets=targets, compute_loss=True) 97 | loss = outputs["loss"] 98 | history_loss = history_loss[-999:] + [loss.item()] 99 | 100 | if total_steps % print_every_n_steps == 0: 101 | ma_loss = sum(history_loss) / len(history_loss) 102 | if local_rank == 0: 103 | logger.info( 104 | "%d epoch %d step total %d steps loss: %.3f" % 105 | (epoch, 106 | steps, 107 | total_steps, 108 | ma_loss) 109 | ) 110 | 111 | model.backward(loss) 112 | model.step() 113 | total_steps += 1 114 | steps += 1 115 | 116 | if lr_scheduler is not None: 117 | lr_scheduler.step() 118 | 119 | if total_steps % save_steps == 0: 120 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 121 | 122 | if total_steps % tmp_save_steps == 0: 123 | save_model(model, optimizer, model_path % "tmp") 124 | 125 | epoch += 1 126 | steps = 0 127 | 128 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 129 | if local_rank == 0: 130 | logger.info("Train Completed!") 131 | -------------------------------------------------------------------------------- /mimix/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Nov 18 15:17:36 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from mimix.decoding import crf_model_decoding 11 | 12 | def eval_acc(model, generator): 13 | """ 14 | """ 15 | model.eval() 16 | with torch.no_grad(): 17 | shot_count = 0 18 | total_count = 0 19 | for inputs,targets in generator(): 20 | outputs = model(inputs) 21 | pred = outputs["cls_logits"] 22 | shot = torch.sum(pred.argmax(1) == targets["labels"].view(-1)) 23 | 24 | shot_count = shot_count + shot.item() 25 | total_count = total_count + targets["labels"].size(0) 26 | 27 | acc = shot_count / total_count 28 | return {"acc":acc} 29 | 30 | 31 | def eval_perplexity(model, generator): 32 | """ 33 | """ 34 | model.eval() 35 | with torch.no_grad(): 36 | sum_log_p = 0 37 | sum_len = 0 38 | for inputs,targets in generator(): 39 | outputs = model(inputs) 40 | logits = outputs["logits"] 41 | 42 | log_probs = torch.gather(F.log_softmax(logits, 2), 43 | 2, 44 | targets["y_target"].unsqueeze(-1)) 45 | 46 | mask = (inputs["y"] != model.PAD).float() 47 | seq_len = torch.sum(mask) 48 | log_probs = torch.sum(mask * log_probs.squeeze(-1)) 49 | sum_log_p = sum_log_p + log_probs.item() 50 | sum_len = sum_len + seq_len.item() 51 | 52 | perplexity = np.exp(-sum_log_p / sum_len) 53 | return {"ppl":perplexity} 54 | 55 | 56 | def eval_sequence_labeling_acc(model, generator): 57 | """ 58 | """ 59 | model.eval() 60 | with torch.no_grad(): 61 | shot_count = 0 62 | total_count = 0 63 | for inputs,targets in generator(): 64 | x = inputs[0] 65 | if model.crf is not None: 66 | pred = crf_model_decoding(model, x) 67 | else: 68 | pred = model.get_emission(x).argmax(-1) 69 | 70 | shot = torch.sum((pred == targets[0]) * (x != model.PAD)) 71 | 72 | shot_count = shot_count + shot.item() 73 | total_count = total_count + torch.sum(x != model.PAD).item() 74 | 75 | acc = shot_count / total_count 76 | return {"acc":acc} 77 | 78 | """ 79 | This is a modified version of tensor2tensor bleu and rouge 80 | See https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/bleu_hook.py 81 | See https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/rouge.py 82 | """ 83 | import re 84 | import six 85 | import unicodedata 86 | import collections 87 | import math 88 | import sys 89 | 90 | def _count_ngrams(segment, max_order): 91 | """Extracts all n-grams up to a given maximum order from an input segment. 92 | Args: 93 | segment: text segment from which n-grams will be extracted. 94 | max_order: maximum length in tokens of the n-grams returned by this 95 | methods. 96 | Returns: 97 | The Counter containing all n-grams up to max_order in segment 98 | with a count of how many times each n-gram occurred. 99 | """ 100 | ngram_counts = collections.Counter() 101 | for order in range(1, max_order + 1): 102 | for i in range(0, len(segment) - order + 1): 103 | ngram = tuple(segment[i:i + order]) 104 | ngram_counts[ngram] += 1 105 | return ngram_counts 106 | 107 | 108 | def compute_bleu(reference_corpus, 109 | translation_corpus, 110 | max_order=4, 111 | use_bp=True): 112 | """Computes BLEU score of translated segments against one or more references. 113 | Args: 114 | reference_corpus: list of references for each translation. Each 115 | reference should be tokenized into a list of tokens. 116 | translation_corpus: list of translations to score. Each translation 117 | should be tokenized into a list of tokens. 118 | max_order: Maximum n-gram order to use when computing BLEU score. 119 | use_bp: boolean, whether to apply brevity penalty. 120 | Returns: 121 | BLEU score. 122 | """ 123 | reference_length = 0 124 | translation_length = 0 125 | bp = 1.0 126 | geo_mean = 0 127 | 128 | matches_by_order = [0] * max_order 129 | possible_matches_by_order = [0] * max_order 130 | precisions = [] 131 | 132 | for (references, translations) in zip(reference_corpus, translation_corpus): 133 | reference_length += len(references) 134 | translation_length += len(translations) 135 | ref_ngram_counts = _count_ngrams(references, max_order) 136 | translation_ngram_counts = _count_ngrams(translations, max_order) 137 | 138 | overlap = dict((ngram, 139 | min(count, translation_ngram_counts[ngram])) 140 | for ngram, count in ref_ngram_counts.items()) 141 | 142 | for ngram in overlap: 143 | matches_by_order[len(ngram) - 1] += overlap[ngram] 144 | for ngram in translation_ngram_counts: 145 | possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram] 146 | precisions = [0] * max_order 147 | smooth = 1.0 148 | for i in range(0, max_order): 149 | if possible_matches_by_order[i] > 0: 150 | precisions[i] = matches_by_order[i] / possible_matches_by_order[i] 151 | if matches_by_order[i] > 0: 152 | precisions[i] = matches_by_order[i] / possible_matches_by_order[i] 153 | else: 154 | smooth *= 2 155 | precisions[i] = 1.0 / (smooth * possible_matches_by_order[i]) 156 | else: 157 | precisions[i] = 0.0 158 | 159 | if max(precisions) > 0: 160 | p_log_sum = sum(math.log(p) for p in precisions if p) 161 | geo_mean = math.exp(p_log_sum/max_order) 162 | 163 | if use_bp: 164 | if not reference_length: 165 | bp = 1.0 166 | else: 167 | ratio = translation_length / reference_length 168 | if ratio <= 0.0: 169 | bp = 0.0 170 | elif ratio >= 1.0: 171 | bp = 1.0 172 | else: 173 | bp = math.exp(1 - 1. / ratio) 174 | bleu = geo_mean * bp 175 | return np.float32(bleu) 176 | 177 | 178 | class UnicodeRegex(object): 179 | """Ad-hoc hack to recognize all punctuation and symbols.""" 180 | 181 | def __init__(self): 182 | punctuation = self.property_chars("P") 183 | self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])") 184 | self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])") 185 | self.symbol_re = re.compile("([" + self.property_chars("S") + "])") 186 | 187 | def property_chars(self, prefix): 188 | return "".join(six.unichr(x) for x in range(sys.maxunicode) 189 | if unicodedata.category(six.unichr(x)).startswith(prefix)) 190 | 191 | uregex = UnicodeRegex() 192 | 193 | 194 | def bleu_tokenize(string): 195 | r"""Tokenize a string following the official BLEU implementation. 196 | See https://github.com/moses-smt/mosesdecoder/" 197 | "blob/master/scripts/generic/mteval-v14.pl#L954-L983 198 | In our case, the input string is expected to be just one line 199 | and no HTML entities de-escaping is needed. 200 | So we just tokenize on punctuation and symbols, 201 | except when a punctuation is preceded and followed by a digit 202 | (e.g. a comma/dot as a thousand/decimal separator). 203 | Note that a number (e.g. a year) followed by a dot at the end of sentence 204 | is NOT tokenized, 205 | i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g` 206 | does not match this case (unless we add a space after each sentence). 207 | However, this error is already in the original mteval-v14.pl 208 | and we want to be consistent with it. 209 | Args: 210 | string: the input string 211 | Returns: 212 | a list of tokens 213 | """ 214 | string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string) 215 | string = uregex.punct_nondigit_re.sub(r" \1 \2", string) 216 | string = uregex.symbol_re.sub(r" \1 ", string) 217 | return string.split() 218 | 219 | 220 | def _len_lcs(x, y): 221 | """Returns the length of the Longest Common Subsequence between two seqs. 222 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 223 | Args: 224 | x: sequence of words 225 | y: sequence of words 226 | Returns 227 | integer: Length of LCS between x and y 228 | """ 229 | table = _lcs(x, y) 230 | n, m = len(x), len(y) 231 | return table[n, m] 232 | 233 | 234 | def _lcs(x, y): 235 | """Computes the length of the LCS between two seqs. 236 | The implementation below uses a DP programming algorithm and runs 237 | in O(nm) time where n = len(x) and m = len(y). 238 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 239 | Args: 240 | x: collection of words 241 | y: collection of words 242 | Returns: 243 | Table of dictionary of coord and len lcs 244 | """ 245 | n, m = len(x), len(y) 246 | table = {} 247 | for i in range(n + 1): 248 | for j in range(m + 1): 249 | if i == 0 or j == 0: 250 | table[i, j] = 0 251 | elif x[i - 1] == y[j - 1]: 252 | table[i, j] = table[i - 1, j - 1] + 1 253 | else: 254 | table[i, j] = max(table[i - 1, j], table[i, j - 1]) 255 | return table 256 | 257 | 258 | def _recon_lcs(x, y): 259 | """ 260 | Returns the Longest Subsequence between x and y. 261 | Source: http://www.algorithmist.com/index.php/Longest_Common_Subsequence 262 | 263 | Args: 264 | x: sequence of words 265 | y: sequence of words 266 | 267 | Returns: 268 | sequence: LCS of x and y 269 | """ 270 | i, j = len(x), len(y) 271 | table = _lcs(x, y) 272 | 273 | def _recon(i, j): 274 | """private recon calculation""" 275 | if i == 0 or j == 0: 276 | return [] 277 | elif x[i - 1] == y[j - 1]: 278 | return _recon(i - 1, j - 1) + [(x[i - 1], i)] 279 | elif table[i - 1, j] > table[i, j - 1]: 280 | return _recon(i - 1, j) 281 | else: 282 | return _recon(i, j - 1) 283 | 284 | recon_list = list(map(lambda x: x[0], _recon(i, j))) 285 | 286 | return recon_list 287 | 288 | 289 | def _f_lcs(llcs, m, n): 290 | """Computes the LCS-based F-measure score. 291 | Source: https://www.microsoft.com/en-us/research/publication/ 292 | rouge-a-package-for-automatic-evaluation-of-summaries/ 293 | Args: 294 | llcs: Length of LCS 295 | m: number of words in reference summary 296 | n: number of words in candidate summary 297 | Returns: 298 | Float. LCS-based F-measure score 299 | """ 300 | r_lcs = llcs / m 301 | p_lcs = llcs / n 302 | f_lcs = 2.0 * ((p_lcs * r_lcs) / (p_lcs + r_lcs + 1e-8)) 303 | return f_lcs 304 | 305 | beta = p_lcs / (r_lcs + 1e-12) 306 | num = (1 + (beta**2)) * r_lcs * p_lcs 307 | denom = r_lcs + ((beta**2) * p_lcs) 308 | f_lcs = num / (denom + 1e-12) 309 | return f_lcs 310 | 311 | 312 | def rouge_l_sentence_level(eval_sentences, ref_sentences): 313 | """Computes ROUGE-L (sentence level) of two collections of sentences. 314 | Source: https://www.microsoft.com/en-us/research/publication/ 315 | rouge-a-package-for-automatic-evaluation-of-summaries/ 316 | Calculated according to: 317 | R_lcs = LCS(X,Y)/m 318 | P_lcs = LCS(X,Y)/n 319 | F_lcs = ((1 + beta^2)*R_lcs*P_lcs) / (R_lcs + (beta^2) * P_lcs) 320 | where: 321 | X = reference summary 322 | Y = Candidate summary 323 | m = length of reference summary 324 | n = length of candidate summary 325 | Args: 326 | eval_sentences: The sentences that have been picked by the summarizer 327 | ref_sentences: The sentences from the reference set 328 | Returns: 329 | A float: F_lcs 330 | """ 331 | 332 | f1_scores = [] 333 | for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences): 334 | #m = len(ref_sentence) 335 | #n = len(eval_sentence) 336 | #lcs = _len_lcs(eval_sentence, ref_sentence) 337 | 338 | m = len(_get_ngrams(1, ref_sentence)) 339 | n = len(_get_ngrams(1, eval_sentence)) 340 | lcs = _recon_lcs(eval_sentence, ref_sentence) 341 | lcs = len(_get_ngrams(1, lcs)) 342 | 343 | f1_scores.append(_f_lcs(lcs, m, n)) 344 | return np.mean(f1_scores, dtype=np.float32) 345 | 346 | 347 | def _get_ngrams(n, text): 348 | """Calculates n-grams. 349 | Args: 350 | n: which n-grams to calculate 351 | text: An array of tokens 352 | Returns: 353 | A set of n-grams 354 | """ 355 | ngram_set = set() 356 | text_length = len(text) 357 | max_index_ngram_start = text_length - n 358 | for i in range(max_index_ngram_start + 1): 359 | ngram_set.add(tuple(text[i:i + n])) 360 | return ngram_set 361 | 362 | 363 | def rouge_n(eval_sentences, ref_sentences, n=2): 364 | """Computes ROUGE-N f1 score of two text collections of sentences. 365 | Source: https://www.microsoft.com/en-us/research/publication/ 366 | rouge-a-package-for-automatic-evaluation-of-summaries/ 367 | Args: 368 | eval_sentences: The sentences that have been picked by the summarizer 369 | ref_sentences: The sentences from the reference set 370 | n: Size of ngram. Defaults to 2. 371 | Returns: 372 | f1 score for ROUGE-N 373 | """ 374 | 375 | f1_scores = [] 376 | for eval_sentence, ref_sentence in zip(eval_sentences, ref_sentences): 377 | eval_ngrams = _get_ngrams(n, eval_sentence) 378 | ref_ngrams = _get_ngrams(n, ref_sentence) 379 | ref_count = len(ref_ngrams) 380 | eval_count = len(eval_ngrams) 381 | 382 | # Gets the overlapping ngrams between evaluated and reference 383 | overlapping_ngrams = eval_ngrams.intersection(ref_ngrams) 384 | overlapping_count = len(overlapping_ngrams) 385 | 386 | # Handle edge case. This isn't mathematically correct, but it's good enough 387 | if eval_count == 0: 388 | precision = 0.0 389 | else: 390 | precision = overlapping_count / eval_count 391 | 392 | if ref_count == 0: 393 | recall = 0.0 394 | else: 395 | recall = overlapping_count / ref_count 396 | 397 | f1_scores.append(2.0 * ((precision * recall) / (precision + recall + 1e-8))) 398 | 399 | # return overlapping_count / reference_count 400 | return np.mean(f1_scores, dtype=np.float32) 401 | 402 | -------------------------------------------------------------------------------- /mimix/interact.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jul 16 12:20:14 2018 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import os 8 | import platform 9 | from argparse import ArgumentParser 10 | import sys 11 | import time 12 | from PIL import Image 13 | from mimix.predictor import EncDecGenerator,LMGenerator,TextEncoder 14 | from mimix.predictor import ImageEncoder, ClipMatcher 15 | from mimix.utils import real_path, load_model_config 16 | 17 | 18 | def pretty_print(res): 19 | """ 20 | Assume res = [{k1:v1,k2:v2,...,kn:[[s1, score1], [s2, score2]]}, {}, ... ] 21 | """ 22 | for dic in res: 23 | info = [[k,dic[k]] for k in dic if not isinstance(dic[k], list)] 24 | info = " ".join("%s:%s" % (k,v) for k,v in info) 25 | if len(info) > 0: 26 | print(info) 27 | print("--------------------") 28 | for k in dic: 29 | if isinstance(dic[k], list): 30 | for a in dic[k]: 31 | info = " ".join([str(x) for x in a]) 32 | print(info) 33 | 34 | 35 | def enc_dec_demo(config): 36 | """ 37 | """ 38 | enc_dec_gen = EncDecGenerator(config) 39 | 40 | print("INPUT TEXT:") 41 | for line in sys.stdin: 42 | line = line.strip() 43 | 44 | if len(line) == 0: 45 | continue 46 | 47 | src_list = [line.strip().split("\t")[0]] 48 | 49 | prefix_list = None 50 | if "\t" in line: 51 | arr = line.strip().split("\t") 52 | src_list = [s for i,s in enumerate(arr) if i % 2 == 0] 53 | prefix_list = [s for i,s in enumerate(arr) if i % 2 == 1] 54 | 55 | start = time.time() 56 | search_res = enc_dec_gen.predict(src_list, prefix_list=prefix_list) 57 | 58 | search_res = [{"src":x, "predict":y} for x,y in search_res] 59 | pretty_print(search_res) 60 | 61 | end = time.time() 62 | cost = end - start 63 | print("-----cost time: %s s-----" % cost) 64 | 65 | 66 | def lm_demo(config): 67 | """ 68 | """ 69 | lm_gen = LMGenerator(config) 70 | 71 | print("INPUT TEXT:") 72 | for line in sys.stdin: 73 | line = line.strip() 74 | prefix_list = None 75 | if len(line) > 0: 76 | prefix_list = line.split("\t") 77 | 78 | start = time.time() 79 | search_res = lm_gen.predict(prefix_list=prefix_list) 80 | 81 | search_res = [{"src":x, "predict":y} for x,y in search_res] 82 | pretty_print(search_res) 83 | 84 | end = time.time() 85 | cost = end - start 86 | print("-----cost time: %s s-----" % cost) 87 | 88 | 89 | def mlm_demo(config): 90 | """ 91 | """ 92 | lm_gen = TextEncoder(config) 93 | 94 | print("INPUT TEXT:") 95 | 96 | for line in sys.stdin: 97 | line = line.strip() 98 | 99 | if len(line) == 0: 100 | continue 101 | 102 | start = time.time() 103 | 104 | res = lm_gen.predict_mlm([line]) 105 | for src,pred in res: 106 | print("src:", src) 107 | for li in pred: 108 | print(" ".join(["%s:%s" % (w,s) for w,s in li])) 109 | end = time.time() 110 | cost = end - start 111 | print("-----cost time: %s s-----" % cost) 112 | 113 | 114 | def match_text_demo(config): 115 | """ 116 | """ 117 | text_matcher = TextEncoder(config) 118 | 119 | print("INPUT TEXT:") 120 | 121 | for line in sys.stdin: 122 | line = line.strip() 123 | 124 | if len(line) == 0: 125 | continue 126 | 127 | start = time.time() 128 | 129 | texts = line.split("\t") 130 | res = text_matcher.predict_sim(texts) 131 | for i,text_1 in enumerate(texts): 132 | for j, text_2 in enumerate(texts): 133 | if j <= i: 134 | continue 135 | print(text_1, text_2, res[i][j]) 136 | end = time.time() 137 | cost = end - start 138 | print("-----cost time: %s s-----" % cost) 139 | 140 | 141 | def classification_demo(config): 142 | """ 143 | """ 144 | classifier = TextEncoder(config) 145 | 146 | print("INPUT TEXT:") 147 | 148 | for line in sys.stdin: 149 | 150 | line = line.strip() 151 | 152 | if len(line) == 0: 153 | continue 154 | 155 | src_list = [line] 156 | 157 | start = time.time() 158 | 159 | res = classifier.predict_cls(src_list) 160 | res = [{"src":src, "labels":li} for src,li in res] 161 | pretty_print(res) 162 | 163 | end = time.time() 164 | cost = end - start 165 | print("-----cost time: %s s-----" % cost) 166 | 167 | 168 | def sequene_labeling_demo(config): 169 | """ 170 | """ 171 | labeler = TextEncoder(config) 172 | 173 | print("INPUT TEXT:") 174 | 175 | for line in sys.stdin: 176 | 177 | line = line.strip() 178 | 179 | if len(line) == 0: 180 | continue 181 | 182 | src_list = [line] 183 | 184 | start = time.time() 185 | 186 | res = labeler.predict_seq(src_list) 187 | res = [{"src":src, "labels":li} for src,li in res] 188 | pretty_print(res) 189 | 190 | end = time.time() 191 | cost = end - start 192 | print("-----cost time: %s s-----" % cost) 193 | 194 | 195 | def enc_dec_debug(config): 196 | """ 197 | """ 198 | enc_dec_gen = EncDecGenerator(config) 199 | 200 | print("INPUT TEXT:") 201 | src_list = [] 202 | trg_list = [] 203 | for line in sys.stdin: 204 | line = line.strip() 205 | 206 | src,trg = line.split("\t")[:2] 207 | 208 | src_list.append(src) 209 | trg_list.append(trg) 210 | 211 | res = enc_dec_gen.get_topk_pred(src_list, trg_list, topk=10)[0] 212 | words, topk_pairs, history, sum_log_probs, entropy = res 213 | print("src: %s" % src) 214 | print("trg: %s" % trg) 215 | print("sum_log_probs: %.2f" % sum_log_probs) 216 | print("avg_log_probs: %.2f" % (sum_log_probs / len(words))) 217 | for i,word in enumerate(words): 218 | info = word 219 | info = info + " prob: %.2f entropy: %.2f" % (history[i], entropy[i]) 220 | info = info + " topk:" + " ".join(["%s:%.2f" % (w,s) for w,s in topk_pairs[i]]) 221 | print(info) 222 | 223 | 224 | def lm_debug(config): 225 | """ 226 | """ 227 | lm_gen = LMGenerator(config) 228 | 229 | print("INPUT TEXT:") 230 | for line in sys.stdin: 231 | line = line.strip() 232 | 233 | trg = line.strip() 234 | 235 | res = lm_gen.get_topk_pred([trg], topk=10)[0] 236 | words, topk_pairs, history, sum_log_probs, entropy = res 237 | print("trg: %s" % trg) 238 | print("sum_log_probs: %.2f" % sum_log_probs) 239 | print("avg_log_probs: %.2f" % (sum_log_probs / len(words))) 240 | for i,word in enumerate(words): 241 | info = word 242 | info = info + " prob: %.2f entropy: %.2f" % (history[i], entropy[i]) 243 | info = info + " topk:" + " ".join(["%s:%.2f" % (w,s) for w,s in topk_pairs[i]]) 244 | print(info) 245 | 246 | 247 | def enc_dec_score(config): 248 | """ 249 | """ 250 | enc_dec_gen = EncDecGenerator(config) 251 | 252 | start = time.time() 253 | 254 | print("INPUT TEXT:") 255 | for line in sys.stdin: 256 | line = line.strip() 257 | 258 | if len(line) == 0: 259 | continue 260 | 261 | arr = line.strip().split("\t") 262 | src = arr[0] 263 | trg_list = arr[1:] 264 | 265 | pairs_list = [[src, trg_list]] 266 | 267 | res = enc_dec_gen.scoring(pairs_list) 268 | 269 | for src, trg_list in res: 270 | for trg,score in trg_list: 271 | print(src, trg, score) 272 | 273 | end = time.time() 274 | cost = end - start 275 | print("#cost time: %s s" % cost) 276 | 277 | 278 | def lm_score(config): 279 | """ 280 | """ 281 | lm_gen = LMGenerator(config) 282 | 283 | start = time.time() 284 | 285 | print("INPUT TEXT:") 286 | for line in sys.stdin: 287 | line = line.strip() 288 | 289 | if len(line) == 0: 290 | continue 291 | 292 | trg_list = line.strip().split("\t") 293 | 294 | res = lm_gen.scoring(trg_list) 295 | 296 | for trg, score in res: 297 | print(trg, score) 298 | 299 | end = time.time() 300 | cost = end - start 301 | print("#cost time: %s s" % cost) 302 | 303 | 304 | def image_classification_demo(config): 305 | """ 306 | """ 307 | classifier = ImageEncoder(config) 308 | 309 | print("INPUT IMAGE PATH:") 310 | 311 | for line in sys.stdin: 312 | 313 | line = line.strip() 314 | 315 | if len(line) == 0: 316 | continue 317 | 318 | image_path = line 319 | images = [Image.open(image_path)] 320 | 321 | start = time.time() 322 | 323 | res = classifier.predict_cls(images) 324 | images[0].close() 325 | res = [{"labels":li} for src,li in res] 326 | pretty_print(res) 327 | 328 | end = time.time() 329 | cost = end - start 330 | print("-----cost time: %s s-----" % cost) 331 | 332 | 333 | def image2text_demo(config): 334 | """ 335 | """ 336 | if config["model"] == "transformer": 337 | enc_dec_gen = EncDecGenerator(config) 338 | 339 | print("INPUT IMAGE PATH:") 340 | 341 | for line in sys.stdin: 342 | 343 | line = line.strip() 344 | 345 | if len(line) == 0: 346 | continue 347 | 348 | image_path = line 349 | images = [Image.open(image_path)] 350 | 351 | start = time.time() 352 | 353 | search_res = enc_dec_gen.predict(images) 354 | images[0].close() 355 | search_res = [{"predict":y} for x,y in search_res] 356 | pretty_print(search_res) 357 | 358 | end = time.time() 359 | cost = end - start 360 | print("-----cost time: %s s-----" % cost) 361 | 362 | 363 | def text_image_match_demo(config): 364 | """ 365 | """ 366 | clip = ClipMatcher(config) 367 | 368 | print("INPUT IMAGE PATH AND TEXT:") 369 | 370 | for line in sys.stdin: 371 | 372 | line = line.strip() 373 | 374 | if len(line) == 0: 375 | continue 376 | 377 | image_path = line.split("\t")[0] 378 | images = [Image.open(image_path)] 379 | texts = line.split("\t")[1:] 380 | 381 | start = time.time() 382 | 383 | res = clip.predict_sim(images,texts) 384 | images[0].close() 385 | for text,score,prob in zip(texts, res[0][0], res[1][0]): 386 | print(text, score, prob) 387 | 388 | end = time.time() 389 | cost = end - start 390 | print("-----cost time: %s s-----" % cost) 391 | 392 | 393 | def stream_enc_dec_demo(config): 394 | """ 395 | """ 396 | config["beam_size"] = 1 397 | enc_dec_gen = EncDecGenerator(config) 398 | 399 | print("INPUT TEXT:") 400 | for line in sys.stdin: 401 | line = line.strip() 402 | 403 | if len(line) == 0: 404 | continue 405 | 406 | src_list = [line.strip().split("\t")[0]] 407 | 408 | prefix_list = None 409 | if "\t" in line: 410 | prefix_list = [line.split("\t")[1]] 411 | 412 | start = time.time() 413 | search_res = enc_dec_gen.predict_stream(src_list, prefix_list=prefix_list) 414 | 415 | text = "" 416 | while True: 417 | try: 418 | _text = next(search_res)[0][1][0][0] 419 | print(_text[len(text):], end="", flush=True) 420 | text = _text 421 | except: 422 | break 423 | 424 | print() 425 | 426 | end = time.time() 427 | cost = end - start 428 | print("-----cost time: %s s-----" % cost) 429 | 430 | 431 | def stream_lm_demo(config): 432 | """ 433 | """ 434 | config["beam_size"] = 1 435 | lm_gen = LMGenerator(config) 436 | 437 | print("INPUT TEXT:") 438 | for line in sys.stdin: 439 | line = line.strip() 440 | prefix_list = None 441 | if len(line) > 0: 442 | prefix_list = [line] 443 | 444 | start = time.time() 445 | search_res = lm_gen.predict_stream(prefix_list=prefix_list) 446 | 447 | text = "" 448 | while True: 449 | try: 450 | _text = next(search_res)[0][1][0][0] 451 | print(_text[len(text):], end="", flush=True) 452 | text = _text 453 | except: 454 | break 455 | 456 | print() 457 | 458 | end = time.time() 459 | cost = end - start 460 | print("-----cost time: %s s-----" % cost) 461 | 462 | 463 | def chat(config): 464 | 465 | print("loading model...") 466 | max_history_len = config.get("max_history_len", 2000) 467 | max_history_turn = config.get("max_history_turn", 20) 468 | sysinfo = config.get("sysinfo", "") 469 | 470 | assert config["is_mimix_chat"] == True 471 | assert max_history_len < config["trg_max_len"] 472 | 473 | lm_gen = LMGenerator(config) 474 | history = [] 475 | if platform.system() == "Windows": 476 | os.system("cls") 477 | else: 478 | os.system("clear") 479 | print("Welcome to MimixAI.") 480 | while True: 481 | print("User:") 482 | user_input = input() 483 | if user_input == ":restart": 484 | if platform.system() == "Windows": 485 | os.system("cls") 486 | else: 487 | os.system("clear") 488 | print("Welcome to MimixAI.") 489 | history = [] 490 | continue 491 | elif user_input == ":exit": 492 | break 493 | history.append(user_input) 494 | context = " _mimix_" 495 | for i,text in enumerate(history[::-1]): 496 | if i > max_history_turn: 497 | break 498 | if len(context) > max_history_len: 499 | break 500 | if i % 2 == 0: 501 | context = " _mimixuser_ " + text + context 502 | else: 503 | context = " _mimix_ " + text + context 504 | context = (sysinfo + context).strip() 505 | 506 | search_res = lm_gen.predict_stream(prefix_list=[context]) 507 | resp = "" 508 | print("Mimix:") 509 | while True: 510 | try: 511 | _resp = next(search_res)[0][1][0][0].split("_mimix_")[-1].strip() 512 | print(_resp[len(resp):], end="", flush=True) 513 | resp = _resp 514 | except: 515 | break 516 | print() 517 | 518 | history.append(resp) 519 | 520 | 521 | def run_interactive(): 522 | """ 523 | """ 524 | parser = ArgumentParser() 525 | 526 | parser.add_argument("--model_conf", type=str) 527 | parser.add_argument("--mode", type=str, default="pred") 528 | parser.add_argument('--stream', action='store_true') 529 | parser.set_defaults(stream=False) 530 | 531 | args = parser.parse_args(sys.argv[1:]) 532 | 533 | conf_file = args.model_conf 534 | config = load_model_config(real_path(conf_file)) 535 | if "convert_special_token" not in config: 536 | config["convert_special_token"] = False 537 | 538 | if args.mode == "pred": 539 | if config["task"] == "enc_dec": 540 | if args.stream == True: 541 | config["beam_size"] = 1 542 | stream_enc_dec_demo(config) 543 | else: 544 | enc_dec_demo(config) 545 | elif config["task"] == "cls": 546 | classification_demo(config) 547 | elif config["task"] == "lm": 548 | if config.get("is_mimix_chat", False) == True: 549 | config["convert_special_token"] = True 550 | chat(config) 551 | elif args.stream == True: 552 | config["beam_size"] = 1 553 | stream_lm_demo(config) 554 | else: 555 | lm_demo(config) 556 | elif config["task"] == "mlm": 557 | mlm_demo(config) 558 | elif config["task"] == "seqcls": 559 | sequene_labeling_demo(config) 560 | elif config["task"] == "match": 561 | match_text_demo(config) 562 | elif config["task"] == "image_classification": 563 | image_classification_demo(config) 564 | elif config["task"] == "image2text": 565 | image2text_demo(config) 566 | elif config["task"] == "image_text_match": 567 | text_image_match_demo(config) 568 | elif args.mode == "debug": 569 | if config["task"] == "enc_dec": 570 | enc_dec_debug(config) 571 | if config["task"] == "lm": 572 | lm_debug(config) 573 | elif args.mode == "scoring": 574 | if config["task"] == "enc_dec": 575 | enc_dec_score(config) 576 | elif config["task"] == "lm": 577 | lm_score(config) 578 | 579 | if __name__ == "__main__": 580 | run_interactive() 581 | -------------------------------------------------------------------------------- /mimix/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jun 28 09:38:03 2020 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | def cross_entropy_with_smoothing(logits, target, eps, pad): 11 | """ 12 | """ 13 | n_class = logits.size(1) 14 | one_hot = torch.eye(n_class, device=target.device)[target] 15 | target = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 16 | loss = F.cross_entropy(logits, target, reduction="none") 17 | 18 | if pad is not None: 19 | mask = target.view(-1).ne(pad) 20 | loss = torch.sum(loss * mask.float()) / torch.sum(mask.float()) 21 | else: 22 | loss = torch.mean(loss) 23 | 24 | return loss 25 | 26 | 27 | def seq_cross_entropy(logits, target, eps, pad): 28 | """ 29 | """ 30 | 31 | if logits.dim() == 3: 32 | logits = torch.flatten(logits, start_dim=0, end_dim=1) 33 | if eps > 0: 34 | loss = cross_entropy_with_smoothing(logits, target, eps, pad) 35 | else: 36 | loss = F.cross_entropy(logits, 37 | target.view(-1), 38 | ignore_index=pad) 39 | return loss 40 | 41 | 42 | def contrastive_loss(vec, target, sim_alpha): 43 | """ 44 | """ 45 | norm_vec = F.normalize(vec, p=2, dim=1) 46 | sim = torch.mm(norm_vec, norm_vec.T) 47 | sim = torch.masked_fill(sim, torch.eye(sim.shape[0], device=sim.device).bool(), -1000) 48 | loss = F.cross_entropy(sim_alpha * sim, target) 49 | 50 | return loss 51 | 52 | 53 | def symmetric_contrastive_loss(vecs, target, sim_alpha): 54 | """ 55 | """ 56 | vec_1,vec_2 = vecs 57 | norm_vec_1 = F.normalize(vec_1, p=2, dim=1) 58 | norm_vec_2 = F.normalize(vec_2, p=2, dim=1) 59 | sim = torch.mm(norm_vec_1, norm_vec_2.T) 60 | loss = (F.cross_entropy(sim_alpha * sim, target) + F.cross_entropy(sim_alpha * sim.T, target)) / 2 61 | 62 | return loss 63 | 64 | 65 | def classify_loss(logits, target, eps): 66 | """ 67 | """ 68 | if eps > 0: 69 | loss = cross_entropy_with_smoothing(logits, target, eps, None) 70 | else: 71 | loss = F.cross_entropy(logits, target) 72 | 73 | return loss 74 | 75 | 76 | def kl_loss(logits, soft_target, target, pad, temperature): 77 | """ 78 | """ 79 | if logits.dim() == 3: 80 | logits = logits.view(-1, logits.size(-1)) 81 | if soft_target.dim() == 3: 82 | soft_target = soft_target.view(-1, soft_target.size(-1)) 83 | 84 | kl_loss = F.kl_div(F.log_softmax(logits/temperature, dim=1), 85 | F.softmax(soft_target/temperature, dim=1), 86 | reduction="none") 87 | 88 | mask = target.ne(pad).view(-1, 1).float() 89 | 90 | kl_loss = torch.sum(kl_loss * mask) / torch.sum(mask) 91 | kl_loss = kl_loss * temperature * temperature 92 | 93 | return kl_loss 94 | 95 | 96 | def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar): 97 | """ 98 | """ 99 | kld = -0.5 * torch.sum(1 + (recog_logvar - prior_logvar) 100 | - torch.pow(prior_mu - recog_mu, 2) / torch.exp(prior_logvar) 101 | - torch.exp(recog_logvar) / torch.exp(prior_logvar), 102 | 1) 103 | return kld 104 | 105 | -------------------------------------------------------------------------------- /mimix/optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 11 15:54:15 2021 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | from typing import Tuple, Optional, Callable 8 | from torch import optim 9 | import torch 10 | from torch.optim.optimizer import Optimizer 11 | 12 | # functions 13 | 14 | def exists(val): 15 | return val is not None 16 | 17 | # update functions 18 | 19 | def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): 20 | # stepweight decay 21 | 22 | p.data.mul_(1 - lr * wd) 23 | 24 | # weight update 25 | 26 | update = exp_avg.clone().mul_(beta1).add(grad, alpha = 1 - beta1).sign_() 27 | p.add_(update, alpha = -lr) 28 | 29 | # decay the momentum running average coefficient 30 | 31 | exp_avg.mul_(beta2).add_(grad, alpha = 1 - beta2) 32 | 33 | # class 34 | 35 | class Lion(Optimizer): 36 | def __init__( 37 | self, 38 | params, 39 | lr: float = 1e-4, 40 | betas: Tuple[float, float] = (0.9, 0.99), 41 | weight_decay: float = 0.0, 42 | use_triton: bool = False 43 | ): 44 | assert lr > 0. 45 | assert all([0. <= beta <= 1. for beta in betas]) 46 | 47 | defaults = dict( 48 | lr = lr, 49 | betas = betas, 50 | weight_decay = weight_decay 51 | ) 52 | 53 | super().__init__(params, defaults) 54 | 55 | self.update_fn = update_fn 56 | 57 | if use_triton: 58 | from lion_pytorch.triton import update_fn as triton_update_fn 59 | self.update_fn = triton_update_fn 60 | 61 | @torch.no_grad() 62 | def step( 63 | self, 64 | closure: Optional[Callable] = None 65 | ): 66 | 67 | loss = None 68 | if exists(closure): 69 | with torch.enable_grad(): 70 | loss = closure() 71 | 72 | for group in self.param_groups: 73 | for p in filter(lambda p: exists(p.grad), group['params']): 74 | 75 | grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], self.state[p] 76 | 77 | # init state - exponential moving average of gradient values 78 | 79 | if len(state) == 0: 80 | state['exp_avg'] = torch.zeros_like(p) 81 | 82 | exp_avg = state['exp_avg'] 83 | 84 | self.update_fn( 85 | p, 86 | grad, 87 | exp_avg, 88 | lr, 89 | wd, 90 | beta1, 91 | beta2 92 | ) 93 | 94 | return loss 95 | 96 | 97 | def build_optimizer(model, train_config): 98 | """ 99 | """ 100 | if train_config.get("optimizer", "adamW") == "adamW": 101 | optimizer = optim.AdamW( 102 | filter(lambda p: p.requires_grad, model.parameters()), 103 | train_config.get("lr", 1e-4), amsgrad=True) 104 | elif train_config["optimizer"] == "lion": 105 | optimizer = Lion( 106 | filter(lambda p: p.requires_grad, model.parameters()), 107 | train_config.get("lr", 1e-4) 108 | ) 109 | return optimizer -------------------------------------------------------------------------------- /mimix/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 18 10:39:14 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | from argparse import ArgumentParser 9 | import json 10 | import os 11 | import random 12 | from mimix.tokenization import build_tokenizer 13 | from mimix.utils import real_path, load_vocab, load_config, load_model_config 14 | 15 | class TextProcessor(): 16 | """ 17 | """ 18 | def __init__(self, **kwargs): 19 | """ 20 | """ 21 | self.src_max_len = kwargs.get("src_max_len", None) 22 | self.trg_max_len = kwargs.get("trg_max_len", None) 23 | self.PAD = kwargs["symbol2id"]["_pad_"] 24 | self.BOS = kwargs["symbol2id"]["_bos_"] 25 | self.EOS = kwargs["symbol2id"]["_eos_"] 26 | self.UNK = kwargs["symbol2id"]["_unk_"] 27 | self.SEP = kwargs["symbol2id"]["_sep_"] 28 | self.CLS = kwargs["symbol2id"]["_cls_"] 29 | self.MASK = kwargs["symbol2id"]["_mask_"] 30 | self.src_tokenizer = None 31 | self.src_tokenizer = build_tokenizer( 32 | tokenizer=kwargs["src_tokenizer"], 33 | vocab_file=real_path(kwargs["src_vocab"])) 34 | self.trg_tokenizer = None 35 | self.trg_tokenizer = build_tokenizer( 36 | tokenizer=kwargs["trg_tokenizer"], 37 | vocab_file=real_path(kwargs["trg_vocab"])) 38 | self.label2id = None 39 | if kwargs.get("label2id", None) is not None: 40 | load_vocab(real_path(kwargs["label2id"])) 41 | self.task = kwargs["task"] 42 | 43 | 44 | def parse(self, line): 45 | """ 46 | """ 47 | try: 48 | data = json.loads(line) 49 | return data 50 | except: 51 | return None 52 | 53 | 54 | def preprocess(self, data): 55 | """ 56 | """ 57 | src = data.get("src", None) 58 | trg = data.get("trg", None) 59 | label = data.get("label", None) 60 | seq_label = data.get("seq_label", None) 61 | 62 | data = {} 63 | if src: 64 | src = self.src_tokenizer.tokenize_to_ids(src) 65 | src = src[:self.src_max_len] 66 | if self.task == "classify" or self.task == "match": 67 | src = [self.CLS] + src[:self.src_max_len - 1] 68 | data["src"] = src 69 | 70 | if trg: 71 | trg = self.trg_tokenizer.tokenize_to_ids(trg) 72 | trg = trg[:self.trg_max_len - 1] 73 | trg = [self.BOS] + trg + [self.EOS] 74 | data["trg"] = trg 75 | 76 | if label: 77 | if self.label2id is not None: 78 | label = self.label2id[label] 79 | else: 80 | label = int(label) 81 | data["label"] = label 82 | 83 | if seq_label: 84 | if self.label2id is not None: 85 | seq_label = [self.label2id[s] for s in seq_label] 86 | else: 87 | seq_label = [int(s) for s in seq_label] 88 | data["seq_label"] = seq_label 89 | 90 | return data 91 | 92 | 93 | def __call__(self, line): 94 | """ 95 | """ 96 | parsed = self.parse(line) 97 | processed = None 98 | if parsed is not None: 99 | processed = self.preprocess(parsed) 100 | 101 | return processed 102 | 103 | 104 | def preprocess(data_dir, 105 | dest_dir, 106 | num_shards=1, 107 | data_preprocessor=None, 108 | sort_key_fn=None): 109 | """ 110 | Shuffle data 111 | """ 112 | data_files = [f for f in os.listdir(data_dir)] 113 | 114 | fo_list = [] 115 | for f in range(num_shards): 116 | fo_list.append(open(os.path.join(dest_dir, str(f)), "w", encoding="utf-8")) 117 | 118 | for fi in data_files: 119 | for line in open(os.path.join(data_dir, fi), "r", encoding="utf-8"): 120 | fo = random.choice(fo_list) 121 | if data_preprocessor is not None: 122 | data = data_preprocessor(line) 123 | line = json.dumps(data, ensure_ascii=False) + "\n" 124 | fo.write(line) 125 | 126 | for fo in fo_list: 127 | fo.close() 128 | 129 | for f in range(num_shards): 130 | lines = [line for line in open(os.path.join(dest_dir, str(f)), "r", encoding="utf-8")] 131 | random.shuffle(lines) 132 | if sort_key_fn is not None: 133 | lines = [[line, sort_key_fn(json.loads(line))] for line in lines] 134 | lines.sort(key=lambda x:x[1]) 135 | lines = [x[0] for x in lines] 136 | fo = open(os.path.join(dest_dir, str(f)), "w", encoding="utf-8") 137 | for line in lines: 138 | fo.write(line) 139 | fo.close() 140 | 141 | 142 | def run_preprocess(): 143 | """ 144 | """ 145 | parser = ArgumentParser() 146 | 147 | parser.add_argument("--model_conf", type=str) 148 | parser.add_argument("--train_conf", type=str) 149 | 150 | args = parser.parse_args(sys.argv[1:]) 151 | 152 | model_config = load_model_config(real_path(args.model_conf)) 153 | train_config = load_config(real_path(args.train_conf)) 154 | 155 | processor = TextProcessor(**model_config) 156 | 157 | preprocess(train_config["train_dir"], 158 | os.path.join(real_path(train_config["tmp_dir"]), "train"), 159 | num_shards=train_config.get("num_shards", 1), 160 | data_preprocessor=processor, 161 | sort_key_fn=None) 162 | if train_config.get("val_dir", None) is not None: 163 | preprocess(train_config["val_dir"], 164 | os.path.join(real_path(train_config["tmp_dir"]), "val"), 165 | num_shards=1, 166 | data_preprocessor=processor, 167 | sort_key_fn=None) 168 | if train_config.get("test_dir", None) is not None: 169 | preprocess(train_config["test_dir"], 170 | os.path.join(real_path(train_config["tmp_dir"]), "test"), 171 | num_shards=1, 172 | data_preprocessor=processor, 173 | sort_key_fn=None) 174 | 175 | if __name__ == "__main__": 176 | run_preprocess() 177 | 178 | 179 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /mimix/scheduler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jan 7 11:53:07 2021 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import math 8 | 9 | class ConstantScheduler(): 10 | """ 11 | """ 12 | def __init__(self, train_config, optimizer): 13 | """ 14 | """ 15 | self.lr = train_config["lr"] 16 | self.optimizer = optimizer 17 | 18 | 19 | def step(self): 20 | """ 21 | """ 22 | for param_group in self.optimizer.param_groups: 23 | param_group['lr'] = self.lr 24 | 25 | 26 | def build_scheduler(train_config, optimizer): 27 | """ 28 | """ 29 | if "scheduler" not in train_config: 30 | return ConstantScheduler(train_config, optimizer) 31 | else: 32 | raise ValueError("scheduler not correct!") -------------------------------------------------------------------------------- /mimix/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 25 11:10:26 2021 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import mimix.tokenization as tokenization 8 | import mimix.evaluate as evaluate 9 | 10 | def test_tokenization(): 11 | """ 12 | """ 13 | test_str = "1234567号选手是top10哦,hello你好666啊,春 秋 忽 代 谢windows7le, _苍天啊_苍天_오늘 날씨가 참 좋네요." 14 | mimix_tokenizer = tokenization.MimixTokenizer(vocab_file="model/vocab/zh_vocab.txt") 15 | 16 | print(mimix_tokenizer.tokenize(test_str)) 17 | 18 | mimix_tokenizer = tokenization.MimixTokenizer(vocab_file="model/vocab/zh_words_vocab.txt") 19 | 20 | print(mimix_tokenizer.tokenize(test_str)) 21 | 22 | #bert_tokenizer = tokenization.BertTokenizer( 23 | # vocab_file="model/pretrain/bert-base-chinese/vocab.txt") 24 | 25 | #print(bert_tokenizer.tokenize(test_str)) 26 | 27 | test_str = "1234567号选手是top10哦,_mask_hello你好666啊,春 秋 忽 代 谢windows7le, _苍天啊_苍天_오늘 날씨가 참 좋네요." 28 | mimix_tokenizer = tokenization.MimixTokenizer(vocab_file="model/vocab/zh_vocab.txt") 29 | 30 | print(mimix_tokenizer.tokenize(test_str)) 31 | 32 | 33 | def test_evaluate(): 34 | """ 35 | """ 36 | ref_corpus = ["今 天 天 气 真 不 错".split(), 37 | "我 好 无 聊 啊".split(),] 38 | eval_corpus = ["今 天 天 气 真 的 不 错".split(), 39 | "我 真 的 很 无 聊".split()] 40 | print(ref_corpus, eval_corpus) 41 | print(evaluate.compute_bleu(ref_corpus, eval_corpus, 4)) 42 | print(evaluate.rouge_n(eval_corpus, ref_corpus, 1)) 43 | print(evaluate.rouge_n(eval_corpus, ref_corpus, 2)) 44 | print(evaluate.rouge_l_sentence_level(eval_corpus, ref_corpus)) 45 | 46 | print("------") 47 | 48 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction, corpus_bleu 49 | print("bleu-4", corpus_bleu( 50 | list_of_references=[[s] for s in ref_corpus], 51 | hypotheses=eval_corpus, 52 | smoothing_function=SmoothingFunction().method1 53 | ) 54 | ) 55 | 56 | 57 | from rouge import Rouge 58 | rouge = Rouge() 59 | res = rouge.get_scores(hyps=[" ".join(s) for s in eval_corpus], 60 | refs=[" ".join(s) for s in ref_corpus], 61 | avg=True) 62 | for k in res: 63 | print(k, res[k]["f"]) 64 | 65 | if __name__ == "__main__": 66 | 67 | test_tokenization() 68 | 69 | test_evaluate() 70 | 71 | 72 | -------------------------------------------------------------------------------- /mimix/tokenization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 20 17:40:07 2020 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import re 8 | from abc import abstractmethod 9 | from mimix.utils import load_vocab 10 | 11 | def is_alphabet(ch): 12 | """ 13 | """ 14 | code = ord(ch) 15 | return 0x3041 <= code <= 0x3093 or \ 16 | 0x30a1 <= code <= 0x30f3 or \ 17 | 0xac00 <= code <= 0xd7af or \ 18 | 0x1100 <= code <= 0x11ff or \ 19 | 0x3130 <= code <= 0x318f or \ 20 | 0xa960 <= code <= 0xa97f or \ 21 | 0xd7b0 <= code <= 0xd7ff or \ 22 | 0x61 <= code <= 0x7a or \ 23 | 0x41 <= code <= 0x5a or \ 24 | 0x430 <= code <= 0x44f or \ 25 | 0x410 <= code <= 0x42f or \ 26 | 0x3b1 <= code <= 0x3c9 or \ 27 | 0x391 <= code <= 0x3a9 or \ 28 | 0xc0 <= code <= 0xff 29 | 30 | 31 | def is_cjk(ch): 32 | """ 33 | """ 34 | code = ord(ch) 35 | return 0x4E00 <= code <= 0x9FFF or \ 36 | 0x3400 <= code <= 0x4DBF or \ 37 | 0x20000 <= code <= 0x2A6DF or \ 38 | 0x2A700 <= code <= 0x2B73F or \ 39 | 0x2B740 <= code <= 0x2B81F or \ 40 | 0x2B820 <= code <= 0x2CEAF or \ 41 | 0xF900 <= code <= 0xFAFF or \ 42 | 0x2F800 <= code <= 0x2FA1F 43 | 44 | 45 | def is_num(ch): 46 | """ 47 | """ 48 | code = ord(ch) 49 | return 0x30 <= code <= 0x39 50 | 51 | 52 | def is_special(ch): 53 | """ 54 | """ 55 | return ch == "_" 56 | 57 | 58 | def is_useless(ch): 59 | """ 60 | """ 61 | code = ord(ch) 62 | return code == 0xfffd or \ 63 | 0x80 <= code <= 0x9F or \ 64 | 0x00 <= code <= 0x1F 65 | 66 | 67 | def is_space(ch): 68 | """ 69 | """ 70 | return ch == " " or ch == " " 71 | 72 | 73 | def is_newline(ch): 74 | """ 75 | """ 76 | return ch == "\n" 77 | 78 | 79 | class Tokenizer(): 80 | """ 81 | """ 82 | def __init__(self, vocab_file): 83 | """ 84 | """ 85 | self.vocab = None 86 | self.id2word = None 87 | 88 | self.vocab = load_vocab(vocab_file) 89 | self.id2word = {self.vocab[w]:w for w in self.vocab} 90 | 91 | 92 | @abstractmethod 93 | def tokenize(self, text): 94 | """ 95 | """ 96 | pass 97 | 98 | 99 | @abstractmethod 100 | def detokenize(self, tokens, convert_special_token=True): 101 | """ 102 | """ 103 | pass 104 | 105 | 106 | def convert_tokens_to_ids(self, tokenized): 107 | """ 108 | """ 109 | unk = self.vocab["_unk_"] 110 | word_ids = [self.vocab.get(word, unk) for word in tokenized] 111 | return word_ids 112 | 113 | 114 | def convert_ids_to_tokens(self, word_ids): 115 | """ 116 | """ 117 | tokens = [self.id2word.get(word, "_unk_") for word in word_ids] 118 | return tokens 119 | 120 | 121 | def tokenize_to_ids(self, text): 122 | """ 123 | """ 124 | tokens = self.tokenize(text) 125 | word_ids = self.convert_tokens_to_ids(tokens) 126 | 127 | return word_ids 128 | 129 | 130 | def detokenize_ids(self, word_ids, convert_special_token=True): 131 | """ 132 | """ 133 | tokens = self.convert_ids_to_tokens(word_ids) 134 | text = self.detokenize(tokens, convert_special_token=convert_special_token) 135 | 136 | return text 137 | 138 | 139 | class SpaceTokenizer(Tokenizer): 140 | """ 141 | """ 142 | def __init__(self, vocab_file): 143 | """ 144 | """ 145 | super(SpaceTokenizer,self).__init__(vocab_file) 146 | 147 | 148 | def tokenize(self, text): 149 | """ 150 | """ 151 | tokens = text.split() 152 | 153 | return tokens 154 | 155 | 156 | def detokenize(self, tokens, convert_special_token=True): 157 | """ 158 | """ 159 | text = " ".join(tokens) 160 | 161 | return text 162 | 163 | 164 | class MimixTokenizer(Tokenizer): 165 | """ 166 | """ 167 | def __init__(self, vocab_file, uncased=True, match_special_symbols=True): 168 | """ 169 | """ 170 | super(MimixTokenizer,self).__init__(vocab_file) 171 | 172 | zh_words = [ww for ww in self.vocab if all([is_cjk(ch) for ch in ww])] 173 | self.tri_tree = self.build_tri_tree(zh_words) 174 | 175 | self.space_token = "_mimixsp_" 176 | self.newline_token = "_mimixnl_" 177 | self.pad_token = "_pad_" 178 | self.bos_token = "_bos_" 179 | self.eos_token = "_eos_" 180 | self.unk_token = "_unk_" 181 | self.sep_token = "_sep_" 182 | self.mask_token = "_mask_" 183 | self.cls_token = "_cls_" 184 | 185 | self.special_symbols = [self.pad_token, 186 | self.bos_token, 187 | self.eos_token, 188 | self.unk_token] 189 | 190 | self.match_special_symbols = match_special_symbols 191 | self.symbols = set() 192 | for word in self.vocab: 193 | if re.search("^_unused[0-9]+_$", word): 194 | continue 195 | if word in self.special_symbols: 196 | continue 197 | if match_special_symbols == False: 198 | continue 199 | if re.search("^_[0-9a-z]+_$", word): 200 | self.symbols.add(word) 201 | 202 | self.symbols_tri_tree = self.build_tri_tree(self.symbols) 203 | 204 | self.uncased = uncased 205 | 206 | 207 | def build_tri_tree(self, keywords): 208 | """ 209 | """ 210 | tri_tree = {} 211 | for key in keywords: 212 | root = tri_tree 213 | for ch in key: 214 | if ch not in root: 215 | root[ch] = {} 216 | root = root[ch] 217 | 218 | root.setdefault(u"##", {}) 219 | 220 | return tri_tree 221 | 222 | 223 | def prefix_match(self, s): 224 | """ 225 | """ 226 | start = 0 227 | size = len(s) 228 | 229 | root = self.symbols_tri_tree 230 | end = start 231 | matched = "" 232 | matched_end = start 233 | while end < size and s[end] in root: 234 | if u"##" in root: 235 | matched = s[start:end] 236 | matched_end = end 237 | 238 | root = root[s[end]] 239 | end += 1 240 | 241 | if u"##" in root: 242 | matched = s[start:end] 243 | matched_end = end 244 | 245 | if matched_end == start: 246 | return "" 247 | 248 | return matched 249 | 250 | 251 | def maximum_match(self, s): 252 | """ 253 | """ 254 | 255 | tokenized = [] 256 | 257 | start = 0 258 | size = len(s) 259 | 260 | while start < size: 261 | root = self.tri_tree 262 | end = start 263 | matched = "" 264 | matched_end = start 265 | while end < size and s[end] in root: 266 | if u"##" in root: 267 | matched = s[start:end] 268 | matched_end = end 269 | 270 | root = root[s[end]] 271 | end += 1 272 | 273 | if u"##" in root: 274 | matched = s[start:end] 275 | matched_end = end 276 | 277 | if matched_end == start: 278 | matched = s[start:start + 1] 279 | matched_end = start + 1 280 | 281 | tokenized.append(matched) 282 | 283 | start = matched_end 284 | 285 | return tokenized 286 | 287 | 288 | def tokenize(self, text): 289 | """ 290 | """ 291 | if self.uncased == True: 292 | text = text.lower() 293 | 294 | i = 0 295 | tokenized = "" 296 | is_last_cjk = False 297 | is_last_num_or_alphabet = False 298 | while i < len(text): 299 | ch = text[i] 300 | if is_special(ch): 301 | if self.match_special_symbols == False: 302 | tokenized += (" " + ch + " ") 303 | i += 1 304 | else: 305 | matched = self.prefix_match(text[i:]) 306 | if len(matched) > 0: 307 | tokenized += (" " + matched + " ") 308 | i += len(matched) 309 | else: 310 | tokenized += (" " + ch + " ") 311 | i += 1 312 | is_last_cjk = False 313 | is_last_num_or_alphabet = False 314 | elif is_cjk(ch): 315 | if is_last_cjk == True: 316 | tokenized += (ch) 317 | else: 318 | tokenized += (" " + ch) 319 | is_last_cjk = True 320 | is_last_num_or_alphabet = False 321 | i += 1 322 | elif is_num(ch) or is_alphabet(ch): 323 | if is_last_num_or_alphabet == True: 324 | tokenized += ch 325 | else: 326 | tokenized += (" " + ch) 327 | is_last_cjk = False 328 | is_last_num_or_alphabet = True 329 | i += 1 330 | elif is_space(ch): 331 | 332 | if i == 0 or i == len(text) - 1: 333 | tokenized += (" " + self.space_token + " ") 334 | elif is_alphabet(text[i-1]) and is_alphabet(text[i+1]): 335 | tokenized += " " 336 | else: 337 | ignore = False 338 | if self.match_special_symbols == True: 339 | if re.search("_[0-9a-z]+_$", text[:i]): 340 | ignore = True 341 | if re.search(" _[0-9a-z]+_", text[i:]): 342 | ignore = True 343 | if ignore == True: 344 | tokenized += " " 345 | else: 346 | tokenized += (" " + self.space_token + " ") 347 | 348 | is_last_cjk = False 349 | is_last_num_or_alphabet = False 350 | i += 1 351 | elif is_newline(ch): 352 | tokenized += (" " + self.newline_token + " ") 353 | is_last_cjk = False 354 | is_last_num_or_alphabet = False 355 | i += 1 356 | elif is_useless(ch): 357 | is_last_cjk = False 358 | is_last_num_or_alphabet = False 359 | i += 1 360 | else: 361 | tokenized += (" " + ch + " ") 362 | is_last_cjk = False 363 | is_last_num_or_alphabet = False 364 | i += 1 365 | 366 | tokens = [] 367 | for token in tokenized.split(): 368 | if len(token) == 0: 369 | continue 370 | elif re.search("^_[0-9a-z]+_$", token): 371 | tokens.append(token) 372 | elif all(is_cjk(ch) for ch in token): 373 | tokens.extend(self.maximum_match(token)) 374 | else: 375 | tokens.extend(self.wordpiece(token)) 376 | 377 | return tokens 378 | 379 | 380 | def wordpiece(self, word): 381 | """ 382 | """ 383 | if word in self.vocab: 384 | return [word] 385 | 386 | tokens = [] 387 | start, stop = 0, 0 388 | while start < len(word): 389 | stop = len(word) 390 | while stop > start: 391 | sub = word[start:stop] 392 | if start > 0: 393 | sub = '##' + sub 394 | if sub in self.vocab: 395 | break 396 | stop -= 1 397 | if start == stop: 398 | stop += 1 399 | tokens.append(sub) 400 | start = stop 401 | 402 | return tokens 403 | 404 | 405 | def detokenize(self, tokens, convert_special_token=True): 406 | """ 407 | """ 408 | text = "" 409 | 410 | is_last_alphabet = False 411 | for token in tokens: 412 | if all(is_cjk(ch) for ch in token): 413 | text += token 414 | is_last_alphabet = False 415 | elif token.startswith("##"): 416 | text += token[2:] 417 | is_last_alphabet = True 418 | elif re.search("^_[0-9a-z]+_$", token): 419 | if convert_special_token == False: 420 | text = text + (" " + token + " ") 421 | else: 422 | if token == self.space_token: 423 | text += " " 424 | elif token == self.newline_token: 425 | text += "\n" 426 | else: 427 | if text.endswith(" ") == False: 428 | text += " " 429 | text += (token + " ") 430 | 431 | is_last_alphabet = False 432 | else: 433 | is_cur_alphabet = False 434 | if all(is_alphabet(ch) or is_num(ch) for ch in token): 435 | is_cur_alphabet = True 436 | if is_last_alphabet == True and is_cur_alphabet == True: 437 | text += (" " + token) 438 | else: 439 | text += token 440 | is_last_alphabet = is_cur_alphabet 441 | 442 | return text 443 | 444 | 445 | class BertTokenizer(Tokenizer): 446 | """ 447 | """ 448 | def __init__(self, vocab_file, uncased=True): 449 | """ 450 | """ 451 | super(BertTokenizer,self).__init__(vocab_file) 452 | from mimix.bert_tokenizer import FullTokenizer 453 | self.tokenizer = FullTokenizer(vocab_file, do_lower_case=uncased) 454 | 455 | 456 | def tokenize(self, text): 457 | """ 458 | """ 459 | return self.tokenizer.tokenize(text) 460 | 461 | 462 | def detokenize(self, tokens, convert_special_token=True): 463 | """ 464 | """ 465 | return " ".join(tokens) 466 | 467 | 468 | def convert_tokens_to_ids(self, tokenized): 469 | """ 470 | """ 471 | return self.tokenizer.convert_tokens_to_ids(tokenized) 472 | 473 | 474 | def convert_ids_to_tokens(self, word_ids): 475 | """ 476 | """ 477 | return self.tokenizer.convert_ids_to_tokens(word_ids) 478 | 479 | 480 | def build_tokenizer(**args): 481 | """ 482 | """ 483 | if args["tokenizer"] == "default": 484 | tokenizer = SpaceTokenizer(vocab_file=args["vocab_file"]) 485 | elif args["tokenizer"] == "mimix": 486 | tokenizer = MimixTokenizer(vocab_file=args["vocab_file"]) 487 | elif args["tokenizer"] == "mimix-cased": 488 | tokenizer = MimixTokenizer(vocab_file=args["vocab_file"], uncased=False) 489 | elif args["tokenizer"] == "bert": 490 | tokenizer = BertTokenizer(vocab_file=args["vocab_file"], uncased=True) 491 | elif args["tokenizer"] == "bert-cased": 492 | tokenizer = BertTokenizer(vocab_file=args["vocab_file"], uncased=False) 493 | else: 494 | raise ValueError("tokenizer not correct!") 495 | 496 | return tokenizer 497 | 498 | -------------------------------------------------------------------------------- /mimix/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 30 15:16:38 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | 8 | import os 9 | from datetime import datetime 10 | import logging 11 | import torch 12 | from mimix.utils import real_path 13 | 14 | LOG_DIR = "logger" 15 | 16 | if not os.path.exists(real_path(LOG_DIR)): 17 | os.mkdir(real_path(LOG_DIR)) 18 | 19 | 20 | def build_logger(): 21 | """ 22 | """ 23 | format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' 24 | filename = datetime.today().strftime('logger/%Y-%m-%d-%H-%M-%S.log') 25 | logging.basicConfig(filename=real_path(filename), 26 | level=logging.INFO, 27 | format=format_str) 28 | console = logging.StreamHandler() 29 | console.setLevel(logging.INFO) 30 | 31 | formatter = logging.Formatter(format_str, "%Y-%m-%d %H:%M:%S") 32 | console.setFormatter(formatter) 33 | logging.getLogger('').addHandler(console) 34 | logger = logging.getLogger(__name__) 35 | 36 | return logger 37 | 38 | logger = build_logger() 39 | 40 | 41 | def save_model(model, optimizer, model_path): 42 | """ 43 | """ 44 | logger.info("Save model to %s" % model_path) 45 | 46 | torch.save(model.state_dict(), 47 | model_path, 48 | _use_new_zipfile_serialization=False) 49 | 50 | torch.save(optimizer.state_dict(), 51 | model_path + ".optimizer", 52 | _use_new_zipfile_serialization=False) 53 | 54 | logger.info("Save model complete") 55 | 56 | 57 | def print_model_info(model): 58 | """ 59 | """ 60 | logger.info("%s" % model) 61 | total_params = sum(p.numel() for p in model.parameters()) 62 | logger.info("Total Model Params:%s" % total_params) 63 | total_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad==True) 64 | logger.info("Trainable Model Params:%s" % total_train_params) 65 | 66 | 67 | def train(model, 68 | optimizer, 69 | train_config, 70 | train_generator, 71 | val_generator=None, 72 | test_generator=None, 73 | eval_fn_list=None, 74 | lr_scheduler=None): 75 | """ 76 | """ 77 | if os.path.exists(real_path(train_config["model_dir"])) == False: 78 | os.mkdir(real_path(train_config["model_dir"])) 79 | 80 | use_amp = train_config.get("use_amp", False) 81 | if use_amp: 82 | scaler = torch.cuda.amp.GradScaler() 83 | 84 | print_model_info(model) 85 | 86 | logger.info("Train Start!") 87 | 88 | accumulate_steps = train_config.get("accumulate_steps", 1) 89 | print_every_n_steps = train_config.get("print_every_n_steps", 100) 90 | model_path = real_path(os.path.join(real_path(train_config["model_dir"]), "%s." + train_config["model_name"])) 91 | save_steps = train_config.get("save_steps", 100000) 92 | tmp_save_steps = train_config.get("tmp_save_steps", 10000) 93 | grad_clip = train_config.get("grad_clip", None) 94 | 95 | history_loss = [] 96 | 97 | epoch,steps,total_steps = 0, 0, 0 98 | while epoch < train_config["max_epoch"]: 99 | model.train() 100 | 101 | for inputs,targets in train_generator(): 102 | if use_amp == True: 103 | with torch.cuda.amp.autocast(): 104 | outputs = model(inputs, targets=targets, compute_loss=True) 105 | loss = outputs["loss"] 106 | history_loss = history_loss[-999:] + [loss.item()] 107 | loss = loss / accumulate_steps 108 | else: 109 | outputs = model(inputs, targets=targets, compute_loss=True) 110 | loss = outputs["loss"] 111 | history_loss = history_loss[-999:] + [loss.item()] 112 | loss = loss / accumulate_steps 113 | 114 | if total_steps % print_every_n_steps == 0: 115 | ma_loss = sum(history_loss) / len(history_loss) 116 | logger.info( 117 | "%d epoch %d step total %d steps loss: %.3f" % 118 | (epoch, 119 | steps, 120 | total_steps, 121 | ma_loss) 122 | ) 123 | 124 | 125 | if use_amp == True: 126 | scaler.scale(loss).backward() 127 | else: 128 | loss.backward() 129 | 130 | if lr_scheduler is not None: 131 | lr_scheduler.step() 132 | 133 | total_steps += 1 134 | steps += 1 135 | 136 | if total_steps % save_steps == 0: 137 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 138 | 139 | if total_steps % tmp_save_steps == 0: 140 | save_model(model, optimizer, model_path % "tmp") 141 | 142 | if grad_clip is not None: 143 | torch.nn.utils.clip_grad_norm_( 144 | model.parameters(), 145 | grad_clip 146 | ) 147 | 148 | if total_steps % accumulate_steps == 0: 149 | if use_amp == True: 150 | scaler.step(optimizer) 151 | optimizer.zero_grad() 152 | scaler.update() 153 | else: 154 | optimizer.step() 155 | optimizer.zero_grad() 156 | 157 | epoch += 1 158 | steps = 0 159 | 160 | if len(eval_fn_list) > 0: 161 | if val_generator is not None: 162 | logger.info("Eval val now...") 163 | for eval_fn in eval_fn_list: 164 | eval_res = eval_fn(model, val_generator) 165 | logger.info("Result: %s" % eval_res) 166 | if test_generator is not None: 167 | logger.info("Eval test now...") 168 | for eval_fn in eval_fn_list: 169 | eval_res = eval_fn(model, test_generator) 170 | logger.info("Result: %s" % eval_res) 171 | save_model(model, optimizer, model_path % ("%d.%d.%d" % (epoch, steps, total_steps))) 172 | logger.info("Train Completed!") 173 | -------------------------------------------------------------------------------- /mimix/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Aug 13 11:39:56 2019 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import sys 8 | import os 9 | import configparser 10 | import json 11 | import random 12 | import numpy as np 13 | from abc import ABC, abstractmethod 14 | 15 | home_dir = os.path.abspath(os.getcwd()) 16 | 17 | 18 | SYMBOLS = {"_pad_" : "_pad_", 19 | "_bos_" : "_bos_", 20 | "_eos_" : "_eos_", 21 | "_unk_" : "_unk_", 22 | "_sep_" : "_sep_", 23 | "_cls_" : "_cls_", 24 | "_mask_" : "_mask_"} 25 | 26 | SYMBOL2ID = {"_pad_":0, 27 | "_bos_":1, 28 | "_eos_":2, 29 | "_unk_":3, 30 | "_sep_":4, 31 | "_cls_":5, 32 | "_mask_":6} 33 | 34 | 35 | def real_path(path, base_dir=None): 36 | """ 37 | get real path 38 | """ 39 | if path is None: 40 | return None 41 | if os.path.isabs(path) == True: 42 | return path 43 | if base_dir is None: 44 | base_dir = home_dir 45 | return os.path.join(base_dir, path) 46 | 47 | 48 | def load_config(config_file): 49 | """ 50 | load config 51 | """ 52 | config = configparser.RawConfigParser() 53 | config.optionxform = str 54 | 55 | config_file = real_path(config_file) 56 | if not os.path.exists(config_file): 57 | print("config file %s not exist!" % config_file) 58 | sys.exit(0) 59 | 60 | config.read(config_file, encoding='utf-8') 61 | 62 | loaded_config = {} 63 | 64 | for dtype in config.sections(): 65 | if dtype not in ["str", "int", "float", "bool"]: 66 | continue 67 | for k,v in config.items(dtype): 68 | if dtype == "str": 69 | loaded_config[k] = str(v) 70 | elif dtype == "int": 71 | loaded_config[k] = int(v) 72 | elif dtype == "float": 73 | loaded_config[k] = float(v) 74 | elif dtype == "bool": 75 | if v.lower() == "false": 76 | loaded_config[k] = False 77 | elif v.lower() == "true": 78 | loaded_config[k] = True 79 | return loaded_config 80 | 81 | 82 | def load_model_config(config_file): 83 | """ 84 | load config 85 | """ 86 | loaded_config = load_config(config_file) 87 | 88 | loaded_config["symbols"] = SYMBOLS 89 | loaded_config["symbol2id"] = SYMBOL2ID 90 | 91 | for symbol in SYMBOLS: 92 | if symbol + "2tok" in loaded_config: 93 | loaded_config["symbols"][symbol] = loaded_config[symbol + "2tok"] 94 | 95 | for symbol in SYMBOL2ID: 96 | if symbol + "2id" in loaded_config: 97 | loaded_config["symbol2id"][symbol] = loaded_config[symbol + "2id"] 98 | 99 | return loaded_config 100 | 101 | 102 | def load_vocab(vocab_path): 103 | """ 104 | """ 105 | vocab = {} 106 | for i,line in enumerate(open(real_path(vocab_path), "rb")): 107 | line = line.decode("utf-8").strip() 108 | if "\t" in line: 109 | word, word_id = line.split("\t") 110 | else: 111 | word, word_id = line, i 112 | vocab[word] = int(word_id) 113 | 114 | return vocab 115 | 116 | 117 | def invert_dict(dic): 118 | """ 119 | """ 120 | return {dic[k]:k for k in dic} 121 | 122 | 123 | def cut_and_pad_seq(seq, max_len, pad, left=False): 124 | """ 125 | """ 126 | if left == True: 127 | return [pad] * (max_len - len(seq)) + seq[:max_len] 128 | return seq[:max_len] + [pad] * (max_len - len(seq)) 129 | 130 | 131 | def cut_and_pad_seq_list(seq_list, max_len, pad, auto=False, pad_left=False): 132 | """ 133 | """ 134 | if auto == True: 135 | max_len = min(max(len(seq) for seq in seq_list), max_len) 136 | 137 | x = [] 138 | for seq in seq_list: 139 | x.append(cut_and_pad_seq(seq, max_len, pad, pad_left)) 140 | 141 | return x 142 | 143 | 144 | def derange(xs): 145 | for a in range(1, len(xs)): 146 | b = random.randint(0, a-1) 147 | xs[a], xs[b] = xs[b], xs[a] 148 | return xs 149 | 150 | 151 | def nested_to_device(nested_tensor, device): 152 | """ 153 | """ 154 | import torch 155 | res = nested_tensor 156 | if isinstance(nested_tensor, list) == True: 157 | res = [] 158 | for elem in nested_tensor: 159 | res.append(nested_to_device(elem, device)) 160 | elif isinstance(nested_tensor, tuple) == True: 161 | res = [] 162 | for elem in nested_tensor: 163 | res.append(nested_to_device(elem, device)) 164 | res = tuple(res) 165 | elif isinstance(nested_tensor, dict) == True: 166 | res = {} 167 | for k in nested_tensor: 168 | res[k] = nested_to_device(nested_tensor[k], device) 169 | elif isinstance(nested_tensor, torch.Tensor) == True: 170 | res = nested_tensor.to(device) 171 | return res 172 | 173 | 174 | def word_dropout(word_list, rate, replace_token): 175 | """ 176 | """ 177 | if rate > 0: 178 | tmp = [] 179 | 180 | for word in word_list: 181 | if random.random() < rate: 182 | tmp.append(replace_token) 183 | else: 184 | tmp.append(word) 185 | 186 | word_list = tmp 187 | 188 | return word_list 189 | 190 | 191 | class SimpleDataset(ABC): 192 | """ 193 | """ 194 | def __init__(self, device="cpu", rank=0, world_size=1): 195 | """ 196 | """ 197 | self.rank = rank 198 | self.world_size = world_size 199 | self.device = device 200 | self.sort_key_fn = None 201 | 202 | 203 | @abstractmethod 204 | def vectorize(self, batch_data): 205 | """ 206 | """ 207 | pass 208 | 209 | 210 | def local_shuffle(self): 211 | """ 212 | """ 213 | for f in os.listdir(self.data_dir): 214 | lines = [line for line in open(os.path.join(self.data_dir, f), "r", encoding="utf-8")] 215 | random.shuffle(lines) 216 | if self.sort_key_fn is not None: 217 | lines = [[line, self.sort_key_fn(json.loads(line))] for line in lines] 218 | lines.sort(key=lambda x:x[1]) 219 | lines = [x[0] for x in lines] 220 | fo = open(os.path.join(self.data_dir, f), "w", encoding="utf-8") 221 | for line in lines: 222 | fo.write(line) 223 | fo.close() 224 | 225 | 226 | def __call__(self, start_steps=0): 227 | """ 228 | """ 229 | data = [] 230 | files = os.listdir(self.data_dir) 231 | files.sort() 232 | 233 | steps = 1 234 | for fi in files: 235 | fi = os.path.join(self.data_dir, fi) 236 | for line in open(fi, "r", encoding="utf-8", errors="ignore"): 237 | steps += 1 238 | if steps < start_steps * self.batch_size: 239 | continue 240 | if steps % self.world_size != self.rank: 241 | continue 242 | data.append(json.loads(line)) 243 | if len(data) % (20 * self.batch_size) == 0: 244 | batch_data = data[:self.batch_size] 245 | data = data[self.batch_size:] 246 | yield nested_to_device(self.vectorize(batch_data), self.device) 247 | 248 | while len(data) > 0: 249 | batch_data = data[:self.batch_size] 250 | yield nested_to_device(self.vectorize(batch_data), self.device) 251 | data = data[self.batch_size:] 252 | 253 | 254 | if __name__ == "__main__": 255 | pass 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /mimix/vis.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Apr 7 13:37:34 2020 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | import os 8 | import sys 9 | import torch 10 | import numpy as np 11 | from argparse import ArgumentParser 12 | import matplotlib 13 | matplotlib.rcParams['font.sans-serif'] = ['KaiTi'] 14 | import matplotlib.pyplot as plt 15 | import matplotlib.ticker as ticker 16 | from pylab import mpl 17 | from mimix.predictor import EncDecGenerator 18 | from mimix.utils import load_model_config,real_path 19 | 20 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 21 | 22 | def draw_heatmap(x, y, scores, fig_path): 23 | """ 24 | """ 25 | #mpl.rcconfig['font.sans-serif'] = ['STZhongsong'] 26 | #mpl.rcconfig['axes.unicode_minus'] = False 27 | 28 | scores = np.round(scores, 2) 29 | 30 | fig, ax = plt.subplots() 31 | im = ax.imshow(scores, cmap='hot_r') 32 | 33 | # We want to show all ticks... 34 | ax.set_xticks(np.arange(len(y))) 35 | ax.set_yticks(np.arange(len(x))) 36 | # ... and label them with the respective list entries 37 | ax.set_xticklabels(y) 38 | ax.set_yticklabels(x) 39 | 40 | # Rotate the tick labels and set their alignment. 41 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", 42 | rotation_mode="anchor") 43 | 44 | # Loop over data dimensions and create text annotations. 45 | for i in range(len(x)): 46 | for j in range(len(y)): 47 | text = ax.text(j, i, scores[i, j], 48 | ha="center", va="center", color="w") 49 | 50 | plt.savefig(fig_path, dpi=300, bbox_inches = 'tight') 51 | #plt.show() 52 | plt.close() 53 | 54 | 55 | def analysis_search(enc_dec_gen, src, trg): 56 | """ 57 | """ 58 | src_list, trg_list = [src], [trg] 59 | batch_size = 1 60 | 61 | x,y = enc_dec_gen.encode_inputs(src_list, trg_list, add_bos=True, add_eos=True) 62 | enc_dec_gen.model.eval() 63 | with torch.no_grad(): 64 | outputs = enc_dec_gen.model({"x":x,"y":y[:, :-1]}) 65 | 66 | dec_enc_attn_weights_list = outputs["enc_attn_weights_list"] 67 | 68 | attn_score_list = [] 69 | 70 | for i in range(enc_dec_gen.model.n_dec_layers): 71 | 72 | attn_weights = dec_enc_attn_weights_list[i] 73 | 74 | attn_weights = attn_weights.mean(1).cpu().numpy() 75 | 76 | attn_score_list.append(attn_weights) 77 | 78 | res_list = [] 79 | 80 | x,y = x.cpu().numpy(), y.cpu().numpy() 81 | for i in range(batch_size): 82 | src = [enc_dec_gen.src_id2word.get(w, "_unk_") for w in x[i]] 83 | trg = [enc_dec_gen.trg_id2word.get(w, "_unk_") for w in y[i][1:]] 84 | 85 | tmp = [] 86 | for j in range(enc_dec_gen.model.n_dec_layers): 87 | tmp.append(attn_score_list[j][i][:len(trg), :len(src)].T) 88 | 89 | res = [src, trg, tmp] 90 | 91 | res_list.append(res) 92 | 93 | return res_list 94 | 95 | 96 | def visualize_enc_dec(config): 97 | """ 98 | """ 99 | enc_dec_gen = EncDecGenerator(config) 100 | 101 | print("INPUT TEXT:") 102 | for line in sys.stdin: 103 | line = line.strip() 104 | 105 | if len(line) == 0: 106 | continue 107 | 108 | src,trg = line.split("\t")[:2] 109 | 110 | res = analysis_search(enc_dec_gen, src, trg) 111 | src, trg, attn_score_list = res[0] 112 | 113 | for i in range(enc_dec_gen.model.n_dec_layers): 114 | 115 | draw_heatmap(src, trg, attn_score_list[i], "logger/dec_enc_%d.png" % i) 116 | 117 | print("generate heatmap done.") 118 | 119 | 120 | def run_visualize(): 121 | """ 122 | """ 123 | parser = ArgumentParser() 124 | 125 | parser.add_argument("--model_conf", type=str) 126 | args = parser.parse_args(sys.argv[1:]) 127 | 128 | conf_file = args.model_conf 129 | config = load_model_config(real_path(conf_file)) 130 | 131 | if config["task"] == "enc_dec": 132 | visualize_enc_dec(config) 133 | 134 | 135 | if __name__ == "__main__": 136 | run_visualize() 137 | -------------------------------------------------------------------------------- /model/readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxiaoyuan/mimix/3a93d23b18f433c4a871032a47459d9abcde1a94/model/readme.txt -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 6 21:42:21 2023 4 | 5 | @author: Xiaoyuan Yao 6 | """ 7 | from mimix.preprocess import run_preprocess 8 | 9 | run_preprocess() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13 2 | streamlit 3 | torchvision 4 | mido -------------------------------------------------------------------------------- /solve_sudoku.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 29 20:54:40 2024 4 | 5 | @author: 1 6 | """ 7 | import re 8 | from argparse import ArgumentParser 9 | import sys 10 | import time 11 | import numpy as np 12 | from mimix.predictor import TextEncoder 13 | from mimix.utils import real_path, load_model_config 14 | 15 | def sudoku_demo(): 16 | """ 17 | """ 18 | conf_file = "conf/sudoku_bert_base_conf" 19 | config = load_model_config(real_path(conf_file)) 20 | lm_gen = TextEncoder(config) 21 | 22 | print("INPUT PUZZLE:") 23 | 24 | for line in sys.stdin: 25 | line = line.strip() 26 | 27 | if len(line) != 81 or re.search("[^0-9]", line): 28 | print("invalid puzzle!") 29 | continue 30 | 31 | arr = np.zeros([9, 9], dtype=np.int64) 32 | for i,w in enumerate(line): 33 | arr[i//9][i%9] = int(w) 34 | 35 | print("puzzle:") 36 | print(arr) 37 | 38 | res = lm_gen.predict_mlm([" ".join(line)]) 39 | for i in range(81): 40 | if arr[i//9][i%9] == 0: 41 | arr[i//9][i%9] = int(res[0][1][i+1][0][0]) 42 | print("predict:") 43 | print(arr) 44 | 45 | flag = True 46 | for i in range(9): 47 | flag = flag and all(j in arr[i,:] for j in range(1, 10)) 48 | for i in range(9): 49 | flag = flag and all(j in arr[:,i] for j in range(1, 10)) 50 | for i in range(3): 51 | for j in range(3): 52 | flag = flag and all(k in arr[3*i:3*i+3,3*j:3*j+3] for k in range(1, 10)) 53 | 54 | if flag == True: 55 | print("solve success!") 56 | else: 57 | print("solve failed!") 58 | 59 | 60 | if __name__ == "__main__": 61 | sudoku_demo() 62 | --------------------------------------------------------------------------------