├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── span-based-dependency-parsing-2.iml └── vcs.xml ├── README.md ├── configs ├── config.yaml ├── datamodule │ ├── _base.yaml │ ├── ctb.yaml │ ├── ptb.yaml │ └── ud2.2.yaml ├── exp │ └── base.yaml ├── finetune │ └── base.yaml ├── logger │ ├── comet.yaml │ ├── csv.yaml │ ├── many_loggers.yaml │ ├── neptune.yaml │ ├── tensorboard.yaml │ └── wandb.yaml ├── model │ ├── _base.yaml │ ├── biaffine.yaml │ ├── biaffine2o.yaml │ ├── span.yaml │ ├── span1o.yaml │ ├── span1oheadplit.yaml │ └── span2oheadplit.yaml ├── optim │ └── finetune_bert.yaml └── trainer │ └── default_trainer.yaml ├── fastNLP ├── __init__.py ├── core │ ├── __init__.py │ ├── _logger.py │ ├── _parallel_utils.py │ ├── batch.py │ ├── callback.py │ ├── collate_fn.py │ ├── const.py │ ├── dataset.py │ ├── dist_trainer.py │ ├── field.py │ ├── instance.py │ ├── losses.py │ ├── metrics.py │ ├── optimizer.py │ ├── predictor.py │ ├── sampler.py │ ├── tester.py │ ├── trainer.py │ ├── utils.py │ └── vocabulary.py ├── doc_utils.py ├── embeddings │ ├── __init__.py │ ├── bert_embedding.py │ ├── char_embedding.py │ ├── contextual_embedding.py │ ├── elmo_embedding.py │ ├── embedding.py │ ├── gpt2_embedding.py │ ├── roberta_embedding.py │ ├── stack_embedding.py │ ├── static_embedding.py │ └── utils.py ├── io │ ├── __init__.py │ ├── data_bundle.py │ ├── embed_loader.py │ ├── file_reader.py │ ├── file_utils.py │ ├── loader │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── conll.py │ │ ├── coreference.py │ │ ├── csv.py │ │ ├── cws.py │ │ ├── json.py │ │ ├── loader.py │ │ ├── matching.py │ │ ├── qa.py │ │ └── summarization.py │ ├── model_io.py │ ├── pipe │ │ ├── __init__.py │ │ ├── classification.py │ │ ├── conll.py │ │ ├── coreference.py │ │ ├── cws.py │ │ ├── matching.py │ │ ├── pipe.py │ │ ├── qa.py │ │ ├── summarization.py │ │ └── utils.py │ └── utils.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── bert.py │ ├── biaffine_parser.py │ ├── cnn_text_classification.py │ ├── seq2seq_generator.py │ ├── seq2seq_model.py │ ├── sequence_labeling.py │ ├── snli.py │ └── star_transformer.py └── modules │ ├── __init__.py │ ├── attention.py │ ├── decoder │ ├── __init__.py │ ├── crf.py │ ├── mlp.py │ ├── seq2seq_decoder.py │ ├── seq2seq_state.py │ └── utils.py │ ├── dropout.py │ ├── encoder │ ├── __init__.py │ ├── _elmo.py │ ├── bert.py │ ├── char_encoder.py │ ├── conv_maxpool.py │ ├── gpt2.py │ ├── lstm.py │ ├── pooling.py │ ├── roberta.py │ ├── seq2seq_encoder.py │ ├── star_transformer.py │ ├── transformer.py │ └── variational_rnn.py │ ├── generator │ ├── __init__.py │ └── seq2seq_generator.py │ ├── tokenizer │ ├── __init__.py │ ├── bert_tokenizer.py │ ├── gpt2_tokenizer.py │ └── roberta_tokenizer.py │ └── utils.py ├── requirements.txt ├── src ├── callbacks │ ├── progressbar.py │ ├── transformer_scheduler.py │ └── wandb_callbacks.py ├── constant.py ├── datamodule │ ├── __init__.py │ ├── base.py │ ├── dep_data.py │ └── dm_util │ │ ├── datamodule_util.py │ │ ├── fields.py │ │ ├── padder.py │ │ └── util.py ├── inside │ ├── __init__.py │ ├── eisner.py │ ├── eisner2o.py │ ├── eisner_satta.py │ ├── fn.py │ └── span.py ├── loss │ ├── __init__.py │ ├── dep_loss.py │ └── get_score.py ├── model │ ├── dep_parsing.py │ ├── metric.py │ └── module │ │ ├── ember │ │ ├── embedding.py │ │ └── ext_embedding.py │ │ ├── encoder │ │ └── lstm_encoder.py │ │ └── scorer │ │ ├── dep_scorer.py │ │ └── module │ │ ├── biaffine.py │ │ ├── nhpsg_scorer.py │ │ ├── quadra_linear.py │ │ └── triaffine.py └── runner │ └── base.py ├── supar ├── __init__.py ├── modules │ ├── __init__.py │ ├── affine.py │ ├── bert.py │ ├── char_lstm.py │ ├── dropout.py │ ├── lstm.py │ ├── mlp.py │ ├── scalar_mix.py │ ├── transformer.py │ ├── treecrf.py │ └── variational_inference.py └── utils │ ├── __init__.py │ ├── alg.py │ ├── common.py │ ├── config.py │ ├── data.py │ ├── embedding.py │ ├── field.py │ ├── field_new.py │ ├── fn.py │ ├── logging.py │ ├── metric.py │ ├── parallel.py │ ├── transform.py │ └── vocab.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/span-based-dependency-parsing-2.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # span-based-dependency-parsing 2 | Source code of ACL2022 "[Headed-Span-Based Projective Dependency Parsing](http://arxiv.org/abs/2108.04750)" and 3 | Findings of ACL2022"[Combining (second-order) graph-based and headed-span-based projective dependency parsing](https://arxiv.org/pdf/2108.05838.pdf)" 4 | 5 | ## Setup 6 | setup environment 7 | ``` 8 | conda create -n parsing python=3.7 9 | conda activate parsing 10 | while read requirement; do pip install $requirement; done < requirements.txt 11 | ``` 12 | 13 | setup dataset: 14 | 15 | you can download the datasets I used from [link](https://mega.nz/file/jFIijLTI#b0b7550tdYVNcpGfgaXc0sk0F943lrt8D35v1SW2wbg). 16 | 17 | # Run 18 | ``` 19 | python train.py +exp=base datamodule=a model=b seed=0 20 | a={ptb, ctb, ud2.2} 21 | b={biaffine, biaffine2o, span, span1o, span1oheadsplit, span2oheadsplit} 22 | ``` 23 | 24 | multirun example: 25 | ``` 26 | python train.py +exp=base datamodule.ud2.2 model=b datamodule.ud_lan=de,it,en,ca,cs,es,fr,no,ru,es,nl,bg seed=0,1,2 --mutlirun 27 | ``` 28 | For UD, you also need to setup the JAVA environment for the use of MaltParser. 29 | You need download MaltParser v1.9.2 from [link](https://www.maltparser.org/download.html). 30 | 31 | # Contact 32 | Please let me know if there are any bugs. Also, feel free to contact bestsonta@gmail.com if you have any questions. 33 | 34 | # Citation 35 | ``` 36 | @inproceedings{yang-tu-2022-headed, 37 | title={Headed-Span-Based Projective Dependency Parsing}, 38 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics", 39 | author={Songlin Yang and Kewei Tu}, 40 | year={2022} 41 | } 42 | 43 | @misc{yang-tu-2022-combining, 44 | title={Combining (second-order) graph-based and headed-span-based projective dependency parsing}, 45 | author={Songlin Yang and Kewei Tu}, 46 | year={2022}, 47 | booktitle = "Findings of ACL", 48 | } 49 | ``` 50 | 51 | # Acknowledge 52 | The code is based on [lightning+hydra](https://github.com/ashleve/lightning-hydra-template) template. I use [FastNLP](https://github.com/fastnlp/fastNLP) as dataloader. I use lots of built-in modules (LSTMs, Biaffines, Triaffines, Dropout Layers, etc) from [Supar](https://github.com/yzhangcs/parser/tree/main/supar). 53 | 54 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default training configuration 4 | defaults: 5 | - trainer: default_trainer 6 | - optim: finetune_bert 7 | - model: _base 8 | - datamodule: _base 9 | 10 | runner: 11 | _target_: src.runner.base.Runner 12 | 13 | work_dir: ${hydra:runtime.cwd}/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed} 14 | 15 | wandb: False 16 | checkpoint: False 17 | device: 0 18 | seed: 0 19 | accumulation: 1 20 | use_logger: True 21 | distributed: False 22 | 23 | # output paths for hydra logs 24 | root: "." 25 | suffix: "" 26 | 27 | hydra: 28 | run: 29 | dir: ${work_dir} 30 | sweep: 31 | dir: logs/multiruns/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed} 32 | subdir: ${hydra.job.num} 33 | job: 34 | env_set: 35 | WANDB_CONSOLE: 'off' 36 | 37 | -------------------------------------------------------------------------------- /configs/datamodule/_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | use_char: False 5 | use_bert: True 6 | use_pos: False 7 | use_word: False 8 | use_emb: False 9 | use_sib: "${model.scorer.use_sib}" 10 | use_span_head_word: "${model.scorer.use_span}" 11 | ext_emb_path: "" 12 | bert: '' 13 | min_freq: 2 14 | fix_len: 20 15 | train_sampler_type: 'token' 16 | test_sampler_type: 'token' 17 | bucket: 32 18 | bucket_test: 32 19 | max_tokens: 5000 20 | max_tokens_test: 5000 21 | use_cache: True 22 | use_bert_cache: True 23 | max_len: 10000 24 | max_len_test: 10000 25 | root: '.' 26 | distributed: False 27 | # for PTB only. clean (-RHS-) 28 | clean_word: False 29 | 30 | -------------------------------------------------------------------------------- /configs/datamodule/ctb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | datamodule: 7 | target: 8 | _target_: src.datamodule.dep_data.DepData 9 | 10 | train_dep: "${root}/data/ctb/train.ctb.conll" 11 | dev_dep: "${root}/data/ctb/dev.ctb.conll" 12 | test_dep: "${root}/data/ctb/test.ctb.conll" 13 | 14 | bert: "bert-base-chinese" 15 | ignore_punct: True 16 | name: 'ctb' 17 | use_pos: True 18 | 19 | # ext_emb_path: ${.mapping.${.emb_type}} 20 | # emb_type: sskip 21 | # mapping: 22 | # giga: "data/giga.100.txt" 23 | # sskip: "data/sskip.chn.50" 24 | 25 | cache: "${root}/data/ctb/ctb.dep.pickle" 26 | cache_bert: "${root}/data/ctb/ctb.dep.cache_${bert}" 27 | 28 | model: 29 | metric: 30 | target: 31 | _target_: src.model.metric.AttachmentMetric 32 | write_result_to_file: True 33 | 34 | 35 | -------------------------------------------------------------------------------- /configs/datamodule/ptb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.dep_data.DepData 8 | name: 'ptb' 9 | train_dep: "${root}/data/ptb/train.gold.conllu" 10 | dev_dep: "${root}/data/ptb/dev.gold.conllu" 11 | test_dep: "${root}/data/ptb/test.gold.conllu" 12 | cache: "${root}/data/ptb/ptb.dep.pickle" 13 | cache_bert: "${root}/data/ptb/ptb.dep.cache_${datamodule.bert}" 14 | ext_emb_path: "${root}/data/ptb/glove.6B.100d.txt" 15 | ignore_punct: True 16 | clean_word: True 17 | bert: 'bert-large-cased' 18 | 19 | model: 20 | metric: 21 | target: 22 | _target_: src.model.metric.AttachmentMetric 23 | write_result_to_file: True 24 | 25 | 26 | -------------------------------------------------------------------------------- /configs/datamodule/ud2.2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _base 4 | 5 | datamodule: 6 | target: 7 | _target_: src.datamodule.dep_data.DepUD 8 | abbreviation: 9 | 'no': no_bokmaal 10 | bg: bg_btb 11 | ca: ca_ancora 12 | cs: cs_pdt 13 | de: de_gsd 14 | en: en_ewt 15 | es: es_ancora 16 | fr: fr_gsd 17 | it: it_isdt 18 | 'nl': nl_alpino 19 | ro: ro_rrt 20 | ru: ru_syntagrus 21 | # extra languages 22 | ta: ta_ttb 23 | ko: ko_gsd 24 | zh: zh_gsd 25 | 26 | use_pos: True 27 | ud_lan: bg 28 | ud_name: "${.abbreviation.${.ud_lan}}" 29 | ud_ver: 2.2 30 | name: "ud2.2_${.ud_lan}" 31 | # Do not ignore punctuations while evaluating 32 | ignore_punct: True 33 | max_len: 200 34 | train_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-train.conllu" 35 | dev_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-dev.conllu" 36 | test_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-test.conllu" 37 | cache: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-cache.pickle" 38 | cache_bert: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-cache_bert.pickle" 39 | ext_emb_path: "${root}/data/ud${.ud_ver}/fasttext/nogen/${.ud_lan}.lower.nogen.300.txt" 40 | bert: "bert-base-multilingual-cased" 41 | #use_bert: False 42 | #use_emb: False 43 | #ext_emb: "" 44 | ud_mapping: 45 | bg_btb: Bulgarian-BTB 46 | ca_ancora: Catalan-AnCora 47 | cs_pdt: Czech-PDT 48 | de_gsd: German-GSD 49 | en_ewt: English-EWT 50 | es_ancora: Spanish-AnCora 51 | fr_gsd: French-GSD 52 | it_isdt: Italian-ISDT 53 | nl_alpino: Dutch-Alpino 54 | no_bokmaal: Norwegian-Bokmaal 55 | ro_rrt: Romanian-RRT 56 | ru_syntagrus: Russian-SynTagRus 57 | # extra languages 58 | ta_ttb : Tamil-TTB 59 | ko_gsd: Korean-GSD 60 | zh_gsd: Chinese-GSD 61 | 62 | model: 63 | metric: 64 | target: 65 | _target_: src.model.metric.PseudoProjDepExternalMetric 66 | write_result_to_file: True 67 | lan: ${datamodule.ud_lan} 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /configs/exp/base.yaml: -------------------------------------------------------------------------------- 1 | 2 | # @package _global_ 3 | 4 | 5 | trainer: 6 | min_epochs: 1 7 | max_epochs: 10 8 | 9 | 10 | # 16*250=4000 11 | accumulation: 16 12 | 13 | datamodule: 14 | max_tokens: 250 15 | max_tokens_test: 250 16 | 17 | # save checkpoints of the model. 18 | checkpoint: False 19 | 20 | model: 21 | embeder: 22 | finetune: True 23 | 24 | optim: 25 | only_embeder: True 26 | 27 | 28 | callbacks: 29 | transformer_scheduler: 30 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler 31 | warmup: ${optim.warmup} 32 | 33 | pretty_progress_bar: 34 | _target_: src.callbacks.progressbar.PrettyProgressBar 35 | refresh_rate: 1 36 | process_position: 0 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /configs/finetune/base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /optim: finetune_bert 5 | 6 | trainer: 7 | min_epochs: 1 8 | max_epochs: 10 9 | 10 | callbacks: 11 | transformer_scheduler: 12 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler 13 | warmup: ${optim.warmup} 14 | 15 | model: 16 | embeder: 17 | finetune: True 18 | 19 | optim: 20 | only_embeder: True -------------------------------------------------------------------------------- /configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: pytorch_lightning.loggers.comet.CometLogger 5 | api_key: ??? 6 | project_name: "project_template_test" 7 | experiment_name: null 8 | -------------------------------------------------------------------------------- /configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # CSVLogger built in PyTorch Lightning 2 | 3 | csv: 4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger 5 | save_dir: "." 6 | name: "csv/" 7 | -------------------------------------------------------------------------------- /configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | - csv.yaml 5 | - wandb.yaml 6 | # - neptune.yaml 7 | # - comet.yaml 8 | # - tensorboard.yaml 9 | -------------------------------------------------------------------------------- /configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger 5 | project_name: "your_name/lightning-hydra-template-test" 6 | api_key: ${env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable 7 | # experiment_name: "some_experiment" 8 | -------------------------------------------------------------------------------- /configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # TensorBoard 2 | 3 | tensorboard: 4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "tensorboard/" 6 | name: "default" 7 | -------------------------------------------------------------------------------- /configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | logger: 2 | wandb: 3 | _target_: pytorch_lightning.loggers.wandb.WandbLogger 4 | project: "supervised_parsing" 5 | entity: "sonta2020" 6 | job_type: "train" 7 | group: "" 8 | name: "" 9 | save_dir: "${work_dir}" 10 | -------------------------------------------------------------------------------- /configs/model/_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | 4 | model: 5 | target: 6 | _target_: src.model.dep_parsing.ProjectiveDepParser 7 | 8 | embeder: 9 | target: 10 | _target_: src.model.module.ember.embedding.Embeder 11 | 12 | #pos 13 | n_pos_embed: 100 14 | #char 15 | n_char_embed: 50 16 | n_char_out: 100 17 | char_input_dropout: 0. 18 | # bert 19 | n_bert_out: 1024 20 | n_bert_layers: 4 21 | mix_dropout: 0. 22 | use_projection: False 23 | use_scalarmix: False 24 | finetune: True 25 | #word 26 | n_embed: 300 27 | 28 | encoder: 29 | target: 30 | _target_: src.model.module.encoder.lstm_encoder.LSTMencoder 31 | embed_dropout: .33 32 | embed_dropout_type: shared 33 | lstm_dropout: .33 34 | n_lstm_hidden: 1000 35 | n_lstm_layers: 3 36 | before_lstm_dropout: 0. 37 | 38 | scorer: 39 | target: 40 | _target_: src.model.module.scorer.dep_scorer.DepScorer 41 | n_mlp_arc: 600 42 | n_mlp_rel: 300 43 | n_mlp_sib: 300 44 | mlp_dropout: .33 45 | scaling: False 46 | use_arc: False 47 | use_span: False 48 | use_sib: False 49 | span_scorer_type: biaffine 50 | 51 | 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/model/biaffine.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: True 9 | use_sib: False 10 | use_span: False 11 | 12 | 13 | loss: 14 | target: 15 | _target_: src.loss.dep_loss.DepLoss 16 | loss_type: 'mm' 17 | 18 | name: 'dep1o_${model.loss.loss_type}' 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/model/biaffine2o.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: True 9 | use_sib: True 10 | use_span: False 11 | 12 | loss: 13 | target: 14 | _target_: src.loss.dep_loss.Dep2O 15 | loss_type: 'mm' 16 | 17 | name: 'dep2o_${model.loss.loss_type}' 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /configs/model/span.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: False 9 | use_sib: False 10 | use_span: True 11 | 12 | 13 | loss: 14 | target: 15 | _target_: src.loss.dep_loss.Span 16 | loss_type: 'mm' 17 | 18 | name: 'span_${model.loss.loss_type}' 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/model/span1o.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: True 9 | use_sib: False 10 | use_span: True 11 | span_scorer_type: biaffine 12 | 13 | loss: 14 | target: 15 | _target_: src.loss.dep_loss.DepSpanLoss 16 | loss_type: 'mm' 17 | 18 | name: 'dep1o_span_${model.loss.loss_type}' 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /configs/model/span1oheadplit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: True 9 | use_sib: False 10 | use_span: True 11 | span_scorer_type: headsplit 12 | 13 | loss: 14 | target: 15 | _target_: src.loss.dep_loss.DepSpanLoss 16 | loss_type: 'mm' 17 | 18 | name: 'dep1o_span_headsplit_${model.loss.loss_type}' 19 | 20 | 21 | -------------------------------------------------------------------------------- /configs/model/span2oheadplit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _base 5 | 6 | model: 7 | scorer: 8 | use_arc: True 9 | use_sib: True 10 | use_span: True 11 | span_scorer_type: headsplit 12 | 13 | loss: 14 | target: 15 | _target_: src.loss.dep_loss.DepSpanLoss 16 | loss_type: 'mm' 17 | 18 | name: 'dep2o_span_headsplit_${model.loss.loss_type}' 19 | 20 | 21 | -------------------------------------------------------------------------------- /configs/optim/finetune_bert.yaml: -------------------------------------------------------------------------------- 1 | warmup: 0.1 2 | #beta1: 0.9 3 | #beta2: 0.9 4 | #eps: 1e-12 5 | lr: 5e-5 6 | #weight_decay: 0 7 | lr_rate: 50 8 | scheduler_type: 'linear_warmup' 9 | 10 | -------------------------------------------------------------------------------- /configs/trainer/default_trainer.yaml: -------------------------------------------------------------------------------- 1 | _target_: pytorch_lightning.Trainer 2 | gpus: 1 # set -1 to train on all GPUs available, set 0 to train on CPU only 3 | min_epochs: 1 4 | max_epochs: 10 5 | gradient_clip_val: 5 6 | num_sanity_val_steps: 3 7 | progress_bar_refresh_rate: 20 8 | weights_summary: null 9 | fast_dev_run: False 10 | #default_root_dir: "lightning_logs/" 11 | -------------------------------------------------------------------------------- /fastNLP/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、 3 | :mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。 4 | 5 | - :mod:`~fastNLP.core` 是fastNLP 的核心模块,包括 DataSet、 Trainer、 Tester 等组件。详见文档 :mod:`fastNLP.core` 6 | - :mod:`~fastNLP.io` 是实现输入输出的模块,包括了数据集的读取,模型的存取等功能。详见文档 :mod:`fastNLP.io` 7 | - :mod:`~fastNLP.embeddings` 提供用于构建复杂网络模型所需的各种embedding。详见文档 :mod:`fastNLP.embeddings` 8 | - :mod:`~fastNLP.modules` 包含了用于搭建神经网络模型的诸多组件,可以帮助用户快速搭建自己所需的网络。详见文档 :mod:`fastNLP.modules` 9 | - :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括 :class:`~fastNLP.models.CNNText` 、 :class:`~fastNLP.models.SeqLabeling` 等常见模型。详见文档 :mod:`fastNLP.models` 10 | 11 | fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的文档如下: 12 | """ 13 | __all__ = [ 14 | "Instance", 15 | "FieldArray", 16 | 17 | "DataSetIter", 18 | "BatchIter", 19 | "TorchLoaderIter", 20 | 21 | "Vocabulary", 22 | "DataSet", 23 | "Const", 24 | 25 | "Trainer", 26 | "Tester", 27 | 28 | "DistTrainer", 29 | "get_local_rank", 30 | 31 | "Callback", 32 | "GradientClipCallback", 33 | "EarlyStopCallback", 34 | "FitlogCallback", 35 | "EvaluateCallback", 36 | "LRScheduler", 37 | "ControlC", 38 | "LRFinder", 39 | "TensorboardCallback", 40 | "WarmupCallback", 41 | 'SaveModelCallback', 42 | "CallbackException", 43 | "EarlyStopError", 44 | "CheckPointCallback", 45 | 46 | "Padder", 47 | "AutoPadder", 48 | "EngChar2DPadder", 49 | 50 | # "CollateFn", 51 | "ConcatCollateFn", 52 | 53 | "MetricBase", 54 | "AccuracyMetric", 55 | "SpanFPreRecMetric", 56 | "CMRC2018Metric", 57 | "ClassifyFPreRecMetric", 58 | "ConfusionMatrixMetric", 59 | 60 | "Optimizer", 61 | "SGD", 62 | "Adam", 63 | "AdamW", 64 | 65 | "Sampler", 66 | "SequentialSampler", 67 | "BucketSampler", 68 | "RandomSampler", 69 | "SortedSampler", 70 | "ConstantTokenNumSampler", 71 | 72 | "LossFunc", 73 | "CrossEntropyLoss", 74 | "MSELoss", 75 | "L1Loss", 76 | "BCELoss", 77 | "NLLLoss", 78 | "LossInForward", 79 | "LossBase", 80 | "CMRC2018Loss", 81 | 82 | "cache_results", 83 | 84 | 'logger', 85 | "init_logger_dist", 86 | ] 87 | __version__ = '0.5.6' 88 | 89 | import sys 90 | 91 | from . import embeddings 92 | from . import models 93 | from . import modules 94 | from .core import * 95 | from .doc_utils import doc_process 96 | from .io import loader, pipe 97 | 98 | doc_process(sys.modules[__name__]) 99 | -------------------------------------------------------------------------------- /fastNLP/core/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import, 3 | 例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式:: 4 | 5 | # 直接从 fastNLP 中 import 6 | from fastNLP import DataSetIter 7 | 8 | # 从 core 模块的子模块 batch 中 import DataSetIter 9 | from fastNLP.core.batch import DataSetIter 10 | 11 | 对于常用的功能,你只需要在 :mod:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 12 | 13 | """ 14 | __all__ = [ 15 | "DataSet", 16 | 17 | "Instance", 18 | 19 | "FieldArray", 20 | "Padder", 21 | "AutoPadder", 22 | "EngChar2DPadder", 23 | 24 | "ConcatCollateFn", 25 | 26 | "Vocabulary", 27 | 28 | "DataSetIter", 29 | "BatchIter", 30 | "TorchLoaderIter", 31 | 32 | "Const", 33 | 34 | "Tester", 35 | "Trainer", 36 | 37 | "DistTrainer", 38 | "get_local_rank", 39 | 40 | "cache_results", 41 | "seq_len_to_mask", 42 | "get_seq_len", 43 | "logger", 44 | "init_logger_dist", 45 | 46 | "Callback", 47 | "GradientClipCallback", 48 | "EarlyStopCallback", 49 | "FitlogCallback", 50 | "EvaluateCallback", 51 | "LRScheduler", 52 | "ControlC", 53 | "LRFinder", 54 | "TensorboardCallback", 55 | "WarmupCallback", 56 | 'SaveModelCallback', 57 | "CallbackException", 58 | "EarlyStopError", 59 | "CheckPointCallback", 60 | 61 | "LossFunc", 62 | "CrossEntropyLoss", 63 | "L1Loss", 64 | "BCELoss", 65 | "NLLLoss", 66 | "LossInForward", 67 | "CMRC2018Loss", 68 | "MSELoss", 69 | "LossBase", 70 | 71 | "MetricBase", 72 | "AccuracyMetric", 73 | "SpanFPreRecMetric", 74 | "CMRC2018Metric", 75 | "ClassifyFPreRecMetric", 76 | "ConfusionMatrixMetric", 77 | 78 | "Optimizer", 79 | "SGD", 80 | "Adam", 81 | "AdamW", 82 | 83 | "SequentialSampler", 84 | "BucketSampler", 85 | "RandomSampler", 86 | "Sampler", 87 | "SortedSampler", 88 | "ConstantTokenNumSampler" 89 | ] 90 | 91 | from ._logger import logger, init_logger_dist 92 | from .batch import DataSetIter, BatchIter, TorchLoaderIter 93 | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ 94 | LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \ 95 | EarlyStopError, CheckPointCallback 96 | from .const import Const 97 | from .dataset import DataSet 98 | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder 99 | from .instance import Instance 100 | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \ 101 | LossInForward, CMRC2018Loss, LossBase, MSELoss 102 | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\ 103 | ConfusionMatrixMetric 104 | from .optimizer import Optimizer, SGD, Adam, AdamW 105 | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler 106 | from .tester import Tester 107 | from .trainer import Trainer 108 | from .utils import cache_results, seq_len_to_mask, get_seq_len 109 | from .vocabulary import Vocabulary 110 | from .collate_fn import ConcatCollateFn 111 | from .dist_trainer import DistTrainer, get_local_rank 112 | -------------------------------------------------------------------------------- /fastNLP/core/_parallel_utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [] 4 | 5 | import threading 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.parallel.parallel_apply import get_a_var 10 | from torch.nn.parallel.replicate import replicate 11 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 12 | 13 | 14 | def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): 15 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 16 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) 17 | on each of :attr:`devices`. 18 | 19 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and 20 | :attr:`devices` (if given) should all have same length. Moreover, each 21 | element of :attr:`inputs` can either be a single object as the only argument 22 | to a module, or a collection of positional arguments. 23 | """ 24 | assert len(modules) == len(inputs) 25 | if kwargs_tup is not None: 26 | assert len(modules) == len(kwargs_tup) 27 | else: 28 | kwargs_tup = ({},) * len(modules) 29 | if devices is not None: 30 | assert len(modules) == len(devices) 31 | else: 32 | devices = [None] * len(modules) 33 | 34 | lock = threading.Lock() 35 | results = {} 36 | grad_enabled = torch.is_grad_enabled() 37 | 38 | def _worker(i, module, input, kwargs, device=None): 39 | torch.set_grad_enabled(grad_enabled) 40 | if device is None: 41 | device = get_a_var(input).get_device() 42 | try: 43 | with torch.cuda.device(device): 44 | # this also avoids accidental slicing of `input` if it is a Tensor 45 | if not isinstance(input, (list, tuple)): 46 | input = (input,) 47 | output = getattr(module, func_name)(*input, **kwargs) 48 | with lock: 49 | results[i] = output 50 | except Exception as e: 51 | with lock: 52 | results[i] = e 53 | 54 | if len(modules) > 1: 55 | threads = [threading.Thread(target=_worker, 56 | args=(i, module, input, kwargs, device)) 57 | for i, (module, input, kwargs, device) in 58 | enumerate(zip(modules, inputs, kwargs_tup, devices))] 59 | 60 | for thread in threads: 61 | thread.start() 62 | for thread in threads: 63 | thread.join() 64 | else: 65 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 66 | 67 | outputs = [] 68 | for i in range(len(inputs)): 69 | output = results[i] 70 | if isinstance(output, Exception): 71 | raise output 72 | outputs.append(output) 73 | return outputs 74 | 75 | 76 | def _data_parallel_wrapper(func_name, device_ids, output_device): 77 | r""" 78 | 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数 79 | 80 | :param str, func_name: 对network中的这个函数进行多卡运行 81 | :param device_ids: nn.DataParallel中的device_ids 82 | :param output_device: nn.DataParallel中的output_device 83 | :return: 84 | """ 85 | 86 | def wrapper(network, *inputs, **kwargs): 87 | inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) 88 | if len(device_ids) == 1: 89 | return getattr(network, func_name)(*inputs[0], **kwargs[0]) 90 | replicas = replicate(network, device_ids[:len(inputs)]) 91 | outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) 92 | return gather(outputs, output_device) 93 | 94 | return wrapper 95 | 96 | 97 | def _model_contains_inner_module(model): 98 | r""" 99 | 100 | :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, 101 | nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 102 | :return: bool 103 | """ 104 | if isinstance(model, nn.Module): 105 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): 106 | return True 107 | return False 108 | -------------------------------------------------------------------------------- /fastNLP/core/collate_fn.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | from builtins import sorted 3 | 4 | import torch 5 | import numpy as np 6 | from .field import _get_ele_type_and_dim 7 | from .utils import logger 8 | from copy import deepcopy 9 | 10 | 11 | def _check_type(batch_dict, fields): 12 | if len(fields) == 0: 13 | raise RuntimeError 14 | types = [] 15 | dims = [] 16 | for f in fields: 17 | t, d = _get_ele_type_and_dim(batch_dict[f]) 18 | types.append(t) 19 | dims.append(d) 20 | diff_types = set(types) 21 | diff_dims = set(dims) 22 | if len(diff_types) > 1 or len(diff_dims) > 1: 23 | raise ValueError 24 | return types[0] 25 | 26 | 27 | def batching(samples, max_len=0, padding_val=0): 28 | if len(samples) == 0: 29 | return samples 30 | if max_len <= 0: 31 | max_len = max(s.shape[0] for s in samples) 32 | batch = np.full((len(samples), max_len), fill_value=padding_val) 33 | for i, s in enumerate(samples): 34 | slen = min(s.shape[0], max_len) 35 | batch[i][:slen] = s[:slen] 36 | return batch 37 | 38 | 39 | class Collater: 40 | r""" 41 | 辅助DataSet管理collate_fn的类 42 | 43 | """ 44 | def __init__(self): 45 | self.collate_fns = {} 46 | 47 | def add_fn(self, fn, name=None): 48 | r""" 49 | 向collater新增一个collate_fn函数 50 | 51 | :param callable fn: 52 | :param str,int name: 53 | :return: 54 | """ 55 | if name in self.collate_fns: 56 | logger.warn(f"collate_fn:{name} will be overwritten.") 57 | if name is None: 58 | name = len(self.collate_fns) 59 | self.collate_fns[name] = fn 60 | 61 | def is_empty(self): 62 | r""" 63 | 返回是否包含collate_fn 64 | 65 | :return: 66 | """ 67 | return len(self.collate_fns) == 0 68 | 69 | def delete_fn(self, name=None): 70 | r""" 71 | 删除collate_fn 72 | 73 | :param str,int name: 如果为None就删除最近加入的collate_fn 74 | :return: 75 | """ 76 | if not self.is_empty(): 77 | if name in self.collate_fns: 78 | self.collate_fns.pop(name) 79 | elif name is None: 80 | last_key = list(self.collate_fns.keys())[0] 81 | self.collate_fns.pop(last_key) 82 | 83 | def collate_batch(self, ins_list): 84 | bx, by = {}, {} 85 | for name, fn in self.collate_fns.items(): 86 | try: 87 | batch_x, batch_y = fn(ins_list) 88 | except BaseException as e: 89 | logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.") 90 | raise e 91 | bx.update(batch_x) 92 | by.update(batch_y) 93 | return bx, by 94 | 95 | def copy_from(self, col): 96 | assert isinstance(col, Collater) 97 | new_col = Collater() 98 | new_col.collate_fns = deepcopy(col.collate_fns) 99 | return new_col 100 | 101 | 102 | class ConcatCollateFn: 103 | r""" 104 | field拼接collate_fn,将不同field按序拼接后,padding产生数据。 105 | 106 | :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field 107 | :param str output: 拼接后的field名称 108 | :param pad_val: padding的数值 109 | :param max_len: 拼接后最大长度 110 | :param is_input: 是否将生成的output设置为input 111 | :param is_target: 是否将生成的output设置为target 112 | """ 113 | 114 | def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False): 115 | super().__init__() 116 | assert isinstance(inputs, list) 117 | self.inputs = inputs 118 | self.output = output 119 | self.pad_val = pad_val 120 | self.max_len = max_len 121 | self.is_input = is_input 122 | self.is_target = is_target 123 | 124 | @staticmethod 125 | def _to_numpy(seq): 126 | if torch.is_tensor(seq): 127 | return seq.numpy() 128 | else: 129 | return np.array(seq) 130 | 131 | def __call__(self, ins_list): 132 | samples = [] 133 | for i, ins in ins_list: 134 | sample = [] 135 | for input_name in self.inputs: 136 | sample.append(self._to_numpy(ins[input_name])) 137 | samples.append(np.concatenate(sample, axis=0)) 138 | batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val) 139 | b_x, b_y = {}, {} 140 | if self.is_input: 141 | b_x[self.output] = batch 142 | if self.is_target: 143 | b_y[self.output] = batch 144 | 145 | return b_x, b_y 146 | -------------------------------------------------------------------------------- /fastNLP/core/const.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP包当中的field命名均符合一定的规范,该规范由fastNLP.Const类进行定义。 3 | """ 4 | 5 | __all__ = [ 6 | "Const" 7 | ] 8 | 9 | 10 | class Const: 11 | r""" 12 | fastNLP中field命名常量。 13 | 14 | .. todo:: 15 | 把下面这段改成表格 16 | 17 | 具体列表:: 18 | 19 | INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) 20 | CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) 21 | INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) 22 | OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) 23 | TARGET 真实目标 target(具有多列target时,依次使用target1,target2) 24 | LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) 25 | RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) 26 | RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) 27 | 28 | """ 29 | INPUT = 'words' 30 | CHAR_INPUT = 'chars' 31 | INPUT_LEN = 'seq_len' 32 | OUTPUT = 'pred' 33 | TARGET = 'target' 34 | LOSS = 'loss' 35 | RAW_WORD = 'raw_words' 36 | RAW_CHAR = 'raw_chars' 37 | 38 | @staticmethod 39 | def INPUTS(i): 40 | r"""得到第 i 个 ``INPUT`` 的命名""" 41 | i = int(i) + 1 42 | return Const.INPUT + str(i) 43 | 44 | @staticmethod 45 | def CHAR_INPUTS(i): 46 | r"""得到第 i 个 ``CHAR_INPUT`` 的命名""" 47 | i = int(i) + 1 48 | return Const.CHAR_INPUT + str(i) 49 | 50 | @staticmethod 51 | def RAW_WORDS(i): 52 | r"""得到第 i 个 ``RAW_WORDS`` 的命名""" 53 | i = int(i) + 1 54 | return Const.RAW_WORD + str(i) 55 | 56 | @staticmethod 57 | def RAW_CHARS(i): 58 | r"""得到第 i 个 ``RAW_CHARS`` 的命名""" 59 | i = int(i) + 1 60 | return Const.RAW_CHAR + str(i) 61 | 62 | @staticmethod 63 | def INPUT_LENS(i): 64 | r"""得到第 i 个 ``INPUT_LEN`` 的命名""" 65 | i = int(i) + 1 66 | return Const.INPUT_LEN + str(i) 67 | 68 | @staticmethod 69 | def OUTPUTS(i): 70 | r"""得到第 i 个 ``OUTPUT`` 的命名""" 71 | i = int(i) + 1 72 | return Const.OUTPUT + str(i) 73 | 74 | @staticmethod 75 | def TARGETS(i): 76 | r"""得到第 i 个 ``TARGET`` 的命名""" 77 | i = int(i) + 1 78 | return Const.TARGET + str(i) 79 | 80 | @staticmethod 81 | def LOSSES(i): 82 | r"""得到第 i 个 ``LOSS`` 的命名""" 83 | i = int(i) + 1 84 | return Const.LOSS + str(i) 85 | -------------------------------------------------------------------------------- /fastNLP/core/instance.py: -------------------------------------------------------------------------------- 1 | r""" 2 | instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 3 | 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 4 | 5 | """ 6 | 7 | __all__ = [ 8 | "Instance" 9 | ] 10 | 11 | from .utils import pretty_table_printer 12 | 13 | 14 | class Instance(object): 15 | r""" 16 | Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 17 | Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: 18 | 19 | >>>from fastNLP import Instance 20 | >>>ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) 21 | >>>ins["field_1"] 22 | [1, 1, 1] 23 | >>>ins.add_field("field_3", [3, 3, 3]) 24 | >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) 25 | """ 26 | 27 | def __init__(self, **fields): 28 | 29 | self.fields = fields 30 | 31 | def add_field(self, field_name, field): 32 | r""" 33 | 向Instance中增加一个field 34 | 35 | :param str field_name: 新增field的名称 36 | :param Any field: 新增field的内容 37 | """ 38 | self.fields[field_name] = field 39 | 40 | def items(self): 41 | r""" 42 | 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value 43 | 44 | :return: 一个迭代器 45 | """ 46 | return self.fields.items() 47 | 48 | def __contains__(self, item): 49 | return item in self.fields 50 | 51 | def __getitem__(self, name): 52 | if name in self.fields: 53 | return self.fields[name] 54 | else: 55 | print(name) 56 | raise KeyError("{} not found".format(name)) 57 | 58 | def __setitem__(self, name, field): 59 | return self.add_field(name, field) 60 | 61 | def __repr__(self): 62 | return str(pretty_table_printer(self)) 63 | -------------------------------------------------------------------------------- /fastNLP/core/predictor.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Predictor" 5 | ] 6 | 7 | from collections import defaultdict 8 | 9 | import torch 10 | 11 | from . import DataSet 12 | from . import DataSetIter 13 | from . import SequentialSampler 14 | from .utils import _build_args, _move_dict_value_to_device, _get_model_device 15 | 16 | 17 | class Predictor(object): 18 | r""" 19 | 一个根据训练模型预测输出的预测器(Predictor) 20 | 21 | 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 22 | 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 23 | """ 24 | 25 | def __init__(self, network): 26 | r""" 27 | 28 | :param torch.nn.Module network: 用来完成预测任务的模型 29 | """ 30 | if not isinstance(network, torch.nn.Module): 31 | raise ValueError( 32 | "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) 33 | self.network = network 34 | self.batch_size = 1 35 | self.batch_output = [] 36 | 37 | def predict(self, data: DataSet, seq_len_field_name=None): 38 | r"""用已经训练好的模型进行inference. 39 | 40 | :param fastNLP.DataSet data: 待预测的数据集 41 | :param str seq_len_field_name: 表示序列长度信息的field名字 42 | :return: dict dict里面的内容为模型预测的结果 43 | """ 44 | if not isinstance(data, DataSet): 45 | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) 46 | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: 47 | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) 48 | 49 | prev_training = self.network.training 50 | self.network.eval() 51 | network_device = _get_model_device(self.network) 52 | batch_output = defaultdict(list) 53 | data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) 54 | 55 | if hasattr(self.network, "predict"): 56 | predict_func = self.network.predict 57 | else: 58 | predict_func = self.network.forward 59 | 60 | with torch.no_grad(): 61 | for batch_x, _ in data_iterator: 62 | _move_dict_value_to_device(batch_x, _, device=network_device) 63 | refined_batch_x = _build_args(predict_func, **batch_x) 64 | prediction = predict_func(**refined_batch_x) 65 | 66 | if seq_len_field_name is not None: 67 | seq_lens = batch_x[seq_len_field_name].tolist() 68 | 69 | for key, value in prediction.items(): 70 | value = value.cpu().numpy() 71 | if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): 72 | batch_output[key].extend(value.tolist()) 73 | else: 74 | if seq_len_field_name is not None: 75 | tmp_batch = [] 76 | for idx, seq_len in enumerate(seq_lens): 77 | tmp_batch.append(value[idx, :seq_len]) 78 | batch_output[key].extend(tmp_batch) 79 | else: 80 | batch_output[key].append(value) 81 | 82 | self.network.train(prev_training) 83 | return batch_output 84 | -------------------------------------------------------------------------------- /fastNLP/doc_utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented 2 | 用于辅助生成 fastNLP 文档的代码 3 | """ 4 | 5 | __all__ = [] 6 | 7 | import inspect 8 | import sys 9 | 10 | 11 | def doc_process(m): 12 | for name, obj in inspect.getmembers(m): 13 | if inspect.isclass(obj) or inspect.isfunction(obj): 14 | if obj.__module__ != m.__name__: 15 | if obj.__doc__ is None: 16 | # print(name, obj.__doc__) 17 | pass 18 | else: 19 | module_name = obj.__module__ 20 | 21 | # 识别并标注类和函数在不同层次中的位置 22 | 23 | while 1: 24 | defined_m = sys.modules[module_name] 25 | try: 26 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: 27 | obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \ 28 | + " :class:`" + module_name + "." + name + "`\n" + obj.__doc__ 29 | break 30 | module_name = ".".join(module_name.split('.')[:-1]) 31 | if module_name == m.__name__: 32 | # print(name, ": not found defined doc.") 33 | break 34 | except: 35 | print("Warning: Module {} lacks `__doc__`".format(module_name)) 36 | break 37 | 38 | # 识别并标注基类,只有基类也在 fastNLP 中定义才显示 39 | 40 | if inspect.isclass(obj): 41 | for base in obj.__bases__: 42 | if base.__module__.startswith("fastNLP"): 43 | parts = base.__module__.split(".") + [] 44 | module_name, i = "fastNLP", 1 45 | for i in range(len(parts) - 1): 46 | defined_m = sys.modules[module_name] 47 | try: 48 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: 49 | obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__ 50 | break 51 | module_name += "." + parts[i + 1] 52 | except: 53 | print("Warning: Module {} lacks `__doc__`".format(module_name)) 54 | break 55 | -------------------------------------------------------------------------------- /fastNLP/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有 3 | embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的 4 | torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的 5 | 输出维度。 6 | """ 7 | 8 | __all__ = [ 9 | "Embedding", 10 | "TokenEmbedding", 11 | "StaticEmbedding", 12 | "ElmoEmbedding", 13 | "BertEmbedding", 14 | "BertWordPieceEncoder", 15 | 16 | "RobertaEmbedding", 17 | "RobertaWordPieceEncoder", 18 | 19 | "GPT2Embedding", 20 | "GPT2WordPieceEncoder", 21 | 22 | "StackEmbedding", 23 | "LSTMCharEmbedding", 24 | "CNNCharEmbedding", 25 | 26 | "get_embeddings", 27 | "get_sinusoid_encoding_table" 28 | ] 29 | 30 | from .embedding import Embedding, TokenEmbedding 31 | from .static_embedding import StaticEmbedding 32 | from .elmo_embedding import ElmoEmbedding 33 | from .bert_embedding import BertEmbedding, BertWordPieceEncoder 34 | from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder 35 | from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding 36 | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding 37 | from .stack_embedding import StackEmbedding 38 | from .utils import get_embeddings, get_sinusoid_encoding_table 39 | 40 | import sys 41 | from ..doc_utils import doc_process 42 | doc_process(sys.modules[__name__]) -------------------------------------------------------------------------------- /fastNLP/embeddings/contextual_embedding.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "ContextualEmbedding" 8 | ] 9 | 10 | from abc import abstractmethod 11 | 12 | import torch 13 | 14 | from .embedding import TokenEmbedding 15 | from ..core import logger 16 | from ..core.batch import DataSetIter 17 | from ..core.dataset import DataSet 18 | from ..core.sampler import SequentialSampler 19 | from ..core.utils import _move_model_to_device, _get_model_device 20 | from ..core.vocabulary import Vocabulary 21 | 22 | 23 | class ContextualEmbedding(TokenEmbedding): 24 | r""" 25 | ContextualEmbedding组件. BertEmbedding与ElmoEmbedding的基类 26 | """ 27 | def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): 28 | super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) 29 | 30 | def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): 31 | r""" 32 | 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 33 | 34 | :param datasets: DataSet对象 35 | :param batch_size: int, 生成cache的sentence表示时使用的batch的大小 36 | :param device: 参考 :class::fastNLP.Trainer 的device 37 | :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。 38 | :return: 39 | """ 40 | for index, dataset in enumerate(datasets): 41 | try: 42 | assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." 43 | assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." 44 | except Exception as e: 45 | logger.error(f"Exception happens at {index} dataset.") 46 | raise e 47 | 48 | sent_embeds = {} 49 | _move_model_to_device(self, device=device) 50 | device = _get_model_device(self) 51 | pad_index = self._word_vocab.padding_idx 52 | logger.info("Start to calculate sentence representations.") 53 | with torch.no_grad(): 54 | for index, dataset in enumerate(datasets): 55 | try: 56 | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) 57 | for batch_x, batch_y in batch: 58 | words = batch_x['words'].to(device) 59 | words_list = words.tolist() 60 | seq_len = words.ne(pad_index).sum(dim=-1) 61 | max_len = words.size(1) 62 | # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 63 | seq_len_from_behind = (max_len - seq_len).tolist() 64 | word_embeds = self(words).detach().cpu().numpy() 65 | for b in range(words.size(0)): 66 | length = seq_len_from_behind[b] 67 | if length == 0: 68 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b] 69 | else: 70 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] 71 | except Exception as e: 72 | logger.error(f"Exception happens at {index} dataset.") 73 | raise e 74 | logger.info("Finish calculating sentence representations.") 75 | self.sent_embeds = sent_embeds 76 | if delete_weights: 77 | self._delete_model_weights() 78 | 79 | def _get_sent_reprs(self, words): 80 | r""" 81 | 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None 82 | 83 | :param words: torch.LongTensor 84 | :return: 85 | """ 86 | if hasattr(self, 'sent_embeds'): 87 | words_list = words.tolist() 88 | seq_len = words.ne(self._word_pad_index).sum(dim=-1) 89 | _embeds = [] 90 | for b in range(len(words)): 91 | words_i = tuple(words_list[b][:seq_len[b]]) 92 | embed = self.sent_embeds[words_i] 93 | _embeds.append(embed) 94 | max_sent_len = max(map(len, _embeds)) 95 | embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float, 96 | device=words.device) 97 | for i, embed in enumerate(_embeds): 98 | embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) 99 | return embeds 100 | return None 101 | 102 | @abstractmethod 103 | def _delete_model_weights(self): 104 | r"""删除计算表示的模型以节省资源""" 105 | raise NotImplementedError 106 | 107 | def remove_sentence_cache(self): 108 | r""" 109 | 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 110 | 111 | :return: 112 | """ 113 | del self.sent_embeds 114 | -------------------------------------------------------------------------------- /fastNLP/embeddings/stack_embedding.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "StackEmbedding", 8 | ] 9 | 10 | from typing import List 11 | 12 | import torch 13 | from torch import nn as nn 14 | 15 | from .embedding import TokenEmbedding 16 | 17 | 18 | class StackEmbedding(TokenEmbedding): 19 | r""" 20 | 支持将多个embedding集合成一个embedding。 21 | 22 | Example:: 23 | 24 | >>> from fastNLP import Vocabulary 25 | >>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding 26 | >>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) 27 | >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) 28 | >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) 29 | >>> embed = StackEmbedding([embed_1, embed_2]) 30 | 31 | """ 32 | 33 | def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): 34 | r""" 35 | 36 | :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 37 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置 38 | 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。 39 | :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 40 | """ 41 | vocabs = [] 42 | for embed in embeds: 43 | if hasattr(embed, 'get_word_vocab'): 44 | vocabs.append(embed.get_word_vocab()) 45 | _vocab = vocabs[0] 46 | for vocab in vocabs[1:]: 47 | assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." 48 | 49 | super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) 50 | assert isinstance(embeds, list) 51 | for embed in embeds: 52 | assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." 53 | self.embeds = nn.ModuleList(embeds) 54 | self._embed_size = sum([embed.embed_size for embed in self.embeds]) 55 | 56 | def append(self, embed: TokenEmbedding): 57 | r""" 58 | 添加一个embedding到结尾。 59 | :param embed: 60 | :return: 61 | """ 62 | assert isinstance(embed, TokenEmbedding) 63 | self._embed_size += embed.embed_size 64 | self.embeds.append(embed) 65 | return self 66 | 67 | def pop(self): 68 | r""" 69 | 弹出最后一个embed 70 | :return: 71 | """ 72 | embed = self.embeds.pop() 73 | self._embed_size -= embed.embed_size 74 | return embed 75 | 76 | @property 77 | def embed_size(self): 78 | r""" 79 | 该Embedding输出的vector的最后一维的维度。 80 | :return: 81 | """ 82 | return self._embed_size 83 | 84 | def forward(self, words): 85 | r""" 86 | 得到多个embedding的结果,并把结果按照顺序concat起来。 87 | 88 | :param words: batch_size x max_len 89 | :return: 返回的shape和当前这个stack embedding中embedding的组成有关 90 | """ 91 | outputs = [] 92 | words = self.drop_word(words) 93 | for embed in self.embeds: 94 | outputs.append(embed(words)) 95 | outputs = self.dropout(torch.cat(outputs, dim=-1)) 96 | return outputs 97 | -------------------------------------------------------------------------------- /fastNLP/embeddings/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | import numpy as np 6 | import torch 7 | from torch import nn as nn 8 | 9 | from ..core.vocabulary import Vocabulary 10 | 11 | __all__ = [ 12 | 'get_embeddings', 13 | 'get_sinusoid_encoding_table' 14 | ] 15 | 16 | 17 | def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True): 18 | r""" 19 | 给定一个word的vocabulary生成character的vocabulary. 20 | 21 | :param vocab: 从vocab 22 | :param min_freq: 23 | :param include_word_start_end: 是否需要包含特殊的 24 | :return: 25 | """ 26 | char_vocab = Vocabulary(min_freq=min_freq) 27 | for word, index in vocab: 28 | if not vocab._is_word_no_create_entry(word): 29 | char_vocab.add_word_lst(list(word)) 30 | if include_word_start_end: 31 | char_vocab.add_word_lst(['', '']) 32 | return char_vocab 33 | 34 | 35 | def get_embeddings(init_embed, padding_idx=None): 36 | r""" 37 | 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray 38 | 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理 39 | 返回原对象。 40 | 41 | :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 42 | nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; 43 | 传入torch.Tensor, 将使用传入的值作为Embedding初始化。 44 | :param padding_idx: 当传入tuple时,padding_idx有效 45 | :return nn.Embedding: embeddings 46 | """ 47 | if isinstance(init_embed, tuple): 48 | res = nn.Embedding( 49 | num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx) 50 | nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), 51 | b=np.sqrt(3 / res.weight.data.size(1))) 52 | elif isinstance(init_embed, nn.Module): 53 | res = init_embed 54 | elif isinstance(init_embed, torch.Tensor): 55 | res = nn.Embedding.from_pretrained(init_embed, freeze=False) 56 | elif isinstance(init_embed, np.ndarray): 57 | init_embed = torch.tensor(init_embed, dtype=torch.float32) 58 | res = nn.Embedding.from_pretrained(init_embed, freeze=False) 59 | else: 60 | raise TypeError( 61 | 'invalid init_embed type: {}'.format((type(init_embed)))) 62 | return res 63 | 64 | 65 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 66 | """ 67 | sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos 68 | 69 | :param int n_position: 一共多少个position 70 | :param int d_hid: 多少维度,需要为偶数 71 | :param padding_idx: 72 | :return: torch.FloatTensor, shape为n_position x d_hid 73 | """ 74 | 75 | def cal_angle(position, hid_idx): 76 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) 77 | 78 | def get_posi_angle_vec(position): 79 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] 80 | 81 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) 82 | 83 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 84 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 85 | 86 | if padding_idx is not None: 87 | # zero vector for padding dimension 88 | sinusoid_table[padding_idx] = 0. 89 | 90 | return torch.FloatTensor(sinusoid_table) 91 | 92 | -------------------------------------------------------------------------------- /fastNLP/io/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 用于IO的模块, 具体包括: 3 | 4 | 1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类, 5 | 6 | 2. 用于读入不同格式数据的 :mod:`Loader ` 类 7 | 8 | 3. 用于处理读入数据的 :mod:`Pipe ` 类 9 | 10 | 4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 ` 11 | 12 | 这些类的使用方法如下: 13 | """ 14 | __all__ = [ 15 | 'DataBundle', 16 | 17 | 'EmbedLoader', 18 | 19 | 'Loader', 20 | 21 | 'CLSBaseLoader', 22 | 'AGsNewsLoader', 23 | 'DBPediaLoader', 24 | 'YelpFullLoader', 25 | 'YelpPolarityLoader', 26 | 'IMDBLoader', 27 | 'SSTLoader', 28 | 'SST2Loader', 29 | "ChnSentiCorpLoader", 30 | "THUCNewsLoader", 31 | "WeiboSenti100kLoader", 32 | 33 | 'ConllLoader', 34 | 'Conll2003Loader', 35 | 'Conll2003NERLoader', 36 | 'OntoNotesNERLoader', 37 | 'CTBLoader', 38 | "MsraNERLoader", 39 | "WeiboNERLoader", 40 | "PeopleDailyNERLoader", 41 | 42 | 'CSVLoader', 43 | 'JsonLoader', 44 | 45 | 'CWSLoader', 46 | 47 | 'MNLILoader', 48 | "QuoraLoader", 49 | "SNLILoader", 50 | "QNLILoader", 51 | "RTELoader", 52 | "CNXNLILoader", 53 | "BQCorpusLoader", 54 | "LCQMCLoader", 55 | 56 | "CMRC2018Loader", 57 | 58 | "Pipe", 59 | 60 | "CLSBasePipe", 61 | "AGsNewsPipe", 62 | "DBPediaPipe", 63 | "YelpFullPipe", 64 | "YelpPolarityPipe", 65 | "SSTPipe", 66 | "SST2Pipe", 67 | "IMDBPipe", 68 | "ChnSentiCorpPipe", 69 | "THUCNewsPipe", 70 | "WeiboSenti100kPipe", 71 | 72 | "Conll2003Pipe", 73 | "Conll2003NERPipe", 74 | "OntoNotesNERPipe", 75 | "MsraNERPipe", 76 | "PeopleDailyPipe", 77 | "WeiboNERPipe", 78 | 79 | "CWSPipe", 80 | 81 | "Conll2003NERPipe", 82 | "OntoNotesNERPipe", 83 | "MsraNERPipe", 84 | "WeiboNERPipe", 85 | "PeopleDailyPipe", 86 | "Conll2003Pipe", 87 | 88 | "MatchingBertPipe", 89 | "RTEBertPipe", 90 | "SNLIBertPipe", 91 | "QuoraBertPipe", 92 | "QNLIBertPipe", 93 | "MNLIBertPipe", 94 | "CNXNLIBertPipe", 95 | "BQCorpusBertPipe", 96 | "LCQMCBertPipe", 97 | "MatchingPipe", 98 | "RTEPipe", 99 | "SNLIPipe", 100 | "QuoraPipe", 101 | "QNLIPipe", 102 | "MNLIPipe", 103 | "LCQMCPipe", 104 | "CNXNLIPipe", 105 | "BQCorpusPipe", 106 | "RenamePipe", 107 | "GranularizePipe", 108 | "MachingTruncatePipe", 109 | 110 | "CMRC2018BertPipe", 111 | 112 | 'ModelLoader', 113 | 'ModelSaver', 114 | 115 | ] 116 | 117 | import sys 118 | 119 | from .data_bundle import DataBundle 120 | from .embed_loader import EmbedLoader 121 | from .loader import * 122 | from .model_io import ModelLoader, ModelSaver 123 | from .pipe import * 124 | from ..doc_utils import doc_process 125 | 126 | doc_process(sys.modules[__name__]) -------------------------------------------------------------------------------- /fastNLP/io/loader/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 3 | 三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, 4 | 读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; 5 | ``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: 6 | 7 | 0.传入None 8 | 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 9 | 10 | 1.传入一个文件的 path 11 | 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.get_dataset('train')`` 获取 12 | 13 | 2.传入一个文件夹目录 14 | 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: 15 | 16 | | 17 | +-train3.txt 18 | +-dev.txt 19 | +-test.txt 20 | +-other.txt 21 | 22 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , 23 | ``data_bundle.get_dataset('dev')`` , 24 | ``data_bundle.get_dataset('test')`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: 25 | 26 | | 27 | +-train3.txt 28 | +-dev.txt 29 | 30 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , 31 | ``data_bundle.get_dataset('dev')`` 获取对应的 dataset。 32 | 33 | 3.传入一个字典 34 | 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: 35 | 36 | paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} 37 | 38 | 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` , 39 | ``data_bundle.get_dataset('test')`` 来获取对应的 `dataset` 40 | 41 | fastNLP 目前提供了如下的 Loader 42 | 43 | 44 | 45 | """ 46 | 47 | __all__ = [ 48 | 'Loader', 49 | 50 | 'CLSBaseLoader', 51 | 'YelpFullLoader', 52 | 'YelpPolarityLoader', 53 | 'AGsNewsLoader', 54 | 'DBPediaLoader', 55 | 'IMDBLoader', 56 | 'SSTLoader', 57 | 'SST2Loader', 58 | "ChnSentiCorpLoader", 59 | "THUCNewsLoader", 60 | "WeiboSenti100kLoader", 61 | 62 | 'ConllLoader', 63 | 'Conll2003Loader', 64 | 'Conll2003NERLoader', 65 | 'OntoNotesNERLoader', 66 | 'CTBLoader', 67 | "MsraNERLoader", 68 | "PeopleDailyNERLoader", 69 | "WeiboNERLoader", 70 | 71 | 'CSVLoader', 72 | 'JsonLoader', 73 | 74 | 'CWSLoader', 75 | 76 | 'MNLILoader', 77 | "QuoraLoader", 78 | "SNLILoader", 79 | "QNLILoader", 80 | "RTELoader", 81 | "CNXNLILoader", 82 | "BQCorpusLoader", 83 | "LCQMCLoader", 84 | 85 | "CoReferenceLoader", 86 | 87 | "CMRC2018Loader" 88 | ] 89 | from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \ 90 | SSTLoader, SST2Loader, DBPediaLoader, \ 91 | ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader 92 | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader 93 | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader 94 | from .coreference import CoReferenceLoader 95 | from .csv import CSVLoader 96 | from .cws import CWSLoader 97 | from .json import JsonLoader 98 | from .loader import Loader 99 | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ 100 | LCQMCLoader 101 | from .qa import CMRC2018Loader 102 | 103 | -------------------------------------------------------------------------------- /fastNLP/io/loader/coreference.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CoReferenceLoader", 5 | ] 6 | 7 | from ...core.dataset import DataSet 8 | from ..file_reader import _read_json 9 | from ...core.instance import Instance 10 | from ...core.const import Const 11 | from .json import JsonLoader 12 | 13 | 14 | class CoReferenceLoader(JsonLoader): 15 | r""" 16 | 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 17 | 18 | Example:: 19 | 20 | {"doc_key": "bc/cctv/00/cctv_0000_0", 21 | "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], 22 | "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], 23 | "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]] 24 | } 25 | 26 | 读取预处理好的Conll2012数据,数据结构如下: 27 | 28 | .. csv-table:: 29 | :header: "raw_words1", "raw_words2", "raw_words3", "raw_words4" 30 | 31 | "bc/cctv/00/cctv_0000_0", "[['Speaker#1', 'Speaker#1', 'Speaker#1...", "[[[70, 70], [485, 486], [500, 500], [7...", "[['In', 'the', 'summer', 'of', '2005',..." 32 | "...", "...", "...", "..." 33 | 34 | """ 35 | def __init__(self, fields=None, dropna=False): 36 | super().__init__(fields, dropna) 37 | self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2), 38 | "sentences": Const.RAW_WORDS(3)} 39 | 40 | def _load(self, path): 41 | r""" 42 | 加载数据 43 | :param path: 数据文件路径,文件为json 44 | 45 | :return: 46 | """ 47 | dataset = DataSet() 48 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): 49 | if self.fields: 50 | ins = {self.fields[k]: v for k, v in d.items()} 51 | else: 52 | ins = d 53 | dataset.append(Instance(**ins)) 54 | return dataset 55 | 56 | def download(self): 57 | r""" 58 | 由于版权限制,不能提供自动下载功能。可参考 59 | 60 | https://www.aclweb.org/anthology/W12-4501 61 | 62 | :return: 63 | """ 64 | raise RuntimeError("CoReference cannot be downloaded automatically.") 65 | -------------------------------------------------------------------------------- /fastNLP/io/loader/csv.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CSVLoader", 5 | ] 6 | 7 | from .loader import Loader 8 | from ..file_reader import _read_csv 9 | from ...core.dataset import DataSet 10 | from ...core.instance import Instance 11 | 12 | 13 | class CSVLoader(Loader): 14 | r""" 15 | 读取CSV格式的数据集, 返回 ``DataSet`` 。 16 | 17 | """ 18 | 19 | def __init__(self, headers=None, sep=",", dropna=False): 20 | r""" 21 | 22 | :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 23 | 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` 24 | :param str sep: CSV文件中列与列之间的分隔符. Default: "," 25 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . 26 | Default: ``False`` 27 | """ 28 | super().__init__() 29 | self.headers = headers 30 | self.sep = sep 31 | self.dropna = dropna 32 | 33 | def _load(self, path): 34 | ds = DataSet() 35 | for idx, data in _read_csv(path, headers=self.headers, 36 | sep=self.sep, dropna=self.dropna): 37 | ds.append(Instance(**data)) 38 | return ds 39 | 40 | -------------------------------------------------------------------------------- /fastNLP/io/loader/cws.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "CWSLoader" 5 | ] 6 | 7 | import glob 8 | import os 9 | import random 10 | import shutil 11 | import time 12 | 13 | from .loader import Loader 14 | from ...core.dataset import DataSet 15 | from ...core.instance import Instance 16 | 17 | 18 | class CWSLoader(Loader): 19 | r""" 20 | CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: 21 | 22 | Example:: 23 | 24 | 上海 浦东 开发 与 法制 建设 同步 25 | 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) 26 | ... 27 | 28 | 该Loader读取后的DataSet具有如下的结构 29 | 30 | .. csv-table:: 31 | :header: "raw_words" 32 | 33 | "上海 浦东 开发 与 法制 建设 同步" 34 | "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" 35 | "..." 36 | 37 | """ 38 | def __init__(self, dataset_name:str=None): 39 | r""" 40 | 41 | :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None 42 | """ 43 | super().__init__() 44 | datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'} 45 | if dataset_name in datanames: 46 | self.dataset_name = datanames[dataset_name] 47 | else: 48 | self.dataset_name = None 49 | 50 | def _load(self, path:str): 51 | ds = DataSet() 52 | with open(path, 'r', encoding='utf-8') as f: 53 | for line in f: 54 | line = line.strip() 55 | if line: 56 | ds.append(Instance(raw_words=line)) 57 | return ds 58 | 59 | def download(self, dev_ratio=0.1, re_download=False)->str: 60 | r""" 61 | 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, 62 | 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 63 | 64 | :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 65 | :param bool re_download: 是否重新下载数据,以重新切分数据。 66 | :return: str 67 | """ 68 | if self.dataset_name is None: 69 | return None 70 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name) 71 | modify_time = 0 72 | for filepath in glob.glob(os.path.join(data_dir, '*')): 73 | modify_time = os.stat(filepath).st_mtime 74 | break 75 | if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 76 | shutil.rmtree(data_dir) 77 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name) 78 | 79 | if not os.path.exists(os.path.join(data_dir, 'dev.txt')): 80 | if dev_ratio > 0: 81 | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." 82 | try: 83 | with open(os.path.join(data_dir, 'train3.txt'), 'r', encoding='utf-8') as f, \ 84 | open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ 85 | open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: 86 | for line in f: 87 | if random.random() < dev_ratio: 88 | f2.write(line) 89 | else: 90 | f1.write(line) 91 | os.remove(os.path.join(data_dir, 'train3.txt')) 92 | os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train3.txt')) 93 | finally: 94 | if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): 95 | os.remove(os.path.join(data_dir, 'middle_file.txt')) 96 | 97 | return data_dir 98 | -------------------------------------------------------------------------------- /fastNLP/io/loader/json.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "JsonLoader" 5 | ] 6 | 7 | from .loader import Loader 8 | from ..file_reader import _read_json 9 | from ...core.dataset import DataSet 10 | from ...core.instance import Instance 11 | 12 | 13 | class JsonLoader(Loader): 14 | r""" 15 | 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` 16 | 17 | 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 18 | 19 | :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name 20 | ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , 21 | `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 22 | ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` 23 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . 24 | Default: ``False`` 25 | """ 26 | 27 | def __init__(self, fields=None, dropna=False): 28 | super(JsonLoader, self).__init__() 29 | self.dropna = dropna 30 | self.fields = None 31 | self.fields_list = None 32 | if fields: 33 | self.fields = {} 34 | for k, v in fields.items(): 35 | self.fields[k] = k if v is None else v 36 | self.fields_list = list(self.fields.keys()) 37 | 38 | def _load(self, path): 39 | ds = DataSet() 40 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): 41 | if self.fields: 42 | ins = {self.fields[k]: v for k, v in d.items()} 43 | else: 44 | ins = d 45 | ds.append(Instance(**ins)) 46 | return ds 47 | -------------------------------------------------------------------------------- /fastNLP/io/loader/loader.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Loader" 5 | ] 6 | 7 | from typing import Union, Dict 8 | 9 | from .. import DataBundle 10 | from ..file_utils import _get_dataset_url, get_cache_path, cached_path 11 | from ..utils import check_loader_paths 12 | from ...core.dataset import DataSet 13 | 14 | 15 | class Loader: 16 | r""" 17 | 各种数据 Loader 的基类,提供了 API 的参考. 18 | Loader支持以下的三个函数 19 | 20 | - download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。 21 | - _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。 22 | - load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回 23 | 24 | """ 25 | 26 | def __init__(self): 27 | pass 28 | 29 | def _load(self, path: str) -> DataSet: 30 | r""" 31 | 给定一个路径,返回读取的DataSet。 32 | 33 | :param str path: 路径 34 | :return: DataSet 35 | """ 36 | raise NotImplementedError 37 | 38 | def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: 39 | r""" 40 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 41 | 42 | :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: 43 | 44 | 0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 45 | 46 | 1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: 47 | 48 | data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train 49 | # dev、 test等有所变化,可以通过以下的方式取出DataSet 50 | tr_data = data_bundle.get_dataset('train') 51 | te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 52 | 53 | 2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: 54 | 55 | paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} 56 | data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" 57 | dev_data = data_bundle.get_dataset('dev') 58 | 59 | 3.传入文件路径:: 60 | 61 | data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' 62 | tr_data = data_bundle.get_dataset('train') # 取出DataSet 63 | 64 | :return: 返回的 :class:`~fastNLP.io.DataBundle` 65 | """ 66 | if paths is None: 67 | paths = self.download() 68 | paths = check_loader_paths(paths) 69 | datasets = {name: self._load(path) for name, path in paths.items()} 70 | data_bundle = DataBundle(datasets=datasets) 71 | return data_bundle 72 | 73 | def download(self) -> str: 74 | r""" 75 | 自动下载该数据集 76 | 77 | :return: 下载后解压目录 78 | """ 79 | raise NotImplementedError(f"{self.__class__} cannot download data automatically.") 80 | 81 | @staticmethod 82 | def _get_dataset_path(dataset_name): 83 | r""" 84 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) 85 | 86 | :param str dataset_name: 数据集的名称 87 | :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 88 | """ 89 | 90 | default_cache_path = get_cache_path() 91 | url = _get_dataset_url(dataset_name) 92 | output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') 93 | 94 | return output_dir 95 | -------------------------------------------------------------------------------- /fastNLP/io/loader/qa.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 该文件中的Loader主要用于读取问答式任务的数据 3 | 4 | """ 5 | 6 | 7 | from . import Loader 8 | import json 9 | from ...core import DataSet, Instance 10 | 11 | __all__ = ['CMRC2018Loader'] 12 | 13 | 14 | class CMRC2018Loader(Loader): 15 | r""" 16 | 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 17 | 18 | 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 19 | 20 | .. csv-table:: 21 | :header:"title", "context", "question", "answers", "answer_starts", "id" 22 | 23 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" 24 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" 25 | "...", "...", "...","...", ".", "..." 26 | 27 | 其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 28 | 29 | 验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) 30 | 31 | .. csv-table:: 32 | :header: "title", "context", "question", "answers", "answer_starts", "id" 33 | 34 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", "['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']", "[30, 30, 30]", "DEV_0_QUERY_0" 35 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1" 36 | "...", "...", "...","...", ".", "..." 37 | 38 | 其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 39 | 英文和数字都直接按照character计量的。 40 | """ 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def _load(self, path: str) -> DataSet: 45 | with open(path, 'r', encoding='utf-8') as f: 46 | data = json.load(f)['data'] 47 | ds = DataSet() 48 | for entry in data: 49 | title = entry['title'] 50 | para = entry['paragraphs'][0] 51 | context = para['context'] 52 | qas = para['qas'] 53 | for qa in qas: 54 | question = qa['question'] 55 | ans = qa['answers'] 56 | answers = [] 57 | answer_starts = [] 58 | id = qa['id'] 59 | for an in ans: 60 | answers.append(an['text']) 61 | answer_starts.append(an['answer_start']) 62 | ds.append(Instance(title=title, context=context, question=question, answers=answers, 63 | answer_starts=answer_starts,id=id)) 64 | return ds 65 | 66 | def download(self) -> str: 67 | r""" 68 | 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. 69 | 70 | :return: 71 | """ 72 | output_dir = self._get_dataset_path('cmrc2018') 73 | return output_dir 74 | 75 | -------------------------------------------------------------------------------- /fastNLP/io/loader/summarization.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ExtCNNDMLoader" 5 | ] 6 | 7 | import os 8 | from typing import Union, Dict 9 | 10 | from ..data_bundle import DataBundle 11 | from ..utils import check_loader_paths 12 | from .json import JsonLoader 13 | 14 | 15 | class ExtCNNDMLoader(JsonLoader): 16 | r""" 17 | 读取之后的DataSet中的field情况为 18 | 19 | .. csv-table:: 20 | :header: "text", "summary", "label", "publication" 21 | 22 | ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" 23 | ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" 24 | ["..."], ["..."], [], "cnndm" 25 | 26 | """ 27 | 28 | def __init__(self, fields=None): 29 | fields = fields or {"text": None, "summary": None, "label": None, "publication": None} 30 | super(ExtCNNDMLoader, self).__init__(fields=fields) 31 | 32 | def load(self, paths: Union[str, Dict[str, str]] = None): 33 | r""" 34 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 35 | 36 | 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 37 | 38 | :param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl 39 | test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` 40 | 当中需要用到)。 41 | 42 | :return: 返回 :class:`~fastNLP.io.DataBundle` 43 | """ 44 | if paths is None: 45 | paths = self.download() 46 | paths = check_loader_paths(paths) 47 | if ('train' in paths) and ('test' not in paths): 48 | paths['test'] = paths['train'] 49 | paths.pop('train') 50 | 51 | datasets = {name: self._load(path) for name, path in paths.items()} 52 | data_bundle = DataBundle(datasets=datasets) 53 | return data_bundle 54 | 55 | def download(self): 56 | r""" 57 | 如果你使用了这个数据,请引用 58 | 59 | https://arxiv.org/pdf/1506.03340.pdf 60 | :return: 61 | """ 62 | output_dir = self._get_dataset_path('ext-cnndm') 63 | return output_dir 64 | -------------------------------------------------------------------------------- /fastNLP/io/model_io.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 用于载入和保存模型 3 | """ 4 | __all__ = [ 5 | "ModelLoader", 6 | "ModelSaver" 7 | ] 8 | 9 | import torch 10 | 11 | 12 | class ModelLoader: 13 | r""" 14 | 用于读取模型 15 | """ 16 | 17 | def __init__(self): 18 | super(ModelLoader, self).__init__() 19 | 20 | @staticmethod 21 | def load_pytorch(empty_model, model_path): 22 | r""" 23 | 从 ".pkl" 文件读取 PyTorch 模型 24 | 25 | :param empty_model: 初始化参数的 PyTorch 模型 26 | :param str model_path: 模型保存的路径 27 | """ 28 | empty_model.load_state_dict(torch.load(model_path)) 29 | 30 | @staticmethod 31 | def load_pytorch_model(model_path): 32 | r""" 33 | 读取整个模型 34 | 35 | :param str model_path: 模型保存的路径 36 | """ 37 | return torch.load(model_path) 38 | 39 | 40 | class ModelSaver(object): 41 | r""" 42 | 用于保存模型 43 | 44 | Example:: 45 | 46 | saver = ModelSaver("./save/model_ckpt_100.pkl") 47 | saver.save_pytorch(model) 48 | 49 | """ 50 | 51 | def __init__(self, save_path): 52 | r""" 53 | 54 | :param save_path: 模型保存的路径 55 | """ 56 | self.save_path = save_path 57 | 58 | def save_pytorch(self, model, param_only=True): 59 | r""" 60 | 把 PyTorch 模型存入 ".pkl" 文件 61 | 62 | :param model: PyTorch 模型 63 | :param bool param_only: 是否只保存模型的参数(否则保存整个模型) 64 | 65 | """ 66 | if param_only is True: 67 | torch.save(model.state_dict(), self.save_path) 68 | else: 69 | torch.save(model, self.save_path) 70 | -------------------------------------------------------------------------------- /fastNLP/io/pipe/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 3 | ``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; 4 | ``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 5 | ``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` 6 | 一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, 7 | `data_bundle` 还会包含将field转为index时所建立的词表。 8 | 9 | """ 10 | __all__ = [ 11 | "Pipe", 12 | 13 | "CWSPipe", 14 | 15 | "CLSBasePipe", 16 | "AGsNewsPipe", 17 | "DBPediaPipe", 18 | "YelpFullPipe", 19 | "YelpPolarityPipe", 20 | "SSTPipe", 21 | "SST2Pipe", 22 | "IMDBPipe", 23 | "ChnSentiCorpPipe", 24 | "THUCNewsPipe", 25 | "WeiboSenti100kPipe", 26 | 27 | "Conll2003NERPipe", 28 | "OntoNotesNERPipe", 29 | "MsraNERPipe", 30 | "WeiboNERPipe", 31 | "PeopleDailyPipe", 32 | "Conll2003Pipe", 33 | 34 | "MatchingBertPipe", 35 | "RTEBertPipe", 36 | "SNLIBertPipe", 37 | "QuoraBertPipe", 38 | "QNLIBertPipe", 39 | "MNLIBertPipe", 40 | "CNXNLIBertPipe", 41 | "BQCorpusBertPipe", 42 | "LCQMCBertPipe", 43 | "MatchingPipe", 44 | "RTEPipe", 45 | "SNLIPipe", 46 | "QuoraPipe", 47 | "QNLIPipe", 48 | "MNLIPipe", 49 | "LCQMCPipe", 50 | "CNXNLIPipe", 51 | "BQCorpusPipe", 52 | "RenamePipe", 53 | "GranularizePipe", 54 | "MachingTruncatePipe", 55 | 56 | "CoReferencePipe", 57 | 58 | "CMRC2018BertPipe" 59 | ] 60 | 61 | from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ 62 | WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe 63 | from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe 64 | from .conll import Conll2003Pipe 65 | from .coreference import CoReferencePipe 66 | from .cws import CWSPipe 67 | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ 68 | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ 69 | LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe 70 | from .pipe import Pipe 71 | from .qa import CMRC2018BertPipe 72 | -------------------------------------------------------------------------------- /fastNLP/io/pipe/pipe.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "Pipe", 5 | ] 6 | 7 | from .. import DataBundle 8 | 9 | 10 | class Pipe: 11 | r""" 12 | Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe 13 | 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 14 | 15 | 一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字; 16 | (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index; 17 | 18 | Pipe中提供了两个方法 19 | 20 | -process()函数,输入为DataBundle 21 | -process_from_file()函数,输入为对应Loader的load函数可接受的类型。 22 | 23 | """ 24 | 25 | def process(self, data_bundle: DataBundle) -> DataBundle: 26 | r""" 27 | 对输入的DataBundle进行处理,然后返回该DataBundle。 28 | 29 | :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 30 | :return: 31 | """ 32 | raise NotImplementedError 33 | 34 | def process_from_file(self, paths) -> DataBundle: 35 | r""" 36 | 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` 37 | 38 | :param paths: 39 | :return: DataBundle 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /fastNLP/io/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "check_loader_paths" 8 | ] 9 | 10 | import os 11 | from pathlib import Path 12 | from typing import Union, Dict 13 | 14 | from ..core import logger 15 | 16 | 17 | def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: 18 | r""" 19 | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: 20 | 21 | { 22 | 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 23 | 'test': 'xxx' # 可能有,也可能没有 24 | ... 25 | } 26 | 27 | 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 28 | 29 | :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 30 | 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 31 | :return: 32 | """ 33 | if isinstance(paths, (str, Path)): 34 | paths = os.path.abspath(os.path.expanduser(paths)) 35 | if os.path.isfile(paths): 36 | return {'train': paths} 37 | elif os.path.isdir(paths): 38 | filenames = os.listdir(paths) 39 | filenames.sort() 40 | files = {} 41 | for filename in filenames: 42 | path_pair = None 43 | if 'train' in filename: 44 | path_pair = ('train', filename) 45 | if 'dev' in filename: 46 | if path_pair: 47 | raise Exception( 48 | "Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) 49 | path_pair = ('dev', filename) 50 | if 'test' in filename: 51 | if path_pair: 52 | raise Exception( 53 | "Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) 54 | path_pair = ('test', filename) 55 | if path_pair: 56 | if path_pair[0] in files: 57 | raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " 58 | f"filepath for `{path_pair[0]}`.") 59 | files[path_pair[0]] = os.path.join(paths, path_pair[1]) 60 | if 'train' not in files: 61 | raise KeyError(f"There is no train file in {paths}.") 62 | return files 63 | else: 64 | raise FileNotFoundError(f"{paths} is not a valid file path.") 65 | 66 | elif isinstance(paths, dict): 67 | if paths: 68 | if 'train' not in paths: 69 | raise KeyError("You have to include `train` in your dict.") 70 | for key, value in paths.items(): 71 | if isinstance(key, str) and isinstance(value, str): 72 | value = os.path.abspath(os.path.expanduser(value)) 73 | if not os.path.exists(value): 74 | raise TypeError(f"{value} is not a valid path.") 75 | paths[key] = value 76 | else: 77 | raise TypeError("All keys and values in paths should be str.") 78 | return paths 79 | else: 80 | raise ValueError("Empty paths is not allowed.") 81 | else: 82 | raise TypeError(f"paths only supports str and dict. not {type(paths)}.") 83 | -------------------------------------------------------------------------------- /fastNLP/models/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、 3 | :class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。 4 | 5 | .. todo:: 6 | 这些模型的介绍(与主页一致) 7 | 8 | 9 | """ 10 | __all__ = [ 11 | "CNNText", 12 | 13 | "SeqLabeling", 14 | "AdvSeqLabel", 15 | "BiLSTMCRF", 16 | 17 | "ESIM", 18 | 19 | "StarTransEnc", 20 | "STSeqLabel", 21 | "STNLICls", 22 | "STSeqCls", 23 | 24 | "BiaffineParser", 25 | "GraphParser", 26 | 27 | "BertForSequenceClassification", 28 | "BertForSentenceMatching", 29 | "BertForMultipleChoice", 30 | "BertForTokenClassification", 31 | "BertForQuestionAnswering", 32 | 33 | "TransformerSeq2SeqModel", 34 | "LSTMSeq2SeqModel", 35 | "Seq2SeqModel", 36 | 37 | 'SequenceGeneratorModel' 38 | ] 39 | 40 | from .base_model import BaseModel 41 | from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ 42 | BertForTokenClassification, BertForSentenceMatching 43 | from .biaffine_parser import BiaffineParser, GraphParser 44 | from .cnn_text_classification import CNNText 45 | from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF 46 | from .snli import ESIM 47 | from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel 48 | from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel 49 | from .seq2seq_generator import SequenceGeneratorModel 50 | import sys 51 | from ..doc_utils import doc_process 52 | 53 | doc_process(sys.modules[__name__]) 54 | -------------------------------------------------------------------------------- /fastNLP/models/base_model.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [] 4 | 5 | import torch 6 | 7 | from ..modules.decoder.mlp import MLP 8 | 9 | 10 | class BaseModel(torch.nn.Module): 11 | r"""Base PyTorch model for all models. 12 | """ 13 | 14 | def __init__(self): 15 | super(BaseModel, self).__init__() 16 | 17 | def fit(self, train_data, dev_data=None, **train_args): 18 | pass 19 | 20 | def predict(self, *args, **kwargs): 21 | raise NotImplementedError 22 | 23 | 24 | class NaiveClassifier(BaseModel): 25 | r""" 26 | 一个简单的分类器例子,可用于各种测试 27 | """ 28 | def __init__(self, in_feature_dim, out_feature_dim): 29 | super(NaiveClassifier, self).__init__() 30 | self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim]) 31 | 32 | def forward(self, x): 33 | return {"predict": torch.sigmoid(self.mlp(x))} 34 | 35 | def predict(self, x): 36 | return {"predict": torch.sigmoid(self.mlp(x)) > 0.5} 37 | -------------------------------------------------------------------------------- /fastNLP/models/cnn_text_classification.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "CNNText" 8 | ] 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from ..core.const import Const as C 14 | from ..core.utils import seq_len_to_mask 15 | from ..embeddings import embedding 16 | from ..modules import encoder 17 | 18 | 19 | class CNNText(torch.nn.Module): 20 | r""" 21 | 使用CNN进行文本分类的模型 22 | 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' 23 | 24 | """ 25 | 26 | def __init__(self, embed, 27 | num_classes, 28 | kernel_nums=(30, 40, 50), 29 | kernel_sizes=(1, 3, 5), 30 | dropout=0.5): 31 | r""" 32 | 33 | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int), 34 | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding 35 | :param int num_classes: 一共有多少类 36 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 37 | :param float dropout: Dropout的大小 38 | """ 39 | super(CNNText, self).__init__() 40 | 41 | # no support for pre-trained embedding currently 42 | self.embed = embedding.Embedding(embed) 43 | self.conv_pool = encoder.ConvMaxpool( 44 | in_channels=self.embed.embedding_dim, 45 | out_channels=kernel_nums, 46 | kernel_sizes=kernel_sizes) 47 | self.dropout = nn.Dropout(dropout) 48 | self.fc = nn.Linear(sum(kernel_nums), num_classes) 49 | 50 | def forward(self, words, seq_len=None): 51 | r""" 52 | 53 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index 54 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 55 | :return output: dict of torch.LongTensor, [batch_size, num_classes] 56 | """ 57 | x = self.embed(words) # [N,L] -> [N,L,C] 58 | if seq_len is not None: 59 | mask = seq_len_to_mask(seq_len) 60 | x = self.conv_pool(x, mask) 61 | else: 62 | x = self.conv_pool(x) # [N,L,C] -> [N,C] 63 | x = self.dropout(x) 64 | x = self.fc(x) # [N,C] -> [N, N_class] 65 | return {C.OUTPUT: x} 66 | 67 | def predict(self, words, seq_len=None): 68 | r""" 69 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index 70 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度 71 | 72 | :return predict: dict of torch.LongTensor, [batch_size, ] 73 | """ 74 | output = self(words, seq_len) 75 | _, predict = output[C.OUTPUT].max(dim=1) 76 | return {C.OUTPUT: predict} 77 | -------------------------------------------------------------------------------- /fastNLP/models/seq2seq_generator.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | import torch 4 | from torch import nn 5 | from .seq2seq_model import Seq2SeqModel 6 | from ..modules.generator.seq2seq_generator import SequenceGenerator 7 | 8 | 9 | class SequenceGeneratorModel(nn.Module): 10 | """ 11 | 用于封装Seq2SeqModel使其可以做生成任务 12 | 13 | """ 14 | 15 | def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0, 16 | num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0, 17 | repetition_penalty=1, length_penalty=1.0, pad_token_id=0): 18 | """ 19 | 20 | :param Seq2SeqModel seq2seq_model: 序列到序列模型 21 | :param int,None bos_token_id: 句子开头的token id 22 | :param int,None eos_token_id: 句子结束的token id 23 | :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len 24 | :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask 25 | :param int num_beams: beam search的大小 26 | :param bool do_sample: 是否通过采样的方式生成 27 | :param float temperature: 只有在do_sample为True才有意义 28 | :param int top_k: 只从top_k中采样 29 | :param float top_p: 只从top_p的token中采样,nucles sample 30 | :param float repetition_penalty: 多大程度上惩罚重复的token 31 | :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧 32 | :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充 33 | """ 34 | super().__init__() 35 | self.seq2seq_model = seq2seq_model 36 | self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a, 37 | num_beams=num_beams, 38 | do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p, 39 | bos_token_id=bos_token_id, 40 | eos_token_id=eos_token_id, 41 | repetition_penalty=repetition_penalty, length_penalty=length_penalty, 42 | pad_token_id=pad_token_id) 43 | 44 | def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None): 45 | """ 46 | 透传调用seq2seq_model的forward 47 | 48 | :param torch.LongTensor src_tokens: bsz x max_len 49 | :param torch.LongTensor tgt_tokens: bsz x max_len' 50 | :param torch.LongTensor src_seq_len: bsz 51 | :param torch.LongTensor tgt_seq_len: bsz 52 | :return: 53 | """ 54 | return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len) 55 | 56 | def predict(self, src_tokens, src_seq_len=None): 57 | """ 58 | 给定source的内容,输出generate的内容 59 | 60 | :param torch.LongTensor src_tokens: bsz x max_len 61 | :param torch.LongTensor src_seq_len: bsz 62 | :return: 63 | """ 64 | state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len) 65 | result = self.generator.generate(state) 66 | return {'pred': result} 67 | -------------------------------------------------------------------------------- /fastNLP/modules/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | .. image:: figures/text_classification.png 4 | 5 | 大部分用于的 NLP 任务神经网络都可以看做由 :mod:`embedding` 、 :mod:`~fastNLP.modules.encoder` 、 6 | :mod:`~fastNLP.modules.decoder` 三种模块组成。 本模块中实现了 fastNLP 提供的诸多模块组件, 7 | 可以帮助用户快速搭建自己所需的网络。几种模块的功能和常见组件如下: 8 | 9 | .. csv-table:: 10 | :header: "类型", "功能", "常见组件" 11 | 12 | "embedding", 参见 :mod:`/fastNLP.embeddings` , "Elmo, Bert" 13 | "encoder", "将输入编码为具有表示能力的向量", "CNN, LSTM, Transformer" 14 | "decoder", "将具有某种表示意义的向量解码为需要的输出形式 ", "MLP, CRF" 15 | "其它", "配合其它组件使用的组件", "Dropout" 16 | 17 | 18 | """ 19 | __all__ = [ 20 | # "BertModel", 21 | 22 | "ConvolutionCharEncoder", 23 | "LSTMCharEncoder", 24 | 25 | "ConvMaxpool", 26 | 27 | "LSTM", 28 | 29 | "StarTransformer", 30 | 31 | "TransformerEncoder", 32 | 33 | "VarRNN", 34 | "VarLSTM", 35 | "VarGRU", 36 | 37 | "MaxPool", 38 | "MaxPoolWithMask", 39 | "KMaxPool", 40 | "AvgPool", 41 | "AvgPoolWithMask", 42 | 43 | "MultiHeadAttention", 44 | 45 | "MLP", 46 | "ConditionalRandomField", 47 | "viterbi_decode", 48 | "allowed_transitions", 49 | 50 | "TimestepDropout", 51 | 52 | 'summary', 53 | 54 | "BertTokenizer", 55 | "BertModel", 56 | 57 | "RobertaTokenizer", 58 | "RobertaModel", 59 | 60 | "GPT2Model", 61 | "GPT2Tokenizer", 62 | 63 | "TransformerSeq2SeqEncoder", 64 | "LSTMSeq2SeqEncoder", 65 | "Seq2SeqEncoder", 66 | 67 | "TransformerSeq2SeqDecoder", 68 | "LSTMSeq2SeqDecoder", 69 | "Seq2SeqDecoder", 70 | 71 | "TransformerState", 72 | "LSTMState", 73 | "State", 74 | 75 | "SequenceGenerator" 76 | ] 77 | 78 | import sys 79 | 80 | from . import decoder 81 | from . import encoder 82 | from .decoder import * 83 | from .dropout import TimestepDropout 84 | from .encoder import * 85 | from .generator import * 86 | from .utils import summary 87 | from ..doc_utils import doc_process 88 | from .tokenizer import * 89 | 90 | doc_process(sys.modules[__name__]) 91 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | __all__ = [ 6 | "MLP", 7 | "ConditionalRandomField", 8 | "viterbi_decode", 9 | "allowed_transitions", 10 | 11 | "LSTMState", 12 | "TransformerState", 13 | "State", 14 | 15 | "TransformerSeq2SeqDecoder", 16 | "LSTMSeq2SeqDecoder", 17 | "Seq2SeqDecoder" 18 | ] 19 | 20 | from .crf import ConditionalRandomField 21 | from .crf import allowed_transitions 22 | from .mlp import MLP 23 | from .utils import viterbi_decode 24 | from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder 25 | from .seq2seq_state import State, LSTMState, TransformerState 26 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/mlp.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "MLP" 5 | ] 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ..utils import initial_parameter 11 | 12 | 13 | class MLP(nn.Module): 14 | r""" 15 | 多层感知器 16 | 17 | 18 | .. note:: 19 | 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。 20 | 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义; 21 | 如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。 22 | 输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。 23 | 24 | Examples:: 25 | 26 | >>> net1 = MLP([5, 10, 5]) 27 | >>> net2 = MLP([5, 10, 5], 'tanh') 28 | >>> net3 = MLP([5, 6, 7, 8, 5], 'tanh') 29 | >>> net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh') 30 | >>> net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh') 31 | >>> for net in [net1, net2, net3, net4, net5]: 32 | >>> x = torch.randn(5, 5) 33 | >>> y = net(x) 34 | >>> print(x) 35 | >>> print(y) 36 | """ 37 | 38 | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0): 39 | r""" 40 | 41 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1 42 | :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和 43 | sigmoid,默认值为relu 44 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数 45 | :param str initial_method: 参数初始化方式 46 | :param float dropout: dropout概率,默认值为0 47 | """ 48 | super(MLP, self).__init__() 49 | self.hiddens = nn.ModuleList() 50 | self.output = None 51 | self.output_activation = output_activation 52 | for i in range(1, len(size_layer)): 53 | if i + 1 == len(size_layer): 54 | self.output = nn.Linear(size_layer[i - 1], size_layer[i]) 55 | else: 56 | self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i])) 57 | 58 | self.dropout = nn.Dropout(p=dropout) 59 | 60 | actives = { 61 | 'relu': nn.ReLU(), 62 | 'tanh': nn.Tanh(), 63 | 'sigmoid': nn.Sigmoid(), 64 | } 65 | if not isinstance(activation, list): 66 | activation = [activation] * (len(size_layer) - 2) 67 | elif len(activation) == len(size_layer) - 2: 68 | pass 69 | else: 70 | raise ValueError( 71 | f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!") 72 | self.hidden_active = [] 73 | for func in activation: 74 | if callable(activation): 75 | self.hidden_active.append(activation) 76 | elif func.lower() in actives: 77 | self.hidden_active.append(actives[func]) 78 | else: 79 | raise ValueError("should set activation correctly: {}".format(activation)) 80 | if self.output_activation is not None: 81 | if callable(self.output_activation): 82 | pass 83 | elif self.output_activation.lower() in actives: 84 | self.output_activation = actives[self.output_activation] 85 | else: 86 | raise ValueError("should set activation correctly: {}".format(activation)) 87 | initial_parameter(self, initial_method) 88 | 89 | def forward(self, x): 90 | r""" 91 | :param torch.Tensor x: MLP接受的输入 92 | :return: torch.Tensor : MLP的输出结果 93 | """ 94 | for layer, func in zip(self.hiddens, self.hidden_active): 95 | x = self.dropout(func(layer(x))) 96 | x = self.output(x) 97 | if self.output_activation is not None: 98 | x = self.output_activation(x) 99 | x = self.dropout(x) 100 | return x 101 | -------------------------------------------------------------------------------- /fastNLP/modules/decoder/utils.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "viterbi_decode" 5 | ] 6 | import torch 7 | 8 | 9 | def viterbi_decode(logits, transitions, mask=None, unpad=False): 10 | r""" 11 | 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 12 | 13 | :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 14 | :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x 15 | (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的 16 | 负数,例如-10000000.0 17 | :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 18 | :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 19 | List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 20 | 个sample的有效长度。 21 | :return: 返回 (paths, scores)。 22 | paths: 是解码后的路径, 其值参照unpad参数. 23 | scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 24 | 25 | """ 26 | batch_size, seq_len, n_tags = logits.size() 27 | if transitions.size(0) == n_tags+2: 28 | include_start_end_trans = True 29 | elif transitions.size(0) == n_tags: 30 | include_start_end_trans = False 31 | else: 32 | raise RuntimeError("The shapes of transitions and feats are not " \ 33 | "compatible.") 34 | logits = logits.transpose(0, 1).data # L, B, H 35 | if mask is not None: 36 | mask = mask.transpose(0, 1).data.eq(True) # L, B 37 | else: 38 | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1) 39 | 40 | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data 41 | 42 | # dp 43 | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) 44 | vscore = logits[0] 45 | if include_start_end_trans: 46 | vscore += transitions[n_tags, :n_tags] 47 | 48 | for i in range(1, seq_len): 49 | prev_score = vscore.view(batch_size, n_tags, 1) 50 | cur_score = logits[i].view(batch_size, 1, n_tags) 51 | score = prev_score + trans_score + cur_score 52 | best_score, best_dst = score.max(1) 53 | vpath[i] = best_dst 54 | vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ 55 | vscore.masked_fill(mask[i].view(batch_size, 1), 0) 56 | 57 | if include_start_end_trans: 58 | vscore += transitions[:n_tags, n_tags + 1].view(1, -1) 59 | # backtrace 60 | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) 61 | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) 62 | lens = (mask.long().sum(0) - 1) 63 | # idxes [L, B], batched idx from seq_len-1 to 0 64 | idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len 65 | 66 | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) 67 | ans_score, last_tags = vscore.max(1) 68 | ans[idxes[0], batch_idx] = last_tags 69 | for i in range(seq_len - 1): 70 | last_tags = vpath[idxes[i], batch_idx, last_tags] 71 | ans[idxes[i + 1], batch_idx] = last_tags 72 | ans = ans.transpose(0, 1) 73 | if unpad: 74 | paths = [] 75 | for idx, seq_len in enumerate(lens): 76 | paths.append(ans[idx, :seq_len + 1].tolist()) 77 | else: 78 | paths = ans 79 | return paths, ans_score 80 | -------------------------------------------------------------------------------- /fastNLP/modules/dropout.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "TimestepDropout" 5 | ] 6 | 7 | import torch 8 | 9 | 10 | class TimestepDropout(torch.nn.Dropout): 11 | r""" 12 | 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)`` 13 | 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。 14 | """ 15 | 16 | def forward(self, x): 17 | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) 18 | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) 19 | dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim] 20 | if self.inplace: 21 | x *= dropout_mask 22 | return 23 | else: 24 | return x * dropout_mask 25 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "ConvolutionCharEncoder", 8 | "LSTMCharEncoder", 9 | 10 | "ConvMaxpool", 11 | 12 | "LSTM", 13 | 14 | "StarTransformer", 15 | 16 | "TransformerEncoder", 17 | 18 | "VarRNN", 19 | "VarLSTM", 20 | "VarGRU", 21 | 22 | "MaxPool", 23 | "MaxPoolWithMask", 24 | "KMaxPool", 25 | "AvgPool", 26 | "AvgPoolWithMask", 27 | 28 | "MultiHeadAttention", 29 | "BiAttention", 30 | "SelfAttention", 31 | 32 | "BertModel", 33 | 34 | "RobertaModel", 35 | 36 | "GPT2Model", 37 | 38 | "LSTMSeq2SeqEncoder", 39 | "TransformerSeq2SeqEncoder", 40 | "Seq2SeqEncoder" 41 | ] 42 | 43 | from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention 44 | from .bert import BertModel 45 | from .roberta import RobertaModel 46 | from .gpt2 import GPT2Model 47 | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder 48 | from .conv_maxpool import ConvMaxpool 49 | from .lstm import LSTM 50 | from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool 51 | from .star_transformer import StarTransformer 52 | from .transformer import TransformerEncoder 53 | from .variational_rnn import VarRNN, VarLSTM, VarGRU 54 | from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder 55 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/char_encoder.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ConvolutionCharEncoder", 5 | "LSTMCharEncoder" 6 | ] 7 | import torch 8 | import torch.nn as nn 9 | 10 | from ..utils import initial_parameter 11 | 12 | 13 | # from torch.nn.init import xavier_uniform 14 | class ConvolutionCharEncoder(nn.Module): 15 | r""" 16 | char级别的卷积编码器. 17 | 18 | """ 19 | 20 | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): 21 | r""" 22 | 23 | :param int char_emb_size: char级别embedding的维度. Default: 50 24 | :例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. 25 | :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter. 26 | :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核. 27 | :param initial_method: 初始化参数的方式, 默认为`xavier normal` 28 | """ 29 | super(ConvolutionCharEncoder, self).__init__() 30 | self.convs = nn.ModuleList([ 31 | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, 32 | padding=(0, kernels[i] // 2)) 33 | for i in range(len(kernels))]) 34 | 35 | initial_parameter(self, initial_method) 36 | 37 | def forward(self, x): 38 | r""" 39 | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding 40 | :return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1] 41 | """ 42 | x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2)) 43 | # [batch_size*sent_length, channel, width, height] 44 | x = x.transpose(2, 3) 45 | # [batch_size*sent_length, channel, height, width] 46 | return self._convolute(x).unsqueeze(2) 47 | 48 | def _convolute(self, x): 49 | feats = [] 50 | for conv in self.convs: 51 | y = conv(x) 52 | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] 53 | y = torch.squeeze(y, 2) 54 | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] 55 | y = torch.tanh(y) 56 | y, __ = torch.max(y, 2) 57 | # [batch_size*sent_length, feature_maps[i]] 58 | feats.append(y) 59 | return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)] 60 | 61 | 62 | class LSTMCharEncoder(nn.Module): 63 | r""" 64 | char级别基于LSTM的encoder. 65 | """ 66 | 67 | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None): 68 | r""" 69 | :param int char_emb_size: char级别embedding的维度. Default: 50 70 | 例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50. 71 | :param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度 72 | :param initial_method: 初始化参数的方式, 默认为`xavier normal` 73 | """ 74 | super(LSTMCharEncoder, self).__init__() 75 | self.hidden_size = char_emb_size if hidden_size is None else hidden_size 76 | 77 | self.lstm = nn.LSTM(input_size=char_emb_size, 78 | hidden_size=self.hidden_size, 79 | num_layers=1, 80 | bias=True, 81 | batch_first=True) 82 | initial_parameter(self, initial_method) 83 | 84 | def forward(self, x): 85 | r""" 86 | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding 87 | :return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果 88 | """ 89 | batch_size = x.shape[0] 90 | h0 = torch.empty(1, batch_size, self.hidden_size) 91 | h0 = nn.init.orthogonal_(h0) 92 | c0 = torch.empty(1, batch_size, self.hidden_size) 93 | c0 = nn.init.orthogonal_(c0) 94 | 95 | _, hidden = self.lstm(x, (h0, c0)) 96 | return hidden[0].squeeze().unsqueeze(2) 97 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/conv_maxpool.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "ConvMaxpool" 5 | ] 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ConvMaxpool(nn.Module): 12 | r""" 13 | 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x 14 | sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len) 15 | 这一维进行max_pooling。最后得到每个sample的一个向量表示。 16 | 17 | """ 18 | 19 | def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): 20 | r""" 21 | 22 | :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 23 | :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 24 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 25 | :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh 26 | """ 27 | super(ConvMaxpool, self).__init__() 28 | 29 | for kernel_size in kernel_sizes: 30 | assert kernel_size % 2 == 1, "kernel size has to be odd numbers." 31 | 32 | # convolution 33 | if isinstance(kernel_sizes, (list, tuple, int)): 34 | if isinstance(kernel_sizes, int) and isinstance(out_channels, int): 35 | out_channels = [out_channels] 36 | kernel_sizes = [kernel_sizes] 37 | elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)): 38 | assert len(out_channels) == len( 39 | kernel_sizes), "The number of out_channels should be equal to the number" \ 40 | " of kernel_sizes." 41 | else: 42 | raise ValueError("The type of out_channels and kernel_sizes should be the same.") 43 | 44 | self.convs = nn.ModuleList([nn.Conv1d( 45 | in_channels=in_channels, 46 | out_channels=oc, 47 | kernel_size=ks, 48 | stride=1, 49 | padding=ks // 2, 50 | dilation=1, 51 | groups=1, 52 | bias=None) 53 | for oc, ks in zip(out_channels, kernel_sizes)]) 54 | 55 | else: 56 | raise Exception( 57 | 'Incorrect kernel sizes: should be list, tuple or int') 58 | 59 | # activation function 60 | if activation == 'relu': 61 | self.activation = F.relu 62 | elif activation == 'sigmoid': 63 | self.activation = F.sigmoid 64 | elif activation == 'tanh': 65 | self.activation = F.tanh 66 | else: 67 | raise Exception( 68 | "Undefined activation function: choose from: relu, tanh, sigmoid") 69 | 70 | def forward(self, x, mask=None): 71 | r""" 72 | 73 | :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值 74 | :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置 75 | :return: 76 | """ 77 | # [N,L,C] -> [N,C,L] 78 | x = torch.transpose(x, 1, 2) 79 | # convolution 80 | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] 81 | if mask is not None: 82 | mask = mask.unsqueeze(1) # B x 1 x L 83 | xs = [x.masked_fill_(mask.eq(False), float('-inf')) for x in xs] 84 | # max-pooling 85 | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) 86 | for i in xs] # [[N, C], ...] 87 | return torch.cat(xs, dim=-1) # [N, C] 88 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/lstm.py: -------------------------------------------------------------------------------- 1 | r"""undocumented 2 | 轻量封装的 Pytorch LSTM 模块. 3 | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. 4 | """ 5 | 6 | __all__ = [ 7 | "LSTM" 8 | ] 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.utils.rnn as rnn 13 | 14 | 15 | class LSTM(nn.Module): 16 | r""" 17 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 18 | 为1; 且可以应对DataParallel中LSTM的使用问题。 19 | 20 | """ 21 | 22 | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, 23 | bidirectional=False, bias=True): 24 | r""" 25 | 26 | :param input_size: 输入 `x` 的特征维度 27 | :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2 28 | :param num_layers: rnn的层数. Default: 1 29 | :param dropout: 层间dropout概率. Default: 0 30 | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` 31 | :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 32 | :(batch, seq, feature). Default: ``False`` 33 | :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` 34 | """ 35 | super(LSTM, self).__init__() 36 | self.batch_first = batch_first 37 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, 38 | dropout=dropout, bidirectional=bidirectional) 39 | self.init_param() 40 | 41 | def init_param(self): 42 | for name, param in self.named_parameters(): 43 | if 'bias' in name: 44 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 45 | param.data.fill_(0) 46 | n = param.size(0) 47 | start, end = n // 4, n // 2 48 | param.data[start:end].fill_(1) 49 | else: 50 | nn.init.xavier_uniform_(param) 51 | 52 | def forward(self, x, seq_len=None, h0=None, c0=None): 53 | r""" 54 | 55 | :param x: [batch, seq_len, input_size] 输入序列 56 | :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` 57 | :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` 58 | :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` 59 | :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列 60 | 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态. 61 | """ 62 | batch_size, max_len, _ = x.size() 63 | if h0 is not None and c0 is not None: 64 | hx = (h0, c0) 65 | else: 66 | hx = None 67 | if seq_len is not None and not isinstance(x, rnn.PackedSequence): 68 | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) 69 | if self.batch_first: 70 | x = x[sort_idx] 71 | else: 72 | x = x[:, sort_idx] 73 | x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) 74 | output, hx = self.lstm(x, hx) # -> [N,L,C] 75 | output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len) 76 | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) 77 | if self.batch_first: 78 | output = output[unsort_idx] 79 | else: 80 | output = output[:, unsort_idx] 81 | hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx] 82 | else: 83 | output, hx = self.lstm(x, hx) 84 | return output, hx 85 | -------------------------------------------------------------------------------- /fastNLP/modules/encoder/transformer.py: -------------------------------------------------------------------------------- 1 | r"""undocumented""" 2 | 3 | __all__ = [ 4 | "TransformerEncoder" 5 | ] 6 | from torch import nn 7 | 8 | from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer 9 | 10 | 11 | class TransformerEncoder(nn.Module): 12 | r""" 13 | transformer的encoder模块,不包含embedding层 14 | 15 | """ 16 | def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1): 17 | """ 18 | 19 | :param int num_layers: 多少层Transformer 20 | :param int d_model: input和output的大小 21 | :param int n_head: 多少个head 22 | :param int dim_ff: FFN中间hidden大小 23 | :param float dropout: 多大概率drop attention和ffn中间的表示 24 | """ 25 | super(TransformerEncoder, self).__init__() 26 | self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff, 27 | dropout = dropout) for _ in range(num_layers)]) 28 | self.norm = nn.LayerNorm(d_model, eps=1e-6) 29 | 30 | def forward(self, x, seq_mask=None): 31 | r""" 32 | :param x: [batch, seq_len, model_size] 输入序列 33 | :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend 34 | Default: ``None`` 35 | :return: [batch, seq_len, model_size] 输出序列 36 | """ 37 | output = x 38 | if seq_mask is None: 39 | seq_mask = x.new_ones(x.size(0), x.size(1)).bool() 40 | for layer in self.layers: 41 | output = layer(output, seq_mask) 42 | return self.norm(output) 43 | -------------------------------------------------------------------------------- /fastNLP/modules/generator/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | 5 | __all__ = [ 6 | "SequenceGenerator" 7 | ] 8 | 9 | from .seq2seq_generator import SequenceGenerator -------------------------------------------------------------------------------- /fastNLP/modules/tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | __all__=[ 5 | 'BertTokenizer', 6 | 7 | "GPT2Tokenizer", 8 | 9 | "RobertaTokenizer" 10 | ] 11 | 12 | from .bert_tokenizer import BertTokenizer 13 | from .gpt2_tokenizer import GPT2Tokenizer 14 | from .roberta_tokenizer import RobertaTokenizer -------------------------------------------------------------------------------- /fastNLP/modules/tokenizer/roberta_tokenizer.py: -------------------------------------------------------------------------------- 1 | r""" 2 | 3 | """ 4 | 5 | __all__ = [ 6 | "RobertaTokenizer" 7 | ] 8 | 9 | import json 10 | from .gpt2_tokenizer import GPT2Tokenizer 11 | from fastNLP.io.file_utils import _get_file_name_base_on_postfix 12 | from ...io.file_utils import _get_roberta_dir 13 | 14 | PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = { 15 | "roberta-base": 512, 16 | "roberta-large": 512, 17 | "roberta-large-mnli": 512, 18 | "distilroberta-base": 512, 19 | "roberta-base-openai-detector": 512, 20 | "roberta-large-openai-detector": 512, 21 | } 22 | 23 | 24 | class RobertaTokenizer(GPT2Tokenizer): 25 | 26 | vocab_files_names = { 27 | "vocab_file": "vocab.json", 28 | "merges_file": "merges.txt", 29 | } 30 | 31 | def __init__( 32 | self, 33 | vocab_file, 34 | merges_file, 35 | errors="replace", 36 | bos_token="", 37 | eos_token="", 38 | sep_token="", 39 | cls_token="", 40 | unk_token="", 41 | pad_token="", 42 | mask_token="", 43 | **kwargs 44 | ): 45 | super().__init__( 46 | vocab_file=vocab_file, 47 | merges_file=merges_file, 48 | errors=errors, 49 | bos_token=bos_token, 50 | eos_token=eos_token, 51 | unk_token=unk_token, 52 | sep_token=sep_token, 53 | cls_token=cls_token, 54 | pad_token=pad_token, 55 | mask_token=mask_token, 56 | **kwargs, 57 | ) 58 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens 59 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens 60 | 61 | @classmethod 62 | def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs): 63 | """ 64 | 65 | :param str model_dir_or_name: 目录或者缩写名 66 | :param kwargs: 67 | :return: 68 | """ 69 | # 它需要两个文件,第一个是vocab.json,第二个是merge_file? 70 | model_dir = _get_roberta_dir(model_dir_or_name) 71 | # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin 72 | 73 | tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json') 74 | with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle: 75 | init_kwargs = json.load(tokenizer_config_handle) 76 | # Set max length if needed 77 | if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES: 78 | # if we're using a pretrained model, ensure the tokenizer 79 | # wont index sequences longer than the number of positional embeddings 80 | max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name] 81 | if max_len is not None and isinstance(max_len, (int, float)): 82 | init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len) 83 | 84 | # 将vocab, merge加入到init_kwargs中 85 | if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表 86 | init_kwargs['vocab_file'] = kwargs['vocab_file'] 87 | else: 88 | init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['vocab_file']) 89 | init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['merges_file']) 90 | 91 | init_inputs = init_kwargs.pop("init_inputs", ()) 92 | # Instantiate tokenizer. 93 | try: 94 | tokenizer = cls(*init_inputs, **init_kwargs) 95 | except OSError: 96 | OSError( 97 | "Unable to load vocabulary from file. " 98 | "Please check that the provided vocabulary is accessible and not corrupted." 99 | ) 100 | 101 | return tokenizer 102 | 103 | -------------------------------------------------------------------------------- /fastNLP/modules/utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | .. todo:: 3 | doc 4 | """ 5 | 6 | __all__ = [ 7 | "initial_parameter", 8 | "summary" 9 | ] 10 | 11 | from functools import reduce 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | 18 | def initial_parameter(net, initial_method=None): 19 | r"""A method used to initialize the weights of PyTorch models. 20 | 21 | :param net: a PyTorch model 22 | :param str initial_method: one of the following initializations. 23 | 24 | - xavier_uniform 25 | - xavier_normal (default) 26 | - kaiming_normal, or msra 27 | - kaiming_uniform 28 | - orthogonal 29 | - sparse 30 | - normal 31 | - uniform 32 | 33 | """ 34 | if initial_method == 'xavier_uniform': 35 | init_method = init.xavier_uniform_ 36 | elif initial_method == 'xavier_normal': 37 | init_method = init.xavier_normal_ 38 | elif initial_method == 'kaiming_normal' or initial_method == 'msra': 39 | init_method = init.kaiming_normal_ 40 | elif initial_method == 'kaiming_uniform': 41 | init_method = init.kaiming_uniform_ 42 | elif initial_method == 'orthogonal': 43 | init_method = init.orthogonal_ 44 | elif initial_method == 'sparse': 45 | init_method = init.sparse_ 46 | elif initial_method == 'normal': 47 | init_method = init.normal_ 48 | elif initial_method == 'uniform': 49 | init_method = init.uniform_ 50 | else: 51 | init_method = init.xavier_normal_ 52 | 53 | def weights_init(m): 54 | # classname = m.__class__.__name__ 55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn 56 | if initial_method is not None: 57 | init_method(m.weight.data) 58 | else: 59 | init.xavier_normal_(m.weight.data) 60 | init.normal_(m.bias.data) 61 | elif isinstance(m, nn.LSTM): 62 | for w in m.parameters(): 63 | if len(w.data.size()) > 1: 64 | init_method(w.data) # weight 65 | else: 66 | init.normal_(w.data) # bias 67 | elif m is not None and hasattr(m, 'weight') and \ 68 | hasattr(m.weight, "requires_grad"): 69 | if len(m.weight.size()) > 1: 70 | init_method(m.weight.data) 71 | else: 72 | init.normal_(m.weight.data) # batchnorm or layernorm 73 | else: 74 | for w in m.parameters(): 75 | if w.requires_grad: 76 | if len(w.data.size()) > 1: 77 | init_method(w.data) # weight 78 | else: 79 | init.normal_(w.data) # bias 80 | # print("init else") 81 | 82 | net.apply(weights_init) 83 | 84 | 85 | def summary(model: nn.Module): 86 | r""" 87 | 得到模型的总参数量 88 | 89 | :params model: Pytorch 模型 90 | :return tuple: 包含总参数量,可训练参数量,不可训练参数量 91 | """ 92 | train = [] 93 | nontrain = [] 94 | buffer = [] 95 | 96 | def layer_summary(module: nn.Module): 97 | def count_size(sizes): 98 | return reduce(lambda x, y: x * y, sizes) 99 | 100 | for p in module.parameters(recurse=False): 101 | if p.requires_grad: 102 | train.append(count_size(p.shape)) 103 | else: 104 | nontrain.append(count_size(p.shape)) 105 | for p in module.buffers(): 106 | buffer.append(count_size(p.shape)) 107 | for subm in module.children(): 108 | layer_summary(subm) 109 | 110 | layer_summary(model) 111 | total_train = sum(train) 112 | total_nontrain = sum(nontrain) 113 | total = total_train + total_nontrain 114 | strings = [] 115 | strings.append('Total params: {:,}'.format(total)) 116 | strings.append('Trainable params: {:,}'.format(total_train)) 117 | strings.append('Non-trainable params: {:,}'.format(total_nontrain)) 118 | strings.append("Buffer params: {:,}".format(sum(buffer))) 119 | max_len = len(max(strings, key=len)) 120 | bar = '-' * (max_len + 3) 121 | strings = [bar] + strings + [bar] 122 | print('\n'.join(strings)) 123 | return total, total_train, total_nontrain 124 | 125 | 126 | def get_dropout_mask(drop_p: float, tensor: torch.Tensor): 127 | r""" 128 | 根据tensor的形状,生成一个mask 129 | 130 | :param drop_p: float, 以多大的概率置为0。 131 | :param tensor: torch.Tensor 132 | :return: torch.FloatTensor. 与tensor一样的shape 133 | """ 134 | mask_x = torch.ones_like(tensor) 135 | nn.functional.dropout(mask_x, p=drop_p, 136 | training=False, inplace=True) 137 | return mask_x 138 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | aiohttp==3.7.4.post0 3 | alembic==1.5.8 4 | antlr4-python3-runtime==4.8 5 | asn1crypto==0.24.0 6 | async-timeout==3.0.1 7 | attrs==20.3.0 8 | blessings==1.7 9 | cachetools==4.2.1 10 | certifi==2020.12.5 11 | chardet==4.0.0 12 | click==7.1.2 13 | cliff==3.7.0 14 | cmaes==0.8.2 15 | cmd2==1.5.0 16 | colorama==0.4.4 17 | colorlog==4.8.0 18 | command-not-found==0.3 19 | commonmark==0.9.1 20 | configparser==5.0.2 21 | cryptography==2.1.4 22 | dataclasses==0.8 23 | distro-info==0.18ubuntu0.18.04.1 24 | docker-pycreds==0.4.0 25 | easydict==1.9 26 | filelock==3.0.12 27 | fsspec==0.8.7 28 | future==0.18.2 29 | gitdb==4.0.5 30 | gitpython==3.1.14 31 | google-auth-oauthlib==0.4.3 32 | google-auth==1.28.0 33 | gpustat==0.6.0 34 | greenlet==1.0.0 35 | grpcio==1.36.1 36 | hydra-core==1.1.0.dev4 37 | hydra-optuna-sweeper==1.1.0.dev1 38 | idna-ssl==1.1.0 39 | idna==2.10 40 | importlib-metadata==3.10.0 41 | importlib-resources==5.1.2 42 | joblib==1.0.1 43 | keyring==10.6.0 44 | keyrings.alt==3.0 45 | language-selector==0.1 46 | mako==1.1.4 47 | markdown==3.3.4 48 | markupsafe==1.1.1 49 | multidict==5.1.0 50 | netifaces==0.10.4 51 | nltk==3.5 52 | numpy==1.19.5 53 | nvidia-ml-py3==7.352.0 54 | oauthlib==3.1.0 55 | omegaconf==2.1.0.dev24 56 | opt-einsum==3.3.0 57 | optuna==2.4.0 58 | packaging==20.9 59 | pathtools==0.1.2 60 | pbr==5.5.1 61 | pillow==8.0.0 62 | pip==9.0.1 63 | prettytable==2.1.0 64 | promise==2.3 65 | protobuf==3.15.6 66 | psutil==5.7.2 67 | pyasn1-modules==0.2.8 68 | pyasn1==0.4.8 69 | pycrypto==2.6.1 70 | pygments==2.8.1 71 | pygobject==3.26.1 72 | pyparsing==2.4.7 73 | pyperclip==1.8.2 74 | python-apt==1.6.5+ubuntu0.3 75 | python-dateutil==2.8.1 76 | python-distutils-extra==2.39 77 | python-editor==1.0.4 78 | pytorch-lightning==1.2.4 79 | pyxdg==0.25 80 | pyyaml==5.4.1 81 | regex==2020.11.13 82 | requests-oauthlib==1.3.0 83 | requests==2.25.1 84 | rich==9.13.0 85 | rsa==4.7.2 86 | sacremoses==0.0.43 87 | scipy==1.5.4 88 | secretstorage==2.3.1 89 | sentry-sdk==1.0.0 90 | setuptools==54.1.2 91 | shortuuid==1.0.1 92 | six==1.15.0 93 | smmap==3.0.5 94 | sqlalchemy==1.4.6 95 | ssh-import-id==5.7 96 | stevedore==3.3.0 97 | subprocess32==3.5.4 98 | tensorboard-plugin-wit==1.8.0 99 | tensorboard==2.4.1 100 | tokenizers==0.10.1 101 | torch-struct==0.5 102 | torch==1.7.0 103 | torchaudio==0.7.0 104 | torchmetrics==0.2.0 105 | torchvision==0.8.0 106 | tqdm==4.60.0 107 | transformers==4.3.3 108 | typing-extensions==3.7.4.3 109 | ufw==0.36 110 | urllib3==1.26.4 111 | wandb==0.10.21 112 | wcwidth==0.2.5 113 | werkzeug==1.0.1 114 | wheel==0.36.2 115 | yarl==1.6.3 116 | zipp==3.4.1 117 | -------------------------------------------------------------------------------- /src/callbacks/progressbar.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint, ProgressBar 2 | from tqdm import tqdm 3 | 4 | class PrettyProgressBar(ProgressBar): 5 | """Good print wrapper.""" 6 | def __init__(self, refresh_rate: int, process_position: int): 7 | super().__init__(refresh_rate=refresh_rate, process_position=process_position) 8 | 9 | def init_sanity_tqdm(self) -> tqdm: 10 | bar = tqdm(desc='Validation sanity check', 11 | position=self.process_position, 12 | disable=self.is_disabled, 13 | leave=False, 14 | ncols=120, 15 | ascii=True) 16 | return bar 17 | 18 | def init_train_tqdm(self) -> tqdm: 19 | bar = tqdm(desc='Training', 20 | initial=self.train_batch_idx, 21 | position=self.process_position, 22 | disable=self.is_disabled, 23 | leave=True, 24 | smoothing=0, 25 | ncols=120, 26 | ascii=True) 27 | return bar 28 | 29 | 30 | 31 | def init_validation_tqdm(self) -> tqdm: 32 | bar = tqdm(disable=True) 33 | return bar 34 | 35 | def on_epoch_start(self, trainer, pl_module): 36 | super().on_epoch_start(trainer, pl_module) 37 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|train') 38 | 39 | def on_validation_start(self, trainer, pl_module): 40 | super().on_validation_start(trainer, pl_module) 41 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|val') 42 | 43 | def on_test_start(self, trainer, pl_module): 44 | super().on_test_start(trainer, pl_module) 45 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|test') 46 | 47 | -------------------------------------------------------------------------------- /src/callbacks/transformer_scheduler.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup 3 | import logging 4 | log = logging.getLogger(__name__) 5 | 6 | class TransformerLrScheduler(pl.Callback): 7 | def __init__(self, warmup): 8 | self.warmup = warmup 9 | 10 | def on_train_start(self, trainer, pl_module): 11 | for lr_scheduler in trainer.lr_schedulers: 12 | scheduler = lr_scheduler['scheduler'] 13 | n_train = len(pl_module.train_dataloader()) 14 | n_accumulate_grad = trainer.accumulate_grad_batches 15 | n_max_epochs = trainer.max_epochs 16 | num_training_steps = n_train // n_accumulate_grad * n_max_epochs 17 | num_warmup_steps = int(self.warmup * num_training_steps) 18 | 19 | if pl_module.optimizer_cfg.scheduler_type == 'linear_warmup': 20 | lr_scheduler['scheduler'] = get_linear_schedule_with_warmup(scheduler.optimizer, num_warmup_steps, num_training_steps) 21 | 22 | elif pl_module.optimizer_cfg.scheduler_type == 'constant_warmup': 23 | lr_scheduler['scheduler'] = get_constant_schedule_with_warmup(scheduler.optimizer, num_warmup_steps, 24 | ) 25 | 26 | log.info(f"Warm up rate:{self.warmup}") 27 | log.info(f"total number of training step:{num_training_steps}") 28 | log.info(f"number of training batches per epochs in the dataloader:{n_train}") -------------------------------------------------------------------------------- /src/callbacks/wandb_callbacks.py: -------------------------------------------------------------------------------- 1 | # wandb 2 | from pytorch_lightning.loggers import WandbLogger 3 | import wandb 4 | 5 | # pytorch 6 | from pytorch_lightning import Callback 7 | import pytorch_lightning as pl 8 | import torch 9 | 10 | # others 11 | import glob 12 | import os 13 | 14 | 15 | def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger: 16 | logger = None 17 | for lg in trainer.logger: 18 | if isinstance(lg, WandbLogger): 19 | logger = lg 20 | 21 | if not logger: 22 | raise Exception( 23 | "You are using wandb related callback," 24 | "but WandbLogger was not found for some reason..." 25 | ) 26 | 27 | return logger 28 | 29 | 30 | # class UploadCodeToWandbAsArtifact(Callback): 31 | # """Upload all *.py files to wandb as an artifact at the beginning of the run.""" 32 | # 33 | # def __init__(self, code_dir: str): 34 | # self.code_dir = code_dir 35 | # 36 | # def on_train_start(self, trainer, pl_module): 37 | # logger = get_wandb_logger(trainer=trainer) 38 | # experiment = logger.experiment 39 | # 40 | # code = wandb.Artifact("project-source", type="code") 41 | # for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): 42 | # print(path) 43 | # code.add_file(path) 44 | # print('ok') 45 | # 46 | # 47 | # 48 | # experiment.use_artifact(code) 49 | # print("successfully update the code .") 50 | 51 | 52 | class UploadHydraConfigFileToWandb(Callback): 53 | def on_fit_start(self, trainer, pl_module: LightningModule) -> None: 54 | logger = get_wandb_logger(trainer=trainer) 55 | 56 | logger.experiment.save() 57 | 58 | 59 | 60 | class UploadCheckpointsToWandbAsArtifact(Callback): 61 | """Upload experiment checkpoints to wandb as an artifact at the end of training.""" 62 | 63 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): 64 | self.ckpt_dir = ckpt_dir 65 | self.upload_best_only = upload_best_only 66 | 67 | def on_train_end(self, trainer, pl_module): 68 | logger = get_wandb_logger(trainer=trainer) 69 | experiment = logger.experiment 70 | 71 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") 72 | 73 | if self.upload_best_only: 74 | ckpts.add_file(trainer.checkpoint_callback.best_model_path) 75 | else: 76 | for path in glob.glob( 77 | os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True 78 | ): 79 | ckpts.add_file(path) 80 | 81 | experiment.use_artifact(ckpts) 82 | 83 | 84 | class WatchModelWithWandb(Callback): 85 | """Make WandbLogger watch model at the beginning of the run.""" 86 | 87 | def __init__(self, log: str = "gradients", log_freq: int = 100): 88 | self.log = log 89 | self.log_freq = log_freq 90 | 91 | def on_train_start(self, trainer, pl_module): 92 | logger = get_wandb_logger(trainer=trainer) 93 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) 94 | 95 | -------------------------------------------------------------------------------- /src/constant.py: -------------------------------------------------------------------------------- 1 | pad = '' 2 | unk = '' 3 | bos = '' 4 | eos = '' 5 | arc = 'arc' 6 | rel = 'rel' 7 | word = 'word' 8 | chart = 'chart' 9 | raw_tree = 'raw_tree' 10 | valid = 'valid' 11 | sib = 'sibling' 12 | relatives = 'relatives' 13 | span_head = 'span_head' 14 | 15 | -------------------------------------------------------------------------------- /src/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/datamodule/dm_util/datamodule_util.py: -------------------------------------------------------------------------------- 1 | from supar.utils.alg import kmeans 2 | from supar.utils.data import Sampler 3 | from fastNLP.core.field import Padder 4 | import numpy as np 5 | 6 | 7 | 8 | 9 | def get_sampler(lengths, max_tokens, n_buckets, shuffle=True, distributed=False, evaluate=False): 10 | buckets = dict(zip(*kmeans(lengths, n_buckets))) 11 | return Sampler(buckets=buckets, 12 | batch_size=max_tokens, 13 | shuffle=shuffle, 14 | distributed=distributed, 15 | evaluate=evaluate) 16 | 17 | 18 | 19 | class SiblingPadder(Padder): 20 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 21 | # max_sent_length = max(rule.shape[0] for rule in contents) 22 | batch_size = sum(len(r) for r in contents) 23 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val, 24 | dtype=np.long) 25 | 26 | i = 0 27 | for b_idx, relations in enumerate(contents): 28 | for (head, child, sibling, _) in relations: 29 | padded_array[i] = np.array([b_idx, head, child, sibling]) 30 | i+=1 31 | 32 | return padded_array 33 | 34 | 35 | 36 | 37 | class ConstAsDepPadder(Padder): 38 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 39 | # max_sent_length = max(rule.shape[0] for rule in contents) 40 | batch_size = sum(len(r) for r in contents) 41 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val, 42 | dtype=np.long) 43 | i = 0 44 | for b_idx, relations in enumerate(contents): 45 | for (head, child, sibling, *_) in relations: 46 | padded_array[i] = np.array([b_idx, head, child, sibling]) 47 | i+=1 48 | return padded_array 49 | 50 | class GrandPadder(Padder): 51 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 52 | batch_size = sum(len(r) for r in contents) 53 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val, 54 | dtype=np.float) 55 | i = 0 56 | for b_idx, relations in enumerate(contents): 57 | for (head, child, _, grandparent, *_) in relations: 58 | padded_array[i] = np.array([b_idx, head, child, grandparent]) 59 | i += 1 60 | 61 | return padded_array 62 | 63 | class SpanPadderCAD(Padder): 64 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 65 | batch_size = sum((len(r) * 2 - 1) for r in contents) 66 | padded_array = np.full((batch_size, 6), fill_value=self.pad_val, 67 | dtype=np.float) 68 | i = 0 69 | 70 | ## 0 stands for inherent 71 | ## 1 stands for noninherent 72 | for b_idx, span in enumerate(contents): 73 | for (head, child, ih_start, ih_end, ni_start, ni_end) in span: 74 | if not ih_start == -1: 75 | padded_array[i] = np.array([b_idx, head, child, ih_start, ih_end, 0]) 76 | i += 1 77 | padded_array[i] = np.array([b_idx, head, child, ni_start, ni_end, 1]) 78 | i += 1 79 | 80 | assert i == batch_size 81 | return padded_array 82 | 83 | 84 | 85 | 86 | class SpanHeadPadder(Padder): 87 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 88 | batch_size = sum(len(r) for r in contents) 89 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val, 90 | dtype=np.float) 91 | i = 0 92 | for b_idx, relations in enumerate(contents): 93 | for (left, right, head) in relations: 94 | padded_array[i] = np.array([b_idx, left, right, head]) 95 | i += 1 96 | return padded_array 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /src/datamodule/dm_util/padder.py: -------------------------------------------------------------------------------- 1 | from fastNLP.core.field import Padder 2 | import numpy as np 3 | 4 | 5 | def set_padder(datasets, name, padder): 6 | for _, dataset in datasets.items(): 7 | dataset.add_field(name, dataset[name].content, padder=padder, ignore_type=True) 8 | 9 | 10 | class DepSibPadder(Padder): 11 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 12 | # max_sent_length = max(rule.shape[0] for rule in contents) 13 | padded_array = [] 14 | # dependency head or relations. 15 | for b_idx, dep in enumerate(contents): 16 | for (child_idx, (head_idx, sib_ix)) in enumerate(dep): 17 | # -1 means no sib; 18 | if sib_ix != -1: 19 | padded_array.append([b_idx, child_idx + 1, head_idx, sib_ix]) 20 | else: 21 | pass 22 | return np.array(padded_array) 23 | 24 | 25 | 26 | 27 | 28 | class DepPadder(Padder): 29 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 30 | # max_sent_length = max(rule.shape[0] for rule in contents) 31 | padded_array = [] 32 | # dependency head or relations. 33 | for b_idx, dep in enumerate(contents): 34 | for (child_idx, dep_idx) in enumerate(dep): 35 | padded_array.append([b_idx, child_idx + 1, dep_idx]) 36 | return np.array(padded_array) 37 | 38 | 39 | 40 | 41 | class SpanHeadWordPadder(Padder): 42 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 43 | padded_array = [] 44 | for b_idx, relations in enumerate(contents): 45 | for (left, right, _, head) in relations: 46 | padded_array.append([b_idx, left, right, head]) 47 | return np.array(padded_array) 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /src/datamodule/dm_util/util.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | 3 | punct_set = '.' '``' "''" ':' ',' 4 | import re 5 | 6 | 7 | # https://github.com/DoodleJZ/HPSG-Neural-Parser/blob/cdcffa78945359e14063326cadd93fd4c509c585/src_joint/dep_eval.py 8 | def is_uni_punctuation(word): 9 | match = re.match("^[^\w\s]+$]", word, flags=re.UNICODE) 10 | return match is not None 11 | 12 | def is_punctuation(word, pos, punct_set=punct_set): 13 | if punct_set is None: 14 | return is_uni_punctuation(word) 15 | else: 16 | return pos in punct_set or pos == 'PU' # for chinese 17 | 18 | def get_path(path): 19 | return path 20 | 21 | def get_path_debug(path): 22 | return path + ".debug" 23 | 24 | def clean_number(w): 25 | new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', '0', w) 26 | return new_w 27 | 28 | def clean_word(words): 29 | PTB_UNESCAPE_MAPPING = { 30 | "«": '"', 31 | "»": '"', 32 | "‘": "'", 33 | "’": "'", 34 | "“": '"', 35 | "”": '"', 36 | "„": '"', 37 | "‹": "'", 38 | "›": "'", 39 | "\u2013": "--", # en dash 40 | "\u2014": "--", # em dash 41 | } 42 | cleaned_words = [] 43 | for word in words: 44 | word = PTB_UNESCAPE_MAPPING.get(word, word) 45 | word = word.replace("\\/", "/").replace("\\*", "*") 46 | # Mid-token punctuation occurs in biomedical text 47 | word = word.replace("-LSB-", "[").replace("-RSB-", "]") 48 | word = word.replace("-LRB-", "(").replace("-RRB-", ")") 49 | word = word.replace("-LCB-", "{").replace("-RCB-", "}") 50 | word = word.replace("``", '"').replace("`", "'").replace("''", '"') 51 | word = clean_number(word) 52 | cleaned_words.append(word) 53 | return cleaned_words 54 | 55 | 56 | 57 | 58 | def find_dep_boundary(heads): 59 | left_bd = [i for i in range(len(heads))] 60 | right_bd = [i + 1 for i in range(len(heads))] 61 | 62 | for child_idx, head_idx in enumerate(heads): 63 | if head_idx > 0: 64 | if left_bd[child_idx] < left_bd[head_idx - 1]: 65 | left_bd[head_idx - 1] = left_bd[child_idx] 66 | 67 | elif child_idx > right_bd[head_idx - 1] - 1: 68 | right_bd[head_idx - 1] = child_idx + 1 69 | while head_idx != 0: 70 | if heads[head_idx-1] > 0 and child_idx + 1 > right_bd[ heads[head_idx-1] - 1] : 71 | right_bd[ heads[head_idx-1] - 1] = child_idx + 1 72 | head_idx = heads[head_idx-1] 73 | else: 74 | break 75 | 76 | # (head_word_idx, left_bd_idx, right_bd_idx) 77 | triplet = [] 78 | # head index should add1, as the root token would be the first token. 79 | # [ ) left bdr, right bdr. 80 | # for i in range(len(heads)): 81 | # what do I want? 82 | # 生成整个span的score???????? 83 | # seems ok.s 84 | 85 | # left boundary, right boundary, parent, head 86 | for i, (parent, left_bdr, right_bdr) in enumerate(zip(heads, left_bd, right_bd)): 87 | triplet.append([left_bdr, right_bdr, parent-1, i]) 88 | 89 | 90 | 91 | 92 | 93 | return triplet 94 | 95 | 96 | 97 | 98 | 99 | def isProjective(heads): 100 | pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0] 101 | for i, (hi, di) in enumerate(pairs): 102 | for hj, dj in pairs[i+1:]: 103 | (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj]) 104 | if li <= hj <= ri and hi == dj: 105 | return False 106 | if lj <= hi <= rj and hj == di: 107 | return False 108 | if (li < lj < ri or li < rj < ri) and (li - lj)*(ri - rj) > 0: 109 | return False 110 | return True 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /src/inside/__init__.py: -------------------------------------------------------------------------------- 1 | from .eisner_satta import es4dep 2 | from .eisner import eisner, eisner_headsplit 3 | from .eisner2o import eisner2o, eisner2o_headsplit 4 | from .span import span_inside 5 | 6 | __all__ = ['eisner_satta', 7 | 'eisner', 8 | 'eisner2o', 9 | 'es4dep', 10 | 'eisner_headsplit', 11 | 'eisner2o_headsplit', 12 | 'span_inside' 13 | ] 14 | 15 | -------------------------------------------------------------------------------- /src/inside/eisner_satta.py: -------------------------------------------------------------------------------- 1 | from .fn import * 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | from ..loss.get_score import augment_score 5 | 6 | 7 | 8 | # O(n4) span+arc 9 | @torch.enable_grad() 10 | def es4dep(ctx, decode=False, max_margin=False): 11 | if max_margin: 12 | augment_score(ctx) 13 | 14 | lens = ctx['seq_len'] 15 | 16 | dependency = ctx['s_arc'] 17 | 18 | 19 | B, seq_len = dependency.shape[:2] 20 | head_score = ctx['s_span_head_word'] 21 | 22 | if decode: 23 | dependency = dependency.detach().clone().requires_grad_(True) 24 | 25 | if max_margin: 26 | dependency = dependency.detach().clone().requires_grad_(True) 27 | head_score = head_score.detach().clone().requires_grad_(True) 28 | 29 | dep = dependency[:, 1:, 1:].contiguous() 30 | root = dependency[:, 1:, 0].contiguous() 31 | 32 | viterbi = decode or max_margin 33 | 34 | N = seq_len 35 | H = N - 1 36 | 37 | s = dependency.new_zeros(B, N, N, H).fill_(-1e9) 38 | s_close = dependency.new_zeros(B, N, N, H).fill_(-1e9) 39 | s_need_dad = dependency.new_zeros(B, N, N, H).fill_(-1e9) 40 | 41 | s_close[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)] = head_score[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)] 42 | 43 | s[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)] = 0 44 | 45 | s_need_dad[:, torch.arange(N - 1), torch.arange(N - 1) + 1, :] = dep[:, torch.arange(N - 1)] + s_close[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)].unsqueeze(-1) 46 | 47 | def merge(left, right, left_need_dad, right_need_dad): 48 | left = (left + right_need_dad) 49 | right = (right + left_need_dad) 50 | if viterbi: 51 | headed = torch.stack([left.max(2)[0], right.max(2)[0]]) 52 | return headed.max(0)[0] 53 | else: 54 | headed = torch.stack([left.logsumexp(2), right.logsumexp(2)]) 55 | return headed.logsumexp(0) 56 | 57 | def seek_head(a, b): 58 | if viterbi: 59 | tmp = (a + b).max(-2)[0] 60 | else: 61 | tmp = (a + b).logsumexp(-2) 62 | return tmp 63 | 64 | 65 | for w in range(2, N): 66 | n = N - w 67 | left = stripe_version2(s, n, w - 1, (0, 1)) 68 | right = stripe_version2(s, n, w - 1, (1, w), 0) 69 | left_need_dad = stripe_version2(s_need_dad, n, w - 1, (0, 1)) 70 | right_need_dad = stripe_version2(s_need_dad, n, w - 1, (1, w), 0) 71 | headed = checkpoint(merge, left.clone(), right.clone(), left_need_dad.clone(),right_need_dad.clone()) 72 | diagonal_copy_v2(s, headed, w) 73 | headed = headed + diagonal_v2(head_score, w) 74 | diagonal_copy_v2(s_close, headed, w) 75 | 76 | if w < N - 1: 77 | u = checkpoint(seek_head, headed.unsqueeze(-1), stripe_version5(dep, n, w)) 78 | diagonal_copy_(s_need_dad, u, w) 79 | 80 | logZ = (s_close[torch.arange(B), 0, lens] + root) 81 | if viterbi: 82 | logZ = logZ.max(-1)[0] 83 | else: 84 | logZ = logZ.logsumexp(-1) 85 | 86 | #crf loss 87 | if not decode and not max_margin: 88 | return logZ 89 | 90 | logZ.sum().backward() 91 | 92 | if decode: 93 | predicted_arc = s.new_zeros(B, seq_len).long() 94 | arc = dependency.grad.nonzero() 95 | predicted_arc[arc[:, 0], arc[:, 1]] = arc[:, 2] 96 | ctx['arc_pred'] = predicted_arc 97 | 98 | if max_margin: 99 | ctx['s_arc_grad'] = dependency.grad 100 | if head_score is not None: 101 | ctx['s_span_head_word_grad'] = head_score.grad 102 | -------------------------------------------------------------------------------- /src/inside/span.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from src.inside.fn import * 3 | from src.inside.eisner_satta import es4dep 4 | from ..loss.get_score import augment_score 5 | 6 | # O(n3) algorithm, only use span core, do not use arc score for projective dependency parsing. 7 | @torch.enable_grad() 8 | def span_inside(ctx, decode=False, max_margin=False): 9 | 10 | assert decode or max_margin 11 | 12 | if max_margin: 13 | augment_score(ctx) 14 | 15 | s_span_score = ctx['s_span_head_word'] 16 | lens = ctx['seq_len'] 17 | 18 | s_span_score = s_span_score.detach().clone().requires_grad_(True) 19 | 20 | b, seq_len = s_span_score.shape[:2] 21 | 22 | s_inside_children = s_span_score.new_zeros(b, seq_len, seq_len).fill_(-1e9) 23 | s_inside_close = s_span_score.new_zeros(b, seq_len, seq_len).fill_(-1e9) 24 | 25 | # do I need s_close? it seems that i do not need this term? right 26 | s_inside_children[:, torch.arange(seq_len-1), torch.arange(seq_len-1)+1] = s_span_score[:, torch.arange(seq_len-1), torch.arange(seq_len-1)+1, torch.arange(seq_len-1)] 27 | 28 | for w in range(2, seq_len): 29 | n = seq_len - w 30 | 31 | # two child compose together. 32 | left = stripe(s_inside_children, n, w - 1, (0, 1)) 33 | right = stripe(s_inside_children, n, w - 1, (1, w), 0) 34 | compose = (left + right).max(2)[0] 35 | 36 | # case 1: the head word is right-most 37 | l = left[:, :, -1] 38 | compose_score1 = l + s_span_score[:, torch.arange(n), torch.arange(n)+w, torch.arange(n)+w-1] 39 | 40 | # case 2: the head word is left-most. 41 | r = right[:, :, 0] 42 | compose_score2 = r + s_span_score[:, torch.arange(n), torch.arange(n)+w, torch.arange(n)] 43 | 44 | if w > 2: 45 | left = stripe(s_inside_children, n, w - 2, (0, 1)) 46 | right = stripe(s_inside_children, n, w - 2, (2, w), 0) 47 | compose_score3 = left + right + diagonal_v2(s_span_score, w)[:, :, 1:-1] 48 | compose_score = torch.cat([compose_score1.unsqueeze(2), compose_score3, compose_score2.unsqueeze(2)], dim=2) 49 | compose_score = compose_score.max(2)[0] 50 | 51 | else: 52 | compose_score = torch.cat([compose_score1.unsqueeze(2), compose_score2.unsqueeze(2)], dim=2) 53 | compose_score = compose_score.max(2)[0] 54 | 55 | compose = torch.stack([compose, compose_score]).max(0)[0] 56 | diagonal_copy_(s_inside_children, compose, w) 57 | diagonal_copy_(s_inside_close, compose_score, w) 58 | 59 | try: 60 | s_inside_close[torch.arange(b), 0, lens].sum().backward() 61 | except: 62 | pass 63 | 64 | if decode: 65 | ctx['arc_pred'] = recover_arc(s_span_score.grad, lens) 66 | 67 | if max_margin: 68 | ctx['s_span_head_word_grad'] = s_span_score.grad 69 | 70 | 71 | 72 | 73 | # heads: (left, right, head word). 74 | # from decoded span to recover arcs. 75 | def recover_arc(heads, lens): 76 | if heads is None: 77 | return lens.new_zeros(lens.shape[0], lens.max() + 1) 78 | result = np.zeros(shape=(heads.shape[0], heads.shape[1])) 79 | lens = lens.detach().cpu().numpy() 80 | for i in range(heads.shape[0]): 81 | if lens[i] == 1: 82 | result[i][1] = 0 83 | else: 84 | span = heads[i].detach().nonzero().cpu().numpy() 85 | start = span[:,0] 86 | end = span[:,1] 87 | preorder_sort = np.lexsort((-end, start)) 88 | start = start[preorder_sort] 89 | end = end[preorder_sort] 90 | head = span[:,2][preorder_sort] 91 | stack = [] 92 | arcs = [] 93 | stack.append((start[0], end[0], 0)) 94 | result[i][head[0]+1] = 0 95 | j = 0 96 | while j < start.shape[0]-1: 97 | j+=1 98 | s = start[j] 99 | e = end[j] 100 | top = stack[-1] 101 | top_s, top_e, top_i = top 102 | if top_s <= s and top_e >= e: 103 | arcs.append((head[top_i] + 1, head[j] + 1)) 104 | result[i][head[j] + 1] = head[top_i] + 1 105 | stack.append((s, e, j)) 106 | else: 107 | while top_s > s or top_e < e: 108 | stack.pop() 109 | top = stack[-1] 110 | top_s, top_e, top_i = top 111 | arcs.append([head[top_i] + 1, head[j] + 1]) 112 | result[i][head[j] + 1] = head[top_i] + 1 113 | stack.append((s, e, j)) 114 | return torch.tensor(result, device=heads.device, dtype=torch.long) 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | -------------------------------------------------------------------------------- /src/loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sustcsonglin/span-based-dependency-parsing/db7b94a06a8da3d3055e5baf50fe2fe2f10e58d3/src/loss/__init__.py -------------------------------------------------------------------------------- /src/loss/dep_loss.py: -------------------------------------------------------------------------------- 1 | from ..inside import * 2 | from supar.utils.common import * 3 | from .get_score import * 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | class DepLoss(): 8 | def __init__(self, conf): 9 | self.conf = conf 10 | self.inside = eisner 11 | 12 | 13 | @classmethod 14 | def label_loss(self, ctx, reduction='mean'): 15 | s_rel = ctx['s_rel'] 16 | gold_rel = ctx['rel'] 17 | gold_arc = ctx['head'] 18 | if len(s_rel.shape) == 4: 19 | return F.cross_entropy(s_rel[gold_arc[:, 0], gold_arc[:, 1], gold_arc[:, 2]] , torch.as_tensor(gold_rel[:, -1], device=s_rel.device, dtype=torch.long),reduction=reduction) 20 | elif len(s_rel.shape) == 3: 21 | return F.cross_entropy(s_rel[gold_arc[:, 0], gold_arc[:, 1]], torch.as_tensor(gold_rel[:, -1], device=s_rel.device, dtype=torch.long),reduction=reduction) 22 | else: 23 | raise AssertionError 24 | 25 | @classmethod 26 | def get_pred_rels(self,ctx): 27 | arc_preds = ctx['arc_pred'] 28 | s_rel = ctx['s_rel'] 29 | ctx['rel_pred'] = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1) 30 | 31 | # for evaluation. 32 | @classmethod 33 | def _transform(self, ctx): 34 | arc_preds, arc_golds = ctx['arc_pred'], ctx['head'] 35 | rel_preds, rel_golds = ctx['rel_pred'], ctx['rel'] 36 | 37 | arc_golds = torch.as_tensor(arc_golds, device=arc_preds.device, dtype=arc_preds.dtype) 38 | rel_golds = torch.as_tensor(rel_golds, device=rel_preds.device, dtype=rel_preds.dtype) 39 | 40 | arc_gold = arc_preds.new_zeros(*arc_preds.shape).fill_(-1) 41 | arc_gold[arc_golds[:, 0], arc_golds[:, 1]] = arc_golds[:, 2] 42 | 43 | rel_gold = rel_preds.new_zeros(*rel_preds.shape).fill_(-1) 44 | rel_gold[rel_golds[:, 0], rel_golds[:, 1]] = rel_golds[:, 2] 45 | mask_dep = arc_gold.ne(-1) 46 | 47 | #ignore punct. 48 | if 'is_punct' in ctx: 49 | mask_punct = ctx['is_punct'].nonzero() 50 | mask_dep[mask_punct[:, 0], mask_punct[:, 1] + 1] = False 51 | ctx['arc_gold'] = arc_gold 52 | ctx['rel_gold'] = rel_gold 53 | ctx['mask_dep'] = mask_dep 54 | 55 | def max_margin_loss(self, ctx): 56 | with torch.no_grad(): 57 | self.inside(ctx, max_margin=True) 58 | gold_score = u_score(ctx) 59 | predict_score = predict_score_mm(ctx) 60 | return (predict_score - gold_score)/ctx['seq_len'].sum() 61 | 62 | def crf_loss(self, ctx): 63 | logZ = self.inside(ctx).sum() 64 | gold_score = u_score(ctx) 65 | return (logZ - gold_score)/ ctx['seq_len'].sum() 66 | 67 | def local_loss(self, ctx): 68 | raise NotImplementedError 69 | 70 | def loss(self, ctx): 71 | if self.conf.loss_type == 'mm': 72 | tree_loss = self.max_margin_loss(ctx) 73 | elif self.conf.loss_type == 'crf': 74 | tree_loss = self.crf_loss(ctx) 75 | elif self.conf.loss_type == 'local': 76 | tree_loss = self.local_loss(ctx) 77 | label_loss = self.label_loss(ctx) 78 | return tree_loss + label_loss 79 | 80 | def decode(self, ctx): 81 | self.inside(ctx, decode=True) 82 | self.get_pred_rels(ctx) 83 | self._transform(ctx) 84 | 85 | 86 | class Dep1OSpan(DepLoss): 87 | def __init__(self, conf): 88 | super(Dep1OSpan, self).__init__(conf) 89 | self.inside = es4dep 90 | 91 | class Dep2O(DepLoss): 92 | def __init__(self, conf): 93 | super(Dep2O, self).__init__(conf) 94 | self.inside = eisner2o 95 | 96 | class Dep2OHeadSplit(DepLoss): 97 | def __init__(self, conf): 98 | super(Dep2OHeadSplit, self).__init__(conf) 99 | self.inside = eisner2o_headsplit 100 | 101 | class Dep1OHeadSplit(DepLoss): 102 | def __init__(self, conf): 103 | super(Dep1OHeadSplit, self).__init__(conf) 104 | self.inside = eisner_headsplit 105 | 106 | class Span(DepLoss): 107 | def __init__(self, conf): 108 | super(Span, self).__init__(conf) 109 | self.inside = span_inside 110 | 111 | 112 | -------------------------------------------------------------------------------- /src/loss/get_score.py: -------------------------------------------------------------------------------- 1 | # obtain the score of unlabeled trees. 2 | def u_score(ctx): 3 | score = 0 4 | 5 | if 's_arc' in ctx: 6 | score += get_arc_score(ctx) 7 | 8 | if 's_span_head_word' in ctx: 9 | score += get_span_head_word_score(ctx) 10 | 11 | if 's_bd' in ctx: 12 | score += get_bd_score(ctx) 13 | 14 | if 's_bd_left' in ctx: 15 | assert 's_bd_right' in ctx 16 | score += get_bd_left_score(ctx) 17 | score += get_bd_right_score(ctx) 18 | 19 | if 's_sib' in ctx: 20 | score += get_sib_score(ctx) 21 | 22 | return score 23 | 24 | def predict_score_mm(ctx): 25 | score = 0 26 | if 's_arc' in ctx : 27 | score += (ctx['s_arc'] * ctx['s_arc_grad']).sum() 28 | 29 | if 's_span_head_word' in ctx: 30 | score += (ctx['s_span_head_word'] * ctx['s_span_head_word_grad']).sum() 31 | 32 | if 's_bd' in ctx: 33 | score += (ctx['s_bd'] * ctx['s_bd_grad']).sum() 34 | 35 | if 's_bd_left' in ctx: 36 | score += (ctx['s_bd_left'] * ctx['s_bd_left_grad']).sum() 37 | 38 | if 's_bd_right' in ctx: 39 | score += (ctx['s_bd_right'] * ctx['s_bd_right_grad']).sum() 40 | 41 | if 's_sib' in ctx: 42 | try: 43 | score += (ctx['s_sib'] * ctx['s_sib_grad']).sum() 44 | except: 45 | # corner case: (e.g. sentences of length 2, no siblings) 46 | pass 47 | 48 | return score 49 | 50 | 51 | def augment_score(ctx): 52 | if 's_arc' in ctx: 53 | aug_arc_score(ctx) 54 | 55 | if 's_span_head_word' in ctx: 56 | aug_span_head_word_score(ctx) 57 | 58 | if 's_bd' in ctx: 59 | aug_bd_score(ctx) 60 | 61 | if 's_bd_left' in ctx: 62 | aug_bd_left_score(ctx) 63 | 64 | if 's_bd_right' in ctx: 65 | aug_bd_right_score(ctx) 66 | 67 | if 's_sib' in ctx: 68 | aug_sib_score(ctx) 69 | 70 | 71 | 72 | def get_arc_score(ctx): 73 | s_arc = ctx['s_arc'] 74 | arcs = ctx['head'] 75 | return s_arc[arcs[:, 0], arcs[:, 1], arcs[:, 2]].sum() 76 | 77 | 78 | 79 | def aug_arc_score(ctx): 80 | s_arc = ctx['s_arc'] 81 | arcs = ctx['head'] 82 | s_arc[arcs[:, 0], arcs[:, 1], arcs[:, 2]] -= 1 83 | 84 | 85 | 86 | def get_sib_score(ctx): 87 | s_sib = ctx['s_sib'] 88 | sib = ctx['sib'] 89 | try: 90 | return s_sib[sib[:, 0], sib[:, 1], sib[:, 2], sib[:, 3]].sum() 91 | except: 92 | return 0 93 | 94 | def aug_sib_score(ctx): 95 | s_sib = ctx['s_sib'] 96 | sib = ctx['sib'] 97 | try: 98 | s_sib[sib[:, 0], sib[:, 1], sib[:, 2], sib[:, 3]] -= 1 99 | except: 100 | pass 101 | 102 | 103 | def get_span_head_word_score(ctx): 104 | span_head_word = ctx['span_head_word'] 105 | s_span_head_word = ctx['s_span_head_word'] 106 | score = s_span_head_word[span_head_word[:, 0], span_head_word[:, 1], span_head_word[:, 2], span_head_word[:, 3]].sum() 107 | return score 108 | 109 | def aug_span_head_word_score(ctx): 110 | span_head_word = ctx['span_head_word'] 111 | s_span_head_word = ctx['s_span_head_word'] 112 | s_span_head_word[span_head_word[:, 0], span_head_word[:, 1], span_head_word[:, 2], span_head_word[:, 3]] -= 1.5 113 | 114 | 115 | def get_bd_score(ctx): 116 | s_bd = ctx['s_bd'] 117 | span_head = ctx['span_head_word'] 118 | score = 0 119 | score += s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 1]].sum() 120 | score += s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]].sum() 121 | return score 122 | 123 | def aug_bd_score(ctx): 124 | s_bd = ctx['s_bd'] 125 | span_head = ctx['span_head_word'] 126 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,1]] -= 1 127 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,2]] -= 1 128 | 129 | 130 | def get_bd_left_score(ctx): 131 | s_bd = ctx['s_bd_left'] 132 | span_head = ctx['span_head_word'] 133 | score = s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 1]].sum() 134 | return score 135 | 136 | def aug_bd_left_score(ctx): 137 | s_bd = ctx['s_bd_left'] 138 | span_head = ctx['span_head_word'] 139 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,1]] -= 1 140 | 141 | 142 | def get_bd_right_score(ctx): 143 | s_bd = ctx['s_bd_right'] 144 | span_head = ctx['span_head_word'] 145 | score = s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]].sum() 146 | return score 147 | 148 | 149 | def aug_bd_right_score(ctx): 150 | s_bd = ctx['s_bd_right'] 151 | span_head = ctx['span_head_word'] 152 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]] -= 1 153 | 154 | 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /src/model/dep_parsing.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import logging 3 | import hydra 4 | log = logging.getLogger(__name__) 5 | 6 | class ProjectiveDepParser(nn.Module): 7 | def __init__(self, conf, fields): 8 | super(ProjectiveDepParser, self).__init__() 9 | self.conf = conf 10 | self.fields = fields 11 | self.embeder = hydra.utils.instantiate(conf.embeder.target, conf.embeder, fields=fields, _recursive_=False) 12 | self.encoder = hydra.utils.instantiate(conf.encoder.target, conf.encoder, input_dim= self.embeder.get_output_dim(), _recursive_=False) 13 | self.scorer = hydra.utils.instantiate(conf.scorer.target, conf.scorer, fields=fields, input_dim=self.encoder.get_output_dim(), _recursive_=False) 14 | self.loss = hydra.utils.instantiate(conf.loss.target, conf.loss, _recursive_=False) 15 | self.metric = hydra.utils.instantiate(conf.metric.target, conf.metric, fields=fields, _recursive_=False) 16 | 17 | log.info(self.embeder) 18 | log.info(self.encoder) 19 | log.info(self.scorer) 20 | 21 | def forward(self, ctx): 22 | self.embeder(ctx) 23 | self.encoder(ctx) 24 | self.scorer(ctx) 25 | 26 | def get_loss(self, x, y): 27 | ctx = {**x, **y} 28 | self.forward(ctx) 29 | return self.loss.loss(ctx) 30 | 31 | def decode(self, x, y): 32 | ctx = {**x, **y} 33 | self.forward(ctx) 34 | self.loss.decode(ctx) 35 | return ctx 36 | 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /src/model/module/ember/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from supar.modules.char_lstm import CharLSTM 4 | from supar.modules import TransformerEmbedding 5 | from src.model.module.ember.ext_embedding import ExternalEmbeddingSupar 6 | import copy 7 | 8 | class Embeder(nn.Module): 9 | def __init__(self, conf, fields): 10 | super(Embeder, self).__init__() 11 | self.conf = conf 12 | 13 | if 'pos' in fields.inputs: 14 | self.pos_emb = nn.Embedding(fields.get_vocab_size("pos"), conf.n_pos_embed) 15 | else: 16 | self.pos_emb = None 17 | 18 | if 'char' in fields.inputs: 19 | self.feat = CharLSTM(n_chars=fields.get_vocab_size('char'), 20 | n_embed=conf.n_char_embed, 21 | n_out=conf.n_char_out, 22 | pad_index=fields.get_pad_index('char'), 23 | input_dropout=conf.char_input_dropout) 24 | self.feat_name = 'char' 25 | 26 | if 'bert' in fields.inputs: 27 | self.feat = TransformerEmbedding(model=fields.get_bert_name(), 28 | n_layers=conf.n_bert_layers, 29 | n_out=conf.n_bert_out, 30 | pad_index=fields.get_pad_index("bert"), 31 | dropout=conf.mix_dropout, 32 | requires_grad=conf.finetune, 33 | use_projection=conf.use_projection, 34 | use_scalarmix=conf.use_scalarmix) 35 | self.feat_name = "bert" 36 | print(fields.get_bert_name()) 37 | 38 | if ('char' not in fields.inputs and 'bert' not in fields.inputs): 39 | self.feat = None 40 | 41 | if 'word' in fields.inputs: 42 | ext_emb = fields.get_ext_emb() 43 | if ext_emb: 44 | self.word_emb = copy.deepcopy(ext_emb) 45 | else: 46 | self.word_emb = nn.Embedding(num_embeddings=fields.get_vocab_size('word'), 47 | embedding_dim=conf.n_embed) 48 | else: 49 | self.word_emb = None 50 | 51 | 52 | def forward(self, ctx): 53 | emb = {} 54 | 55 | if self.pos_emb: 56 | emb['pos'] = self.pos_emb(ctx['pos']) 57 | 58 | if self.word_emb: 59 | emb['word'] = self.word_emb(ctx['word']) 60 | 61 | #For now, char or ber、t, choose one. 62 | if self.feat: 63 | emb[self.feat_name] = self.feat(ctx[self.feat_name]) 64 | 65 | ctx['embed'] = emb 66 | 67 | 68 | def get_output_dim(self): 69 | 70 | size = 0 71 | 72 | if self.pos_emb: 73 | size += self.conf.n_pos_embed 74 | 75 | if self.word_emb: 76 | if isinstance(self.word_emb, nn.Embedding): 77 | size += self.conf.n_embed 78 | else: 79 | size += self.word_emb.get_dim() 80 | 81 | if self.feat: 82 | if self.feat_name == 'char': 83 | size += self.conf.n_char_out 84 | else: 85 | size += self.feat.n_out 86 | return size 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /src/model/module/ember/ext_embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import logging 3 | 4 | 5 | log = logging.getLogger(__name__) 6 | 7 | # As Supar. 8 | class ExternalEmbeddingSupar(nn.Module): 9 | def __init__(self, emb, origin_word_size, unk_index): 10 | super(ExternalEmbeddingSupar, self).__init__() 11 | 12 | self.pretrained = nn.Embedding.from_pretrained(emb) 13 | self.origin_word_size = origin_word_size 14 | self.word_emb = nn.Embedding(origin_word_size, emb.shape[-1]) 15 | self.unk_index = unk_index 16 | nn.init.zeros_(self.word_emb.weight) 17 | 18 | def forward(self, words): 19 | ext_mask = words.ge(self.word_emb.num_embeddings) 20 | ext_words = words.masked_fill(ext_mask, self.unk_index) 21 | # get outputs from embedding layers 22 | word_embed = self.word_emb(ext_words) 23 | word_embed += self.pretrained(words) 24 | return word_embed 25 | 26 | def get_dim(self): 27 | return self.word_emb.weight.shape[-1] -------------------------------------------------------------------------------- /src/model/module/encoder/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import LSTM 3 | from supar.modules.dropout import IndependentDropout, SharedDropout 4 | import torch 5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 6 | 7 | 8 | class LSTMencoder(nn.Module): 9 | def __init__(self, conf, input_dim, **kwargs): 10 | super(LSTMencoder, self).__init__() 11 | self.conf = conf 12 | 13 | self.before_lstm_dropout = None 14 | 15 | if self.conf.embed_dropout_type == 'independent': 16 | self.embed_dropout = IndependentDropout(p=conf.embed_dropout) 17 | if conf.before_lstm_dropout: 18 | self.before_lstm_dropout = SharedDropout(p=conf.before_lstm_dropout) 19 | 20 | elif self.conf.embed_dropout_type == 'shared': 21 | self.embed_dropout = SharedDropout(p=conf.embed_dropout) 22 | 23 | elif self.conf.embed_dropout_type == 'simple': 24 | self.embed_dropout = nn.Dropout(p=conf.embed_dropout) 25 | 26 | else: 27 | self.embed_dropout = nn.Dropout(0.) 28 | 29 | self.lstm = LSTM(input_size=input_dim, 30 | hidden_size=conf.n_lstm_hidden, 31 | num_layers=conf.n_lstm_layers, 32 | bidirectional=True, 33 | dropout=conf.lstm_dropout) 34 | self.lstm_dropout = SharedDropout(p=conf.lstm_dropout) 35 | 36 | 37 | 38 | def forward(self, info): 39 | # lstm encoder 40 | embed = info['embed'] 41 | seq_len = info['seq_len'] 42 | 43 | embed = [i for i in embed.values()] 44 | 45 | if self.conf.embed_dropout_type == 'independent': 46 | embed = self.embed_dropout(embed) 47 | embed = torch.cat(embed, dim=-1) 48 | else: 49 | embed = torch.cat(embed, dim=-1) 50 | embed = self.embed_dropout(embed) 51 | 52 | seq_len = seq_len.cpu() 53 | x = pack_padded_sequence(embed, seq_len.cpu() + (embed.shape[1] - seq_len.max()), True, False) 54 | x, _ = self.lstm(x) 55 | x, _ = pad_packed_sequence(x, True, total_length=embed.shape[1]) 56 | x = self.lstm_dropout(x) 57 | 58 | info['encoded_emb'] = x 59 | 60 | def get_output_dim(self): 61 | return self.conf.n_lstm_hidden * 2 62 | 63 | -------------------------------------------------------------------------------- /src/model/module/scorer/dep_scorer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .module.biaffine import BiaffineScorer 3 | from .module.triaffine import TriaffineScorer 4 | import torch 5 | import logging 6 | log = logging.getLogger(__name__) 7 | 8 | 9 | 10 | class DepScorer(nn.Module): 11 | def __init__(self, conf, fields, input_dim): 12 | super(DepScorer, self).__init__() 13 | self.conf = conf 14 | 15 | 16 | self.rel_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_rel, bias_x=True, bias_y=True, 17 | dropout=conf.mlp_dropout, n_out_label=fields.get_vocab_size("rel"), 18 | scaling=conf.scaling) 19 | 20 | 21 | if self.conf.use_arc: 22 | self.arc_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=False, dropout=conf.mlp_dropout, scaling=conf.scaling) 23 | log.info("Use arc score") 24 | if self.conf.use_sib: 25 | log.info("Use sib score") 26 | self.sib_scorer = TriaffineScorer(n_in=input_dim, n_out=conf.n_mlp_sib, bias_x=True, bias_y=True, 27 | dropout=conf.mlp_dropout) 28 | 29 | if self.conf.use_span: 30 | log.info('use span score') 31 | if self.conf.span_scorer_type == 'biaffine': 32 | assert not self.conf.use_sib 33 | self.span_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling) 34 | 35 | elif self.conf.span_scorer_type == 'triaffine': 36 | assert not self.conf.use_sib 37 | self.span_scorer = TriaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, ) 38 | 39 | elif self.conf.span_scorer_type == 'headsplit': 40 | assert self.conf.use_arc 41 | self.span_scorer_left = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling) 42 | self.span_scorer_right = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling) 43 | else: 44 | raise NotImplementedError 45 | 46 | 47 | 48 | def forward(self, ctx): 49 | x = ctx['encoded_emb'] 50 | x_f, x_b = x.chunk(2, -1) 51 | x_boundary = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1) 52 | 53 | if self.conf.use_arc: 54 | ctx['s_arc'] = self.arc_scorer(x[:, :-1]) 55 | if self.conf.use_sib: 56 | ctx['s_sib'] = self.sib_scorer(x[:, :-1]) 57 | 58 | if self.conf.use_span: 59 | if self.conf.span_scorer_type == 'headsplit': 60 | ctx['s_bd_left'] = self.span_scorer_left.forward_v2(x[:,1:], x_boundary) 61 | ctx['s_bd_right'] = self.span_scorer_right.forward_v2(x[:,1:], x_boundary) 62 | 63 | elif self.conf.span_scorer_type == 'biaffine': 64 | # LSTM minus features 65 | span_repr = (x_boundary.unsqueeze(1) - x_boundary.unsqueeze(2)) 66 | batch, seq_len = span_repr.shape[:2] 67 | span_repr2 = span_repr.reshape(batch, seq_len * seq_len, -1) 68 | ctx['s_span_head_word'] = self.span_scorer.forward_v2(span_repr2, x[:, 1:-1]).reshape(batch, seq_len, 69 | seq_len, -1) 70 | elif self.conf.span_scorer_type == 'triaffine': 71 | ctx['s_span_head_word'] = self.span_scorer.forward2(x[:, 1:-1], x_boundary) 72 | else: 73 | raise NotImplementedError 74 | 75 | ctx['s_rel'] = self.rel_scorer(x[:, :-1]) 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/biaffine.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import MLP, Biaffine 3 | 4 | class BiaffineScorer(nn.Module): 5 | def __init__(self, n_in=800, n_out=400, n_out_label=1, bias_x=True, bias_y=False, scaling=False, dropout=0.33): 6 | super(BiaffineScorer, self).__init__() 7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 8 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 9 | self.attn = Biaffine(n_in=n_out, n_out=n_out_label, bias_x=bias_x, bias_y=bias_y) 10 | self.scaling = 0 if not scaling else n_out ** (1/4) 11 | self.n_in = n_in 12 | 13 | def forward(self, h): 14 | left = self.l(h) 15 | right = self.r(h) 16 | 17 | if self.scaling: 18 | left = left / self.scaling 19 | right = right / self.scaling 20 | 21 | return self.attn(left, right) 22 | 23 | 24 | def forward_v2(self, h, q): 25 | left = self.l(h) 26 | right = self.r(q) 27 | if self.scaling: 28 | left = left / self.scaling 29 | right = right / self.scaling 30 | 31 | return self.attn(left, right) 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/nhpsg_scorer.py: -------------------------------------------------------------------------------- 1 | # import torch.nn as nn 2 | # 3 | # 4 | # class BiAAttention(nn.Module): 5 | # ''' 6 | # Bi-Affine attention layer. 7 | # ''' 8 | # 9 | # def __init__(self, hparams): 10 | # super(BiAAttention, self).__init__() 11 | # self.hparams = hparams 12 | # 13 | # self.dep_weight = nn.Parameter(torch_t.FloatTensor(hparams.d_biaffine + 1, hparams.d_biaffine + 1)) 14 | # nn.init.xavier_uniform_(self.dep_weight) 15 | # 16 | # def forward(self, input_d, input_e, input_s = None): 17 | # 18 | # score = torch.matmul(torch.cat( 19 | # [input_d, torch_t.FloatTensor(input_d.size(0), 1).fill_(1).requires_grad_(False)], 20 | # dim=1), self.dep_weight) 21 | # score1 = torch.matmul(score, torch.transpose(torch.cat( 22 | # [input_e, torch_t.FloatTensor(input_e.size(0), 1).fill_(1).requires_grad_(False)], 23 | # dim=1), 0, 1)) 24 | # 25 | # return score1 26 | # 27 | # class Dep_score(nn.Module): 28 | # def __init__(self, hparams, num_labels): 29 | # super(Dep_score, self).__init__() 30 | # 31 | # self.dropout_out = nn.Dropout2d(p=0.33) 32 | # self.hparams = hparams 33 | # out_dim = hparams.d_biaffine#d_biaffine 34 | # self.arc_h = nn.Linear(hparams.d_model, hparams.d_biaffine) 35 | # self.arc_c = nn.Linear(hparams.d_model, hparams.d_biaffine) 36 | # 37 | # self.attention = BiAAttention(hparams) 38 | # 39 | # self.type_h = nn.Linear(hparams.d_model, hparams.d_label_hidden) 40 | # self.type_c = nn.Linear(hparams.d_model, hparams.d_label_hidden) 41 | # self.bilinear = BiLinear(hparams.d_label_hidden, hparams.d_label_hidden, num_labels) 42 | # 43 | # def forward(self, outputs, outpute): 44 | # # output from rnn [batch, length, hidden_size] 45 | # 46 | # # apply dropout for output 47 | # # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size] 48 | # outpute = self.dropout_out(outpute.transpose(1, 0)).transpose(1, 0) 49 | # outputs = self.dropout_out(outputs.transpose(1, 0)).transpose(1, 0) 50 | # 51 | # # output size [batch, length, arc_space] 52 | # arc_h = nn.functional.relu(self.arc_h(outputs)) 53 | # arc_c = nn.functional.relu(self.arc_c(outpute)) 54 | # 55 | # # output size [batch, length, type_space] 56 | # type_h = nn.functional.relu(self.type_h(outputs)) 57 | # type_c = nn.functional.relu(self.type_c(outpute)) 58 | # 59 | # # apply dropout 60 | # # [batch, length, dim] --> [batch, 2 * length, dim] 61 | # arc = torch.cat([arc_h, arc_c], dim=0) 62 | # type = torch.cat([type_h, type_c], dim=0) 63 | # 64 | # arc = self.dropout_out(arc.transpose(1, 0)).transpose(1, 0) 65 | # arc_h, arc_c = arc.chunk(2, 0) 66 | # 67 | # type = self.dropout_out(type.transpose(1, 0)).transpose(1, 0) 68 | # type_h, type_c = type.chunk(2, 0) 69 | # type_h = type_h.contiguous() 70 | # type_c = type_c.contiguous() 71 | # 72 | # out_arc = self.attention(arc_h, arc_c) 73 | # out_type = self.bilinear(type_h, type_c) 74 | # 75 | # return out_arc, out_type 76 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/quadra_linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | import pdb 5 | import opt_einsum 6 | 7 | ## TODO: 完成一个quadra-linear的label的打分函数? 8 | class QlinearScorer(nn.Module): 9 | """ 10 | Outer product version of trilinear function. 11 | Trilinear attention layer. 12 | """ 13 | 14 | def __init__(self, input_size_1, input_size_2, input_size_3, input_size_4, factorize=True, 15 | **kwargs): 16 | """ 17 | Args: 18 | input_size_encoder: int 19 | the dimension of the encoder input. 20 | input_size_decoder: int 21 | the dimension of the decoder input. 22 | num_labels: int 23 | the number of labels of the crf layer 24 | biaffine: bool 25 | if apply bi-affine parameter. 26 | **kwargs: 27 | """ 28 | # super(QlinearScorer, self).__init__() 29 | # self.input_size_1 = input_size_1 + 1 30 | # self.input_size_2 = input_size_2 + 1 31 | # self.input_size_3 = input_size_3 + 1 32 | # self.input_size_4 = input_size_3 + 1 33 | # self.rank = rank 34 | # self.init_std = init_std 35 | # self.factorize = factorize 36 | # if not factorize: 37 | # self.W = Parameter(torch.Tensor(self.input_size_1, self.input_size_2, self.input_size_3, self.input_size_4)) 38 | # else: 39 | # self.W_1 = Parameter(torch.Tensor(self.input_size_1, self.rank)) 40 | # self.W_2 = Parameter(torch.Tensor(self.input_size_2, self.rank)) 41 | # self.W_3 = Parameter(torch.Tensor(self.input_size_3, self.rank)) 42 | # self.W_4 = Parameter(torch.Tensor(self.input_size_4, self.rank)) 43 | # self.W_5 = Parameter(torch.Tensor(self.rank, )) 44 | # # if self.biaffine: 45 | # # self.U = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder, self.input_size_encoder)) 46 | # # else: 47 | # # self.register_parameter('U', None) 48 | # 49 | # self.reset_parameters() 50 | 51 | # def reset_parameters(self): 52 | # if not self.factorize: 53 | # nn.init.xavier_normal_(self.W) 54 | # else: 55 | # nn.init.xavier_normal_(self.W_1, gain=self.init_std) 56 | # nn.init.xavier_normal_(self.W_2, gain=self.init_std) 57 | # nn.init.xavier_normal_(self.W_3, gain=self.init_std) 58 | # nn.init.xavier_normal_(self.W_4, gain=self.init_std) 59 | 60 | def forward(self, layer1, layer2, layer3, layer4 = delattr()): 61 | """ 62 | Args: 63 | 64 | Returns: Tensor 65 | the energy tensor with shape = [batch, num_label, length, length] 66 | """ 67 | assert layer1.size(0) == layer2.size(0), 'batch sizes of encoder and decoder are requires to be equal.' 68 | layer_shape = layer1.size() 69 | one_shape = list(layer_shape[:2]) + [1] 70 | ones = torch.ones(one_shape).cuda() 71 | layer1 = torch.cat([layer1, ones], -1) 72 | layer2 = torch.cat([layer2, ones], -1) 73 | layer3 = torch.cat([layer3, ones], -1) 74 | layer4 = torch.cat([layer4, ones], -1) 75 | 76 | layer = opt_einsum.contract('bnx,xr,bny,yr,bnz,zr, bns,sr->bn', self.W_1, layer1, self.W_2, layer2, self.W_3, layer3, self.W_4, layer4, backend='torch') 77 | return layer 78 | -------------------------------------------------------------------------------- /src/model/module/scorer/module/triaffine.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from supar.modules import MLP, Triaffine 3 | 4 | class TriaffineScorer(nn.Module): 5 | def __init__(self, n_in=800, n_out=400, bias_x=True, bias_y=False, dropout=0.33): 6 | super(TriaffineScorer, self).__init__() 7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 8 | self.m = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 9 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout) 10 | self.attn = Triaffine(n_in=n_out, bias_x=bias_x, bias_y=bias_y) 11 | 12 | def forward(self, h): 13 | left = self.l(h) 14 | mid = self.m(h) 15 | right = self.r(h) 16 | #sib, dependent, head) 17 | return self.attn(left, mid, right).permute(0, 2, 3, 1) 18 | 19 | def forward2(self, word, span): 20 | left = self.l(word) 21 | mid = self.m(span) 22 | right = self.r(span) 23 | # head, left_bdr, right_bdr: used in span-head model? 24 | 25 | # word, left, right 26 | # fine. 27 | return self.attn(mid, right, left).permute(0, 2, 3, 1) 28 | 29 | 30 | 31 | # class TriaffineScorer 32 | # class TriaffineScorer(nn.Module): 33 | 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /supar/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | 5 | __version__ = '1.0.0' 6 | 7 | 8 | PRETRAINED = { 9 | 'biaffine-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.char.zip', 10 | 'biaffine-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.char.zip', 11 | 'biaffine-dep-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.bert.zip', 12 | 'biaffine-dep-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.bert.zip', 13 | 'crfnp-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crfnp.dependency.char.zip', 14 | 'crfnp-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crfnp.dependency.char.zip', 15 | 'crf-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.dependency.char.zip', 16 | 'crf-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.dependency.char.zip', 17 | 'crf2o-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf2o.dependency.char.zip', 18 | 'crf2o-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf2o.dependency.char.zip', 19 | 'crf-con-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.char.zip', 20 | 'crf-con-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.char.zip', 21 | 'crf-con-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.bert.zip', 22 | 'crf-con-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.bert.zip' 23 | } 24 | -------------------------------------------------------------------------------- /supar/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .affine import Biaffine, Triaffine 4 | from .bert import BertEmbedding 5 | from .char_lstm import CharLSTM 6 | from .dropout import IndependentDropout, SharedDropout 7 | from .lstm import LSTM 8 | from .mlp import MLP 9 | from .scalar_mix import ScalarMix 10 | from .treecrf import (CRF2oDependency, CRFConstituency, CRFDependency, 11 | MatrixTree) 12 | from .variational_inference import (LBPSemanticDependency, 13 | MFVISemanticDependency) 14 | from .transformer import TransformerEmbedding 15 | 16 | __all__ = ['LSTM', 'MLP', 'BertEmbedding', 'Biaffine', 'CharLSTM', 'CRF2oDependency', 'CRFConstituency', 'CRFDependency', 17 | 'IndependentDropout', 'LBPSemanticDependency', 'MatrixTree', 18 | 'MFVISemanticDependency', 'ScalarMix', 'SharedDropout', 'Triaffine', 'TransformerEmbedding'] 19 | -------------------------------------------------------------------------------- /supar/modules/char_lstm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.utils.rnn import pack_padded_sequence 6 | 7 | 8 | class CharLSTM(nn.Module): 9 | r""" 10 | CharLSTM aims to generate character-level embeddings for tokens. 11 | It summerizes the information of characters in each token to an embedding using a LSTM layer. 12 | 13 | Args: 14 | n_char (int): 15 | The number of characters. 16 | n_embed (int): 17 | The size of each embedding vector as input to LSTM. 18 | n_out (int): 19 | The size of each output vector. 20 | pad_index (int): 21 | The index of the padding token in the vocabulary. Default: 0. 22 | """ 23 | 24 | def __init__(self, n_chars, n_embed, n_out, pad_index=0, input_dropout=0.): 25 | super().__init__() 26 | 27 | self.n_chars = n_chars 28 | self.n_embed = n_embed 29 | self.n_out = n_out 30 | self.pad_index = pad_index 31 | 32 | # the embedding layer 33 | self.embed = nn.Embedding(num_embeddings=n_chars, 34 | embedding_dim=n_embed) 35 | # the lstm layer 36 | self.lstm = nn.LSTM(input_size=n_embed, 37 | hidden_size=n_out//2, 38 | batch_first=True, 39 | bidirectional=True) 40 | 41 | self.input_dropout = nn.Dropout(input_dropout) 42 | 43 | def __repr__(self): 44 | return f"{self.__class__.__name__}({self.n_chars}, {self.n_embed}, n_out={self.n_out}, pad_index={self.pad_index})" 45 | 46 | def forward(self, x): 47 | r""" 48 | Args: 49 | x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. 50 | Characters of all tokens. 51 | Each token holds no more than `fix_len` characters, and the excess is cut off directly. 52 | Returns: 53 | ~torch.Tensor: 54 | The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters. 55 | """ 56 | # [batch_size, seq_len, fix_len] 57 | mask = x.ne(self.pad_index) 58 | # [batch_size, seq_len] 59 | lens = mask.sum(-1) 60 | char_mask = lens.gt(0) 61 | 62 | # [n, fix_len, n_embed] 63 | x = self.embed(x[char_mask]) 64 | x = self.input_dropout(x) 65 | x = pack_padded_sequence(x, lens[char_mask].cpu(), True, False) 66 | x, (h, _) = self.lstm(x) 67 | # [n, fix_len, n_out] 68 | h = torch.cat(torch.unbind(h), -1) 69 | # [batch_size, seq_len, n_out] 70 | embed = h.new_zeros(*lens.shape, self.n_out) 71 | embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h) 72 | 73 | return embed 74 | -------------------------------------------------------------------------------- /supar/modules/dropout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class SharedDropout(nn.Module): 8 | r""" 9 | SharedDropout differs from the vanilla dropout strategy in that 10 | the dropout mask is shared across one dimension. 11 | 12 | Args: 13 | p (float): 14 | The probability of an element to be zeroed. Default: 0.5. 15 | batch_first (bool): 16 | If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``. 17 | Default: ``True``. 18 | 19 | Examples: 20 | >>> x = torch.ones(1, 3, 5) 21 | >>> nn.Dropout()(x) 22 | tensor([[[0., 2., 2., 0., 0.], 23 | [2., 2., 0., 2., 2.], 24 | [2., 2., 2., 2., 0.]]]) 25 | >>> SharedDropout()(x) 26 | tensor([[[2., 0., 2., 0., 2.], 27 | [2., 0., 2., 0., 2.], 28 | [2., 0., 2., 0., 2.]]]) 29 | """ 30 | 31 | def __init__(self, p=0.5, batch_first=True): 32 | super().__init__() 33 | 34 | self.p = p 35 | self.batch_first = batch_first 36 | 37 | def __repr__(self): 38 | s = f"p={self.p}" 39 | if self.batch_first: 40 | s += f", batch_first={self.batch_first}" 41 | 42 | return f"{self.__class__.__name__}({s})" 43 | 44 | def forward(self, x): 45 | r""" 46 | Args: 47 | x (~torch.Tensor): 48 | A tensor of any shape. 49 | Returns: 50 | The returned tensor is of the same shape as `x`. 51 | """ 52 | 53 | if self.training: 54 | if self.batch_first: 55 | mask = self.get_mask(x[:, 0], self.p).unsqueeze(1) 56 | else: 57 | mask = self.get_mask(x[0], self.p) 58 | x = x * mask 59 | 60 | return x 61 | 62 | @staticmethod 63 | def get_mask(x, p): 64 | return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p) 65 | 66 | 67 | class IndependentDropout(nn.Module): 68 | r""" 69 | For :math:`N` tensors, they use different dropout masks respectively. 70 | When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate, 71 | and when all of them are dropped together, zeros are returned. 72 | 73 | Args: 74 | p (float): 75 | The probability of an element to be zeroed. Default: 0.5. 76 | 77 | Examples: 78 | >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5) 79 | >>> x, y = IndependentDropout()(x, y) 80 | >>> x 81 | tensor([[[1., 1., 1., 1., 1.], 82 | [0., 0., 0., 0., 0.], 83 | [2., 2., 2., 2., 2.]]]) 84 | >>> y 85 | tensor([[[1., 1., 1., 1., 1.], 86 | [2., 2., 2., 2., 2.], 87 | [0., 0., 0., 0., 0.]]]) 88 | """ 89 | 90 | def __init__(self, p=0.5): 91 | super().__init__() 92 | 93 | self.p = p 94 | 95 | def __repr__(self): 96 | return f"{self.__class__.__name__}(p={self.p})" 97 | 98 | def forward(self, items): 99 | r""" 100 | Args: 101 | items (list[~torch.Tensor]): 102 | A list of tensors that have the same shape except the last dimension. 103 | Returns: 104 | The returned tensors are of the same shape as `items`. 105 | """ 106 | 107 | if self.training: 108 | masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items] 109 | total = sum(masks) 110 | scale = len(items) / total.max(torch.ones_like(total)) 111 | masks = [mask * scale for mask in masks] 112 | items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)] 113 | 114 | return items 115 | -------------------------------------------------------------------------------- /supar/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from supar.modules.dropout import SharedDropout 5 | 6 | 7 | class MLP(nn.Module): 8 | r""" 9 | Applies a linear transformation together with a non-linear activation to the incoming tensor: 10 | :math:`y = \mathrm{Activation}(x A^T + b)` 11 | 12 | Args: 13 | n_in (~torch.Tensor): 14 | The size of each input feature. 15 | n_out (~torch.Tensor): 16 | The size of each output feature. 17 | dropout (float): 18 | If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0. 19 | activation (bool): 20 | Whether to use activations. Default: True. 21 | """ 22 | 23 | def __init__(self, n_in, n_out, dropout=0, activation=True): 24 | super().__init__() 25 | 26 | self.n_in = n_in 27 | self.n_out = n_out 28 | self.linear = nn.Linear(n_in, n_out) 29 | self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity() 30 | self.dropout = SharedDropout(p=dropout) 31 | 32 | self.reset_parameters() 33 | 34 | def __repr__(self): 35 | s = f"n_in={self.n_in}, n_out={self.n_out}" 36 | if self.dropout.p > 0: 37 | s += f", dropout={self.dropout.p}" 38 | 39 | return f"{self.__class__.__name__}({s})" 40 | 41 | def reset_parameters(self): 42 | nn.init.orthogonal_(self.linear.weight) 43 | nn.init.zeros_(self.linear.bias) 44 | 45 | def forward(self, x): 46 | r""" 47 | Args: 48 | x (~torch.Tensor): 49 | The size of each input feature is `n_in`. 50 | 51 | Returns: 52 | A tensor with the size of each output feature `n_out`. 53 | """ 54 | 55 | x = self.linear(x) 56 | x = self.activation(x) 57 | x = self.dropout(x) 58 | 59 | return x 60 | -------------------------------------------------------------------------------- /supar/modules/scalar_mix.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class ScalarMix(nn.Module): 8 | r""" 9 | Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` 10 | where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. 11 | 12 | Args: 13 | n_layers (int): 14 | The number of layers to be mixed, i.e., :math:`N`. 15 | dropout (float): 16 | The dropout ratio of the layer weights. 17 | If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0 18 | with the dropout probability (i.e., setting the unnormalized weight to -inf). 19 | This effectively redistributes the dropped probability mass to all other weights. 20 | Default: 0. 21 | """ 22 | 23 | def __init__(self, n_layers, dropout=0): 24 | super().__init__() 25 | 26 | self.n_layers = n_layers 27 | 28 | self.weights = nn.Parameter(torch.zeros(n_layers)) 29 | self.gamma = nn.Parameter(torch.tensor([1.0])) 30 | self.dropout = nn.Dropout(dropout) 31 | 32 | def __repr__(self): 33 | s = f"n_layers={self.n_layers}" 34 | if self.dropout.p > 0: 35 | s += f", dropout={self.dropout.p}" 36 | 37 | return f"{self.__class__.__name__}({s})" 38 | 39 | def forward(self, tensors): 40 | r""" 41 | Args: 42 | tensors (list[~torch.Tensor]): 43 | :math:`N` tensors to be mixed. 44 | 45 | Returns: 46 | The mixture of :math:`N` tensors. 47 | """ 48 | 49 | normed_weights = self.dropout(self.weights.softmax(-1)) 50 | weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors)) 51 | 52 | return self.gamma * weighted_sum 53 | -------------------------------------------------------------------------------- /supar/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import alg, field, fn, metric, transform 4 | from .alg import chuliu_edmonds, cky, eisner, eisner2o, kmeans, mst, tarjan 5 | from .config import Config 6 | from .data import Dataset 7 | from .embedding import Embedding 8 | from .field import ChartField, Field, RawField, SubwordField 9 | from .transform import CoNLL, Transform, Tree 10 | from .vocab import Vocab 11 | 12 | __all__ = ['ChartField', 'CoNLL', 'Config', 'Dataset', 'Embedding', 'Field', 13 | 'RawField', 'SubwordField', 'Transform', 'Tree', 'Vocab', 14 | 'alg', 'field', 'fn', 'metric', 'chuliu_edmonds', 'cky', 15 | 'eisner', 'eisner2o', 'kmeans', 'mst', 'tarjan', 'transform'] 16 | -------------------------------------------------------------------------------- /supar/utils/common.py: -------------------------------------------------------------------------------- 1 | PAD = '' 2 | UNK = '' 3 | BOS = '' 4 | EOS = '' 5 | subword_bos = '' 6 | subword_eos = '<_eos>' 7 | no_label = '' 8 | 9 | -------------------------------------------------------------------------------- /supar/utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from ast import literal_eval 4 | from configparser import ConfigParser 5 | 6 | 7 | class Config(object): 8 | 9 | def __init__(self, conf=None, **kwargs): 10 | super(Config, self).__init__() 11 | 12 | config = ConfigParser() 13 | config.read(conf or []) 14 | self.update({**dict((name, literal_eval(value)) 15 | for section in config.sections() 16 | for name, value in config.items(section)), 17 | **kwargs}) 18 | 19 | def __repr__(self): 20 | s = line = "-" * 20 + "-+-" + "-" * 30 + "\n" 21 | s += f"{'Param':20} | {'Value':^30}\n" + line 22 | for name, value in vars(self).items(): 23 | s += f"{name:20} | {str(value):^30}\n" 24 | s += line 25 | 26 | return s 27 | 28 | def __getitem__(self, key): 29 | return getattr(self, key) 30 | 31 | def __getstate__(self): 32 | return vars(self) 33 | 34 | def __setstate__(self, state): 35 | self.__dict__.update(state) 36 | 37 | def keys(self): 38 | return vars(self).keys() 39 | 40 | def items(self): 41 | return vars(self).items() 42 | 43 | def update(self, kwargs): 44 | for key in ('self', 'cls', '__class__'): 45 | kwargs.pop(key, None) 46 | kwargs.update(kwargs.pop('kwargs', dict())) 47 | for name, value in kwargs.items(): 48 | setattr(self, name, value) 49 | 50 | return self 51 | 52 | def pop(self, key, val=None): 53 | return self.__dict__.pop(key, val) 54 | -------------------------------------------------------------------------------- /supar/utils/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | class Embedding(object): 7 | 8 | def __init__(self, tokens, vectors, unk=None): 9 | self.tokens = tokens 10 | self.vectors = torch.tensor(vectors) 11 | self.pretrained = {w: v for w, v in zip(tokens, vectors)} 12 | self.unk = unk 13 | 14 | def __len__(self): 15 | return len(self.tokens) 16 | 17 | def __contains__(self, token): 18 | return token in self.pretrained 19 | 20 | @property 21 | def dim(self): 22 | return self.vectors.size(1) 23 | 24 | @property 25 | def unk_index(self): 26 | if self.unk is not None: 27 | return self.tokens.index(self.unk) 28 | else: 29 | raise AttributeError 30 | 31 | @classmethod 32 | def load(cls, path, unk=None): 33 | with open(path, 'r', encoding="utf-8") as f: 34 | lines = [line for line in f] 35 | 36 | if len(lines[0].split()) == 2: 37 | lines = lines[1:] 38 | 39 | splits = [line.split() for line in lines] 40 | tokens, vectors = zip(*[(s[0], list(map(float, s[1:]))) 41 | for s in splits]) 42 | 43 | return cls(tokens, vectors, unk=unk) 44 | -------------------------------------------------------------------------------- /supar/utils/fn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import unicodedata 4 | 5 | 6 | def ispunct(token): 7 | return all(unicodedata.category(char).startswith('P') 8 | for char in token) 9 | 10 | 11 | def isfullwidth(token): 12 | return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A'] 13 | for char in token) 14 | 15 | 16 | def islatin(token): 17 | return all('LATIN' in unicodedata.name(char) 18 | for char in token) 19 | 20 | 21 | def isdigit(token): 22 | return all('DIGIT' in unicodedata.name(char) 23 | for char in token) 24 | 25 | 26 | def tohalfwidth(token): 27 | return unicodedata.normalize('NFKC', token) 28 | 29 | 30 | def stripe(x, n, w, offset=(0, 0), dim=1): 31 | # r""" 32 | # Returns a diagonal stripe of the tensor. 33 | # 34 | # Args: 35 | # x (~torch.Tensor): the input tensor with 2 or more dims. 36 | # n (int): the length of the stripe. 37 | # w (int): the width of the stripe. 38 | # offset (tuple): the offset of the first two dims. 39 | # dim (int): 1 if returns a horizontal stripe; 0 otherwise. 40 | # 41 | # Returns: 42 | # a diagonal stripe of the tensor. 43 | # Examples: 44 | # >>> x = torch.arange(25).view(5, 5) 45 | # >>> x 46 | # tensor([[ 0, 1, 2, 3, 4], 47 | # [ 5, 6, 7, 8, 9], 48 | # [10, 11, 12, 13, 14], 49 | # [15, 16, 17, 18, 19], 50 | # [20, 21, 22, 23, 24]]) 51 | # >>> stripe(x, 2, 3) 52 | # tensor([[0, 1, 2], 53 | # [6, 7, 8]]) 54 | # >>> stripe(x, 2, 3, (1, 1)) 55 | # tensor([[ 6, 7, 8], 56 | # [12, 13, 14]]) 57 | # >>> stripe(x, 2, 3, (1, 1), 0) 58 | # tensor([[ 6, 11, 16], 59 | # [12, 17, 22]]) 60 | # """ 61 | 62 | x, seq_len = x.contiguous(), x.size(1) 63 | stride, numel = list(x.stride()), x[0, 0].numel() 64 | stride[0] = (seq_len + 1) * numel 65 | stride[1] = (1 if dim == 1 else seq_len) * numel 66 | return x.as_strided(size=(n, w, *x.shape[2:]), 67 | stride=stride, 68 | storage_offset=(offset[0]*seq_len+offset[1])*numel) 69 | 70 | 71 | def pad(tensors, padding_value=0, total_length=None, padding_side='right'): 72 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) 73 | for i in range(len(tensors[0].size()))] 74 | if total_length is not None: 75 | assert total_length >= size[1] 76 | size[1] = total_length 77 | out_tensor = tensors[0].data.new(*size).fill_(padding_value) 78 | for i, tensor in enumerate(tensors): 79 | out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor 80 | return out_tensor 81 | 82 | 83 | if __name__ == '__main__': 84 | import torch 85 | x = torch.arange(25).view(5, 5) 86 | print(stripe(x, 2, 3, (3, 1))) 87 | -------------------------------------------------------------------------------- /supar/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import logging 4 | import os 5 | 6 | from supar.utils.parallel import is_master 7 | from tqdm import tqdm 8 | 9 | 10 | def get_logger(name): 11 | return logging.getLogger(name) 12 | 13 | 14 | def init_logger(logger, 15 | path=None, 16 | mode='w', 17 | level=None, 18 | handlers=None, 19 | verbose=True): 20 | level = level or logging.WARNING 21 | if not handlers: 22 | handlers = [logging.StreamHandler()] 23 | if path: 24 | os.makedirs(os.path.dirname(path), exist_ok=True) 25 | handlers.append(logging.FileHandler(path, mode)) 26 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s', 27 | datefmt='%Y-%m-%d %H:%M:%S', 28 | level=level, 29 | handlers=handlers) 30 | logger.setLevel(logging.INFO if is_master() and verbose else logging.WARNING) 31 | 32 | 33 | def progress_bar(iterator, 34 | ncols=None, 35 | bar_format='{l_bar}{bar:18}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}', 36 | leave=True): 37 | return tqdm(iterator, 38 | ncols=ncols, 39 | bar_format=bar_format, 40 | ascii=True, 41 | disable=(not (logger.level == logging.INFO and is_master())), 42 | leave=leave) 43 | 44 | 45 | logger = get_logger('supar') 46 | -------------------------------------------------------------------------------- /supar/utils/parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from random import Random 5 | 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn as nn 9 | 10 | 11 | class DistributedDataParallel(nn.parallel.DistributedDataParallel): 12 | 13 | def __init__(self, module, **kwargs): 14 | super().__init__(module, **kwargs) 15 | 16 | def __getattr__(self, name): 17 | wrapped = super().__getattr__('module') 18 | if hasattr(wrapped, name): 19 | return getattr(wrapped, name) 20 | return super().__getattr__(name) 21 | 22 | 23 | def init_device(device, local_rank=-1, backend='nccl', host=None, port=None): 24 | os.environ['CUDA_VISIBLE_DEVICES'] = device 25 | if torch.cuda.device_count() > 1: 26 | host = host or os.environ.get('MASTER_ADDR', 'localhost') 27 | port = port or os.environ.get('MASTER_PORT', str(Random(0).randint(10000, 20000))) 28 | os.environ['MASTER_ADDR'] = host 29 | os.environ['MASTER_PORT'] = port 30 | dist.init_process_group(backend) 31 | torch.cuda.set_device(local_rank) 32 | 33 | 34 | def is_master(): 35 | return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0 36 | -------------------------------------------------------------------------------- /supar/utils/vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from collections import defaultdict 4 | from collections.abc import Iterable 5 | 6 | 7 | class Vocab(object): 8 | r""" 9 | Defines a vocabulary object that will be used to numericalize a field. 10 | 11 | Args: 12 | counter (~collections.Counter): 13 | :class:`~collections.Counter` object holding the frequencies of each value found in the data. 14 | min_freq (int): 15 | The minimum frequency needed to include a token in the vocabulary. Default: 1. 16 | specials (list[str]): 17 | The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: []. 18 | unk_index (int): 19 | The index of unk token. Default: 0. 20 | 21 | Attributes: 22 | itos: 23 | A list of token strings indexed by their numerical identifiers. 24 | stoi: 25 | A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers. 26 | """ 27 | 28 | def __init__(self, counter, min_freq=1, specials=[], unk_index=0): 29 | self.itos = list(specials) 30 | self.stoi = defaultdict(lambda: unk_index) 31 | 32 | self.stoi.update({token: i for i, token in enumerate(self.itos)}) 33 | self.extend([token for token, freq in counter.items() 34 | if freq >= min_freq]) 35 | self.unk_index = unk_index 36 | self.n_init = len(self) 37 | 38 | def __len__(self): 39 | return len(self.itos) 40 | 41 | def __getitem__(self, key): 42 | if isinstance(key, str): 43 | return self.stoi[key] 44 | elif not isinstance(key, Iterable): 45 | return self.itos[key] 46 | elif isinstance(key[0], str): 47 | return [self.stoi[i] for i in key] 48 | else: 49 | return [self.itos[i] for i in key] 50 | 51 | def __contains__(self, token): 52 | return token in self.stoi 53 | 54 | def __getstate__(self): 55 | # avoid picking defaultdict 56 | attrs = dict(self.__dict__) 57 | # cast to regular dict 58 | attrs['stoi'] = dict(self.stoi) 59 | return attrs 60 | 61 | def __setstate__(self, state): 62 | stoi = defaultdict(lambda: self.unk_index) 63 | stoi.update(state['stoi']) 64 | state['stoi'] = stoi 65 | self.__dict__.update(state) 66 | 67 | 68 | ### do not change 69 | def extend(self, tokens): 70 | try: 71 | self.stoi = defaultdict(lambda: self.unk_index) 72 | for word in self.itos: 73 | self.stoi[word] 74 | except: 75 | pass 76 | 77 | self.itos.extend(sorted(set(tokens).difference(self.stoi))) 78 | self.stoi.update({token: i for i, token in enumerate(self.itos)}) 79 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from typing import List 5 | 6 | import hydra 7 | import omegaconf 8 | import pytorch_lightning as pl 9 | from hydra.core.hydra_config import HydraConfig 10 | import wandb 11 | from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer 12 | from pytorch_lightning.loggers import LightningLoggerBase 13 | from pytorch_lightning import seed_everything 14 | 15 | # hydra imports 16 | from omegaconf import DictConfig 17 | from hydra.utils import log 18 | import hydra 19 | from omegaconf import OmegaConf 20 | 21 | # normal imports 22 | from typing import List 23 | import warnings 24 | import logging 25 | from pytorch_lightning.loggers import WandbLogger 26 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping 27 | import sys 28 | # sys.setdefaultencoding() does not exist, here! 29 | from omegaconf import OmegaConf, open_dict 30 | 31 | 32 | def train(config): 33 | # if contains this, means we are multi-run and optuna-ing 34 | log.info(OmegaConf.to_container(config,resolve=True)) 35 | config.root = hydra.utils.get_original_cwd() 36 | limit_train_batches = 1.0 37 | 38 | hydra_dir = str(os.getcwd()) 39 | seed_everything(config.seed) 40 | os.chdir(hydra.utils.get_original_cwd()) 41 | 42 | # Instantiate datamodule 43 | hydra.utils.log.info(os.getcwd()) 44 | hydra.utils.log.info(f"Instantiating <{config.datamodule.target}>") 45 | # Instantiate callbacks and logger. 46 | callbacks: List[Callback] = [] 47 | logger: List[LightningLoggerBase] = [] 48 | 49 | datamodule: pl.LightningDataModule = hydra.utils.instantiate( 50 | config.datamodule.target, config.datamodule, _recursive_=False 51 | ) 52 | 53 | log.info("created datamodule") 54 | datamodule.setup() 55 | model = hydra.utils.instantiate(config.runner, cfg = config, fields=datamodule.fields, datamodule=datamodule, _recursive_=False) 56 | 57 | os.chdir(hydra_dir) 58 | # Train the model ⚡ 59 | if "callbacks" in config: 60 | for _, cb_conf in config["callbacks"].items(): 61 | if "_target_" in cb_conf: 62 | log.info(f"Instantiating callback <{cb_conf._target_}>") 63 | callbacks.append(hydra.utils.instantiate(cb_conf)) 64 | 65 | if config.checkpoint: 66 | callbacks.append( 67 | ModelCheckpoint( 68 | monitor='valid/score', 69 | mode='max', 70 | save_last=False, 71 | filename='checkpoint' 72 | ) 73 | ) 74 | log.info("Instantiating callback, ModelCheckpoint") 75 | 76 | 77 | if config.wandb: 78 | logger.append(hydra.utils.instantiate(config.logger)) 79 | 80 | log.info(f"Instantiating trainer <{config.trainer._target_}>") 81 | trainer: Trainer = hydra.utils.instantiate( 82 | config.trainer, callbacks=callbacks, logger=logger, 83 | replace_sampler_ddp=False, 84 | # accelerator='ddp' if distributed else None, 85 | accumulate_grad_batches=config.accumulation, 86 | limit_train_batches=limit_train_batches, 87 | checkpoint_callback=False 88 | ) 89 | 90 | log.info(f"Starting training!") 91 | if config.wandb: 92 | logger[-1].experiment.save(str(hydra_dir) + "/.hydra/*", base_path=str(hydra_dir)) 93 | 94 | trainer.fit(model, datamodule) 95 | log.info(f"Finalizing!") 96 | 97 | if config.wandb: 98 | logger[-1].experiment.save(str(hydra_dir) + "/*.log", base_path=str(hydra_dir)) 99 | wandb.finish() 100 | 101 | log.info(f'hydra_path:{os.getcwd()}') 102 | 103 | 104 | @hydra.main(config_path="configs/", config_name="config.yaml") 105 | def main(config): 106 | train(config) 107 | 108 | if __name__ == "__main__": 109 | main() 110 | 111 | 112 | --------------------------------------------------------------------------------