├── .gitignore ├── README.md ├── configs ├── config.yaml ├── config_evaluate.yaml ├── datamodule │ ├── _base.yaml │ ├── ace04.yaml │ ├── ace05.yaml │ ├── ctb7.yaml │ ├── genia.yaml │ └── ptb.yaml ├── exp │ ├── ft_10.yaml │ └── ft_50.yaml ├── finetune │ └── base.yaml ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── _base.yaml │ └── pointer.yaml ├── optim │ ├── exponential.yaml │ └── finetune_bert.yaml └── trainer │ └── default_trainer.yaml ├── evaluate.py ├── fastNLP ├── __init__.py ├── core │ ├── __init__.py │ ├── _logger.py │ ├── _parallel_utils.py │ ├── batch.py │ ├── callback.py │ ├── collate_fn.py │ ├── const.py │ ├── dataset.py │ ├── dist_trainer.py │ ├── field.py │ ├── instance.py │ ├── losses.py │ ├── metrics.py │ ├── optimizer.py │ ├── predictor.py │ ├── sampler.py │ ├── tester.py │ ├── trainer.py │ ├── utils.py │ └── vocabulary.py ├── doc_utils.py ├── embeddings │ ├── __init__.py │ ├── bert_embedding.py │ ├── char_embedding.py │ ├── contextual_embedding.py │ ├── elmo_embedding.py │ ├── embedding.py │ ├── gpt2_embedding.py │ ├── roberta_embedding.py │ ├── stack_embedding.py │ ├── static_embedding.py │ └── utils.py ├── io │ ├── __init__.py │ ├── data_bundle.py │ ├── embed_loader.py │ ├── file_reader.py │ ├── file_utils.py │ ├── loader │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── conll.py │ │ ├── coreference.py │ │ ├── csv.py │ │ ├── cws.py │ │ ├── json.py │ │ ├── loader.py │ │ ├── matching.py │ │ ├── qa.py │ │ └── summarization.py │ ├── model_io.py │ ├── pipe │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── conll.py │ │ ├── coreference.py │ │ ├── cws.py │ │ ├── matching.py │ │ ├── pipe.py │ │ ├── qa.py │ │ ├── summarization.py │ │ └── utils.py │ └── utils.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── bert.py │ ├── biaffine_parser.py │ ├── cnn_text_classification.py │ ├── seq2seq_generator.py │ ├── seq2seq_model.py │ ├── sequence_labeling.py │ ├── snli.py │ └── star_transformer.py └── modules │ ├── __init__.py │ ├── attention.py │ ├── decoder │ ├── __init__.py │ ├── crf.py │ ├── mlp.py │ ├── seq2seq_decoder.py │ ├── seq2seq_state.py │ └── utils.py │ ├── dropout.py │ ├── encoder │ ├── __init__.py │ ├── _elmo.py │ ├── bert.py │ ├── char_encoder.py │ ├── conv_maxpool.py │ ├── gpt2.py │ ├── lstm.py │ ├── pooling.py │ ├── roberta.py │ ├── seq2seq_encoder.py │ ├── star_transformer.py │ ├── transformer.py │ └── variational_rnn.py │ ├── generator │ ├── __init__.py │ └── seq2seq_generator.py │ ├── tokenizer │ ├── __init__.py │ ├── bert_tokenizer.py │ ├── gpt2_tokenizer.py │ └── roberta_tokenizer.py │ └── utils.py ├── requirement.txt ├── src ├── callbacks │ ├── progressbar.py │ ├── transformer_scheduler.py │ └── wandb_callbacks.py ├── constant.py ├── datamodule │ ├── __init__.py │ ├── base.py │ ├── benepar │ │ ├── __init__.py │ │ ├── char_lstm.py │ │ ├── decode_chart.py │ │ ├── integrations │ │ │ ├── __init__.py │ │ │ ├── downloader.py │ │ │ ├── nltk_plugin.py │ │ │ ├── spacy_extensions.py │ │ │ └── spacy_plugin.py │ │ ├── nkutil.py │ │ ├── parse_base.py │ │ ├── parse_chart.py │ │ ├── partitioned_transformer.py │ │ ├── ptb_unescape.py │ │ ├── retokenization.py │ │ ├── spacy_plugin.py │ │ └── subbatching.py │ ├── const_data_pointer.py │ ├── dm_util │ │ ├── datamodule_util.py │ │ ├── fields.py │ │ └── padder.py │ ├── nested_ner.py │ └── trees.py ├── model │ ├── metric.py │ ├── module │ │ ├── ember │ │ │ ├── embedding.py │ │ │ └── ext_embedding.py │ │ ├── encoder │ │ │ ├── lstm_encoder.py │ │ │ └── self_attentive.py │ │ └── scorer │ │ │ ├── const_scorer.py │ │ │ ├── module │ │ │ ├── biaffine.py │ │ │ └── triaffine.py │ │ │ └── self_attentive.py │ ├── parsing.py │ └── pointer.py └── runner │ └── base.py ├── supar ├── __init__.py ├── modules │ ├── __init__.py │ ├── affine.py │ ├── bert.py │ ├── char_lstm.py │ ├── dropout.py │ ├── lstm.py │ ├── mlp.py │ ├── scalar_mix.py │ ├── transformer.py │ ├── treecrf.py │ └── variational_inference.py └── utils │ ├── __init__.py │ ├── alg.py │ ├── common.py │ ├── config.py │ ├── data.py │ ├── embedding.py │ ├── field.py │ ├── field_new.py │ ├── fn.py │ ├── logging.py │ ├── metric.py │ ├── parallel.py │ ├── transform.py │ └── vocab.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pointer-net-for-nested 2 | The official implementation of ACL2022: [Bottom-Up Constituency Parsing and Nested Named Entity Recognition with Pointer Networks](https://arxiv.org/pdf/2110.05419.pdf) 3 | 4 | 5 | ## Setup 6 | Environment 7 | ``` 8 | conda create -n parsing python=3.7 9 | conda activate parsing 10 | while read requirement; do pip install $requirement; done < requirement.txt 11 | ``` 12 | 13 | Download preprocessed PTB, CTB7, GENIA from: [link](https://drive.google.com/drive/folders/1qFP2JbcltAJ-Jq3MpkS--0MGEIgyE6vQ?usp=sharing) 14 | 15 | For ACE04 and ACE05, send me e-mails. 16 | 17 | 18 | 19 | 20 | # Run 21 | ``` 22 | python train.py +exp=ft_10 datamodule=a model=pointer 23 | a={ptb, ctb7} 24 | 25 | python train.py +exp=ft_10 datamodule=genia model=pointer model.use_prev_label=True 26 | 27 | python train.py +exp=ft_50 datamodule=b model=pointer model.use_prev_label=True 28 | b={ace04, ace05} 29 | ``` 30 | 31 | multirun example: 32 | ``` 33 | python train.py +exp=base model=pointer datamodule=ptb,ctb7 seed=0,1,2 --mutlirun 34 | ``` 35 | 36 | evaluation: 37 | ``` 38 | python evaluate.py +load_from_checkpoint=your/checkpoint/dir 39 | ``` 40 | 41 | 42 | # Contact 43 | Please let me know if there are any bugs. Also, feel free to contact bestsonta@gmail.com if you have any questions. 44 | 45 | # Citation 46 | ``` 47 | @misc{yang2021bottomup, 48 | title={Bottom-Up Constituency Parsing and Nested Named Entity Recognition with Pointer Networks}, 49 | author={Songlin Yang and Kewei Tu}, 50 | year={2021}, 51 | eprint={2110.05419}, 52 | archivePrefix={arXiv}, 53 | primaryClass={cs.CL} 54 | } 55 | ``` 56 | 57 | # Credits 58 | The code is based on [lightning+hydra](https://github.com/ashleve/lightning-hydra-template) template. I use [FastNLP](https://github.com/fastnlp/fastNLP) for loading data. I use lots of built-in modules (LSTMs, Biaffines, Triaffines, Dropout Layers, etc) from [Supar](https://github.com/yzhangcs/parser/tree/main/supar). 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: default_trainer 6 | - model: _base 7 | - datamodule: _base 8 | 9 | runner: 10 | _target_: src.runner.base.Runner 11 | 12 | work_dir: ${hydra:runtime.cwd}/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed} 13 | 14 | wandb: False 15 | checkpoint: False 16 | device: 0 17 | seed: 0 18 | accumulation: 1 19 | use_logger: True 20 | distributed: False 21 | debug: False 22 | 23 | 24 | # output paths for hydra logs 25 | root: "." 26 | suffix: "" 27 | 28 | hydra: 29 | run: 30 | dir: ${work_dir} 31 | sweep: 32 | dir: logs/multiruns/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed} 33 | subdir: ${hydra.job.num} 34 | job: 35 | env_set: 36 | WANDB_CONSOLE: 'off' 37 | 38 | -------------------------------------------------------------------------------- /configs/config_evaluate.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - config 3 | 4 | load_from_checkpoint: ~ 5 | 6 | hydra: 7 | run: 8 | dir: . 9 | output_subdir: null 10 | -------------------------------------------------------------------------------- /configs/datamodule/_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | use_char: False 5 | use_bert: True 6 | use_pos: False 7 | use_word: False 8 | use_emb: False 9 | ext_emb_path: "" 10 | bert: '' 11 | min_freq: 2 12 | fix_len: 20 13 | train_sampler_type: 'token' 14 | test_sampler_type: 'token' 15 | bucket: 32 16 | bucket_test: 32 17 | max_tokens: 5000 18 | max_tokens_test: 5000 19 | use_cache: True 20 | use_bert_cache: True 21 | max_len: 10000 22 | max_len_test: 10000 23 | root: '.' 24 | distributed: False 25 | # for PTB only. clean (-RHS-) 26 | clean_word: False 27 | 28 | -------------------------------------------------------------------------------- /configs/datamodule/ace04.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.nested_ner.NestedNERData 8 | 9 | train: "${root}/data/ace2004/ace2004.train" 10 | dev: "${root}/data/ace2004/ace2004.dev" 11 | test: "${root}/data/ace2004/ace2004.test" 12 | cache: "${root}/data/ace2004/ace2004.pickle" 13 | cache_bert: "${root}/data/ace2004/ace2004.cache_bert" 14 | ext_emb_path: "${root}/data/ptb/glove.6B.100d.txt" 15 | clean_word: False 16 | bert: 'bert-large-cased' 17 | name: 'ace04' 18 | 19 | model: 20 | metric: 21 | target: 22 | _target_: src.model.metric.NERMetric 23 | write_result_to_file: True 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /configs/datamodule/ace05.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.nested_ner.NestedNERData 8 | 9 | train: "${root}/data/ace2005/ace2005.train" 10 | dev: "${root}/data/ace2005/ace2005.dev" 11 | test: "${root}/data/ace2005/ace2005.test" 12 | cache: "${root}/data/ace2005/ace2005.pickle" 13 | cache_bert: "${root}/data/ace2005/ace2005.cache_bert" 14 | ext_emb_path: "${root}/data/ptb/glove.6B.100d.txt" 15 | clean_word: False 16 | bert: 'bert-large-cased' 17 | name: 'ace05_use_fine_grained' 18 | use_fine_grained: False 19 | 20 | model: 21 | metric: 22 | target: 23 | _target_: src.model.metric.NERMetric 24 | write_result_to_file: True 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/datamodule/ctb7.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.const_data_pointer.ConstData4Pointer 8 | train_const: "${root}/data/ctb7/train.pid" 9 | dev_const: "${root}/data/ctb7/dev.pid" 10 | test_const: "${root}/data/ctb7/test.pid" 11 | cache: "${root}/data/ctb7/ctb.const.pickle" 12 | cache_bert: "${root}/data/ctb7/ctb.const.cache_bert" 13 | ext_emb_path: "${root}/data/ctb/glove.6B.100d.txt" 14 | clean_word: False 15 | bert: 'bert-base-chinese' 16 | name: 'cbt7' 17 | 18 | model: 19 | metric: 20 | target: 21 | _target_: src.model.metric.SpanMetric 22 | write_result_to_file: True 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /configs/datamodule/genia.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.nested_ner.NestedNERData 8 | 9 | train: "${root}/data/genia/genia.train" 10 | dev: "${root}/data/genia/genia.dev" 11 | test: "${root}/data/genia/genia.test" 12 | cache: "${root}/data/genia/genia.pickle" 13 | cache_bert: "${root}/data/genia/genia.cache_bert" 14 | clean_word: False 15 | bert: "dmis-lab/biobert-large-cased-v1.1" 16 | name: 'genia' 17 | ext_emb_path: "${root}/data/genia/PubMed-shuffle-win-30.txt" 18 | 19 | 20 | model: 21 | metric: 22 | target: 23 | _target_: src.model.metric.NERMetric 24 | write_result_to_file: True 25 | -------------------------------------------------------------------------------- /configs/datamodule/ptb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.const_data_pointer.ConstData4Pointer 8 | train_const: "${root}/data/ptb/02-21.10way.clean.txt" 9 | dev_const: "${root}/data/ptb/22.auto.clean.txt" 10 | test_const: "${root}/data/ptb/23.auto.clean.txt" 11 | cache: "${root}/data/ptb/ptb.const.pickle" 12 | bert: 'bert-large-cased' 13 | cache_bert: "${root}/data/ptb/ptb.const.cache_bert" 14 | ext_emb_path: "${root}/data/ptb/glove.6B.100d.txt" 15 | clean_word: False 16 | name: 'ptb' 17 | 18 | 19 | model: 20 | metric: 21 | target: 22 | _target_: src.model.metric.SpanMetric 23 | write_result_to_file: True 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /configs/exp/ft_10.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /optim: finetune_bert 4 | 5 | trainer: 6 | min_epochs: 1 7 | max_epochs: 10 8 | 9 | # 16*250=4000 10 | accumulation: 15 11 | 12 | datamodule: 13 | max_tokens: 200 14 | max_tokens_test: 1000 15 | max_len: 200 16 | 17 | # save checkpoints of the model. 18 | checkpoint: True 19 | 20 | model: 21 | embeder: 22 | finetune: True 23 | 24 | optim: 25 | only_embeder: True 26 | lr_rate: 50 27 | 28 | callbacks: 29 | transformer_scheduler: 30 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler 31 | warmup: ${optim.warmup} 32 | 33 | pretty_progress_bar: 34 | _target_: src.callbacks.progressbar.PrettyProgressBar 35 | refresh_rate: 1 36 | process_position: 0 37 | -------------------------------------------------------------------------------- /configs/exp/ft_50.yaml: -------------------------------------------------------------------------------- 1 | 2 | # @package _global_ 3 | defaults: 4 | - /optim: finetune_bert 5 | 6 | 7 | trainer: 8 | min_epochs: 1 9 | max_epochs: 50 10 | 11 | 12 | accumulation: 15 13 | 14 | datamodule: 15 | max_tokens: 200 16 | max_tokens_test: 1000 17 | max_len: 200 18 | use_word: False 19 | use_emb: False 20 | 21 | # save checkpoints of the model. 22 | checkpoint: True 23 | 24 | 25 | model: 26 | embeder: 27 | finetune: True 28 | 29 | 30 | optim: 31 | only_embeder: True 32 | lr_rate: 50 33 | 34 | 35 | callbacks: 36 | transformer_scheduler: 37 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler 38 | warmup: ${optim.warmup} 39 | 40 | pretty_progress_bar: 41 | _target_: src.callbacks.progressbar.PrettyProgressBar 42 | refresh_rate: 1 43 | process_position: 0 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /configs/finetune/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /optim: finetune_bert 5 | 6 | trainer: 7 | min_epochs: 1 8 | max_epochs: 10 9 | 10 | callbacks: 11 | transformer_scheduler: 12 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler 13 | warmup: ${optim.warmup} 14 | 15 | model: 16 | embeder: 17 | finetune: True 18 | 19 | optim: 20 | only_embeder: True -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ??? 6 | project_name: "project_template_test" 7 | experiment_name: null 8 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # CSVLogger built in PyTorch Lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | - csv.yaml 5 | - wandb.yaml 6 | # - neptune.yaml 7 | # - comet.yaml 8 | # - tensorboard.yaml 9 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | project_name: "your_name/lightning-hydra-template-test" 6 | api_key: ${env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable 7 | # experiment_name: "some_experiment" 8 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # TensorBoard 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | logger: 2 | wandb: 3 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 4 | project: "supervised_parsing" 5 | entity: "sonta2020" 6 | job_type: "train" 7 | group: "" 8 | name: "" 9 | save_dir: "${work_dir}" 10 | -------------------------------------------------------------------------------- /configs/model/_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | 4 | model: 5 | target: 6 | _target_: src.model.parsing.Parser 7 | 8 | embeder: 9 | target: 10 | _target_: src.model.module.ember.embedding.Embeder 11 | 12 | #pos 13 | n_pos_embed: 100 14 | #char 15 | n_char_embed: 50 16 | n_char_out: 100 17 | char_input_dropout: 0. 18 | # bert 19 | n_bert_out: 1024 20 | n_bert_layers: 4 21 | mix_dropout: 0. 22 | use_projection: False 23 | use_scalarmix: False 24 | finetune: False 25 | #word 26 | n_embed: 300 27 | 28 | encoder: 29 | target: 30 | _target_: src.model.module.encoder.lstm_encoder.LSTMencoder 31 | embed_dropout: .33 32 | embed_dropout_type: shared 33 | lstm_dropout: .33 34 | n_lstm_hidden: 500 35 | n_lstm_layers: 3 36 | before_lstm_dropout: 0. 37 | 38 | scorer: 39 | target: 40 | _target_: src.model.module.scorer.const_scorer.ConstScorer 41 | n_mlp_span: 1000 42 | n_mlp_label: 100 43 | mlp_dropout: .33 44 | scaling: False 45 | use_span: False 46 | use_transition: False 47 | 48 | loss: 49 | target: 50 | _target_: src.model.module.loss.semicrf.SemiCRF 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/model/pointer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | target: 5 | _target_: src.datamodule.const_data_pointer.ConstData4Pointer 6 | 7 | model: 8 | target: 9 | _target_: src.model.pointer.PointerNet 10 | 11 | decode_type: beam_search 12 | encoder_type: LSTM 13 | beam_size: 1 14 | 15 | use_hx: False 16 | use_focus: False 17 | use_prev_span: True 18 | use_prev_label: False 19 | use_remain_span: False 20 | use_action_mask: False 21 | 22 | n_lstm_hidden: 1000 23 | input_span_size: 500 24 | label_emb_size: 500 25 | biaffine_size: 500 26 | lstm_dropout: 0.33 27 | 28 | 29 | self_attentive_encoder: 30 | target: 31 | _target_: src.model.module.encoder.self_attentive.SelfAttentiveEncoder 32 | # d_model: 1024 33 | num_layers: 2 34 | num_heads: 8 35 | d_kv: 64 36 | d_ff: 2048 37 | morpho_emb_dropout: 0.2 38 | attention_dropout: 0.2 39 | relu_dropout: 0.1 40 | residual_dropout: 0.2 41 | 42 | lstm_encoder: 43 | embed_dropout: .33 44 | embed_dropout_type: shared 45 | lstm_dropout: .33 46 | n_lstm_hidden: 1000 47 | n_lstm_layers: 3 48 | before_lstm_dropout: 0. 49 | 50 | 51 | embeder: 52 | #pos 53 | n_pos_embed: 100 54 | #char 55 | n_char_embed: 50 56 | n_char_out: 100 57 | char_input_dropout: 0. 58 | # bert 59 | n_bert_out: 1024 60 | n_bert_layers: 4 61 | mix_dropout: 0. 62 | use_projection: False 63 | use_scalarmix: False 64 | finetune: False 65 | #word 66 | n_embed: 300 67 | 68 | metric: 69 | write_result_to_file: False 70 | 71 | name: 'pointer_net_use_prev_span_${model.use_prev_span}_use_remain_span_${model.use_remain_span}_use_prev_label_${model.use_prev_label}_use_word_${datamodule.use_word}_use_emb_${datamodule.use_emb}' 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /configs/optim/exponential.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | optimizer: 4 | # Adam-oriented deep learning 5 | _target_: torch.optim.Adam 6 | # These are all default parameters for the Adam optimizer 7 | lr: 2e-3 8 | betas: [ 0.9, 0.9] 9 | eps: 1e-12 10 | weight_decay: 0 11 | 12 | use_lr_scheduler: True 13 | lr_scheduler_name: exponential 14 | only_embeder: False 15 | 16 | lr_scheduler: 17 | _target_: torch.optim.lr_scheduler.ExponentialLR 18 | interval: 'step' 19 | frequency: 1 20 | gamma: 1 21 | 22 | -------------------------------------------------------------------------------- /configs/optim/finetune_bert.yaml: -------------------------------------------------------------------------------- 1 | warmup: 0.1 2 | #beta1: 0.9 3 | #beta2: 0.9 4 | #eps: 1e-12 5 | lr: 5e-5 6 | #weight_decay: 0 7 | lr_rate: 50 8 | scheduler_type: 'linear_warmup' 9 | 10 | 11 | -------------------------------------------------------------------------------- /configs/trainer/default_trainer.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | gpus: 1 # set -1 to train on all GPUs available, set 0 to train on CPU only 3 | min_epochs: 1 4 | max_epochs: 10 5 | gradient_clip_val: 5 6 | num_sanity_val_steps: 3 7 | progress_bar_refresh_rate: 20 8 | weights_summary: null 9 | fast_dev_run: False 10 | #default_root_dir: "lightning_logs/" 11 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import hydra 7 | import omegaconf 8 | import pytorch_lightning as pl 9 | from hydra.core.hydra_config import HydraConfig 10 | import wandb 11 | from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer 12 | from pytorch_lightning.loggers import LightningLoggerBase 13 | from pytorch_lightning import seed_everything 14 | 15 | # hydra imports 16 | from omegaconf import DictConfig 17 | from hydra.utils import log 18 | import hydra 19 | from omegaconf import OmegaConf 20 | 21 | # normal imports 22 | from typing import List 23 | import warnings 24 | import logging 25 | from pytorch_lightning.loggers import WandbLogger 26 | from pytorch_lightning.callbacks import ModelCheckpoint 27 | import sys 28 | # sys.setdefaultencoding() does not exist, here! 29 | 30 | from hydra.experimental import compose 31 | 32 | 33 | 34 | def evaluate(config): 35 | 36 | os.chdir(config.load_from_checkpoint) 37 | original_overrides = OmegaConf.load(os.path.join(config.load_from_checkpoint, ".hydra/overrides.yaml")) 38 | current_overrides = HydraConfig.get().overrides.task 39 | hydra_config = OmegaConf.load(os.path.join(config.load_from_checkpoint, ".hydra/hydra.yaml")) 40 | # getting the config name from the previous job. 41 | config_name = hydra_config.hydra.job.config_name 42 | # concatenating the original overrides with the current overrides 43 | overrides = original_overrides + current_overrides 44 | # compose a new config from scratch 45 | config = compose(config_name, overrides=overrides) 46 | print(config) 47 | checkpoint = os.path.join(config.load_from_checkpoint, "checkpoints/checkpoint.ckpt") 48 | config.model.metric.write_result_to_file=True 49 | # config.model.metric.target._target_= 'src.model.metric.AttachmentSpanMetric' 50 | config.root = hydra.utils.get_original_cwd() 51 | 52 | hydra_dir = str(os.getcwd()) 53 | 54 | os.chdir(hydra.utils.get_original_cwd()) 55 | 56 | datamodule = hydra.utils.instantiate(config.datamodule.target, config.datamodule, _recursive_=False) 57 | 58 | 59 | log.info("created datamodule") 60 | datamodule.setup() 61 | # Instantiate model, fuck hydra 1.1 62 | config.runner._target_ += '.load_from_checkpoint' 63 | 64 | model = hydra.utils.instantiate(config.runner, cfg = config, fields=datamodule.fields, datamodule=datamodule, checkpoint_path=checkpoint, _recursive_=False) 65 | os.chdir(hydra_dir) 66 | trainer = hydra.utils.instantiate(config.trainer, logger=False,replace_sampler_ddp=False, checkpoint_callback=False) 67 | trainer.test(model, datamodule=datamodule) 68 | 69 | 70 | @hydra.main(config_path="configs/", config_name="config_evaluate.yaml") 71 | def main(config): 72 | evaluate(config) 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /fastNLP/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、 3 | :mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。 4 | 5 | - :mod:`~fastNLP.core` 是fastNLP 的核心模块,包括 DataSet、 Trainer、 Tester 等组件。详见文档 :mod:`fastNLP.core` 6 | - :mod:`~fastNLP.io` 是实现输入输出的模块,包括了数据集的读取,模型的存取等功能。详见文档 :mod:`fastNLP.io` 7 | - :mod:`~fastNLP.embeddings` 提供用于构建复杂网络模型所需的各种embedding。详见文档 :mod:`fastNLP.embeddings` 8 | - :mod:`~fastNLP.modules` 包含了用于搭建神经网络模型的诸多组件,可以帮助用户快速搭建自己所需的网络。详见文档 :mod:`fastNLP.modules` 9 | - :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括 :class:`~fastNLP.models.CNNText` 、 :class:`~fastNLP.models.SeqLabeling` 等常见模型。详见文档 :mod:`fastNLP.models` 10 | 11 | fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的文档如下: 12 | """ 13 | __all__ = [ 14 | "Instance", 15 | "FieldArray", 16 | 17 | "DataSetIter", 18 | "BatchIter", 19 | "TorchLoaderIter", 20 | 21 | "Vocabulary", 22 | "DataSet", 23 | "Const", 24 | 25 | "Trainer", 26 | "Tester", 27 | 28 | "DistTrainer", 29 | "get_local_rank", 30 | 31 | "Callback", 32 | "GradientClipCallback", 33 | "EarlyStopCallback", 34 | "FitlogCallback", 35 | "EvaluateCallback", 36 | "LRScheduler", 37 | "ControlC", 38 | "LRFinder", 39 | "TensorboardCallback", 40 | "WarmupCallback", 41 | 'SaveModelCallback', 42 | "CallbackException", 43 | "EarlyStopError", 44 | "CheckPointCallback", 45 | 46 | "Padder", 47 | "AutoPadder", 48 | "EngChar2DPadder", 49 | 50 | # "CollateFn", 51 | "ConcatCollateFn", 52 | 53 | "MetricBase", 54 | "AccuracyMetric", 55 | "SpanFPreRecMetric", 56 | "CMRC2018Metric", 57 | "ClassifyFPreRecMetric", 58 | "ConfusionMatrixMetric", 59 | 60 | "Optimizer", 61 | "SGD", 62 | "Adam", 63 | "AdamW", 64 | 65 | "Sampler", 66 | "SequentialSampler", 67 | "BucketSampler", 68 | "RandomSampler", 69 | "SortedSampler", 70 | "ConstantTokenNumSampler", 71 | 72 | "LossFunc", 73 | "CrossEntropyLoss", 74 | "MSELoss", 75 | "L1Loss", 76 | "BCELoss", 77 | "NLLLoss", 78 | "LossInForward", 79 | "LossBase", 80 | "CMRC2018Loss", 81 | 82 | "cache_results", 83 | 84 | 'logger', 85 | "init_logger_dist", 86 | ] 87 | __version__ = '0.5.6' 88 | 89 | import sys 90 | 91 | from . import embeddings 92 | from . import models 93 | from . import modules 94 | from .core import * 95 | from .doc_utils import doc_process 96 | from .io import loader, pipe 97 | 98 | doc_process(sys.modules[__name__]) 99 | -------------------------------------------------------------------------------- /fastNLP/core/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import, 3 | 例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式:: 4 | 5 | # 直接从 fastNLP 中 import 6 | from fastNLP import DataSetIter 7 | 8 | # 从 core 模块的子模块 batch 中 import DataSetIter 9 | from fastNLP.core.batch import DataSetIter 10 | 11 | 对于常用的功能,你只需要在 :mod:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 12 | 13 | """ 14 | __all__ = [ 15 | "DataSet", 16 | 17 | "Instance", 18 | 19 | "FieldArray", 20 | "Padder", 21 | "AutoPadder", 22 | "EngChar2DPadder", 23 | 24 | "ConcatCollateFn", 25 | 26 | "Vocabulary", 27 | 28 | "DataSetIter", 29 | "BatchIter", 30 | "TorchLoaderIter", 31 | 32 | "Const", 33 | 34 | "Tester", 35 | "Trainer", 36 | 37 | "DistTrainer", 38 | "get_local_rank", 39 | 40 | "cache_results", 41 | "seq_len_to_mask", 42 | "get_seq_len", 43 | "logger", 44 | "init_logger_dist", 45 | 46 | "Callback", 47 | "GradientClipCallback", 48 | "EarlyStopCallback", 49 | "FitlogCallback", 50 | "EvaluateCallback", 51 | "LRScheduler", 52 | "ControlC", 53 | "LRFinder", 54 | "TensorboardCallback", 55 | "WarmupCallback", 56 | 'SaveModelCallback', 57 | "CallbackException", 58 | "EarlyStopError", 59 | "CheckPointCallback", 60 | 61 | "LossFunc", 62 | "CrossEntropyLoss", 63 | "L1Loss", 64 | "BCELoss", 65 | "NLLLoss", 66 | "LossInForward", 67 | "CMRC2018Loss", 68 | "MSELoss", 69 | "LossBase", 70 | 71 | "MetricBase", 72 | "AccuracyMetric", 73 | "SpanFPreRecMetric", 74 | "CMRC2018Metric", 75 | "ClassifyFPreRecMetric", 76 | "ConfusionMatrixMetric", 77 | 78 | "Optimizer", 79 | "SGD", 80 | "Adam", 81 | "AdamW", 82 | 83 | "SequentialSampler", 84 | "BucketSampler", 85 | "RandomSampler", 86 | "Sampler", 87 | "SortedSampler", 88 | "ConstantTokenNumSampler" 89 | ] 90 | 91 | from ._logger import logger, init_logger_dist 92 | from .batch import DataSetIter, BatchIter, TorchLoaderIter 93 | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ 94 | LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ 95 | EarlyStopError, CheckPointCallback 96 | from .const import Const 97 | from .dataset import DataSet 98 | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder 99 | from .instance import Instance 100 | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ 101 | LossInForward, CMRC2018Loss, LossBase, MSELoss 102 | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ 103 | ConfusionMatrixMetric 104 | from .optimizer import Optimizer, SGD, Adam, AdamW 105 | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler 106 | from .tester import Tester 107 | from .trainer import Trainer 108 | from .utils import cache_results, seq_len_to_mask, get_seq_len 109 | from .vocabulary import Vocabulary 110 | from .collate_fn import ConcatCollateFn 111 | from .dist_trainer import DistTrainer, get_local_rank 112 | -------------------------------------------------------------------------------- /fastNLP/core/_logger.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, 3 | 具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API 4 | 使用方式: 5 | from fastNLP import logger 6 | # 7 | # logger 可以和 logging.Logger 一样使用 8 | logger.info('your msg') 9 | logger.error('your msg') 10 | 11 | # logger 新增的API 12 | # 将日志输出到文件,以及输出的日志等级 13 | logger.add_file('/path/to/log', level='INFO') 14 | # 定义在命令行中的显示格式和日志等级 15 | logger.set_stdout('tqdm', level='WARN') 16 | 17 | """ 18 | 19 | __all__ = [ 20 | 'logger', 21 | 'init_logger_dist' 22 | ] 23 | 24 | import logging 25 | import logging.config 26 | import os 27 | import sys 28 | import warnings 29 | from torch import distributed as dist 30 | 31 | ROOT_NAME = 'fastNLP' 32 | 33 | try: 34 | import fitlog 35 | except ImportError: 36 | fitlog = None 37 | try: 38 | from tqdm.auto import tqdm 39 | except ImportError: 40 | tqdm = None 41 | 42 | if tqdm is not None: 43 | class TqdmLoggingHandler(logging.Handler): 44 | def __init__(self, level=logging.INFO): 45 | super().__init__(level) 46 | 47 | def emit(self, record): 48 | try: 49 | msg = self.format(record) 50 | tqdm.write(msg) 51 | self.flush() 52 | except (KeyboardInterrupt, SystemExit): 53 | raise 54 | except: 55 | self.handleError(record) 56 | else: 57 | class TqdmLoggingHandler(logging.StreamHandler): 58 | def __init__(self, level=logging.INFO): 59 | super().__init__(sys.stdout) 60 | self.setLevel(level) 61 | 62 | 63 | def _get_level(level): 64 | if isinstance(level, int): 65 | pass 66 | else: 67 | level = level.lower() 68 | level = {'info': logging.INFO, 'debug': logging.DEBUG, 69 | 'warn': logging.WARN, 'warning': logging.WARN, 70 | 'error': logging.ERROR}[level] 71 | return level 72 | 73 | 74 | def _add_file_handler(logger, path, level='INFO'): 75 | for h in logger.handlers: 76 | if isinstance(h, logging.FileHandler): 77 | if os.path.abspath(path) == h.baseFilename: 78 | # file path already added 79 | return 80 | 81 | # File Handler 82 | if os.path.exists(path): 83 | assert os.path.isfile(path) 84 | warnings.warn('log already exists in {}'.format(path)) 85 | dirname = os.path.abspath(os.path.dirname(path)) 86 | os.makedirs(dirname, exist_ok=True) 87 | 88 | file_handler = logging.FileHandler(path, mode='a') 89 | file_handler.setLevel(_get_level(level)) 90 | file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', 91 | datefmt='%Y/%m/%d %H:%M:%S') 92 | file_handler.setFormatter(file_formatter) 93 | logger.addHandler(file_handler) 94 | 95 | 96 | def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): 97 | level = _get_level(level) 98 | if stdout not in ['none', 'plain', 'tqdm']: 99 | raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) 100 | # make sure to initialize logger only once 101 | stream_handler = None 102 | for i, h in enumerate(logger.handlers): 103 | if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)): 104 | stream_handler = h 105 | break 106 | if stream_handler is not None: 107 | logger.removeHandler(stream_handler) 108 | 109 | # Stream Handler 110 | if stdout == 'plain': 111 | stream_handler = logging.StreamHandler(sys.stdout) 112 | elif stdout == 'tqdm': 113 | stream_handler = TqdmLoggingHandler(level) 114 | else: 115 | stream_handler = None 116 | 117 | if stream_handler is not None: 118 | stream_formatter = logging.Formatter('%(message)s') 119 | stream_handler.setLevel(level) 120 | stream_handler.setFormatter(stream_formatter) 121 | logger.addHandler(stream_handler) 122 | 123 | 124 | class FastNLPLogger(logging.getLoggerClass()): 125 | def __init__(self, name): 126 | super().__init__(name) 127 | 128 | def add_file(self, path='./log.txt', level='INFO'): 129 | r"""add log output file and the output level""" 130 | _add_file_handler(self, path, level) 131 | 132 | def set_stdout(self, stdout='tqdm', level='INFO'): 133 | r"""set stdout format and the output level""" 134 | _set_stdout_handler(self, stdout, level) 135 | 136 | 137 | logging.setLoggerClass(FastNLPLogger) 138 | 139 | 140 | # print(logging.getLoggerClass()) 141 | # print(logging.getLogger()) 142 | 143 | def _init_logger(path=None, stdout='tqdm', level='INFO'): 144 | r"""initialize logger""" 145 | level = _get_level(level) 146 | 147 | # logger = logging.getLogger() 148 | logger = logging.getLogger(ROOT_NAME) 149 | logger.propagate = False 150 | logger.setLevel(1) # make the logger the lowest level 151 | 152 | _set_stdout_handler(logger, stdout, level) 153 | 154 | # File Handler 155 | if path is not None: 156 | _add_file_handler(logger, path, level) 157 | 158 | return logger 159 | 160 | 161 | def _get_logger(name=None, level='INFO'): 162 | level = _get_level(level) 163 | if name is None: 164 | name = ROOT_NAME 165 | assert isinstance(name, str) 166 | if not name.startswith(ROOT_NAME): 167 | name = '{}.{}'.format(ROOT_NAME, name) 168 | logger = logging.getLogger(name) 169 | logger.setLevel(level) 170 | return logger 171 | 172 | 173 | logger = _init_logger(path=None, level='INFO') 174 | 175 | 176 | def init_logger_dist(): 177 | global logger 178 | rank = dist.get_rank() 179 | logger.setLevel(logging.INFO if rank == 0 else logging.WARNING) 180 | -------------------------------------------------------------------------------- /fastNLP/core/_parallel_utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [] 4 | 5 | import threading 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.parallel.parallel_apply import get_a_var 10 | from torch.nn.parallel.replicate import replicate 11 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 12 | 13 | 14 | def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): 15 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 16 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 17 | on each of :attr:`devices`. 18 | 19 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 20 | :attr:`devices` (if given) should all have same length. Moreover, each 21 | element of :attr:`inputs` can either be a single object as the only argument 22 | to a module, or a collection of positional arguments. 23 | """ 24 | assert len(modules) == len(inputs) 25 | if kwargs_tup is not None: 26 | assert len(modules) == len(kwargs_tup) 27 | else: 28 | kwargs_tup = ({},) * len(modules) 29 | if devices is not None: 30 | assert len(modules) == len(devices) 31 | else: 32 | devices = [None] * len(modules) 33 | 34 | lock = threading.Lock() 35 | results = {} 36 | grad_enabled = torch.is_grad_enabled() 37 | 38 | def _worker(i, module, input, kwargs, device=None): 39 | torch.set_grad_enabled(grad_enabled) 40 | if device is None: 41 | device = get_a_var(input).get_device() 42 | try: 43 | with torch.cuda.device(device): 44 | # this also avoids accidental slicing of `input` if it is a Tensor 45 | if not isinstance(input, (list, tuple)): 46 | input = (input,) 47 | output = getattr(module, func_name)(*input, **kwargs) 48 | with lock: 49 | results[i] = output 50 | except Exception as e: 51 | with lock: 52 | results[i] = e 53 | 54 | if len(modules) > 1: 55 | threads = [threading.Thread(target=_worker, 56 | args=(i, module, input, kwargs, device)) 57 | for i, (module, input, kwargs, device) in 58 | enumerate(zip(modules, inputs, kwargs_tup, devices))] 59 | 60 | for thread in threads: 61 | thread.start() 62 | for thread in threads: 63 | thread.join() 64 | else: 65 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 66 | 67 | outputs = [] 68 | for i in range(len(inputs)): 69 | output = results[i] 70 | if isinstance(output, Exception): 71 | raise output 72 | outputs.append(output) 73 | return outputs 74 | 75 | 76 | def _data_parallel_wrapper(func_name, device_ids, output_device): 77 | r""" 78 | 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 79 | 80 | :param str, func_name: 对network中的这个函数进行多卡运行 81 | :param device_ids: nn.DataParallel中的device_ids 82 | :param output_device: nn.DataParallel中的output_device 83 | :return: 84 | """ 85 | 86 | def wrapper(network, *inputs, **kwargs): 87 | inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) 88 | if len(device_ids) == 1: 89 | return getattr(network, func_name)(*inputs[0], **kwargs[0]) 90 | replicas = replicate(network, device_ids[:len(inputs)]) 91 | outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) 92 | return gather(outputs, output_device) 93 | 94 | return wrapper 95 | 96 | 97 | def _model_contains_inner_module(model): 98 | r""" 99 | 100 | :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, 101 | nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 102 | :return: bool 103 | """ 104 | if isinstance(model, nn.Module): 105 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): 106 | return True 107 | return False 108 | -------------------------------------------------------------------------------- /fastNLP/core/collate_fn.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | from builtins import sorted 3 | 4 | import torch 5 | import numpy as np 6 | from .field import _get_ele_type_and_dim 7 | from .utils import logger 8 | from copy import deepcopy 9 | 10 | 11 | def _check_type(batch_dict, fields): 12 | if len(fields) == 0: 13 | raise RuntimeError 14 | types = [] 15 | dims = [] 16 | for f in fields: 17 | t, d = _get_ele_type_and_dim(batch_dict[f]) 18 | types.append(t) 19 | dims.append(d) 20 | diff_types = set(types) 21 | diff_dims = set(dims) 22 | if len(diff_types) > 1 or len(diff_dims) > 1: 23 | raise ValueError 24 | return types[0] 25 | 26 | 27 | def batching(samples, max_len=0, padding_val=0): 28 | if len(samples) == 0: 29 | return samples 30 | if max_len <= 0: 31 | max_len = max(s.shape[0] for s in samples) 32 | batch = np.full((len(samples), max_len), fill_value=padding_val) 33 | for i, s in enumerate(samples): 34 | slen = min(s.shape[0], max_len) 35 | batch[i][:slen] = s[:slen] 36 | return batch 37 | 38 | 39 | class Collater: 40 | r""" 41 | 辅助DataSet管理collate_fn的类 42 | 43 | """ 44 | def __init__(self): 45 | self.collate_fns = {} 46 | 47 | def add_fn(self, fn, name=None): 48 | r""" 49 | 向collater新增一个collate_fn函数 50 | 51 | :param callable fn: 52 | :param str,int name: 53 | :return: 54 | """ 55 | if name in self.collate_fns: 56 | logger.warn(f"collate_fn:{name} will be overwritten.") 57 | if name is None: 58 | name = len(self.collate_fns) 59 | self.collate_fns[name] = fn 60 | 61 | def is_empty(self): 62 | r""" 63 | 返回是否包含collate_fn 64 | 65 | :return: 66 | """ 67 | return len(self.collate_fns) == 0 68 | 69 | def delete_fn(self, name=None): 70 | r""" 71 | 删除collate_fn 72 | 73 | :param str,int name: 如果为None就删除最近加入的collate_fn 74 | :return: 75 | """ 76 | if not self.is_empty(): 77 | if name in self.collate_fns: 78 | self.collate_fns.pop(name) 79 | elif name is None: 80 | last_key = list(self.collate_fns.keys())[0] 81 | self.collate_fns.pop(last_key) 82 | 83 | def collate_batch(self, ins_list): 84 | bx, by = {}, {} 85 | for name, fn in self.collate_fns.items(): 86 | try: 87 | batch_x, batch_y = fn(ins_list) 88 | except BaseException as e: 89 | logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.") 90 | raise e 91 | bx.update(batch_x) 92 | by.update(batch_y) 93 | return bx, by 94 | 95 | def copy_from(self, col): 96 | assert isinstance(col, Collater) 97 | new_col = Collater() 98 | new_col.collate_fns = deepcopy(col.collate_fns) 99 | return new_col 100 | 101 | 102 | class ConcatCollateFn: 103 | r""" 104 | field拼接collate_fn,将不同field按序拼接后,padding产生数据。 105 | 106 | :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field 107 | :param str output: 拼接后的field名称 108 | :param pad_val: padding的数值 109 | :param max_len: 拼接后最大长度 110 | :param is_input: 是否将生成的output设置为input 111 | :param is_target: 是否将生成的output设置为target 112 | """ 113 | 114 | def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False): 115 | super().__init__() 116 | assert isinstance(inputs, list) 117 | self.inputs = inputs 118 | self.output = output 119 | self.pad_val = pad_val 120 | self.max_len = max_len 121 | self.is_input = is_input 122 | self.is_target = is_target 123 | 124 | @staticmethod 125 | def _to_numpy(seq): 126 | if torch.is_tensor(seq): 127 | return seq.numpy() 128 | else: 129 | return np.array(seq) 130 | 131 | def __call__(self, ins_list): 132 | samples = [] 133 | for i, ins in ins_list: 134 | sample = [] 135 | for input_name in self.inputs: 136 | sample.append(self._to_numpy(ins[input_name])) 137 | samples.append(np.concatenate(sample, axis=0)) 138 | batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) 139 | b_x, b_y = {}, {} 140 | if self.is_input: 141 | b_x[self.output] = batch 142 | if self.is_target: 143 | b_y[self.output] = batch 144 | 145 | return b_x, b_y 146 | -------------------------------------------------------------------------------- /fastNLP/core/const.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP包当中的field命名均符合一定的规范,该规范由fastNLP.Const类进行定义。 3 | """ 4 | 5 | __all__ = [ 6 | "Const" 7 | ] 8 | 9 | 10 | class Const: 11 | r""" 12 | fastNLP中field命名常量。 13 | 14 | .. todo:: 15 | 把下面这段改成表格 16 | 17 | 具体列表:: 18 | 19 | INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) 20 | CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) 21 | INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) 22 | OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) 23 | TARGET 真实目标 target(具有多列target时,依次使用target1,target2) 24 | LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) 25 | RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) 26 | RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) 27 | 28 | """ 29 | INPUT = 'words' 30 | CHAR_INPUT = 'chars' 31 | INPUT_LEN = 'seq_len' 32 | OUTPUT = 'pred' 33 | TARGET = 'target' 34 | LOSS = 'loss' 35 | RAW_WORD = 'raw_words' 36 | RAW_CHAR = 'raw_chars' 37 | 38 | @staticmethod 39 | def INPUTS(i): 40 | r"""得到第 i 个 ``INPUT`` 的命名""" 41 | i = int(i) + 1 42 | return Const.INPUT + str(i) 43 | 44 | @staticmethod 45 | def CHAR_INPUTS(i): 46 | r"""得到第 i 个 ``CHAR_INPUT`` 的命名""" 47 | i = int(i) + 1 48 | return Const.CHAR_INPUT + str(i) 49 | 50 | @staticmethod 51 | def RAW_WORDS(i): 52 | r"""得到第 i 个 ``RAW_WORDS`` 的命名""" 53 | i = int(i) + 1 54 | return Const.RAW_WORD + str(i) 55 | 56 | @staticmethod 57 | def RAW_CHARS(i): 58 | r"""得到第 i 个 ``RAW_CHARS`` 的命名""" 59 | i = int(i) + 1 60 | return Const.RAW_CHAR + str(i) 61 | 62 | @staticmethod 63 | def INPUT_LENS(i): 64 | r"""得到第 i 个 ``INPUT_LEN`` 的命名""" 65 | i = int(i) + 1 66 | return Const.INPUT_LEN + str(i) 67 | 68 | @staticmethod 69 | def OUTPUTS(i): 70 | r"""得到第 i 个 ``OUTPUT`` 的命名""" 71 | i = int(i) + 1 72 | return Const.OUTPUT + str(i) 73 | 74 | @staticmethod 75 | def TARGETS(i): 76 | r"""得到第 i 个 ``TARGET`` 的命名""" 77 | i = int(i) + 1 78 | return Const.TARGET + str(i) 79 | 80 | @staticmethod 81 | def LOSSES(i): 82 | r"""得到第 i 个 ``LOSS`` 的命名""" 83 | i = int(i) + 1 84 | return Const.LOSS + str(i) 85 | -------------------------------------------------------------------------------- /fastNLP/core/instance.py: -------------------------------------------------------------------------------- 1 | r""" 2 | instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 3 | 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 4 | 5 | """ 6 | 7 | __all__ = [ 8 | "Instance" 9 | ] 10 | 11 | from .utils import pretty_table_printer 12 | 13 | 14 | class Instance(object): 15 | r""" 16 | Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 17 | Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: 18 | 19 | >>>from fastNLP import Instance 20 | >>>ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) 21 | >>>ins["field_1"] 22 | [1, 1, 1] 23 | >>>ins.add_field("field_3", [3, 3, 3]) 24 | >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) 25 | """ 26 | 27 | def __init__(self, **fields): 28 | 29 | self.fields = fields 30 | 31 | def add_field(self, field_name, field): 32 | r""" 33 | 向Instance中增加一个field 34 | 35 | :param str field_name: 新增field的名称 36 | :param Any field: 新增field的内容 37 | """ 38 | self.fields[field_name] = field 39 | 40 | def items(self): 41 | r""" 42 | 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value 43 | 44 | :return: 一个迭代器 45 | """ 46 | return self.fields.items() 47 | 48 | def __contains__(self, item): 49 | return item in self.fields 50 | 51 | def __getitem__(self, name): 52 | if name in self.fields: 53 | return self.fields[name] 54 | else: 55 | print(name) 56 | raise KeyError("{} not found".format(name)) 57 | 58 | def __setitem__(self, name, field): 59 | return self.add_field(name, field) 60 | 61 | def __repr__(self): 62 | return str(pretty_table_printer(self)) 63 | -------------------------------------------------------------------------------- /fastNLP/core/predictor.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Predictor" 5 | ] 6 | 7 | from collections import defaultdict 8 | 9 | import torch 10 | 11 | from . import DataSet 12 | from . import DataSetIter 13 | from . import SequentialSampler 14 | from .utils import _build_args, _move_dict_value_to_device, _get_model_device 15 | 16 | 17 | class Predictor(object): 18 | r""" 19 | 一个根据训练模型预测输出的预测器(Predictor) 20 | 21 | 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 22 | 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 23 | """ 24 | 25 | def __init__(self, network): 26 | r""" 27 | 28 | :param torch.nn.Module network: 用来完成预测任务的模型 29 | """ 30 | if not isinstance(network, torch.nn.Module): 31 | raise ValueError( 32 | "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) 33 | self.network = network 34 | self.batch_size = 1 35 | self.batch_output = [] 36 | 37 | def predict(self, data: DataSet, seq_len_field_name=None): 38 | r"""用已经训练好的模型进行inference. 39 | 40 | :param fastNLP.DataSet data: 待预测的数据集 41 | :param str seq_len_field_name: 表示序列长度信息的field名字 42 | :return: dict dict里面的内容为模型预测的结果 43 | """ 44 | if not isinstance(data, DataSet): 45 | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) 46 | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: 47 | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) 48 | 49 | prev_training = self.network.training 50 | self.network.eval() 51 | network_device = _get_model_device(self.network) 52 | batch_output = defaultdict(list) 53 | data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) 54 | 55 | if hasattr(self.network, "predict"): 56 | predict_func = self.network.predict 57 | else: 58 | predict_func = self.network.forward 59 | 60 | with torch.no_grad(): 61 | for batch_x, _ in data_iterator: 62 | _move_dict_value_to_device(batch_x, _, device=network_device) 63 | refined_batch_x = _build_args(predict_func, **batch_x) 64 | prediction = predict_func(**refined_batch_x) 65 | 66 | if seq_len_field_name is not None: 67 | seq_lens = batch_x[seq_len_field_name].tolist() 68 | 69 | for key, value in prediction.items(): 70 | value = value.cpu().numpy() 71 | if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): 72 | batch_output[key].extend(value.tolist()) 73 | else: 74 | if seq_len_field_name is not None: 75 | tmp_batch = [] 76 | for idx, seq_len in enumerate(seq_lens): 77 | tmp_batch.append(value[idx, :seq_len]) 78 | batch_output[key].extend(tmp_batch) 79 | else: 80 | batch_output[key].append(value) 81 | 82 | self.network.train(prev_training) 83 | return batch_output 84 | -------------------------------------------------------------------------------- /fastNLP/doc_utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented 2 | 用于辅助生成 fastNLP 文档的代码 3 | """ 4 | 5 | __all__ = [] 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | def doc_process(m): 12 | for name, obj in inspect.getmembers(m): 13 | if inspect.isclass(obj) or inspect.isfunction(obj): 14 | if obj.__module__ != m.__name__: 15 | if obj.__doc__ is None: 16 | # print(name, obj.__doc__) 17 | pass 18 | else: 19 | module_name = obj.__module__ 20 | 21 | # 识别并标注类和函数在不同层次中的位置 22 | 23 | while 1: 24 | defined_m = sys.modules[module_name] 25 | try: 26 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: 27 | obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \ 28 | + " :class:`" + module_name + "." + name + "`\n" + obj.__doc__ 29 | break 30 | module_name = ".".join(module_name.split('.')[:-1]) 31 | if module_name == m.__name__: 32 | # print(name, ": not found defined doc.") 33 | break 34 | except: 35 | print("Warning: Module {} lacks `__doc__`".format(module_name)) 36 | break 37 | 38 | # 识别并标注基类,只有基类也在 fastNLP 中定义才显示 39 | 40 | if inspect.isclass(obj): 41 | for base in obj.__bases__: 42 | if base.__module__.startswith("fastNLP"): 43 | parts = base.__module__.split(".") + [] 44 | module_name, i = "fastNLP", 1 45 | for i in range(len(parts) - 1): 46 | defined_m = sys.modules[module_name] 47 | try: 48 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: 49 | obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__ 50 | break 51 | module_name += "." + parts[i + 1] 52 | except: 53 | print("Warning: Module {} lacks `__doc__`".format(module_name)) 54 | break 55 | -------------------------------------------------------------------------------- /fastNLP/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有 3 | embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的 4 | torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的 5 | 输出维度。 6 | """ 7 | 8 | __all__ = [ 9 | "Embedding", 10 | "TokenEmbedding", 11 | "StaticEmbedding", 12 | "ElmoEmbedding", 13 | "BertEmbedding", 14 | "BertWordPieceEncoder", 15 | 16 | "RobertaEmbedding", 17 | "RobertaWordPieceEncoder", 18 | 19 | "GPT2Embedding", 20 | "GPT2WordPieceEncoder", 21 | 22 | "StackEmbedding", 23 | "LSTMCharEmbedding", 24 | "CNNCharEmbedding", 25 | 26 | "get_embeddings", 27 | "get_sinusoid_encoding_table" 28 | ] 29 | 30 | from .embedding import Embedding, TokenEmbedding 31 | from .static_embedding import StaticEmbedding 32 | from .elmo_embedding import ElmoEmbedding 33 | from .bert_embedding import BertEmbedding, BertWordPieceEncoder 34 | from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder 35 | from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding 36 | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding 37 | from .stack_embedding import StackEmbedding 38 | from .utils import get_embeddings, get_sinusoid_encoding_table 39 | 40 | import sys 41 | from ..doc_utils import doc_process 42 | doc_process(sys.modules[__name__]) -------------------------------------------------------------------------------- /fastNLP/embeddings/contextual_embedding.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "ContextualEmbedding" 8 | ] 9 | 10 | from abc import abstractmethod 11 | 12 | import torch 13 | 14 | from .embedding import TokenEmbedding 15 | from ..core import logger 16 | from ..core.batch import DataSetIter 17 | from ..core.dataset import DataSet 18 | from ..core.sampler import SequentialSampler 19 | from ..core.utils import _move_model_to_device, _get_model_device 20 | from ..core.vocabulary import Vocabulary 21 | 22 | 23 | class ContextualEmbedding(TokenEmbedding): 24 | r""" 25 | ContextualEmbedding组件. BertEmbedding与ElmoEmbedding的基类 26 | """ 27 | def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): 28 | super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) 29 | 30 | def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): 31 | r""" 32 | 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 33 | 34 | :param datasets: DataSet对象 35 | :param batch_size: int, 生成cache的sentence表示时使用的batch的大小 36 | :param device: 参考 :class::fastNLP.Trainer 的device 37 | :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。 38 | :return: 39 | """ 40 | for index, dataset in enumerate(datasets): 41 | try: 42 | assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." 43 | assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." 44 | except Exception as e: 45 | logger.error(f"Exception happens at {index} dataset.") 46 | raise e 47 | 48 | sent_embeds = {} 49 | _move_model_to_device(self, device=device) 50 | device = _get_model_device(self) 51 | pad_index = self._word_vocab.padding_idx 52 | logger.info("Start to calculate sentence representations.") 53 | with torch.no_grad(): 54 | for index, dataset in enumerate(datasets): 55 | try: 56 | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) 57 | for batch_x, batch_y in batch: 58 | words = batch_x['words'].to(device) 59 | words_list = words.tolist() 60 | seq_len = words.ne(pad_index).sum(dim=-1) 61 | max_len = words.size(1) 62 | # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 63 | seq_len_from_behind = (max_len - seq_len).tolist() 64 | word_embeds = self(words).detach().cpu().numpy() 65 | for b in range(words.size(0)): 66 | length = seq_len_from_behind[b] 67 | if length == 0: 68 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b] 69 | else: 70 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] 71 | except Exception as e: 72 | logger.error(f"Exception happens at {index} dataset.") 73 | raise e 74 | logger.info("Finish calculating sentence representations.") 75 | self.sent_embeds = sent_embeds 76 | if delete_weights: 77 | self._delete_model_weights() 78 | 79 | def _get_sent_reprs(self, words): 80 | r""" 81 | 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None 82 | 83 | :param words: torch.LongTensor 84 | :return: 85 | """ 86 | if hasattr(self, 'sent_embeds'): 87 | words_list = words.tolist() 88 | seq_len = words.ne(self._word_pad_index).sum(dim=-1) 89 | _embeds = [] 90 | for b in range(len(words)): 91 | words_i = tuple(words_list[b][:seq_len[b]]) 92 | embed = self.sent_embeds[words_i] 93 | _embeds.append(embed) 94 | max_sent_len = max(map(len, _embeds)) 95 | embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float, 96 | device=words.device) 97 | for i, embed in enumerate(_embeds): 98 | embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) 99 | return embeds 100 | return None 101 | 102 | @abstractmethod 103 | def _delete_model_weights(self): 104 | r"""删除计算表示的模型以节省资源""" 105 | raise NotImplementedError 106 | 107 | def remove_sentence_cache(self): 108 | r""" 109 | 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 110 | 111 | :return: 112 | """ 113 | del self.sent_embeds 114 | -------------------------------------------------------------------------------- /fastNLP/embeddings/stack_embedding.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "StackEmbedding", 8 | ] 9 | 10 | from typing import List 11 | 12 | import torch 13 | from torch import nn as nn 14 | 15 | from .embedding import TokenEmbedding 16 | 17 | 18 | class StackEmbedding(TokenEmbedding): 19 | r""" 20 | 支持将多个embedding集合成一个embedding。 21 | 22 | Example:: 23 | 24 | >>> from fastNLP import Vocabulary 25 | >>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding 26 | >>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) 27 | >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) 28 | >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) 29 | >>> embed = StackEmbedding([embed_1, embed_2]) 30 | 31 | """ 32 | 33 | def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): 34 | r""" 35 | 36 | :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 37 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 38 | 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 39 | :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 40 | """ 41 | vocabs = [] 42 | for embed in embeds: 43 | if hasattr(embed, 'get_word_vocab'): 44 | vocabs.append(embed.get_word_vocab()) 45 | _vocab = vocabs[0] 46 | for vocab in vocabs[1:]: 47 | assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." 48 | 49 | super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) 50 | assert isinstance(embeds, list) 51 | for embed in embeds: 52 | assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." 53 | self.embeds = nn.ModuleList(embeds) 54 | self._embed_size = sum([embed.embed_size for embed in self.embeds]) 55 | 56 | def append(self, embed: TokenEmbedding): 57 | r""" 58 | 添加一个embedding到结尾。 59 | :param embed: 60 | :return: 61 | """ 62 | assert isinstance(embed, TokenEmbedding) 63 | self._embed_size += embed.embed_size 64 | self.embeds.append(embed) 65 | return self 66 | 67 | def pop(self): 68 | r""" 69 | 弹出最后一个embed 70 | :return: 71 | """ 72 | embed = self.embeds.pop() 73 | self._embed_size -= embed.embed_size 74 | return embed 75 | 76 | @property 77 | def embed_size(self): 78 | r""" 79 | 该Embedding输出的vector的最后一维的维度。 80 | :return: 81 | """ 82 | return self._embed_size 83 | 84 | def forward(self, words): 85 | r""" 86 | 得到多个embedding的结果,并把结果按照顺序concat起来。 87 | 88 | :param words: batch_size x max_len 89 | :return: 返回的shape和当前这个stack embedding中embedding的组成有关 90 | """ 91 | outputs = [] 92 | words = self.drop_word(words) 93 | for embed in self.embeds: 94 | outputs.append(embed(words)) 95 | outputs = self.dropout(torch.cat(outputs, dim=-1)) 96 | return outputs 97 | -------------------------------------------------------------------------------- /fastNLP/embeddings/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | import numpy as np 6 | import torch 7 | from torch import nn as nn 8 | 9 | from ..core.vocabulary import Vocabulary 10 | 11 | __all__ = [ 12 | 'get_embeddings', 13 | 'get_sinusoid_encoding_table' 14 | ] 15 | 16 | 17 | def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): 18 | r""" 19 | 给定一个word的vocabulary生成character的vocabulary. 20 | 21 | :param vocab: 从vocab 22 | :param min_freq: 23 | :param include_word_start_end: 是否需要包含特殊的 24 | :return: 25 | """ 26 | char_vocab = Vocabulary(min_freq=min_freq) 27 | for word, index in vocab: 28 | if not vocab._is_word_no_create_entry(word): 29 | char_vocab.add_word_lst(list(word)) 30 | if include_word_start_end: 31 | char_vocab.add_word_lst(['', '']) 32 | return char_vocab 33 | 34 | 35 | def get_embeddings(init_embed, padding_idx=None): 36 | r""" 37 | 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray 38 | 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 39 | 返回原对象。 40 | 41 | :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 42 | nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; 43 | 传入torch.Tensor, 将使用传入的值作为Embedding初始化。 44 | :param padding_idx: 当传入tuple时,padding_idx有效 45 | :return nn.Embedding: embeddings 46 | """ 47 | if isinstance(init_embed, tuple): 48 | res = nn.Embedding( 49 | num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) 50 | nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), 51 | b=np.sqrt(3 / res.weight.data.size(1))) 52 | elif isinstance(init_embed, nn.Module): 53 | res = init_embed 54 | elif isinstance(init_embed, torch.Tensor): 55 | res = nn.Embedding.from_pretrained(init_embed, freeze=False) 56 | elif isinstance(init_embed, np.ndarray): 57 | init_embed = torch.tensor(init_embed, dtype=torch.float32) 58 | res = nn.Embedding.from_pretrained(init_embed, freeze=False) 59 | else: 60 | raise TypeError( 61 | 'invalid init_embed type: {}'.format((type(init_embed)))) 62 | return res 63 | 64 | 65 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 66 | """ 67 | sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos 68 | 69 | :param int n_position: 一共多少个position 70 | :param int d_hid: 多少维度,需要为偶数 71 | :param padding_idx: 72 | :return: torch.FloatTensor, shape为n_position x d_hid 73 | """ 74 | 75 | def cal_angle(position, hid_idx): 76 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 77 | 78 | def get_posi_angle_vec(position): 79 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 80 | 81 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 82 | 83 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 84 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 85 | 86 | if padding_idx is not None: 87 | # zero vector for padding dimension 88 | sinusoid_table[padding_idx] = 0. 89 | 90 | return torch.FloatTensor(sinusoid_table) 91 | 92 | -------------------------------------------------------------------------------- /fastNLP/io/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 用于IO的模块, 具体包括: 3 | 4 | 1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类, 5 | 6 | 2. 用于读入不同格式数据的 :mod:`Loader ` 类 7 | 8 | 3. 用于处理读入数据的 :mod:`Pipe ` 类 9 | 10 | 4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 ` 11 | 12 | 这些类的使用方法如下: 13 | """ 14 | __all__ = [ 15 | 'DataBundle', 16 | 17 | 'EmbedLoader', 18 | 19 | 'Loader', 20 | 21 | 'CLSBaseLoader', 22 | 'AGsNewsLoader', 23 | 'DBPediaLoader', 24 | 'YelpFullLoader', 25 | 'YelpPolarityLoader', 26 | 'IMDBLoader', 27 | 'SSTLoader', 28 | 'SST2Loader', 29 | "ChnSentiCorpLoader", 30 | "THUCNewsLoader", 31 | "WeiboSenti100kLoader", 32 | 33 | 'ConllLoader', 34 | 'Conll2003Loader', 35 | 'Conll2003NERLoader', 36 | 'OntoNotesNERLoader', 37 | 'CTBLoader', 38 | "MsraNERLoader", 39 | "WeiboNERLoader", 40 | "PeopleDailyNERLoader", 41 | 42 | 'CSVLoader', 43 | 'JsonLoader', 44 | 45 | 'CWSLoader', 46 | 47 | 'MNLILoader', 48 | "QuoraLoader", 49 | "SNLILoader", 50 | "QNLILoader", 51 | "RTELoader", 52 | "CNXNLILoader", 53 | "BQCorpusLoader", 54 | "LCQMCLoader", 55 | 56 | "CMRC2018Loader", 57 | 58 | "Pipe", 59 | 60 | "CLSBasePipe", 61 | "AGsNewsPipe", 62 | "DBPediaPipe", 63 | "YelpFullPipe", 64 | "YelpPolarityPipe", 65 | "SSTPipe", 66 | "SST2Pipe", 67 | "IMDBPipe", 68 | "ChnSentiCorpPipe", 69 | "THUCNewsPipe", 70 | "WeiboSenti100kPipe", 71 | 72 | "Conll2003Pipe", 73 | "Conll2003NERPipe", 74 | "OntoNotesNERPipe", 75 | "MsraNERPipe", 76 | "PeopleDailyPipe", 77 | "WeiboNERPipe", 78 | 79 | "CWSPipe", 80 | 81 | "Conll2003NERPipe", 82 | "OntoNotesNERPipe", 83 | "MsraNERPipe", 84 | "WeiboNERPipe", 85 | "PeopleDailyPipe", 86 | "Conll2003Pipe", 87 | 88 | "MatchingBertPipe", 89 | "RTEBertPipe", 90 | "SNLIBertPipe", 91 | "QuoraBertPipe", 92 | "QNLIBertPipe", 93 | "MNLIBertPipe", 94 | "CNXNLIBertPipe", 95 | "BQCorpusBertPipe", 96 | "LCQMCBertPipe", 97 | "MatchingPipe", 98 | "RTEPipe", 99 | "SNLIPipe", 100 | "QuoraPipe", 101 | "QNLIPipe", 102 | "MNLIPipe", 103 | "LCQMCPipe", 104 | "CNXNLIPipe", 105 | "BQCorpusPipe", 106 | "RenamePipe", 107 | "GranularizePipe", 108 | "MachingTruncatePipe", 109 | 110 | "CMRC2018BertPipe", 111 | 112 | 'ModelLoader', 113 | 'ModelSaver', 114 | 115 | ] 116 | 117 | import sys 118 | 119 | from .data_bundle import DataBundle 120 | from .embed_loader import EmbedLoader 121 | from .loader import * 122 | from .model_io import ModelLoader, ModelSaver 123 | from .pipe import * 124 | from ..doc_utils import doc_process 125 | 126 | doc_process(sys.modules[__name__]) -------------------------------------------------------------------------------- /fastNLP/io/file_reader.py: -------------------------------------------------------------------------------- 1 | r"""undocumented 2 | 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API 3 | """ 4 | 5 | __all__ = [] 6 | 7 | import json 8 | import csv 9 | 10 | from ..core import logger 11 | 12 | 13 | def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): 14 | r""" 15 | Construct a generator to read csv items. 16 | 17 | :param path: file path 18 | :param encoding: file's encoding, default: utf-8 19 | :param headers: file's headers, if None, make file's first line as headers. default: None 20 | :param sep: separator for each column. default: ',' 21 | :param dropna: weather to ignore and drop invalid data, 22 | :if False, raise ValueError when reading invalid data. default: True 23 | :return: generator, every time yield (line number, csv item) 24 | """ 25 | with open(path, 'r', encoding=encoding) as csv_file: 26 | f = csv.reader(csv_file, delimiter=sep) 27 | start_idx = 0 28 | if headers is None: 29 | headers = next(f) 30 | start_idx += 1 31 | elif not isinstance(headers, (list, tuple)): 32 | raise TypeError("headers should be list or tuple, not {}." \ 33 | .format(type(headers))) 34 | for line_idx, line in enumerate(f, start_idx): 35 | contents = line 36 | if len(contents) != len(headers): 37 | if dropna: 38 | continue 39 | else: 40 | if "" in headers: 41 | raise ValueError(("Line {} has {} parts, while header has {} parts.\n" + 42 | "Please check the empty parts or unnecessary '{}'s in header.") 43 | .format(line_idx, len(contents), len(headers), sep)) 44 | else: 45 | raise ValueError("Line {} has {} parts, while header has {} parts." \ 46 | .format(line_idx, len(contents), len(headers))) 47 | _dict = {} 48 | for header, content in zip(headers, contents): 49 | _dict[header] = content 50 | yield line_idx, _dict 51 | 52 | 53 | def _read_json(path, encoding='utf-8', fields=None, dropna=True): 54 | r""" 55 | Construct a generator to read json items. 56 | 57 | :param path: file path 58 | :param encoding: file's encoding, default: utf-8 59 | :param fields: json object's fields that needed, if None, all fields are needed. default: None 60 | :param dropna: weather to ignore and drop invalid data, 61 | :if False, raise ValueError when reading invalid data. default: True 62 | :return: generator, every time yield (line number, json item) 63 | """ 64 | if fields: 65 | fields = set(fields) 66 | with open(path, 'r', encoding=encoding) as f: 67 | for line_idx, line in enumerate(f): 68 | data = json.loads(line) 69 | if fields is None: 70 | yield line_idx, data 71 | continue 72 | _res = {} 73 | for k, v in data.items(): 74 | if k in fields: 75 | _res[k] = v 76 | if len(_res) < len(fields): 77 | if dropna: 78 | continue 79 | else: 80 | raise ValueError('invalid instance at line: {}'.format(line_idx)) 81 | yield line_idx, _res 82 | 83 | 84 | def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): 85 | r""" 86 | Construct a generator to read conll items. 87 | :param path: file path 88 | :param encoding: file's encoding, default: utf-8 89 | :param sep: seperator 90 | :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None 91 | :param dropna: weather to ignore and drop invalid data, 92 | :if False, raise ValueError when reading invalid data. default: True 93 | :return: generator, every time yield (line number, conll item) 94 | """ 95 | def parse_conll(sample): 96 | sample = list(map(list, zip(*sample))) 97 | sample = [sample[i] for i in indexes] 98 | for f in sample: 99 | if len(f) <= 0: 100 | raise ValueError('empty field') 101 | return sample 102 | 103 | 104 | with open(path, 'r', encoding=encoding) as f: 105 | sample = [] 106 | start = next(f).strip() 107 | if start != '': 108 | sample.append(start.split(sep)) if sep else sample.append(start.split()) 109 | for line_idx, line in enumerate(f, 1): 110 | line = line.strip() 111 | if line == '': 112 | if len(sample): 113 | try: 114 | res = parse_conll(sample) 115 | sample = [] 116 | yield line_idx, res 117 | except Exception as e: 118 | if dropna: 119 | logger.warning('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) 120 | sample = [] 121 | continue 122 | raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) 123 | elif line.startswith('#'): 124 | continue 125 | else: 126 | sample.append(line.split(sep)) if sep else sample.append(line.split()) 127 | if len(sample) > 0: 128 | try: 129 | res = parse_conll(sample) 130 | yield line_idx, res 131 | except Exception as e: 132 | if dropna: 133 | return 134 | logger.error('invalid instance ends at line: {}'.format(line_idx)) 135 | raise e 136 | -------------------------------------------------------------------------------- /fastNLP/io/loader/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 3 | 三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, 4 | 读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; 5 | ``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: 6 | 7 | 0.传入None 8 | 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 9 | 10 | 1.传入一个文件的 path 11 | 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.get_dataset('train')`` 获取 12 | 13 | 2.传入一个文件夹目录 14 | 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: 15 | 16 | | 17 | +-train3.txt 18 | +-dev.txt 19 | +-test.txt 20 | +-other.txt 21 | 22 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , 23 | ``data_bundle.get_dataset('dev')`` , 24 | ``data_bundle.get_dataset('test')`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: 25 | 26 | | 27 | +-train3.txt 28 | +-dev.txt 29 | 30 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , 31 | ``data_bundle.get_dataset('dev')`` 获取对应的 dataset。 32 | 33 | 3.传入一个字典 34 | 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: 35 | 36 | paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} 37 | 38 | 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` , 39 | ``data_bundle.get_dataset('test')`` 来获取对应的 `dataset` 40 | 41 | fastNLP 目前提供了如下的 Loader 42 | 43 | 44 | 45 | """ 46 | 47 | __all__ = [ 48 | 'Loader', 49 | 50 | 'CLSBaseLoader', 51 | 'YelpFullLoader', 52 | 'YelpPolarityLoader', 53 | 'AGsNewsLoader', 54 | 'DBPediaLoader', 55 | 'IMDBLoader', 56 | 'SSTLoader', 57 | 'SST2Loader', 58 | "ChnSentiCorpLoader", 59 | "THUCNewsLoader", 60 | "WeiboSenti100kLoader", 61 | 62 | 'ConllLoader', 63 | 'Conll2003Loader', 64 | 'Conll2003NERLoader', 65 | 'OntoNotesNERLoader', 66 | 'CTBLoader', 67 | "MsraNERLoader", 68 | "PeopleDailyNERLoader", 69 | "WeiboNERLoader", 70 | 71 | 'CSVLoader', 72 | 'JsonLoader', 73 | 74 | 'CWSLoader', 75 | 76 | 'MNLILoader', 77 | "QuoraLoader", 78 | "SNLILoader", 79 | "QNLILoader", 80 | "RTELoader", 81 | "CNXNLILoader", 82 | "BQCorpusLoader", 83 | "LCQMCLoader", 84 | 85 | "CoReferenceLoader", 86 | 87 | "CMRC2018Loader" 88 | ] 89 | from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \ 90 | SSTLoader, SST2Loader, DBPediaLoader, \ 91 | ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader 92 | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader 93 | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader 94 | from .coreference import CoReferenceLoader 95 | from .csv import CSVLoader 96 | from .cws import CWSLoader 97 | from .json import JsonLoader 98 | from .loader import Loader 99 | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ 100 | LCQMCLoader 101 | from .qa import CMRC2018Loader 102 | 103 | -------------------------------------------------------------------------------- /fastNLP/io/loader/coreference.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CoReferenceLoader", 5 | ] 6 | 7 | from ...core.dataset import DataSet 8 | from ..file_reader import _read_json 9 | from ...core.instance import Instance 10 | from ...core.const import Const 11 | from .json import JsonLoader 12 | 13 | 14 | class CoReferenceLoader(JsonLoader): 15 | r""" 16 | 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 17 | 18 | Example:: 19 | 20 | {"doc_key": "bc/cctv/00/cctv_0000_0", 21 | "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], 22 | "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], 23 | "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]] 24 | } 25 | 26 | 读取预处理好的Conll2012数据,数据结构如下: 27 | 28 | .. csv-table:: 29 | :header: "raw_words1", "raw_words2", "raw_words3", "raw_words4" 30 | 31 | "bc/cctv/00/cctv_0000_0", "[['Speaker#1', 'Speaker#1', 'Speaker#1...", "[[[70, 70], [485, 486], [500, 500], [7...", "[['In', 'the', 'summer', 'of', '2005',..." 32 | "...", "...", "...", "..." 33 | 34 | """ 35 | def __init__(self, fields=None, dropna=False): 36 | super().__init__(fields, dropna) 37 | self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), 38 | "sentences": Const.RAW_WORDS(3)} 39 | 40 | def _load(self, path): 41 | r""" 42 | 加载数据 43 | :param path: 数据文件路径,文件为json 44 | 45 | :return: 46 | """ 47 | dataset = DataSet() 48 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): 49 | if self.fields: 50 | ins = {self.fields[k]: v for k, v in d.items()} 51 | else: 52 | ins = d 53 | dataset.append(Instance(**ins)) 54 | return dataset 55 | 56 | def download(self): 57 | r""" 58 | 由于版权限制,不能提供自动下载功能。可参考 59 | 60 | https://www.aclweb.org/anthology/W12-4501 61 | 62 | :return: 63 | """ 64 | raise RuntimeError("CoReference cannot be downloaded automatically.") 65 | -------------------------------------------------------------------------------- /fastNLP/io/loader/csv.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CSVLoader", 5 | ] 6 | 7 | from .loader import Loader 8 | from ..file_reader import _read_csv 9 | from ...core.dataset import DataSet 10 | from ...core.instance import Instance 11 | 12 | 13 | class CSVLoader(Loader): 14 | r""" 15 | 读取CSV格式的数据集, 返回 ``DataSet`` 。 16 | 17 | """ 18 | 19 | def __init__(self, headers=None, sep=",", dropna=False): 20 | r""" 21 | 22 | :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 23 | 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` 24 | :param str sep: CSV文件中列与列之间的分隔符. Default: "," 25 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . 26 | Default: ``False`` 27 | """ 28 | super().__init__() 29 | self.headers = headers 30 | self.sep = sep 31 | self.dropna = dropna 32 | 33 | def _load(self, path): 34 | ds = DataSet() 35 | for idx, data in _read_csv(path, headers=self.headers, 36 | sep=self.sep, dropna=self.dropna): 37 | ds.append(Instance(**data)) 38 | return ds 39 | 40 | -------------------------------------------------------------------------------- /fastNLP/io/loader/cws.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CWSLoader" 5 | ] 6 | 7 | import glob 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | 13 | from .loader import Loader 14 | from ...core.dataset import DataSet 15 | from ...core.instance import Instance 16 | 17 | 18 | class CWSLoader(Loader): 19 | r""" 20 | CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: 21 | 22 | Example:: 23 | 24 | 上海 浦东 开发 与 法制 建设 同步 25 | 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) 26 | ... 27 | 28 | 该Loader读取后的DataSet具有如下的结构 29 | 30 | .. csv-table:: 31 | :header: "raw_words" 32 | 33 | "上海 浦东 开发 与 法制 建设 同步" 34 | "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" 35 | "..." 36 | 37 | """ 38 | def __init__(self, dataset_name:str=None): 39 | r""" 40 | 41 | :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None 42 | """ 43 | super().__init__() 44 | datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'} 45 | if dataset_name in datanames: 46 | self.dataset_name = datanames[dataset_name] 47 | else: 48 | self.dataset_name = None 49 | 50 | def _load(self, path:str): 51 | ds = DataSet() 52 | with open(path, 'r', encoding='utf-8') as f: 53 | for line in f: 54 | line = line.strip() 55 | if line: 56 | ds.append(Instance(raw_words=line)) 57 | return ds 58 | 59 | def download(self, dev_ratio=0.1, re_download=False)->str: 60 | r""" 61 | 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, 62 | 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 63 | 64 | :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 65 | :param bool re_download: 是否重新下载数据,以重新切分数据。 66 | :return: str 67 | """ 68 | if self.dataset_name is None: 69 | return None 70 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name) 71 | modify_time = 0 72 | for filepath in glob.glob(os.path.join(data_dir, '*')): 73 | modify_time = os.stat(filepath).st_mtime 74 | break 75 | if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 76 | shutil.rmtree(data_dir) 77 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name) 78 | 79 | if not os.path.exists(os.path.join(data_dir, 'dev.txt')): 80 | if dev_ratio > 0: 81 | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." 82 | try: 83 | with open(os.path.join(data_dir, 'train3.txt'), 'r', encoding='utf-8') as f, \ 84 | open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ 85 | open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: 86 | for line in f: 87 | if random.random() < dev_ratio: 88 | f2.write(line) 89 | else: 90 | f1.write(line) 91 | os.remove(os.path.join(data_dir, 'train3.txt')) 92 | os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train3.txt')) 93 | finally: 94 | if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): 95 | os.remove(os.path.join(data_dir, 'middle_file.txt')) 96 | 97 | return data_dir 98 | -------------------------------------------------------------------------------- /fastNLP/io/loader/json.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "JsonLoader" 5 | ] 6 | 7 | from .loader import Loader 8 | from ..file_reader import _read_json 9 | from ...core.dataset import DataSet 10 | from ...core.instance import Instance 11 | 12 | 13 | class JsonLoader(Loader): 14 | r""" 15 | 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` 16 | 17 | 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 18 | 19 | :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name 20 | ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , 21 | `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 22 | ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` 23 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . 24 | Default: ``False`` 25 | """ 26 | 27 | def __init__(self, fields=None, dropna=False): 28 | super(JsonLoader, self).__init__() 29 | self.dropna = dropna 30 | self.fields = None 31 | self.fields_list = None 32 | if fields: 33 | self.fields = {} 34 | for k, v in fields.items(): 35 | self.fields[k] = k if v is None else v 36 | self.fields_list = list(self.fields.keys()) 37 | 38 | def _load(self, path): 39 | ds = DataSet() 40 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): 41 | if self.fields: 42 | ins = {self.fields[k]: v for k, v in d.items()} 43 | else: 44 | ins = d 45 | ds.append(Instance(**ins)) 46 | return ds 47 | -------------------------------------------------------------------------------- /fastNLP/io/loader/loader.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Loader" 5 | ] 6 | 7 | from typing import Union, Dict 8 | 9 | from .. import DataBundle 10 | from ..file_utils import _get_dataset_url, get_cache_path, cached_path 11 | from ..utils import check_loader_paths 12 | from ...core.dataset import DataSet 13 | 14 | 15 | class Loader: 16 | r""" 17 | 各种数据 Loader 的基类,提供了 API 的参考. 18 | Loader支持以下的三个函数 19 | 20 | - download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。 21 | - _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。 22 | - load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回 23 | 24 | """ 25 | 26 | def __init__(self): 27 | pass 28 | 29 | def _load(self, path: str) -> DataSet: 30 | r""" 31 | 给定一个路径,返回读取的DataSet。 32 | 33 | :param str path: 路径 34 | :return: DataSet 35 | """ 36 | raise NotImplementedError 37 | 38 | def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: 39 | r""" 40 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 41 | 42 | :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: 43 | 44 | 0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 45 | 46 | 1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: 47 | 48 | data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train 49 | # dev、 test等有所变化,可以通过以下的方式取出DataSet 50 | tr_data = data_bundle.get_dataset('train') 51 | te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 52 | 53 | 2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: 54 | 55 | paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} 56 | data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" 57 | dev_data = data_bundle.get_dataset('dev') 58 | 59 | 3.传入文件路径:: 60 | 61 | data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' 62 | tr_data = data_bundle.get_dataset('train') # 取出DataSet 63 | 64 | :return: 返回的 :class:`~fastNLP.io.DataBundle` 65 | """ 66 | if paths is None: 67 | paths = self.download() 68 | paths = check_loader_paths(paths) 69 | datasets = {name: self._load(path) for name, path in paths.items()} 70 | data_bundle = DataBundle(datasets=datasets) 71 | return data_bundle 72 | 73 | def download(self) -> str: 74 | r""" 75 | 自动下载该数据集 76 | 77 | :return: 下载后解压目录 78 | """ 79 | raise NotImplementedError(f"{self.__class__} cannot download data automatically.") 80 | 81 | @staticmethod 82 | def _get_dataset_path(dataset_name): 83 | r""" 84 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) 85 | 86 | :param str dataset_name: 数据集的名称 87 | :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 88 | """ 89 | 90 | default_cache_path = get_cache_path() 91 | url = _get_dataset_url(dataset_name) 92 | output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') 93 | 94 | return output_dir 95 | -------------------------------------------------------------------------------- /fastNLP/io/loader/qa.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 该文件中的Loader主要用于读取问答式任务的数据 3 | 4 | """ 5 | 6 | 7 | from . import Loader 8 | import json 9 | from ...core import DataSet, Instance 10 | 11 | __all__ = ['CMRC2018Loader'] 12 | 13 | 14 | class CMRC2018Loader(Loader): 15 | r""" 16 | 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 17 | 18 | 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 19 | 20 | .. csv-table:: 21 | :header:"title", "context", "question", "answers", "answer_starts", "id" 22 | 23 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" 24 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" 25 | "...", "...", "...","...", ".", "..." 26 | 27 | 其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 28 | 29 | 验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) 30 | 31 | .. csv-table:: 32 | :header: "title", "context", "question", "answers", "answer_starts", "id" 33 | 34 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", "['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']", "[30, 30, 30]", "DEV_0_QUERY_0" 35 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1" 36 | "...", "...", "...","...", ".", "..." 37 | 38 | 其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 39 | 英文和数字都直接按照character计量的。 40 | """ 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def _load(self, path: str) -> DataSet: 45 | with open(path, 'r', encoding='utf-8') as f: 46 | data = json.load(f)['data'] 47 | ds = DataSet() 48 | for entry in data: 49 | title = entry['title'] 50 | para = entry['paragraphs'][0] 51 | context = para['context'] 52 | qas = para['qas'] 53 | for qa in qas: 54 | question = qa['question'] 55 | ans = qa['answers'] 56 | answers = [] 57 | answer_starts = [] 58 | id = qa['id'] 59 | for an in ans: 60 | answers.append(an['text']) 61 | answer_starts.append(an['answer_start']) 62 | ds.append(Instance(title=title, context=context, question=question, answers=answers, 63 | answer_starts=answer_starts,id=id)) 64 | return ds 65 | 66 | def download(self) -> str: 67 | r""" 68 | 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. 69 | 70 | :return: 71 | """ 72 | output_dir = self._get_dataset_path('cmrc2018') 73 | return output_dir 74 | 75 | -------------------------------------------------------------------------------- /fastNLP/io/loader/summarization.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ExtCNNDMLoader" 5 | ] 6 | 7 | import os 8 | from typing import Union, Dict 9 | 10 | from ..data_bundle import DataBundle 11 | from ..utils import check_loader_paths 12 | from .json import JsonLoader 13 | 14 | 15 | class ExtCNNDMLoader(JsonLoader): 16 | r""" 17 | 读取之后的DataSet中的field情况为 18 | 19 | .. csv-table:: 20 | :header: "text", "summary", "label", "publication" 21 | 22 | ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" 23 | ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" 24 | ["..."], ["..."], [], "cnndm" 25 | 26 | """ 27 | 28 | def __init__(self, fields=None): 29 | fields = fields or {"text": None, "summary": None, "label": None, "publication": None} 30 | super(ExtCNNDMLoader, self).__init__(fields=fields) 31 | 32 | def load(self, paths: Union[str, Dict[str, str]] = None): 33 | r""" 34 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 35 | 36 | 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 37 | 38 | :param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl 39 | test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` 40 | 当中需要用到)。 41 | 42 | :return: 返回 :class:`~fastNLP.io.DataBundle` 43 | """ 44 | if paths is None: 45 | paths = self.download() 46 | paths = check_loader_paths(paths) 47 | if ('train' in paths) and ('test' not in paths): 48 | paths['test'] = paths['train'] 49 | paths.pop('train') 50 | 51 | datasets = {name: self._load(path) for name, path in paths.items()} 52 | data_bundle = DataBundle(datasets=datasets) 53 | return data_bundle 54 | 55 | def download(self): 56 | r""" 57 | 如果你使用了这个数据,请引用 58 | 59 | https://arxiv.org/pdf/1506.03340.pdf 60 | :return: 61 | """ 62 | output_dir = self._get_dataset_path('ext-cnndm') 63 | return output_dir 64 | -------------------------------------------------------------------------------- /fastNLP/io/model_io.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 用于载入和保存模型 3 | """ 4 | __all__ = [ 5 | "ModelLoader", 6 | "ModelSaver" 7 | ] 8 | 9 | import torch 10 | 11 | 12 | class ModelLoader: 13 | r""" 14 | 用于读取模型 15 | """ 16 | 17 | def __init__(self): 18 | super(ModelLoader, self).__init__() 19 | 20 | @staticmethod 21 | def load_pytorch(empty_model, model_path): 22 | r""" 23 | 从 ".pkl" 文件读取 PyTorch 模型 24 | 25 | :param empty_model: 初始化参数的 PyTorch 模型 26 | :param str model_path: 模型保存的路径 27 | """ 28 | empty_model.load_state_dict(torch.load(model_path)) 29 | 30 | @staticmethod 31 | def load_pytorch_model(model_path): 32 | r""" 33 | 读取整个模型 34 | 35 | :param str model_path: 模型保存的路径 36 | """ 37 | return torch.load(model_path) 38 | 39 | 40 | class ModelSaver(object): 41 | r""" 42 | 用于保存模型 43 | 44 | Example:: 45 | 46 | saver = ModelSaver("./save/model_ckpt_100.pkl") 47 | saver.save_pytorch(model) 48 | 49 | """ 50 | 51 | def __init__(self, save_path): 52 | r""" 53 | 54 | :param save_path: 模型保存的路径 55 | """ 56 | self.save_path = save_path 57 | 58 | def save_pytorch(self, model, param_only=True): 59 | r""" 60 | 把 PyTorch 模型存入 ".pkl" 文件 61 | 62 | :param model: PyTorch 模型 63 | :param bool param_only: 是否只保存模型的参数(否则保存整个模型) 64 | 65 | """ 66 | if param_only is True: 67 | torch.save(model.state_dict(), self.save_path) 68 | else: 69 | torch.save(model, self.save_path) 70 | -------------------------------------------------------------------------------- /fastNLP/io/pipe/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 3 | ``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; 4 | ``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 5 | ``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` 6 | 一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, 7 | `data_bundle` 还会包含将field转为index时所建立的词表。 8 | 9 | """ 10 | __all__ = [ 11 | "Pipe", 12 | 13 | "CWSPipe", 14 | 15 | "CLSBasePipe", 16 | "AGsNewsPipe", 17 | "DBPediaPipe", 18 | "YelpFullPipe", 19 | "YelpPolarityPipe", 20 | "SSTPipe", 21 | "SST2Pipe", 22 | "IMDBPipe", 23 | "ChnSentiCorpPipe", 24 | "THUCNewsPipe", 25 | "WeiboSenti100kPipe", 26 | 27 | "Conll2003NERPipe", 28 | "OntoNotesNERPipe", 29 | "MsraNERPipe", 30 | "WeiboNERPipe", 31 | "PeopleDailyPipe", 32 | "Conll2003Pipe", 33 | 34 | "MatchingBertPipe", 35 | "RTEBertPipe", 36 | "SNLIBertPipe", 37 | "QuoraBertPipe", 38 | "QNLIBertPipe", 39 | "MNLIBertPipe", 40 | "CNXNLIBertPipe", 41 | "BQCorpusBertPipe", 42 | "LCQMCBertPipe", 43 | "MatchingPipe", 44 | "RTEPipe", 45 | "SNLIPipe", 46 | "QuoraPipe", 47 | "QNLIPipe", 48 | "MNLIPipe", 49 | "LCQMCPipe", 50 | "CNXNLIPipe", 51 | "BQCorpusPipe", 52 | "RenamePipe", 53 | "GranularizePipe", 54 | "MachingTruncatePipe", 55 | 56 | "CoReferencePipe", 57 | 58 | "CMRC2018BertPipe" 59 | ] 60 | 61 | from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ 62 | WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe 63 | from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe 64 | from .conll import Conll2003Pipe 65 | from .coreference import CoReferencePipe 66 | from .cws import CWSPipe 67 | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ 68 | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ 69 | LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe 70 | from .pipe import Pipe 71 | from .qa import CMRC2018BertPipe 72 | -------------------------------------------------------------------------------- /fastNLP/io/pipe/pipe.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Pipe", 5 | ] 6 | 7 | from .. import DataBundle 8 | 9 | 10 | class Pipe: 11 | r""" 12 | Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe 13 | 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 14 | 15 | 一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字; 16 | (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index; 17 | 18 | Pipe中提供了两个方法 19 | 20 | -process()函数,输入为DataBundle 21 | -process_from_file()函数,输入为对应Loader的load函数可接受的类型。 22 | 23 | """ 24 | 25 | def process(self, data_bundle: DataBundle) -> DataBundle: 26 | r""" 27 | 对输入的DataBundle进行处理,然后返回该DataBundle。 28 | 29 | :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 30 | :return: 31 | """ 32 | raise NotImplementedError 33 | 34 | def process_from_file(self, paths) -> DataBundle: 35 | r""" 36 | 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` 37 | 38 | :param paths: 39 | :return: DataBundle 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /fastNLP/io/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "check_loader_paths" 8 | ] 9 | 10 | import os 11 | from pathlib import Path 12 | from typing import Union, Dict 13 | 14 | from ..core import logger 15 | 16 | 17 | def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: 18 | r""" 19 | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: 20 | 21 | { 22 | 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 23 | 'test': 'xxx' # 可能有,也可能没有 24 | ... 25 | } 26 | 27 | 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 28 | 29 | :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 30 | 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 31 | :return: 32 | """ 33 | if isinstance(paths, (str, Path)): 34 | paths = os.path.abspath(os.path.expanduser(paths)) 35 | if os.path.isfile(paths): 36 | return {'train': paths} 37 | elif os.path.isdir(paths): 38 | filenames = os.listdir(paths) 39 | filenames.sort() 40 | files = {} 41 | for filename in filenames: 42 | path_pair = None 43 | if 'train' in filename: 44 | path_pair = ('train', filename) 45 | if 'dev' in filename: 46 | if path_pair: 47 | raise Exception( 48 | "Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) 49 | path_pair = ('dev', filename) 50 | if 'test' in filename: 51 | if path_pair: 52 | raise Exception( 53 | "Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) 54 | path_pair = ('test', filename) 55 | if path_pair: 56 | if path_pair[0] in files: 57 | raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " 58 | f"filepath for `{path_pair[0]}`.") 59 | files[path_pair[0]] = os.path.join(paths, path_pair[1]) 60 | if 'train' not in files: 61 | raise KeyError(f"There is no train file in {paths}.") 62 | return files 63 | else: 64 | raise FileNotFoundError(f"{paths} is not a valid file path.") 65 | 66 | elif isinstance(paths, dict): 67 | if paths: 68 | if 'train' not in paths: 69 | raise KeyError("You have to include `train` in your dict.") 70 | for key, value in paths.items(): 71 | if isinstance(key, str) and isinstance(value, str): 72 | value = os.path.abspath(os.path.expanduser(value)) 73 | if not os.path.exists(value): 74 | raise TypeError(f"{value} is not a valid path.") 75 | paths[key] = value 76 | else: 77 | raise TypeError("All keys and values in paths should be str.") 78 | return paths 79 | else: 80 | raise ValueError("Empty paths is not allowed.") 81 | else: 82 | raise TypeError(f"paths only supports str and dict. not {type(paths)}.") 83 | -------------------------------------------------------------------------------- /fastNLP/models/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、 3 | :class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。 4 | 5 | .. todo:: 6 | 这些模型的介绍(与主页一致) 7 | 8 | 9 | """ 10 | __all__ = [ 11 | "CNNText", 12 | 13 | "SeqLabeling", 14 | "AdvSeqLabel", 15 | "BiLSTMCRF", 16 | 17 | "ESIM", 18 | 19 | "StarTransEnc", 20 | "STSeqLabel", 21 | "STNLICls", 22 | "STSeqCls", 23 | 24 | "BiaffineParser", 25 | "GraphParser", 26 | 27 | "BertForSequenceClassification", 28 | "BertForSentenceMatching", 29 | "BertForMultipleChoice", 30 | "BertForTokenClassification", 31 | "BertForQuestionAnswering", 32 | 33 | "TransformerSeq2SeqModel", 34 | "LSTMSeq2SeqModel", 35 | "Seq2SeqModel", 36 | 37 | 'SequenceGeneratorModel' 38 | ] 39 | 40 | from .base_model import BaseModel 41 | from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ 42 | BertForTokenClassification, BertForSentenceMatching 43 | from .biaffine_parser import BiaffineParser, GraphParser 44 | from .cnn_text_classification import CNNText 45 | from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF 46 | from .snli import ESIM 47 | from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel 48 | from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel 49 | from .seq2seq_generator import SequenceGeneratorModel 50 | import sys 51 | from ..doc_utils import doc_process 52 | 53 | doc_process(sys.modules[__name__]) 54 | -------------------------------------------------------------------------------- /fastNLP/models/base_model.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [] 4 | 5 | import torch 6 | 7 | from ..modules.decoder.mlp import MLP 8 | 9 | 10 | class BaseModel(torch.nn.Module): 11 | r"""Base PyTorch model for all models. 12 | """ 13 | 14 | def __init__(self): 15 | super(BaseModel, self).__init__() 16 | 17 | def fit(self, train_data, dev_data=None, **train_args): 18 | pass 19 | 20 | def predict(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | 24 | class NaiveClassifier(BaseModel): 25 | r""" 26 | 一个简单的分类器例子,可用于各种测试 27 | """ 28 | def __init__(self, in_feature_dim, out_feature_dim): 29 | super(NaiveClassifier, self).__init__() 30 | self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) 31 | 32 | def forward(self, x): 33 | return {"predict": torch.sigmoid(self.mlp(x))} 34 | 35 | def predict(self, x): 36 | return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} 37 | -------------------------------------------------------------------------------- /fastNLP/models/cnn_text_classification.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "CNNText" 8 | ] 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from ..core.const import Const as C 14 | from ..core.utils import seq_len_to_mask 15 | from ..embeddings import embedding 16 | from ..modules import encoder 17 | 18 | 19 | class CNNText(torch.nn.Module): 20 | r""" 21 | 使用CNN进行文本分类的模型 22 | 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' 23 | 24 | """ 25 | 26 | def __init__(self, embed, 27 | num_classes, 28 | kernel_nums=(30, 40, 50), 29 | kernel_sizes=(1, 3, 5), 30 | dropout=0.5): 31 | r""" 32 | 33 | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), 34 | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding 35 | :param int num_classes: 一共有多少类 36 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 37 | :param float dropout: Dropout的大小 38 | """ 39 | super(CNNText, self).__init__() 40 | 41 | # no support for pre-trained embedding currently 42 | self.embed = embedding.Embedding(embed) 43 | self.conv_pool = encoder.ConvMaxpool( 44 | in_channels=self.embed.embedding_dim, 45 | out_channels=kernel_nums, 46 | kernel_sizes=kernel_sizes) 47 | self.dropout = nn.Dropout(dropout) 48 | self.fc = nn.Linear(sum(kernel_nums), num_classes) 49 | 50 | def forward(self, words, seq_len=None): 51 | r""" 52 | 53 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index 54 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 55 | :return output: dict of torch.LongTensor, [batch_size, num_classes] 56 | """ 57 | x = self.embed(words) # [N,L] -> [N,L,C] 58 | if seq_len is not None: 59 | mask = seq_len_to_mask(seq_len) 60 | x = self.conv_pool(x, mask) 61 | else: 62 | x = self.conv_pool(x) # [N,L,C] -> [N,C] 63 | x = self.dropout(x) 64 | x = self.fc(x) # [N,C] -> [N, N_class] 65 | return {C.OUTPUT: x} 66 | 67 | def predict(self, words, seq_len=None): 68 | r""" 69 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index 70 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 71 | 72 | :return predict: dict of torch.LongTensor, [batch_size, ] 73 | """ 74 | output = self(words, seq_len) 75 | _, predict = output[C.OUTPUT].max(dim=1) 76 | return {C.OUTPUT: predict} 77 | -------------------------------------------------------------------------------- /fastNLP/models/seq2seq_generator.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | import torch 4 | from torch import nn 5 | from .seq2seq_model import Seq2SeqModel 6 | from ..modules.generator.seq2seq_generator import SequenceGenerator 7 | 8 | 9 | class SequenceGeneratorModel(nn.Module): 10 | """ 11 | 用于封装Seq2SeqModel使其可以做生成任务 12 | 13 | """ 14 | 15 | def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, 16 | num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, 17 | repetition_penalty=1, length_penalty=1.0, pad_token_id=0): 18 | """ 19 | 20 | :param Seq2SeqModel seq2seq_model: 序列到序列模型 21 | :param int,None bos_token_id: 句子开头的token id 22 | :param int,None eos_token_id: 句子结束的token id 23 | :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len 24 | :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask 25 | :param int num_beams: beam search的大小 26 | :param bool do_sample: 是否通过采样的方式生成 27 | :param float temperature: 只有在do_sample为True才有意义 28 | :param int top_k: 只从top_k中采样 29 | :param float top_p: 只从top_p的token中采样,nucles sample 30 | :param float repetition_penalty: 多大程度上惩罚重复的token 31 | :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 32 | :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 33 | """ 34 | super().__init__() 35 | self.seq2seq_model = seq2seq_model 36 | self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, 37 | num_beams=num_beams, 38 | do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, 39 | bos_token_id=bos_token_id, 40 | eos_token_id=eos_token_id, 41 | repetition_penalty=repetition_penalty, length_penalty=length_penalty, 42 | pad_token_id=pad_token_id) 43 | 44 | def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): 45 | """ 46 | 透传调用seq2seq_model的forward 47 | 48 | :param torch.LongTensor src_tokens: bsz x max_len 49 | :param torch.LongTensor tgt_tokens: bsz x max_len' 50 | :param torch.LongTensor src_seq_len: bsz 51 | :param torch.LongTensor tgt_seq_len: bsz 52 | :return: 53 | """ 54 | return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) 55 | 56 | def predict(self, src_tokens, src_seq_len=None): 57 | """ 58 | 给定source的内容,输出generate的内容 59 | 60 | :param torch.LongTensor src_tokens: bsz x max_len 61 | :param torch.LongTensor src_seq_len: bsz 62 | :return: 63 | """ 64 | state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) 65 | result = self.generator.generate(state) 66 | return {'pred': result} 67 | -------------------------------------------------------------------------------- /fastNLP/modules/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | .. image:: figures/text_classification.png 4 | 5 | 大部分用于的 NLP 任务神经网络都可以看做由 :mod:`embedding` 、 :mod:`~fastNLP.modules.encoder` 、 6 | :mod:`~fastNLP.modules.decoder` 三种模块组成。 本模块中实现了 fastNLP 提供的诸多模块组件, 7 | 可以帮助用户快速搭建自己所需的网络。几种模块的功能和常见组件如下: 8 | 9 | .. csv-table:: 10 | :header: "类型", "功能", "常见组件" 11 | 12 | "embedding", 参见 :mod:`/fastNLP.embeddings` , "Elmo, Bert" 13 | "encoder", "将输入编码为具有表示能力的向量", "CNN, LSTM, Transformer" 14 | "decoder", "将具有某种表示意义的向量解码为需要的输出形式 ", "MLP, CRF" 15 | "其它", "配合其它组件使用的组件", "Dropout" 16 | 17 | 18 | """ 19 | __all__ = [ 20 | # "BertModel", 21 | 22 | "ConvolutionCharEncoder", 23 | "LSTMCharEncoder", 24 | 25 | "ConvMaxpool", 26 | 27 | "LSTM", 28 | 29 | "StarTransformer", 30 | 31 | "TransformerEncoder", 32 | 33 | "VarRNN", 34 | "VarLSTM", 35 | "VarGRU", 36 | 37 | "MaxPool", 38 | "MaxPoolWithMask", 39 | "KMaxPool", 40 | "AvgPool", 41 | "AvgPoolWithMask", 42 | 43 | "MultiHeadAttention", 44 | 45 | "MLP", 46 | "ConditionalRandomField", 47 | "viterbi_decode", 48 | "allowed_transitions", 49 | 50 | "TimestepDropout", 51 | 52 | 'summary', 53 | 54 | "BertTokenizer", 55 | "BertModel", 56 | 57 | "RobertaTokenizer", 58 | "RobertaModel", 59 | 60 | "GPT2Model", 61 | "GPT2Tokenizer", 62 | 63 | "TransformerSeq2SeqEncoder", 64 | "LSTMSeq2SeqEncoder", 65 | "Seq2SeqEncoder", 66 | 67 | "TransformerSeq2SeqDecoder", 68 | "LSTMSeq2SeqDecoder", 69 | "Seq2SeqDecoder", 70 | 71 | "TransformerState", 72 | "LSTMState", 73 | "State", 74 | 75 | "SequenceGenerator" 76 | ] 77 | 78 | import sys 79 | 80 | from . import decoder 81 | from . import encoder 82 | from .decoder import * 83 | from .dropout import TimestepDropout 84 | from .encoder import * 85 | from .generator import * 86 | from .utils import summary 87 | from ..doc_utils import doc_process 88 | from .tokenizer import * 89 | 90 | doc_process(sys.modules[__name__]) 91 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | __all__ = [ 6 | "MLP", 7 | "ConditionalRandomField", 8 | "viterbi_decode", 9 | "allowed_transitions", 10 | 11 | "LSTMState", 12 | "TransformerState", 13 | "State", 14 | 15 | "TransformerSeq2SeqDecoder", 16 | "LSTMSeq2SeqDecoder", 17 | "Seq2SeqDecoder" 18 | ] 19 | 20 | from .crf import ConditionalRandomField 21 | from .crf import allowed_transitions 22 | from .mlp import MLP 23 | from .utils import viterbi_decode 24 | from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder 25 | from .seq2seq_state import State, LSTMState, TransformerState 26 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/mlp.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "MLP" 5 | ] 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ..utils import initial_parameter 11 | 12 | 13 | class MLP(nn.Module): 14 | r""" 15 | 多层感知器 16 | 17 | 18 | .. note:: 19 | 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 20 | 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; 21 | 如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 22 | 输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 23 | 24 | Examples:: 25 | 26 | >>> net1 = MLP([5, 10, 5]) 27 | >>> net2 = MLP([5, 10, 5], 'tanh') 28 | >>> net3 = MLP([5, 6, 7, 8, 5], 'tanh') 29 | >>> net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') 30 | >>> net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') 31 | >>> for net in [net1, net2, net3, net4, net5]: 32 | >>> x = torch.randn(5, 5) 33 | >>> y = net(x) 34 | >>> print(x) 35 | >>> print(y) 36 | """ 37 | 38 | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): 39 | r""" 40 | 41 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 42 | :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 43 | sigmoid,默认值为relu 44 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 45 | :param str initial_method: 参数初始化方式 46 | :param float dropout: dropout概率,默认值为0 47 | """ 48 | super(MLP, self).__init__() 49 | self.hiddens = nn.ModuleList() 50 | self.output = None 51 | self.output_activation = output_activation 52 | for i in range(1, len(size_layer)): 53 | if i + 1 == len(size_layer): 54 | self.output = nn.Linear(size_layer[i - 1], size_layer[i]) 55 | else: 56 | self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i])) 57 | 58 | self.dropout = nn.Dropout(p=dropout) 59 | 60 | actives = { 61 | 'relu': nn.ReLU(), 62 | 'tanh': nn.Tanh(), 63 | 'sigmoid': nn.Sigmoid(), 64 | } 65 | if not isinstance(activation, list): 66 | activation = [activation] * (len(size_layer) - 2) 67 | elif len(activation) == len(size_layer) - 2: 68 | pass 69 | else: 70 | raise ValueError( 71 | f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!") 72 | self.hidden_active = [] 73 | for func in activation: 74 | if callable(activation): 75 | self.hidden_active.append(activation) 76 | elif func.lower() in actives: 77 | self.hidden_active.append(actives[func]) 78 | else: 79 | raise ValueError("should set activation correctly: {}".format(activation)) 80 | if self.output_activation is not None: 81 | if callable(self.output_activation): 82 | pass 83 | elif self.output_activation.lower() in actives: 84 | self.output_activation = actives[self.output_activation] 85 | else: 86 | raise ValueError("should set activation correctly: {}".format(activation)) 87 | initial_parameter(self, initial_method) 88 | 89 | def forward(self, x): 90 | r""" 91 | :param torch.Tensor x: MLP接受的输入 92 | :return: torch.Tensor : MLP的输出结果 93 | """ 94 | for layer, func in zip(self.hiddens, self.hidden_active): 95 | x = self.dropout(func(layer(x))) 96 | x = self.output(x) 97 | if self.output_activation is not None: 98 | x = self.output_activation(x) 99 | x = self.dropout(x) 100 | return x 101 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "viterbi_decode" 5 | ] 6 | import torch 7 | 8 | 9 | def viterbi_decode(logits, transitions, mask=None, unpad=False): 10 | r""" 11 | 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 12 | 13 | :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 14 | :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x 15 | (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 16 | 负数,例如-10000000.0 17 | :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 18 | :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 19 | List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 20 | 个sample的有效长度。 21 | :return: 返回 (paths, scores)。 22 | paths: 是解码后的路径, 其值参照unpad参数. 23 | scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 24 | 25 | """ 26 | batch_size, seq_len, n_tags = logits.size() 27 | if transitions.size(0) == n_tags+2: 28 | include_start_end_trans = True 29 | elif transitions.size(0) == n_tags: 30 | include_start_end_trans = False 31 | else: 32 | raise RuntimeError("The shapes of transitions and feats are not " \ 33 | "compatible.") 34 | logits = logits.transpose(0, 1).data # L, B, H 35 | if mask is not None: 36 | mask = mask.transpose(0, 1).data.eq(True) # L, B 37 | else: 38 | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1) 39 | 40 | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data 41 | 42 | # dp 43 | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) 44 | vscore = logits[0] 45 | if include_start_end_trans: 46 | vscore += transitions[n_tags, :n_tags] 47 | 48 | for i in range(1, seq_len): 49 | prev_score = vscore.view(batch_size, n_tags, 1) 50 | cur_score = logits[i].view(batch_size, 1, n_tags) 51 | score = prev_score + trans_score + cur_score 52 | best_score, best_dst = score.max(1) 53 | vpath[i] = best_dst 54 | vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ 55 | vscore.masked_fill(mask[i].view(batch_size, 1), 0) 56 | 57 | if include_start_end_trans: 58 | vscore += transitions[:n_tags, n_tags + 1].view(1, -1) 59 | # backtrace 60 | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) 61 | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) 62 | lens = (mask.long().sum(0) - 1) 63 | # idxes [L, B], batched idx from seq_len-1 to 0 64 | idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len 65 | 66 | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) 67 | ans_score, last_tags = vscore.max(1) 68 | ans[idxes[0], batch_idx] = last_tags 69 | for i in range(seq_len - 1): 70 | last_tags = vpath[idxes[i], batch_idx, last_tags] 71 | ans[idxes[i + 1], batch_idx] = last_tags 72 | ans = ans.transpose(0, 1) 73 | if unpad: 74 | paths = [] 75 | for idx, seq_len in enumerate(lens): 76 | paths.append(ans[idx, :seq_len + 1].tolist()) 77 | else: 78 | paths = ans 79 | return paths, ans_score 80 | -------------------------------------------------------------------------------- /fastNLP/modules/dropout.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "TimestepDropout" 5 | ] 6 | 7 | import torch 8 | 9 | 10 | class TimestepDropout(torch.nn.Dropout): 11 | r""" 12 | 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` 13 | 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 14 | """ 15 | 16 | def forward(self, x): 17 | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) 18 | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) 19 | dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] 20 | if self.inplace: 21 | x *= dropout_mask 22 | return 23 | else: 24 | return x * dropout_mask 25 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "ConvolutionCharEncoder", 8 | "LSTMCharEncoder", 9 | 10 | "ConvMaxpool", 11 | 12 | "LSTM", 13 | 14 | "StarTransformer", 15 | 16 | "TransformerEncoder", 17 | 18 | "VarRNN", 19 | "VarLSTM", 20 | "VarGRU", 21 | 22 | "MaxPool", 23 | "MaxPoolWithMask", 24 | "KMaxPool", 25 | "AvgPool", 26 | "AvgPoolWithMask", 27 | 28 | "MultiHeadAttention", 29 | "BiAttention", 30 | "SelfAttention", 31 | 32 | "BertModel", 33 | 34 | "RobertaModel", 35 | 36 | "GPT2Model", 37 | 38 | "LSTMSeq2SeqEncoder", 39 | "TransformerSeq2SeqEncoder", 40 | "Seq2SeqEncoder" 41 | ] 42 | 43 | from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention 44 | from .bert import BertModel 45 | from .roberta import RobertaModel 46 | from .gpt2 import GPT2Model 47 | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder 48 | from .conv_maxpool import ConvMaxpool 49 | from .lstm import LSTM 50 | from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool 51 | from .star_transformer import StarTransformer 52 | from .transformer import TransformerEncoder 53 | from .variational_rnn import VarRNN, VarLSTM, VarGRU 54 | from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder 55 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/char_encoder.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ConvolutionCharEncoder", 5 | "LSTMCharEncoder" 6 | ] 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ..utils import initial_parameter 11 | 12 | 13 | # from torch.nn.init import xavier_uniform 14 | class ConvolutionCharEncoder(nn.Module): 15 | r""" 16 | char级别的卷积编码器. 17 | 18 | """ 19 | 20 | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): 21 | r""" 22 | 23 | :param int char_emb_size: char级别embedding的维度. Default: 50 24 | :例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. 25 | :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. 26 | :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. 27 | :param initial_method: 初始化参数的方式, 默认为`xavier normal` 28 | """ 29 | super(ConvolutionCharEncoder, self).__init__() 30 | self.convs = nn.ModuleList([ 31 | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, 32 | padding=(0, kernels[i] // 2)) 33 | for i in range(len(kernels))]) 34 | 35 | initial_parameter(self, initial_method) 36 | 37 | def forward(self, x): 38 | r""" 39 | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding 40 | :return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1] 41 | """ 42 | x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) 43 | # [batch_size*sent_length, channel, width, height] 44 | x = x.transpose(2, 3) 45 | # [batch_size*sent_length, channel, height, width] 46 | return self._convolute(x).unsqueeze(2) 47 | 48 | def _convolute(self, x): 49 | feats = [] 50 | for conv in self.convs: 51 | y = conv(x) 52 | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] 53 | y = torch.squeeze(y, 2) 54 | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] 55 | y = torch.tanh(y) 56 | y, __ = torch.max(y, 2) 57 | # [batch_size*sent_length, feature_maps[i]] 58 | feats.append(y) 59 | return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)] 60 | 61 | 62 | class LSTMCharEncoder(nn.Module): 63 | r""" 64 | char级别基于LSTM的encoder. 65 | """ 66 | 67 | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): 68 | r""" 69 | :param int char_emb_size: char级别embedding的维度. Default: 50 70 | 例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. 71 | :param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度 72 | :param initial_method: 初始化参数的方式, 默认为`xavier normal` 73 | """ 74 | super(LSTMCharEncoder, self).__init__() 75 | self.hidden_size = char_emb_size if hidden_size is None else hidden_size 76 | 77 | self.lstm = nn.LSTM(input_size=char_emb_size, 78 | hidden_size=self.hidden_size, 79 | num_layers=1, 80 | bias=True, 81 | batch_first=True) 82 | initial_parameter(self, initial_method) 83 | 84 | def forward(self, x): 85 | r""" 86 | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding 87 | :return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果 88 | """ 89 | batch_size = x.shape[0] 90 | h0 = torch.empty(1, batch_size, self.hidden_size) 91 | h0 = nn.init.orthogonal_(h0) 92 | c0 = torch.empty(1, batch_size, self.hidden_size) 93 | c0 = nn.init.orthogonal_(c0) 94 | 95 | _, hidden = self.lstm(x, (h0, c0)) 96 | return hidden[0].squeeze().unsqueeze(2) 97 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/conv_maxpool.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ConvMaxpool" 5 | ] 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ConvMaxpool(nn.Module): 12 | r""" 13 | 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x 14 | sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) 15 | 这一维进行max_pooling。最后得到每个sample的一个向量表示。 16 | 17 | """ 18 | 19 | def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): 20 | r""" 21 | 22 | :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 23 | :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 24 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 25 | :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh 26 | """ 27 | super(ConvMaxpool, self).__init__() 28 | 29 | for kernel_size in kernel_sizes: 30 | assert kernel_size % 2 == 1, "kernel size has to be odd numbers." 31 | 32 | # convolution 33 | if isinstance(kernel_sizes, (list, tuple, int)): 34 | if isinstance(kernel_sizes, int) and isinstance(out_channels, int): 35 | out_channels = [out_channels] 36 | kernel_sizes = [kernel_sizes] 37 | elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): 38 | assert len(out_channels) == len( 39 | kernel_sizes), "The number of out_channels should be equal to the number" \ 40 | " of kernel_sizes." 41 | else: 42 | raise ValueError("The type of out_channels and kernel_sizes should be the same.") 43 | 44 | self.convs = nn.ModuleList([nn.Conv1d( 45 | in_channels=in_channels, 46 | out_channels=oc, 47 | kernel_size=ks, 48 | stride=1, 49 | padding=ks // 2, 50 | dilation=1, 51 | groups=1, 52 | bias=None) 53 | for oc, ks in zip(out_channels, kernel_sizes)]) 54 | 55 | else: 56 | raise Exception( 57 | 'Incorrect kernel sizes: should be list, tuple or int') 58 | 59 | # activation function 60 | if activation == 'relu': 61 | self.activation = F.relu 62 | elif activation == 'sigmoid': 63 | self.activation = F.sigmoid 64 | elif activation == 'tanh': 65 | self.activation = F.tanh 66 | else: 67 | raise Exception( 68 | "Undefined activation function: choose from: relu, tanh, sigmoid") 69 | 70 | def forward(self, x, mask=None): 71 | r""" 72 | 73 | :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 74 | :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 75 | :return: 76 | """ 77 | # [N,L,C] -> [N,C,L] 78 | x = torch.transpose(x, 1, 2) 79 | # convolution 80 | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] 81 | if mask is not None: 82 | mask = mask.unsqueeze(1) # B x 1 x L 83 | xs = [x.masked_fill_(mask.eq(False), float('-inf')) for x in xs] 84 | # max-pooling 85 | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) 86 | for i in xs] # [[N, C], ...] 87 | return torch.cat(xs, dim=-1) # [N, C] 88 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/lstm.py: -------------------------------------------------------------------------------- 1 | r"""undocumented 2 | 轻量封装的 Pytorch LSTM 模块. 3 | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. 4 | """ 5 | 6 | __all__ = [ 7 | "LSTM" 8 | ] 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.utils.rnn as rnn 13 | 14 | 15 | class LSTM(nn.Module): 16 | r""" 17 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 18 | 为1; 且可以应对DataParallel中LSTM的使用问题。 19 | 20 | """ 21 | 22 | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, 23 | bidirectional=False, bias=True): 24 | r""" 25 | 26 | :param input_size: 输入 `x` 的特征维度 27 | :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 28 | :param num_layers: rnn的层数. Default: 1 29 | :param dropout: 层间dropout概率. Default: 0 30 | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` 31 | :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 32 | :(batch, seq, feature). Default: ``False`` 33 | :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` 34 | """ 35 | super(LSTM, self).__init__() 36 | self.batch_first = batch_first 37 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, 38 | dropout=dropout, bidirectional=bidirectional) 39 | self.init_param() 40 | 41 | def init_param(self): 42 | for name, param in self.named_parameters(): 43 | if 'bias' in name: 44 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 45 | param.data.fill_(0) 46 | n = param.size(0) 47 | start, end = n // 4, n // 2 48 | param.data[start:end].fill_(1) 49 | else: 50 | nn.init.xavier_uniform_(param) 51 | 52 | def forward(self, x, seq_len=None, h0=None, c0=None): 53 | r""" 54 | 55 | :param x: [batch, seq_len, input_size] 输入序列 56 | :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` 57 | :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` 58 | :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` 59 | :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 60 | 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. 61 | """ 62 | batch_size, max_len, _ = x.size() 63 | if h0 is not None and c0 is not None: 64 | hx = (h0, c0) 65 | else: 66 | hx = None 67 | if seq_len is not None and not isinstance(x, rnn.PackedSequence): 68 | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) 69 | if self.batch_first: 70 | x = x[sort_idx] 71 | else: 72 | x = x[:, sort_idx] 73 | x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) 74 | output, hx = self.lstm(x, hx) # -> [N,L,C] 75 | output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) 76 | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) 77 | if self.batch_first: 78 | output = output[unsort_idx] 79 | else: 80 | output = output[:, unsort_idx] 81 | hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] 82 | else: 83 | output, hx = self.lstm(x, hx) 84 | return output, hx 85 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/pooling.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "MaxPool", 5 | "MaxPoolWithMask", 6 | "KMaxPool", 7 | "AvgPool", 8 | "AvgPoolWithMask" 9 | ] 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class MaxPool(nn.Module): 15 | r""" 16 | Max-pooling模块。 17 | 18 | """ 19 | 20 | def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): 21 | r""" 22 | 23 | :param stride: 窗口移动大小,默认为kernel_size 24 | :param padding: padding的内容,默认为0 25 | :param dilation: 控制窗口内元素移动距离的大小 26 | :param dimension: MaxPool的维度,支持1,2,3维。 27 | :param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension 28 | :param ceil_mode: 29 | """ 30 | super(MaxPool, self).__init__() 31 | assert dimension in [1, 2, 3], f'Now we only support 1d, 2d, or 3d Pooling' 32 | self.dimension = dimension 33 | self.stride = stride 34 | self.padding = padding 35 | self.dilation = dilation 36 | self.kernel_size = kernel_size 37 | self.ceil_mode = ceil_mode 38 | 39 | def forward(self, x): 40 | if self.dimension == 1: 41 | x = torch.transpose(x, 1, 2) # [N,L,C] -> [N,C,L] 42 | pooling = nn.MaxPool1d( 43 | stride=self.stride, padding=self.padding, dilation=self.dilation, 44 | kernel_size=self.kernel_size if self.kernel_size is not None else x.size(-1), 45 | return_indices=False, ceil_mode=self.ceil_mode 46 | ) 47 | elif self.dimension == 2: 48 | pooling = nn.MaxPool2d( 49 | stride=self.stride, padding=self.padding, dilation=self.dilation, 50 | kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-2), x.size(-1)), 51 | return_indices=False, ceil_mode=self.ceil_mode 52 | ) 53 | else: 54 | pooling = nn.MaxPool3d( 55 | stride=self.stride, padding=self.padding, dilation=self.dilation, 56 | kernel_size=self.kernel_size if self.kernel_size is not None else (x.size(-3), x.size(-2), x.size(-1)), 57 | return_indices=False, ceil_mode=self.ceil_mode 58 | ) 59 | x = pooling(x) 60 | return x.squeeze(dim=-1) # [N,C,1] -> [N,C] 61 | 62 | 63 | class MaxPoolWithMask(nn.Module): 64 | r""" 65 | 带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 66 | """ 67 | 68 | def __init__(self): 69 | super(MaxPoolWithMask, self).__init__() 70 | self.inf = 10e12 71 | 72 | def forward(self, tensor, mask, dim=1): 73 | r""" 74 | :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor 75 | :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 76 | :param int dim: 需要进行max pooling的维度 77 | :return: 78 | """ 79 | masks = mask.view(mask.size(0), mask.size(1), -1) 80 | masks = masks.expand(-1, -1, tensor.size(2)).float() 81 | return torch.max(tensor + masks.le(0.5).float() * -self.inf, dim=dim)[0] 82 | 83 | 84 | class KMaxPool(nn.Module): 85 | r"""K max-pooling module.""" 86 | 87 | def __init__(self, k=1): 88 | super(KMaxPool, self).__init__() 89 | self.k = k 90 | 91 | def forward(self, x): 92 | r""" 93 | :param torch.Tensor x: [N, C, L] 初始tensor 94 | :return: torch.Tensor x: [N, C*k] k-max pool后的结果 95 | """ 96 | x, index = torch.topk(x, self.k, dim=-1, sorted=False) 97 | x = torch.reshape(x, (x.size(0), -1)) 98 | return x 99 | 100 | 101 | class AvgPool(nn.Module): 102 | r""" 103 | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] 104 | """ 105 | 106 | def __init__(self, stride=None, padding=0): 107 | super(AvgPool, self).__init__() 108 | self.stride = stride 109 | self.padding = padding 110 | 111 | def forward(self, x): 112 | r""" 113 | :param torch.Tensor x: [N, C, L] 初始tensor 114 | :return: torch.Tensor x: [N, C] avg pool后的结果 115 | """ 116 | # [N,C,L] -> [N,C] 117 | kernel_size = x.size(2) 118 | pooling = nn.AvgPool1d( 119 | kernel_size=kernel_size, 120 | stride=self.stride, 121 | padding=self.padding) 122 | x = pooling(x) 123 | return x.squeeze(dim=-1) 124 | 125 | 126 | class AvgPoolWithMask(nn.Module): 127 | r""" 128 | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling 129 | 的时候只会考虑mask为1的位置 130 | """ 131 | 132 | def __init__(self): 133 | super(AvgPoolWithMask, self).__init__() 134 | self.inf = 10e12 135 | 136 | def forward(self, tensor, mask, dim=1): 137 | r""" 138 | :param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor 139 | :param torch.LongTensor mask: [batch_size, seq_len] 0/1的mask矩阵 140 | :param int dim: 需要进行max pooling的维度 141 | :return: 142 | """ 143 | masks = mask.view(mask.size(0), mask.size(1), -1).float() 144 | return torch.sum(tensor * masks.float(), dim=dim) / torch.sum(masks.float(), dim=1) 145 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/transformer.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "TransformerEncoder" 5 | ] 6 | from torch import nn 7 | 8 | from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer 9 | 10 | 11 | class TransformerEncoder(nn.Module): 12 | r""" 13 | transformer的encoder模块,不包含embedding层 14 | 15 | """ 16 | def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): 17 | """ 18 | 19 | :param int num_layers: 多少层Transformer 20 | :param int d_model: input和output的大小 21 | :param int n_head: 多少个head 22 | :param int dim_ff: FFN中间hidden大小 23 | :param float dropout: 多大概率drop attention和ffn中间的表示 24 | """ 25 | super(TransformerEncoder, self).__init__() 26 | self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, 27 | dropout = dropout) for _ in range(num_layers)]) 28 | self.norm = nn.LayerNorm(d_model, eps=1e-6) 29 | 30 | def forward(self, x, seq_mask=None): 31 | r""" 32 | :param x: [batch, seq_len, model_size] 输入序列 33 | :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend 34 | Default: ``None`` 35 | :return: [batch, seq_len, model_size] 输出序列 36 | """ 37 | output = x 38 | if seq_mask is None: 39 | seq_mask = x.new_ones(x.size(0), x.size(1)).bool() 40 | for layer in self.layers: 41 | output = layer(output, seq_mask) 42 | return self.norm(output) 43 | -------------------------------------------------------------------------------- /fastNLP/modules/generator/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | 5 | __all__ = [ 6 | "SequenceGenerator" 7 | ] 8 | 9 | from .seq2seq_generator import SequenceGenerator -------------------------------------------------------------------------------- /fastNLP/modules/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | __all__=[ 5 | 'BertTokenizer', 6 | 7 | "GPT2Tokenizer", 8 | 9 | "RobertaTokenizer" 10 | ] 11 | 12 | from .bert_tokenizer import BertTokenizer 13 | from .gpt2_tokenizer import GPT2Tokenizer 14 | from .roberta_tokenizer import RobertaTokenizer -------------------------------------------------------------------------------- /fastNLP/modules/tokenizer/roberta_tokenizer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | 5 | __all__ = [ 6 | "RobertaTokenizer" 7 | ] 8 | 9 | import json 10 | from .gpt2_tokenizer import GPT2Tokenizer 11 | from fastNLP.io.file_utils import _get_file_name_base_on_postfix 12 | from ...io.file_utils import _get_roberta_dir 13 | 14 | PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = { 15 | "roberta-base": 512, 16 | "roberta-large": 512, 17 | "roberta-large-mnli": 512, 18 | "distilroberta-base": 512, 19 | "roberta-base-openai-detector": 512, 20 | "roberta-large-openai-detector": 512, 21 | } 22 | 23 | 24 | class RobertaTokenizer(GPT2Tokenizer): 25 | 26 | vocab_files_names = { 27 | "vocab_file": "vocab.json", 28 | "merges_file": "merges.txt", 29 | } 30 | 31 | def __init__( 32 | self, 33 | vocab_file, 34 | merges_file, 35 | errors="replace", 36 | bos_token="", 37 | eos_token="", 38 | sep_token="", 39 | cls_token="", 40 | unk_token="", 41 | pad_token="", 42 | mask_token="", 43 | **kwargs 44 | ): 45 | super().__init__( 46 | vocab_file=vocab_file, 47 | merges_file=merges_file, 48 | errors=errors, 49 | bos_token=bos_token, 50 | eos_token=eos_token, 51 | unk_token=unk_token, 52 | sep_token=sep_token, 53 | cls_token=cls_token, 54 | pad_token=pad_token, 55 | mask_token=mask_token, 56 | **kwargs, 57 | ) 58 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 59 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 60 | 61 | @classmethod 62 | def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): 63 | """ 64 | 65 | :param str model_dir_or_name: 目录或者缩写名 66 | :param kwargs: 67 | :return: 68 | """ 69 | # 它需要两个文件,第一个是vocab.json,第二个是merge_file? 70 | model_dir = _get_roberta_dir(model_dir_or_name) 71 | # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin 72 | 73 | tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json') 74 | with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: 75 | init_kwargs = json.load(tokenizer_config_handle) 76 | # Set max length if needed 77 | if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES: 78 | # if we're using a pretrained model, ensure the tokenizer 79 | # wont index sequences longer than the number of positional embeddings 80 | max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] 81 | if max_len is not None and isinstance(max_len, (int, float)): 82 | init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) 83 | 84 | # 将vocab, merge加入到init_kwargs中 85 | if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表 86 | init_kwargs['vocab_file'] = kwargs['vocab_file'] 87 | else: 88 | init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['vocab_file']) 89 | init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['merges_file']) 90 | 91 | init_inputs = init_kwargs.pop("init_inputs", ()) 92 | # Instantiate tokenizer. 93 | try: 94 | tokenizer = cls(*init_inputs, **init_kwargs) 95 | except OSError: 96 | OSError( 97 | "Unable to load vocabulary from file. " 98 | "Please check that the provided vocabulary is accessible and not corrupted." 99 | ) 100 | 101 | return tokenizer 102 | 103 | -------------------------------------------------------------------------------- /fastNLP/modules/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "initial_parameter", 8 | "summary" 9 | ] 10 | 11 | from functools import reduce 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | 18 | def initial_parameter(net, initial_method=None): 19 | r"""A method used to initialize the weights of PyTorch models. 20 | 21 | :param net: a PyTorch model 22 | :param str initial_method: one of the following initializations. 23 | 24 | - xavier_uniform 25 | - xavier_normal (default) 26 | - kaiming_normal, or msra 27 | - kaiming_uniform 28 | - orthogonal 29 | - sparse 30 | - normal 31 | - uniform 32 | 33 | """ 34 | if initial_method == 'xavier_uniform': 35 | init_method = init.xavier_uniform_ 36 | elif initial_method == 'xavier_normal': 37 | init_method = init.xavier_normal_ 38 | elif initial_method == 'kaiming_normal' or initial_method == 'msra': 39 | init_method = init.kaiming_normal_ 40 | elif initial_method == 'kaiming_uniform': 41 | init_method = init.kaiming_uniform_ 42 | elif initial_method == 'orthogonal': 43 | init_method = init.orthogonal_ 44 | elif initial_method == 'sparse': 45 | init_method = init.sparse_ 46 | elif initial_method == 'normal': 47 | init_method = init.normal_ 48 | elif initial_method == 'uniform': 49 | init_method = init.uniform_ 50 | else: 51 | init_method = init.xavier_normal_ 52 | 53 | def weights_init(m): 54 | # classname = m.__class__.__name__ 55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn 56 | if initial_method is not None: 57 | init_method(m.weight.data) 58 | else: 59 | init.xavier_normal_(m.weight.data) 60 | init.normal_(m.bias.data) 61 | elif isinstance(m, nn.LSTM): 62 | for w in m.parameters(): 63 | if len(w.data.size()) > 1: 64 | init_method(w.data) # weight 65 | else: 66 | init.normal_(w.data) # bias 67 | elif m is not None and hasattr(m, 'weight') and \ 68 | hasattr(m.weight, "requires_grad"): 69 | if len(m.weight.size()) > 1: 70 | init_method(m.weight.data) 71 | else: 72 | init.normal_(m.weight.data) # batchnorm or layernorm 73 | else: 74 | for w in m.parameters(): 75 | if w.requires_grad: 76 | if len(w.data.size()) > 1: 77 | init_method(w.data) # weight 78 | else: 79 | init.normal_(w.data) # bias 80 | # print("init else") 81 | 82 | net.apply(weights_init) 83 | 84 | 85 | def summary(model: nn.Module): 86 | r""" 87 | 得到模型的总参数量 88 | 89 | :params model: Pytorch 模型 90 | :return tuple: 包含总参数量,可训练参数量,不可训练参数量 91 | """ 92 | train = [] 93 | nontrain = [] 94 | buffer = [] 95 | 96 | def layer_summary(module: nn.Module): 97 | def count_size(sizes): 98 | return reduce(lambda x, y: x * y, sizes) 99 | 100 | for p in module.parameters(recurse=False): 101 | if p.requires_grad: 102 | train.append(count_size(p.shape)) 103 | else: 104 | nontrain.append(count_size(p.shape)) 105 | for p in module.buffers(): 106 | buffer.append(count_size(p.shape)) 107 | for subm in module.children(): 108 | layer_summary(subm) 109 | 110 | layer_summary(model) 111 | total_train = sum(train) 112 | total_nontrain = sum(nontrain) 113 | total = total_train + total_nontrain 114 | strings = [] 115 | strings.append('Total params: {:,}'.format(total)) 116 | strings.append('Trainable params: {:,}'.format(total_train)) 117 | strings.append('Non-trainable params: {:,}'.format(total_nontrain)) 118 | strings.append("Buffer params: {:,}".format(sum(buffer))) 119 | max_len = len(max(strings, key=len)) 120 | bar = '-' * (max_len + 3) 121 | strings = [bar] + strings + [bar] 122 | print('\n'.join(strings)) 123 | return total, total_train, total_nontrain 124 | 125 | 126 | def get_dropout_mask(drop_p: float, tensor: torch.Tensor): 127 | r""" 128 | 根据tensor的形状,生成一个mask 129 | 130 | :param drop_p: float, 以多大的概率置为0。 131 | :param tensor: torch.Tensor 132 | :return: torch.FloatTensor. 与tensor一样的shape 133 | """ 134 | mask_x = torch.ones_like(tensor) 135 | nn.functional.dropout(mask_x, p=drop_p, 136 | training=False, inplace=True) 137 | return mask_x 138 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | aiohttp==3.7.4.post0 3 | alembic==1.5.8 4 | antlr4-python3-runtime==4.8 5 | asn1crypto==0.24.0 6 | async-timeout==3.0.1 7 | attrs==20.3.0 8 | blessings==1.7 9 | cachetools==4.2.1 10 | certifi==2020.12.5 11 | chardet==4.0.0 12 | click==7.1.2 13 | cliff==3.7.0 14 | cmaes==0.8.2 15 | cmd2==1.5.0 16 | colorama==0.4.4 17 | colorlog==4.8.0 18 | command-not-found==0.3 19 | commonmark==0.9.1 20 | configparser==5.0.2 21 | cryptography==2.1.4 22 | dataclasses==0.8 23 | distro-info==0.18ubuntu0.18.04.1 24 | docker-pycreds==0.4.0 25 | easydict==1.9 26 | filelock==3.0.12 27 | fsspec==0.8.7 28 | future==0.18.2 29 | gitdb==4.0.5 30 | gitpython==3.1.14 31 | google-auth-oauthlib==0.4.3 32 | google-auth==1.28.0 33 | gpustat==0.6.0 34 | greenlet==1.0.0 35 | grpcio==1.36.1 36 | hydra-core==1.1.0.dev4 37 | hydra-optuna-sweeper==1.1.0.dev1 38 | idna-ssl==1.1.0 39 | idna==2.10 40 | importlib-metadata==3.10.0 41 | importlib-resources==5.1.2 42 | joblib==1.0.1 43 | keyring==10.6.0 44 | keyrings.alt==3.0 45 | language-selector==0.1 46 | mako==1.1.4 47 | markdown==3.3.4 48 | markupsafe==1.1.1 49 | multidict==5.1.0 50 | netifaces==0.10.4 51 | nltk==3.5 52 | numpy==1.19.5 53 | nvidia-ml-py3==7.352.0 54 | oauthlib==3.1.0 55 | omegaconf==2.1.0.dev24 56 | opt-einsum==3.3.0 57 | optuna==2.4.0 58 | packaging==20.9 59 | pathtools==0.1.2 60 | pbr==5.5.1 61 | pillow==8.0.0 62 | pip==9.0.1 63 | prettytable==2.1.0 64 | promise==2.3 65 | protobuf==3.15.6 66 | psutil==5.7.2 67 | pyasn1-modules==0.2.8 68 | pyasn1==0.4.8 69 | pycrypto==2.6.1 70 | pygments==2.8.1 71 | pygobject==3.26.1 72 | pyparsing==2.4.7 73 | pyperclip==1.8.2 74 | python-apt==1.6.5+ubuntu0.3 75 | python-dateutil==2.8.1 76 | python-distutils-extra==2.39 77 | python-editor==1.0.4 78 | pytorch-lightning==1.2.4 79 | pyxdg==0.25 80 | pyyaml==5.4.1 81 | regex==2020.11.13 82 | requests-oauthlib==1.3.0 83 | requests==2.25.1 84 | rich==9.13.0 85 | rsa==4.7.2 86 | sacremoses==0.0.43 87 | scipy==1.5.4 88 | secretstorage==2.3.1 89 | sentry-sdk==1.0.0 90 | setuptools==54.1.2 91 | shortuuid==1.0.1 92 | six==1.15.0 93 | smmap==3.0.5 94 | sqlalchemy==1.4.6 95 | ssh-import-id==5.7 96 | stevedore==3.3.0 97 | subprocess32==3.5.4 98 | tensorboard-plugin-wit==1.8.0 99 | tensorboard==2.4.1 100 | tokenizers==0.10.1 101 | torch-struct==0.5 102 | torch==1.7.0 103 | torchaudio==0.7.0 104 | torchmetrics==0.2.0 105 | torchvision==0.8.0 106 | tqdm==4.60.0 107 | transformers==4.3.3 108 | typing-extensions==3.7.4.3 109 | ufw==0.36 110 | urllib3==1.26.4 111 | wandb==0.10.21 112 | wcwidth==0.2.5 113 | werkzeug==1.0.1 114 | wheel==0.36.2 115 | yarl==1.6.3 116 | zipp==3.4.1 117 | -------------------------------------------------------------------------------- /src/callbacks/progressbar.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint, ProgressBar 2 | from tqdm import tqdm 3 | 4 | class PrettyProgressBar(ProgressBar): 5 | """Good print wrapper.""" 6 | def __init__(self, refresh_rate: int, process_position: int): 7 | super().__init__(refresh_rate=refresh_rate, process_position=process_position) 8 | 9 | def init_sanity_tqdm(self) -> tqdm: 10 | bar = tqdm(desc='Validation sanity check', 11 | position=self.process_position, 12 | disable=self.is_disabled, 13 | leave=False, 14 | ncols=120, 15 | ascii=True) 16 | return bar 17 | 18 | def init_train_tqdm(self) -> tqdm: 19 | bar = tqdm(desc='Training', 20 | initial=self.train_batch_idx, 21 | position=self.process_position, 22 | disable=self.is_disabled, 23 | leave=True, 24 | smoothing=0, 25 | ncols=120, 26 | ascii=True) 27 | return bar 28 | 29 | 30 | 31 | def init_validation_tqdm(self) -> tqdm: 32 | bar = tqdm(disable=True) 33 | return bar 34 | 35 | def on_epoch_start(self, trainer, pl_module): 36 | super().on_epoch_start(trainer, pl_module) 37 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|train') 38 | 39 | def on_validation_start(self, trainer, pl_module): 40 | super().on_validation_start(trainer, pl_module) 41 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|val') 42 | 43 | def on_test_start(self, trainer, pl_module): 44 | super().on_test_start(trainer, pl_module) 45 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|test') 46 | 47 | -------------------------------------------------------------------------------- /src/callbacks/transformer_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 3 | import logging 4 | log = logging.getLogger(__name__) 5 | 6 | class TransformerLrScheduler(pl.Callback): 7 | def __init__(self, warmup): 8 | self.warmup = warmup 9 | 10 | def on_train_start(self, trainer, pl_module): 11 | for lr_scheduler in trainer.lr_schedulers: 12 | scheduler = lr_scheduler['scheduler'] 13 | n_train = len(pl_module.train_dataloader()) 14 | n_accumulate_grad = trainer.accumulate_grad_batches 15 | n_max_epochs = trainer.max_epochs 16 | num_training_steps = n_train // n_accumulate_grad * n_max_epochs 17 | num_warmup_steps = int(self.warmup * num_training_steps) 18 | 19 | if pl_module.optimizer_cfg.scheduler_type == 'linear_warmup': 20 | lr_scheduler['scheduler'] = get_linear_schedule_with_warmup(scheduler.optimizer, num_warmup_steps, num_training_steps) 21 | 22 | elif pl_module.optimizer_cfg.scheduler_type == 'constant_warmup': 23 | lr_scheduler['scheduler'] = get_constant_schedule_with_warmup(scheduler.optimizer, num_warmup_steps, 24 | ) 25 | 26 | log.info(f"Warm up rate:{self.warmup}") 27 | log.info(f"total number of training step:{num_training_steps}") 28 | log.info(f"number of training batches per epochs in the dataloader:{n_train}") -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | # wandb 2 | from pytorch_lightning.loggers import WandbLogger 3 | import wandb 4 | 5 | # pytorch 6 | from pytorch_lightning import Callback 7 | import pytorch_lightning as pl 8 | import torch 9 | 10 | # others 11 | import glob 12 | import os 13 | 14 | 15 | def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger: 16 | logger = None 17 | for lg in trainer.logger: 18 | if isinstance(lg, WandbLogger): 19 | logger = lg 20 | 21 | if not logger: 22 | raise Exception( 23 | "You are using wandb related callback," 24 | "but WandbLogger was not found for some reason..." 25 | ) 26 | 27 | return logger 28 | 29 | 30 | # class UploadCodeToWandbAsArtifact(Callback): 31 | # """Upload all *.py files to wandb as an artifact at the beginning of the run.""" 32 | # 33 | # def __init__(self, code_dir: str): 34 | # self.code_dir = code_dir 35 | # 36 | # def on_train_start(self, trainer, pl_module): 37 | # logger = get_wandb_logger(trainer=trainer) 38 | # experiment = logger.experiment 39 | # 40 | # code = wandb.Artifact("project-source", type="code") 41 | # for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): 42 | # print(path) 43 | # code.add_file(path) 44 | # print('ok') 45 | # 46 | # 47 | # 48 | # experiment.use_artifact(code) 49 | # print("successfully update the code .") 50 | 51 | 52 | class UploadHydraConfigFileToWandb(Callback): 53 | def on_fit_start(self, trainer, pl_module: LightningModule) -> None: 54 | logger = get_wandb_logger(trainer=trainer) 55 | 56 | logger.experiment.save() 57 | 58 | 59 | 60 | class UploadCheckpointsToWandbAsArtifact(Callback): 61 | """Upload experiment checkpoints to wandb as an artifact at the end of training.""" 62 | 63 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): 64 | self.ckpt_dir = ckpt_dir 65 | self.upload_best_only = upload_best_only 66 | 67 | def on_train_end(self, trainer, pl_module): 68 | logger = get_wandb_logger(trainer=trainer) 69 | experiment = logger.experiment 70 | 71 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 72 | 73 | if self.upload_best_only: 74 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 75 | else: 76 | for path in glob.glob( 77 | os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True 78 | ): 79 | ckpts.add_file(path) 80 | 81 | experiment.use_artifact(ckpts) 82 | 83 | 84 | class WatchModelWithWandb(Callback): 85 | """Make WandbLogger watch model at the beginning of the run.""" 86 | 87 | def __init__(self, log: str = "gradients", log_freq: int = 100): 88 | self.log = log 89 | self.log_freq = log_freq 90 | 91 | def on_train_start(self, trainer, pl_module): 92 | logger = get_wandb_logger(trainer=trainer) 93 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) 94 | 95 | -------------------------------------------------------------------------------- /src/constant.py: -------------------------------------------------------------------------------- 1 | pad = '' 2 | unk = '' 3 | bos = '' 4 | eos = '' 5 | arc = 'arc' 6 | rel = 'rel' 7 | word = 'word' 8 | chart = 'chart' 9 | raw_tree = 'raw_tree' 10 | valid = 'valid' 11 | sib = 'sibling' 12 | relatives = 'relatives' 13 | span_head = 'span_head' 14 | 15 | -------------------------------------------------------------------------------- /src/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/datamodule/benepar/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | benepar: Berkeley Neural Parser 3 | """ 4 | 5 | # This file and all code in integrations/ relate to the version of the parser 6 | # released via PyPI. If you only need to run research experiments, it is safe 7 | # to delete the integrations/ folder and replace this __init__.py with an 8 | # empty file. 9 | 10 | __all__ = [ 11 | "Parser", 12 | "InputSentence", 13 | "download", 14 | "BeneparComponent", 15 | "NonConstituentException", 16 | ] 17 | 18 | from .integrations.downloader import download 19 | from .integrations.nltk_plugin import Parser, InputSentence 20 | from .integrations.spacy_plugin import BeneparComponent, NonConstituentException 21 | -------------------------------------------------------------------------------- /src/datamodule/benepar/integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sustcsonglin/pointer-net-for-nested/fe9932a361b6af64c87491c3f312917e9a3a40a0/src/datamodule/benepar/integrations/__init__.py -------------------------------------------------------------------------------- /src/datamodule/benepar/integrations/downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | BENEPAR_SERVER_INDEX = "https://kitaev.com/benepar/index.xml" 4 | 5 | _downloader = None 6 | def get_downloader(): 7 | global _downloader 8 | if _downloader is None: 9 | import nltk.downloader 10 | _downloader = nltk.downloader.Downloader(server_index_url=BENEPAR_SERVER_INDEX) 11 | return _downloader 12 | 13 | def download(*args, **kwargs): 14 | return get_downloader().download(*args, **kwargs) 15 | 16 | def locate_model(name): 17 | if os.path.exists(name): 18 | return name 19 | elif "/" not in name and "." not in name: 20 | import nltk.data 21 | try: 22 | nltk_loc = nltk.data.find(f"models/{name}") 23 | return nltk_loc.path 24 | except LookupError as e: 25 | arg = e.args[0].replace("nltk.download", "benepar.download") 26 | 27 | raise LookupError(arg) 28 | 29 | raise LookupError("Can't find {}".format(name)) 30 | 31 | def load_trained_model(model_name_or_path): 32 | model_path = locate_model(model_name_or_path) 33 | from ..parse_chart import ChartParser 34 | parser = ChartParser.from_trained(model_path) 35 | return parser 36 | -------------------------------------------------------------------------------- /src/datamodule/benepar/nkutil.py: -------------------------------------------------------------------------------- 1 | class HParams: 2 | _skip_keys = ["populate_arguments", "set_from_args", "print", "to_dict"] 3 | 4 | def __init__(self, **kwargs): 5 | for k, v in kwargs.items(): 6 | setattr(self, k, v) 7 | 8 | def __getitem__(self, item): 9 | return getattr(self, item) 10 | 11 | def __setitem__(self, item, value): 12 | if not hasattr(self, item): 13 | raise KeyError(f"Hyperparameter {item} has not been declared yet") 14 | setattr(self, item, value) 15 | 16 | def to_dict(self): 17 | res = {} 18 | for k in dir(self): 19 | if k.startswith("_") or k in self._skip_keys: 20 | continue 21 | res[k] = self[k] 22 | return res 23 | 24 | def populate_arguments(self, parser): 25 | for k in dir(self): 26 | if k.startswith("_") or k in self._skip_keys: 27 | continue 28 | v = self[k] 29 | k = k.replace("_", "-") 30 | if type(v) in (int, float, str): 31 | parser.add_argument(f"--{k}", type=type(v), default=v) 32 | elif isinstance(v, bool): 33 | if not v: 34 | parser.add_argument(f"--{k}", action="store_true") 35 | else: 36 | parser.add_argument(f"--no-{k}", action="store_false") 37 | 38 | def set_from_args(self, args): 39 | for k in dir(self): 40 | if k.startswith("_") or k in self._skip_keys: 41 | continue 42 | if hasattr(args, k): 43 | self[k] = getattr(args, k) 44 | elif hasattr(args, f"no_{k}"): 45 | self[k] = getattr(args, f"no_{k}") 46 | 47 | def print(self): 48 | for k in dir(self): 49 | if k.startswith("_") or k in self._skip_keys: 50 | continue 51 | print(k, repr(self[k])) 52 | -------------------------------------------------------------------------------- /src/datamodule/benepar/ptb_unescape.py: -------------------------------------------------------------------------------- 1 | PTB_UNESCAPE_MAPPING = { 2 | "«": '"', 3 | "»": '"', 4 | "‘": "'", 5 | "’": "'", 6 | "“": '"', 7 | "”": '"', 8 | "„": '"', 9 | "‹": "'", 10 | "›": "'", 11 | "\u2013": "--", # en dash 12 | "\u2014": "--", # em dash 13 | } 14 | 15 | NO_SPACE_BEFORE = {"-RRB-", "-RCB-", "-RSB-", "''"} | set("%.,!?:;") 16 | NO_SPACE_AFTER = {"-LRB-", "-LCB-", "-LSB-", "``", "`"} | set("$#") 17 | NO_SPACE_BEFORE_TOKENS_ENGLISH = {"'", "'s", "'ll", "'re", "'d", "'m", "'ve"} 18 | PTB_DASH_ESCAPED = {"-RRB-", "-RCB-", "-RSB-", "-LRB-", "-LCB-", "-LSB-", "--"} 19 | 20 | 21 | def ptb_unescape(words): 22 | cleaned_words = [] 23 | for word in words: 24 | word = PTB_UNESCAPE_MAPPING.get(word, word) 25 | # This un-escaping for / and * was not yet added for the 26 | # parser version in https://arxiv.org/abs/1812.11760v1 27 | # and related model releases (e.g. benepar_en2) 28 | word = word.replace("\\/", "/").replace("\\*", "*") 29 | # Mid-token punctuation occurs in biomedical text 30 | word = word.replace("-LSB-", "[").replace("-RSB-", "]") 31 | word = word.replace("-LRB-", "(").replace("-RRB-", ")") 32 | word = word.replace("-LCB-", "{").replace("-RCB-", "}") 33 | word = word.replace("``", '"').replace("`", "'").replace("''", '"') 34 | cleaned_words.append(word) 35 | return cleaned_words 36 | 37 | 38 | def guess_space_after_non_english(escaped_words): 39 | sp_after = [True for _ in escaped_words] 40 | for i, word in enumerate(escaped_words): 41 | if i > 0 and ( 42 | ( 43 | word.startswith("-") 44 | and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) 45 | ) 46 | or any(word.startswith(x) for x in NO_SPACE_BEFORE) 47 | or word == "'" 48 | ): 49 | sp_after[i - 1] = False 50 | if ( 51 | word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) 52 | ) or any(word.endswith(x) for x in NO_SPACE_AFTER): 53 | sp_after[i] = False 54 | 55 | return sp_after 56 | 57 | 58 | def guess_space_after(escaped_words, for_english=True): 59 | if not for_english: 60 | return guess_space_after_non_english(escaped_words) 61 | 62 | sp_after = [True for _ in escaped_words] 63 | for i, word in enumerate(escaped_words): 64 | if word.lower() == "n't" and i > 0: 65 | sp_after[i - 1] = False 66 | elif word.lower() == "not" and i > 0 and escaped_words[i - 1].lower() == "can": 67 | sp_after[i - 1] = False 68 | 69 | if i > 0 and ( 70 | ( 71 | word.startswith("-") 72 | and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) 73 | ) 74 | or any(word.startswith(x) for x in NO_SPACE_BEFORE) 75 | or word.lower() in NO_SPACE_BEFORE_TOKENS_ENGLISH 76 | ): 77 | sp_after[i - 1] = False 78 | if ( 79 | word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) 80 | ) or any(word.endswith(x) for x in NO_SPACE_AFTER): 81 | sp_after[i] = False 82 | 83 | return sp_after 84 | -------------------------------------------------------------------------------- /src/datamodule/benepar/spacy_plugin.py: -------------------------------------------------------------------------------- 1 | __all__ = ["BeneparComponent", "NonConstituentException"] 2 | 3 | import warnings 4 | 5 | from .integrations.spacy_plugin import BeneparComponent, NonConstituentException 6 | 7 | warnings.warn( 8 | "BeneparComponent and NonConstituentException have been moved to the benepar " 9 | "module. Use `from benepar import BeneparComponent, NonConstituentException` " 10 | "instead of benepar.spacy_plugin. The benepar.spacy_plugin namespace is deprecated " 11 | "and will be removed in a future version.", 12 | FutureWarning, 13 | ) 14 | -------------------------------------------------------------------------------- /src/datamodule/benepar/subbatching.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for splitting batches of examples into smaller sub-batches. 3 | 4 | This is useful during training when the batch size is too large to fit on GPU, 5 | meaning that gradient accumulation across multiple sub-batches must be used. 6 | It is also useful for batching examples during evaluation. Unlike a naive 7 | approach, this code groups examples with similar lengths to reduce the amount 8 | of wasted computation due to padding. 9 | """ 10 | 11 | import numpy as np 12 | 13 | 14 | def split(*data, costs, max_cost): 15 | """Splits a batch of input items into sub-batches. 16 | 17 | Args: 18 | *data: One or more lists of input items, all of the same length 19 | costs: A list of costs for each item 20 | max_cost: Maximum total cost for each sub-batch 21 | 22 | Yields: 23 | (example_ids, *subbatch_data) tuples. 24 | """ 25 | costs = np.asarray(costs, dtype=int) 26 | costs_argsort = np.argsort(costs).tolist() 27 | 28 | subbatch_size = 1 29 | while costs_argsort: 30 | if subbatch_size == len(costs_argsort) or ( 31 | subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost 32 | ): 33 | subbatch_item_ids = costs_argsort[:subbatch_size] 34 | subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data] 35 | yield (subbatch_item_ids,) + tuple(subbatch_data) 36 | costs_argsort = costs_argsort[subbatch_size:] 37 | subbatch_size = 1 38 | else: 39 | subbatch_size += 1 40 | 41 | 42 | def map(func, *data, costs, max_cost, **common_kwargs): 43 | """Maps a function over subbatches of input items. 44 | 45 | Args: 46 | func: Function to map over the data 47 | *data: One or more lists of input items, all of the same length. 48 | costs: A list of costs for each item 49 | max_cost: Maximum total cost for each sub-batch 50 | **common_kwargs: Keyword arguments to pass to all calls of func 51 | 52 | Returns: 53 | A list of outputs from calling func(*subbatch_data, **kwargs) for each 54 | subbatch, and then rearranging the outputs from func into the original 55 | item order. 56 | """ 57 | res = [None] * len(data[0]) 58 | for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost): 59 | subbatch_out = func(*subbatch_items, **common_kwargs) 60 | for item_id, item_out in zip(item_ids, subbatch_out): 61 | res[item_id] = item_out 62 | return res 63 | -------------------------------------------------------------------------------- /src/datamodule/dm_util/datamodule_util.py: -------------------------------------------------------------------------------- 1 | from supar.utils.alg import kmeans 2 | from supar.utils.data import Sampler 3 | from fastNLP.core.field import Padder 4 | import numpy as np 5 | 6 | 7 | 8 | 9 | def get_sampler(lengths, max_tokens, n_buckets, shuffle=True, distributed=False, evaluate=False): 10 | buckets = dict(zip(*kmeans(lengths, n_buckets))) 11 | return Sampler(buckets=buckets, 12 | batch_size=max_tokens, 13 | shuffle=shuffle, 14 | distributed=distributed, 15 | evaluate=evaluate) 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /src/model/module/ember/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from supar.modules.char_lstm import CharLSTM 4 | from supar.modules import TransformerEmbedding 5 | import copy 6 | 7 | class Embeder(nn.Module): 8 | def __init__(self, conf, fields): 9 | super(Embeder, self).__init__() 10 | self.conf = conf 11 | 12 | if 'pos' in fields.inputs: 13 | self.pos_emb = nn.Embedding(fields.get_vocab_size("pos"), conf.n_pos_embed) 14 | else: 15 | self.pos_emb = None 16 | 17 | if 'char' in fields.inputs: 18 | self.feat = CharLSTM(n_chars=fields.get_vocab_size('char'), 19 | n_embed=conf.n_char_embed, 20 | n_out=conf.n_char_out, 21 | pad_index=fields.get_pad_index('char'), 22 | input_dropout=conf.char_input_dropout) 23 | self.feat_name = 'char' 24 | 25 | if 'bert' in fields.inputs: 26 | self.feat = TransformerEmbedding(model=fields.get_bert_name(), 27 | n_layers=conf.n_bert_layers, 28 | n_out=conf.n_bert_out, 29 | pad_index=fields.get_pad_index("bert"), 30 | dropout=conf.mix_dropout, 31 | requires_grad=conf.finetune, 32 | use_projection=conf.use_projection, 33 | use_scalarmix=conf.use_scalarmix) 34 | self.feat_name = "bert" 35 | print(fields.get_bert_name()) 36 | 37 | if ('char' not in fields.inputs and 'bert' not in fields.inputs): 38 | self.feat = None 39 | 40 | if 'word' in fields.inputs: 41 | ext_emb = fields.get_ext_emb() 42 | if ext_emb: 43 | self.word_emb = copy.deepcopy(ext_emb) 44 | else: 45 | self.word_emb = nn.Embedding(num_embeddings=fields.get_vocab_size('word'), 46 | embedding_dim=conf.n_embed) 47 | else: 48 | self.word_emb = None 49 | 50 | 51 | def forward(self, ctx): 52 | emb = {} 53 | 54 | if self.pos_emb: 55 | emb['pos'] = self.pos_emb(ctx['pos']) 56 | 57 | if self.word_emb: 58 | emb['word'] = self.word_emb(ctx['word']) 59 | 60 | #For now, char or ber、t, choose one. 61 | if self.feat: 62 | emb[self.feat_name] = self.feat(ctx[self.feat_name]) 63 | 64 | ctx['embed'] = emb 65 | 66 | 67 | def get_output_dim(self): 68 | 69 | size = 0 70 | 71 | if self.pos_emb: 72 | size += self.conf.n_pos_embed 73 | 74 | if self.word_emb: 75 | if isinstance(self.word_emb, nn.Embedding): 76 | size += self.conf.n_embed 77 | else: 78 | size += self.word_emb.get_dim() 79 | 80 | if self.feat: 81 | if self.feat_name == 'char': 82 | size += self.conf.n_char_out 83 | else: 84 | size += self.feat.n_out 85 | return size 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /src/model/module/ember/ext_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import logging 3 | 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | # As Supar. 8 | class ExternalEmbeddingSupar(nn.Module): 9 | def __init__(self, emb, origin_word_size, unk_index): 10 | super(ExternalEmbeddingSupar, self).__init__() 11 | 12 | self.pretrained = nn.Embedding.from_pretrained(emb) 13 | self.origin_word_size = origin_word_size 14 | self.word_emb = nn.Embedding(origin_word_size, emb.shape[-1]) 15 | self.unk_index = unk_index 16 | nn.init.zeros_(self.word_emb.weight) 17 | 18 | def forward(self, words): 19 | ext_mask = words.ge(self.word_emb.num_embeddings) 20 | ext_words = words.masked_fill(ext_mask, self.unk_index) 21 | # get outputs from embedding layers 22 | word_embed = self.word_emb(ext_words) 23 | word_embed += self.pretrained(words) 24 | return word_embed 25 | 26 | def get_dim(self): 27 | return self.word_emb.weight.shape[-1] -------------------------------------------------------------------------------- /src/model/module/encoder/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import LSTM 3 | from supar.modules.dropout import IndependentDropout, SharedDropout 4 | import torch 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | 8 | class LSTMencoder(nn.Module): 9 | def __init__(self, conf, input_dim, **kwargs): 10 | super(LSTMencoder, self).__init__() 11 | self.conf = conf 12 | 13 | self.before_lstm_dropout = None 14 | 15 | if self.conf.embed_dropout_type == 'independent': 16 | self.embed_dropout = IndependentDropout(p=conf.embed_dropout) 17 | if conf.before_lstm_dropout: 18 | self.before_lstm_dropout = SharedDropout(p=conf.before_lstm_dropout) 19 | 20 | elif self.conf.embed_dropout_type == 'shared': 21 | self.embed_dropout = SharedDropout(p=conf.embed_dropout) 22 | 23 | elif self.conf.embed_dropout_type == 'simple': 24 | self.embed_dropout = nn.Dropout(p=conf.embed_dropout) 25 | 26 | else: 27 | self.embed_dropout = nn.Dropout(0.) 28 | 29 | self.lstm = LSTM(input_size=input_dim, 30 | hidden_size=conf.n_lstm_hidden, 31 | num_layers=conf.n_lstm_layers, 32 | bidirectional=True, 33 | dropout=conf.lstm_dropout) 34 | self.lstm_dropout = SharedDropout(p=conf.lstm_dropout) 35 | 36 | 37 | 38 | def forward(self, info): 39 | # lstm encoder 40 | embed = info['embed'] 41 | seq_len = info['seq_len'] 42 | 43 | embed = [i for i in embed.values()] 44 | 45 | if self.conf.embed_dropout_type == 'independent': 46 | embed = self.embed_dropout(embed) 47 | embed = torch.cat(embed, dim=-1) 48 | else: 49 | embed = torch.cat(embed, dim=-1) 50 | embed = self.embed_dropout(embed) 51 | 52 | seq_len = seq_len.cpu() 53 | x = pack_padded_sequence(embed, seq_len.cpu() + (embed.shape[1] - seq_len.max()), True, False) 54 | x, _ = self.lstm(x) 55 | x, _ = pad_packed_sequence(x, True, total_length=embed.shape[1]) 56 | x = self.lstm_dropout(x) 57 | 58 | info['encoded_emb'] = x 59 | 60 | def get_output_dim(self): 61 | return self.conf.n_lstm_hidden * 2 62 | 63 | -------------------------------------------------------------------------------- /src/model/module/scorer/const_scorer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .module.biaffine import BiaffineScorer 3 | import torch 4 | 5 | 6 | 7 | class ConstScorer(nn.Module): 8 | def __init__(self, conf, fields, input_dim): 9 | super(ConstScorer, self).__init__() 10 | self.conf = conf 11 | 12 | if self.conf.use_span: 13 | self.span_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_span, bias_x=True, bias_y=False, dropout=conf.mlp_dropout) 14 | self.label_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_label, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, n_out_label=fields.get_vocab_size("chart")) 15 | self.null_idx = fields.get_vocab('chart')['NULL'] 16 | 17 | # if self.conf.use_transition: 18 | # using transition scores? 19 | # vocab_size = fields.get_vocab_size('chart') 20 | # self.transition = nn.Parameter(torch.rand(vocab_size, vocab_size)) 21 | 22 | 23 | def forward(self, ctx): 24 | x = ctx['encoded_emb'] 25 | 26 | if 'span_repr' not in ctx: 27 | x_f, x_b = x.chunk(2, -1) 28 | x = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1) 29 | else: 30 | x = ctx['span_repr'] 31 | 32 | s_span = self.label_scorer(x) 33 | if self.conf.use_span: 34 | s_span += self.span_scorer(x).unsqueeze(-1) 35 | mask = s_span.new_zeros(s_span.shape[1], s_span.shape[1], dtype=torch.bool) 36 | mask.diagonal(1).fill_(1) 37 | s_span[..., self.null_idx].masked_fill_(~mask.unsqueeze(0).expand(s_span.shape[0], s_span.shape[1], s_span.shape[1]), -1e9) 38 | ctx['s_span'] = s_span 39 | 40 | 41 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/biaffine.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import MLP, Biaffine 3 | 4 | class BiaffineScorer(nn.Module): 5 | def __init__(self, n_in=800, n_out=400, n_out_label=1, bias_x=True, bias_y=False, scaling=False, dropout=0.33): 6 | super(BiaffineScorer, self).__init__() 7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 8 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 9 | self.attn = Biaffine(n_in=n_out, n_out=n_out_label, bias_x=bias_x, bias_y=bias_y) 10 | self.scaling = 0 if not scaling else n_out ** (1/4) 11 | self.n_in = n_in 12 | 13 | def forward(self, h): 14 | left = self.l(h) 15 | right = self.r(h) 16 | 17 | if self.scaling: 18 | left = left / self.scaling 19 | right = right / self.scaling 20 | 21 | return self.attn(left, right) 22 | 23 | 24 | def forward_v2(self, h, q): 25 | left = self.l(h) 26 | right = self.r(q) 27 | if self.scaling: 28 | left = left / self.scaling 29 | right = right / self.scaling 30 | 31 | return self.attn(left, right) 32 | 33 | # def forward_v3(self, h, q): 34 | 35 | def forward_v3(self, h, q): 36 | src = self.l(h) 37 | dec = self.r(q) 38 | return self.attn.forward_v2(src, dec) 39 | 40 | def forward_linear(self, h, q): 41 | src = self.l(h) 42 | dec = self.r(q) 43 | return self.attn.forward2(src, dec) 44 | 45 | 46 | class BiaffineScorer2(nn.Module): 47 | def __init__(self, n_in_a=800, n_in_b=800, n_out=400, n_out_label=1, bias_x=False, bias_y=False, scaling=False, dropout=0.33): 48 | super(BiaffineScorer2, self).__init__() 49 | self.l = MLP(n_in=n_in_a, n_out=n_out, dropout=dropout) 50 | self.r = MLP(n_in=n_in_b, n_out=n_out, dropout=dropout) 51 | self.attn = Biaffine(n_in=n_out, n_out=n_out_label, bias_x=bias_x, bias_y=bias_y) 52 | self.scaling = 0 if not scaling else n_out ** (1/4) 53 | 54 | def forward(self, h, q): 55 | src = self.l(h) 56 | dec = self.r(q) 57 | return self.attn.forward_v2(src, dec) 58 | 59 | def forward2(self, h, q): 60 | src = self.l(h) 61 | dec = self.r(q) 62 | return self.attn.forward_v3(src, dec) 63 | 64 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/triaffine.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import MLP, Triaffine 3 | 4 | class TriaffineScorer(nn.Module): 5 | def __init__(self, n_in=800, n_out=400, bias_x=True, bias_y=False, dropout=0.33): 6 | super(TriaffineScorer, self).__init__() 7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 8 | self.m = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 9 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 10 | self.attn = Triaffine(n_in=n_out, bias_x=bias_x, bias_y=bias_y) 11 | 12 | def forward(self, h): 13 | left = self.l(h) 14 | mid = self.m(h) 15 | right = self.r(h) 16 | #sib, dependent, head) 17 | return self.attn(left, mid, right).permute(0, 2, 3, 1) 18 | 19 | def forward2(self, word, span): 20 | left = self.l(word) 21 | mid = self.m(span) 22 | right = self.r(span) 23 | # head, left_bdr, right_bdr: used in span-head model? 24 | 25 | # word, left, right 26 | # fine. 27 | return self.attn(mid, right, left).permute(0, 2, 3, 1) 28 | 29 | 30 | 31 | # class TriaffineScorer 32 | # class TriaffineScorer(nn.Module): 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /src/model/module/scorer/self_attentive.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from .module.biaffine import BiaffineScorer 4 | 5 | 6 | class LabelMLPScorer(nn.Module): 7 | def __init__(self, conf, fields, input_dim): 8 | super(LabelMLPScorer, self).__init__() 9 | self.conf = conf 10 | self.f_label = nn.Sequential( 11 | nn.Linear(input_dim, conf.d_label_hidden), 12 | nn.LayerNorm(conf.d_label_hidden), 13 | nn.ReLU(), 14 | nn.Linear(conf.d_label_hidden, fields.get_vocab_size('chart')), 15 | ) 16 | 17 | def forward(self, ctx): 18 | fence_post = ctx['fencepost'] 19 | span_repr = fence_post.unsqueeze(1) - fence_post.unsqueeze(2) 20 | ctx['s_span'] = self.f_label(span_repr) 21 | 22 | 23 | -------------------------------------------------------------------------------- /src/model/parsing.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import logging 3 | import hydra 4 | log = logging.getLogger(__name__) 5 | 6 | class Parser(nn.Module): 7 | def __init__(self, conf, fields): 8 | super(Parser, self).__init__() 9 | self.conf = conf 10 | self.fields = fields 11 | self.embeder = hydra.utils.instantiate(conf.embeder.target, conf.embeder, fields=fields, _recursive_=False) 12 | self.encoder = hydra.utils.instantiate(conf.encoder.target, conf.encoder, input_dim= self.embeder.get_output_dim(), _recursive_=False) 13 | self.scorer = hydra.utils.instantiate(conf.scorer.target, conf.scorer, fields=fields, input_dim=self.encoder.get_output_dim(), _recursive_=False) 14 | self.loss = hydra.utils.instantiate(conf.loss.target, conf.loss, _recursive_=False) 15 | self.metric = hydra.utils.instantiate(conf.metric.target, conf.metric, fields=fields, _recursive_=False) 16 | 17 | log.info(self.embeder) 18 | log.info(self.encoder) 19 | log.info(self.scorer) 20 | 21 | def forward(self, ctx): 22 | self.embeder(ctx) 23 | self.encoder(ctx) 24 | self.scorer(ctx) 25 | 26 | def get_loss(self, x, y): 27 | ctx = {**x, **y} 28 | self.forward(ctx) 29 | return self.loss.loss(ctx) 30 | 31 | def decode(self, x, y): 32 | ctx = {**x, **y} 33 | self.forward(ctx) 34 | self.loss.decode(ctx) 35 | return ctx 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /supar/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | __version__ = '1.0.0' 6 | 7 | 8 | PRETRAINED = { 9 | 'biaffine-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.char.zip', 10 | 'biaffine-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.char.zip', 11 | 'biaffine-dep-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.bert.zip', 12 | 'biaffine-dep-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.bert.zip', 13 | 'crfnp-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crfnp.dependency.char.zip', 14 | 'crfnp-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crfnp.dependency.char.zip', 15 | 'crf-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.dependency.char.zip', 16 | 'crf-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.dependency.char.zip', 17 | 'crf2o-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf2o.dependency.char.zip', 18 | 'crf2o-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf2o.dependency.char.zip', 19 | 'crf-con-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.char.zip', 20 | 'crf-con-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.char.zip', 21 | 'crf-con-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.bert.zip', 22 | 'crf-con-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.bert.zip' 23 | } 24 | -------------------------------------------------------------------------------- /supar/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .affine import Biaffine, Triaffine 4 | from .bert import BertEmbedding 5 | from .char_lstm import CharLSTM 6 | from .dropout import IndependentDropout, SharedDropout 7 | from .lstm import LSTM 8 | from .mlp import MLP 9 | from .scalar_mix import ScalarMix 10 | from .treecrf import (CRF2oDependency, CRFConstituency, CRFDependency, 11 | MatrixTree) 12 | from .variational_inference import (LBPSemanticDependency, 13 | MFVISemanticDependency) 14 | from .transformer import TransformerEmbedding 15 | 16 | __all__ = ['LSTM', 'MLP', 'BertEmbedding', 'Biaffine', 'CharLSTM', 'CRF2oDependency', 'CRFConstituency', 'CRFDependency', 17 | 'IndependentDropout', 'LBPSemanticDependency', 'MatrixTree', 18 | 'MFVISemanticDependency', 'ScalarMix', 'SharedDropout', 'Triaffine', 'TransformerEmbedding'] 19 | -------------------------------------------------------------------------------- /supar/modules/char_lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | 8 | class CharLSTM(nn.Module): 9 | r""" 10 | CharLSTM aims to generate character-level embeddings for tokens. 11 | It summerizes the information of characters in each token to an embedding using a LSTM layer. 12 | 13 | Args: 14 | n_char (int): 15 | The number of characters. 16 | n_embed (int): 17 | The size of each embedding vector as input to LSTM. 18 | n_out (int): 19 | The size of each output vector. 20 | pad_index (int): 21 | The index of the padding token in the vocabulary. Default: 0. 22 | """ 23 | 24 | def __init__(self, n_chars, n_embed, n_out, pad_index=0, input_dropout=0.): 25 | super().__init__() 26 | 27 | self.n_chars = n_chars 28 | self.n_embed = n_embed 29 | self.n_out = n_out 30 | self.pad_index = pad_index 31 | 32 | # the embedding layer 33 | self.embed = nn.Embedding(num_embeddings=n_chars, 34 | embedding_dim=n_embed) 35 | # the lstm layer 36 | self.lstm = nn.LSTM(input_size=n_embed, 37 | hidden_size=n_out//2, 38 | batch_first=True, 39 | bidirectional=True) 40 | 41 | self.input_dropout = nn.Dropout(input_dropout) 42 | 43 | def __repr__(self): 44 | return f"{self.__class__.__name__}({self.n_chars}, {self.n_embed}, n_out={self.n_out}, pad_index={self.pad_index})" 45 | 46 | def forward(self, x): 47 | r""" 48 | Args: 49 | x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. 50 | Characters of all tokens. 51 | Each token holds no more than `fix_len` characters, and the excess is cut off directly. 52 | Returns: 53 | ~torch.Tensor: 54 | The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters. 55 | """ 56 | # [batch_size, seq_len, fix_len] 57 | mask = x.ne(self.pad_index) 58 | # [batch_size, seq_len] 59 | lens = mask.sum(-1) 60 | char_mask = lens.gt(0) 61 | 62 | # [n, fix_len, n_embed] 63 | x = self.embed(x[char_mask]) 64 | x = self.input_dropout(x) 65 | x = pack_padded_sequence(x, lens[char_mask].cpu(), True, False) 66 | x, (h, _) = self.lstm(x) 67 | # [n, fix_len, n_out] 68 | h = torch.cat(torch.unbind(h), -1) 69 | # [batch_size, seq_len, n_out] 70 | embed = h.new_zeros(*lens.shape, self.n_out) 71 | embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h) 72 | 73 | return embed 74 | -------------------------------------------------------------------------------- /supar/modules/dropout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SharedDropout(nn.Module): 8 | r""" 9 | SharedDropout differs from the vanilla dropout strategy in that 10 | the dropout mask is shared across one dimension. 11 | 12 | Args: 13 | p (float): 14 | The probability of an element to be zeroed. Default: 0.5. 15 | batch_first (bool): 16 | If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``. 17 | Default: ``True``. 18 | 19 | Examples: 20 | >>> x = torch.ones(1, 3, 5) 21 | >>> nn.Dropout()(x) 22 | tensor([[[0., 2., 2., 0., 0.], 23 | [2., 2., 0., 2., 2.], 24 | [2., 2., 2., 2., 0.]]]) 25 | >>> SharedDropout()(x) 26 | tensor([[[2., 0., 2., 0., 2.], 27 | [2., 0., 2., 0., 2.], 28 | [2., 0., 2., 0., 2.]]]) 29 | """ 30 | 31 | def __init__(self, p=0.5, batch_first=True): 32 | super().__init__() 33 | 34 | self.p = p 35 | self.batch_first = batch_first 36 | 37 | def __repr__(self): 38 | s = f"p={self.p}" 39 | if self.batch_first: 40 | s += f", batch_first={self.batch_first}" 41 | 42 | return f"{self.__class__.__name__}({s})" 43 | 44 | def forward(self, x): 45 | r""" 46 | Args: 47 | x (~torch.Tensor): 48 | A tensor of any shape. 49 | Returns: 50 | The returned tensor is of the same shape as `x`. 51 | """ 52 | 53 | if self.training: 54 | if self.batch_first: 55 | mask = self.get_mask(x[:, 0], self.p).unsqueeze(1) 56 | else: 57 | mask = self.get_mask(x[0], self.p) 58 | x = x * mask 59 | 60 | return x 61 | 62 | @staticmethod 63 | def get_mask(x, p): 64 | return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p) 65 | 66 | 67 | class IndependentDropout(nn.Module): 68 | r""" 69 | For :math:`N` tensors, they use different dropout masks respectively. 70 | When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate, 71 | and when all of them are dropped together, zeros are returned. 72 | 73 | Args: 74 | p (float): 75 | The probability of an element to be zeroed. Default: 0.5. 76 | 77 | Examples: 78 | >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5) 79 | >>> x, y = IndependentDropout()(x, y) 80 | >>> x 81 | tensor([[[1., 1., 1., 1., 1.], 82 | [0., 0., 0., 0., 0.], 83 | [2., 2., 2., 2., 2.]]]) 84 | >>> y 85 | tensor([[[1., 1., 1., 1., 1.], 86 | [2., 2., 2., 2., 2.], 87 | [0., 0., 0., 0., 0.]]]) 88 | """ 89 | 90 | def __init__(self, p=0.5): 91 | super().__init__() 92 | 93 | self.p = p 94 | 95 | def __repr__(self): 96 | return f"{self.__class__.__name__}(p={self.p})" 97 | 98 | def forward(self, items): 99 | r""" 100 | Args: 101 | items (list[~torch.Tensor]): 102 | A list of tensors that have the same shape except the last dimension. 103 | Returns: 104 | The returned tensors are of the same shape as `items`. 105 | """ 106 | 107 | if self.training: 108 | masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items] 109 | total = sum(masks) 110 | scale = len(items) / total.max(torch.ones_like(total)) 111 | masks = [mask * scale for mask in masks] 112 | items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)] 113 | 114 | return items 115 | -------------------------------------------------------------------------------- /supar/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from supar.modules.dropout import SharedDropout 5 | 6 | 7 | class MLP(nn.Module): 8 | r""" 9 | Applies a linear transformation together with a non-linear activation to the incoming tensor: 10 | :math:`y = \mathrm{Activation}(x A^T + b)` 11 | 12 | Args: 13 | n_in (~torch.Tensor): 14 | The size of each input feature. 15 | n_out (~torch.Tensor): 16 | The size of each output feature. 17 | dropout (float): 18 | If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0. 19 | activation (bool): 20 | Whether to use activations. Default: True. 21 | """ 22 | 23 | def __init__(self, n_in, n_out, dropout=0, activation=True): 24 | super().__init__() 25 | 26 | self.n_in = n_in 27 | self.n_out = n_out 28 | self.linear = nn.Linear(n_in, n_out) 29 | self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity() 30 | self.dropout = SharedDropout(p=dropout) 31 | 32 | self.reset_parameters() 33 | 34 | def __repr__(self): 35 | s = f"n_in={self.n_in}, n_out={self.n_out}" 36 | if self.dropout.p > 0: 37 | s += f", dropout={self.dropout.p}" 38 | 39 | return f"{self.__class__.__name__}({s})" 40 | 41 | def reset_parameters(self): 42 | nn.init.orthogonal_(self.linear.weight) 43 | nn.init.zeros_(self.linear.bias) 44 | 45 | def forward(self, x): 46 | r""" 47 | Args: 48 | x (~torch.Tensor): 49 | The size of each input feature is `n_in`. 50 | 51 | Returns: 52 | A tensor with the size of each output feature `n_out`. 53 | """ 54 | 55 | x = self.linear(x) 56 | x = self.activation(x) 57 | x = self.dropout(x) 58 | 59 | return x 60 | -------------------------------------------------------------------------------- /supar/modules/scalar_mix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ScalarMix(nn.Module): 8 | r""" 9 | Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` 10 | where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. 11 | 12 | Args: 13 | n_layers (int): 14 | The number of layers to be mixed, i.e., :math:`N`. 15 | dropout (float): 16 | The dropout ratio of the layer weights. 17 | If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0 18 | with the dropout probability (i.e., setting the unnormalized weight to -inf). 19 | This effectively redistributes the dropped probability mass to all other weights. 20 | Default: 0. 21 | """ 22 | 23 | def __init__(self, n_layers, dropout=0): 24 | super().__init__() 25 | 26 | self.n_layers = n_layers 27 | 28 | self.weights = nn.Parameter(torch.zeros(n_layers)) 29 | self.gamma = nn.Parameter(torch.tensor([1.0])) 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | def __repr__(self): 33 | s = f"n_layers={self.n_layers}" 34 | if self.dropout.p > 0: 35 | s += f", dropout={self.dropout.p}" 36 | 37 | return f"{self.__class__.__name__}({s})" 38 | 39 | def forward(self, tensors): 40 | r""" 41 | Args: 42 | tensors (list[~torch.Tensor]): 43 | :math:`N` tensors to be mixed. 44 | 45 | Returns: 46 | The mixture of :math:`N` tensors. 47 | """ 48 | 49 | normed_weights = self.dropout(self.weights.softmax(-1)) 50 | weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors)) 51 | 52 | return self.gamma * weighted_sum 53 | -------------------------------------------------------------------------------- /supar/modules/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from supar.modules.scalar_mix import ScalarMix 4 | from supar.utils.fn import pad 5 | 6 | 7 | class TransformerEmbedding(nn.Module): 8 | r""" 9 | A module that directly utilizes the pretrained models in `transformers`_ to produce BERT representations. 10 | While mainly tailored to provide input preparation and post-processing for the BERT model, 11 | it is also compatiable with other pretrained language models like XLNet, RoBERTa and ELECTRA, etc. 12 | Args: 13 | model (str): 14 | Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``. 15 | n_layers (int): 16 | The number of layers from the model to use. 17 | If 0, uses all layers. 18 | n_out (int): 19 | The requested size of the embeddings. Default: 0. 20 | If 0, uses the size of the pretrained embedding model. 21 | stride (int): 22 | A sequence longer than max length will be splitted into several small pieces 23 | with a window size of ``stride``. Default: 10. 24 | pad_index (int): 25 | The index of the padding token in the BERT vocabulary. Default: 0. 26 | dropout (float): 27 | The dropout ratio of BERT layers. Default: 0. 28 | This value will be passed into the :class:`ScalarMix` layer. 29 | requires_grad (bool): 30 | If ``True``, the model parameters will be updated together with the downstream task. 31 | Default: ``False``. 32 | .. _transformers: 33 | https://github.com/huggingface/transformers 34 | """ 35 | 36 | def __init__(self, model, n_layers, n_out=0, stride=10, pad_index=0, dropout=0, requires_grad=False, use_projection=False, use_scalarmix=True): 37 | super().__init__() 38 | 39 | from transformers import AutoConfig, AutoModel, AutoTokenizer 40 | self.bert = AutoModel.from_pretrained(model, config=AutoConfig.from_pretrained(model, output_hidden_states=True)) 41 | self.bert = self.bert.requires_grad_(requires_grad) 42 | 43 | self.use_scalarmix = use_scalarmix 44 | 45 | self.model = model 46 | self.n_layers = n_layers or self.bert.config.num_hidden_layers 47 | self.hidden_size = self.bert.config.hidden_size 48 | self.n_out = self.hidden_size 49 | self.stride = stride 50 | self.pad_index = pad_index 51 | self.dropout = dropout 52 | self.requires_grad = requires_grad 53 | self.max_len = int(max(0, self.bert.config.max_position_embeddings) or 1e12) 54 | 55 | self.tokenizer = AutoTokenizer.from_pretrained(model) 56 | 57 | self.scalar_mix = ScalarMix(self.n_layers, dropout) 58 | 59 | if use_projection: 60 | self.projection = nn.Linear(self.hidden_size, self.n_out, False) 61 | else: 62 | self.projection = nn.Identity() 63 | 64 | def __repr__(self): 65 | s = f"{self.model}, n_layers={self.n_layers}, n_out={self.n_out}, pad_index={self.pad_index}" 66 | if self.dropout > 0: 67 | s += f", dropout={self.dropout}" 68 | if self.requires_grad: 69 | s += f", requires_grad={self.requires_grad}" 70 | 71 | return f"{self.__class__.__name__}({s})" 72 | 73 | def forward(self, subwords): 74 | r""" 75 | Args: 76 | subwords (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. 77 | Returns: 78 | ~torch.Tensor: 79 | BERT embeddings of shape ``[batch_size, seq_len, n_out]``. 80 | """ 81 | mask = subwords.ne(self.pad_index) 82 | lens = mask.sum((1, 2)) 83 | # [batch_size, n_subwords] 84 | subwords = pad(subwords[mask].split(lens.tolist()), self.pad_index, padding_side=self.tokenizer.padding_side) 85 | bert_mask = pad(mask[mask].split(lens.tolist()), 0, padding_side=self.tokenizer.padding_side) 86 | 87 | # return the hidden states of all layers 88 | bert = self.bert(subwords[:, :self.max_len], attention_mask=bert_mask[:, :self.max_len].float())[-1] 89 | # [n_layers, batch_size, max_len, hidden_size] 90 | 91 | if self.use_scalarmix: 92 | bert = bert[-self.n_layers:] 93 | bert = self.scalar_mix(bert) 94 | # [batch_size, n_subwords, hidden_size] 95 | for i in range(self.stride, 96 | (subwords.shape[1] - self.max_len + self.stride - 1) // self.stride * self.stride + 1, 97 | self.stride): 98 | part = \ 99 | self.bert(subwords[:, i:i + self.max_len], attention_mask=bert_mask[:, i:i + self.max_len].float())[-1] 100 | bert = torch.cat((bert, self.scalar_mix(part[-self.n_layers:])[:, self.max_len - self.stride:]), 1) 101 | 102 | else: 103 | bert = bert[-1] 104 | 105 | 106 | # [batch_size, n_subwords] 107 | bert_lens = mask.sum(-1) 108 | bert_lens = bert_lens.masked_fill_(bert_lens.eq(0), 1) 109 | # [batch_size, seq_len, fix_len, hidden_size] 110 | embed = bert.new_zeros(*mask.shape, self.hidden_size).masked_scatter_(mask.unsqueeze(-1), bert[bert_mask]) 111 | 112 | embed = embed.gather(2, (bert_lens - 1).unsqueeze(-1).repeat(1, 1, self.hidden_size).unsqueeze(2)).squeeze(2) 113 | 114 | # [batch_size, seq_len, hidden_size] 115 | # embed = embed.sum(2) / bert_lens.unsqueeze(-1) 116 | embed = self.projection(embed) 117 | 118 | return embed 119 | -------------------------------------------------------------------------------- /supar/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import alg, field, fn, metric, transform 4 | from .alg import chuliu_edmonds, cky, eisner, eisner2o, kmeans, mst, tarjan 5 | from .config import Config 6 | from .data import Dataset 7 | from .embedding import Embedding 8 | from .field import ChartField, Field, RawField, SubwordField 9 | from .transform import CoNLL, Transform, Tree 10 | from .vocab import Vocab 11 | 12 | __all__ = ['ChartField', 'CoNLL', 'Config', 'Dataset', 'Embedding', 'Field', 13 | 'RawField', 'SubwordField', 'Transform', 'Tree', 'Vocab', 14 | 'alg', 'field', 'fn', 'metric', 'chuliu_edmonds', 'cky', 15 | 'eisner', 'eisner2o', 'kmeans', 'mst', 'tarjan', 'transform'] 16 | -------------------------------------------------------------------------------- /supar/utils/common.py: -------------------------------------------------------------------------------- 1 | PAD = '' 2 | UNK = '' 3 | BOS = '' 4 | EOS = '' 5 | subword_bos = '' 6 | subword_eos = '<_eos>' 7 | no_label = '' 8 | 9 | -------------------------------------------------------------------------------- /supar/utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from ast import literal_eval 4 | from configparser import ConfigParser 5 | 6 | 7 | class Config(object): 8 | 9 | def __init__(self, conf=None, **kwargs): 10 | super(Config, self).__init__() 11 | 12 | config = ConfigParser() 13 | config.read(conf or []) 14 | self.update({**dict((name, literal_eval(value)) 15 | for section in config.sections() 16 | for name, value in config.items(section)), 17 | **kwargs}) 18 | 19 | def __repr__(self): 20 | s = line = "-" * 20 + "-+-" + "-" * 30 + "\n" 21 | s += f"{'Param':20} | {'Value':^30}\n" + line 22 | for name, value in vars(self).items(): 23 | s += f"{name:20} | {str(value):^30}\n" 24 | s += line 25 | 26 | return s 27 | 28 | def __getitem__(self, key): 29 | return getattr(self, key) 30 | 31 | def __getstate__(self): 32 | return vars(self) 33 | 34 | def __setstate__(self, state): 35 | self.__dict__.update(state) 36 | 37 | def keys(self): 38 | return vars(self).keys() 39 | 40 | def items(self): 41 | return vars(self).items() 42 | 43 | def update(self, kwargs): 44 | for key in ('self', 'cls', '__class__'): 45 | kwargs.pop(key, None) 46 | kwargs.update(kwargs.pop('kwargs', dict())) 47 | for name, value in kwargs.items(): 48 | setattr(self, name, value) 49 | 50 | return self 51 | 52 | def pop(self, key, val=None): 53 | return self.__dict__.pop(key, val) 54 | -------------------------------------------------------------------------------- /supar/utils/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | class Embedding(object): 7 | 8 | def __init__(self, tokens, vectors, unk=None): 9 | self.tokens = tokens 10 | self.vectors = torch.tensor(vectors) 11 | self.pretrained = {w: v for w, v in zip(tokens, vectors)} 12 | self.unk = unk 13 | 14 | def __len__(self): 15 | return len(self.tokens) 16 | 17 | def __contains__(self, token): 18 | return token in self.pretrained 19 | 20 | @property 21 | def dim(self): 22 | return self.vectors.size(1) 23 | 24 | @property 25 | def unk_index(self): 26 | if self.unk is not None: 27 | return self.tokens.index(self.unk) 28 | else: 29 | raise AttributeError 30 | 31 | @classmethod 32 | def load(cls, path, unk=None): 33 | with open(path, 'r', encoding="utf-8") as f: 34 | lines = [line for line in f] 35 | 36 | if len(lines[0].split()) == 2: 37 | lines = lines[1:] 38 | 39 | splits = [line.split() for line in lines] 40 | for s in splits: 41 | if not s[1].isdigit(): 42 | splits.remove(s) 43 | 44 | # fuck genia 45 | tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) 46 | for s in splits]) 47 | 48 | return cls(tokens, vectors, unk=unk) 49 | -------------------------------------------------------------------------------- /supar/utils/fn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unicodedata 4 | 5 | 6 | def ispunct(token): 7 | return all(unicodedata.category(char).startswith('P') 8 | for char in token) 9 | 10 | 11 | def isfullwidth(token): 12 | return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] 13 | for char in token) 14 | 15 | 16 | def islatin(token): 17 | return all('LATIN' in unicodedata.name(char) 18 | for char in token) 19 | 20 | 21 | def isdigit(token): 22 | return all('DIGIT' in unicodedata.name(char) 23 | for char in token) 24 | 25 | 26 | def tohalfwidth(token): 27 | return unicodedata.normalize('NFKC', token) 28 | 29 | 30 | def stripe(x, n, w, offset=(0, 0), dim=1): 31 | # r""" 32 | # Returns a diagonal stripe of the tensor. 33 | # 34 | # Args: 35 | # x (~torch.Tensor): the input tensor with 2 or more dims. 36 | # n (int): the length of the stripe. 37 | # w (int): the width of the stripe. 38 | # offset (tuple): the offset of the first two dims. 39 | # dim (int): 1 if returns a horizontal stripe; 0 otherwise. 40 | # 41 | # Returns: 42 | # a diagonal stripe of the tensor. 43 | # Examples: 44 | # >>> x = torch.arange(25).view(5, 5) 45 | # >>> x 46 | # tensor([[ 0, 1, 2, 3, 4], 47 | # [ 5, 6, 7, 8, 9], 48 | # [10, 11, 12, 13, 14], 49 | # [15, 16, 17, 18, 19], 50 | # [20, 21, 22, 23, 24]]) 51 | # >>> stripe(x, 2, 3) 52 | # tensor([[0, 1, 2], 53 | # [6, 7, 8]]) 54 | # >>> stripe(x, 2, 3, (1, 1)) 55 | # tensor([[ 6, 7, 8], 56 | # [12, 13, 14]]) 57 | # >>> stripe(x, 2, 3, (1, 1), 0) 58 | # tensor([[ 6, 11, 16], 59 | # [12, 17, 22]]) 60 | # """ 61 | 62 | x, seq_len = x.contiguous(), x.size(1) 63 | stride, numel = list(x.stride()), x[0, 0].numel() 64 | stride[0] = (seq_len + 1) * numel 65 | stride[1] = (1 if dim == 1 else seq_len) * numel 66 | return x.as_strided(size=(n, w, *x.shape[2:]), 67 | stride=stride, 68 | storage_offset=(offset[0]*seq_len+offset[1])*numel) 69 | 70 | 71 | def pad(tensors, padding_value=0, total_length=None, padding_side='right'): 72 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) 73 | for i in range(len(tensors[0].size()))] 74 | if total_length is not None: 75 | assert total_length >= size[1] 76 | size[1] = total_length 77 | out_tensor = tensors[0].data.new(*size).fill_(padding_value) 78 | for i, tensor in enumerate(tensors): 79 | out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor 80 | return out_tensor 81 | 82 | 83 | if __name__ == '__main__': 84 | import torch 85 | x = torch.arange(25).view(5, 5) 86 | print(stripe(x, 2, 3, (3, 1))) 87 | -------------------------------------------------------------------------------- /supar/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | import os 5 | 6 | from supar.utils.parallel import is_master 7 | from tqdm import tqdm 8 | 9 | 10 | def get_logger(name): 11 | return logging.getLogger(name) 12 | 13 | 14 | def init_logger(logger, 15 | path=None, 16 | mode='w', 17 | level=None, 18 | handlers=None, 19 | verbose=True): 20 | level = level or logging.WARNING 21 | if not handlers: 22 | handlers = [logging.StreamHandler()] 23 | if path: 24 | os.makedirs(os.path.dirname(path), exist_ok=True) 25 | handlers.append(logging.FileHandler(path, mode)) 26 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 27 | datefmt='%Y-%m-%d %H:%M:%S', 28 | level=level, 29 | handlers=handlers) 30 | logger.setLevel(logging.INFO if is_master() and verbose else logging.WARNING) 31 | 32 | 33 | def progress_bar(iterator, 34 | ncols=None, 35 | bar_format='{l_bar}{bar:18}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}', 36 | leave=True): 37 | return tqdm(iterator, 38 | ncols=ncols, 39 | bar_format=bar_format, 40 | ascii=True, 41 | disable=(not (logger.level == logging.INFO and is_master())), 42 | leave=leave) 43 | 44 | 45 | logger = get_logger('supar') 46 | -------------------------------------------------------------------------------- /supar/utils/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from random import Random 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | 10 | 11 | class DistributedDataParallel(nn.parallel.DistributedDataParallel): 12 | 13 | def __init__(self, module, **kwargs): 14 | super().__init__(module, **kwargs) 15 | 16 | def __getattr__(self, name): 17 | wrapped = super().__getattr__('module') 18 | if hasattr(wrapped, name): 19 | return getattr(wrapped, name) 20 | return super().__getattr__(name) 21 | 22 | 23 | def init_device(device, local_rank=-1, backend='nccl', host=None, port=None): 24 | os.environ['CUDA_VISIBLE_DEVICES'] = device 25 | if torch.cuda.device_count() > 1: 26 | host = host or os.environ.get('MASTER_ADDR', 'localhost') 27 | port = port or os.environ.get('MASTER_PORT', str(Random(0).randint(10000, 20000))) 28 | os.environ['MASTER_ADDR'] = host 29 | os.environ['MASTER_PORT'] = port 30 | dist.init_process_group(backend) 31 | torch.cuda.set_device(local_rank) 32 | 33 | 34 | def is_master(): 35 | return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 36 | -------------------------------------------------------------------------------- /supar/utils/vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import defaultdict 4 | from collections.abc import Iterable 5 | 6 | 7 | class Vocab(object): 8 | r""" 9 | Defines a vocabulary object that will be used to numericalize a field. 10 | 11 | Args: 12 | counter (~collections.Counter): 13 | :class:`~collections.Counter` object holding the frequencies of each value found in the data. 14 | min_freq (int): 15 | The minimum frequency needed to include a token in the vocabulary. Default: 1. 16 | specials (list[str]): 17 | The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: []. 18 | unk_index (int): 19 | The index of unk token. Default: 0. 20 | 21 | Attributes: 22 | itos: 23 | A list of token strings indexed by their numerical identifiers. 24 | stoi: 25 | A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. 26 | """ 27 | 28 | def __init__(self, counter, min_freq=1, specials=[], unk_index=0): 29 | self.itos = list(specials) 30 | self.stoi = defaultdict(lambda: unk_index) 31 | 32 | self.stoi.update({token: i for i, token in enumerate(self.itos)}) 33 | self.extend([token for token, freq in counter.items() 34 | if freq >= min_freq]) 35 | self.unk_index = unk_index 36 | self.n_init = len(self) 37 | 38 | def __len__(self): 39 | return len(self.itos) 40 | 41 | def __getitem__(self, key): 42 | if isinstance(key, str): 43 | return self.stoi[key] 44 | elif not isinstance(key, Iterable): 45 | return self.itos[key] 46 | elif isinstance(key[0], str): 47 | return [self.stoi[i] for i in key] 48 | else: 49 | return [self.itos[i] for i in key] 50 | 51 | def __contains__(self, token): 52 | return token in self.stoi 53 | 54 | def __getstate__(self): 55 | # avoid picking defaultdict 56 | attrs = dict(self.__dict__) 57 | # cast to regular dict 58 | attrs['stoi'] = dict(self.stoi) 59 | return attrs 60 | 61 | def __setstate__(self, state): 62 | stoi = defaultdict(lambda: self.unk_index) 63 | stoi.update(state['stoi']) 64 | state['stoi'] = stoi 65 | self.__dict__.update(state) 66 | 67 | 68 | ### do not change 69 | def extend(self, tokens): 70 | try: 71 | self.stoi = defaultdict(lambda: self.unk_index) 72 | for word in self.itos: 73 | self.stoi[word] 74 | except: 75 | pass 76 | 77 | self.itos.extend(sorted(set(tokens).difference(self.stoi))) 78 | self.stoi.update({token: i for i, token in enumerate(self.itos)}) 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import hydra 7 | import omegaconf 8 | import pytorch_lightning as pl 9 | from hydra.core.hydra_config import HydraConfig 10 | import wandb 11 | from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer 12 | from pytorch_lightning.loggers import LightningLoggerBase 13 | from pytorch_lightning import seed_everything 14 | 15 | # hydra imports 16 | from omegaconf import DictConfig 17 | from hydra.utils import log 18 | import hydra 19 | from omegaconf import OmegaConf 20 | 21 | # normal imports 22 | from typing import List 23 | import warnings 24 | import logging 25 | from pytorch_lightning.loggers import WandbLogger 26 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 27 | import sys 28 | # sys.setdefaultencoding() does not exist, here! 29 | from omegaconf import OmegaConf, open_dict 30 | 31 | 32 | def train(config): 33 | # if contains this, means we are multi-run and optuna-ing 34 | log.info(OmegaConf.to_container(config,resolve=True)) 35 | config.root = hydra.utils.get_original_cwd() 36 | limit_train_batches = 1.0 37 | 38 | 39 | hydra_dir = str(os.getcwd()) 40 | seed_everything(config.seed) 41 | os.chdir(hydra.utils.get_original_cwd()) 42 | 43 | # Instantiate datamodule 44 | hydra.utils.log.info(os.getcwd()) 45 | hydra.utils.log.info(f"Instantiating <{config.datamodule.target}>") 46 | # Instantiate callbacks and logger. 47 | callbacks: List[Callback] = [] 48 | logger: List[LightningLoggerBase] = [] 49 | 50 | # ------------------------------- debug -------------------------- # 51 | if config.debug is True: 52 | config.trainer.max_epochs = 10 53 | config.trainer.fast_dev_run = False 54 | config.trainer.gpus = 0 55 | # config.datamodule.debug = True 56 | config.datamodule.use_bert = False 57 | config.datamodule.use_char = True 58 | config.wandb = False 59 | # distributed = False 60 | # config.datamodule.suffix = '.debug' 61 | config.datamodule.use_cache=False 62 | config.datamodule.train_const += ".debug" 63 | config.datamodule.dev_const += ".debug" 64 | config.datamodule.test_const += ".debug" 65 | config.datamodule.use_emb = False 66 | else: 67 | pass 68 | # distributed = config.distributed 69 | 70 | # -------------------------------end debug -------------------------- # 71 | datamodule: pl.LightningDataModule = hydra.utils.instantiate( 72 | config.datamodule.target, config.datamodule, _recursive_=False 73 | ) 74 | 75 | log.info("created datamodule") 76 | datamodule.setup() 77 | model = hydra.utils.instantiate(config.runner, cfg = config, fields=datamodule.fields, datamodule=datamodule, _recursive_=False) 78 | 79 | os.chdir(hydra_dir) 80 | 81 | 82 | 83 | # Train the model ⚡ 84 | if "callbacks" in config: 85 | for _, cb_conf in config["callbacks"].items(): 86 | if "_target_" in cb_conf: 87 | log.info(f"Instantiating callback <{cb_conf._target_}>") 88 | callbacks.append(hydra.utils.instantiate(cb_conf)) 89 | 90 | if config.checkpoint: 91 | callbacks.append( 92 | ModelCheckpoint( 93 | monitor='valid/score', 94 | mode='max', 95 | save_last=False, 96 | filename='checkpoint' 97 | ) 98 | ) 99 | log.info("Instantiating callback, ModelCheckpoint") 100 | 101 | 102 | if config.wandb: 103 | logger.append(hydra.utils.instantiate(config.logger)) 104 | 105 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 106 | trainer: Trainer = hydra.utils.instantiate( 107 | config.trainer, callbacks=callbacks, logger=logger, 108 | replace_sampler_ddp=False, 109 | # accelerator='ddp' if distributed else None, 110 | accumulate_grad_batches=config.accumulation, 111 | # limit_train_batches=0.1, 112 | checkpoint_callback=config.checkpoint, 113 | # turnoff sanity check. 0 turn off, -1 all, positive number is n samples. depends on yours. 114 | # TODO: there is a bug which is conflits to 'write_result_to_file' TOFIX..... 115 | num_sanity_val_steps=0, 116 | # limit_train_batches=.1, 117 | ) 118 | 119 | log.info(f"Starting training!") 120 | if config.wandb: 121 | logger[-1].experiment.save(str(hydra_dir) + "/.hydra/*", base_path=str(hydra_dir)) 122 | 123 | trainer.fit(model, datamodule) 124 | log.info(f"Finalizing!") 125 | 126 | if config.wandb: 127 | logger[-1].experiment.save(str(hydra_dir) + "/*.log", base_path=str(hydra_dir)) 128 | wandb.finish() 129 | 130 | log.info(f'hydra_path:{os.getcwd()}') 131 | 132 | 133 | @hydra.main(config_path="configs/", config_name="config.yaml") 134 | def main(config): 135 | train(config) 136 | 137 | if __name__ == "__main__": 138 | main() 139 | 140 | 141 | --------------------------------------------------------------------------------