├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── span-based-dependency-parsing-2.iml
└── vcs.xml
├── README.md
├── configs
├── config.yaml
├── datamodule
│ ├── _base.yaml
│ ├── ctb.yaml
│ ├── ptb.yaml
│ └── ud2.2.yaml
├── exp
│ └── base.yaml
├── finetune
│ └── base.yaml
├── logger
│ ├── comet.yaml
│ ├── csv.yaml
│ ├── many_loggers.yaml
│ ├── neptune.yaml
│ ├── tensorboard.yaml
│ └── wandb.yaml
├── model
│ ├── _base.yaml
│ ├── biaffine.yaml
│ ├── biaffine2o.yaml
│ ├── span.yaml
│ ├── span1o.yaml
│ ├── span1oheadplit.yaml
│ └── span2oheadplit.yaml
├── optim
│ └── finetune_bert.yaml
└── trainer
│ └── default_trainer.yaml
├── fastNLP
├── __init__.py
├── core
│ ├── __init__.py
│ ├── _logger.py
│ ├── _parallel_utils.py
│ ├── batch.py
│ ├── callback.py
│ ├── collate_fn.py
│ ├── const.py
│ ├── dataset.py
│ ├── dist_trainer.py
│ ├── field.py
│ ├── instance.py
│ ├── losses.py
│ ├── metrics.py
│ ├── optimizer.py
│ ├── predictor.py
│ ├── sampler.py
│ ├── tester.py
│ ├── trainer.py
│ ├── utils.py
│ └── vocabulary.py
├── doc_utils.py
├── embeddings
│ ├── __init__.py
│ ├── bert_embedding.py
│ ├── char_embedding.py
│ ├── contextual_embedding.py
│ ├── elmo_embedding.py
│ ├── embedding.py
│ ├── gpt2_embedding.py
│ ├── roberta_embedding.py
│ ├── stack_embedding.py
│ ├── static_embedding.py
│ └── utils.py
├── io
│ ├── __init__.py
│ ├── data_bundle.py
│ ├── embed_loader.py
│ ├── file_reader.py
│ ├── file_utils.py
│ ├── loader
│ │ ├── __init__.py
│ │ ├── classification.py
│ │ ├── conll.py
│ │ ├── coreference.py
│ │ ├── csv.py
│ │ ├── cws.py
│ │ ├── json.py
│ │ ├── loader.py
│ │ ├── matching.py
│ │ ├── qa.py
│ │ └── summarization.py
│ ├── model_io.py
│ ├── pipe
│ │ ├── __init__.py
│ │ ├── classification.py
│ │ ├── conll.py
│ │ ├── coreference.py
│ │ ├── cws.py
│ │ ├── matching.py
│ │ ├── pipe.py
│ │ ├── qa.py
│ │ ├── summarization.py
│ │ └── utils.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── bert.py
│ ├── biaffine_parser.py
│ ├── cnn_text_classification.py
│ ├── seq2seq_generator.py
│ ├── seq2seq_model.py
│ ├── sequence_labeling.py
│ ├── snli.py
│ └── star_transformer.py
└── modules
│ ├── __init__.py
│ ├── attention.py
│ ├── decoder
│ ├── __init__.py
│ ├── crf.py
│ ├── mlp.py
│ ├── seq2seq_decoder.py
│ ├── seq2seq_state.py
│ └── utils.py
│ ├── dropout.py
│ ├── encoder
│ ├── __init__.py
│ ├── _elmo.py
│ ├── bert.py
│ ├── char_encoder.py
│ ├── conv_maxpool.py
│ ├── gpt2.py
│ ├── lstm.py
│ ├── pooling.py
│ ├── roberta.py
│ ├── seq2seq_encoder.py
│ ├── star_transformer.py
│ ├── transformer.py
│ └── variational_rnn.py
│ ├── generator
│ ├── __init__.py
│ └── seq2seq_generator.py
│ ├── tokenizer
│ ├── __init__.py
│ ├── bert_tokenizer.py
│ ├── gpt2_tokenizer.py
│ └── roberta_tokenizer.py
│ └── utils.py
├── requirements.txt
├── src
├── callbacks
│ ├── progressbar.py
│ ├── transformer_scheduler.py
│ └── wandb_callbacks.py
├── constant.py
├── datamodule
│ ├── __init__.py
│ ├── base.py
│ ├── dep_data.py
│ └── dm_util
│ │ ├── datamodule_util.py
│ │ ├── fields.py
│ │ ├── padder.py
│ │ └── util.py
├── inside
│ ├── __init__.py
│ ├── eisner.py
│ ├── eisner2o.py
│ ├── eisner_satta.py
│ ├── fn.py
│ └── span.py
├── loss
│ ├── __init__.py
│ ├── dep_loss.py
│ └── get_score.py
├── model
│ ├── dep_parsing.py
│ ├── metric.py
│ └── module
│ │ ├── ember
│ │ ├── embedding.py
│ │ └── ext_embedding.py
│ │ ├── encoder
│ │ └── lstm_encoder.py
│ │ └── scorer
│ │ ├── dep_scorer.py
│ │ └── module
│ │ ├── biaffine.py
│ │ ├── nhpsg_scorer.py
│ │ ├── quadra_linear.py
│ │ └── triaffine.py
└── runner
│ └── base.py
├── supar
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── affine.py
│ ├── bert.py
│ ├── char_lstm.py
│ ├── dropout.py
│ ├── lstm.py
│ ├── mlp.py
│ ├── scalar_mix.py
│ ├── transformer.py
│ ├── treecrf.py
│ └── variational_inference.py
└── utils
│ ├── __init__.py
│ ├── alg.py
│ ├── common.py
│ ├── config.py
│ ├── data.py
│ ├── embedding.py
│ ├── field.py
│ ├── field_new.py
│ ├── fn.py
│ ├── logging.py
│ ├── metric.py
│ ├── parallel.py
│ ├── transform.py
│ └── vocab.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/span-based-dependency-parsing-2.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # span-based-dependency-parsing
2 | Source code of ACL2022 "[Headed-Span-Based Projective Dependency Parsing](http://arxiv.org/abs/2108.04750)" and
3 | Findings of ACL2022"[Combining (second-order) graph-based and headed-span-based projective dependency parsing](https://arxiv.org/pdf/2108.05838.pdf)"
4 |
5 | ## Setup
6 | setup environment
7 | ```
8 | conda create -n parsing python=3.7
9 | conda activate parsing
10 | while read requirement; do pip install $requirement; done < requirements.txt
11 | ```
12 |
13 | setup dataset:
14 |
15 | you can download the datasets I used from [link](https://mega.nz/file/jFIijLTI#b0b7550tdYVNcpGfgaXc0sk0F943lrt8D35v1SW2wbg).
16 |
17 | # Run
18 | ```
19 | python train.py +exp=base datamodule=a model=b seed=0
20 | a={ptb, ctb, ud2.2}
21 | b={biaffine, biaffine2o, span, span1o, span1oheadsplit, span2oheadsplit}
22 | ```
23 |
24 | multirun example:
25 | ```
26 | python train.py +exp=base datamodule.ud2.2 model=b datamodule.ud_lan=de,it,en,ca,cs,es,fr,no,ru,es,nl,bg seed=0,1,2 --mutlirun
27 | ```
28 | For UD, you also need to setup the JAVA environment for the use of MaltParser.
29 | You need download MaltParser v1.9.2 from [link](https://www.maltparser.org/download.html).
30 |
31 | # Contact
32 | Please let me know if there are any bugs. Also, feel free to contact bestsonta@gmail.com if you have any questions.
33 |
34 | # Citation
35 | ```
36 | @inproceedings{yang-tu-2022-headed,
37 | title={Headed-Span-Based Projective Dependency Parsing},
38 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics",
39 | author={Songlin Yang and Kewei Tu},
40 | year={2022}
41 | }
42 |
43 | @misc{yang-tu-2022-combining,
44 | title={Combining (second-order) graph-based and headed-span-based projective dependency parsing},
45 | author={Songlin Yang and Kewei Tu},
46 | year={2022},
47 | booktitle = "Findings of ACL",
48 | }
49 | ```
50 |
51 | # Acknowledge
52 | The code is based on [lightning+hydra](https://github.com/ashleve/lightning-hydra-template) template. I use [FastNLP](https://github.com/fastnlp/fastNLP) as dataloader. I use lots of built-in modules (LSTMs, Biaffines, Triaffines, Dropout Layers, etc) from [Supar](https://github.com/yzhangcs/parser/tree/main/supar).
53 |
54 |
--------------------------------------------------------------------------------
/configs/config.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # specify here default training configuration
4 | defaults:
5 | - trainer: default_trainer
6 | - optim: finetune_bert
7 | - model: _base
8 | - datamodule: _base
9 |
10 | runner:
11 | _target_: src.runner.base.Runner
12 |
13 | work_dir: ${hydra:runtime.cwd}/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed}
14 |
15 | wandb: False
16 | checkpoint: False
17 | device: 0
18 | seed: 0
19 | accumulation: 1
20 | use_logger: True
21 | distributed: False
22 |
23 | # output paths for hydra logs
24 | root: "."
25 | suffix: ""
26 |
27 | hydra:
28 | run:
29 | dir: ${work_dir}
30 | sweep:
31 | dir: logs/multiruns/experiment/${datamodule.name}/${model.name}/${now:%Y-%m-%d}-${now:%H-%M-%S}-seed-${seed}
32 | subdir: ${hydra.job.num}
33 | job:
34 | env_set:
35 | WANDB_CONSOLE: 'off'
36 |
37 |
--------------------------------------------------------------------------------
/configs/datamodule/_base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | datamodule:
4 | use_char: False
5 | use_bert: True
6 | use_pos: False
7 | use_word: False
8 | use_emb: False
9 | use_sib: "${model.scorer.use_sib}"
10 | use_span_head_word: "${model.scorer.use_span}"
11 | ext_emb_path: ""
12 | bert: ''
13 | min_freq: 2
14 | fix_len: 20
15 | train_sampler_type: 'token'
16 | test_sampler_type: 'token'
17 | bucket: 32
18 | bucket_test: 32
19 | max_tokens: 5000
20 | max_tokens_test: 5000
21 | use_cache: True
22 | use_bert_cache: True
23 | max_len: 10000
24 | max_len_test: 10000
25 | root: '.'
26 | distributed: False
27 | # for PTB only. clean (-RHS-)
28 | clean_word: False
29 |
30 |
--------------------------------------------------------------------------------
/configs/datamodule/ctb.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | datamodule:
7 | target:
8 | _target_: src.datamodule.dep_data.DepData
9 |
10 | train_dep: "${root}/data/ctb/train.ctb.conll"
11 | dev_dep: "${root}/data/ctb/dev.ctb.conll"
12 | test_dep: "${root}/data/ctb/test.ctb.conll"
13 |
14 | bert: "bert-base-chinese"
15 | ignore_punct: True
16 | name: 'ctb'
17 | use_pos: True
18 |
19 | # ext_emb_path: ${.mapping.${.emb_type}}
20 | # emb_type: sskip
21 | # mapping:
22 | # giga: "data/giga.100.txt"
23 | # sskip: "data/sskip.chn.50"
24 |
25 | cache: "${root}/data/ctb/ctb.dep.pickle"
26 | cache_bert: "${root}/data/ctb/ctb.dep.cache_${bert}"
27 |
28 | model:
29 | metric:
30 | target:
31 | _target_: src.model.metric.AttachmentMetric
32 | write_result_to_file: True
33 |
34 |
35 |
--------------------------------------------------------------------------------
/configs/datamodule/ptb.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - _base
4 |
5 | datamodule:
6 | target:
7 | _target_: src.datamodule.dep_data.DepData
8 | name: 'ptb'
9 | train_dep: "${root}/data/ptb/train.gold.conllu"
10 | dev_dep: "${root}/data/ptb/dev.gold.conllu"
11 | test_dep: "${root}/data/ptb/test.gold.conllu"
12 | cache: "${root}/data/ptb/ptb.dep.pickle"
13 | cache_bert: "${root}/data/ptb/ptb.dep.cache_${datamodule.bert}"
14 | ext_emb_path: "${root}/data/ptb/glove.6B.100d.txt"
15 | ignore_punct: True
16 | clean_word: True
17 | bert: 'bert-large-cased'
18 |
19 | model:
20 | metric:
21 | target:
22 | _target_: src.model.metric.AttachmentMetric
23 | write_result_to_file: True
24 |
25 |
26 |
--------------------------------------------------------------------------------
/configs/datamodule/ud2.2.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - _base
4 |
5 | datamodule:
6 | target:
7 | _target_: src.datamodule.dep_data.DepUD
8 | abbreviation:
9 | 'no': no_bokmaal
10 | bg: bg_btb
11 | ca: ca_ancora
12 | cs: cs_pdt
13 | de: de_gsd
14 | en: en_ewt
15 | es: es_ancora
16 | fr: fr_gsd
17 | it: it_isdt
18 | 'nl': nl_alpino
19 | ro: ro_rrt
20 | ru: ru_syntagrus
21 | # extra languages
22 | ta: ta_ttb
23 | ko: ko_gsd
24 | zh: zh_gsd
25 |
26 | use_pos: True
27 | ud_lan: bg
28 | ud_name: "${.abbreviation.${.ud_lan}}"
29 | ud_ver: 2.2
30 | name: "ud2.2_${.ud_lan}"
31 | # Do not ignore punctuations while evaluating
32 | ignore_punct: True
33 | max_len: 200
34 | train_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-train.conllu"
35 | dev_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-dev.conllu"
36 | test_dep: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-test.conllu"
37 | cache: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-cache.pickle"
38 | cache_bert: "${root}/data/ud${.ud_ver}/UD_${.ud_mapping.${.ud_name}}/${.ud_name}-ud-cache_bert.pickle"
39 | ext_emb_path: "${root}/data/ud${.ud_ver}/fasttext/nogen/${.ud_lan}.lower.nogen.300.txt"
40 | bert: "bert-base-multilingual-cased"
41 | #use_bert: False
42 | #use_emb: False
43 | #ext_emb: ""
44 | ud_mapping:
45 | bg_btb: Bulgarian-BTB
46 | ca_ancora: Catalan-AnCora
47 | cs_pdt: Czech-PDT
48 | de_gsd: German-GSD
49 | en_ewt: English-EWT
50 | es_ancora: Spanish-AnCora
51 | fr_gsd: French-GSD
52 | it_isdt: Italian-ISDT
53 | nl_alpino: Dutch-Alpino
54 | no_bokmaal: Norwegian-Bokmaal
55 | ro_rrt: Romanian-RRT
56 | ru_syntagrus: Russian-SynTagRus
57 | # extra languages
58 | ta_ttb : Tamil-TTB
59 | ko_gsd: Korean-GSD
60 | zh_gsd: Chinese-GSD
61 |
62 | model:
63 | metric:
64 | target:
65 | _target_: src.model.metric.PseudoProjDepExternalMetric
66 | write_result_to_file: True
67 | lan: ${datamodule.ud_lan}
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/configs/exp/base.yaml:
--------------------------------------------------------------------------------
1 |
2 | # @package _global_
3 |
4 |
5 | trainer:
6 | min_epochs: 1
7 | max_epochs: 10
8 |
9 |
10 | # 16*250=4000
11 | accumulation: 16
12 |
13 | datamodule:
14 | max_tokens: 250
15 | max_tokens_test: 250
16 |
17 | # save checkpoints of the model.
18 | checkpoint: False
19 |
20 | model:
21 | embeder:
22 | finetune: True
23 |
24 | optim:
25 | only_embeder: True
26 |
27 |
28 | callbacks:
29 | transformer_scheduler:
30 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler
31 | warmup: ${optim.warmup}
32 |
33 | pretty_progress_bar:
34 | _target_: src.callbacks.progressbar.PrettyProgressBar
35 | refresh_rate: 1
36 | process_position: 0
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/configs/finetune/base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /optim: finetune_bert
5 |
6 | trainer:
7 | min_epochs: 1
8 | max_epochs: 10
9 |
10 | callbacks:
11 | transformer_scheduler:
12 | _target_: src.callbacks.transformer_scheduler.TransformerLrScheduler
13 | warmup: ${optim.warmup}
14 |
15 | model:
16 | embeder:
17 | finetune: True
18 |
19 | optim:
20 | only_embeder: True
--------------------------------------------------------------------------------
/configs/logger/comet.yaml:
--------------------------------------------------------------------------------
1 | # https://www.comet.ml
2 |
3 | comet:
4 | _target_: pytorch_lightning.loggers.comet.CometLogger
5 | api_key: ???
6 | project_name: "project_template_test"
7 | experiment_name: null
8 |
--------------------------------------------------------------------------------
/configs/logger/csv.yaml:
--------------------------------------------------------------------------------
1 | # CSVLogger built in PyTorch Lightning
2 |
3 | csv:
4 | _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
5 | save_dir: "."
6 | name: "csv/"
7 |
--------------------------------------------------------------------------------
/configs/logger/many_loggers.yaml:
--------------------------------------------------------------------------------
1 | # train with many loggers at once
2 |
3 | defaults:
4 | - csv.yaml
5 | - wandb.yaml
6 | # - neptune.yaml
7 | # - comet.yaml
8 | # - tensorboard.yaml
9 |
--------------------------------------------------------------------------------
/configs/logger/neptune.yaml:
--------------------------------------------------------------------------------
1 | # https://neptune.ai
2 |
3 | neptune:
4 | _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
5 | project_name: "your_name/lightning-hydra-template-test"
6 | api_key: ${env:NEPTUNE_API_TOKEN} # api key is laoded from environment variable
7 | # experiment_name: "some_experiment"
8 |
--------------------------------------------------------------------------------
/configs/logger/tensorboard.yaml:
--------------------------------------------------------------------------------
1 | # TensorBoard
2 |
3 | tensorboard:
4 | _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
5 | save_dir: "tensorboard/"
6 | name: "default"
7 |
--------------------------------------------------------------------------------
/configs/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | logger:
2 | wandb:
3 | _target_: pytorch_lightning.loggers.wandb.WandbLogger
4 | project: "supervised_parsing"
5 | entity: "sonta2020"
6 | job_type: "train"
7 | group: ""
8 | name: ""
9 | save_dir: "${work_dir}"
10 |
--------------------------------------------------------------------------------
/configs/model/_base.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 |
4 | model:
5 | target:
6 | _target_: src.model.dep_parsing.ProjectiveDepParser
7 |
8 | embeder:
9 | target:
10 | _target_: src.model.module.ember.embedding.Embeder
11 |
12 | #pos
13 | n_pos_embed: 100
14 | #char
15 | n_char_embed: 50
16 | n_char_out: 100
17 | char_input_dropout: 0.
18 | # bert
19 | n_bert_out: 1024
20 | n_bert_layers: 4
21 | mix_dropout: 0.
22 | use_projection: False
23 | use_scalarmix: False
24 | finetune: True
25 | #word
26 | n_embed: 300
27 |
28 | encoder:
29 | target:
30 | _target_: src.model.module.encoder.lstm_encoder.LSTMencoder
31 | embed_dropout: .33
32 | embed_dropout_type: shared
33 | lstm_dropout: .33
34 | n_lstm_hidden: 1000
35 | n_lstm_layers: 3
36 | before_lstm_dropout: 0.
37 |
38 | scorer:
39 | target:
40 | _target_: src.model.module.scorer.dep_scorer.DepScorer
41 | n_mlp_arc: 600
42 | n_mlp_rel: 300
43 | n_mlp_sib: 300
44 | mlp_dropout: .33
45 | scaling: False
46 | use_arc: False
47 | use_span: False
48 | use_sib: False
49 | span_scorer_type: biaffine
50 |
51 |
52 |
53 |
54 |
55 |
56 |
--------------------------------------------------------------------------------
/configs/model/biaffine.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: True
9 | use_sib: False
10 | use_span: False
11 |
12 |
13 | loss:
14 | target:
15 | _target_: src.loss.dep_loss.DepLoss
16 | loss_type: 'mm'
17 |
18 | name: 'dep1o_${model.loss.loss_type}'
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/configs/model/biaffine2o.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: True
9 | use_sib: True
10 | use_span: False
11 |
12 | loss:
13 | target:
14 | _target_: src.loss.dep_loss.Dep2O
15 | loss_type: 'mm'
16 |
17 | name: 'dep2o_${model.loss.loss_type}'
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/configs/model/span.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: False
9 | use_sib: False
10 | use_span: True
11 |
12 |
13 | loss:
14 | target:
15 | _target_: src.loss.dep_loss.Span
16 | loss_type: 'mm'
17 |
18 | name: 'span_${model.loss.loss_type}'
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/configs/model/span1o.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: True
9 | use_sib: False
10 | use_span: True
11 | span_scorer_type: biaffine
12 |
13 | loss:
14 | target:
15 | _target_: src.loss.dep_loss.DepSpanLoss
16 | loss_type: 'mm'
17 |
18 | name: 'dep1o_span_${model.loss.loss_type}'
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/configs/model/span1oheadplit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: True
9 | use_sib: False
10 | use_span: True
11 | span_scorer_type: headsplit
12 |
13 | loss:
14 | target:
15 | _target_: src.loss.dep_loss.DepSpanLoss
16 | loss_type: 'mm'
17 |
18 | name: 'dep1o_span_headsplit_${model.loss.loss_type}'
19 |
20 |
21 |
--------------------------------------------------------------------------------
/configs/model/span2oheadplit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - _base
5 |
6 | model:
7 | scorer:
8 | use_arc: True
9 | use_sib: True
10 | use_span: True
11 | span_scorer_type: headsplit
12 |
13 | loss:
14 | target:
15 | _target_: src.loss.dep_loss.DepSpanLoss
16 | loss_type: 'mm'
17 |
18 | name: 'dep2o_span_headsplit_${model.loss.loss_type}'
19 |
20 |
21 |
--------------------------------------------------------------------------------
/configs/optim/finetune_bert.yaml:
--------------------------------------------------------------------------------
1 | warmup: 0.1
2 | #beta1: 0.9
3 | #beta2: 0.9
4 | #eps: 1e-12
5 | lr: 5e-5
6 | #weight_decay: 0
7 | lr_rate: 50
8 | scheduler_type: 'linear_warmup'
9 |
10 |
--------------------------------------------------------------------------------
/configs/trainer/default_trainer.yaml:
--------------------------------------------------------------------------------
1 | _target_: pytorch_lightning.Trainer
2 | gpus: 1 # set -1 to train on all GPUs available, set 0 to train on CPU only
3 | min_epochs: 1
4 | max_epochs: 10
5 | gradient_clip_val: 5
6 | num_sanity_val_steps: 3
7 | progress_bar_refresh_rate: 20
8 | weights_summary: null
9 | fast_dev_run: False
10 | #default_root_dir: "lightning_logs/"
11 |
--------------------------------------------------------------------------------
/fastNLP/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、
3 | :mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。
4 |
5 | - :mod:`~fastNLP.core` 是fastNLP 的核心模块,包括 DataSet、 Trainer、 Tester 等组件。详见文档 :mod:`fastNLP.core`
6 | - :mod:`~fastNLP.io` 是实现输入输出的模块,包括了数据集的读取,模型的存取等功能。详见文档 :mod:`fastNLP.io`
7 | - :mod:`~fastNLP.embeddings` 提供用于构建复杂网络模型所需的各种embedding。详见文档 :mod:`fastNLP.embeddings`
8 | - :mod:`~fastNLP.modules` 包含了用于搭建神经网络模型的诸多组件,可以帮助用户快速搭建自己所需的网络。详见文档 :mod:`fastNLP.modules`
9 | - :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括 :class:`~fastNLP.models.CNNText` 、 :class:`~fastNLP.models.SeqLabeling` 等常见模型。详见文档 :mod:`fastNLP.models`
10 |
11 | fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的文档如下:
12 | """
13 | __all__ = [
14 | "Instance",
15 | "FieldArray",
16 |
17 | "DataSetIter",
18 | "BatchIter",
19 | "TorchLoaderIter",
20 |
21 | "Vocabulary",
22 | "DataSet",
23 | "Const",
24 |
25 | "Trainer",
26 | "Tester",
27 |
28 | "DistTrainer",
29 | "get_local_rank",
30 |
31 | "Callback",
32 | "GradientClipCallback",
33 | "EarlyStopCallback",
34 | "FitlogCallback",
35 | "EvaluateCallback",
36 | "LRScheduler",
37 | "ControlC",
38 | "LRFinder",
39 | "TensorboardCallback",
40 | "WarmupCallback",
41 | 'SaveModelCallback',
42 | "CallbackException",
43 | "EarlyStopError",
44 | "CheckPointCallback",
45 |
46 | "Padder",
47 | "AutoPadder",
48 | "EngChar2DPadder",
49 |
50 | # "CollateFn",
51 | "ConcatCollateFn",
52 |
53 | "MetricBase",
54 | "AccuracyMetric",
55 | "SpanFPreRecMetric",
56 | "CMRC2018Metric",
57 | "ClassifyFPreRecMetric",
58 | "ConfusionMatrixMetric",
59 |
60 | "Optimizer",
61 | "SGD",
62 | "Adam",
63 | "AdamW",
64 |
65 | "Sampler",
66 | "SequentialSampler",
67 | "BucketSampler",
68 | "RandomSampler",
69 | "SortedSampler",
70 | "ConstantTokenNumSampler",
71 |
72 | "LossFunc",
73 | "CrossEntropyLoss",
74 | "MSELoss",
75 | "L1Loss",
76 | "BCELoss",
77 | "NLLLoss",
78 | "LossInForward",
79 | "LossBase",
80 | "CMRC2018Loss",
81 |
82 | "cache_results",
83 |
84 | 'logger',
85 | "init_logger_dist",
86 | ]
87 | __version__ = '0.5.6'
88 |
89 | import sys
90 |
91 | from . import embeddings
92 | from . import models
93 | from . import modules
94 | from .core import *
95 | from .doc_utils import doc_process
96 | from .io import loader, pipe
97 |
98 | doc_process(sys.modules[__name__])
99 |
--------------------------------------------------------------------------------
/fastNLP/core/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import,
3 | 例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式::
4 |
5 | # 直接从 fastNLP 中 import
6 | from fastNLP import DataSetIter
7 |
8 | # 从 core 模块的子模块 batch 中 import DataSetIter
9 | from fastNLP.core.batch import DataSetIter
10 |
11 | 对于常用的功能,你只需要在 :mod:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
12 |
13 | """
14 | __all__ = [
15 | "DataSet",
16 |
17 | "Instance",
18 |
19 | "FieldArray",
20 | "Padder",
21 | "AutoPadder",
22 | "EngChar2DPadder",
23 |
24 | "ConcatCollateFn",
25 |
26 | "Vocabulary",
27 |
28 | "DataSetIter",
29 | "BatchIter",
30 | "TorchLoaderIter",
31 |
32 | "Const",
33 |
34 | "Tester",
35 | "Trainer",
36 |
37 | "DistTrainer",
38 | "get_local_rank",
39 |
40 | "cache_results",
41 | "seq_len_to_mask",
42 | "get_seq_len",
43 | "logger",
44 | "init_logger_dist",
45 |
46 | "Callback",
47 | "GradientClipCallback",
48 | "EarlyStopCallback",
49 | "FitlogCallback",
50 | "EvaluateCallback",
51 | "LRScheduler",
52 | "ControlC",
53 | "LRFinder",
54 | "TensorboardCallback",
55 | "WarmupCallback",
56 | 'SaveModelCallback',
57 | "CallbackException",
58 | "EarlyStopError",
59 | "CheckPointCallback",
60 |
61 | "LossFunc",
62 | "CrossEntropyLoss",
63 | "L1Loss",
64 | "BCELoss",
65 | "NLLLoss",
66 | "LossInForward",
67 | "CMRC2018Loss",
68 | "MSELoss",
69 | "LossBase",
70 |
71 | "MetricBase",
72 | "AccuracyMetric",
73 | "SpanFPreRecMetric",
74 | "CMRC2018Metric",
75 | "ClassifyFPreRecMetric",
76 | "ConfusionMatrixMetric",
77 |
78 | "Optimizer",
79 | "SGD",
80 | "Adam",
81 | "AdamW",
82 |
83 | "SequentialSampler",
84 | "BucketSampler",
85 | "RandomSampler",
86 | "Sampler",
87 | "SortedSampler",
88 | "ConstantTokenNumSampler"
89 | ]
90 |
91 | from ._logger import logger, init_logger_dist
92 | from .batch import DataSetIter, BatchIter, TorchLoaderIter
93 | from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
94 | LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, CallbackException, \
95 | EarlyStopError, CheckPointCallback
96 | from .const import Const
97 | from .dataset import DataSet
98 | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
99 | from .instance import Instance
100 | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, \
101 | LossInForward, CMRC2018Loss, LossBase, MSELoss
102 | from .metrics import AccuracyMetric, SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric, MetricBase,\
103 | ConfusionMatrixMetric
104 | from .optimizer import Optimizer, SGD, Adam, AdamW
105 | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler, SortedSampler, ConstantTokenNumSampler
106 | from .tester import Tester
107 | from .trainer import Trainer
108 | from .utils import cache_results, seq_len_to_mask, get_seq_len
109 | from .vocabulary import Vocabulary
110 | from .collate_fn import ConcatCollateFn
111 | from .dist_trainer import DistTrainer, get_local_rank
112 |
--------------------------------------------------------------------------------
/fastNLP/core/_parallel_utils.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = []
4 |
5 | import threading
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn.parallel.parallel_apply import get_a_var
10 | from torch.nn.parallel.replicate import replicate
11 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
12 |
13 |
14 | def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
15 | r"""Applies each `module` in :attr:`modules` in parallel on arguments
16 | contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
17 | on each of :attr:`devices`.
18 |
19 | :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
20 | :attr:`devices` (if given) should all have same length. Moreover, each
21 | element of :attr:`inputs` can either be a single object as the only argument
22 | to a module, or a collection of positional arguments.
23 | """
24 | assert len(modules) == len(inputs)
25 | if kwargs_tup is not None:
26 | assert len(modules) == len(kwargs_tup)
27 | else:
28 | kwargs_tup = ({},) * len(modules)
29 | if devices is not None:
30 | assert len(modules) == len(devices)
31 | else:
32 | devices = [None] * len(modules)
33 |
34 | lock = threading.Lock()
35 | results = {}
36 | grad_enabled = torch.is_grad_enabled()
37 |
38 | def _worker(i, module, input, kwargs, device=None):
39 | torch.set_grad_enabled(grad_enabled)
40 | if device is None:
41 | device = get_a_var(input).get_device()
42 | try:
43 | with torch.cuda.device(device):
44 | # this also avoids accidental slicing of `input` if it is a Tensor
45 | if not isinstance(input, (list, tuple)):
46 | input = (input,)
47 | output = getattr(module, func_name)(*input, **kwargs)
48 | with lock:
49 | results[i] = output
50 | except Exception as e:
51 | with lock:
52 | results[i] = e
53 |
54 | if len(modules) > 1:
55 | threads = [threading.Thread(target=_worker,
56 | args=(i, module, input, kwargs, device))
57 | for i, (module, input, kwargs, device) in
58 | enumerate(zip(modules, inputs, kwargs_tup, devices))]
59 |
60 | for thread in threads:
61 | thread.start()
62 | for thread in threads:
63 | thread.join()
64 | else:
65 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
66 |
67 | outputs = []
68 | for i in range(len(inputs)):
69 | output = results[i]
70 | if isinstance(output, Exception):
71 | raise output
72 | outputs.append(output)
73 | return outputs
74 |
75 |
76 | def _data_parallel_wrapper(func_name, device_ids, output_device):
77 | r"""
78 | 这个函数是用于对需要多卡执行的函数的wrapper函数。参考的nn.DataParallel的forward函数
79 |
80 | :param str, func_name: 对network中的这个函数进行多卡运行
81 | :param device_ids: nn.DataParallel中的device_ids
82 | :param output_device: nn.DataParallel中的output_device
83 | :return:
84 | """
85 |
86 | def wrapper(network, *inputs, **kwargs):
87 | inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
88 | if len(device_ids) == 1:
89 | return getattr(network, func_name)(*inputs[0], **kwargs[0])
90 | replicas = replicate(network, device_ids[:len(inputs)])
91 | outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
92 | return gather(outputs, output_device)
93 |
94 | return wrapper
95 |
96 |
97 | def _model_contains_inner_module(model):
98 | r"""
99 |
100 | :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel,
101 | nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。
102 | :return: bool
103 | """
104 | if isinstance(model, nn.Module):
105 | if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
106 | return True
107 | return False
108 |
--------------------------------------------------------------------------------
/fastNLP/core/collate_fn.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 | from builtins import sorted
3 |
4 | import torch
5 | import numpy as np
6 | from .field import _get_ele_type_and_dim
7 | from .utils import logger
8 | from copy import deepcopy
9 |
10 |
11 | def _check_type(batch_dict, fields):
12 | if len(fields) == 0:
13 | raise RuntimeError
14 | types = []
15 | dims = []
16 | for f in fields:
17 | t, d = _get_ele_type_and_dim(batch_dict[f])
18 | types.append(t)
19 | dims.append(d)
20 | diff_types = set(types)
21 | diff_dims = set(dims)
22 | if len(diff_types) > 1 or len(diff_dims) > 1:
23 | raise ValueError
24 | return types[0]
25 |
26 |
27 | def batching(samples, max_len=0, padding_val=0):
28 | if len(samples) == 0:
29 | return samples
30 | if max_len <= 0:
31 | max_len = max(s.shape[0] for s in samples)
32 | batch = np.full((len(samples), max_len), fill_value=padding_val)
33 | for i, s in enumerate(samples):
34 | slen = min(s.shape[0], max_len)
35 | batch[i][:slen] = s[:slen]
36 | return batch
37 |
38 |
39 | class Collater:
40 | r"""
41 | 辅助DataSet管理collate_fn的类
42 |
43 | """
44 | def __init__(self):
45 | self.collate_fns = {}
46 |
47 | def add_fn(self, fn, name=None):
48 | r"""
49 | 向collater新增一个collate_fn函数
50 |
51 | :param callable fn:
52 | :param str,int name:
53 | :return:
54 | """
55 | if name in self.collate_fns:
56 | logger.warn(f"collate_fn:{name} will be overwritten.")
57 | if name is None:
58 | name = len(self.collate_fns)
59 | self.collate_fns[name] = fn
60 |
61 | def is_empty(self):
62 | r"""
63 | 返回是否包含collate_fn
64 |
65 | :return:
66 | """
67 | return len(self.collate_fns) == 0
68 |
69 | def delete_fn(self, name=None):
70 | r"""
71 | 删除collate_fn
72 |
73 | :param str,int name: 如果为None就删除最近加入的collate_fn
74 | :return:
75 | """
76 | if not self.is_empty():
77 | if name in self.collate_fns:
78 | self.collate_fns.pop(name)
79 | elif name is None:
80 | last_key = list(self.collate_fns.keys())[0]
81 | self.collate_fns.pop(last_key)
82 |
83 | def collate_batch(self, ins_list):
84 | bx, by = {}, {}
85 | for name, fn in self.collate_fns.items():
86 | try:
87 | batch_x, batch_y = fn(ins_list)
88 | except BaseException as e:
89 | logger.error(f"Exception:`{e}` happens when call collate_fn:`{name}`.")
90 | raise e
91 | bx.update(batch_x)
92 | by.update(batch_y)
93 | return bx, by
94 |
95 | def copy_from(self, col):
96 | assert isinstance(col, Collater)
97 | new_col = Collater()
98 | new_col.collate_fns = deepcopy(col.collate_fns)
99 | return new_col
100 |
101 |
102 | class ConcatCollateFn:
103 | r"""
104 | field拼接collate_fn,将不同field按序拼接后,padding产生数据。
105 |
106 | :param List[str] inputs: 将哪些field的数据拼接起来, 目前仅支持1d的field
107 | :param str output: 拼接后的field名称
108 | :param pad_val: padding的数值
109 | :param max_len: 拼接后最大长度
110 | :param is_input: 是否将生成的output设置为input
111 | :param is_target: 是否将生成的output设置为target
112 | """
113 |
114 | def __init__(self, inputs, output, pad_val=0, max_len=0, is_input=True, is_target=False):
115 | super().__init__()
116 | assert isinstance(inputs, list)
117 | self.inputs = inputs
118 | self.output = output
119 | self.pad_val = pad_val
120 | self.max_len = max_len
121 | self.is_input = is_input
122 | self.is_target = is_target
123 |
124 | @staticmethod
125 | def _to_numpy(seq):
126 | if torch.is_tensor(seq):
127 | return seq.numpy()
128 | else:
129 | return np.array(seq)
130 |
131 | def __call__(self, ins_list):
132 | samples = []
133 | for i, ins in ins_list:
134 | sample = []
135 | for input_name in self.inputs:
136 | sample.append(self._to_numpy(ins[input_name]))
137 | samples.append(np.concatenate(sample, axis=0))
138 | batch = batching(samples, max_len=self.max_len, padding_val=self.pad_val)
139 | b_x, b_y = {}, {}
140 | if self.is_input:
141 | b_x[self.output] = batch
142 | if self.is_target:
143 | b_y[self.output] = batch
144 |
145 | return b_x, b_y
146 |
--------------------------------------------------------------------------------
/fastNLP/core/const.py:
--------------------------------------------------------------------------------
1 | r"""
2 | fastNLP包当中的field命名均符合一定的规范,该规范由fastNLP.Const类进行定义。
3 | """
4 |
5 | __all__ = [
6 | "Const"
7 | ]
8 |
9 |
10 | class Const:
11 | r"""
12 | fastNLP中field命名常量。
13 |
14 | .. todo::
15 | 把下面这段改成表格
16 |
17 | 具体列表::
18 |
19 | INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, )
20 | CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2)
21 | INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2)
22 | OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2)
23 | TARGET 真实目标 target(具有多列target时,依次使用target1,target2)
24 | LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2)
25 | RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2)
26 | RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2)
27 |
28 | """
29 | INPUT = 'words'
30 | CHAR_INPUT = 'chars'
31 | INPUT_LEN = 'seq_len'
32 | OUTPUT = 'pred'
33 | TARGET = 'target'
34 | LOSS = 'loss'
35 | RAW_WORD = 'raw_words'
36 | RAW_CHAR = 'raw_chars'
37 |
38 | @staticmethod
39 | def INPUTS(i):
40 | r"""得到第 i 个 ``INPUT`` 的命名"""
41 | i = int(i) + 1
42 | return Const.INPUT + str(i)
43 |
44 | @staticmethod
45 | def CHAR_INPUTS(i):
46 | r"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
47 | i = int(i) + 1
48 | return Const.CHAR_INPUT + str(i)
49 |
50 | @staticmethod
51 | def RAW_WORDS(i):
52 | r"""得到第 i 个 ``RAW_WORDS`` 的命名"""
53 | i = int(i) + 1
54 | return Const.RAW_WORD + str(i)
55 |
56 | @staticmethod
57 | def RAW_CHARS(i):
58 | r"""得到第 i 个 ``RAW_CHARS`` 的命名"""
59 | i = int(i) + 1
60 | return Const.RAW_CHAR + str(i)
61 |
62 | @staticmethod
63 | def INPUT_LENS(i):
64 | r"""得到第 i 个 ``INPUT_LEN`` 的命名"""
65 | i = int(i) + 1
66 | return Const.INPUT_LEN + str(i)
67 |
68 | @staticmethod
69 | def OUTPUTS(i):
70 | r"""得到第 i 个 ``OUTPUT`` 的命名"""
71 | i = int(i) + 1
72 | return Const.OUTPUT + str(i)
73 |
74 | @staticmethod
75 | def TARGETS(i):
76 | r"""得到第 i 个 ``TARGET`` 的命名"""
77 | i = int(i) + 1
78 | return Const.TARGET + str(i)
79 |
80 | @staticmethod
81 | def LOSSES(i):
82 | r"""得到第 i 个 ``LOSS`` 的命名"""
83 | i = int(i) + 1
84 | return Const.LOSS + str(i)
85 |
--------------------------------------------------------------------------------
/fastNLP/core/instance.py:
--------------------------------------------------------------------------------
1 | r"""
2 | instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。
3 | 便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格
4 |
5 | """
6 |
7 | __all__ = [
8 | "Instance"
9 | ]
10 |
11 | from .utils import pretty_table_printer
12 |
13 |
14 | class Instance(object):
15 | r"""
16 | Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。
17 | Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示::
18 |
19 | >>>from fastNLP import Instance
20 | >>>ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])
21 | >>>ins["field_1"]
22 | [1, 1, 1]
23 | >>>ins.add_field("field_3", [3, 3, 3])
24 | >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
25 | """
26 |
27 | def __init__(self, **fields):
28 |
29 | self.fields = fields
30 |
31 | def add_field(self, field_name, field):
32 | r"""
33 | 向Instance中增加一个field
34 |
35 | :param str field_name: 新增field的名称
36 | :param Any field: 新增field的内容
37 | """
38 | self.fields[field_name] = field
39 |
40 | def items(self):
41 | r"""
42 | 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value
43 |
44 | :return: 一个迭代器
45 | """
46 | return self.fields.items()
47 |
48 | def __contains__(self, item):
49 | return item in self.fields
50 |
51 | def __getitem__(self, name):
52 | if name in self.fields:
53 | return self.fields[name]
54 | else:
55 | print(name)
56 | raise KeyError("{} not found".format(name))
57 |
58 | def __setitem__(self, name, field):
59 | return self.add_field(name, field)
60 |
61 | def __repr__(self):
62 | return str(pretty_table_printer(self))
63 |
--------------------------------------------------------------------------------
/fastNLP/core/predictor.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "Predictor"
5 | ]
6 |
7 | from collections import defaultdict
8 |
9 | import torch
10 |
11 | from . import DataSet
12 | from . import DataSetIter
13 | from . import SequentialSampler
14 | from .utils import _build_args, _move_dict_value_to_device, _get_model_device
15 |
16 |
17 | class Predictor(object):
18 | r"""
19 | 一个根据训练模型预测输出的预测器(Predictor)
20 |
21 | 与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。
22 | 这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。
23 | """
24 |
25 | def __init__(self, network):
26 | r"""
27 |
28 | :param torch.nn.Module network: 用来完成预测任务的模型
29 | """
30 | if not isinstance(network, torch.nn.Module):
31 | raise ValueError(
32 | "Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network)))
33 | self.network = network
34 | self.batch_size = 1
35 | self.batch_output = []
36 |
37 | def predict(self, data: DataSet, seq_len_field_name=None):
38 | r"""用已经训练好的模型进行inference.
39 |
40 | :param fastNLP.DataSet data: 待预测的数据集
41 | :param str seq_len_field_name: 表示序列长度信息的field名字
42 | :return: dict dict里面的内容为模型预测的结果
43 | """
44 | if not isinstance(data, DataSet):
45 | raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
46 | if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
47 | raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
48 |
49 | prev_training = self.network.training
50 | self.network.eval()
51 | network_device = _get_model_device(self.network)
52 | batch_output = defaultdict(list)
53 | data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
54 |
55 | if hasattr(self.network, "predict"):
56 | predict_func = self.network.predict
57 | else:
58 | predict_func = self.network.forward
59 |
60 | with torch.no_grad():
61 | for batch_x, _ in data_iterator:
62 | _move_dict_value_to_device(batch_x, _, device=network_device)
63 | refined_batch_x = _build_args(predict_func, **batch_x)
64 | prediction = predict_func(**refined_batch_x)
65 |
66 | if seq_len_field_name is not None:
67 | seq_lens = batch_x[seq_len_field_name].tolist()
68 |
69 | for key, value in prediction.items():
70 | value = value.cpu().numpy()
71 | if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
72 | batch_output[key].extend(value.tolist())
73 | else:
74 | if seq_len_field_name is not None:
75 | tmp_batch = []
76 | for idx, seq_len in enumerate(seq_lens):
77 | tmp_batch.append(value[idx, :seq_len])
78 | batch_output[key].extend(tmp_batch)
79 | else:
80 | batch_output[key].append(value)
81 |
82 | self.network.train(prev_training)
83 | return batch_output
84 |
--------------------------------------------------------------------------------
/fastNLP/doc_utils.py:
--------------------------------------------------------------------------------
1 | r"""undocumented
2 | 用于辅助生成 fastNLP 文档的代码
3 | """
4 |
5 | __all__ = []
6 |
7 | import inspect
8 | import sys
9 |
10 |
11 | def doc_process(m):
12 | for name, obj in inspect.getmembers(m):
13 | if inspect.isclass(obj) or inspect.isfunction(obj):
14 | if obj.__module__ != m.__name__:
15 | if obj.__doc__ is None:
16 | # print(name, obj.__doc__)
17 | pass
18 | else:
19 | module_name = obj.__module__
20 |
21 | # 识别并标注类和函数在不同层次中的位置
22 |
23 | while 1:
24 | defined_m = sys.modules[module_name]
25 | try:
26 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
27 | obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \
28 | + " :class:`" + module_name + "." + name + "`\n" + obj.__doc__
29 | break
30 | module_name = ".".join(module_name.split('.')[:-1])
31 | if module_name == m.__name__:
32 | # print(name, ": not found defined doc.")
33 | break
34 | except:
35 | print("Warning: Module {} lacks `__doc__`".format(module_name))
36 | break
37 |
38 | # 识别并标注基类,只有基类也在 fastNLP 中定义才显示
39 |
40 | if inspect.isclass(obj):
41 | for base in obj.__bases__:
42 | if base.__module__.startswith("fastNLP"):
43 | parts = base.__module__.split(".") + []
44 | module_name, i = "fastNLP", 1
45 | for i in range(len(parts) - 1):
46 | defined_m = sys.modules[module_name]
47 | try:
48 | if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__:
49 | obj.__doc__ = r"基类 :class:`" + defined_m.__name__ + "." + base.__name__ + "` \n\n" + obj.__doc__
50 | break
51 | module_name += "." + parts[i + 1]
52 | except:
53 | print("Warning: Module {} lacks `__doc__`".format(module_name))
54 | break
55 |
--------------------------------------------------------------------------------
/fastNLP/embeddings/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有
3 | embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的
4 | torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的
5 | 输出维度。
6 | """
7 |
8 | __all__ = [
9 | "Embedding",
10 | "TokenEmbedding",
11 | "StaticEmbedding",
12 | "ElmoEmbedding",
13 | "BertEmbedding",
14 | "BertWordPieceEncoder",
15 |
16 | "RobertaEmbedding",
17 | "RobertaWordPieceEncoder",
18 |
19 | "GPT2Embedding",
20 | "GPT2WordPieceEncoder",
21 |
22 | "StackEmbedding",
23 | "LSTMCharEmbedding",
24 | "CNNCharEmbedding",
25 |
26 | "get_embeddings",
27 | "get_sinusoid_encoding_table"
28 | ]
29 |
30 | from .embedding import Embedding, TokenEmbedding
31 | from .static_embedding import StaticEmbedding
32 | from .elmo_embedding import ElmoEmbedding
33 | from .bert_embedding import BertEmbedding, BertWordPieceEncoder
34 | from .roberta_embedding import RobertaEmbedding, RobertaWordPieceEncoder
35 | from .gpt2_embedding import GPT2WordPieceEncoder, GPT2Embedding
36 | from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
37 | from .stack_embedding import StackEmbedding
38 | from .utils import get_embeddings, get_sinusoid_encoding_table
39 |
40 | import sys
41 | from ..doc_utils import doc_process
42 | doc_process(sys.modules[__name__])
--------------------------------------------------------------------------------
/fastNLP/embeddings/contextual_embedding.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "ContextualEmbedding"
8 | ]
9 |
10 | from abc import abstractmethod
11 |
12 | import torch
13 |
14 | from .embedding import TokenEmbedding
15 | from ..core import logger
16 | from ..core.batch import DataSetIter
17 | from ..core.dataset import DataSet
18 | from ..core.sampler import SequentialSampler
19 | from ..core.utils import _move_model_to_device, _get_model_device
20 | from ..core.vocabulary import Vocabulary
21 |
22 |
23 | class ContextualEmbedding(TokenEmbedding):
24 | r"""
25 | ContextualEmbedding组件. BertEmbedding与ElmoEmbedding的基类
26 | """
27 | def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
28 | super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
29 |
30 | def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True):
31 | r"""
32 | 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。
33 |
34 | :param datasets: DataSet对象
35 | :param batch_size: int, 生成cache的sentence表示时使用的batch的大小
36 | :param device: 参考 :class::fastNLP.Trainer 的device
37 | :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。
38 | :return:
39 | """
40 | for index, dataset in enumerate(datasets):
41 | try:
42 | assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
43 | assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
44 | except Exception as e:
45 | logger.error(f"Exception happens at {index} dataset.")
46 | raise e
47 |
48 | sent_embeds = {}
49 | _move_model_to_device(self, device=device)
50 | device = _get_model_device(self)
51 | pad_index = self._word_vocab.padding_idx
52 | logger.info("Start to calculate sentence representations.")
53 | with torch.no_grad():
54 | for index, dataset in enumerate(datasets):
55 | try:
56 | batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
57 | for batch_x, batch_y in batch:
58 | words = batch_x['words'].to(device)
59 | words_list = words.tolist()
60 | seq_len = words.ne(pad_index).sum(dim=-1)
61 | max_len = words.size(1)
62 | # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
63 | seq_len_from_behind = (max_len - seq_len).tolist()
64 | word_embeds = self(words).detach().cpu().numpy()
65 | for b in range(words.size(0)):
66 | length = seq_len_from_behind[b]
67 | if length == 0:
68 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
69 | else:
70 | sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
71 | except Exception as e:
72 | logger.error(f"Exception happens at {index} dataset.")
73 | raise e
74 | logger.info("Finish calculating sentence representations.")
75 | self.sent_embeds = sent_embeds
76 | if delete_weights:
77 | self._delete_model_weights()
78 |
79 | def _get_sent_reprs(self, words):
80 | r"""
81 | 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None
82 |
83 | :param words: torch.LongTensor
84 | :return:
85 | """
86 | if hasattr(self, 'sent_embeds'):
87 | words_list = words.tolist()
88 | seq_len = words.ne(self._word_pad_index).sum(dim=-1)
89 | _embeds = []
90 | for b in range(len(words)):
91 | words_i = tuple(words_list[b][:seq_len[b]])
92 | embed = self.sent_embeds[words_i]
93 | _embeds.append(embed)
94 | max_sent_len = max(map(len, _embeds))
95 | embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
96 | device=words.device)
97 | for i, embed in enumerate(_embeds):
98 | embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
99 | return embeds
100 | return None
101 |
102 | @abstractmethod
103 | def _delete_model_weights(self):
104 | r"""删除计算表示的模型以节省资源"""
105 | raise NotImplementedError
106 |
107 | def remove_sentence_cache(self):
108 | r"""
109 | 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。
110 |
111 | :return:
112 | """
113 | del self.sent_embeds
114 |
--------------------------------------------------------------------------------
/fastNLP/embeddings/stack_embedding.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "StackEmbedding",
8 | ]
9 |
10 | from typing import List
11 |
12 | import torch
13 | from torch import nn as nn
14 |
15 | from .embedding import TokenEmbedding
16 |
17 |
18 | class StackEmbedding(TokenEmbedding):
19 | r"""
20 | 支持将多个embedding集合成一个embedding。
21 |
22 | Example::
23 |
24 | >>> from fastNLP import Vocabulary
25 | >>> from fastNLP.embeddings import StaticEmbedding, StackEmbedding
26 | >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
27 | >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
28 | >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
29 | >>> embed = StackEmbedding([embed_1, embed_2])
30 |
31 | """
32 |
33 | def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
34 | r"""
35 |
36 | :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
37 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
38 | 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
39 | :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
40 | """
41 | vocabs = []
42 | for embed in embeds:
43 | if hasattr(embed, 'get_word_vocab'):
44 | vocabs.append(embed.get_word_vocab())
45 | _vocab = vocabs[0]
46 | for vocab in vocabs[1:]:
47 | assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
48 |
49 | super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
50 | assert isinstance(embeds, list)
51 | for embed in embeds:
52 | assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
53 | self.embeds = nn.ModuleList(embeds)
54 | self._embed_size = sum([embed.embed_size for embed in self.embeds])
55 |
56 | def append(self, embed: TokenEmbedding):
57 | r"""
58 | 添加一个embedding到结尾。
59 | :param embed:
60 | :return:
61 | """
62 | assert isinstance(embed, TokenEmbedding)
63 | self._embed_size += embed.embed_size
64 | self.embeds.append(embed)
65 | return self
66 |
67 | def pop(self):
68 | r"""
69 | 弹出最后一个embed
70 | :return:
71 | """
72 | embed = self.embeds.pop()
73 | self._embed_size -= embed.embed_size
74 | return embed
75 |
76 | @property
77 | def embed_size(self):
78 | r"""
79 | 该Embedding输出的vector的最后一维的维度。
80 | :return:
81 | """
82 | return self._embed_size
83 |
84 | def forward(self, words):
85 | r"""
86 | 得到多个embedding的结果,并把结果按照顺序concat起来。
87 |
88 | :param words: batch_size x max_len
89 | :return: 返回的shape和当前这个stack embedding中embedding的组成有关
90 | """
91 | outputs = []
92 | words = self.drop_word(words)
93 | for embed in self.embeds:
94 | outputs.append(embed(words))
95 | outputs = self.dropout(torch.cat(outputs, dim=-1))
96 | return outputs
97 |
--------------------------------------------------------------------------------
/fastNLP/embeddings/utils.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 | import numpy as np
6 | import torch
7 | from torch import nn as nn
8 |
9 | from ..core.vocabulary import Vocabulary
10 |
11 | __all__ = [
12 | 'get_embeddings',
13 | 'get_sinusoid_encoding_table'
14 | ]
15 |
16 |
17 | def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1, include_word_start_end=True):
18 | r"""
19 | 给定一个word的vocabulary生成character的vocabulary.
20 |
21 | :param vocab: 从vocab
22 | :param min_freq:
23 | :param include_word_start_end: 是否需要包含特殊的和
24 | :return:
25 | """
26 | char_vocab = Vocabulary(min_freq=min_freq)
27 | for word, index in vocab:
28 | if not vocab._is_word_no_create_entry(word):
29 | char_vocab.add_word_lst(list(word))
30 | if include_word_start_end:
31 | char_vocab.add_word_lst(['', ''])
32 | return char_vocab
33 |
34 |
35 | def get_embeddings(init_embed, padding_idx=None):
36 | r"""
37 | 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray
38 | 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理
39 | 返回原对象。
40 |
41 | :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入
42 | nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化;
43 | 传入torch.Tensor, 将使用传入的值作为Embedding初始化。
44 | :param padding_idx: 当传入tuple时,padding_idx有效
45 | :return nn.Embedding: embeddings
46 | """
47 | if isinstance(init_embed, tuple):
48 | res = nn.Embedding(
49 | num_embeddings=init_embed[0], embedding_dim=init_embed[1], padding_idx=padding_idx)
50 | nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)),
51 | b=np.sqrt(3 / res.weight.data.size(1)))
52 | elif isinstance(init_embed, nn.Module):
53 | res = init_embed
54 | elif isinstance(init_embed, torch.Tensor):
55 | res = nn.Embedding.from_pretrained(init_embed, freeze=False)
56 | elif isinstance(init_embed, np.ndarray):
57 | init_embed = torch.tensor(init_embed, dtype=torch.float32)
58 | res = nn.Embedding.from_pretrained(init_embed, freeze=False)
59 | else:
60 | raise TypeError(
61 | 'invalid init_embed type: {}'.format((type(init_embed))))
62 | return res
63 |
64 |
65 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
66 | """
67 | sinusoid的embedding,其中position的表示中,偶数维(0,2,4,...)是sin, 奇数(1,3,5...)是cos
68 |
69 | :param int n_position: 一共多少个position
70 | :param int d_hid: 多少维度,需要为偶数
71 | :param padding_idx:
72 | :return: torch.FloatTensor, shape为n_position x d_hid
73 | """
74 |
75 | def cal_angle(position, hid_idx):
76 | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
77 |
78 | def get_posi_angle_vec(position):
79 | return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
80 |
81 | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)])
82 |
83 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
84 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
85 |
86 | if padding_idx is not None:
87 | # zero vector for padding dimension
88 | sinusoid_table[padding_idx] = 0.
89 |
90 | return torch.FloatTensor(sinusoid_table)
91 |
92 |
--------------------------------------------------------------------------------
/fastNLP/io/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 用于IO的模块, 具体包括:
3 |
4 | 1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类,
5 |
6 | 2. 用于读入不同格式数据的 :mod:`Loader ` 类
7 |
8 | 3. 用于处理读入数据的 :mod:`Pipe ` 类
9 |
10 | 4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 `
11 |
12 | 这些类的使用方法如下:
13 | """
14 | __all__ = [
15 | 'DataBundle',
16 |
17 | 'EmbedLoader',
18 |
19 | 'Loader',
20 |
21 | 'CLSBaseLoader',
22 | 'AGsNewsLoader',
23 | 'DBPediaLoader',
24 | 'YelpFullLoader',
25 | 'YelpPolarityLoader',
26 | 'IMDBLoader',
27 | 'SSTLoader',
28 | 'SST2Loader',
29 | "ChnSentiCorpLoader",
30 | "THUCNewsLoader",
31 | "WeiboSenti100kLoader",
32 |
33 | 'ConllLoader',
34 | 'Conll2003Loader',
35 | 'Conll2003NERLoader',
36 | 'OntoNotesNERLoader',
37 | 'CTBLoader',
38 | "MsraNERLoader",
39 | "WeiboNERLoader",
40 | "PeopleDailyNERLoader",
41 |
42 | 'CSVLoader',
43 | 'JsonLoader',
44 |
45 | 'CWSLoader',
46 |
47 | 'MNLILoader',
48 | "QuoraLoader",
49 | "SNLILoader",
50 | "QNLILoader",
51 | "RTELoader",
52 | "CNXNLILoader",
53 | "BQCorpusLoader",
54 | "LCQMCLoader",
55 |
56 | "CMRC2018Loader",
57 |
58 | "Pipe",
59 |
60 | "CLSBasePipe",
61 | "AGsNewsPipe",
62 | "DBPediaPipe",
63 | "YelpFullPipe",
64 | "YelpPolarityPipe",
65 | "SSTPipe",
66 | "SST2Pipe",
67 | "IMDBPipe",
68 | "ChnSentiCorpPipe",
69 | "THUCNewsPipe",
70 | "WeiboSenti100kPipe",
71 |
72 | "Conll2003Pipe",
73 | "Conll2003NERPipe",
74 | "OntoNotesNERPipe",
75 | "MsraNERPipe",
76 | "PeopleDailyPipe",
77 | "WeiboNERPipe",
78 |
79 | "CWSPipe",
80 |
81 | "Conll2003NERPipe",
82 | "OntoNotesNERPipe",
83 | "MsraNERPipe",
84 | "WeiboNERPipe",
85 | "PeopleDailyPipe",
86 | "Conll2003Pipe",
87 |
88 | "MatchingBertPipe",
89 | "RTEBertPipe",
90 | "SNLIBertPipe",
91 | "QuoraBertPipe",
92 | "QNLIBertPipe",
93 | "MNLIBertPipe",
94 | "CNXNLIBertPipe",
95 | "BQCorpusBertPipe",
96 | "LCQMCBertPipe",
97 | "MatchingPipe",
98 | "RTEPipe",
99 | "SNLIPipe",
100 | "QuoraPipe",
101 | "QNLIPipe",
102 | "MNLIPipe",
103 | "LCQMCPipe",
104 | "CNXNLIPipe",
105 | "BQCorpusPipe",
106 | "RenamePipe",
107 | "GranularizePipe",
108 | "MachingTruncatePipe",
109 |
110 | "CMRC2018BertPipe",
111 |
112 | 'ModelLoader',
113 | 'ModelSaver',
114 |
115 | ]
116 |
117 | import sys
118 |
119 | from .data_bundle import DataBundle
120 | from .embed_loader import EmbedLoader
121 | from .loader import *
122 | from .model_io import ModelLoader, ModelSaver
123 | from .pipe import *
124 | from ..doc_utils import doc_process
125 |
126 | doc_process(sys.modules[__name__])
--------------------------------------------------------------------------------
/fastNLP/io/loader/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的
3 | 三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式,
4 | 读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ;
5 | ``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数:
6 |
7 | 0.传入None
8 | 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。
9 |
10 | 1.传入一个文件的 path
11 | 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.get_dataset('train')`` 获取
12 |
13 | 2.传入一个文件夹目录
14 | 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为::
15 |
16 | |
17 | +-train3.txt
18 | +-dev.txt
19 | +-test.txt
20 | +-other.txt
21 |
22 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` ,
23 | ``data_bundle.get_dataset('dev')`` ,
24 | ``data_bundle.get_dataset('test')`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为::
25 |
26 | |
27 | +-train3.txt
28 | +-dev.txt
29 |
30 | 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` ,
31 | ``data_bundle.get_dataset('dev')`` 获取对应的 dataset。
32 |
33 | 3.传入一个字典
34 | 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径::
35 |
36 | paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'}
37 |
38 | 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` ,
39 | ``data_bundle.get_dataset('test')`` 来获取对应的 `dataset`
40 |
41 | fastNLP 目前提供了如下的 Loader
42 |
43 |
44 |
45 | """
46 |
47 | __all__ = [
48 | 'Loader',
49 |
50 | 'CLSBaseLoader',
51 | 'YelpFullLoader',
52 | 'YelpPolarityLoader',
53 | 'AGsNewsLoader',
54 | 'DBPediaLoader',
55 | 'IMDBLoader',
56 | 'SSTLoader',
57 | 'SST2Loader',
58 | "ChnSentiCorpLoader",
59 | "THUCNewsLoader",
60 | "WeiboSenti100kLoader",
61 |
62 | 'ConllLoader',
63 | 'Conll2003Loader',
64 | 'Conll2003NERLoader',
65 | 'OntoNotesNERLoader',
66 | 'CTBLoader',
67 | "MsraNERLoader",
68 | "PeopleDailyNERLoader",
69 | "WeiboNERLoader",
70 |
71 | 'CSVLoader',
72 | 'JsonLoader',
73 |
74 | 'CWSLoader',
75 |
76 | 'MNLILoader',
77 | "QuoraLoader",
78 | "SNLILoader",
79 | "QNLILoader",
80 | "RTELoader",
81 | "CNXNLILoader",
82 | "BQCorpusLoader",
83 | "LCQMCLoader",
84 |
85 | "CoReferenceLoader",
86 |
87 | "CMRC2018Loader"
88 | ]
89 | from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \
90 | SSTLoader, SST2Loader, DBPediaLoader, \
91 | ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
92 | from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
93 | from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
94 | from .coreference import CoReferenceLoader
95 | from .csv import CSVLoader
96 | from .cws import CWSLoader
97 | from .json import JsonLoader
98 | from .loader import Loader
99 | from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \
100 | LCQMCLoader
101 | from .qa import CMRC2018Loader
102 |
103 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/coreference.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "CoReferenceLoader",
5 | ]
6 |
7 | from ...core.dataset import DataSet
8 | from ..file_reader import _read_json
9 | from ...core.instance import Instance
10 | from ...core.const import Const
11 | from .json import JsonLoader
12 |
13 |
14 | class CoReferenceLoader(JsonLoader):
15 | r"""
16 | 原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。
17 |
18 | Example::
19 |
20 | {"doc_key": "bc/cctv/00/cctv_0000_0",
21 | "speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]],
22 | "clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]],
23 | "sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]]
24 | }
25 |
26 | 读取预处理好的Conll2012数据,数据结构如下:
27 |
28 | .. csv-table::
29 | :header: "raw_words1", "raw_words2", "raw_words3", "raw_words4"
30 |
31 | "bc/cctv/00/cctv_0000_0", "[['Speaker#1', 'Speaker#1', 'Speaker#1...", "[[[70, 70], [485, 486], [500, 500], [7...", "[['In', 'the', 'summer', 'of', '2005',..."
32 | "...", "...", "...", "..."
33 |
34 | """
35 | def __init__(self, fields=None, dropna=False):
36 | super().__init__(fields, dropna)
37 | self.fields = {"doc_key": Const.RAW_WORDS(0), "speakers": Const.RAW_WORDS(1), "clusters": Const.RAW_WORDS(2),
38 | "sentences": Const.RAW_WORDS(3)}
39 |
40 | def _load(self, path):
41 | r"""
42 | 加载数据
43 | :param path: 数据文件路径,文件为json
44 |
45 | :return:
46 | """
47 | dataset = DataSet()
48 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
49 | if self.fields:
50 | ins = {self.fields[k]: v for k, v in d.items()}
51 | else:
52 | ins = d
53 | dataset.append(Instance(**ins))
54 | return dataset
55 |
56 | def download(self):
57 | r"""
58 | 由于版权限制,不能提供自动下载功能。可参考
59 |
60 | https://www.aclweb.org/anthology/W12-4501
61 |
62 | :return:
63 | """
64 | raise RuntimeError("CoReference cannot be downloaded automatically.")
65 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/csv.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "CSVLoader",
5 | ]
6 |
7 | from .loader import Loader
8 | from ..file_reader import _read_csv
9 | from ...core.dataset import DataSet
10 | from ...core.instance import Instance
11 |
12 |
13 | class CSVLoader(Loader):
14 | r"""
15 | 读取CSV格式的数据集, 返回 ``DataSet`` 。
16 |
17 | """
18 |
19 | def __init__(self, headers=None, sep=",", dropna=False):
20 | r"""
21 |
22 | :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
23 | 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
24 | :param str sep: CSV文件中列与列之间的分隔符. Default: ","
25 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
26 | Default: ``False``
27 | """
28 | super().__init__()
29 | self.headers = headers
30 | self.sep = sep
31 | self.dropna = dropna
32 |
33 | def _load(self, path):
34 | ds = DataSet()
35 | for idx, data in _read_csv(path, headers=self.headers,
36 | sep=self.sep, dropna=self.dropna):
37 | ds.append(Instance(**data))
38 | return ds
39 |
40 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/cws.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "CWSLoader"
5 | ]
6 |
7 | import glob
8 | import os
9 | import random
10 | import shutil
11 | import time
12 |
13 | from .loader import Loader
14 | from ...core.dataset import DataSet
15 | from ...core.instance import Instance
16 |
17 |
18 | class CWSLoader(Loader):
19 | r"""
20 | CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如:
21 |
22 | Example::
23 |
24 | 上海 浦东 开发 与 法制 建设 同步
25 | 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
26 | ...
27 |
28 | 该Loader读取后的DataSet具有如下的结构
29 |
30 | .. csv-table::
31 | :header: "raw_words"
32 |
33 | "上海 浦东 开发 与 法制 建设 同步"
34 | "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
35 | "..."
36 |
37 | """
38 | def __init__(self, dataset_name:str=None):
39 | r"""
40 |
41 | :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None
42 | """
43 | super().__init__()
44 | datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'}
45 | if dataset_name in datanames:
46 | self.dataset_name = datanames[dataset_name]
47 | else:
48 | self.dataset_name = None
49 |
50 | def _load(self, path:str):
51 | ds = DataSet()
52 | with open(path, 'r', encoding='utf-8') as f:
53 | for line in f:
54 | line = line.strip()
55 | if line:
56 | ds.append(Instance(raw_words=line))
57 | return ds
58 |
59 | def download(self, dev_ratio=0.1, re_download=False)->str:
60 | r"""
61 | 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff,
62 | 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看
63 |
64 | :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
65 | :param bool re_download: 是否重新下载数据,以重新切分数据。
66 | :return: str
67 | """
68 | if self.dataset_name is None:
69 | return None
70 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name)
71 | modify_time = 0
72 | for filepath in glob.glob(os.path.join(data_dir, '*')):
73 | modify_time = os.stat(filepath).st_mtime
74 | break
75 | if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
76 | shutil.rmtree(data_dir)
77 | data_dir = self._get_dataset_path(dataset_name=self.dataset_name)
78 |
79 | if not os.path.exists(os.path.join(data_dir, 'dev.txt')):
80 | if dev_ratio > 0:
81 | assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
82 | try:
83 | with open(os.path.join(data_dir, 'train3.txt'), 'r', encoding='utf-8') as f, \
84 | open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
85 | open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
86 | for line in f:
87 | if random.random() < dev_ratio:
88 | f2.write(line)
89 | else:
90 | f1.write(line)
91 | os.remove(os.path.join(data_dir, 'train3.txt'))
92 | os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train3.txt'))
93 | finally:
94 | if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
95 | os.remove(os.path.join(data_dir, 'middle_file.txt'))
96 |
97 | return data_dir
98 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/json.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "JsonLoader"
5 | ]
6 |
7 | from .loader import Loader
8 | from ..file_reader import _read_json
9 | from ...core.dataset import DataSet
10 | from ...core.instance import Instance
11 |
12 |
13 | class JsonLoader(Loader):
14 | r"""
15 | 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader`
16 |
17 | 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象
18 |
19 | :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
20 | ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
21 | `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
22 | ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
23 | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
24 | Default: ``False``
25 | """
26 |
27 | def __init__(self, fields=None, dropna=False):
28 | super(JsonLoader, self).__init__()
29 | self.dropna = dropna
30 | self.fields = None
31 | self.fields_list = None
32 | if fields:
33 | self.fields = {}
34 | for k, v in fields.items():
35 | self.fields[k] = k if v is None else v
36 | self.fields_list = list(self.fields.keys())
37 |
38 | def _load(self, path):
39 | ds = DataSet()
40 | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
41 | if self.fields:
42 | ins = {self.fields[k]: v for k, v in d.items()}
43 | else:
44 | ins = d
45 | ds.append(Instance(**ins))
46 | return ds
47 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/loader.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "Loader"
5 | ]
6 |
7 | from typing import Union, Dict
8 |
9 | from .. import DataBundle
10 | from ..file_utils import _get_dataset_url, get_cache_path, cached_path
11 | from ..utils import check_loader_paths
12 | from ...core.dataset import DataSet
13 |
14 |
15 | class Loader:
16 | r"""
17 | 各种数据 Loader 的基类,提供了 API 的参考.
18 | Loader支持以下的三个函数
19 |
20 | - download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。
21 | - _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。
22 | - load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回
23 |
24 | """
25 |
26 | def __init__(self):
27 | pass
28 |
29 | def _load(self, path: str) -> DataSet:
30 | r"""
31 | 给定一个路径,返回读取的DataSet。
32 |
33 | :param str path: 路径
34 | :return: DataSet
35 | """
36 | raise NotImplementedError
37 |
38 | def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
39 | r"""
40 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。
41 |
42 | :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式:
43 |
44 | 0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。
45 |
46 | 1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错::
47 |
48 | data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train
49 | # dev、 test等有所变化,可以通过以下的方式取出DataSet
50 | tr_data = data_bundle.get_dataset('train')
51 | te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段
52 |
53 | 2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test::
54 |
55 | paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
56 | data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
57 | dev_data = data_bundle.get_dataset('dev')
58 |
59 | 3.传入文件路径::
60 |
61 | data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
62 | tr_data = data_bundle.get_dataset('train') # 取出DataSet
63 |
64 | :return: 返回的 :class:`~fastNLP.io.DataBundle`
65 | """
66 | if paths is None:
67 | paths = self.download()
68 | paths = check_loader_paths(paths)
69 | datasets = {name: self._load(path) for name, path in paths.items()}
70 | data_bundle = DataBundle(datasets=datasets)
71 | return data_bundle
72 |
73 | def download(self) -> str:
74 | r"""
75 | 自动下载该数据集
76 |
77 | :return: 下载后解压目录
78 | """
79 | raise NotImplementedError(f"{self.__class__} cannot download data automatically.")
80 |
81 | @staticmethod
82 | def _get_dataset_path(dataset_name):
83 | r"""
84 | 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话)
85 |
86 | :param str dataset_name: 数据集的名称
87 | :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
88 | """
89 |
90 | default_cache_path = get_cache_path()
91 | url = _get_dataset_url(dataset_name)
92 | output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')
93 |
94 | return output_dir
95 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/qa.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 该文件中的Loader主要用于读取问答式任务的数据
3 |
4 | """
5 |
6 |
7 | from . import Loader
8 | import json
9 | from ...core import DataSet, Instance
10 |
11 | __all__ = ['CMRC2018Loader']
12 |
13 |
14 | class CMRC2018Loader(Loader):
15 | r"""
16 | 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测
17 |
18 | 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个
19 |
20 | .. csv-table::
21 | :header:"title", "context", "question", "answers", "answer_starts", "id"
22 |
23 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0"
24 | "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1"
25 | "...", "...", "...","...", ".", "..."
26 |
27 | 其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性
28 |
29 | 验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案)
30 |
31 | .. csv-table::
32 | :header: "title", "context", "question", "answers", "answer_starts", "id"
33 |
34 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", "['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']", "[30, 30, 30]", "DEV_0_QUERY_0"
35 | "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1"
36 | "...", "...", "...","...", ".", "..."
37 |
38 | 其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为
39 | 英文和数字都直接按照character计量的。
40 | """
41 | def __init__(self):
42 | super().__init__()
43 |
44 | def _load(self, path: str) -> DataSet:
45 | with open(path, 'r', encoding='utf-8') as f:
46 | data = json.load(f)['data']
47 | ds = DataSet()
48 | for entry in data:
49 | title = entry['title']
50 | para = entry['paragraphs'][0]
51 | context = para['context']
52 | qas = para['qas']
53 | for qa in qas:
54 | question = qa['question']
55 | ans = qa['answers']
56 | answers = []
57 | answer_starts = []
58 | id = qa['id']
59 | for an in ans:
60 | answers.append(an['text'])
61 | answer_starts.append(an['answer_start'])
62 | ds.append(Instance(title=title, context=context, question=question, answers=answers,
63 | answer_starts=answer_starts,id=id))
64 | return ds
65 |
66 | def download(self) -> str:
67 | r"""
68 | 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc.
69 |
70 | :return:
71 | """
72 | output_dir = self._get_dataset_path('cmrc2018')
73 | return output_dir
74 |
75 |
--------------------------------------------------------------------------------
/fastNLP/io/loader/summarization.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "ExtCNNDMLoader"
5 | ]
6 |
7 | import os
8 | from typing import Union, Dict
9 |
10 | from ..data_bundle import DataBundle
11 | from ..utils import check_loader_paths
12 | from .json import JsonLoader
13 |
14 |
15 | class ExtCNNDMLoader(JsonLoader):
16 | r"""
17 | 读取之后的DataSet中的field情况为
18 |
19 | .. csv-table::
20 | :header: "text", "summary", "label", "publication"
21 |
22 | ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm"
23 | ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm"
24 | ["..."], ["..."], [], "cnndm"
25 |
26 | """
27 |
28 | def __init__(self, fields=None):
29 | fields = fields or {"text": None, "summary": None, "label": None, "publication": None}
30 | super(ExtCNNDMLoader, self).__init__(fields=fields)
31 |
32 | def load(self, paths: Union[str, Dict[str, str]] = None):
33 | r"""
34 | 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。
35 |
36 | 读取的field根据ExtCNNDMLoader初始化时传入的headers决定。
37 |
38 | :param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl
39 | test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe`
40 | 当中需要用到)。
41 |
42 | :return: 返回 :class:`~fastNLP.io.DataBundle`
43 | """
44 | if paths is None:
45 | paths = self.download()
46 | paths = check_loader_paths(paths)
47 | if ('train' in paths) and ('test' not in paths):
48 | paths['test'] = paths['train']
49 | paths.pop('train')
50 |
51 | datasets = {name: self._load(path) for name, path in paths.items()}
52 | data_bundle = DataBundle(datasets=datasets)
53 | return data_bundle
54 |
55 | def download(self):
56 | r"""
57 | 如果你使用了这个数据,请引用
58 |
59 | https://arxiv.org/pdf/1506.03340.pdf
60 | :return:
61 | """
62 | output_dir = self._get_dataset_path('ext-cnndm')
63 | return output_dir
64 |
--------------------------------------------------------------------------------
/fastNLP/io/model_io.py:
--------------------------------------------------------------------------------
1 | r"""
2 | 用于载入和保存模型
3 | """
4 | __all__ = [
5 | "ModelLoader",
6 | "ModelSaver"
7 | ]
8 |
9 | import torch
10 |
11 |
12 | class ModelLoader:
13 | r"""
14 | 用于读取模型
15 | """
16 |
17 | def __init__(self):
18 | super(ModelLoader, self).__init__()
19 |
20 | @staticmethod
21 | def load_pytorch(empty_model, model_path):
22 | r"""
23 | 从 ".pkl" 文件读取 PyTorch 模型
24 |
25 | :param empty_model: 初始化参数的 PyTorch 模型
26 | :param str model_path: 模型保存的路径
27 | """
28 | empty_model.load_state_dict(torch.load(model_path))
29 |
30 | @staticmethod
31 | def load_pytorch_model(model_path):
32 | r"""
33 | 读取整个模型
34 |
35 | :param str model_path: 模型保存的路径
36 | """
37 | return torch.load(model_path)
38 |
39 |
40 | class ModelSaver(object):
41 | r"""
42 | 用于保存模型
43 |
44 | Example::
45 |
46 | saver = ModelSaver("./save/model_ckpt_100.pkl")
47 | saver.save_pytorch(model)
48 |
49 | """
50 |
51 | def __init__(self, save_path):
52 | r"""
53 |
54 | :param save_path: 模型保存的路径
55 | """
56 | self.save_path = save_path
57 |
58 | def save_pytorch(self, model, param_only=True):
59 | r"""
60 | 把 PyTorch 模型存入 ".pkl" 文件
61 |
62 | :param model: PyTorch 模型
63 | :param bool param_only: 是否只保存模型的参数(否则保存整个模型)
64 |
65 | """
66 | if param_only is True:
67 | torch.save(model.state_dict(), self.save_path)
68 | else:
69 | torch.save(model, self.save_path)
70 |
--------------------------------------------------------------------------------
/fastNLP/io/pipe/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。
3 | ``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回;
4 | ``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。
5 | ``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet`
6 | 一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外,
7 | `data_bundle` 还会包含将field转为index时所建立的词表。
8 |
9 | """
10 | __all__ = [
11 | "Pipe",
12 |
13 | "CWSPipe",
14 |
15 | "CLSBasePipe",
16 | "AGsNewsPipe",
17 | "DBPediaPipe",
18 | "YelpFullPipe",
19 | "YelpPolarityPipe",
20 | "SSTPipe",
21 | "SST2Pipe",
22 | "IMDBPipe",
23 | "ChnSentiCorpPipe",
24 | "THUCNewsPipe",
25 | "WeiboSenti100kPipe",
26 |
27 | "Conll2003NERPipe",
28 | "OntoNotesNERPipe",
29 | "MsraNERPipe",
30 | "WeiboNERPipe",
31 | "PeopleDailyPipe",
32 | "Conll2003Pipe",
33 |
34 | "MatchingBertPipe",
35 | "RTEBertPipe",
36 | "SNLIBertPipe",
37 | "QuoraBertPipe",
38 | "QNLIBertPipe",
39 | "MNLIBertPipe",
40 | "CNXNLIBertPipe",
41 | "BQCorpusBertPipe",
42 | "LCQMCBertPipe",
43 | "MatchingPipe",
44 | "RTEPipe",
45 | "SNLIPipe",
46 | "QuoraPipe",
47 | "QNLIPipe",
48 | "MNLIPipe",
49 | "LCQMCPipe",
50 | "CNXNLIPipe",
51 | "BQCorpusPipe",
52 | "RenamePipe",
53 | "GranularizePipe",
54 | "MachingTruncatePipe",
55 |
56 | "CoReferencePipe",
57 |
58 | "CMRC2018BertPipe"
59 | ]
60 |
61 | from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
62 | WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe
63 | from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
64 | from .conll import Conll2003Pipe
65 | from .coreference import CoReferencePipe
66 | from .cws import CWSPipe
67 | from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
68 | MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \
69 | LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe
70 | from .pipe import Pipe
71 | from .qa import CMRC2018BertPipe
72 |
--------------------------------------------------------------------------------
/fastNLP/io/pipe/pipe.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "Pipe",
5 | ]
6 |
7 | from .. import DataBundle
8 |
9 |
10 | class Pipe:
11 | r"""
12 | Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe
13 | 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。
14 |
15 | 一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字;
16 | (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index;
17 |
18 | Pipe中提供了两个方法
19 |
20 | -process()函数,输入为DataBundle
21 | -process_from_file()函数,输入为对应Loader的load函数可接受的类型。
22 |
23 | """
24 |
25 | def process(self, data_bundle: DataBundle) -> DataBundle:
26 | r"""
27 | 对输入的DataBundle进行处理,然后返回该DataBundle。
28 |
29 | :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象
30 | :return:
31 | """
32 | raise NotImplementedError
33 |
34 | def process_from_file(self, paths) -> DataBundle:
35 | r"""
36 | 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
37 |
38 | :param paths:
39 | :return: DataBundle
40 | """
41 | raise NotImplementedError
42 |
--------------------------------------------------------------------------------
/fastNLP/io/utils.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "check_loader_paths"
8 | ]
9 |
10 | import os
11 | from pathlib import Path
12 | from typing import Union, Dict
13 |
14 | from ..core import logger
15 |
16 |
17 | def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
18 | r"""
19 | 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::
20 |
21 | {
22 | 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
23 | 'test': 'xxx' # 可能有,也可能没有
24 | ...
25 | }
26 |
27 | 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。
28 |
29 | :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名
30 | 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
31 | :return:
32 | """
33 | if isinstance(paths, (str, Path)):
34 | paths = os.path.abspath(os.path.expanduser(paths))
35 | if os.path.isfile(paths):
36 | return {'train': paths}
37 | elif os.path.isdir(paths):
38 | filenames = os.listdir(paths)
39 | filenames.sort()
40 | files = {}
41 | for filename in filenames:
42 | path_pair = None
43 | if 'train' in filename:
44 | path_pair = ('train', filename)
45 | if 'dev' in filename:
46 | if path_pair:
47 | raise Exception(
48 | "Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
49 | path_pair = ('dev', filename)
50 | if 'test' in filename:
51 | if path_pair:
52 | raise Exception(
53 | "Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
54 | path_pair = ('test', filename)
55 | if path_pair:
56 | if path_pair[0] in files:
57 | raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the "
58 | f"filepath for `{path_pair[0]}`.")
59 | files[path_pair[0]] = os.path.join(paths, path_pair[1])
60 | if 'train' not in files:
61 | raise KeyError(f"There is no train file in {paths}.")
62 | return files
63 | else:
64 | raise FileNotFoundError(f"{paths} is not a valid file path.")
65 |
66 | elif isinstance(paths, dict):
67 | if paths:
68 | if 'train' not in paths:
69 | raise KeyError("You have to include `train` in your dict.")
70 | for key, value in paths.items():
71 | if isinstance(key, str) and isinstance(value, str):
72 | value = os.path.abspath(os.path.expanduser(value))
73 | if not os.path.exists(value):
74 | raise TypeError(f"{value} is not a valid path.")
75 | paths[key] = value
76 | else:
77 | raise TypeError("All keys and values in paths should be str.")
78 | return paths
79 | else:
80 | raise ValueError("Empty paths is not allowed.")
81 | else:
82 | raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
83 |
--------------------------------------------------------------------------------
/fastNLP/models/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、
3 | :class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。
4 |
5 | .. todo::
6 | 这些模型的介绍(与主页一致)
7 |
8 |
9 | """
10 | __all__ = [
11 | "CNNText",
12 |
13 | "SeqLabeling",
14 | "AdvSeqLabel",
15 | "BiLSTMCRF",
16 |
17 | "ESIM",
18 |
19 | "StarTransEnc",
20 | "STSeqLabel",
21 | "STNLICls",
22 | "STSeqCls",
23 |
24 | "BiaffineParser",
25 | "GraphParser",
26 |
27 | "BertForSequenceClassification",
28 | "BertForSentenceMatching",
29 | "BertForMultipleChoice",
30 | "BertForTokenClassification",
31 | "BertForQuestionAnswering",
32 |
33 | "TransformerSeq2SeqModel",
34 | "LSTMSeq2SeqModel",
35 | "Seq2SeqModel",
36 |
37 | 'SequenceGeneratorModel'
38 | ]
39 |
40 | from .base_model import BaseModel
41 | from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \
42 | BertForTokenClassification, BertForSentenceMatching
43 | from .biaffine_parser import BiaffineParser, GraphParser
44 | from .cnn_text_classification import CNNText
45 | from .sequence_labeling import SeqLabeling, AdvSeqLabel, BiLSTMCRF
46 | from .snli import ESIM
47 | from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel
48 | from .seq2seq_model import TransformerSeq2SeqModel, LSTMSeq2SeqModel, Seq2SeqModel
49 | from .seq2seq_generator import SequenceGeneratorModel
50 | import sys
51 | from ..doc_utils import doc_process
52 |
53 | doc_process(sys.modules[__name__])
54 |
--------------------------------------------------------------------------------
/fastNLP/models/base_model.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = []
4 |
5 | import torch
6 |
7 | from ..modules.decoder.mlp import MLP
8 |
9 |
10 | class BaseModel(torch.nn.Module):
11 | r"""Base PyTorch model for all models.
12 | """
13 |
14 | def __init__(self):
15 | super(BaseModel, self).__init__()
16 |
17 | def fit(self, train_data, dev_data=None, **train_args):
18 | pass
19 |
20 | def predict(self, *args, **kwargs):
21 | raise NotImplementedError
22 |
23 |
24 | class NaiveClassifier(BaseModel):
25 | r"""
26 | 一个简单的分类器例子,可用于各种测试
27 | """
28 | def __init__(self, in_feature_dim, out_feature_dim):
29 | super(NaiveClassifier, self).__init__()
30 | self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])
31 |
32 | def forward(self, x):
33 | return {"predict": torch.sigmoid(self.mlp(x))}
34 |
35 | def predict(self, x):
36 | return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}
37 |
--------------------------------------------------------------------------------
/fastNLP/models/cnn_text_classification.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "CNNText"
8 | ]
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from ..core.const import Const as C
14 | from ..core.utils import seq_len_to_mask
15 | from ..embeddings import embedding
16 | from ..modules import encoder
17 |
18 |
19 | class CNNText(torch.nn.Module):
20 | r"""
21 | 使用CNN进行文本分类的模型
22 | 'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
23 |
24 | """
25 |
26 | def __init__(self, embed,
27 | num_classes,
28 | kernel_nums=(30, 40, 50),
29 | kernel_sizes=(1, 3, 5),
30 | dropout=0.5):
31 | r"""
32 |
33 | :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
34 | 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
35 | :param int num_classes: 一共有多少类
36 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
37 | :param float dropout: Dropout的大小
38 | """
39 | super(CNNText, self).__init__()
40 |
41 | # no support for pre-trained embedding currently
42 | self.embed = embedding.Embedding(embed)
43 | self.conv_pool = encoder.ConvMaxpool(
44 | in_channels=self.embed.embedding_dim,
45 | out_channels=kernel_nums,
46 | kernel_sizes=kernel_sizes)
47 | self.dropout = nn.Dropout(dropout)
48 | self.fc = nn.Linear(sum(kernel_nums), num_classes)
49 |
50 | def forward(self, words, seq_len=None):
51 | r"""
52 |
53 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index
54 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度
55 | :return output: dict of torch.LongTensor, [batch_size, num_classes]
56 | """
57 | x = self.embed(words) # [N,L] -> [N,L,C]
58 | if seq_len is not None:
59 | mask = seq_len_to_mask(seq_len)
60 | x = self.conv_pool(x, mask)
61 | else:
62 | x = self.conv_pool(x) # [N,L,C] -> [N,C]
63 | x = self.dropout(x)
64 | x = self.fc(x) # [N,C] -> [N, N_class]
65 | return {C.OUTPUT: x}
66 |
67 | def predict(self, words, seq_len=None):
68 | r"""
69 | :param torch.LongTensor words: [batch_size, seq_len],句子中word的index
70 | :param torch.LongTensor seq_len: [batch,] 每个句子的长度
71 |
72 | :return predict: dict of torch.LongTensor, [batch_size, ]
73 | """
74 | output = self(words, seq_len)
75 | _, predict = output[C.OUTPUT].max(dim=1)
76 | return {C.OUTPUT: predict}
77 |
--------------------------------------------------------------------------------
/fastNLP/models/seq2seq_generator.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | import torch
4 | from torch import nn
5 | from .seq2seq_model import Seq2SeqModel
6 | from ..modules.generator.seq2seq_generator import SequenceGenerator
7 |
8 |
9 | class SequenceGeneratorModel(nn.Module):
10 | """
11 | 用于封装Seq2SeqModel使其可以做生成任务
12 |
13 | """
14 |
15 | def __init__(self, seq2seq_model: Seq2SeqModel, bos_token_id, eos_token_id=None, max_length=30, max_len_a=0.0,
16 | num_beams=1, do_sample=True, temperature=1.0, top_k=50, top_p=1.0,
17 | repetition_penalty=1, length_penalty=1.0, pad_token_id=0):
18 | """
19 |
20 | :param Seq2SeqModel seq2seq_model: 序列到序列模型
21 | :param int,None bos_token_id: 句子开头的token id
22 | :param int,None eos_token_id: 句子结束的token id
23 | :param int max_length: 生成句子的最大长度, 每句话的decode长度为max_length + max_len_a*src_len
24 | :param float max_len_a: 每句话的decode长度为max_length + max_len_a*src_len。 如果不为0,需要保证State中包含encoder_mask
25 | :param int num_beams: beam search的大小
26 | :param bool do_sample: 是否通过采样的方式生成
27 | :param float temperature: 只有在do_sample为True才有意义
28 | :param int top_k: 只从top_k中采样
29 | :param float top_p: 只从top_p的token中采样,nucles sample
30 | :param float repetition_penalty: 多大程度上惩罚重复的token
31 | :param float length_penalty: 对长度的惩罚,小于1鼓励长句,大于1鼓励短剧
32 | :param int pad_token_id: 当某句话生成结束之后,之后生成的内容用pad_token_id补充
33 | """
34 | super().__init__()
35 | self.seq2seq_model = seq2seq_model
36 | self.generator = SequenceGenerator(seq2seq_model.decoder, max_length=max_length, max_len_a=max_len_a,
37 | num_beams=num_beams,
38 | do_sample=do_sample, temperature=temperature, top_k=top_k, top_p=top_p,
39 | bos_token_id=bos_token_id,
40 | eos_token_id=eos_token_id,
41 | repetition_penalty=repetition_penalty, length_penalty=length_penalty,
42 | pad_token_id=pad_token_id)
43 |
44 | def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None):
45 | """
46 | 透传调用seq2seq_model的forward
47 |
48 | :param torch.LongTensor src_tokens: bsz x max_len
49 | :param torch.LongTensor tgt_tokens: bsz x max_len'
50 | :param torch.LongTensor src_seq_len: bsz
51 | :param torch.LongTensor tgt_seq_len: bsz
52 | :return:
53 | """
54 | return self.seq2seq_model(src_tokens, tgt_tokens, src_seq_len, tgt_seq_len)
55 |
56 | def predict(self, src_tokens, src_seq_len=None):
57 | """
58 | 给定source的内容,输出generate的内容
59 |
60 | :param torch.LongTensor src_tokens: bsz x max_len
61 | :param torch.LongTensor src_seq_len: bsz
62 | :return:
63 | """
64 | state = self.seq2seq_model.prepare_state(src_tokens, src_seq_len)
65 | result = self.generator.generate(state)
66 | return {'pred': result}
67 |
--------------------------------------------------------------------------------
/fastNLP/modules/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 |
3 | .. image:: figures/text_classification.png
4 |
5 | 大部分用于的 NLP 任务神经网络都可以看做由 :mod:`embedding` 、 :mod:`~fastNLP.modules.encoder` 、
6 | :mod:`~fastNLP.modules.decoder` 三种模块组成。 本模块中实现了 fastNLP 提供的诸多模块组件,
7 | 可以帮助用户快速搭建自己所需的网络。几种模块的功能和常见组件如下:
8 |
9 | .. csv-table::
10 | :header: "类型", "功能", "常见组件"
11 |
12 | "embedding", 参见 :mod:`/fastNLP.embeddings` , "Elmo, Bert"
13 | "encoder", "将输入编码为具有表示能力的向量", "CNN, LSTM, Transformer"
14 | "decoder", "将具有某种表示意义的向量解码为需要的输出形式 ", "MLP, CRF"
15 | "其它", "配合其它组件使用的组件", "Dropout"
16 |
17 |
18 | """
19 | __all__ = [
20 | # "BertModel",
21 |
22 | "ConvolutionCharEncoder",
23 | "LSTMCharEncoder",
24 |
25 | "ConvMaxpool",
26 |
27 | "LSTM",
28 |
29 | "StarTransformer",
30 |
31 | "TransformerEncoder",
32 |
33 | "VarRNN",
34 | "VarLSTM",
35 | "VarGRU",
36 |
37 | "MaxPool",
38 | "MaxPoolWithMask",
39 | "KMaxPool",
40 | "AvgPool",
41 | "AvgPoolWithMask",
42 |
43 | "MultiHeadAttention",
44 |
45 | "MLP",
46 | "ConditionalRandomField",
47 | "viterbi_decode",
48 | "allowed_transitions",
49 |
50 | "TimestepDropout",
51 |
52 | 'summary',
53 |
54 | "BertTokenizer",
55 | "BertModel",
56 |
57 | "RobertaTokenizer",
58 | "RobertaModel",
59 |
60 | "GPT2Model",
61 | "GPT2Tokenizer",
62 |
63 | "TransformerSeq2SeqEncoder",
64 | "LSTMSeq2SeqEncoder",
65 | "Seq2SeqEncoder",
66 |
67 | "TransformerSeq2SeqDecoder",
68 | "LSTMSeq2SeqDecoder",
69 | "Seq2SeqDecoder",
70 |
71 | "TransformerState",
72 | "LSTMState",
73 | "State",
74 |
75 | "SequenceGenerator"
76 | ]
77 |
78 | import sys
79 |
80 | from . import decoder
81 | from . import encoder
82 | from .decoder import *
83 | from .dropout import TimestepDropout
84 | from .encoder import *
85 | from .generator import *
86 | from .utils import summary
87 | from ..doc_utils import doc_process
88 | from .tokenizer import *
89 |
90 | doc_process(sys.modules[__name__])
91 |
--------------------------------------------------------------------------------
/fastNLP/modules/decoder/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 | __all__ = [
6 | "MLP",
7 | "ConditionalRandomField",
8 | "viterbi_decode",
9 | "allowed_transitions",
10 |
11 | "LSTMState",
12 | "TransformerState",
13 | "State",
14 |
15 | "TransformerSeq2SeqDecoder",
16 | "LSTMSeq2SeqDecoder",
17 | "Seq2SeqDecoder"
18 | ]
19 |
20 | from .crf import ConditionalRandomField
21 | from .crf import allowed_transitions
22 | from .mlp import MLP
23 | from .utils import viterbi_decode
24 | from .seq2seq_decoder import Seq2SeqDecoder, LSTMSeq2SeqDecoder, TransformerSeq2SeqDecoder
25 | from .seq2seq_state import State, LSTMState, TransformerState
26 |
--------------------------------------------------------------------------------
/fastNLP/modules/decoder/mlp.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "MLP"
5 | ]
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from ..utils import initial_parameter
11 |
12 |
13 | class MLP(nn.Module):
14 | r"""
15 | 多层感知器
16 |
17 |
18 | .. note::
19 | 隐藏层的激活函数通过activation定义。一个str/function或者一个str/function的list可以被传入activation。
20 | 如果只传入了一个str/function,那么所有隐藏层的激活函数都由这个str/function定义;
21 | 如果传入了一个str/function的list,那么每一个隐藏层的激活函数由这个list中对应的元素定义,其中list的长度为隐藏层数。
22 | 输出层的激活函数由output_activation定义,默认值为None,此时输出层没有激活函数。
23 |
24 | Examples::
25 |
26 | >>> net1 = MLP([5, 10, 5])
27 | >>> net2 = MLP([5, 10, 5], 'tanh')
28 | >>> net3 = MLP([5, 6, 7, 8, 5], 'tanh')
29 | >>> net4 = MLP([5, 6, 7, 8, 5], 'relu', output_activation='tanh')
30 | >>> net5 = MLP([5, 6, 7, 8, 5], ['tanh', 'relu', 'tanh'], 'tanh')
31 | >>> for net in [net1, net2, net3, net4, net5]:
32 | >>> x = torch.randn(5, 5)
33 | >>> y = net(x)
34 | >>> print(x)
35 | >>> print(y)
36 | """
37 |
38 | def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0):
39 | r"""
40 |
41 | :param List[int] size_layer: 一个int的列表,用来定义MLP的层数,列表中的数字为每一层是hidden数目。MLP的层数为 len(size_layer) - 1
42 | :param Union[str,func,List[str]] activation: 一个字符串或者函数的列表,用来定义每一个隐层的激活函数,字符串包括relu,tanh和
43 | sigmoid,默认值为relu
44 | :param Union[str,func] output_activation: 字符串或者函数,用来定义输出层的激活函数,默认值为None,表示输出层没有激活函数
45 | :param str initial_method: 参数初始化方式
46 | :param float dropout: dropout概率,默认值为0
47 | """
48 | super(MLP, self).__init__()
49 | self.hiddens = nn.ModuleList()
50 | self.output = None
51 | self.output_activation = output_activation
52 | for i in range(1, len(size_layer)):
53 | if i + 1 == len(size_layer):
54 | self.output = nn.Linear(size_layer[i - 1], size_layer[i])
55 | else:
56 | self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i]))
57 |
58 | self.dropout = nn.Dropout(p=dropout)
59 |
60 | actives = {
61 | 'relu': nn.ReLU(),
62 | 'tanh': nn.Tanh(),
63 | 'sigmoid': nn.Sigmoid(),
64 | }
65 | if not isinstance(activation, list):
66 | activation = [activation] * (len(size_layer) - 2)
67 | elif len(activation) == len(size_layer) - 2:
68 | pass
69 | else:
70 | raise ValueError(
71 | f"the length of activation function list except {len(size_layer) - 2} but got {len(activation)}!")
72 | self.hidden_active = []
73 | for func in activation:
74 | if callable(activation):
75 | self.hidden_active.append(activation)
76 | elif func.lower() in actives:
77 | self.hidden_active.append(actives[func])
78 | else:
79 | raise ValueError("should set activation correctly: {}".format(activation))
80 | if self.output_activation is not None:
81 | if callable(self.output_activation):
82 | pass
83 | elif self.output_activation.lower() in actives:
84 | self.output_activation = actives[self.output_activation]
85 | else:
86 | raise ValueError("should set activation correctly: {}".format(activation))
87 | initial_parameter(self, initial_method)
88 |
89 | def forward(self, x):
90 | r"""
91 | :param torch.Tensor x: MLP接受的输入
92 | :return: torch.Tensor : MLP的输出结果
93 | """
94 | for layer, func in zip(self.hiddens, self.hidden_active):
95 | x = self.dropout(func(layer(x)))
96 | x = self.output(x)
97 | if self.output_activation is not None:
98 | x = self.output_activation(x)
99 | x = self.dropout(x)
100 | return x
101 |
--------------------------------------------------------------------------------
/fastNLP/modules/decoder/utils.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "viterbi_decode"
5 | ]
6 | import torch
7 |
8 |
9 | def viterbi_decode(logits, transitions, mask=None, unpad=False):
10 | r"""
11 | 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数
12 |
13 | :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
14 | :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x
15 | (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; 如果要i->j之间不允许越迁,就把transitions中(i,j)设置为很小的
16 | 负数,例如-10000000.0
17 | :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
18 | :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
19 | List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这
20 | 个sample的有效长度。
21 | :return: 返回 (paths, scores)。
22 | paths: 是解码后的路径, 其值参照unpad参数.
23 | scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。
24 |
25 | """
26 | batch_size, seq_len, n_tags = logits.size()
27 | if transitions.size(0) == n_tags+2:
28 | include_start_end_trans = True
29 | elif transitions.size(0) == n_tags:
30 | include_start_end_trans = False
31 | else:
32 | raise RuntimeError("The shapes of transitions and feats are not " \
33 | "compatible.")
34 | logits = logits.transpose(0, 1).data # L, B, H
35 | if mask is not None:
36 | mask = mask.transpose(0, 1).data.eq(True) # L, B
37 | else:
38 | mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1)
39 |
40 | trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
41 |
42 | # dp
43 | vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
44 | vscore = logits[0]
45 | if include_start_end_trans:
46 | vscore += transitions[n_tags, :n_tags]
47 |
48 | for i in range(1, seq_len):
49 | prev_score = vscore.view(batch_size, n_tags, 1)
50 | cur_score = logits[i].view(batch_size, 1, n_tags)
51 | score = prev_score + trans_score + cur_score
52 | best_score, best_dst = score.max(1)
53 | vpath[i] = best_dst
54 | vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \
55 | vscore.masked_fill(mask[i].view(batch_size, 1), 0)
56 |
57 | if include_start_end_trans:
58 | vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
59 | # backtrace
60 | batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
61 | seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
62 | lens = (mask.long().sum(0) - 1)
63 | # idxes [L, B], batched idx from seq_len-1 to 0
64 | idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
65 |
66 | ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
67 | ans_score, last_tags = vscore.max(1)
68 | ans[idxes[0], batch_idx] = last_tags
69 | for i in range(seq_len - 1):
70 | last_tags = vpath[idxes[i], batch_idx, last_tags]
71 | ans[idxes[i + 1], batch_idx] = last_tags
72 | ans = ans.transpose(0, 1)
73 | if unpad:
74 | paths = []
75 | for idx, seq_len in enumerate(lens):
76 | paths.append(ans[idx, :seq_len + 1].tolist())
77 | else:
78 | paths = ans
79 | return paths, ans_score
80 |
--------------------------------------------------------------------------------
/fastNLP/modules/dropout.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "TimestepDropout"
5 | ]
6 |
7 | import torch
8 |
9 |
10 | class TimestepDropout(torch.nn.Dropout):
11 | r"""
12 | 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)``
13 | 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。
14 | """
15 |
16 | def forward(self, x):
17 | dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
18 | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
19 | dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
20 | if self.inplace:
21 | x *= dropout_mask
22 | return
23 | else:
24 | return x * dropout_mask
25 |
--------------------------------------------------------------------------------
/fastNLP/modules/encoder/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "ConvolutionCharEncoder",
8 | "LSTMCharEncoder",
9 |
10 | "ConvMaxpool",
11 |
12 | "LSTM",
13 |
14 | "StarTransformer",
15 |
16 | "TransformerEncoder",
17 |
18 | "VarRNN",
19 | "VarLSTM",
20 | "VarGRU",
21 |
22 | "MaxPool",
23 | "MaxPoolWithMask",
24 | "KMaxPool",
25 | "AvgPool",
26 | "AvgPoolWithMask",
27 |
28 | "MultiHeadAttention",
29 | "BiAttention",
30 | "SelfAttention",
31 |
32 | "BertModel",
33 |
34 | "RobertaModel",
35 |
36 | "GPT2Model",
37 |
38 | "LSTMSeq2SeqEncoder",
39 | "TransformerSeq2SeqEncoder",
40 | "Seq2SeqEncoder"
41 | ]
42 |
43 | from fastNLP.modules.attention import MultiHeadAttention, BiAttention, SelfAttention
44 | from .bert import BertModel
45 | from .roberta import RobertaModel
46 | from .gpt2 import GPT2Model
47 | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
48 | from .conv_maxpool import ConvMaxpool
49 | from .lstm import LSTM
50 | from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, KMaxPool
51 | from .star_transformer import StarTransformer
52 | from .transformer import TransformerEncoder
53 | from .variational_rnn import VarRNN, VarLSTM, VarGRU
54 | from .seq2seq_encoder import LSTMSeq2SeqEncoder, TransformerSeq2SeqEncoder, Seq2SeqEncoder
55 |
--------------------------------------------------------------------------------
/fastNLP/modules/encoder/char_encoder.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "ConvolutionCharEncoder",
5 | "LSTMCharEncoder"
6 | ]
7 | import torch
8 | import torch.nn as nn
9 |
10 | from ..utils import initial_parameter
11 |
12 |
13 | # from torch.nn.init import xavier_uniform
14 | class ConvolutionCharEncoder(nn.Module):
15 | r"""
16 | char级别的卷积编码器.
17 |
18 | """
19 |
20 | def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None):
21 | r"""
22 |
23 | :param int char_emb_size: char级别embedding的维度. Default: 50
24 | :例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
25 | :param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter.
26 | :param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核.
27 | :param initial_method: 初始化参数的方式, 默认为`xavier normal`
28 | """
29 | super(ConvolutionCharEncoder, self).__init__()
30 | self.convs = nn.ModuleList([
31 | nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True,
32 | padding=(0, kernels[i] // 2))
33 | for i in range(len(kernels))])
34 |
35 | initial_parameter(self, initial_method)
36 |
37 | def forward(self, x):
38 | r"""
39 | :param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding
40 | :return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1]
41 | """
42 | x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
43 | # [batch_size*sent_length, channel, width, height]
44 | x = x.transpose(2, 3)
45 | # [batch_size*sent_length, channel, height, width]
46 | return self._convolute(x).unsqueeze(2)
47 |
48 | def _convolute(self, x):
49 | feats = []
50 | for conv in self.convs:
51 | y = conv(x)
52 | # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
53 | y = torch.squeeze(y, 2)
54 | # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
55 | y = torch.tanh(y)
56 | y, __ = torch.max(y, 2)
57 | # [batch_size*sent_length, feature_maps[i]]
58 | feats.append(y)
59 | return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)]
60 |
61 |
62 | class LSTMCharEncoder(nn.Module):
63 | r"""
64 | char级别基于LSTM的encoder.
65 | """
66 |
67 | def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
68 | r"""
69 | :param int char_emb_size: char级别embedding的维度. Default: 50
70 | 例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
71 | :param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度
72 | :param initial_method: 初始化参数的方式, 默认为`xavier normal`
73 | """
74 | super(LSTMCharEncoder, self).__init__()
75 | self.hidden_size = char_emb_size if hidden_size is None else hidden_size
76 |
77 | self.lstm = nn.LSTM(input_size=char_emb_size,
78 | hidden_size=self.hidden_size,
79 | num_layers=1,
80 | bias=True,
81 | batch_first=True)
82 | initial_parameter(self, initial_method)
83 |
84 | def forward(self, x):
85 | r"""
86 | :param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding
87 | :return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果
88 | """
89 | batch_size = x.shape[0]
90 | h0 = torch.empty(1, batch_size, self.hidden_size)
91 | h0 = nn.init.orthogonal_(h0)
92 | c0 = torch.empty(1, batch_size, self.hidden_size)
93 | c0 = nn.init.orthogonal_(c0)
94 |
95 | _, hidden = self.lstm(x, (h0, c0))
96 | return hidden[0].squeeze().unsqueeze(2)
97 |
--------------------------------------------------------------------------------
/fastNLP/modules/encoder/conv_maxpool.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "ConvMaxpool"
5 | ]
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class ConvMaxpool(nn.Module):
12 | r"""
13 | 集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x
14 | sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)
15 | 这一维进行max_pooling。最后得到每个sample的一个向量表示。
16 |
17 | """
18 |
19 | def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"):
20 | r"""
21 |
22 | :param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度
23 | :param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
24 | :param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
25 | :param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
26 | """
27 | super(ConvMaxpool, self).__init__()
28 |
29 | for kernel_size in kernel_sizes:
30 | assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
31 |
32 | # convolution
33 | if isinstance(kernel_sizes, (list, tuple, int)):
34 | if isinstance(kernel_sizes, int) and isinstance(out_channels, int):
35 | out_channels = [out_channels]
36 | kernel_sizes = [kernel_sizes]
37 | elif isinstance(kernel_sizes, (tuple, list)) and isinstance(out_channels, (tuple, list)):
38 | assert len(out_channels) == len(
39 | kernel_sizes), "The number of out_channels should be equal to the number" \
40 | " of kernel_sizes."
41 | else:
42 | raise ValueError("The type of out_channels and kernel_sizes should be the same.")
43 |
44 | self.convs = nn.ModuleList([nn.Conv1d(
45 | in_channels=in_channels,
46 | out_channels=oc,
47 | kernel_size=ks,
48 | stride=1,
49 | padding=ks // 2,
50 | dilation=1,
51 | groups=1,
52 | bias=None)
53 | for oc, ks in zip(out_channels, kernel_sizes)])
54 |
55 | else:
56 | raise Exception(
57 | 'Incorrect kernel sizes: should be list, tuple or int')
58 |
59 | # activation function
60 | if activation == 'relu':
61 | self.activation = F.relu
62 | elif activation == 'sigmoid':
63 | self.activation = F.sigmoid
64 | elif activation == 'tanh':
65 | self.activation = F.tanh
66 | else:
67 | raise Exception(
68 | "Undefined activation function: choose from: relu, tanh, sigmoid")
69 |
70 | def forward(self, x, mask=None):
71 | r"""
72 |
73 | :param torch.FloatTensor x: batch_size x max_len x input_size, 一般是经过embedding后的值
74 | :param mask: batch_size x max_len, pad的地方为0。不影响卷积运算,max-pool一定不会pool到pad为0的位置
75 | :return:
76 | """
77 | # [N,L,C] -> [N,C,L]
78 | x = torch.transpose(x, 1, 2)
79 | # convolution
80 | xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...]
81 | if mask is not None:
82 | mask = mask.unsqueeze(1) # B x 1 x L
83 | xs = [x.masked_fill_(mask.eq(False), float('-inf')) for x in xs]
84 | # max-pooling
85 | xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2)
86 | for i in xs] # [[N, C], ...]
87 | return torch.cat(xs, dim=-1) # [N, C]
88 |
--------------------------------------------------------------------------------
/fastNLP/modules/encoder/lstm.py:
--------------------------------------------------------------------------------
1 | r"""undocumented
2 | 轻量封装的 Pytorch LSTM 模块.
3 | 可在 forward 时传入序列的长度, 自动对padding做合适的处理.
4 | """
5 |
6 | __all__ = [
7 | "LSTM"
8 | ]
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.utils.rnn as rnn
13 |
14 |
15 | class LSTM(nn.Module):
16 | r"""
17 | LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
18 | 为1; 且可以应对DataParallel中LSTM的使用问题。
19 |
20 | """
21 |
22 | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
23 | bidirectional=False, bias=True):
24 | r"""
25 |
26 | :param input_size: 输入 `x` 的特征维度
27 | :param hidden_size: 隐状态 `h` 的特征维度. 如果bidirectional为True,则输出的维度会是hidde_size*2
28 | :param num_layers: rnn的层数. Default: 1
29 | :param dropout: 层间dropout概率. Default: 0
30 | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
31 | :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为
32 | :(batch, seq, feature). Default: ``False``
33 | :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
34 | """
35 | super(LSTM, self).__init__()
36 | self.batch_first = batch_first
37 | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
38 | dropout=dropout, bidirectional=bidirectional)
39 | self.init_param()
40 |
41 | def init_param(self):
42 | for name, param in self.named_parameters():
43 | if 'bias' in name:
44 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
45 | param.data.fill_(0)
46 | n = param.size(0)
47 | start, end = n // 4, n // 2
48 | param.data[start:end].fill_(1)
49 | else:
50 | nn.init.xavier_uniform_(param)
51 |
52 | def forward(self, x, seq_len=None, h0=None, c0=None):
53 | r"""
54 |
55 | :param x: [batch, seq_len, input_size] 输入序列
56 | :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None``
57 | :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None``
58 | :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None``
59 | :return (output, (ht, ct)): output: [batch, seq_len, hidden_size*num_direction] 输出序列
60 | 和 ht,ct: [num_layers*num_direction, batch, hidden_size] 最后时刻隐状态.
61 | """
62 | batch_size, max_len, _ = x.size()
63 | if h0 is not None and c0 is not None:
64 | hx = (h0, c0)
65 | else:
66 | hx = None
67 | if seq_len is not None and not isinstance(x, rnn.PackedSequence):
68 | sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True)
69 | if self.batch_first:
70 | x = x[sort_idx]
71 | else:
72 | x = x[:, sort_idx]
73 | x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first)
74 | output, hx = self.lstm(x, hx) # -> [N,L,C]
75 | output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first, total_length=max_len)
76 | _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False)
77 | if self.batch_first:
78 | output = output[unsort_idx]
79 | else:
80 | output = output[:, unsort_idx]
81 | hx = hx[0][:, unsort_idx], hx[1][:, unsort_idx]
82 | else:
83 | output, hx = self.lstm(x, hx)
84 | return output, hx
85 |
--------------------------------------------------------------------------------
/fastNLP/modules/encoder/transformer.py:
--------------------------------------------------------------------------------
1 | r"""undocumented"""
2 |
3 | __all__ = [
4 | "TransformerEncoder"
5 | ]
6 | from torch import nn
7 |
8 | from .seq2seq_encoder import TransformerSeq2SeqEncoderLayer
9 |
10 |
11 | class TransformerEncoder(nn.Module):
12 | r"""
13 | transformer的encoder模块,不包含embedding层
14 |
15 | """
16 | def __init__(self, num_layers, d_model=512, n_head=8, dim_ff=2048, dropout=0.1):
17 | """
18 |
19 | :param int num_layers: 多少层Transformer
20 | :param int d_model: input和output的大小
21 | :param int n_head: 多少个head
22 | :param int dim_ff: FFN中间hidden大小
23 | :param float dropout: 多大概率drop attention和ffn中间的表示
24 | """
25 | super(TransformerEncoder, self).__init__()
26 | self.layers = nn.ModuleList([TransformerSeq2SeqEncoderLayer(d_model = d_model, n_head = n_head, dim_ff = dim_ff,
27 | dropout = dropout) for _ in range(num_layers)])
28 | self.norm = nn.LayerNorm(d_model, eps=1e-6)
29 |
30 | def forward(self, x, seq_mask=None):
31 | r"""
32 | :param x: [batch, seq_len, model_size] 输入序列
33 | :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. 为1的地方需要attend
34 | Default: ``None``
35 | :return: [batch, seq_len, model_size] 输出序列
36 | """
37 | output = x
38 | if seq_mask is None:
39 | seq_mask = x.new_ones(x.size(0), x.size(1)).bool()
40 | for layer in self.layers:
41 | output = layer(output, seq_mask)
42 | return self.norm(output)
43 |
--------------------------------------------------------------------------------
/fastNLP/modules/generator/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 |
3 | """
4 |
5 | __all__ = [
6 | "SequenceGenerator"
7 | ]
8 |
9 | from .seq2seq_generator import SequenceGenerator
--------------------------------------------------------------------------------
/fastNLP/modules/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | r"""
2 |
3 | """
4 | __all__=[
5 | 'BertTokenizer',
6 |
7 | "GPT2Tokenizer",
8 |
9 | "RobertaTokenizer"
10 | ]
11 |
12 | from .bert_tokenizer import BertTokenizer
13 | from .gpt2_tokenizer import GPT2Tokenizer
14 | from .roberta_tokenizer import RobertaTokenizer
--------------------------------------------------------------------------------
/fastNLP/modules/tokenizer/roberta_tokenizer.py:
--------------------------------------------------------------------------------
1 | r"""
2 |
3 | """
4 |
5 | __all__ = [
6 | "RobertaTokenizer"
7 | ]
8 |
9 | import json
10 | from .gpt2_tokenizer import GPT2Tokenizer
11 | from fastNLP.io.file_utils import _get_file_name_base_on_postfix
12 | from ...io.file_utils import _get_roberta_dir
13 |
14 | PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES = {
15 | "roberta-base": 512,
16 | "roberta-large": 512,
17 | "roberta-large-mnli": 512,
18 | "distilroberta-base": 512,
19 | "roberta-base-openai-detector": 512,
20 | "roberta-large-openai-detector": 512,
21 | }
22 |
23 |
24 | class RobertaTokenizer(GPT2Tokenizer):
25 |
26 | vocab_files_names = {
27 | "vocab_file": "vocab.json",
28 | "merges_file": "merges.txt",
29 | }
30 |
31 | def __init__(
32 | self,
33 | vocab_file,
34 | merges_file,
35 | errors="replace",
36 | bos_token="",
37 | eos_token="",
38 | sep_token="",
39 | cls_token="",
40 | unk_token="",
41 | pad_token="",
42 | mask_token="",
43 | **kwargs
44 | ):
45 | super().__init__(
46 | vocab_file=vocab_file,
47 | merges_file=merges_file,
48 | errors=errors,
49 | bos_token=bos_token,
50 | eos_token=eos_token,
51 | unk_token=unk_token,
52 | sep_token=sep_token,
53 | cls_token=cls_token,
54 | pad_token=pad_token,
55 | mask_token=mask_token,
56 | **kwargs,
57 | )
58 | self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
59 | self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
60 |
61 | @classmethod
62 | def from_pretrained(cls, model_dir_or_name, *inputs, **kwargs):
63 | """
64 |
65 | :param str model_dir_or_name: 目录或者缩写名
66 | :param kwargs:
67 | :return:
68 | """
69 | # 它需要两个文件,第一个是vocab.json,第二个是merge_file?
70 | model_dir = _get_roberta_dir(model_dir_or_name)
71 | # 里面会包含四个文件vocab.json, merge.txt, config.json, model.bin
72 |
73 | tokenizer_config_file = _get_file_name_base_on_postfix(model_dir, 'config.json')
74 | with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
75 | init_kwargs = json.load(tokenizer_config_handle)
76 | # Set max length if needed
77 | if model_dir_or_name in PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES:
78 | # if we're using a pretrained model, ensure the tokenizer
79 | # wont index sequences longer than the number of positional embeddings
80 | max_len = PRETRAINED_ROBERTA_POSITIONAL_EMBEDDINGS_SIZES[model_dir_or_name]
81 | if max_len is not None and isinstance(max_len, (int, float)):
82 | init_kwargs["max_len"] = min(init_kwargs.get("max_len", int(1e12)), max_len)
83 |
84 | # 将vocab, merge加入到init_kwargs中
85 | if 'vocab_file' in kwargs: # 如果指定了词表则用指定词表
86 | init_kwargs['vocab_file'] = kwargs['vocab_file']
87 | else:
88 | init_kwargs['vocab_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['vocab_file'])
89 | init_kwargs['merges_file'] = _get_file_name_base_on_postfix(model_dir, RobertaTokenizer.vocab_files_names['merges_file'])
90 |
91 | init_inputs = init_kwargs.pop("init_inputs", ())
92 | # Instantiate tokenizer.
93 | try:
94 | tokenizer = cls(*init_inputs, **init_kwargs)
95 | except OSError:
96 | OSError(
97 | "Unable to load vocabulary from file. "
98 | "Please check that the provided vocabulary is accessible and not corrupted."
99 | )
100 |
101 | return tokenizer
102 |
103 |
--------------------------------------------------------------------------------
/fastNLP/modules/utils.py:
--------------------------------------------------------------------------------
1 | r"""
2 | .. todo::
3 | doc
4 | """
5 |
6 | __all__ = [
7 | "initial_parameter",
8 | "summary"
9 | ]
10 |
11 | from functools import reduce
12 |
13 | import torch
14 | import torch.nn as nn
15 | import torch.nn.init as init
16 |
17 |
18 | def initial_parameter(net, initial_method=None):
19 | r"""A method used to initialize the weights of PyTorch models.
20 |
21 | :param net: a PyTorch model
22 | :param str initial_method: one of the following initializations.
23 |
24 | - xavier_uniform
25 | - xavier_normal (default)
26 | - kaiming_normal, or msra
27 | - kaiming_uniform
28 | - orthogonal
29 | - sparse
30 | - normal
31 | - uniform
32 |
33 | """
34 | if initial_method == 'xavier_uniform':
35 | init_method = init.xavier_uniform_
36 | elif initial_method == 'xavier_normal':
37 | init_method = init.xavier_normal_
38 | elif initial_method == 'kaiming_normal' or initial_method == 'msra':
39 | init_method = init.kaiming_normal_
40 | elif initial_method == 'kaiming_uniform':
41 | init_method = init.kaiming_uniform_
42 | elif initial_method == 'orthogonal':
43 | init_method = init.orthogonal_
44 | elif initial_method == 'sparse':
45 | init_method = init.sparse_
46 | elif initial_method == 'normal':
47 | init_method = init.normal_
48 | elif initial_method == 'uniform':
49 | init_method = init.uniform_
50 | else:
51 | init_method = init.xavier_normal_
52 |
53 | def weights_init(m):
54 | # classname = m.__class__.__name__
55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn
56 | if initial_method is not None:
57 | init_method(m.weight.data)
58 | else:
59 | init.xavier_normal_(m.weight.data)
60 | init.normal_(m.bias.data)
61 | elif isinstance(m, nn.LSTM):
62 | for w in m.parameters():
63 | if len(w.data.size()) > 1:
64 | init_method(w.data) # weight
65 | else:
66 | init.normal_(w.data) # bias
67 | elif m is not None and hasattr(m, 'weight') and \
68 | hasattr(m.weight, "requires_grad"):
69 | if len(m.weight.size()) > 1:
70 | init_method(m.weight.data)
71 | else:
72 | init.normal_(m.weight.data) # batchnorm or layernorm
73 | else:
74 | for w in m.parameters():
75 | if w.requires_grad:
76 | if len(w.data.size()) > 1:
77 | init_method(w.data) # weight
78 | else:
79 | init.normal_(w.data) # bias
80 | # print("init else")
81 |
82 | net.apply(weights_init)
83 |
84 |
85 | def summary(model: nn.Module):
86 | r"""
87 | 得到模型的总参数量
88 |
89 | :params model: Pytorch 模型
90 | :return tuple: 包含总参数量,可训练参数量,不可训练参数量
91 | """
92 | train = []
93 | nontrain = []
94 | buffer = []
95 |
96 | def layer_summary(module: nn.Module):
97 | def count_size(sizes):
98 | return reduce(lambda x, y: x * y, sizes)
99 |
100 | for p in module.parameters(recurse=False):
101 | if p.requires_grad:
102 | train.append(count_size(p.shape))
103 | else:
104 | nontrain.append(count_size(p.shape))
105 | for p in module.buffers():
106 | buffer.append(count_size(p.shape))
107 | for subm in module.children():
108 | layer_summary(subm)
109 |
110 | layer_summary(model)
111 | total_train = sum(train)
112 | total_nontrain = sum(nontrain)
113 | total = total_train + total_nontrain
114 | strings = []
115 | strings.append('Total params: {:,}'.format(total))
116 | strings.append('Trainable params: {:,}'.format(total_train))
117 | strings.append('Non-trainable params: {:,}'.format(total_nontrain))
118 | strings.append("Buffer params: {:,}".format(sum(buffer)))
119 | max_len = len(max(strings, key=len))
120 | bar = '-' * (max_len + 3)
121 | strings = [bar] + strings + [bar]
122 | print('\n'.join(strings))
123 | return total, total_train, total_nontrain
124 |
125 |
126 | def get_dropout_mask(drop_p: float, tensor: torch.Tensor):
127 | r"""
128 | 根据tensor的形状,生成一个mask
129 |
130 | :param drop_p: float, 以多大的概率置为0。
131 | :param tensor: torch.Tensor
132 | :return: torch.FloatTensor. 与tensor一样的shape
133 | """
134 | mask_x = torch.ones_like(tensor)
135 | nn.functional.dropout(mask_x, p=drop_p,
136 | training=False, inplace=True)
137 | return mask_x
138 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.12.0
2 | aiohttp==3.7.4.post0
3 | alembic==1.5.8
4 | antlr4-python3-runtime==4.8
5 | asn1crypto==0.24.0
6 | async-timeout==3.0.1
7 | attrs==20.3.0
8 | blessings==1.7
9 | cachetools==4.2.1
10 | certifi==2020.12.5
11 | chardet==4.0.0
12 | click==7.1.2
13 | cliff==3.7.0
14 | cmaes==0.8.2
15 | cmd2==1.5.0
16 | colorama==0.4.4
17 | colorlog==4.8.0
18 | command-not-found==0.3
19 | commonmark==0.9.1
20 | configparser==5.0.2
21 | cryptography==2.1.4
22 | dataclasses==0.8
23 | distro-info==0.18ubuntu0.18.04.1
24 | docker-pycreds==0.4.0
25 | easydict==1.9
26 | filelock==3.0.12
27 | fsspec==0.8.7
28 | future==0.18.2
29 | gitdb==4.0.5
30 | gitpython==3.1.14
31 | google-auth-oauthlib==0.4.3
32 | google-auth==1.28.0
33 | gpustat==0.6.0
34 | greenlet==1.0.0
35 | grpcio==1.36.1
36 | hydra-core==1.1.0.dev4
37 | hydra-optuna-sweeper==1.1.0.dev1
38 | idna-ssl==1.1.0
39 | idna==2.10
40 | importlib-metadata==3.10.0
41 | importlib-resources==5.1.2
42 | joblib==1.0.1
43 | keyring==10.6.0
44 | keyrings.alt==3.0
45 | language-selector==0.1
46 | mako==1.1.4
47 | markdown==3.3.4
48 | markupsafe==1.1.1
49 | multidict==5.1.0
50 | netifaces==0.10.4
51 | nltk==3.5
52 | numpy==1.19.5
53 | nvidia-ml-py3==7.352.0
54 | oauthlib==3.1.0
55 | omegaconf==2.1.0.dev24
56 | opt-einsum==3.3.0
57 | optuna==2.4.0
58 | packaging==20.9
59 | pathtools==0.1.2
60 | pbr==5.5.1
61 | pillow==8.0.0
62 | pip==9.0.1
63 | prettytable==2.1.0
64 | promise==2.3
65 | protobuf==3.15.6
66 | psutil==5.7.2
67 | pyasn1-modules==0.2.8
68 | pyasn1==0.4.8
69 | pycrypto==2.6.1
70 | pygments==2.8.1
71 | pygobject==3.26.1
72 | pyparsing==2.4.7
73 | pyperclip==1.8.2
74 | python-apt==1.6.5+ubuntu0.3
75 | python-dateutil==2.8.1
76 | python-distutils-extra==2.39
77 | python-editor==1.0.4
78 | pytorch-lightning==1.2.4
79 | pyxdg==0.25
80 | pyyaml==5.4.1
81 | regex==2020.11.13
82 | requests-oauthlib==1.3.0
83 | requests==2.25.1
84 | rich==9.13.0
85 | rsa==4.7.2
86 | sacremoses==0.0.43
87 | scipy==1.5.4
88 | secretstorage==2.3.1
89 | sentry-sdk==1.0.0
90 | setuptools==54.1.2
91 | shortuuid==1.0.1
92 | six==1.15.0
93 | smmap==3.0.5
94 | sqlalchemy==1.4.6
95 | ssh-import-id==5.7
96 | stevedore==3.3.0
97 | subprocess32==3.5.4
98 | tensorboard-plugin-wit==1.8.0
99 | tensorboard==2.4.1
100 | tokenizers==0.10.1
101 | torch-struct==0.5
102 | torch==1.7.0
103 | torchaudio==0.7.0
104 | torchmetrics==0.2.0
105 | torchvision==0.8.0
106 | tqdm==4.60.0
107 | transformers==4.3.3
108 | typing-extensions==3.7.4.3
109 | ufw==0.36
110 | urllib3==1.26.4
111 | wandb==0.10.21
112 | wcwidth==0.2.5
113 | werkzeug==1.0.1
114 | wheel==0.36.2
115 | yarl==1.6.3
116 | zipp==3.4.1
117 |
--------------------------------------------------------------------------------
/src/callbacks/progressbar.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint, ProgressBar
2 | from tqdm import tqdm
3 |
4 | class PrettyProgressBar(ProgressBar):
5 | """Good print wrapper."""
6 | def __init__(self, refresh_rate: int, process_position: int):
7 | super().__init__(refresh_rate=refresh_rate, process_position=process_position)
8 |
9 | def init_sanity_tqdm(self) -> tqdm:
10 | bar = tqdm(desc='Validation sanity check',
11 | position=self.process_position,
12 | disable=self.is_disabled,
13 | leave=False,
14 | ncols=120,
15 | ascii=True)
16 | return bar
17 |
18 | def init_train_tqdm(self) -> tqdm:
19 | bar = tqdm(desc='Training',
20 | initial=self.train_batch_idx,
21 | position=self.process_position,
22 | disable=self.is_disabled,
23 | leave=True,
24 | smoothing=0,
25 | ncols=120,
26 | ascii=True)
27 | return bar
28 |
29 |
30 |
31 | def init_validation_tqdm(self) -> tqdm:
32 | bar = tqdm(disable=True)
33 | return bar
34 |
35 | def on_epoch_start(self, trainer, pl_module):
36 | super().on_epoch_start(trainer, pl_module)
37 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|train')
38 |
39 | def on_validation_start(self, trainer, pl_module):
40 | super().on_validation_start(trainer, pl_module)
41 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|val')
42 |
43 | def on_test_start(self, trainer, pl_module):
44 | super().on_test_start(trainer, pl_module)
45 | self.main_progress_bar.set_description(f'E{trainer.current_epoch}|test')
46 |
47 |
--------------------------------------------------------------------------------
/src/callbacks/transformer_scheduler.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from transformers import get_linear_schedule_with_warmup, get_constant_schedule_with_warmup
3 | import logging
4 | log = logging.getLogger(__name__)
5 |
6 | class TransformerLrScheduler(pl.Callback):
7 | def __init__(self, warmup):
8 | self.warmup = warmup
9 |
10 | def on_train_start(self, trainer, pl_module):
11 | for lr_scheduler in trainer.lr_schedulers:
12 | scheduler = lr_scheduler['scheduler']
13 | n_train = len(pl_module.train_dataloader())
14 | n_accumulate_grad = trainer.accumulate_grad_batches
15 | n_max_epochs = trainer.max_epochs
16 | num_training_steps = n_train // n_accumulate_grad * n_max_epochs
17 | num_warmup_steps = int(self.warmup * num_training_steps)
18 |
19 | if pl_module.optimizer_cfg.scheduler_type == 'linear_warmup':
20 | lr_scheduler['scheduler'] = get_linear_schedule_with_warmup(scheduler.optimizer, num_warmup_steps, num_training_steps)
21 |
22 | elif pl_module.optimizer_cfg.scheduler_type == 'constant_warmup':
23 | lr_scheduler['scheduler'] = get_constant_schedule_with_warmup(scheduler.optimizer, num_warmup_steps,
24 | )
25 |
26 | log.info(f"Warm up rate:{self.warmup}")
27 | log.info(f"total number of training step:{num_training_steps}")
28 | log.info(f"number of training batches per epochs in the dataloader:{n_train}")
--------------------------------------------------------------------------------
/src/callbacks/wandb_callbacks.py:
--------------------------------------------------------------------------------
1 | # wandb
2 | from pytorch_lightning.loggers import WandbLogger
3 | import wandb
4 |
5 | # pytorch
6 | from pytorch_lightning import Callback
7 | import pytorch_lightning as pl
8 | import torch
9 |
10 | # others
11 | import glob
12 | import os
13 |
14 |
15 | def get_wandb_logger(trainer: pl.Trainer) -> WandbLogger:
16 | logger = None
17 | for lg in trainer.logger:
18 | if isinstance(lg, WandbLogger):
19 | logger = lg
20 |
21 | if not logger:
22 | raise Exception(
23 | "You are using wandb related callback,"
24 | "but WandbLogger was not found for some reason..."
25 | )
26 |
27 | return logger
28 |
29 |
30 | # class UploadCodeToWandbAsArtifact(Callback):
31 | # """Upload all *.py files to wandb as an artifact at the beginning of the run."""
32 | #
33 | # def __init__(self, code_dir: str):
34 | # self.code_dir = code_dir
35 | #
36 | # def on_train_start(self, trainer, pl_module):
37 | # logger = get_wandb_logger(trainer=trainer)
38 | # experiment = logger.experiment
39 | #
40 | # code = wandb.Artifact("project-source", type="code")
41 | # for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True):
42 | # print(path)
43 | # code.add_file(path)
44 | # print('ok')
45 | #
46 | #
47 | #
48 | # experiment.use_artifact(code)
49 | # print("successfully update the code .")
50 |
51 |
52 | class UploadHydraConfigFileToWandb(Callback):
53 | def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
54 | logger = get_wandb_logger(trainer=trainer)
55 |
56 | logger.experiment.save()
57 |
58 |
59 |
60 | class UploadCheckpointsToWandbAsArtifact(Callback):
61 | """Upload experiment checkpoints to wandb as an artifact at the end of training."""
62 |
63 | def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False):
64 | self.ckpt_dir = ckpt_dir
65 | self.upload_best_only = upload_best_only
66 |
67 | def on_train_end(self, trainer, pl_module):
68 | logger = get_wandb_logger(trainer=trainer)
69 | experiment = logger.experiment
70 |
71 | ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints")
72 |
73 | if self.upload_best_only:
74 | ckpts.add_file(trainer.checkpoint_callback.best_model_path)
75 | else:
76 | for path in glob.glob(
77 | os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True
78 | ):
79 | ckpts.add_file(path)
80 |
81 | experiment.use_artifact(ckpts)
82 |
83 |
84 | class WatchModelWithWandb(Callback):
85 | """Make WandbLogger watch model at the beginning of the run."""
86 |
87 | def __init__(self, log: str = "gradients", log_freq: int = 100):
88 | self.log = log
89 | self.log_freq = log_freq
90 |
91 | def on_train_start(self, trainer, pl_module):
92 | logger = get_wandb_logger(trainer=trainer)
93 | logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq)
94 |
95 |
--------------------------------------------------------------------------------
/src/constant.py:
--------------------------------------------------------------------------------
1 | pad = ''
2 | unk = ''
3 | bos = ''
4 | eos = ''
5 | arc = 'arc'
6 | rel = 'rel'
7 | word = 'word'
8 | chart = 'chart'
9 | raw_tree = 'raw_tree'
10 | valid = 'valid'
11 | sib = 'sibling'
12 | relatives = 'relatives'
13 | span_head = 'span_head'
14 |
15 |
--------------------------------------------------------------------------------
/src/datamodule/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/datamodule/dm_util/datamodule_util.py:
--------------------------------------------------------------------------------
1 | from supar.utils.alg import kmeans
2 | from supar.utils.data import Sampler
3 | from fastNLP.core.field import Padder
4 | import numpy as np
5 |
6 |
7 |
8 |
9 | def get_sampler(lengths, max_tokens, n_buckets, shuffle=True, distributed=False, evaluate=False):
10 | buckets = dict(zip(*kmeans(lengths, n_buckets)))
11 | return Sampler(buckets=buckets,
12 | batch_size=max_tokens,
13 | shuffle=shuffle,
14 | distributed=distributed,
15 | evaluate=evaluate)
16 |
17 |
18 |
19 | class SiblingPadder(Padder):
20 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
21 | # max_sent_length = max(rule.shape[0] for rule in contents)
22 | batch_size = sum(len(r) for r in contents)
23 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val,
24 | dtype=np.long)
25 |
26 | i = 0
27 | for b_idx, relations in enumerate(contents):
28 | for (head, child, sibling, _) in relations:
29 | padded_array[i] = np.array([b_idx, head, child, sibling])
30 | i+=1
31 |
32 | return padded_array
33 |
34 |
35 |
36 |
37 | class ConstAsDepPadder(Padder):
38 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
39 | # max_sent_length = max(rule.shape[0] for rule in contents)
40 | batch_size = sum(len(r) for r in contents)
41 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val,
42 | dtype=np.long)
43 | i = 0
44 | for b_idx, relations in enumerate(contents):
45 | for (head, child, sibling, *_) in relations:
46 | padded_array[i] = np.array([b_idx, head, child, sibling])
47 | i+=1
48 | return padded_array
49 |
50 | class GrandPadder(Padder):
51 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
52 | batch_size = sum(len(r) for r in contents)
53 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val,
54 | dtype=np.float)
55 | i = 0
56 | for b_idx, relations in enumerate(contents):
57 | for (head, child, _, grandparent, *_) in relations:
58 | padded_array[i] = np.array([b_idx, head, child, grandparent])
59 | i += 1
60 |
61 | return padded_array
62 |
63 | class SpanPadderCAD(Padder):
64 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
65 | batch_size = sum((len(r) * 2 - 1) for r in contents)
66 | padded_array = np.full((batch_size, 6), fill_value=self.pad_val,
67 | dtype=np.float)
68 | i = 0
69 |
70 | ## 0 stands for inherent
71 | ## 1 stands for noninherent
72 | for b_idx, span in enumerate(contents):
73 | for (head, child, ih_start, ih_end, ni_start, ni_end) in span:
74 | if not ih_start == -1:
75 | padded_array[i] = np.array([b_idx, head, child, ih_start, ih_end, 0])
76 | i += 1
77 | padded_array[i] = np.array([b_idx, head, child, ni_start, ni_end, 1])
78 | i += 1
79 |
80 | assert i == batch_size
81 | return padded_array
82 |
83 |
84 |
85 |
86 | class SpanHeadPadder(Padder):
87 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
88 | batch_size = sum(len(r) for r in contents)
89 | padded_array = np.full((batch_size, 4), fill_value=self.pad_val,
90 | dtype=np.float)
91 | i = 0
92 | for b_idx, relations in enumerate(contents):
93 | for (left, right, head) in relations:
94 | padded_array[i] = np.array([b_idx, left, right, head])
95 | i += 1
96 | return padded_array
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/src/datamodule/dm_util/padder.py:
--------------------------------------------------------------------------------
1 | from fastNLP.core.field import Padder
2 | import numpy as np
3 |
4 |
5 | def set_padder(datasets, name, padder):
6 | for _, dataset in datasets.items():
7 | dataset.add_field(name, dataset[name].content, padder=padder, ignore_type=True)
8 |
9 |
10 | class DepSibPadder(Padder):
11 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
12 | # max_sent_length = max(rule.shape[0] for rule in contents)
13 | padded_array = []
14 | # dependency head or relations.
15 | for b_idx, dep in enumerate(contents):
16 | for (child_idx, (head_idx, sib_ix)) in enumerate(dep):
17 | # -1 means no sib;
18 | if sib_ix != -1:
19 | padded_array.append([b_idx, child_idx + 1, head_idx, sib_ix])
20 | else:
21 | pass
22 | return np.array(padded_array)
23 |
24 |
25 |
26 |
27 |
28 | class DepPadder(Padder):
29 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
30 | # max_sent_length = max(rule.shape[0] for rule in contents)
31 | padded_array = []
32 | # dependency head or relations.
33 | for b_idx, dep in enumerate(contents):
34 | for (child_idx, dep_idx) in enumerate(dep):
35 | padded_array.append([b_idx, child_idx + 1, dep_idx])
36 | return np.array(padded_array)
37 |
38 |
39 |
40 |
41 | class SpanHeadWordPadder(Padder):
42 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
43 | padded_array = []
44 | for b_idx, relations in enumerate(contents):
45 | for (left, right, _, head) in relations:
46 | padded_array.append([b_idx, left, right, head])
47 | return np.array(padded_array)
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/src/datamodule/dm_util/util.py:
--------------------------------------------------------------------------------
1 | import unicodedata
2 |
3 | punct_set = '.' '``' "''" ':' ','
4 | import re
5 |
6 |
7 | # https://github.com/DoodleJZ/HPSG-Neural-Parser/blob/cdcffa78945359e14063326cadd93fd4c509c585/src_joint/dep_eval.py
8 | def is_uni_punctuation(word):
9 | match = re.match("^[^\w\s]+$]", word, flags=re.UNICODE)
10 | return match is not None
11 |
12 | def is_punctuation(word, pos, punct_set=punct_set):
13 | if punct_set is None:
14 | return is_uni_punctuation(word)
15 | else:
16 | return pos in punct_set or pos == 'PU' # for chinese
17 |
18 | def get_path(path):
19 | return path
20 |
21 | def get_path_debug(path):
22 | return path + ".debug"
23 |
24 | def clean_number(w):
25 | new_w = re.sub('[0-9]{1,}([,.]?[0-9]*)*', '0', w)
26 | return new_w
27 |
28 | def clean_word(words):
29 | PTB_UNESCAPE_MAPPING = {
30 | "«": '"',
31 | "»": '"',
32 | "‘": "'",
33 | "’": "'",
34 | "“": '"',
35 | "”": '"',
36 | "„": '"',
37 | "‹": "'",
38 | "›": "'",
39 | "\u2013": "--", # en dash
40 | "\u2014": "--", # em dash
41 | }
42 | cleaned_words = []
43 | for word in words:
44 | word = PTB_UNESCAPE_MAPPING.get(word, word)
45 | word = word.replace("\\/", "/").replace("\\*", "*")
46 | # Mid-token punctuation occurs in biomedical text
47 | word = word.replace("-LSB-", "[").replace("-RSB-", "]")
48 | word = word.replace("-LRB-", "(").replace("-RRB-", ")")
49 | word = word.replace("-LCB-", "{").replace("-RCB-", "}")
50 | word = word.replace("``", '"').replace("`", "'").replace("''", '"')
51 | word = clean_number(word)
52 | cleaned_words.append(word)
53 | return cleaned_words
54 |
55 |
56 |
57 |
58 | def find_dep_boundary(heads):
59 | left_bd = [i for i in range(len(heads))]
60 | right_bd = [i + 1 for i in range(len(heads))]
61 |
62 | for child_idx, head_idx in enumerate(heads):
63 | if head_idx > 0:
64 | if left_bd[child_idx] < left_bd[head_idx - 1]:
65 | left_bd[head_idx - 1] = left_bd[child_idx]
66 |
67 | elif child_idx > right_bd[head_idx - 1] - 1:
68 | right_bd[head_idx - 1] = child_idx + 1
69 | while head_idx != 0:
70 | if heads[head_idx-1] > 0 and child_idx + 1 > right_bd[ heads[head_idx-1] - 1] :
71 | right_bd[ heads[head_idx-1] - 1] = child_idx + 1
72 | head_idx = heads[head_idx-1]
73 | else:
74 | break
75 |
76 | # (head_word_idx, left_bd_idx, right_bd_idx)
77 | triplet = []
78 | # head index should add1, as the root token would be the first token.
79 | # [ ) left bdr, right bdr.
80 | # for i in range(len(heads)):
81 | # what do I want?
82 | # 生成整个span的score????????
83 | # seems ok.s
84 |
85 | # left boundary, right boundary, parent, head
86 | for i, (parent, left_bdr, right_bdr) in enumerate(zip(heads, left_bd, right_bd)):
87 | triplet.append([left_bdr, right_bdr, parent-1, i])
88 |
89 |
90 |
91 |
92 |
93 | return triplet
94 |
95 |
96 |
97 |
98 |
99 | def isProjective(heads):
100 | pairs = [(h, d) for d, h in enumerate(heads, 1) if h >= 0]
101 | for i, (hi, di) in enumerate(pairs):
102 | for hj, dj in pairs[i+1:]:
103 | (li, ri), (lj, rj) = sorted([hi, di]), sorted([hj, dj])
104 | if li <= hj <= ri and hi == dj:
105 | return False
106 | if lj <= hi <= rj and hj == di:
107 | return False
108 | if (li < lj < ri or li < rj < ri) and (li - lj)*(ri - rj) > 0:
109 | return False
110 | return True
111 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/src/inside/__init__.py:
--------------------------------------------------------------------------------
1 | from .eisner_satta import es4dep
2 | from .eisner import eisner, eisner_headsplit
3 | from .eisner2o import eisner2o, eisner2o_headsplit
4 | from .span import span_inside
5 |
6 | __all__ = ['eisner_satta',
7 | 'eisner',
8 | 'eisner2o',
9 | 'es4dep',
10 | 'eisner_headsplit',
11 | 'eisner2o_headsplit',
12 | 'span_inside'
13 | ]
14 |
15 |
--------------------------------------------------------------------------------
/src/inside/eisner_satta.py:
--------------------------------------------------------------------------------
1 | from .fn import *
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | from ..loss.get_score import augment_score
5 |
6 |
7 |
8 | # O(n4) span+arc
9 | @torch.enable_grad()
10 | def es4dep(ctx, decode=False, max_margin=False):
11 | if max_margin:
12 | augment_score(ctx)
13 |
14 | lens = ctx['seq_len']
15 |
16 | dependency = ctx['s_arc']
17 |
18 |
19 | B, seq_len = dependency.shape[:2]
20 | head_score = ctx['s_span_head_word']
21 |
22 | if decode:
23 | dependency = dependency.detach().clone().requires_grad_(True)
24 |
25 | if max_margin:
26 | dependency = dependency.detach().clone().requires_grad_(True)
27 | head_score = head_score.detach().clone().requires_grad_(True)
28 |
29 | dep = dependency[:, 1:, 1:].contiguous()
30 | root = dependency[:, 1:, 0].contiguous()
31 |
32 | viterbi = decode or max_margin
33 |
34 | N = seq_len
35 | H = N - 1
36 |
37 | s = dependency.new_zeros(B, N, N, H).fill_(-1e9)
38 | s_close = dependency.new_zeros(B, N, N, H).fill_(-1e9)
39 | s_need_dad = dependency.new_zeros(B, N, N, H).fill_(-1e9)
40 |
41 | s_close[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)] = head_score[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)]
42 |
43 | s[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)] = 0
44 |
45 | s_need_dad[:, torch.arange(N - 1), torch.arange(N - 1) + 1, :] = dep[:, torch.arange(N - 1)] + s_close[:, torch.arange(N - 1), torch.arange(N - 1) + 1, torch.arange(N - 1)].unsqueeze(-1)
46 |
47 | def merge(left, right, left_need_dad, right_need_dad):
48 | left = (left + right_need_dad)
49 | right = (right + left_need_dad)
50 | if viterbi:
51 | headed = torch.stack([left.max(2)[0], right.max(2)[0]])
52 | return headed.max(0)[0]
53 | else:
54 | headed = torch.stack([left.logsumexp(2), right.logsumexp(2)])
55 | return headed.logsumexp(0)
56 |
57 | def seek_head(a, b):
58 | if viterbi:
59 | tmp = (a + b).max(-2)[0]
60 | else:
61 | tmp = (a + b).logsumexp(-2)
62 | return tmp
63 |
64 |
65 | for w in range(2, N):
66 | n = N - w
67 | left = stripe_version2(s, n, w - 1, (0, 1))
68 | right = stripe_version2(s, n, w - 1, (1, w), 0)
69 | left_need_dad = stripe_version2(s_need_dad, n, w - 1, (0, 1))
70 | right_need_dad = stripe_version2(s_need_dad, n, w - 1, (1, w), 0)
71 | headed = checkpoint(merge, left.clone(), right.clone(), left_need_dad.clone(),right_need_dad.clone())
72 | diagonal_copy_v2(s, headed, w)
73 | headed = headed + diagonal_v2(head_score, w)
74 | diagonal_copy_v2(s_close, headed, w)
75 |
76 | if w < N - 1:
77 | u = checkpoint(seek_head, headed.unsqueeze(-1), stripe_version5(dep, n, w))
78 | diagonal_copy_(s_need_dad, u, w)
79 |
80 | logZ = (s_close[torch.arange(B), 0, lens] + root)
81 | if viterbi:
82 | logZ = logZ.max(-1)[0]
83 | else:
84 | logZ = logZ.logsumexp(-1)
85 |
86 | #crf loss
87 | if not decode and not max_margin:
88 | return logZ
89 |
90 | logZ.sum().backward()
91 |
92 | if decode:
93 | predicted_arc = s.new_zeros(B, seq_len).long()
94 | arc = dependency.grad.nonzero()
95 | predicted_arc[arc[:, 0], arc[:, 1]] = arc[:, 2]
96 | ctx['arc_pred'] = predicted_arc
97 |
98 | if max_margin:
99 | ctx['s_arc_grad'] = dependency.grad
100 | if head_score is not None:
101 | ctx['s_span_head_word_grad'] = head_score.grad
102 |
--------------------------------------------------------------------------------
/src/inside/span.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from src.inside.fn import *
3 | from src.inside.eisner_satta import es4dep
4 | from ..loss.get_score import augment_score
5 |
6 | # O(n3) algorithm, only use span core, do not use arc score for projective dependency parsing.
7 | @torch.enable_grad()
8 | def span_inside(ctx, decode=False, max_margin=False):
9 |
10 | assert decode or max_margin
11 |
12 | if max_margin:
13 | augment_score(ctx)
14 |
15 | s_span_score = ctx['s_span_head_word']
16 | lens = ctx['seq_len']
17 |
18 | s_span_score = s_span_score.detach().clone().requires_grad_(True)
19 |
20 | b, seq_len = s_span_score.shape[:2]
21 |
22 | s_inside_children = s_span_score.new_zeros(b, seq_len, seq_len).fill_(-1e9)
23 | s_inside_close = s_span_score.new_zeros(b, seq_len, seq_len).fill_(-1e9)
24 |
25 | # do I need s_close? it seems that i do not need this term? right
26 | s_inside_children[:, torch.arange(seq_len-1), torch.arange(seq_len-1)+1] = s_span_score[:, torch.arange(seq_len-1), torch.arange(seq_len-1)+1, torch.arange(seq_len-1)]
27 |
28 | for w in range(2, seq_len):
29 | n = seq_len - w
30 |
31 | # two child compose together.
32 | left = stripe(s_inside_children, n, w - 1, (0, 1))
33 | right = stripe(s_inside_children, n, w - 1, (1, w), 0)
34 | compose = (left + right).max(2)[0]
35 |
36 | # case 1: the head word is right-most
37 | l = left[:, :, -1]
38 | compose_score1 = l + s_span_score[:, torch.arange(n), torch.arange(n)+w, torch.arange(n)+w-1]
39 |
40 | # case 2: the head word is left-most.
41 | r = right[:, :, 0]
42 | compose_score2 = r + s_span_score[:, torch.arange(n), torch.arange(n)+w, torch.arange(n)]
43 |
44 | if w > 2:
45 | left = stripe(s_inside_children, n, w - 2, (0, 1))
46 | right = stripe(s_inside_children, n, w - 2, (2, w), 0)
47 | compose_score3 = left + right + diagonal_v2(s_span_score, w)[:, :, 1:-1]
48 | compose_score = torch.cat([compose_score1.unsqueeze(2), compose_score3, compose_score2.unsqueeze(2)], dim=2)
49 | compose_score = compose_score.max(2)[0]
50 |
51 | else:
52 | compose_score = torch.cat([compose_score1.unsqueeze(2), compose_score2.unsqueeze(2)], dim=2)
53 | compose_score = compose_score.max(2)[0]
54 |
55 | compose = torch.stack([compose, compose_score]).max(0)[0]
56 | diagonal_copy_(s_inside_children, compose, w)
57 | diagonal_copy_(s_inside_close, compose_score, w)
58 |
59 | try:
60 | s_inside_close[torch.arange(b), 0, lens].sum().backward()
61 | except:
62 | pass
63 |
64 | if decode:
65 | ctx['arc_pred'] = recover_arc(s_span_score.grad, lens)
66 |
67 | if max_margin:
68 | ctx['s_span_head_word_grad'] = s_span_score.grad
69 |
70 |
71 |
72 |
73 | # heads: (left, right, head word).
74 | # from decoded span to recover arcs.
75 | def recover_arc(heads, lens):
76 | if heads is None:
77 | return lens.new_zeros(lens.shape[0], lens.max() + 1)
78 | result = np.zeros(shape=(heads.shape[0], heads.shape[1]))
79 | lens = lens.detach().cpu().numpy()
80 | for i in range(heads.shape[0]):
81 | if lens[i] == 1:
82 | result[i][1] = 0
83 | else:
84 | span = heads[i].detach().nonzero().cpu().numpy()
85 | start = span[:,0]
86 | end = span[:,1]
87 | preorder_sort = np.lexsort((-end, start))
88 | start = start[preorder_sort]
89 | end = end[preorder_sort]
90 | head = span[:,2][preorder_sort]
91 | stack = []
92 | arcs = []
93 | stack.append((start[0], end[0], 0))
94 | result[i][head[0]+1] = 0
95 | j = 0
96 | while j < start.shape[0]-1:
97 | j+=1
98 | s = start[j]
99 | e = end[j]
100 | top = stack[-1]
101 | top_s, top_e, top_i = top
102 | if top_s <= s and top_e >= e:
103 | arcs.append((head[top_i] + 1, head[j] + 1))
104 | result[i][head[j] + 1] = head[top_i] + 1
105 | stack.append((s, e, j))
106 | else:
107 | while top_s > s or top_e < e:
108 | stack.pop()
109 | top = stack[-1]
110 | top_s, top_e, top_i = top
111 | arcs.append([head[top_i] + 1, head[j] + 1])
112 | result[i][head[j] + 1] = head[top_i] + 1
113 | stack.append((s, e, j))
114 | return torch.tensor(result, device=heads.device, dtype=torch.long)
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
--------------------------------------------------------------------------------
/src/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sustcsonglin/span-based-dependency-parsing/db7b94a06a8da3d3055e5baf50fe2fe2f10e58d3/src/loss/__init__.py
--------------------------------------------------------------------------------
/src/loss/dep_loss.py:
--------------------------------------------------------------------------------
1 | from ..inside import *
2 | from supar.utils.common import *
3 | from .get_score import *
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | class DepLoss():
8 | def __init__(self, conf):
9 | self.conf = conf
10 | self.inside = eisner
11 |
12 |
13 | @classmethod
14 | def label_loss(self, ctx, reduction='mean'):
15 | s_rel = ctx['s_rel']
16 | gold_rel = ctx['rel']
17 | gold_arc = ctx['head']
18 | if len(s_rel.shape) == 4:
19 | return F.cross_entropy(s_rel[gold_arc[:, 0], gold_arc[:, 1], gold_arc[:, 2]] , torch.as_tensor(gold_rel[:, -1], device=s_rel.device, dtype=torch.long),reduction=reduction)
20 | elif len(s_rel.shape) == 3:
21 | return F.cross_entropy(s_rel[gold_arc[:, 0], gold_arc[:, 1]], torch.as_tensor(gold_rel[:, -1], device=s_rel.device, dtype=torch.long),reduction=reduction)
22 | else:
23 | raise AssertionError
24 |
25 | @classmethod
26 | def get_pred_rels(self,ctx):
27 | arc_preds = ctx['arc_pred']
28 | s_rel = ctx['s_rel']
29 | ctx['rel_pred'] = s_rel.argmax(-1).gather(-1, arc_preds.unsqueeze(-1)).squeeze(-1)
30 |
31 | # for evaluation.
32 | @classmethod
33 | def _transform(self, ctx):
34 | arc_preds, arc_golds = ctx['arc_pred'], ctx['head']
35 | rel_preds, rel_golds = ctx['rel_pred'], ctx['rel']
36 |
37 | arc_golds = torch.as_tensor(arc_golds, device=arc_preds.device, dtype=arc_preds.dtype)
38 | rel_golds = torch.as_tensor(rel_golds, device=rel_preds.device, dtype=rel_preds.dtype)
39 |
40 | arc_gold = arc_preds.new_zeros(*arc_preds.shape).fill_(-1)
41 | arc_gold[arc_golds[:, 0], arc_golds[:, 1]] = arc_golds[:, 2]
42 |
43 | rel_gold = rel_preds.new_zeros(*rel_preds.shape).fill_(-1)
44 | rel_gold[rel_golds[:, 0], rel_golds[:, 1]] = rel_golds[:, 2]
45 | mask_dep = arc_gold.ne(-1)
46 |
47 | #ignore punct.
48 | if 'is_punct' in ctx:
49 | mask_punct = ctx['is_punct'].nonzero()
50 | mask_dep[mask_punct[:, 0], mask_punct[:, 1] + 1] = False
51 | ctx['arc_gold'] = arc_gold
52 | ctx['rel_gold'] = rel_gold
53 | ctx['mask_dep'] = mask_dep
54 |
55 | def max_margin_loss(self, ctx):
56 | with torch.no_grad():
57 | self.inside(ctx, max_margin=True)
58 | gold_score = u_score(ctx)
59 | predict_score = predict_score_mm(ctx)
60 | return (predict_score - gold_score)/ctx['seq_len'].sum()
61 |
62 | def crf_loss(self, ctx):
63 | logZ = self.inside(ctx).sum()
64 | gold_score = u_score(ctx)
65 | return (logZ - gold_score)/ ctx['seq_len'].sum()
66 |
67 | def local_loss(self, ctx):
68 | raise NotImplementedError
69 |
70 | def loss(self, ctx):
71 | if self.conf.loss_type == 'mm':
72 | tree_loss = self.max_margin_loss(ctx)
73 | elif self.conf.loss_type == 'crf':
74 | tree_loss = self.crf_loss(ctx)
75 | elif self.conf.loss_type == 'local':
76 | tree_loss = self.local_loss(ctx)
77 | label_loss = self.label_loss(ctx)
78 | return tree_loss + label_loss
79 |
80 | def decode(self, ctx):
81 | self.inside(ctx, decode=True)
82 | self.get_pred_rels(ctx)
83 | self._transform(ctx)
84 |
85 |
86 | class Dep1OSpan(DepLoss):
87 | def __init__(self, conf):
88 | super(Dep1OSpan, self).__init__(conf)
89 | self.inside = es4dep
90 |
91 | class Dep2O(DepLoss):
92 | def __init__(self, conf):
93 | super(Dep2O, self).__init__(conf)
94 | self.inside = eisner2o
95 |
96 | class Dep2OHeadSplit(DepLoss):
97 | def __init__(self, conf):
98 | super(Dep2OHeadSplit, self).__init__(conf)
99 | self.inside = eisner2o_headsplit
100 |
101 | class Dep1OHeadSplit(DepLoss):
102 | def __init__(self, conf):
103 | super(Dep1OHeadSplit, self).__init__(conf)
104 | self.inside = eisner_headsplit
105 |
106 | class Span(DepLoss):
107 | def __init__(self, conf):
108 | super(Span, self).__init__(conf)
109 | self.inside = span_inside
110 |
111 |
112 |
--------------------------------------------------------------------------------
/src/loss/get_score.py:
--------------------------------------------------------------------------------
1 | # obtain the score of unlabeled trees.
2 | def u_score(ctx):
3 | score = 0
4 |
5 | if 's_arc' in ctx:
6 | score += get_arc_score(ctx)
7 |
8 | if 's_span_head_word' in ctx:
9 | score += get_span_head_word_score(ctx)
10 |
11 | if 's_bd' in ctx:
12 | score += get_bd_score(ctx)
13 |
14 | if 's_bd_left' in ctx:
15 | assert 's_bd_right' in ctx
16 | score += get_bd_left_score(ctx)
17 | score += get_bd_right_score(ctx)
18 |
19 | if 's_sib' in ctx:
20 | score += get_sib_score(ctx)
21 |
22 | return score
23 |
24 | def predict_score_mm(ctx):
25 | score = 0
26 | if 's_arc' in ctx :
27 | score += (ctx['s_arc'] * ctx['s_arc_grad']).sum()
28 |
29 | if 's_span_head_word' in ctx:
30 | score += (ctx['s_span_head_word'] * ctx['s_span_head_word_grad']).sum()
31 |
32 | if 's_bd' in ctx:
33 | score += (ctx['s_bd'] * ctx['s_bd_grad']).sum()
34 |
35 | if 's_bd_left' in ctx:
36 | score += (ctx['s_bd_left'] * ctx['s_bd_left_grad']).sum()
37 |
38 | if 's_bd_right' in ctx:
39 | score += (ctx['s_bd_right'] * ctx['s_bd_right_grad']).sum()
40 |
41 | if 's_sib' in ctx:
42 | try:
43 | score += (ctx['s_sib'] * ctx['s_sib_grad']).sum()
44 | except:
45 | # corner case: (e.g. sentences of length 2, no siblings)
46 | pass
47 |
48 | return score
49 |
50 |
51 | def augment_score(ctx):
52 | if 's_arc' in ctx:
53 | aug_arc_score(ctx)
54 |
55 | if 's_span_head_word' in ctx:
56 | aug_span_head_word_score(ctx)
57 |
58 | if 's_bd' in ctx:
59 | aug_bd_score(ctx)
60 |
61 | if 's_bd_left' in ctx:
62 | aug_bd_left_score(ctx)
63 |
64 | if 's_bd_right' in ctx:
65 | aug_bd_right_score(ctx)
66 |
67 | if 's_sib' in ctx:
68 | aug_sib_score(ctx)
69 |
70 |
71 |
72 | def get_arc_score(ctx):
73 | s_arc = ctx['s_arc']
74 | arcs = ctx['head']
75 | return s_arc[arcs[:, 0], arcs[:, 1], arcs[:, 2]].sum()
76 |
77 |
78 |
79 | def aug_arc_score(ctx):
80 | s_arc = ctx['s_arc']
81 | arcs = ctx['head']
82 | s_arc[arcs[:, 0], arcs[:, 1], arcs[:, 2]] -= 1
83 |
84 |
85 |
86 | def get_sib_score(ctx):
87 | s_sib = ctx['s_sib']
88 | sib = ctx['sib']
89 | try:
90 | return s_sib[sib[:, 0], sib[:, 1], sib[:, 2], sib[:, 3]].sum()
91 | except:
92 | return 0
93 |
94 | def aug_sib_score(ctx):
95 | s_sib = ctx['s_sib']
96 | sib = ctx['sib']
97 | try:
98 | s_sib[sib[:, 0], sib[:, 1], sib[:, 2], sib[:, 3]] -= 1
99 | except:
100 | pass
101 |
102 |
103 | def get_span_head_word_score(ctx):
104 | span_head_word = ctx['span_head_word']
105 | s_span_head_word = ctx['s_span_head_word']
106 | score = s_span_head_word[span_head_word[:, 0], span_head_word[:, 1], span_head_word[:, 2], span_head_word[:, 3]].sum()
107 | return score
108 |
109 | def aug_span_head_word_score(ctx):
110 | span_head_word = ctx['span_head_word']
111 | s_span_head_word = ctx['s_span_head_word']
112 | s_span_head_word[span_head_word[:, 0], span_head_word[:, 1], span_head_word[:, 2], span_head_word[:, 3]] -= 1.5
113 |
114 |
115 | def get_bd_score(ctx):
116 | s_bd = ctx['s_bd']
117 | span_head = ctx['span_head_word']
118 | score = 0
119 | score += s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 1]].sum()
120 | score += s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]].sum()
121 | return score
122 |
123 | def aug_bd_score(ctx):
124 | s_bd = ctx['s_bd']
125 | span_head = ctx['span_head_word']
126 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,1]] -= 1
127 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,2]] -= 1
128 |
129 |
130 | def get_bd_left_score(ctx):
131 | s_bd = ctx['s_bd_left']
132 | span_head = ctx['span_head_word']
133 | score = s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 1]].sum()
134 | return score
135 |
136 | def aug_bd_left_score(ctx):
137 | s_bd = ctx['s_bd_left']
138 | span_head = ctx['span_head_word']
139 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:,1]] -= 1
140 |
141 |
142 | def get_bd_right_score(ctx):
143 | s_bd = ctx['s_bd_right']
144 | span_head = ctx['span_head_word']
145 | score = s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]].sum()
146 | return score
147 |
148 |
149 | def aug_bd_right_score(ctx):
150 | s_bd = ctx['s_bd_right']
151 | span_head = ctx['span_head_word']
152 | s_bd[span_head[:, 0], span_head[:, 3], span_head[:, 2]] -= 1
153 |
154 |
155 |
156 |
157 |
158 |
--------------------------------------------------------------------------------
/src/model/dep_parsing.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import logging
3 | import hydra
4 | log = logging.getLogger(__name__)
5 |
6 | class ProjectiveDepParser(nn.Module):
7 | def __init__(self, conf, fields):
8 | super(ProjectiveDepParser, self).__init__()
9 | self.conf = conf
10 | self.fields = fields
11 | self.embeder = hydra.utils.instantiate(conf.embeder.target, conf.embeder, fields=fields, _recursive_=False)
12 | self.encoder = hydra.utils.instantiate(conf.encoder.target, conf.encoder, input_dim= self.embeder.get_output_dim(), _recursive_=False)
13 | self.scorer = hydra.utils.instantiate(conf.scorer.target, conf.scorer, fields=fields, input_dim=self.encoder.get_output_dim(), _recursive_=False)
14 | self.loss = hydra.utils.instantiate(conf.loss.target, conf.loss, _recursive_=False)
15 | self.metric = hydra.utils.instantiate(conf.metric.target, conf.metric, fields=fields, _recursive_=False)
16 |
17 | log.info(self.embeder)
18 | log.info(self.encoder)
19 | log.info(self.scorer)
20 |
21 | def forward(self, ctx):
22 | self.embeder(ctx)
23 | self.encoder(ctx)
24 | self.scorer(ctx)
25 |
26 | def get_loss(self, x, y):
27 | ctx = {**x, **y}
28 | self.forward(ctx)
29 | return self.loss.loss(ctx)
30 |
31 | def decode(self, x, y):
32 | ctx = {**x, **y}
33 | self.forward(ctx)
34 | self.loss.decode(ctx)
35 | return ctx
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/src/model/module/ember/embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from supar.modules.char_lstm import CharLSTM
4 | from supar.modules import TransformerEmbedding
5 | from src.model.module.ember.ext_embedding import ExternalEmbeddingSupar
6 | import copy
7 |
8 | class Embeder(nn.Module):
9 | def __init__(self, conf, fields):
10 | super(Embeder, self).__init__()
11 | self.conf = conf
12 |
13 | if 'pos' in fields.inputs:
14 | self.pos_emb = nn.Embedding(fields.get_vocab_size("pos"), conf.n_pos_embed)
15 | else:
16 | self.pos_emb = None
17 |
18 | if 'char' in fields.inputs:
19 | self.feat = CharLSTM(n_chars=fields.get_vocab_size('char'),
20 | n_embed=conf.n_char_embed,
21 | n_out=conf.n_char_out,
22 | pad_index=fields.get_pad_index('char'),
23 | input_dropout=conf.char_input_dropout)
24 | self.feat_name = 'char'
25 |
26 | if 'bert' in fields.inputs:
27 | self.feat = TransformerEmbedding(model=fields.get_bert_name(),
28 | n_layers=conf.n_bert_layers,
29 | n_out=conf.n_bert_out,
30 | pad_index=fields.get_pad_index("bert"),
31 | dropout=conf.mix_dropout,
32 | requires_grad=conf.finetune,
33 | use_projection=conf.use_projection,
34 | use_scalarmix=conf.use_scalarmix)
35 | self.feat_name = "bert"
36 | print(fields.get_bert_name())
37 |
38 | if ('char' not in fields.inputs and 'bert' not in fields.inputs):
39 | self.feat = None
40 |
41 | if 'word' in fields.inputs:
42 | ext_emb = fields.get_ext_emb()
43 | if ext_emb:
44 | self.word_emb = copy.deepcopy(ext_emb)
45 | else:
46 | self.word_emb = nn.Embedding(num_embeddings=fields.get_vocab_size('word'),
47 | embedding_dim=conf.n_embed)
48 | else:
49 | self.word_emb = None
50 |
51 |
52 | def forward(self, ctx):
53 | emb = {}
54 |
55 | if self.pos_emb:
56 | emb['pos'] = self.pos_emb(ctx['pos'])
57 |
58 | if self.word_emb:
59 | emb['word'] = self.word_emb(ctx['word'])
60 |
61 | #For now, char or ber、t, choose one.
62 | if self.feat:
63 | emb[self.feat_name] = self.feat(ctx[self.feat_name])
64 |
65 | ctx['embed'] = emb
66 |
67 |
68 | def get_output_dim(self):
69 |
70 | size = 0
71 |
72 | if self.pos_emb:
73 | size += self.conf.n_pos_embed
74 |
75 | if self.word_emb:
76 | if isinstance(self.word_emb, nn.Embedding):
77 | size += self.conf.n_embed
78 | else:
79 | size += self.word_emb.get_dim()
80 |
81 | if self.feat:
82 | if self.feat_name == 'char':
83 | size += self.conf.n_char_out
84 | else:
85 | size += self.feat.n_out
86 | return size
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/src/model/module/ember/ext_embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import logging
3 |
4 |
5 | log = logging.getLogger(__name__)
6 |
7 | # As Supar.
8 | class ExternalEmbeddingSupar(nn.Module):
9 | def __init__(self, emb, origin_word_size, unk_index):
10 | super(ExternalEmbeddingSupar, self).__init__()
11 |
12 | self.pretrained = nn.Embedding.from_pretrained(emb)
13 | self.origin_word_size = origin_word_size
14 | self.word_emb = nn.Embedding(origin_word_size, emb.shape[-1])
15 | self.unk_index = unk_index
16 | nn.init.zeros_(self.word_emb.weight)
17 |
18 | def forward(self, words):
19 | ext_mask = words.ge(self.word_emb.num_embeddings)
20 | ext_words = words.masked_fill(ext_mask, self.unk_index)
21 | # get outputs from embedding layers
22 | word_embed = self.word_emb(ext_words)
23 | word_embed += self.pretrained(words)
24 | return word_embed
25 |
26 | def get_dim(self):
27 | return self.word_emb.weight.shape[-1]
--------------------------------------------------------------------------------
/src/model/module/encoder/lstm_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from supar.modules import LSTM
3 | from supar.modules.dropout import IndependentDropout, SharedDropout
4 | import torch
5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6 |
7 |
8 | class LSTMencoder(nn.Module):
9 | def __init__(self, conf, input_dim, **kwargs):
10 | super(LSTMencoder, self).__init__()
11 | self.conf = conf
12 |
13 | self.before_lstm_dropout = None
14 |
15 | if self.conf.embed_dropout_type == 'independent':
16 | self.embed_dropout = IndependentDropout(p=conf.embed_dropout)
17 | if conf.before_lstm_dropout:
18 | self.before_lstm_dropout = SharedDropout(p=conf.before_lstm_dropout)
19 |
20 | elif self.conf.embed_dropout_type == 'shared':
21 | self.embed_dropout = SharedDropout(p=conf.embed_dropout)
22 |
23 | elif self.conf.embed_dropout_type == 'simple':
24 | self.embed_dropout = nn.Dropout(p=conf.embed_dropout)
25 |
26 | else:
27 | self.embed_dropout = nn.Dropout(0.)
28 |
29 | self.lstm = LSTM(input_size=input_dim,
30 | hidden_size=conf.n_lstm_hidden,
31 | num_layers=conf.n_lstm_layers,
32 | bidirectional=True,
33 | dropout=conf.lstm_dropout)
34 | self.lstm_dropout = SharedDropout(p=conf.lstm_dropout)
35 |
36 |
37 |
38 | def forward(self, info):
39 | # lstm encoder
40 | embed = info['embed']
41 | seq_len = info['seq_len']
42 |
43 | embed = [i for i in embed.values()]
44 |
45 | if self.conf.embed_dropout_type == 'independent':
46 | embed = self.embed_dropout(embed)
47 | embed = torch.cat(embed, dim=-1)
48 | else:
49 | embed = torch.cat(embed, dim=-1)
50 | embed = self.embed_dropout(embed)
51 |
52 | seq_len = seq_len.cpu()
53 | x = pack_padded_sequence(embed, seq_len.cpu() + (embed.shape[1] - seq_len.max()), True, False)
54 | x, _ = self.lstm(x)
55 | x, _ = pad_packed_sequence(x, True, total_length=embed.shape[1])
56 | x = self.lstm_dropout(x)
57 |
58 | info['encoded_emb'] = x
59 |
60 | def get_output_dim(self):
61 | return self.conf.n_lstm_hidden * 2
62 |
63 |
--------------------------------------------------------------------------------
/src/model/module/scorer/dep_scorer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from .module.biaffine import BiaffineScorer
3 | from .module.triaffine import TriaffineScorer
4 | import torch
5 | import logging
6 | log = logging.getLogger(__name__)
7 |
8 |
9 |
10 | class DepScorer(nn.Module):
11 | def __init__(self, conf, fields, input_dim):
12 | super(DepScorer, self).__init__()
13 | self.conf = conf
14 |
15 |
16 | self.rel_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_rel, bias_x=True, bias_y=True,
17 | dropout=conf.mlp_dropout, n_out_label=fields.get_vocab_size("rel"),
18 | scaling=conf.scaling)
19 |
20 |
21 | if self.conf.use_arc:
22 | self.arc_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=False, dropout=conf.mlp_dropout, scaling=conf.scaling)
23 | log.info("Use arc score")
24 | if self.conf.use_sib:
25 | log.info("Use sib score")
26 | self.sib_scorer = TriaffineScorer(n_in=input_dim, n_out=conf.n_mlp_sib, bias_x=True, bias_y=True,
27 | dropout=conf.mlp_dropout)
28 |
29 | if self.conf.use_span:
30 | log.info('use span score')
31 | if self.conf.span_scorer_type == 'biaffine':
32 | assert not self.conf.use_sib
33 | self.span_scorer = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling)
34 |
35 | elif self.conf.span_scorer_type == 'triaffine':
36 | assert not self.conf.use_sib
37 | self.span_scorer = TriaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, )
38 |
39 | elif self.conf.span_scorer_type == 'headsplit':
40 | assert self.conf.use_arc
41 | self.span_scorer_left = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling)
42 | self.span_scorer_right = BiaffineScorer(n_in=input_dim, n_out=conf.n_mlp_arc, bias_x=True, bias_y=True, dropout=conf.mlp_dropout, scaling=conf.scaling)
43 | else:
44 | raise NotImplementedError
45 |
46 |
47 |
48 | def forward(self, ctx):
49 | x = ctx['encoded_emb']
50 | x_f, x_b = x.chunk(2, -1)
51 | x_boundary = torch.cat((x_f[:, :-1], x_b[:, 1:]), -1)
52 |
53 | if self.conf.use_arc:
54 | ctx['s_arc'] = self.arc_scorer(x[:, :-1])
55 | if self.conf.use_sib:
56 | ctx['s_sib'] = self.sib_scorer(x[:, :-1])
57 |
58 | if self.conf.use_span:
59 | if self.conf.span_scorer_type == 'headsplit':
60 | ctx['s_bd_left'] = self.span_scorer_left.forward_v2(x[:,1:], x_boundary)
61 | ctx['s_bd_right'] = self.span_scorer_right.forward_v2(x[:,1:], x_boundary)
62 |
63 | elif self.conf.span_scorer_type == 'biaffine':
64 | # LSTM minus features
65 | span_repr = (x_boundary.unsqueeze(1) - x_boundary.unsqueeze(2))
66 | batch, seq_len = span_repr.shape[:2]
67 | span_repr2 = span_repr.reshape(batch, seq_len * seq_len, -1)
68 | ctx['s_span_head_word'] = self.span_scorer.forward_v2(span_repr2, x[:, 1:-1]).reshape(batch, seq_len,
69 | seq_len, -1)
70 | elif self.conf.span_scorer_type == 'triaffine':
71 | ctx['s_span_head_word'] = self.span_scorer.forward2(x[:, 1:-1], x_boundary)
72 | else:
73 | raise NotImplementedError
74 |
75 | ctx['s_rel'] = self.rel_scorer(x[:, :-1])
76 |
77 |
78 |
79 |
80 |
--------------------------------------------------------------------------------
/src/model/module/scorer/module/biaffine.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from supar.modules import MLP, Biaffine
3 |
4 | class BiaffineScorer(nn.Module):
5 | def __init__(self, n_in=800, n_out=400, n_out_label=1, bias_x=True, bias_y=False, scaling=False, dropout=0.33):
6 | super(BiaffineScorer, self).__init__()
7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout)
8 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout)
9 | self.attn = Biaffine(n_in=n_out, n_out=n_out_label, bias_x=bias_x, bias_y=bias_y)
10 | self.scaling = 0 if not scaling else n_out ** (1/4)
11 | self.n_in = n_in
12 |
13 | def forward(self, h):
14 | left = self.l(h)
15 | right = self.r(h)
16 |
17 | if self.scaling:
18 | left = left / self.scaling
19 | right = right / self.scaling
20 |
21 | return self.attn(left, right)
22 |
23 |
24 | def forward_v2(self, h, q):
25 | left = self.l(h)
26 | right = self.r(q)
27 | if self.scaling:
28 | left = left / self.scaling
29 | right = right / self.scaling
30 |
31 | return self.attn(left, right)
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/src/model/module/scorer/module/nhpsg_scorer.py:
--------------------------------------------------------------------------------
1 | # import torch.nn as nn
2 | #
3 | #
4 | # class BiAAttention(nn.Module):
5 | # '''
6 | # Bi-Affine attention layer.
7 | # '''
8 | #
9 | # def __init__(self, hparams):
10 | # super(BiAAttention, self).__init__()
11 | # self.hparams = hparams
12 | #
13 | # self.dep_weight = nn.Parameter(torch_t.FloatTensor(hparams.d_biaffine + 1, hparams.d_biaffine + 1))
14 | # nn.init.xavier_uniform_(self.dep_weight)
15 | #
16 | # def forward(self, input_d, input_e, input_s = None):
17 | #
18 | # score = torch.matmul(torch.cat(
19 | # [input_d, torch_t.FloatTensor(input_d.size(0), 1).fill_(1).requires_grad_(False)],
20 | # dim=1), self.dep_weight)
21 | # score1 = torch.matmul(score, torch.transpose(torch.cat(
22 | # [input_e, torch_t.FloatTensor(input_e.size(0), 1).fill_(1).requires_grad_(False)],
23 | # dim=1), 0, 1))
24 | #
25 | # return score1
26 | #
27 | # class Dep_score(nn.Module):
28 | # def __init__(self, hparams, num_labels):
29 | # super(Dep_score, self).__init__()
30 | #
31 | # self.dropout_out = nn.Dropout2d(p=0.33)
32 | # self.hparams = hparams
33 | # out_dim = hparams.d_biaffine#d_biaffine
34 | # self.arc_h = nn.Linear(hparams.d_model, hparams.d_biaffine)
35 | # self.arc_c = nn.Linear(hparams.d_model, hparams.d_biaffine)
36 | #
37 | # self.attention = BiAAttention(hparams)
38 | #
39 | # self.type_h = nn.Linear(hparams.d_model, hparams.d_label_hidden)
40 | # self.type_c = nn.Linear(hparams.d_model, hparams.d_label_hidden)
41 | # self.bilinear = BiLinear(hparams.d_label_hidden, hparams.d_label_hidden, num_labels)
42 | #
43 | # def forward(self, outputs, outpute):
44 | # # output from rnn [batch, length, hidden_size]
45 | #
46 | # # apply dropout for output
47 | # # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
48 | # outpute = self.dropout_out(outpute.transpose(1, 0)).transpose(1, 0)
49 | # outputs = self.dropout_out(outputs.transpose(1, 0)).transpose(1, 0)
50 | #
51 | # # output size [batch, length, arc_space]
52 | # arc_h = nn.functional.relu(self.arc_h(outputs))
53 | # arc_c = nn.functional.relu(self.arc_c(outpute))
54 | #
55 | # # output size [batch, length, type_space]
56 | # type_h = nn.functional.relu(self.type_h(outputs))
57 | # type_c = nn.functional.relu(self.type_c(outpute))
58 | #
59 | # # apply dropout
60 | # # [batch, length, dim] --> [batch, 2 * length, dim]
61 | # arc = torch.cat([arc_h, arc_c], dim=0)
62 | # type = torch.cat([type_h, type_c], dim=0)
63 | #
64 | # arc = self.dropout_out(arc.transpose(1, 0)).transpose(1, 0)
65 | # arc_h, arc_c = arc.chunk(2, 0)
66 | #
67 | # type = self.dropout_out(type.transpose(1, 0)).transpose(1, 0)
68 | # type_h, type_c = type.chunk(2, 0)
69 | # type_h = type_h.contiguous()
70 | # type_c = type_c.contiguous()
71 | #
72 | # out_arc = self.attention(arc_h, arc_c)
73 | # out_type = self.bilinear(type_h, type_c)
74 | #
75 | # return out_arc, out_type
76 |
--------------------------------------------------------------------------------
/src/model/module/scorer/module/quadra_linear.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 | import pdb
5 | import opt_einsum
6 |
7 | ## TODO: 完成一个quadra-linear的label的打分函数?
8 | class QlinearScorer(nn.Module):
9 | """
10 | Outer product version of trilinear function.
11 | Trilinear attention layer.
12 | """
13 |
14 | def __init__(self, input_size_1, input_size_2, input_size_3, input_size_4, factorize=True,
15 | **kwargs):
16 | """
17 | Args:
18 | input_size_encoder: int
19 | the dimension of the encoder input.
20 | input_size_decoder: int
21 | the dimension of the decoder input.
22 | num_labels: int
23 | the number of labels of the crf layer
24 | biaffine: bool
25 | if apply bi-affine parameter.
26 | **kwargs:
27 | """
28 | # super(QlinearScorer, self).__init__()
29 | # self.input_size_1 = input_size_1 + 1
30 | # self.input_size_2 = input_size_2 + 1
31 | # self.input_size_3 = input_size_3 + 1
32 | # self.input_size_4 = input_size_3 + 1
33 | # self.rank = rank
34 | # self.init_std = init_std
35 | # self.factorize = factorize
36 | # if not factorize:
37 | # self.W = Parameter(torch.Tensor(self.input_size_1, self.input_size_2, self.input_size_3, self.input_size_4))
38 | # else:
39 | # self.W_1 = Parameter(torch.Tensor(self.input_size_1, self.rank))
40 | # self.W_2 = Parameter(torch.Tensor(self.input_size_2, self.rank))
41 | # self.W_3 = Parameter(torch.Tensor(self.input_size_3, self.rank))
42 | # self.W_4 = Parameter(torch.Tensor(self.input_size_4, self.rank))
43 | # self.W_5 = Parameter(torch.Tensor(self.rank, ))
44 | # # if self.biaffine:
45 | # # self.U = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder, self.input_size_encoder))
46 | # # else:
47 | # # self.register_parameter('U', None)
48 | #
49 | # self.reset_parameters()
50 |
51 | # def reset_parameters(self):
52 | # if not self.factorize:
53 | # nn.init.xavier_normal_(self.W)
54 | # else:
55 | # nn.init.xavier_normal_(self.W_1, gain=self.init_std)
56 | # nn.init.xavier_normal_(self.W_2, gain=self.init_std)
57 | # nn.init.xavier_normal_(self.W_3, gain=self.init_std)
58 | # nn.init.xavier_normal_(self.W_4, gain=self.init_std)
59 |
60 | def forward(self, layer1, layer2, layer3, layer4 = delattr()):
61 | """
62 | Args:
63 |
64 | Returns: Tensor
65 | the energy tensor with shape = [batch, num_label, length, length]
66 | """
67 | assert layer1.size(0) == layer2.size(0), 'batch sizes of encoder and decoder are requires to be equal.'
68 | layer_shape = layer1.size()
69 | one_shape = list(layer_shape[:2]) + [1]
70 | ones = torch.ones(one_shape).cuda()
71 | layer1 = torch.cat([layer1, ones], -1)
72 | layer2 = torch.cat([layer2, ones], -1)
73 | layer3 = torch.cat([layer3, ones], -1)
74 | layer4 = torch.cat([layer4, ones], -1)
75 |
76 | layer = opt_einsum.contract('bnx,xr,bny,yr,bnz,zr, bns,sr->bn', self.W_1, layer1, self.W_2, layer2, self.W_3, layer3, self.W_4, layer4, backend='torch')
77 | return layer
78 |
--------------------------------------------------------------------------------
/src/model/module/scorer/module/triaffine.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from supar.modules import MLP, Triaffine
3 |
4 | class TriaffineScorer(nn.Module):
5 | def __init__(self, n_in=800, n_out=400, bias_x=True, bias_y=False, dropout=0.33):
6 | super(TriaffineScorer, self).__init__()
7 | self.l = MLP(n_in=n_in, n_out=n_out, dropout=dropout)
8 | self.m = MLP(n_in=n_in, n_out=n_out, dropout=dropout)
9 | self.r = MLP(n_in=n_in, n_out=n_out, dropout=dropout)
10 | self.attn = Triaffine(n_in=n_out, bias_x=bias_x, bias_y=bias_y)
11 |
12 | def forward(self, h):
13 | left = self.l(h)
14 | mid = self.m(h)
15 | right = self.r(h)
16 | #sib, dependent, head)
17 | return self.attn(left, mid, right).permute(0, 2, 3, 1)
18 |
19 | def forward2(self, word, span):
20 | left = self.l(word)
21 | mid = self.m(span)
22 | right = self.r(span)
23 | # head, left_bdr, right_bdr: used in span-head model?
24 |
25 | # word, left, right
26 | # fine.
27 | return self.attn(mid, right, left).permute(0, 2, 3, 1)
28 |
29 |
30 |
31 | # class TriaffineScorer
32 | # class TriaffineScorer(nn.Module):
33 |
34 |
35 |
36 |
37 |
38 |
--------------------------------------------------------------------------------
/supar/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 |
5 | __version__ = '1.0.0'
6 |
7 |
8 | PRETRAINED = {
9 | 'biaffine-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.char.zip',
10 | 'biaffine-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.char.zip',
11 | 'biaffine-dep-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.biaffine.dependency.bert.zip',
12 | 'biaffine-dep-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.biaffine.dependency.bert.zip',
13 | 'crfnp-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crfnp.dependency.char.zip',
14 | 'crfnp-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crfnp.dependency.char.zip',
15 | 'crf-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.dependency.char.zip',
16 | 'crf-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.dependency.char.zip',
17 | 'crf2o-dep-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf2o.dependency.char.zip',
18 | 'crf2o-dep-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf2o.dependency.char.zip',
19 | 'crf-con-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.char.zip',
20 | 'crf-con-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.char.zip',
21 | 'crf-con-bert-en': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ptb.crf.constituency.bert.zip',
22 | 'crf-con-bert-zh': 'https://github.com/yzhangcs/parser/releases/download/v1.0.0/ctb7.crf.constituency.bert.zip'
23 | }
24 |
--------------------------------------------------------------------------------
/supar/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .affine import Biaffine, Triaffine
4 | from .bert import BertEmbedding
5 | from .char_lstm import CharLSTM
6 | from .dropout import IndependentDropout, SharedDropout
7 | from .lstm import LSTM
8 | from .mlp import MLP
9 | from .scalar_mix import ScalarMix
10 | from .treecrf import (CRF2oDependency, CRFConstituency, CRFDependency,
11 | MatrixTree)
12 | from .variational_inference import (LBPSemanticDependency,
13 | MFVISemanticDependency)
14 | from .transformer import TransformerEmbedding
15 |
16 | __all__ = ['LSTM', 'MLP', 'BertEmbedding', 'Biaffine', 'CharLSTM', 'CRF2oDependency', 'CRFConstituency', 'CRFDependency',
17 | 'IndependentDropout', 'LBPSemanticDependency', 'MatrixTree',
18 | 'MFVISemanticDependency', 'ScalarMix', 'SharedDropout', 'Triaffine', 'TransformerEmbedding']
19 |
--------------------------------------------------------------------------------
/supar/modules/char_lstm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.utils.rnn import pack_padded_sequence
6 |
7 |
8 | class CharLSTM(nn.Module):
9 | r"""
10 | CharLSTM aims to generate character-level embeddings for tokens.
11 | It summerizes the information of characters in each token to an embedding using a LSTM layer.
12 |
13 | Args:
14 | n_char (int):
15 | The number of characters.
16 | n_embed (int):
17 | The size of each embedding vector as input to LSTM.
18 | n_out (int):
19 | The size of each output vector.
20 | pad_index (int):
21 | The index of the padding token in the vocabulary. Default: 0.
22 | """
23 |
24 | def __init__(self, n_chars, n_embed, n_out, pad_index=0, input_dropout=0.):
25 | super().__init__()
26 |
27 | self.n_chars = n_chars
28 | self.n_embed = n_embed
29 | self.n_out = n_out
30 | self.pad_index = pad_index
31 |
32 | # the embedding layer
33 | self.embed = nn.Embedding(num_embeddings=n_chars,
34 | embedding_dim=n_embed)
35 | # the lstm layer
36 | self.lstm = nn.LSTM(input_size=n_embed,
37 | hidden_size=n_out//2,
38 | batch_first=True,
39 | bidirectional=True)
40 |
41 | self.input_dropout = nn.Dropout(input_dropout)
42 |
43 | def __repr__(self):
44 | return f"{self.__class__.__name__}({self.n_chars}, {self.n_embed}, n_out={self.n_out}, pad_index={self.pad_index})"
45 |
46 | def forward(self, x):
47 | r"""
48 | Args:
49 | x (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
50 | Characters of all tokens.
51 | Each token holds no more than `fix_len` characters, and the excess is cut off directly.
52 | Returns:
53 | ~torch.Tensor:
54 | The embeddings of shape ``[batch_size, seq_len, n_out]`` derived from the characters.
55 | """
56 | # [batch_size, seq_len, fix_len]
57 | mask = x.ne(self.pad_index)
58 | # [batch_size, seq_len]
59 | lens = mask.sum(-1)
60 | char_mask = lens.gt(0)
61 |
62 | # [n, fix_len, n_embed]
63 | x = self.embed(x[char_mask])
64 | x = self.input_dropout(x)
65 | x = pack_padded_sequence(x, lens[char_mask].cpu(), True, False)
66 | x, (h, _) = self.lstm(x)
67 | # [n, fix_len, n_out]
68 | h = torch.cat(torch.unbind(h), -1)
69 | # [batch_size, seq_len, n_out]
70 | embed = h.new_zeros(*lens.shape, self.n_out)
71 | embed = embed.masked_scatter_(char_mask.unsqueeze(-1), h)
72 |
73 | return embed
74 |
--------------------------------------------------------------------------------
/supar/modules/dropout.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class SharedDropout(nn.Module):
8 | r"""
9 | SharedDropout differs from the vanilla dropout strategy in that
10 | the dropout mask is shared across one dimension.
11 |
12 | Args:
13 | p (float):
14 | The probability of an element to be zeroed. Default: 0.5.
15 | batch_first (bool):
16 | If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``.
17 | Default: ``True``.
18 |
19 | Examples:
20 | >>> x = torch.ones(1, 3, 5)
21 | >>> nn.Dropout()(x)
22 | tensor([[[0., 2., 2., 0., 0.],
23 | [2., 2., 0., 2., 2.],
24 | [2., 2., 2., 2., 0.]]])
25 | >>> SharedDropout()(x)
26 | tensor([[[2., 0., 2., 0., 2.],
27 | [2., 0., 2., 0., 2.],
28 | [2., 0., 2., 0., 2.]]])
29 | """
30 |
31 | def __init__(self, p=0.5, batch_first=True):
32 | super().__init__()
33 |
34 | self.p = p
35 | self.batch_first = batch_first
36 |
37 | def __repr__(self):
38 | s = f"p={self.p}"
39 | if self.batch_first:
40 | s += f", batch_first={self.batch_first}"
41 |
42 | return f"{self.__class__.__name__}({s})"
43 |
44 | def forward(self, x):
45 | r"""
46 | Args:
47 | x (~torch.Tensor):
48 | A tensor of any shape.
49 | Returns:
50 | The returned tensor is of the same shape as `x`.
51 | """
52 |
53 | if self.training:
54 | if self.batch_first:
55 | mask = self.get_mask(x[:, 0], self.p).unsqueeze(1)
56 | else:
57 | mask = self.get_mask(x[0], self.p)
58 | x = x * mask
59 |
60 | return x
61 |
62 | @staticmethod
63 | def get_mask(x, p):
64 | return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p)
65 |
66 |
67 | class IndependentDropout(nn.Module):
68 | r"""
69 | For :math:`N` tensors, they use different dropout masks respectively.
70 | When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` to compensate,
71 | and when all of them are dropped together, zeros are returned.
72 |
73 | Args:
74 | p (float):
75 | The probability of an element to be zeroed. Default: 0.5.
76 |
77 | Examples:
78 | >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5)
79 | >>> x, y = IndependentDropout()(x, y)
80 | >>> x
81 | tensor([[[1., 1., 1., 1., 1.],
82 | [0., 0., 0., 0., 0.],
83 | [2., 2., 2., 2., 2.]]])
84 | >>> y
85 | tensor([[[1., 1., 1., 1., 1.],
86 | [2., 2., 2., 2., 2.],
87 | [0., 0., 0., 0., 0.]]])
88 | """
89 |
90 | def __init__(self, p=0.5):
91 | super().__init__()
92 |
93 | self.p = p
94 |
95 | def __repr__(self):
96 | return f"{self.__class__.__name__}(p={self.p})"
97 |
98 | def forward(self, items):
99 | r"""
100 | Args:
101 | items (list[~torch.Tensor]):
102 | A list of tensors that have the same shape except the last dimension.
103 | Returns:
104 | The returned tensors are of the same shape as `items`.
105 | """
106 |
107 | if self.training:
108 | masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items]
109 | total = sum(masks)
110 | scale = len(items) / total.max(torch.ones_like(total))
111 | masks = [mask * scale for mask in masks]
112 | items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)]
113 |
114 | return items
115 |
--------------------------------------------------------------------------------
/supar/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | from supar.modules.dropout import SharedDropout
5 |
6 |
7 | class MLP(nn.Module):
8 | r"""
9 | Applies a linear transformation together with a non-linear activation to the incoming tensor:
10 | :math:`y = \mathrm{Activation}(x A^T + b)`
11 |
12 | Args:
13 | n_in (~torch.Tensor):
14 | The size of each input feature.
15 | n_out (~torch.Tensor):
16 | The size of each output feature.
17 | dropout (float):
18 | If non-zero, introduce a :class:`SharedDropout` layer on the output with this dropout ratio. Default: 0.
19 | activation (bool):
20 | Whether to use activations. Default: True.
21 | """
22 |
23 | def __init__(self, n_in, n_out, dropout=0, activation=True):
24 | super().__init__()
25 |
26 | self.n_in = n_in
27 | self.n_out = n_out
28 | self.linear = nn.Linear(n_in, n_out)
29 | self.activation = nn.LeakyReLU(negative_slope=0.1) if activation else nn.Identity()
30 | self.dropout = SharedDropout(p=dropout)
31 |
32 | self.reset_parameters()
33 |
34 | def __repr__(self):
35 | s = f"n_in={self.n_in}, n_out={self.n_out}"
36 | if self.dropout.p > 0:
37 | s += f", dropout={self.dropout.p}"
38 |
39 | return f"{self.__class__.__name__}({s})"
40 |
41 | def reset_parameters(self):
42 | nn.init.orthogonal_(self.linear.weight)
43 | nn.init.zeros_(self.linear.bias)
44 |
45 | def forward(self, x):
46 | r"""
47 | Args:
48 | x (~torch.Tensor):
49 | The size of each input feature is `n_in`.
50 |
51 | Returns:
52 | A tensor with the size of each output feature `n_out`.
53 | """
54 |
55 | x = self.linear(x)
56 | x = self.activation(x)
57 | x = self.dropout(x)
58 |
59 | return x
60 |
--------------------------------------------------------------------------------
/supar/modules/scalar_mix.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class ScalarMix(nn.Module):
8 | r"""
9 | Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)`
10 | where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters.
11 |
12 | Args:
13 | n_layers (int):
14 | The number of layers to be mixed, i.e., :math:`N`.
15 | dropout (float):
16 | The dropout ratio of the layer weights.
17 | If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0
18 | with the dropout probability (i.e., setting the unnormalized weight to -inf).
19 | This effectively redistributes the dropped probability mass to all other weights.
20 | Default: 0.
21 | """
22 |
23 | def __init__(self, n_layers, dropout=0):
24 | super().__init__()
25 |
26 | self.n_layers = n_layers
27 |
28 | self.weights = nn.Parameter(torch.zeros(n_layers))
29 | self.gamma = nn.Parameter(torch.tensor([1.0]))
30 | self.dropout = nn.Dropout(dropout)
31 |
32 | def __repr__(self):
33 | s = f"n_layers={self.n_layers}"
34 | if self.dropout.p > 0:
35 | s += f", dropout={self.dropout.p}"
36 |
37 | return f"{self.__class__.__name__}({s})"
38 |
39 | def forward(self, tensors):
40 | r"""
41 | Args:
42 | tensors (list[~torch.Tensor]):
43 | :math:`N` tensors to be mixed.
44 |
45 | Returns:
46 | The mixture of :math:`N` tensors.
47 | """
48 |
49 | normed_weights = self.dropout(self.weights.softmax(-1))
50 | weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors))
51 |
52 | return self.gamma * weighted_sum
53 |
--------------------------------------------------------------------------------
/supar/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from . import alg, field, fn, metric, transform
4 | from .alg import chuliu_edmonds, cky, eisner, eisner2o, kmeans, mst, tarjan
5 | from .config import Config
6 | from .data import Dataset
7 | from .embedding import Embedding
8 | from .field import ChartField, Field, RawField, SubwordField
9 | from .transform import CoNLL, Transform, Tree
10 | from .vocab import Vocab
11 |
12 | __all__ = ['ChartField', 'CoNLL', 'Config', 'Dataset', 'Embedding', 'Field',
13 | 'RawField', 'SubwordField', 'Transform', 'Tree', 'Vocab',
14 | 'alg', 'field', 'fn', 'metric', 'chuliu_edmonds', 'cky',
15 | 'eisner', 'eisner2o', 'kmeans', 'mst', 'tarjan', 'transform']
16 |
--------------------------------------------------------------------------------
/supar/utils/common.py:
--------------------------------------------------------------------------------
1 | PAD = ''
2 | UNK = ''
3 | BOS = ''
4 | EOS = ''
5 | subword_bos = ''
6 | subword_eos = '<_eos>'
7 | no_label = ''
8 |
9 |
--------------------------------------------------------------------------------
/supar/utils/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from ast import literal_eval
4 | from configparser import ConfigParser
5 |
6 |
7 | class Config(object):
8 |
9 | def __init__(self, conf=None, **kwargs):
10 | super(Config, self).__init__()
11 |
12 | config = ConfigParser()
13 | config.read(conf or [])
14 | self.update({**dict((name, literal_eval(value))
15 | for section in config.sections()
16 | for name, value in config.items(section)),
17 | **kwargs})
18 |
19 | def __repr__(self):
20 | s = line = "-" * 20 + "-+-" + "-" * 30 + "\n"
21 | s += f"{'Param':20} | {'Value':^30}\n" + line
22 | for name, value in vars(self).items():
23 | s += f"{name:20} | {str(value):^30}\n"
24 | s += line
25 |
26 | return s
27 |
28 | def __getitem__(self, key):
29 | return getattr(self, key)
30 |
31 | def __getstate__(self):
32 | return vars(self)
33 |
34 | def __setstate__(self, state):
35 | self.__dict__.update(state)
36 |
37 | def keys(self):
38 | return vars(self).keys()
39 |
40 | def items(self):
41 | return vars(self).items()
42 |
43 | def update(self, kwargs):
44 | for key in ('self', 'cls', '__class__'):
45 | kwargs.pop(key, None)
46 | kwargs.update(kwargs.pop('kwargs', dict()))
47 | for name, value in kwargs.items():
48 | setattr(self, name, value)
49 |
50 | return self
51 |
52 | def pop(self, key, val=None):
53 | return self.__dict__.pop(key, val)
54 |
--------------------------------------------------------------------------------
/supar/utils/embedding.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 |
6 | class Embedding(object):
7 |
8 | def __init__(self, tokens, vectors, unk=None):
9 | self.tokens = tokens
10 | self.vectors = torch.tensor(vectors)
11 | self.pretrained = {w: v for w, v in zip(tokens, vectors)}
12 | self.unk = unk
13 |
14 | def __len__(self):
15 | return len(self.tokens)
16 |
17 | def __contains__(self, token):
18 | return token in self.pretrained
19 |
20 | @property
21 | def dim(self):
22 | return self.vectors.size(1)
23 |
24 | @property
25 | def unk_index(self):
26 | if self.unk is not None:
27 | return self.tokens.index(self.unk)
28 | else:
29 | raise AttributeError
30 |
31 | @classmethod
32 | def load(cls, path, unk=None):
33 | with open(path, 'r', encoding="utf-8") as f:
34 | lines = [line for line in f]
35 |
36 | if len(lines[0].split()) == 2:
37 | lines = lines[1:]
38 |
39 | splits = [line.split() for line in lines]
40 | tokens, vectors = zip(*[(s[0], list(map(float, s[1:])))
41 | for s in splits])
42 |
43 | return cls(tokens, vectors, unk=unk)
44 |
--------------------------------------------------------------------------------
/supar/utils/fn.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import unicodedata
4 |
5 |
6 | def ispunct(token):
7 | return all(unicodedata.category(char).startswith('P')
8 | for char in token)
9 |
10 |
11 | def isfullwidth(token):
12 | return all(unicodedata.east_asian_width(char) in ['W', 'F', 'A']
13 | for char in token)
14 |
15 |
16 | def islatin(token):
17 | return all('LATIN' in unicodedata.name(char)
18 | for char in token)
19 |
20 |
21 | def isdigit(token):
22 | return all('DIGIT' in unicodedata.name(char)
23 | for char in token)
24 |
25 |
26 | def tohalfwidth(token):
27 | return unicodedata.normalize('NFKC', token)
28 |
29 |
30 | def stripe(x, n, w, offset=(0, 0), dim=1):
31 | # r"""
32 | # Returns a diagonal stripe of the tensor.
33 | #
34 | # Args:
35 | # x (~torch.Tensor): the input tensor with 2 or more dims.
36 | # n (int): the length of the stripe.
37 | # w (int): the width of the stripe.
38 | # offset (tuple): the offset of the first two dims.
39 | # dim (int): 1 if returns a horizontal stripe; 0 otherwise.
40 | #
41 | # Returns:
42 | # a diagonal stripe of the tensor.
43 | # Examples:
44 | # >>> x = torch.arange(25).view(5, 5)
45 | # >>> x
46 | # tensor([[ 0, 1, 2, 3, 4],
47 | # [ 5, 6, 7, 8, 9],
48 | # [10, 11, 12, 13, 14],
49 | # [15, 16, 17, 18, 19],
50 | # [20, 21, 22, 23, 24]])
51 | # >>> stripe(x, 2, 3)
52 | # tensor([[0, 1, 2],
53 | # [6, 7, 8]])
54 | # >>> stripe(x, 2, 3, (1, 1))
55 | # tensor([[ 6, 7, 8],
56 | # [12, 13, 14]])
57 | # >>> stripe(x, 2, 3, (1, 1), 0)
58 | # tensor([[ 6, 11, 16],
59 | # [12, 17, 22]])
60 | # """
61 |
62 | x, seq_len = x.contiguous(), x.size(1)
63 | stride, numel = list(x.stride()), x[0, 0].numel()
64 | stride[0] = (seq_len + 1) * numel
65 | stride[1] = (1 if dim == 1 else seq_len) * numel
66 | return x.as_strided(size=(n, w, *x.shape[2:]),
67 | stride=stride,
68 | storage_offset=(offset[0]*seq_len+offset[1])*numel)
69 |
70 |
71 | def pad(tensors, padding_value=0, total_length=None, padding_side='right'):
72 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors)
73 | for i in range(len(tensors[0].size()))]
74 | if total_length is not None:
75 | assert total_length >= size[1]
76 | size[1] = total_length
77 | out_tensor = tensors[0].data.new(*size).fill_(padding_value)
78 | for i, tensor in enumerate(tensors):
79 | out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor
80 | return out_tensor
81 |
82 |
83 | if __name__ == '__main__':
84 | import torch
85 | x = torch.arange(25).view(5, 5)
86 | print(stripe(x, 2, 3, (3, 1)))
87 |
--------------------------------------------------------------------------------
/supar/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import logging
4 | import os
5 |
6 | from supar.utils.parallel import is_master
7 | from tqdm import tqdm
8 |
9 |
10 | def get_logger(name):
11 | return logging.getLogger(name)
12 |
13 |
14 | def init_logger(logger,
15 | path=None,
16 | mode='w',
17 | level=None,
18 | handlers=None,
19 | verbose=True):
20 | level = level or logging.WARNING
21 | if not handlers:
22 | handlers = [logging.StreamHandler()]
23 | if path:
24 | os.makedirs(os.path.dirname(path), exist_ok=True)
25 | handlers.append(logging.FileHandler(path, mode))
26 | logging.basicConfig(format='%(asctime)s %(levelname)s %(message)s',
27 | datefmt='%Y-%m-%d %H:%M:%S',
28 | level=level,
29 | handlers=handlers)
30 | logger.setLevel(logging.INFO if is_master() and verbose else logging.WARNING)
31 |
32 |
33 | def progress_bar(iterator,
34 | ncols=None,
35 | bar_format='{l_bar}{bar:18}| {n_fmt}/{total_fmt} {elapsed}<{remaining}, {rate_fmt}{postfix}',
36 | leave=True):
37 | return tqdm(iterator,
38 | ncols=ncols,
39 | bar_format=bar_format,
40 | ascii=True,
41 | disable=(not (logger.level == logging.INFO and is_master())),
42 | leave=leave)
43 |
44 |
45 | logger = get_logger('supar')
46 |
--------------------------------------------------------------------------------
/supar/utils/parallel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | from random import Random
5 |
6 | import torch
7 | import torch.distributed as dist
8 | import torch.nn as nn
9 |
10 |
11 | class DistributedDataParallel(nn.parallel.DistributedDataParallel):
12 |
13 | def __init__(self, module, **kwargs):
14 | super().__init__(module, **kwargs)
15 |
16 | def __getattr__(self, name):
17 | wrapped = super().__getattr__('module')
18 | if hasattr(wrapped, name):
19 | return getattr(wrapped, name)
20 | return super().__getattr__(name)
21 |
22 |
23 | def init_device(device, local_rank=-1, backend='nccl', host=None, port=None):
24 | os.environ['CUDA_VISIBLE_DEVICES'] = device
25 | if torch.cuda.device_count() > 1:
26 | host = host or os.environ.get('MASTER_ADDR', 'localhost')
27 | port = port or os.environ.get('MASTER_PORT', str(Random(0).randint(10000, 20000)))
28 | os.environ['MASTER_ADDR'] = host
29 | os.environ['MASTER_PORT'] = port
30 | dist.init_process_group(backend)
31 | torch.cuda.set_device(local_rank)
32 |
33 |
34 | def is_master():
35 | return not dist.is_available() or not dist.is_initialized() or dist.get_rank() == 0
36 |
--------------------------------------------------------------------------------
/supar/utils/vocab.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from collections import defaultdict
4 | from collections.abc import Iterable
5 |
6 |
7 | class Vocab(object):
8 | r"""
9 | Defines a vocabulary object that will be used to numericalize a field.
10 |
11 | Args:
12 | counter (~collections.Counter):
13 | :class:`~collections.Counter` object holding the frequencies of each value found in the data.
14 | min_freq (int):
15 | The minimum frequency needed to include a token in the vocabulary. Default: 1.
16 | specials (list[str]):
17 | The list of special tokens (e.g., pad, unk, bos and eos) that will be prepended to the vocabulary. Default: [].
18 | unk_index (int):
19 | The index of unk token. Default: 0.
20 |
21 | Attributes:
22 | itos:
23 | A list of token strings indexed by their numerical identifiers.
24 | stoi:
25 | A :class:`~collections.defaultdict` object mapping token strings to numerical identifiers.
26 | """
27 |
28 | def __init__(self, counter, min_freq=1, specials=[], unk_index=0):
29 | self.itos = list(specials)
30 | self.stoi = defaultdict(lambda: unk_index)
31 |
32 | self.stoi.update({token: i for i, token in enumerate(self.itos)})
33 | self.extend([token for token, freq in counter.items()
34 | if freq >= min_freq])
35 | self.unk_index = unk_index
36 | self.n_init = len(self)
37 |
38 | def __len__(self):
39 | return len(self.itos)
40 |
41 | def __getitem__(self, key):
42 | if isinstance(key, str):
43 | return self.stoi[key]
44 | elif not isinstance(key, Iterable):
45 | return self.itos[key]
46 | elif isinstance(key[0], str):
47 | return [self.stoi[i] for i in key]
48 | else:
49 | return [self.itos[i] for i in key]
50 |
51 | def __contains__(self, token):
52 | return token in self.stoi
53 |
54 | def __getstate__(self):
55 | # avoid picking defaultdict
56 | attrs = dict(self.__dict__)
57 | # cast to regular dict
58 | attrs['stoi'] = dict(self.stoi)
59 | return attrs
60 |
61 | def __setstate__(self, state):
62 | stoi = defaultdict(lambda: self.unk_index)
63 | stoi.update(state['stoi'])
64 | state['stoi'] = stoi
65 | self.__dict__.update(state)
66 |
67 |
68 | ### do not change
69 | def extend(self, tokens):
70 | try:
71 | self.stoi = defaultdict(lambda: self.unk_index)
72 | for word in self.itos:
73 | self.stoi[word]
74 | except:
75 | pass
76 |
77 | self.itos.extend(sorted(set(tokens).difference(self.stoi)))
78 | self.stoi.update({token: i for i, token in enumerate(self.itos)})
79 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from pathlib import Path
4 | from typing import List
5 |
6 | import hydra
7 | import omegaconf
8 | import pytorch_lightning as pl
9 | from hydra.core.hydra_config import HydraConfig
10 | import wandb
11 | from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
12 | from pytorch_lightning.loggers import LightningLoggerBase
13 | from pytorch_lightning import seed_everything
14 |
15 | # hydra imports
16 | from omegaconf import DictConfig
17 | from hydra.utils import log
18 | import hydra
19 | from omegaconf import OmegaConf
20 |
21 | # normal imports
22 | from typing import List
23 | import warnings
24 | import logging
25 | from pytorch_lightning.loggers import WandbLogger
26 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
27 | import sys
28 | # sys.setdefaultencoding() does not exist, here!
29 | from omegaconf import OmegaConf, open_dict
30 |
31 |
32 | def train(config):
33 | # if contains this, means we are multi-run and optuna-ing
34 | log.info(OmegaConf.to_container(config,resolve=True))
35 | config.root = hydra.utils.get_original_cwd()
36 | limit_train_batches = 1.0
37 |
38 | hydra_dir = str(os.getcwd())
39 | seed_everything(config.seed)
40 | os.chdir(hydra.utils.get_original_cwd())
41 |
42 | # Instantiate datamodule
43 | hydra.utils.log.info(os.getcwd())
44 | hydra.utils.log.info(f"Instantiating <{config.datamodule.target}>")
45 | # Instantiate callbacks and logger.
46 | callbacks: List[Callback] = []
47 | logger: List[LightningLoggerBase] = []
48 |
49 | datamodule: pl.LightningDataModule = hydra.utils.instantiate(
50 | config.datamodule.target, config.datamodule, _recursive_=False
51 | )
52 |
53 | log.info("created datamodule")
54 | datamodule.setup()
55 | model = hydra.utils.instantiate(config.runner, cfg = config, fields=datamodule.fields, datamodule=datamodule, _recursive_=False)
56 |
57 | os.chdir(hydra_dir)
58 | # Train the model ⚡
59 | if "callbacks" in config:
60 | for _, cb_conf in config["callbacks"].items():
61 | if "_target_" in cb_conf:
62 | log.info(f"Instantiating callback <{cb_conf._target_}>")
63 | callbacks.append(hydra.utils.instantiate(cb_conf))
64 |
65 | if config.checkpoint:
66 | callbacks.append(
67 | ModelCheckpoint(
68 | monitor='valid/score',
69 | mode='max',
70 | save_last=False,
71 | filename='checkpoint'
72 | )
73 | )
74 | log.info("Instantiating callback, ModelCheckpoint")
75 |
76 |
77 | if config.wandb:
78 | logger.append(hydra.utils.instantiate(config.logger))
79 |
80 | log.info(f"Instantiating trainer <{config.trainer._target_}>")
81 | trainer: Trainer = hydra.utils.instantiate(
82 | config.trainer, callbacks=callbacks, logger=logger,
83 | replace_sampler_ddp=False,
84 | # accelerator='ddp' if distributed else None,
85 | accumulate_grad_batches=config.accumulation,
86 | limit_train_batches=limit_train_batches,
87 | checkpoint_callback=False
88 | )
89 |
90 | log.info(f"Starting training!")
91 | if config.wandb:
92 | logger[-1].experiment.save(str(hydra_dir) + "/.hydra/*", base_path=str(hydra_dir))
93 |
94 | trainer.fit(model, datamodule)
95 | log.info(f"Finalizing!")
96 |
97 | if config.wandb:
98 | logger[-1].experiment.save(str(hydra_dir) + "/*.log", base_path=str(hydra_dir))
99 | wandb.finish()
100 |
101 | log.info(f'hydra_path:{os.getcwd()}')
102 |
103 |
104 | @hydra.main(config_path="configs/", config_name="config.yaml")
105 | def main(config):
106 | train(config)
107 |
108 | if __name__ == "__main__":
109 | main()
110 |
111 |
112 |
--------------------------------------------------------------------------------