├── .gitignore
├── LICENSE
├── README.md
├── assets
└── task.png
├── config
├── config_test.yaml
├── config_train.yaml
├── data
│ ├── vlparse.yaml
│ └── vlparse_lang_only.yaml
├── exp
│ ├── lang_only.yaml
│ └── vlgae.yaml
├── hydra
│ └── job_logging
│ │ ├── custom.yaml
│ │ └── nofile.yaml
├── model
│ ├── embedding
│ │ └── en.yaml
│ ├── lang_only.yaml
│ ├── metric
│ │ ├── attachment.yaml
│ │ └── attachment_box_rel.yaml
│ ├── optimize
│ │ ├── constant.yaml
│ │ └── linear.yaml
│ └── vlgae.yaml
└── trainer
│ ├── callbacks
│ ├── best_watcher.yaml
│ ├── early_stopping.yaml
│ ├── lr_monitor.yaml
│ ├── progressbar.yaml
│ ├── wandb.yaml
│ └── weights_summary.yaml
│ ├── debug.yaml
│ ├── logger
│ └── wandb.yaml
│ ├── test.yaml
│ └── train.yaml
├── data
├── data_format.json
└── vlparse.json
├── eval.py
├── requirements.txt
├── src
├── __init__.py
├── datamodule
│ ├── __init__.py
│ ├── datamodule.py
│ ├── sampler.py
│ ├── task
│ │ ├── __init__.py
│ │ ├── dep.py
│ │ └── vlparse.py
│ └── vocabulary.py
├── model
│ ├── __init__.py
│ ├── base.py
│ ├── dmv.py
│ ├── dmv_helper
│ │ ├── __init__.py
│ │ ├── good_init.py
│ │ ├── good_init_nn.py
│ │ └── km_init.py
│ ├── embedding
│ │ ├── __init__.py
│ │ ├── embedding.py
│ │ ├── fastnlp_embedding.py
│ │ └── transformers_embedding.py
│ ├── joint.py
│ ├── ldndmv.py
│ ├── nn
│ │ ├── __init__.py
│ │ ├── affine.py
│ │ ├── affine_scorer.py
│ │ ├── common.py
│ │ ├── dmv_spec.py
│ │ ├── dropout.py
│ │ ├── multivariate_kl.py
│ │ ├── scalar_mix.py
│ │ └── variational_lstm.py
│ ├── text_encoder
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── blank_encoder.py
│ │ ├── mlp_encoder.py
│ │ ├── multi_encoder.py
│ │ └── rnn_encoder.py
│ ├── torch_struct
│ │ ├── __init__.py
│ │ ├── deptree.py
│ │ ├── distributions.py
│ │ ├── dmv.py
│ │ ├── helpers.py
│ │ └── semirings
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint.py
│ │ │ ├── fast_semirings.py
│ │ │ ├── keops.py
│ │ │ ├── sample.py
│ │ │ ├── semirings.py
│ │ │ └── sparse_max.py
│ └── vis_encoder
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── box_rel.py
├── pipeline.py
└── utility
│ ├── _metric_legacy.py
│ ├── alg.py
│ ├── config.py
│ ├── defaultlist.py
│ ├── fn.py
│ ├── logger.py
│ ├── meta.py
│ ├── metric.py
│ ├── pl_callback.py
│ ├── scheduler.py
│ ├── spacy_helper.py
│ └── var_pool.py
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /outputs
2 | /.vscode
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | ### VisualStudioCode
134 | .vscode/*
135 | !.vscode/settings.json
136 | !.vscode/tasks.json
137 | !.vscode/launch.json
138 | !.vscode/extensions.json
139 | *.code-workspace
140 | **/.vscode
141 |
142 | # JetBrains
143 | .idea/
144 |
145 | # Lightning-Hydra-Template
146 | /configs/local/default.yaml
147 | # /data/
148 | /logs/
149 | /wandb/
150 | .env
151 | .autoenv
152 |
153 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Beijing Institute for General Artificial Intelligence
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VLGAE
2 | Official Implementation for CVPR 2022 paper "Unsupervised Vision-Language Parsing: Seamlessly Bridging Visual Scene Graphs with Language Structures via Dependency Relationships"
3 |
4 |
5 |

6 |
7 |
--------------------------------------------------------------------------------
/assets/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LouChao98/VLGAE/d71d07a31c3e4c04616070a053729956108fb83d/assets/task.png
--------------------------------------------------------------------------------
/config/config_test.yaml:
--------------------------------------------------------------------------------
1 | # some cfg should not belong to any submodule
2 | seed: 1
3 | project: untitled
4 | name: ${name_guard:@@@AUTO@@@}
5 | watch_field: val/loss
6 | watch_mode: min
7 | root: ${hydra:runtime.cwd}
8 | output_name: ~
9 |
10 | pipeline:
11 | _target_: src.pipeline.Pipeline
12 |
13 | load_from_checkpoint: ~
14 | loss_reduction_mode: token
15 |
16 | hydra:
17 | run:
18 | dir: .
19 | output_subdir: null
20 | job:
21 | env_set:
22 | TOKENIZERS_PARALLELISM: 'false'
23 | HF_DATASETS_OFFLINE: '1'
24 | TRANSFORMERS_OFFLINE: '1'
25 | TORCH_WARN_ONCE: '1'
26 | NUMEXPR_MAX_THREADS: '8'
27 | DEBUG_MODE: ''
28 |
29 | defaults:
30 | - _self_
31 | - trainer: train
32 | - data: vlparse
33 | - model: vlgae
--------------------------------------------------------------------------------
/config/config_train.yaml:
--------------------------------------------------------------------------------
1 | # some cfg should not belong to any submodule
2 | seed: ~
3 | project: untitled
4 | name: ${name_guard:@@@AUTO@@@}
5 | watch_field: val/loss
6 | watch_mode: min
7 | root: ${hydra:runtime.cwd}
8 | load_cfg_from_checkpoint: ~
9 |
10 | pipeline:
11 | _target_: src.pipeline.Pipeline
12 |
13 | load_from_checkpoint: ~
14 | loss_reduction_mode: token
15 |
16 | hydra:
17 | sweep:
18 | dir: outputs/multirun/${now:%Y-%m-%d_%H-%M-%S}
19 | subdir: ${path_guard:${hydra.job.override_dirname}}
20 | run:
21 | dir: outputs/${path_guard:${name}}/${now:%Y-%m-%d_%H-%M-%S}
22 | output_subdir: config
23 | job:
24 | env_set:
25 | TOKENIZERS_PARALLELISM: 'false'
26 | HF_DATASETS_OFFLINE: '1'
27 | TRANSFORMERS_OFFLINE: '1'
28 | TORCH_WARN_ONCE: '1'
29 | NUMEXPR_MAX_THREADS: '8'
30 | DEBUG_MODE: '0'
31 |
32 | defaults:
33 | - _self_
34 | - trainer: train
35 | - data: vlparse
36 | - model: vlgae
--------------------------------------------------------------------------------
/config/data/vlparse.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | datamodule:
4 | _target_: src.datamodule.task.VLParseDataModule
5 | train_path: ${root}/data/vlparse/train
6 | train_init_path: ${root}/data/vlparse/init
7 | dev_path: ${root}/data/vlparse/val
8 | test_path: ${root}/data/vlparse/test
9 |
10 | use_img: false
11 | use_gold_scene_graph: false
12 | sg_path: ${root}/data/vlparse/vlparse.json
13 |
14 | use_tag: true
15 | num_lex: 200
16 | num_token: 99999
17 | ignore_stop_word: false
18 |
19 | normalize_word: true
20 | build_no_create_entry: true
21 | max_len:
22 | train: 10
23 |
24 | train_dataloader:
25 | token_size: 5000
26 | num_bucket: 10
27 | batch_size: 64
28 | dev_dataloader:
29 | token_size: 5000
30 | num_bucket: 8
31 | batch_size: 64
32 | test_dataloader:
33 | token_size: 5000
34 | num_bucket: 8
35 | batch_size: 64
36 |
37 | trainer:
38 | val_check_interval: 0.5
--------------------------------------------------------------------------------
/config/data/vlparse_lang_only.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | datamodule:
4 | _target_: src.datamodule.task.DepDataModule
5 | train_path: ${root}/data/vlparse/train.conll
6 | train_init_path: ${root}/vlparse/init.conll
7 | dev_path: ${root}/data/vlparse/val.conll
8 | test_path: ${root}/data/vlparse/test.conll
9 |
10 | use_tag: true
11 | num_lex: 200
12 | num_token: 99999
13 | ignore_stop_word: false
14 |
15 | normalize_word: true
16 | build_no_create_entry: true
17 |
18 | train_dataloader:
19 | token_size: 5000
20 | num_bucket: 10
21 | batch_size: 64
22 | dev_dataloader:
23 | num_bucket: 8
24 | token_size: 10000
25 | test_dataloader:
26 | num_bucket: 8
27 | token_size: 10000
28 | max_len:
29 | train: 15
30 |
31 | trainer:
32 | val_check_interval: 0.5
--------------------------------------------------------------------------------
/config/exp/lang_only.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /data: vlparse_lang_only
5 | - override /model: lang_only
6 |
7 | datamodule:
8 | num_lex: 0
9 | ignore_stop_word: true
10 |
11 | dataloader:
12 | default:
13 | batch_size: 16
14 |
15 | encoder:
16 | hidden_size: 400
17 | num_layers: 3
18 | lstm_dropout: 0.2
19 |
20 | model:
21 | init_method: 'y'
22 | context_mode: 'hx'
23 | init_epoch: 3
24 |
25 | mid_ff:
26 | n_bottleneck: 0
27 | n_mid: 100
28 | dropout: 0.2
29 | root_emb_dim: 10
30 | dec_emb_dim: 10
31 |
32 | variational_mode: 'none'
33 | z_dim: 64
34 |
35 | optimizer:
36 | args:
37 | lr: 0.0005
38 |
39 | _rank: 32
40 | _dropout: 0.5
41 | _hidden_size: 384
42 | project: unnamed
--------------------------------------------------------------------------------
/config/exp/vlgae.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - override /data: vlparse
5 | - override /model: vlgae
6 | - override /model/optimize@optimize: linear
7 |
8 | datamodule:
9 | num_lex: 0
10 | max_len:
11 | train: 50
12 |
13 | trainer:
14 | val_check_interval: 0.5
15 | max_epochs: 50
16 |
17 | optimizer:
18 | args:
19 | lr: 1.0e-3
20 |
21 | project: unnamed
--------------------------------------------------------------------------------
/config/hydra/job_logging/custom.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra.job_logging
2 |
3 | version: 1
4 | formatters:
5 | console:
6 | (): src.utility.logger.ColorFormatter
7 | format: '%(message)s'
8 | detail:
9 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
10 | datefmt: '%y-%m-%d %H:%M:%S'
11 | handlers:
12 | console:
13 | class: src.utility.logger.TqdmLoggingHandler
14 | formatter: console
15 | level: DEBUG
16 | file:
17 | class: logging.FileHandler
18 | formatter: detail
19 | filename: ${hydra.job.name}.log
20 | root:
21 | handlers: [console, file]
22 | loggers:
23 | fastNLP:
24 | handlers: [console, file]
25 | lightning:
26 | handlers: [console, file]
27 | nni:
28 | handlers: [console, file]
29 | disable_existing_loggers: false
30 |
--------------------------------------------------------------------------------
/config/hydra/job_logging/nofile.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra.job_logging
2 |
3 | version: 1
4 | formatters:
5 | console:
6 | (): src.utility.logger.ColorFormatter
7 | format: '[%(name)s] %(message)s'
8 | handlers:
9 | console:
10 | class: src.utility.logger.TqdmLoggingHandler
11 | formatter: console
12 | level: DEBUG
13 | root:
14 | handlers: [console]
15 | loggers:
16 | fastNLP:
17 | handlers: [console]
18 | lightning:
19 | handlers: [console]
20 | nni:
21 | handlers: [console]
22 | disable_existing_loggers: false
23 |
--------------------------------------------------------------------------------
/config/model/embedding/en.yaml:
--------------------------------------------------------------------------------
1 | # @package embedding
2 |
3 | # embedding args
4 | use_word: true
5 | use_tag: true
6 | use_subword: false
7 | dropout: 0.
8 |
9 | # embedding item args
10 | word_embedding:
11 | args:
12 | _target_: fastNLP.embeddings.StaticEmbedding
13 | model_dir_or_name: ${..._emb_mapping.glove100}
14 | min_freq: 2
15 | lower: true
16 | adaptor_args:
17 | _target_: src.model.embedding.FastNLPEmbeddingVariationalAdaptor
18 | mode: basic
19 | out_dim: 0
20 | field: word
21 | normalize_method: mean+std
22 | normalize_time: begin
23 | tag_embedding:
24 | args:
25 | _target_: fastNLP.embeddings.StaticEmbedding
26 | embedding_dim: 100
27 | init_embed: normal
28 | adaptor_args:
29 | _target_: src.model.embedding.FastNLPEmbeddingAdaptor
30 | field: tag
31 | normalize_method: mean+std
32 | normalize_time: begin
33 | transformer:
34 | args:
35 | _target_: src.model.embedding.TransformersEmbedding
36 | model: bert-base-cased
37 | n_layers: 1
38 | n_out: 0
39 | requires_grad: false
40 | adaptor_args:
41 | _target_: src.model.embedding.TransformersAdaptor
42 | field: subword
43 | requires_vocab: false
44 |
45 |
46 | # others
47 | _emb_mapping:
48 | glove100: ${root}/data/glove/glove.6B.100d.txt
49 | glove300: ${root}/data/glove/glove.840B.300d.txt
50 | glove6b_300: ${root}/data/glove/glove.6B.300d.txt
51 | bio: ${root}/data/bio_nlp_vec/PubMed-shuffle-win-30.txt
52 | jose100: ${root}/data/jose/jose_100d.txt
--------------------------------------------------------------------------------
/config/model/lang_only.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - embedding: en
5 | - metric: attachment
6 | - optimize: linear
7 |
8 | encoder:
9 | _target_: src.model.text_encoder.RNNEncoder
10 | reproject_emb: 0
11 | reproject_out: 0
12 | mix: false
13 | pre_shared_dropout: 0.1
14 | pre_dropout: 0.1
15 | post_shared_dropout: 0.1
16 | post_dropout: 0.1
17 | hidden_size: 200
18 | proj_size: 0
19 | num_layers: 2
20 | output_layers: -1
21 | init_version: zy
22 | shared_dropout: true
23 | lstm_dropout: 0.33
24 |
25 | _hidden_size: 500
26 | _dropout: 0.33
27 | _rank: 32
28 |
29 | model:
30 | _target_: src.model.DiscriminativeNDMV
31 | _recursive_: false
32 | context_mode: hx
33 | init_method: 'y'
34 | init_epoch: 3
35 | viterbi_training: true
36 | mbr_decoding: false
37 | extended_valence: true
38 | function_mask: false
39 |
40 | variational_mode: 'none'
41 | z_dim: 0
42 |
43 | mid_ff:
44 | _target_: src.model.nn.DMVSkipConnectEncoder
45 | n_bottleneck: 0
46 | n_mid: 0
47 | dropout: 0.
48 |
49 | head_ff:
50 | _target_: src.model.nn.MLP
51 | n_hidden: ${_hidden_size}
52 | dropout: ${_dropout}
53 | child_ff:
54 | _target_: src.model.nn.MLP
55 | n_hidden: ${_hidden_size}
56 | dropout: ${_dropout}
57 | root_ff:
58 | _target_: src.model.nn.MLP
59 | n_hidden: ${_hidden_size}
60 | dropout: ${_dropout}
61 | dec_ff:
62 | _target_: src.model.nn.MLP
63 | n_hidden: ${_hidden_size}
64 | dropout: ${_dropout}
65 |
66 | attach_rank: ${_rank}
67 | dec_rank: ${_rank}
68 | root_rank: ${_rank}
69 |
70 | root_emb_dim: 50
71 | dec_emb_dim: 50
--------------------------------------------------------------------------------
/config/model/metric/attachment.yaml:
--------------------------------------------------------------------------------
1 | # @package metric
2 |
3 | _target_: src.utility.metric.DependencyParsingMetric
4 |
--------------------------------------------------------------------------------
/config/model/metric/attachment_box_rel.yaml:
--------------------------------------------------------------------------------
1 | # @package metric
2 | _target_: src.utility.metric.MultiMetric
3 |
4 | dep:
5 | _target_: src.utility.metric.DependencyParsingMetric
6 | extra_vocab: ${..extra_vocab}
7 |
8 | img:
9 | _target_: src.utility.metric.FactorImageMatchingMetric
10 | extra_vocab: ${..extra_vocab}
11 |
12 | match:
13 | _target_: src.utility.metric.BoxRelMatchingMetric
14 | extra_vocab: ${..extra_vocab}
15 |
--------------------------------------------------------------------------------
/config/model/optimize/constant.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optimizer:
4 | groups:
5 | - pattern: dependency.embedding.transformer
6 | lr: 1.0e-5
7 | args:
8 | _target_: torch.optim.Adam
9 | lr: 1.0e-3
10 | betas: [ 0.9, 0.999 ]
11 | weight_decay: 0.
12 | eps: 1.0e-12
13 |
14 | scheduler: ~
--------------------------------------------------------------------------------
/config/model/optimize/linear.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | optimizer:
4 | groups: []
5 | args:
6 | _target_: torch.optim.Adam
7 | lr: 1.0e-3
8 | betas: [ 0.9, 0.999 ]
9 | weight_decay: 0.
10 | eps: 1.0e-12
11 |
12 | scheduler:
13 | interval: step
14 | frequency: 1
15 | args:
16 | _target_: src.utility.scheduler.get_exponential_lr_scheduler
17 | gamma: 0.75**(1/2000)
18 |
--------------------------------------------------------------------------------
/config/model/vlgae.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | defaults:
4 | - metric: attachment_box_rel
5 | - embedding: en
6 | - optimize: constant
7 |
8 | _match_hidden_size: 128
9 | _hidden_size: 256
10 | _dropout: 0.33
11 | _rank: 16
12 |
13 | embedding:
14 | use_word: false
15 | use_subword: true
16 | use_tag: true
17 | tag_embedding:
18 | args:
19 | embedding_dim: 32
20 |
21 | encoder:
22 | _target_: src.model.text_encoder.MLPEncoder
23 | dropout: 0.33
24 | shared_dropout: 0
25 | n_hidden: ${_hidden_size}
26 |
27 | vis_encoder:
28 | _target_: src.model.vis_encoder.VisBoxRelSimpleEncoder
29 | n_in: 2048
30 | n_hidden: ${_hidden_size}
31 | dropout: 0.
32 | activate: true
33 | use_attr: true
34 | use_img: false
35 | img_feat: true
36 |
37 | model:
38 | _target_: src.model.DependencyBoxRel
39 | _recursive_: false
40 |
41 | add_rel: true
42 | add_attr: true
43 | add_image: true
44 | add_marginal: true
45 |
46 | margin: 1
47 | language_factor_mode: word+maxdep
48 | visual_factor_mode: unprune
49 | visual_factor_cfg:
50 | n_hidden: ${_match_hidden_size}
51 | feat_fuse_mode: attention
52 | feat_fuse_args:
53 | num_heads: 4
54 | dropout: 0.33
55 | replace: false
56 | aug_with_matching: true
57 | gather_logit_mode: simple
58 | gather_logit_args: ~
59 | loss_grounding_mode: factor|ce
60 | loss_grounding_args:
61 | use_pos_prior: true
62 | vis2txt: 1
63 | decode_grounding_mode: on_factor
64 | decode_grounding_args:
65 | use_pos_prior: true
66 | use_heuristic: true
67 | grounding_interpolation: 0.5
68 |
69 | word_encoder:
70 | _target_: src.model.nn.MLP
71 | n_hidden: ${_match_hidden_size}
72 | dropout: 0.33
73 | activate: false
74 |
75 | init_method: 'y'
76 | init_epoch: 5
77 |
78 | dep_model_cfg:
79 | _target_: src.model.DiscriminativeNDMV
80 | _recursive_: false
81 | context_mode: 'mean'
82 | init_method: ${..init_method}
83 | init_epoch: ${..init_epoch}
84 | viterbi_training: true
85 | mbr_decoding: false
86 | extended_valence: true
87 | function_mask: false
88 |
89 | variational_mode: 'none'
90 | z_dim: 0
91 |
92 | mid_ff:
93 | _target_: src.model.nn.DMVSkipConnectEncoder
94 | n_bottleneck: 150
95 | n_mid: 0
96 | dropout: 0.3
97 |
98 | head_ff:
99 | _target_: src.model.nn.MLP
100 | n_hidden: ${_hidden_size}
101 | dropout: ${_dropout}
102 | child_ff:
103 | _target_: src.model.nn.MLP
104 | n_hidden: ${_hidden_size}
105 | dropout: ${_dropout}
106 | root_ff:
107 | _target_: src.model.nn.MLP
108 | n_hidden: ${_hidden_size}
109 | dropout: ${_dropout}
110 | dec_ff:
111 | _target_: src.model.nn.MLP
112 | n_hidden: ${_hidden_size}
113 | dropout: ${_dropout}
114 |
115 | attach_rank: ${_rank}
116 | dec_rank: ${_rank}
117 | root_rank: ${_rank}
118 |
119 | root_emb_dim: 10
120 | dec_emb_dim: 10
--------------------------------------------------------------------------------
/config/trainer/callbacks/best_watcher.yaml:
--------------------------------------------------------------------------------
1 | best_watcher:
2 | _target_: src.utility.pl_callback.BestWatcherCallback
3 | monitor: ${watch_field}
4 | mode: ${watch_mode}
5 | hint: true
6 | save:
7 | dirpath: checkpoint
8 | filename: "{epoch}-{step}-{${watch_field}:.2f}"
9 | start_patience: 2
10 | write: 'new'
11 | report: true
12 |
--------------------------------------------------------------------------------
/config/trainer/callbacks/early_stopping.yaml:
--------------------------------------------------------------------------------
1 | early_stopping:
2 | _target_: pytorch_lightning.callbacks.EarlyStopping
3 | monitor: ${watch_field}
4 | mode: ${watch_mode}
5 | patience: 100
6 |
--------------------------------------------------------------------------------
/config/trainer/callbacks/lr_monitor.yaml:
--------------------------------------------------------------------------------
1 | lr_monitor:
2 | _target_: src.utility.pl_callback.LearningRateMonitorWithEarlyStop
3 | logging_interval: 'epoch' # None, step, epoch. None=following scheduler
4 | minimum_lr: 1e-8
5 |
--------------------------------------------------------------------------------
/config/trainer/callbacks/progressbar.yaml:
--------------------------------------------------------------------------------
1 | progress_bar:
2 | _target_: src.utility.pl_callback.MyProgressBar
3 | refresh_rate: 1
4 | process_position: 0
5 |
6 | #progress_bar:
7 | # _target_: pytorch_lightning.callbacks.RichProgressBar
--------------------------------------------------------------------------------
/config/trainer/callbacks/wandb.yaml:
--------------------------------------------------------------------------------
1 | wandb:
2 | _target_: src.utility.pl_callback.WatchModelWithWandb
3 | log: ${in_debugger:gradients,null} # all, gradients, parameters, None
4 | log_freq: 100
5 |
--------------------------------------------------------------------------------
/config/trainer/callbacks/weights_summary.yaml:
--------------------------------------------------------------------------------
1 | weights_summary:
2 | _target_: pytorch_lightning.callbacks.ModelSummary
3 | max_depth: ${in_debugger:5,2}
--------------------------------------------------------------------------------
/config/trainer/debug.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - logger: ~
3 | - callbacks:
4 | - progressbar
5 | - early_stopping
6 | - weights_summary
7 | # - swa
8 | - override /hydra/job_logging@_global_.hydra.job_logging: nofile
9 |
10 | hydra:
11 | job:
12 | env_set:
13 | DEBUG_MODE: '1'
14 |
15 | _target_: src.utility.fn.instantiate_trainer
16 |
17 | fast_dev_run: 3
18 | checkpoint_callback: false
19 |
20 | gpus: 1
21 | gradient_clip_val: 5.
22 | track_grad_norm: -1
23 | # max_epochs: 1000 # due to fast_dev_run
24 | max_steps: -1
25 | val_check_interval: 1.0 # int for n epoch, float for in epoch
26 | accumulate_grad_batches: 1
27 | precision: 32
28 | # num_sanity_val_steps: 2 # due to fast_dev_run
29 | resume_from_checkpoint: ~
30 | detect_anomaly: true
31 | deterministic: false
32 |
33 | # following are settings you should not touch in most cases
34 | accelerator: ${accelerator:${.gpus}}
35 | replace_sampler_ddp: false
36 | multiple_trainloader_mode: min_size
37 | enable_model_summary: false
--------------------------------------------------------------------------------
/config/trainer/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 |
2 | _target_: pytorch_lightning.loggers.WandbLogger
3 | name: ${name}
4 | project: ${project}
5 | tags: []
6 | save_code: false
7 | save_dir: ${root}/outputs
8 |
--------------------------------------------------------------------------------
/config/trainer/test.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - logger: ~
3 | - callbacks:
4 | - progressbar
5 | - override /hydra/job_logging@_global_.hydra.job_logging: nofile
6 |
7 | _target_: src.utility.fn.instantiate_trainer
8 |
9 | enable_checkpointing: false
10 | logger: ~
11 |
12 | gpus: 1
13 | precision: 32
14 | resume_from_checkpoint: ~
15 |
16 | # following are settings you should not touch in most cases
17 | accelerator: ${accelerator:${.gpus}}
18 | detect_anomaly: false
19 | replace_sampler_ddp: false
20 | enable_model_summary: false
21 |
--------------------------------------------------------------------------------
/config/trainer/train.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - logger: ~
3 | - callbacks:
4 | - progressbar
5 | # - early_stopping
6 | - lr_monitor # may early-stopping
7 | - best_watcher
8 | - weights_summary
9 | - override /hydra/job_logging@_global_.hydra.job_logging: custom
10 |
11 | _target_: src.utility.fn.instantiate_trainer
12 |
13 | gpus: 1
14 | gradient_clip_val: 5.
15 | track_grad_norm: -1
16 | max_epochs: 50
17 | max_steps: -1
18 | val_check_interval: 1.0 # int for n step, float for in epoch
19 | accumulate_grad_batches: 1
20 | precision: 32
21 | num_sanity_val_steps: ${in_debugger:1,5}
22 | resume_from_checkpoint: ~
23 | detect_anomaly: false
24 | deterministic: false
25 |
26 | # following are settings you should not touch in most cases
27 | accelerator: gpu
28 | strategy: ${accelerator:${.gpus}}
29 | replace_sampler_ddp: false
30 | multiple_trainloader_mode: min_size
31 | enable_model_summary: false
32 |
--------------------------------------------------------------------------------
/data/data_format.json:
--------------------------------------------------------------------------------
1 | {
2 | "": {
3 | "image": {
4 | "coco_id": 0, // MSCOCO id
5 | "vg_id": 0, // VisualGenome id
6 | "height": 0,
7 | "width": 0
8 | },
9 | "box": {
10 | "": {
11 | "width": 0.0, // percentage of image width
12 | "height": 0.0, // percentage of image height
13 | "x": 0.0, // percentage of image width
14 | "y": 0.0, // percentage of image height
15 | "label": "region label from VisualGenome",
16 | "attribute": "list of attributes separated by semicolon"
17 | },
18 | ...
19 | },
20 | "relationship":{
21 | "",
23 | "to": "",
24 | "label": "relationship label from VisualGenome"
25 | },
26 | ...
27 | },
28 | "sentence": {
29 | "": {
30 | "text": "the sentence",
31 | "pos": "part-of-speech tags",
32 | "dephead": "dependency heads",
33 | "span": {
34 | "": { // object
35 | "label": "object",
36 | "start": 0, // inclusive character offset
37 | "end": 0, // exclusive character offset
38 | "attribute_start": 0, // inclusive character offset
39 | "attribute_end": 0, // exclusive character offset, (0,0)=no attribute
40 | "text": "text",
41 | "attribute_text": "attribute_text",
42 | "alignment": [""]
43 | },
44 | ""],
50 | "alignment": [""]
51 | },
52 | ...
53 | }
54 | },
55 | ...
56 | }
57 | },
58 | ...
59 | }
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | import conllu
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | "--gold", action="store_true", help="whether to use gold boxes instead of proposals"
9 | )
10 | parser.add_argument(
11 | "--file",
12 | help="path to the prediction",
13 | default="outputs/0_latest_run/dev.predict.txt",
14 | )
15 | parser.add_argument(
16 | "--dataroot",
17 | help="path to VLParse",
18 | default="data/vlparse",
19 | )
20 | args = parser.parse_args()
21 |
22 | id_list_path = f"{args.dataroot}/id_list/val.txt"
23 | predict_path = args.file
24 |
25 | if args.gold:
26 | with open(f"{args.dataroot}/dev_gold_boxes.json") as f:
27 | img2boxes = json.load(f)
28 | else:
29 | with open(f"{args.dataroot}/dev_roi_boxes.json") as f:
30 | img2boxes = json.load(f)
31 | img2boxes = {int(key): value for key, value in img2boxes.items()}
32 |
33 | with open(f"{args.dataroot}/vlparse.json") as f:
34 | gold = json.load(f)
35 | gold = {item["coco_id"]: item for item in gold if isinstance(item, dict)}
36 |
37 |
38 | id_list = [line for line in Path(id_list_path).read_text().splitlines()]
39 | img_ids = [int(item) for item in id_list for _ in range(5)]
40 | sent_ids = [item for _ in id_list for item in range(5)]
41 | predict = list(
42 | conllu.parse_incr(open(predict_path), fields=["ID", "FORM", "POS", "HEAD", "ALIGN"])
43 | )
44 | has_vg = [item in gold for item in img_ids]
45 | img_ids = [item for item, flag in zip(img_ids, has_vg) if flag]
46 | sent_ids = [item for item, flag in zip(sent_ids, has_vg) if flag]
47 | # predict = [item for item, flag in zip(predict, has_vg) if flag]
48 | print(len(sent_ids), len(predict))
49 |
50 |
51 | def get_position(item):
52 | return item["x"], item["y"], item["x"] + item["width"], item["y"] + item["height"]
53 |
54 |
55 | def bb_intersection_over_union(boxA, boxB):
56 | # boxA = [int(x) for x in boxA]
57 | # boxB = [int(x) for x in boxB]
58 |
59 | xA = max(boxA[0], boxB[0])
60 | yA = max(boxA[1], boxB[1])
61 | xB = min(boxA[2], boxB[2])
62 | yB = min(boxA[3], boxB[3])
63 |
64 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
65 |
66 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
67 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
68 |
69 | iou = interArea / float(boxAArea + boxBArea - interArea)
70 |
71 | return iou
72 |
73 |
74 | obj_correct = 0
75 | obj_total = 0
76 | attr_correct = 0
77 | attr_total = 0
78 | rel_correct = 0
79 | rel_total = 0
80 |
81 |
82 | def test(boxA, boxB):
83 | if bb_intersection_over_union(boxA, boxB) >= 0.5:
84 | return True
85 | return False
86 |
87 |
88 | for idx in range(len(predict)):
89 | img_id, sent_id = img_ids[idx], sent_ids[idx]
90 |
91 | # obj
92 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items():
93 | if data["type"] != "OBJ":
94 | continue
95 | correct_flag = False
96 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"):
97 | pred_type, pred_id = item.split()
98 | if pred_type == "obj":
99 | word_predict = img2boxes[img_id][int(pred_id)]
100 | correct_flag = False
101 | for obj_id, _ in data["candidates"]:
102 | position = get_position(gold[img_id]["obj"][obj_id])
103 | if test(word_predict, position):
104 | correct_flag = True
105 | break
106 | if correct_flag:
107 | obj_correct += 1
108 | break
109 | obj_total += 1
110 |
111 | # attr
112 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items():
113 | if data["type"] != "ATTR":
114 | continue
115 | correct_flag = False
116 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"):
117 | pred_type, pred_id = item.split()
118 | if pred_type == "attr":
119 | try:
120 | word_predict = img2boxes[img_id][int(pred_id)]
121 | except IndexError:
122 | print(img_id, sent_id)
123 | correct_flag = False
124 | for obj_id, _ in data["candidates"]:
125 | position = get_position(gold[img_id]["obj"][obj_id])
126 | if test(word_predict, position):
127 | correct_flag = True
128 | break
129 | if correct_flag:
130 | attr_correct += 1
131 | break
132 | attr_total += 1
133 |
134 | # rel
135 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items():
136 | if data["type"] != "REL":
137 | continue
138 | correct_flag = False
139 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"):
140 | pred_type, pred_id = item.split()
141 | if pred_type == "rel":
142 | obj1, obj2 = pred_id.split("-")
143 | obj1 = img2boxes[img_id][int(obj1)]
144 | obj2 = img2boxes[img_id][int(obj2)]
145 |
146 | correct_flag = False
147 | for rel_id, _ in data["candidates"]:
148 | rel_item = gold[img_id]["rel"][rel_id - len(gold[img_id]["obj"])]
149 | assert rel_item["id"] == rel_id
150 | gold_obj1 = get_position(gold[img_id]["obj"][rel_item["subj"]])
151 | gold_obj2 = get_position(gold[img_id]["obj"][rel_item["obj"]])
152 |
153 | if test(obj1, gold_obj1) and test(obj2, gold_obj2):
154 | correct_flag = True
155 | break
156 | if test(obj2, gold_obj1) and test(obj1, gold_obj2):
157 | correct_flag = True
158 | break
159 | if correct_flag:
160 | rel_correct += 1
161 | break
162 | rel_total += 1
163 |
164 |
165 | print("obj", obj_correct / obj_total, obj_total)
166 | print("attr", attr_correct / attr_total, attr_total)
167 | print("rel", rel_correct / rel_total, rel_total)
168 | print(
169 | "0-order",
170 | (obj_correct + attr_correct + rel_correct) / (obj_total + attr_total + rel_total),
171 | )
172 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | hydra-core
2 | pytorch-lightning
3 | transformers
4 | easydict
5 | colorama
6 | fastnlp
7 | nltk
8 | wandb
9 | matplotlib
10 | seaborn
11 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Optional, Mapping
4 |
5 | import numpy as np
6 | import pytorch_lightning
7 | import torch
8 | from easydict import EasyDict
9 | from hydra._internal.utils import is_under_debugger as _is_under_debugger
10 | from hydra.utils import HydraConfig
11 | from omegaconf import ListConfig, OmegaConf
12 |
13 | from src.utility.logger import get_logger_func
14 |
15 | _warn, _info, _debug = get_logger_func('src')
16 |
17 | g_cfg = EasyDict({
18 | 'MANUAL': 1,
19 | }) # globel configuration obj
20 | trainer: Optional[pytorch_lightning.Trainer] = None
21 | debugging = False
22 |
23 | # >>> setup logger
24 |
25 | pl_logger = logging.getLogger('lightning')
26 | pl_logger.propagate = False
27 |
28 | fastnlp_logger = logging.getLogger('fastNLP')
29 | fastnlp_logger.propagate = False
30 |
31 | wandb_logger = logging.getLogger('wandb')
32 | # wandb_logger.propagate = False
33 |
34 | # >>> setup OmegaConf
35 |
36 | # OmegaConf.register_new_resolver('in', lambda x, y: x in y)
37 | OmegaConf.register_new_resolver('lang', lambda x: x.split('_')[0])
38 | OmegaConf.register_new_resolver('last', lambda x: x.split('/')[-1])
39 | OmegaConf.register_new_resolver('div2', lambda x: x // 2)
40 | # OmegaConf.register_new_resolver('cat', lambda x, y: x + y)
41 |
42 | _hit_debug = True
43 |
44 |
45 | def is_under_debugger():
46 | if os.environ.get('DEBUG_MODE', '').lower() in ('true', 't', '1', 'yes', 'y'):
47 | result = True
48 | else:
49 | result = _is_under_debugger()
50 | global _hit_debug, debugging
51 | if result and _hit_debug:
52 | _warn("Debug mode.")
53 | _hit_debug = False
54 | debugging = True
55 | return result
56 |
57 |
58 | OmegaConf.register_new_resolver('in_debugger', lambda x, default=None: x if is_under_debugger() else default)
59 |
60 |
61 | def path_guard(x: str):
62 | x = x.split(',')
63 | x.sort()
64 | x = '_'.join(x)
65 | x = x.replace('/', '-')
66 | x = x.replace('=', '-')
67 | return x[:240]
68 |
69 |
70 | OmegaConf.register_new_resolver('path_guard', path_guard)
71 |
72 |
73 | def half_int(x):
74 | assert x % 2 == 0
75 | return x // 2
76 |
77 |
78 | OmegaConf.register_new_resolver('half_int', half_int)
79 |
80 |
81 | def name_guard(fallback):
82 | try:
83 | return HydraConfig.get().job.override_dirname
84 | except ValueError as v:
85 | if 'HydraConfig was not set' in str(v):
86 | return fallback
87 | raise v
88 |
89 |
90 | OmegaConf.register_new_resolver('name_guard', name_guard)
91 |
92 |
93 | def choose_accelerator(gpus):
94 | if isinstance(gpus, int):
95 | return 'ddp' if gpus > 1 else None
96 | elif isinstance(gpus, str):
97 | return 'ddp' if len(gpus.split(',')) > 1 else None
98 | elif isinstance(gpus, (list, ListConfig)):
99 | return 'ddp' if len(gpus) > 1 else None
100 | elif gpus is None:
101 | return None
102 | raise ValueError(f'Unrecognized {gpus=} ({type(gpus)})')
103 |
104 |
105 | OmegaConf.register_new_resolver('accelerator', choose_accelerator)
106 |
107 |
108 | # >>> setup inf
109 |
110 | INF = 1e20
111 |
112 |
113 | def setup_inf(v):
114 | global INF
115 | import src.model.torch_struct as stt
116 | INF = v
117 | stt.semirings.semirings.NEGINF = -INF
118 |
119 |
120 | setup_inf(1e20)
121 |
122 |
123 | # pl patch
124 |
125 | def _extract_batch_size(batch):
126 | if isinstance(batch, torch.Tensor):
127 | yield batch.shape[0]
128 | elif isinstance(batch, np.ndarray):
129 | yield batch.shape[0]
130 | elif isinstance(batch, str):
131 | yield len(batch)
132 | elif isinstance(batch, Mapping):
133 | for sample in batch:
134 | yield from _extract_batch_size(sample)
135 | else:
136 | x, y = batch
137 | yield len(x['id'])
138 |
139 |
140 | from pytorch_lightning.utilities import data as pludata
141 |
142 | pludata._extract_batch_size = _extract_batch_size
143 |
--------------------------------------------------------------------------------
/src/datamodule/__init__.py:
--------------------------------------------------------------------------------
1 | from src.datamodule.datamodule import DataModule
--------------------------------------------------------------------------------
/src/datamodule/sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 | from math import ceil
4 | from typing import List
5 |
6 | import torch
7 | from fastNLP import RandomSampler, SequentialSampler
8 |
9 |
10 | from src.utility.logger import get_logger_func
11 |
12 | _warn, _info, _debug = get_logger_func("sampler")
13 |
14 |
15 | class ConstantTokenNumSampler:
16 | def __init__(
17 | self,
18 | seq_len: List[int],
19 | max_token: int = 4096,
20 | max_sentence: int = -1,
21 | num_bucket: int = 16,
22 | single_sent_threshold: int = -1,
23 | sort_in_batch: bool = True,
24 | shuffle: bool = True,
25 | force_same_len: bool = False,
26 | ):
27 | """
28 | :param List[int] seq_len: sample 的长度的列表。
29 | :param int max_token: 每个 batch 的最大的 token 数量
30 | :param int max_sentence: 每个 batch 最大的句子数量,与 max_token 同时生效, <=0 不生效
31 | :param int num_bucket: 将数据按长度拆分为 num_bucket 个 bucket
32 | :param int single_sent_threshold: 长度大于阈值的句子强制 batch_size=1, -1 不生效
33 | :param bool sort_in_batch: 使得一个 batch 内句子长度降序
34 | :param bool shuffle: shuffle
35 | :param bool force_same_len: 忽略 num_buckt, 每个长度为一个桶, 每个 batch 中所有的句子长度相同
36 | """
37 |
38 | assert (
39 | len(seq_len) >= num_bucket
40 | ), "The number of samples should be larger than buckets."
41 | assert (
42 | num_bucket > 1 or force_same_len
43 | ), "Use RandomSampler if you do not need bucket."
44 |
45 | self.seq_len = seq_len
46 | self.max_token = max_token
47 | self.max_sentence = max_sentence if max_sentence > 0 else 10000000000000000
48 | self.single_sent_threshold = single_sent_threshold
49 | self.sort_in_batch = sort_in_batch and not force_same_len
50 | self.shuffle = shuffle
51 | self.epoch = 0 # +=1 everytime __iter__ is called.
52 |
53 | # sizes: List[int], pseudo size of each buckets.
54 | # buckets: List[List[int]], each one is a bucket, containing idx.
55 | if force_same_len:
56 | self.sizes = list(set(seq_len))
57 | len2idx = dict((l, i) for i, l in enumerate(self.sizes))
58 | self.buckets = [[] for _ in range(len(self.sizes))]
59 | for i, l in enumerate(seq_len):
60 | self.buckets[len2idx[l]].append(i)
61 | else:
62 | self.sizes, self.buckets = self.kmeans(seq_len, num_bucket)
63 |
64 | # chunks: List[int], n chunk for each bucket
65 | self.chunks = [
66 | min(
67 | len(bucket),
68 | max(
69 | ceil(size * len(bucket) / max_token),
70 | ceil(len(bucket) / max_sentence),
71 | ),
72 | )
73 | for size, bucket in zip(self.sizes, self.buckets)
74 | ]
75 |
76 | self._batches = []
77 | self._all_batches = [] # including other workers
78 | self._exhausted = True
79 | self._init_iter_with_retry() # init here for valid __len__ at any time.
80 |
81 | def __iter__(self):
82 | self._init_iter_with_retry()
83 | yield from self._batches
84 | self._exhausted = True
85 |
86 | def __len__(self):
87 | return len(self._batches)
88 |
89 | def _init_iter(self):
90 | if self.shuffle:
91 | self.epoch += 1
92 | g = torch.Generator()
93 | g.manual_seed(self.epoch)
94 | range_fn = partial(torch.randperm, generator=g)
95 | else:
96 | range_fn = torch.arange
97 |
98 | batches = []
99 | for i in range(len(self.buckets)):
100 | split_sizes = [
101 | (len(self.buckets[i]) - j - 1) // self.chunks[i] + 1
102 | for j in range(self.chunks[i])
103 | ]
104 | for batch in range_fn(len(self.buckets[i])).split(split_sizes):
105 | batches.append([self.buckets[i][j] for j in batch])
106 | batches = [
107 | batch
108 | for i in range_fn(len(batches))
109 | for batch in self._process_batch(batches[i])
110 | ]
111 |
112 | self._batches = batches
113 | self._all_batches = batches
114 | self._exhausted = False
115 |
116 | def _init_iter_with_retry(self, max_try=5):
117 | _count = 0
118 | while self._exhausted:
119 | _count += 1
120 | if _count == max_try:
121 | raise ValueError("Failed to init iteration.")
122 | self._init_iter()
123 |
124 | def _process_batch(self, batch):
125 | # apply sort_in_batch and single_sent_threshold
126 | singles = []
127 | if self.single_sent_threshold != -1:
128 | new_batch = []
129 | for inst_idx in batch:
130 | if self.seq_len[inst_idx] >= self.single_sent_threshold:
131 | singles.append([inst_idx])
132 | else:
133 | new_batch.append(inst_idx)
134 | batch = new_batch
135 | if self.sort_in_batch:
136 | batch.sort(key=lambda i: -self.seq_len[i])
137 | if len(batch):
138 | return [batch] + singles
139 | else:
140 | return singles
141 |
142 | def set_epoch(self, epoch: int):
143 | # This is not a subclass of DistributedSampler, so will never be called by pytorch-lightning.
144 | breakpoint() # any case call this?
145 | self.epoch = epoch
146 |
147 | @staticmethod
148 | def kmeans(x, k, max_it=32):
149 | """From https://github.com/yzhangcs/parser/blob/main/supar/utils/alg.py#L7"""
150 |
151 | # the number of clusters must not be greater than the number of datapoints
152 | x, k = torch.tensor(x, dtype=torch.float), min(len(x), k)
153 | # collect unique datapoints
154 | d = x.unique()
155 | # initialize k centroids randomly
156 | c = d[torch.randperm(len(d))[:k]]
157 | # assign each datapoint to the cluster with the closest centroid
158 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
159 |
160 | for _ in range(max_it):
161 | # if an empty cluster is encountered,
162 | # choose the farthest datapoint from the biggest cluster and move that the empty one
163 | mask = torch.arange(k).unsqueeze(-1).eq(y)
164 | none = torch.where(~mask.any(-1))[0].tolist()
165 | while len(none) > 0:
166 | for i in none:
167 | # the biggest cluster
168 | b = torch.where(mask[mask.sum(-1).argmax()])[0]
169 | # the datapoint farthest from the centroid of cluster b
170 | f = dists[b].argmax()
171 | # update the assigned cluster of f
172 | y[b[f]] = i
173 | # re-calculate the mask
174 | mask = torch.arange(k).unsqueeze(-1).eq(y)
175 | none = torch.where(~mask.any(-1))[0].tolist()
176 | # update the centroids
177 | c, old = (x * mask).sum(-1) / mask.sum(-1), c
178 | # re-assign all datapoints to clusters
179 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1)
180 | # stop iteration early if the centroids converge
181 | if c.equal(old):
182 | break
183 | # assign all datapoints to the new-generated clusters
184 | # the empty ones are discarded
185 | assigned = y.unique().tolist()
186 | # get the centroids of the assigned clusters
187 | centroids = c[assigned].tolist()
188 | # map all values of datapoints to buckets
189 | clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned]
190 |
191 | return centroids, clusters
192 |
193 |
194 | class BasicSampler:
195 | """RandomSampler and SequentialSampler"""
196 |
197 | def __init__(
198 | self,
199 | seq_len,
200 | batch_size,
201 | single_sent_threshold=-1,
202 | sort_in_batch=True,
203 | shuffle=True,
204 | ):
205 | self.seq_len = seq_len
206 | self.batch_size = batch_size
207 | self.single_sent_threshold = single_sent_threshold
208 | self.sort_in_batch = sort_in_batch
209 | self.shuffle = shuffle
210 | self.epoch = 0
211 |
212 | self._sampler = RandomSampler() if shuffle else SequentialSampler()
213 |
214 | def __iter__(self):
215 | batch = []
216 | for i in self._sampler(self.seq_len):
217 | batch.append(i)
218 | if len(batch) == self.batch_size:
219 | yield from self._process_batch(batch)
220 | batch.clear()
221 | if batch:
222 | yield from self._process_batch(batch)
223 |
224 | def __len__(self):
225 | return math.ceil(len(self.seq_len) / self.batch_size)
226 |
227 | def _process_batch(self, batch):
228 | # apply sort_in_batch and single_sent_threshold
229 | singles = []
230 | if self.single_sent_threshold != -1:
231 | new_batch = []
232 | for inst_idx in batch:
233 | if self.seq_len[inst_idx] >= self.single_sent_threshold:
234 | singles.append([inst_idx])
235 | else:
236 | new_batch.append(inst_idx)
237 | batch = new_batch
238 | if self.sort_in_batch:
239 | batch.sort(key=lambda i: -self.seq_len[i])
240 | if len(batch):
241 | return [batch] + singles
242 | else:
243 | return singles
244 |
245 | def set_epoch(self, epoch: int):
246 | # This is not a subclass of DistributedSampler
247 | # this function will never be called by pytorch-lightning.
248 | self.epoch = epoch
249 |
--------------------------------------------------------------------------------
/src/datamodule/task/__init__.py:
--------------------------------------------------------------------------------
1 | from .dep import DepDataModule
2 | from .vlparse import VLParseDataModule
--------------------------------------------------------------------------------
/src/datamodule/task/dep.py:
--------------------------------------------------------------------------------
1 | from collections import Counter
2 |
3 | from fastNLP import DataSet
4 | from fastNLP.io import ConllLoader
5 | from nltk.corpus import stopwords
6 |
7 | import src
8 |
9 | from src.datamodule.datamodule import DataModule
10 | from src.datamodule.vocabulary import Vocabulary
11 | from src.utility.alg import isprojective
12 | from src.utility.logger import get_logger_func
13 | import omegaconf
14 |
15 | _warn, _info, _debug = get_logger_func('runner')
16 |
17 |
18 | class DepDataModule(DataModule):
19 | INPUTS = ('id', 'word', 'token', 'seq_len') # word for encoder, token for dmv
20 | TARGETS = ('arc', )
21 | LOADER = ConllLoader
22 |
23 | def __init__(
24 | self,
25 | use_tag=True,
26 | num_lex=0, # limit word in token. not consider tag.
27 | num_token=99999, # limit total token. consider (lex, tag) pair.
28 | ignore_stop_word=False,
29 | headers=None,
30 | indexes=None,
31 | **kwargs):
32 | assert num_lex > 0 or use_tag, 'Nothing to build token'
33 |
34 | headers = headers or ['raw_word', 'tag', 'arc']
35 | indexes = indexes or [1, 2, 3]
36 | loader = self.LOADER(headers, indexes=indexes, dropna=False, sep='\t')
37 |
38 | self.use_tag = use_tag
39 | if use_tag:
40 | assert 'tag' in headers
41 | self.INPUTS = self.INPUTS + ('tag', )
42 | self.EXTRA_VOCAB = self.EXTRA_VOCAB + ('tag', )
43 |
44 | self.num_lex = num_lex
45 | self.num_token = num_token
46 | self.ignore_stop_word = ignore_stop_word
47 | super().__init__(loader=loader, **kwargs)
48 | self.vocabs['token'] = None # set to manual init
49 |
50 | self.token2word = None
51 | self.token2tag = None
52 | if self.use_tag and self.num_lex > 0:
53 | self.token_mode = 'joint'
54 | elif self.use_tag:
55 | self.token_mode = 'tag'
56 | else:
57 | self.token_mode = 'word'
58 |
59 | def _load(self, path, name):
60 | ds: DataSet = self.loader._load(path)
61 |
62 | if self.token_mode == 'joint':
63 | ds.apply(lambda x: [f'{w.lower()}:{p}' for w, p in zip(x['raw_word'], x['tag'])], new_field_name='token')
64 | elif self.token_mode == 'tag':
65 | ds.apply(lambda x: x['tag'], new_field_name='token')
66 | else:
67 | ds.apply(lambda x: list(map(str.lower, x['raw_word'])), new_field_name='token')
68 |
69 | if name in ('train', 'train_init', 'dev', 'val', 'test'):
70 | ds['arc'].int()
71 | orig_len = len(ds)
72 | ds.drop(lambda i: not isprojective(i['arc']), inplace=False)
73 | cleaned_len = len(ds)
74 | if cleaned_len < orig_len:
75 | _warn(f'Data contains nonprojective trees. {path}')
76 | else:
77 | raise NotImplementedError
78 |
79 | return ds
80 |
81 | def post_init_vocab(self, datasets):
82 | count = Counter()
83 | word_count = Counter()
84 |
85 | if self.token_mode == 'tag':
86 | self.vocabs['token'] = self.vocabs['tag']
87 | self.token2tag = list(range(len(self.vocabs['token'])))
88 | return
89 |
90 | for ds in self.get_create_entry_ds():
91 | for inst in ds:
92 | word_count.update(map(str.lower, inst['word']))
93 | if self.token_mode == 'joint':
94 | count.update(zip(map(str.lower, inst['word']), inst['tag']))
95 |
96 | if self.ignore_stop_word:
97 | sw = set(stopwords.words('english'))
98 | used_word = [w for w, i in word_count.most_common(self.num_lex + len(sw)) if w not in sw]
99 | used_word = set(used_word[:self.num_lex])
100 | else:
101 | used_word = set(w for w, i in word_count.most_common(self.num_lex))
102 |
103 | processed_count = {}
104 | if self.token_mode == 'joint':
105 | for (w, p), c in count.most_common():
106 | if w in used_word:
107 | processed_count[f'{w}:{p}'] = c
108 | if len(processed_count) == self.num_token:
109 | break
110 | for p in self.vocabs['tag'].word2idx:
111 | if p in ('', ''): continue
112 | processed_count[f':{p}'] = 100000
113 | else:
114 | for w, c in word_count.most_common():
115 | if w in used_word:
116 | processed_count[w] = c
117 | if len(processed_count) == self.num_token:
118 | break
119 |
120 | token_vocab = Vocabulary()
121 | token_vocab.word_count = Counter(processed_count)
122 | token_vocab.build_vocab()
123 | self.vocabs['token'] = token_vocab
124 |
125 | if self.token_mode == 'joint':
126 | w, t = zip(*[token_vocab.idx2word[i].rsplit(':', 1) for i in range(2, len(token_vocab))])
127 | w = ['', ''] + list(w)
128 | t = ['', ''] + list(t)
129 | self.token2word = [self.vocabs['word'][i] for i in w]
130 | self.token2tag = [self.vocabs['tag'][i] for i in t]
131 | else:
132 | self.token2word = [self.vocabs['word'][token_vocab.idx2word[i]] for i in range(len(token_vocab))]
133 |
134 | def train_dataloader(self):
135 | loaders = {'train': self.dataloader('train')}
136 | for key in self.datasets:
137 | if key in ('train', 'dev', 'test'):
138 | continue
139 | if key == 'train_init':
140 | try:
141 | n_init = src.g_cfg.model.init_epoch
142 | do_init = src.g_cfg.model.init_method == 'y' and n_init > 0
143 | except (KeyError, omegaconf.errors.ConfigAttributeError):
144 | _warn('ignoring train_init due to missing cfg.')
145 | continue
146 | if do_init:
147 | loaders['train'] = _TrainInitLoader(self.dataloader('train_init'), loaders['train'], n_init)
148 | loaders[key] = self.dataloader(key)
149 | _info(f'Returning {len(loaders)} loader(s) as train_dataloader.')
150 | return loaders
151 |
152 |
153 | class _TrainInitLoader:
154 | def __init__(self, init_loader, normal_loader, n_init) -> None:
155 | self.init_loader = init_loader
156 | self.normal_loader = normal_loader
157 | self.n_init = n_init
158 | self.current = 1
159 |
160 | def __iter__(self):
161 | if self.current <= self.n_init:
162 | self.current += 1
163 | _warn('Initializing')
164 | yield from self.init_loader
165 | else:
166 | yield from self.normal_loader
167 |
--------------------------------------------------------------------------------
/src/datamodule/task/vlparse.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | import json
3 | import os
4 | from pathlib import Path
5 | from typing import Any, Dict, List, Tuple
6 |
7 | import numpy as np
8 | import torch
9 | from fastNLP.core import DataSet
10 | from omegaconf import DictConfig, ListConfig
11 | from torch import Tensor
12 |
13 | from src.datamodule.task.dep import DepDataModule
14 | from src.utility.logger import get_logger_func
15 |
16 | InputDict = Dict[str, Tensor]
17 | TensorDict = Dict[str, Tensor]
18 | AnyDict = Dict[str, Any]
19 | GenDict = (dict, DictConfig)
20 | GenList = (list, ListConfig)
21 |
22 | _warn, _info, _debug = get_logger_func("datamodule")
23 |
24 |
25 | def get_box(obj):
26 | return [obj["x"], obj["y"], obj["x"] + obj["width"], obj["y"] + obj["height"]]
27 |
28 |
29 | class _COCODetFeatLazyLoader:
30 | def __init__(self, root, sg_data, sample, gold):
31 | self.root = root
32 | self.sg_data = sg_data
33 | self.sample = sample
34 | self.gold = gold
35 |
36 | def __call__(self, batch: List[Tuple[int, Any]]):
37 | box_feats, boxes, masks, rel_masks = [], [], [], []
38 | max_len = 0
39 | for _, inst in batch:
40 | if (self.root / f"{inst['img_id']}.npy").exists():
41 | feat = np.load(str(self.root / f"{inst['img_id']}.npy"))
42 | if self.sample > 0 and self.sample < len(feat):
43 | sample_id = np.random.choice(
44 | np.arange(len(feat)), self.sample, False
45 | )
46 | feat = feat[sample_id]
47 | else:
48 | feat = feat[:35]
49 | sample_id = np.arange(len(feat))
50 | box_feat, box = feat[:, :-4], feat[:, -4:]
51 | box_feat = torch.tensor(box_feat, dtype=torch.float)
52 | box = torch.tensor(box)
53 |
54 | box_feats.append(box_feat)
55 | boxes.append(box)
56 |
57 | if self.gold:
58 | inst_mask, inst_rel_mask = self.build_gold_mask(inst, sample_id)
59 | masks.append(inst_mask)
60 | rel_masks.append(inst_rel_mask)
61 | else:
62 | masks.append(torch.ones(len(box_feat), dtype=torch.bool))
63 | rel_masks.append(None)
64 | max_len = max(len(box_feat), max_len)
65 | else:
66 | assert False
67 |
68 | box_feats_output = torch.zeros(len(box_feats), max_len, 2048)
69 | boxes_output = torch.zeros(len(boxes), max_len, 4)
70 | masks_output = torch.zeros(len(masks), max_len, dtype=torch.bool)
71 | rel_masks_output = (
72 | None
73 | if len(rel_masks) == 0
74 | else torch.zeros(len(rel_masks), max_len, max_len, dtype=torch.bool)
75 | )
76 | for i, (bf, b, m, rm) in enumerate(zip(box_feats, boxes, masks, rel_masks)):
77 | if bf is not None:
78 | box_feats_output[i, : len(bf)] = bf
79 | boxes_output[i, : len(b)] = b
80 | masks_output[i, : len(m)] = m
81 | if rm is not None:
82 | rel_masks_output[i, : rm.shape[0], : rm.shape[1]] = rm
83 |
84 | return (
85 | {
86 | "vis_box_feat": box_feats_output,
87 | "vis_box_mask": masks_output,
88 | "vis_rel_mask": rel_masks_output,
89 | "vis_available": masks_output[:, 0],
90 | },
91 | {"vis_box": boxes_output},
92 | )
93 |
94 | def build_gold_mask(self, inst, sample_id):
95 | sg_inst = self.sg_data[inst["img_id"]]
96 | if len(sg_inst["obj"]) == 0:
97 | return torch.zeros(0, dtype=torch.bool), torch.zeros(0, 0, dtype=torch.bool)
98 | mask = torch.ones(min(len(sample_id), len(sg_inst["obj"])), dtype=torch.bool)
99 | rel_mask = torch.zeros(
100 | len(sg_inst["obj"]), len(sg_inst["obj"]), dtype=torch.bool
101 | )
102 | for item in sg_inst["rel"]:
103 | rel_mask[item["subj"], item["obj"]] = 1
104 | sample_id = torch.from_numpy(sample_id)
105 | rel_mask = rel_mask.gather(
106 | 1, sample_id.unsqueeze(0).expand(rel_mask.shape[1], -1)
107 | ).gather(0, sample_id.unsqueeze(-1).expand(-1, len(sample_id)))
108 | return mask, rel_mask
109 |
110 |
111 | class VLParseDataModule(DepDataModule):
112 | TARGETS = ("arc", "sg_type", "sg_box", "sg_mask")
113 | # train: text(.conll), proposed box(det_feats/.npy), img(.npy)
114 | # dev: text(.conll), proposed box(det_feats/.npy), img(.npy), scene graph(../.json)
115 | # test: text(.conll), proposed box(det_feats/.npy), scene graph(../.json)
116 |
117 | def __init__(self, use_img, use_gold_scene_graph, sg_path, **kwargs):
118 |
119 | self.use_img = use_img # use native image feature
120 | if self.use_img:
121 | self.INPUTS = self.INPUTS + ("vis_img",)
122 | self.use_gold_scene_graph = use_gold_scene_graph # return gold box and rels
123 |
124 | with open(sg_path) as f: # load scene graph
125 | sg_data = json.load(f)
126 | self.sg_data = {inst["coco_id"]: inst for inst in sg_data}
127 |
128 | if use_gold_scene_graph:
129 | with open(os.path.split(sg_path)[0] + "/vlparse_train_sg_raw.json") as f:
130 | sg_data = json.load(f)
131 | self.sg_data |= {inst["coco_id"]: inst for inst in sg_data}
132 |
133 | super().__init__(**kwargs)
134 |
135 | def _load(self, path, name) -> DataSet:
136 | # text: xxx.conll, a conllu format file
137 | # img: xxx.npy, each item is prefeteched feat. [n_img x hidden_size]
138 | # det_feats/.npy, box feat for each img shape: 100 x (1024+4)
139 | # id_list/xxx.txt, each line is a img_id and sent_id pair. assume sent with same img_id are put together.
140 | ds: DataSet = super()._load(path + ".conll", name)
141 |
142 | # load ids
143 | folder, filename = os.path.split(path)
144 | with open(Path(folder) / "id_list" / (filename + ".txt")) as f:
145 | img_id = [int(line.strip()) for line in f]
146 | if len(img_id) != len(ds):
147 | img_id = [id_ for id_ in img_id for _ in range(5)]
148 | ds.add_field("img_id", img_id)
149 | ds.add_field("img_sent_id", [i % 5 for i, _ in enumerate(img_id)])
150 |
151 | # native image feature
152 | with self.tolerant_exception(["test"], name):
153 | if self.use_img:
154 | img_feat = np.load(path + ".npy").repeat(5, 0)
155 | ds.add_field("vis_img", img_feat, is_input=True)
156 |
157 | # prepare target, (and input if gold_sg) from sg data
158 | ds.apply_more(self.process_sg)
159 |
160 | ds.add_collate_fn(
161 | _COCODetFeatLazyLoader(
162 | Path(folder)
163 | / ("gold_feats" if self.use_gold_scene_graph else "det_feats"),
164 | self.sg_data,
165 | 35 if name in ("train", "train_init") else 0,
166 | self.use_gold_scene_graph,
167 | ),
168 | "det_feat_loader",
169 | )
170 | if name in ("dev", "test") or self.use_gold_scene_graph:
171 | ds.drop(lambda x: not x["has_sg"])
172 | return ds
173 |
174 | def process_sg(self, inst):
175 | if inst["img_id"] not in self.sg_data:
176 | txt2sg = {}
177 | rels = []
178 | else:
179 | sg = self.sg_data[inst["img_id"]]
180 | rels = sg["rel"]
181 | txt2sg = sg["txt2sg"][inst["img_sent_id"]]
182 | id2node = {node["id"]: node for node in chain(sg["obj"], sg["rel"])}
183 | typestr2id = {"OBJ": 1, "ATTR": 2, "REL": 3}
184 | gold_box, tok_type = [], []
185 |
186 | # here only collect grounded box per words
187 | for i in range(len(inst["raw_word"])):
188 | if (i := str(i)) in txt2sg:
189 | alignment = txt2sg[i]
190 | tok_type.append(typestr2id[alignment["type"]])
191 | if tok_type[-1] == 3:
192 | node = id2node[alignment["preferred"]]
193 | subj, obj = id2node[node["subj"]], id2node[node["obj"]]
194 | gold_box.append(get_box(subj) + get_box(obj))
195 | else:
196 | gold_box.append(
197 | get_box(id2node[alignment["preferred"]]) + [0.0] * 4
198 | )
199 | else:
200 | tok_type.append(0)
201 | gold_box.append([0.0] * 8)
202 |
203 | sg_rel = [[item["subj"], item["obj"]] for item in rels]
204 | return {
205 | "sg_type": tok_type,
206 | "sg_box": gold_box,
207 | "vis_rel": sg_rel, # this is for inputs. When eval we just need sg_box.
208 | "sg_mask": [t != 0 for t in tok_type],
209 | "has_sg": inst["img_id"] in self.sg_data,
210 | }
211 |
--------------------------------------------------------------------------------
/src/datamodule/vocabulary.py:
--------------------------------------------------------------------------------
1 | from fastNLP import Vocabulary as _fastNLP_Vocabulary
2 | from fastNLP.core.vocabulary import _check_build_vocab
3 |
4 |
5 | class Vocabulary(_fastNLP_Vocabulary):
6 | @_check_build_vocab
7 | def __getitem__(self, w: str):
8 | if w.endswith("::"):
9 | w = [w[:-2], ":"]
10 | else:
11 | w = w.rsplit(":", 1)
12 | w[0] = w[0].lower()
13 | if (_w := ":".join(w)) in self._word2idx:
14 | return self._word2idx[_w]
15 | if (_w := ":" + w[1]) in self._word2idx:
16 | return self._word2idx[_w]
17 | # no need to check
18 | raise ValueError("word `{}` not in vocabulary".format(w))
19 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import ModelBase, JointModelBase
2 | from .dmv import DMV
3 | from .ldndmv import DiscriminativeNDMV
4 | from .joint import DependencyBoxRel
--------------------------------------------------------------------------------
/src/model/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import functools
4 | from collections import defaultdict
5 | from io import IOBase
6 | from typing import Any, Dict, List, Tuple
7 |
8 | import torch.nn as nn
9 | from fastNLP import DataSet, Vocabulary
10 | from hydra.utils import instantiate
11 | from torch import Tensor
12 |
13 | import src
14 | from src.datamodule import DataModule
15 | from src.model.embedding import Embedding
16 | from src.model.text_encoder import EncoderBase
17 | from src.utility.defaultlist import defaultlist
18 | from src.utility.fn import get_coeff_iter
19 | from src.utility.logger import get_logger_func
20 | from src.utility.var_pool import VarPool
21 | from abc import ABC
22 | from typing import Dict, Any, Type, Tuple
23 |
24 | from src.utility.config import Config
25 | from hydra.utils import instantiate
26 | from omegaconf import open_dict, OmegaConf
27 | from torch import Tensor
28 |
29 |
30 | from src.model.vis_encoder.base import VisEncoderBase
31 |
32 | InputDict = Dict[str, Tensor]
33 | TensorDict = Dict[str, Tensor]
34 | AnyDict = Dict[str, Any]
35 |
36 | _warn, _info, _debug = get_logger_func("model")
37 |
38 |
39 | class ModelBase(nn.Module):
40 | datamodule: DataModule
41 | embedding: Embedding
42 | encoder: EncoderBase
43 | _function_group = {}
44 |
45 | def __init__(self):
46 | super(ModelBase, self).__init__()
47 | self._dynamic_cfg = {}
48 |
49 | def setup(self, dm: DataModule):
50 | self.datamodule = dm
51 | self.embedding = Embedding(**src.g_cfg.embedding, dm=dm)
52 | self.encoder = instantiate(src.g_cfg.encoder, embedding=self.embedding)
53 | self.embedding.__dict__["bounded_model"] = self
54 | self.encoder.__dict__["bounded_model"] = self
55 |
56 | def forward(
57 | self, inputs: InputDict, vp: VarPool, embed=None, encoded=None, return_all=False
58 | ):
59 | dyn_cfg = self.apply_dynamic_cfg()
60 | src.trainer.lightning_module.log_dict(dyn_cfg)
61 | if embed is None:
62 | embed = self.embedding(inputs, vp)
63 | if encoded is None or encoded["__need_encode"]:
64 | if encoded is None:
65 | encoded = {}
66 | else:
67 | del encoded["__need_encode"]
68 | encoded |= self.encoder(embed, vp)
69 | encoded["emb"] = embed
70 | score = self._forward(inputs, encoded, vp)
71 | if return_all:
72 | return embed, encoded, score
73 | return score
74 |
75 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool):
76 | raise NotImplementedError
77 |
78 | def loss(
79 | self, x: TensorDict, gold: InputDict, vp: VarPool
80 | ) -> Tuple[Tensor, TensorDict]:
81 | raise NotImplementedError
82 |
83 | def decode(self, x: TensorDict, vp: VarPool) -> AnyDict:
84 | raise NotImplementedError
85 |
86 | def normalize_embedding(self, now):
87 | self.embedding.normalize(now)
88 |
89 | def preprocess_write(self, output: List[Dict[str, Any]]):
90 | batch_size = len(output[0]["id"]) # check one batch
91 | safe_to_sort = all(
92 | (len(p) == batch_size) for p in output[0]["predict"].values()
93 | )
94 |
95 | if safe_to_sort:
96 | # I will put all predicts in the order of idx, but you have to remove padding by yourself.
97 | sorted_predicts = defaultdict(defaultlist)
98 | for batch in output:
99 | id_, predict = batch["id"], batch["predict"]
100 | for key, value in predict.items():
101 | if isinstance(value, Tensor):
102 | value = value.detach().cpu().numpy()
103 | for one_id, one_value in zip(id_, value):
104 | sorted_predicts[key][one_id] = one_value
105 | return sorted_predicts
106 | else:
107 | raise NotImplementedError("Can not preprocess automatically.")
108 |
109 | def write_prediction(
110 | self, s: IOBase, predicts, dataset: DataSet, vocabs: Dict[str, Vocabulary]
111 | ) -> IOBase:
112 | raise NotImplementedError
113 |
114 | # noinspection PyMethodMayBeStatic
115 | def set_varpool(self, vp: VarPool) -> VarPool:
116 | return vp
117 |
118 | @classmethod
119 | def add_impl_to_group(cls, group, spec, pre_hook=None):
120 | def decorator(func):
121 | if group not in cls._function_group:
122 | cls._function_group[group] = {}
123 | assert spec not in cls._function_group[group], spec
124 | cls._function_group[group][spec] = (func, pre_hook)
125 |
126 | @functools.wraps(func)
127 | def wrapper(*args, **kwargs):
128 | return func(*args, **kwargs)
129 |
130 | return wrapper
131 |
132 | return decorator
133 |
134 | def set_impl_in_group(self, group, spec):
135 | try:
136 | impl, pre_hook = self._function_group[group][spec]
137 | except Exception as e:
138 | _warn(f"Failed to load {group}: {spec}")
139 | raise e
140 | if pre_hook is not None:
141 | getattr(self, pre_hook)()
142 | setattr(self, group, functools.partial(impl, self))
143 |
144 | def add_dynamic_cfg(self, name, command):
145 | """name: |"""
146 | if name in self._dynamic_cfg:
147 | _warn(f"Overwriting {name} with {command}")
148 | self._dynamic_cfg[name] = get_coeff_iter(
149 | command, idx_getter=lambda: src.trainer.current_epoch
150 | )
151 |
152 | def apply_dynamic_cfg(self):
153 | params = {key: next(value) for key, value in self._dynamic_cfg.items()}
154 | for key, value in params.items():
155 | obj_nev, cfg_nev = key.split("|")
156 | o = self
157 | for attr_name in obj_nev.split("."):
158 | o = getattr(o, attr_name)
159 | s = o
160 | cfg_nev = cfg_nev.split(".")
161 | for k in cfg_nev[:-1]:
162 | s = s[k]
163 | s[cfg_nev[-1]] = value
164 | return params
165 |
166 | def process_checkpoint(self, ckpt):
167 | return ckpt
168 |
169 |
170 | class JointModelBase(ModelBase, ABC):
171 | # assume only one datamodule
172 | # assume image does not require embedding
173 | # assume all visual-side module/parameter are named with 'vis_' prefix.
174 |
175 | # I prefer not seperate the joint model into a language-side model and a visual-side model
176 | # because it is hard to foresee possible interaction between two sides and
177 | # for now the visual-side model is very simple.
178 |
179 | # language part, inherit from ModelBase
180 | # datamodule: DataModule
181 | # embedding: Embedding
182 | # encoder: EncoderBase
183 |
184 | # visual part
185 | vis_encoder: VisEncoderBase
186 |
187 | def setup(self, dm: DataModule):
188 | if getattr(self, "__setup_handled") is not True:
189 | _warn("You call setup() directly. Consider to use _setup()")
190 | self.datamodule = dm
191 | # self.embedding = Embedding(**src.g_cfg.embedding, dm=dm)
192 | # self.embedding.__dict__['bounded_model'] = self
193 | self.encoder = instantiate(src.g_cfg.encoder, embedding=self.embedding)
194 | self.encoder.__dict__["bounded_model"] = self
195 | self.vis_encoder = instantiate(src.g_cfg.vis_encoder)
196 | if self.vis_encoder is None:
197 | _warn("vis_encoder is disabled.")
198 | else:
199 | self.vis_encoder.__dict__["bounded_model"] = self
200 |
201 | def _setup(self, dm: DataModule, cfg_class: Type[Config], allow_missing=None):
202 | setattr(self, "__setup_handled", True)
203 | self.cfg = cfg = cfg_class.build(self.cfg, allow_missing=allow_missing)
204 | with open_dict(cfg.dep_model_cfg):
205 | cfg.dep_model_cfg = OmegaConf.merge(cfg.dep_model_cfg, dm.get_vocab_count())
206 | self.dependency = instantiate(cfg.dep_model_cfg)
207 | self.dependency.setup(dm)
208 | JointModelBase.setup(self, dm)
209 | return cfg
210 |
211 | @property
212 | def embedding(self):
213 | return self.dependency.embedding
214 |
215 | def forward(
216 | self,
217 | inputs: InputDict,
218 | vp: VarPool,
219 | embed=None,
220 | encoded=None,
221 | vis_encoded=None,
222 | return_all=False,
223 | ):
224 | if vis_encoded is None:
225 | vis_input = {
226 | key: value for key, value in inputs.items() if key.startswith("vis_")
227 | }
228 | if len(vis_input) > 0:
229 | vis_encoded = self.vis_encoder(vis_input, vp)
230 | else:
231 | vis_encoded = {}
232 | encoded = encoded if encoded is not None else {"__need_encode": True}
233 | for key, value in vis_encoded.items():
234 | encoded[f"vis_{key}"] = value
235 | embed, encoded, score = super().forward(inputs, vp, embed, encoded, True)
236 | vis_score = self._vis_forward(inputs, vis_encoded, encoded, score, vp)
237 | score = {**score, **vis_score}
238 | if return_all:
239 | return embed, encoded, score
240 | else:
241 | return score
242 |
243 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool):
244 | return self.dependency._forward(inputs, encoded, vp)
245 |
246 | def _vis_forward(
247 | self,
248 | inputs: InputDict,
249 | encoded: TensorDict,
250 | language_encoded: TensorDict,
251 | lang_score: TensorDict,
252 | vp: VarPool,
253 | ):
254 | raise NotImplementedError
255 |
256 |
--------------------------------------------------------------------------------
/src/model/dmv.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from io import IOBase
3 | from typing import Any, Dict, Tuple, Optional
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from fastNLP import DataSet, Vocabulary
9 | from hydra.conf import MISSING
10 | from torch import Tensor
11 | from torch.optim import Optimizer
12 |
13 | from src.datamodule.task import DepDataModule
14 | from src.model import ModelBase
15 | from src.model.dmv_helper import km_init, good_init
16 | from src.model.torch_struct import DMV1o, DependencyCRF
17 | from src.utility.config import Config
18 | from src.utility.logger import get_logger_func
19 | from src.utility.var_pool import VarPool
20 |
21 | InputDict = Dict[str, Tensor]
22 | TensorDict = Dict[str, Tensor]
23 | AnyDict = Dict[str, Any]
24 |
25 | _warn, _info, _debug = get_logger_func('model')
26 |
27 |
28 | @dataclass
29 | class DMVConfig(Config):
30 | viterbi_training: bool
31 | mbr_decoding: bool
32 | init_method: str # km, good, random
33 | smooth: float
34 |
35 | # ============================= AUTO FIELDS =============================
36 | n_word: int = MISSING
37 | n_tag: int = MISSING
38 | n_token: int = MISSING
39 |
40 |
41 | class DMV(ModelBase):
42 | _instance = None # work around for DMVMStepOptimizer
43 |
44 | def __init__(self, **cfg):
45 | super().__init__()
46 | # noinspection PyTypeChecker
47 | self.cfg: DMVConfig = cfg
48 | self.root_param: Optional[nn.Parameter] = None
49 | self.trans_param: Optional[nn.Parameter] = None
50 | self.dec_param: Optional[nn.Parameter] = None
51 | self.optimizer: Optional[DMVMStepOptimizer] = None
52 |
53 | if DMV._instance is not None:
54 | _warn('overwriting DMV._instance')
55 | DMV._instance = self
56 |
57 | def setup(self, dm: DepDataModule):
58 | self.datamodule = dm
59 | self.cfg = cfg = DMVConfig.build(self.cfg, allow_missing={'n_word', 'n_tag'})
60 |
61 | if cfg.init_method == 'km':
62 | d, t, r = km_init(dm.datasets['train'], cfg.n_token, cfg.smooth)
63 | elif cfg.init_method == 'good':
64 | d, t, r = good_init(dm.datasets['train'], cfg.n_token, cfg.smooth)
65 | else:
66 | d = np.random.randn(cfg.n_token, 2, 2, 2)
67 | r = np.random.randn(cfg.n_token)
68 | t = np.random.randn(cfg.n_token, cfg.n_token, 2, 2)
69 |
70 | self.root_param = nn.Parameter(torch.from_numpy(r))
71 | # head, child, dir, valence
72 | self.trans_param = nn.Parameter(torch.from_numpy(t))
73 | # head, dir, valence, decision
74 | self.dec_param = nn.Parameter(torch.from_numpy(d))
75 |
76 | def forward(self, inputs: InputDict, vp: VarPool, embed=None, encoded=None, return_all=False):
77 | assert embed is None
78 | assert encoded is None
79 | assert not return_all
80 | return self._forward(inputs, {}, vp)
81 |
82 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool):
83 | b, l, n = vp.batch_size, vp.max_len, self.cfg.n_token
84 | token_array = inputs['token']
85 |
86 | t = self.trans_param.unsqueeze(0).expand(b, n, n, 2, 2)
87 | head_token_index = token_array.view(b, l, 1, 1, 1).expand(b, l, n, 2, 2)
88 | child_token_index = token_array.view(b, 1, l, 1, 1).expand(b, l, l, 2, 2)
89 | t = torch.gather(torch.gather(t, 1, head_token_index), 2, child_token_index)
90 | index = torch.triu(torch.ones(l, l, dtype=torch.long, device=t.device)) \
91 | .view(1, l, l, 1, 1).expand(b, l, l, 1, 2)
92 | t = torch.gather(t, 3, index).squeeze(3)
93 |
94 | d = self.dec_param.unsqueeze(0).expand(b, n, 2, 2, 2)
95 | head_pos_index = token_array.view(b, l, 1, 1, 1).expand(b, l, 2, 2, 2)
96 | d = torch.gather(d, 1, head_pos_index)
97 |
98 | r = self.root_param.unsqueeze(0).expand(b, n)
99 | r = torch.gather(r, 1, token_array)
100 |
101 | merged_d, merged_t = DMV1o.merge(d, t, r)
102 | return {'merged_dec': merged_d, 'merged_attach': merged_t}
103 |
104 | def loss(self, x: TensorDict, gold: InputDict, vp: VarPool) -> Tuple[Tensor, TensorDict]:
105 | dist = DMV1o([x['merged_dec'], x['merged_attach']], vp.seq_len)
106 | if self.cfg.viterbi_training:
107 | ll = dist.max.sum()
108 | else:
109 | ll = dist.partition.sum()
110 | return -ll, {'ll': ll}
111 |
112 | # noinspection DuplicatedCode
113 | @torch.enable_grad()
114 | def decode(self, x: TensorDict, vp: VarPool) -> AnyDict:
115 | if self.optimizer:
116 | self.optimizer.apply()
117 | mdec = x['merged_dec'].detach().requires_grad_()
118 | mattach = x['merged_attach'].detach().requires_grad_()
119 | dist = DMV1o([mdec, mattach], vp.seq_len)
120 | if self.cfg.mbr_decoding:
121 | arc = torch.autograd.grad(dist.partition.sum(), mattach)[0].sum(-1)
122 | dist = DependencyCRF(arc, vp.seq_len)
123 | arc = dist.argmax.nonzero()
124 | predicted = vp.seq_len.new_zeros(vp.batch_size, vp.max_len)
125 | predicted[arc[:, 0], arc[:, 2] - 1] = arc[:, 1]
126 | else:
127 | arc = dist.argmax.sum(-1).nonzero()
128 | predicted = vp.seq_len.new_zeros(vp.batch_size, vp.max_len)
129 | predicted[arc[:, 0], arc[:, 2] - 1] = arc[:, 1]
130 | return {'arc': predicted}
131 |
132 | def normalize_embedding(self, now):
133 | pass
134 |
135 | # noinspection DuplicatedCode
136 | def write_prediction(self, s: IOBase, predicts, dataset: DataSet, vocabs: Dict[str, Vocabulary]) -> IOBase:
137 | for i, length in enumerate(dataset['seq_len'].content):
138 | word, arc = dataset[i]['raw_word'], predicts['arc'][i]
139 | for line_id, (word, arc) in enumerate(zip(word, arc), start=1):
140 | line = '\t'.join([str(line_id), word, '-', str(arc)])
141 | s.write(f'{line}\n')
142 | s.write('\n')
143 | return s
144 |
145 |
146 | class DMVMStepOptimizer(Optimizer):
147 | def __init__(self, params, smooth: float):
148 | self.dmv = DMV._instance
149 | self.dmv.optimizer = self
150 |
151 | self._root, self._dec, self._trans = None, None, None
152 | self.smooth = smooth
153 | self.can_apply = False
154 | super().__init__(self.dmv.parameters(), {})
155 |
156 | def step(self, closure=None):
157 | loss = None
158 | if closure is not None:
159 | with torch.enable_grad():
160 | loss = closure()
161 |
162 | if self._root is None:
163 | self._root = torch.zeros_like(self.dmv.root_param)
164 | self._dec = torch.zeros_like(self.dmv.dec_param)
165 | self._trans = torch.zeros_like(self.dmv.trans_param)
166 |
167 | self._root -= self.dmv.root_param.grad
168 | self._dec -= self.dmv.dec_param.grad
169 | self._trans -= self.dmv.trans_param.grad
170 | self.can_apply = True
171 |
172 | def apply(self):
173 | if self.can_apply:
174 | self.dmv.root_param.data, self._root = \
175 | torch.log(self._root + self.smooth).log_softmax(0), self.dmv.root_param.data
176 | self.dmv.dec_param.data, self._dec = \
177 | torch.log(self._dec + self.smooth).log_softmax(3), self.dmv.dec_param.data
178 | self.dmv.trans_param.data, self._trans = \
179 | torch.log(self._trans + self.smooth).log_softmax(1), self.dmv.trans_param.data
180 | self.reset()
181 |
182 | def reset(self):
183 | self._root.zero_()
184 | self._dec.zero_()
185 | self._trans.zero_()
186 | self.can_apply = False
187 |
--------------------------------------------------------------------------------
/src/model/dmv_helper/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.dmv_helper.good_init import good_init
2 | from src.model.dmv_helper.good_init_nn import generate_rule_1o, LinearPadder, SquarePadder
3 | from src.model.dmv_helper.km_init import km_init
4 |
--------------------------------------------------------------------------------
/src/model/dmv_helper/good_init.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fastNLP import DataSet, AutoPadder
3 |
4 |
5 | from src.model.torch_struct.dmv import HASCHILD, NOCHILD, STOP, GO
6 |
7 |
8 | def recovery_one(heads):
9 | left_most = np.arange(len(heads))
10 | right_most = np.arange(len(heads))
11 | for idx, each_head in enumerate(heads):
12 | if each_head in (0, len(heads) + 1): # skip head is ROOT
13 | continue
14 | each_head -= 1
15 | if idx < left_most[each_head]:
16 | left_most[each_head] = idx
17 | if idx > right_most[each_head]:
18 | right_most[each_head] = idx
19 |
20 | valences = np.empty((len(heads), 2), dtype=np.int)
21 | head_valences = np.empty(len(heads), dtype=np.int)
22 |
23 | for idx, each_head in enumerate(heads):
24 | each_head -= 1
25 | valences[idx, 0] = NOCHILD if left_most[idx] == idx else HASCHILD
26 | valences[idx, 1] = NOCHILD if right_most[idx] == idx else HASCHILD
27 | if each_head > idx: # each_head = -1 `s head_valence is never used
28 | head_valences[idx] = NOCHILD if left_most[each_head] == idx else HASCHILD
29 | else:
30 | head_valences[idx] = NOCHILD if right_most[each_head] == idx else HASCHILD
31 | return valences, head_valences
32 |
33 |
34 | def good_init(dataset: DataSet, n_token: int, smooth: float):
35 | """process all sentences in one batch."""
36 | max_len = max(dataset['seq_len'].content)
37 | heads = np.zeros((len(dataset), max_len + 1), dtype=np.int)
38 | valences = np.zeros((len(dataset), max_len + 1, 2), dtype=np.int)
39 | head_valences = np.zeros((len(dataset), max_len + 1), dtype=np.int)
40 | root_counter = np.zeros((n_token,))
41 |
42 | for idx, instance in enumerate(dataset):
43 | one_heads = np.asarray(instance['arc'])
44 | one_valences, one_head_valences = recovery_one(one_heads)
45 | heads[idx, 1:instance['seq_len'] + 1] = one_heads
46 | valences[idx, 1:instance['seq_len'] + 1] = one_valences
47 | head_valences[idx, 1:instance['seq_len'] + 1] = one_head_valences
48 |
49 | batch_size, sentence_len = heads.shape
50 | len_array = np.asarray(dataset['seq_len'].content)
51 | token_array = AutoPadder()(dataset['token'].content, 'token', np.int, 1)
52 | batch_arange = np.arange(batch_size)
53 |
54 | batch_trans_trace = np.zeros((batch_size, max_len, max_len, 2, 2))
55 | batch_dec_trace = np.zeros((batch_size, max_len, max_len, 2, 2, 2))
56 |
57 | for m in range(1, sentence_len):
58 | h = heads[:, m]
59 | direction = (h <= m).astype(np.long)
60 | h_valence = head_valences[:, m]
61 | m_valence = valences[:, m]
62 | m_child_valence = h_valence
63 |
64 | len_mask = ((h <= len_array) & (m <= len_array))
65 |
66 | batch_dec_trace[batch_arange, m - 1, m - 1, 0, m_valence[:, 0], STOP] = len_mask
67 | batch_dec_trace[batch_arange, m - 1, m - 1, 1, m_valence[:, 1], STOP] = len_mask
68 |
69 | head_mask = h == 0
70 | mask = head_mask * len_mask
71 | if mask.any():
72 | np.add.at(root_counter, token_array[:, m - 1], mask)
73 |
74 | head_mask = ~head_mask
75 | mask = head_mask * len_mask
76 | if mask.any():
77 | batch_trans_trace[batch_arange, h - 1, m - 1, direction, m_child_valence] = mask
78 | batch_dec_trace[batch_arange, h - 1, m - 1, direction, h_valence, GO] = mask
79 |
80 | dec_post_dim = (2, 2, 2)
81 | dec_counter = np.zeros((n_token, *dec_post_dim))
82 | index = (token_array.flatten(),)
83 | np.add.at(dec_counter, index, np.sum(batch_dec_trace, 2).reshape(-1, *dec_post_dim))
84 |
85 | trans_post_dim = (2, 2)
86 | head_ids = np.tile(np.expand_dims(token_array, 2), (1, 1, max_len))
87 | child_ids = np.tile(np.expand_dims(token_array, 1), (1, max_len, 1))
88 | trans_counter = np.zeros((n_token, n_token, *trans_post_dim))
89 | index = (head_ids.flatten(), child_ids.flatten())
90 | np.add.at(trans_counter, index, batch_trans_trace.reshape(-1, *trans_post_dim))
91 |
92 | root_counter += smooth
93 | root_sum = root_counter.sum()
94 | root_param = np.log(root_counter / root_sum)
95 |
96 | trans_counter += smooth
97 | trans_sum = trans_counter.sum(axis=1, keepdims=True)
98 | trans_param = np.log(trans_counter / trans_sum)
99 |
100 | dec_counter += smooth
101 | dec_sum = dec_counter.sum(axis=3, keepdims=True)
102 | dec_param = np.log(dec_counter / dec_sum)
103 | return dec_param, trans_param, root_param
104 |
--------------------------------------------------------------------------------
/src/model/dmv_helper/good_init_nn.py:
--------------------------------------------------------------------------------
1 | # unlike good_init.py, this file contains helpers to initialize nn without dmv.
2 |
3 | from typing import List
4 |
5 | import numpy as np
6 | from fastNLP.core.field import Padder
7 |
8 | from src.model.torch_struct.dmv import LEFT, RIGHT, HASCHILD, NOCHILD, GO, STOP
9 |
10 |
11 | class LinearPadder(Padder):
12 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
13 | max_sent_length = max(r.shape[0] for r in contents)
14 | batch_size = len(contents)
15 | out = np.full((batch_size, max_sent_length, *contents[0].shape[1:]), fill_value=self.pad_val, dtype=np.float)
16 | for b_idx, rule in enumerate(contents):
17 | sent_len = rule.shape[0]
18 | out[b_idx, :sent_len] = rule
19 | return out
20 |
21 |
22 | class SquarePadder(Padder):
23 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
24 | max_sent_length = max(r.shape[0] for r in contents)
25 | batch_size = len(contents)
26 | out = np.full((batch_size, max_sent_length, max_sent_length, *contents[0].shape[2:]), fill_value=self.pad_val,
27 | dtype=np.float)
28 | for b_idx, rule in enumerate(contents):
29 | sent_len = rule.shape[0]
30 | out[b_idx, :sent_len, :sent_len] = rule
31 | return out
32 |
33 |
34 | def generate_rule_1o(heads: List[int]):
35 | """
36 | First-order DMV, generate the grammar rules used in the "predicted" parse tree from other parser.
37 | :param heads: the head of each position
38 | :return: decision rule
39 | """
40 | seq_len = len(heads)
41 | decision = np.zeros(shape=(seq_len, 2, 2, 2))
42 | attach = np.zeros(shape=(seq_len, seq_len, 2))
43 | root = np.zeros(shape=(seq_len,))
44 | root[heads.index(0)] = 1
45 |
46 | left_most_child = list(range(seq_len))
47 | right_most_child = list(range(seq_len))
48 | for child, head in enumerate(heads):
49 | head = head - 1
50 | if head == -1:
51 | continue
52 | elif child < head:
53 | if child < left_most_child[head]:
54 | left_most_child[head] = child
55 | else:
56 | if child > right_most_child[head]:
57 | right_most_child[head] = child
58 |
59 | for child, head in enumerate(heads):
60 | head = head - 1
61 |
62 | if child < head:
63 | most_child, d = left_most_child, LEFT
64 | else:
65 | most_child, d = right_most_child, RIGHT
66 |
67 | valence = NOCHILD if most_child[head] == child else HASCHILD
68 | decision[head][d][valence][GO] += 1
69 | if head != -1:
70 | attach[head][child][valence] += 1
71 |
72 | valence = NOCHILD if left_most_child[child] == child else HASCHILD
73 | decision[child][LEFT][valence][STOP] += 1
74 |
75 | valence = NOCHILD if right_most_child[child] == child else HASCHILD
76 | decision[child][RIGHT][valence][STOP] += 1
77 |
78 | return {'dec_rule': decision, 'attach_rule': attach, 'root_rule': root}
79 |
--------------------------------------------------------------------------------
/src/model/dmv_helper/km_init.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from fastNLP import DataSet, DataSetIter
3 | from numpy import ndarray
4 |
5 | from src.datamodule.sampler import ConstantTokenNumSampler
6 | from src.model.torch_struct.dmv import HASCHILD, NOCHILD, STOP, GO
7 |
8 | harmonic_sum = [0., 1.]
9 |
10 |
11 | def get_harmonic_sum(n: int):
12 | global harmonic_sum
13 | while n >= len(harmonic_sum):
14 | harmonic_sum.append(harmonic_sum[-1] + 1 / len(harmonic_sum))
15 | return harmonic_sum[n]
16 |
17 |
18 | def update_decision(change: ndarray, norm_counter: ndarray, token_array: ndarray, dec_param: ndarray):
19 | for i in range(token_array.shape[1]):
20 | pos = token_array[:, i]
21 | for _direction in (0, 1):
22 | if change[i, _direction] > 0:
23 | np.add.at(norm_counter, (pos, _direction, NOCHILD, GO), 1.)
24 | np.add.at(norm_counter, (pos, _direction, HASCHILD, GO), -1.)
25 | np.add.at(dec_param, (pos, _direction, HASCHILD, GO), change[i, _direction])
26 | np.add.at(norm_counter, (pos, _direction, NOCHILD, STOP), -1.)
27 | np.add.at(norm_counter, (pos, _direction, HASCHILD, STOP), 1.)
28 | np.add.at(dec_param, (pos, _direction, NOCHILD, STOP), 1.)
29 | else:
30 | np.add.at(dec_param, (pos, _direction, NOCHILD, STOP), 1.)
31 |
32 |
33 | def first_child_update(norm_counter: ndarray, dec_param: ndarray):
34 | all_param = dec_param.flatten()
35 | all_norm = norm_counter.flatten()
36 | mask = (all_param <= 0) | (0 <= all_norm)
37 | ratio = -all_param / all_norm
38 | ratio[mask] = 1.
39 | return np.min(ratio)
40 |
41 |
42 | def km_init(dataset: DataSet, n_token: int, smooth: float):
43 | # do not ask why? I do not know more than you.
44 | dec_param = np.zeros((n_token, 2, 2, 2))
45 | root_param = np.zeros((n_token,))
46 | trans_param = np.zeros((n_token, n_token, 2, 2))
47 |
48 | norm_counter = np.full(dec_param.shape, smooth)
49 | change = np.zeros((max(dataset['seq_len'].content), 2))
50 | sampler = ConstantTokenNumSampler(dataset['seq_len'].content, 1000000, -1, 0, force_same_len=True)
51 | data_iter = DataSetIter(dataset, batch_sampler=sampler, as_numpy=True)
52 | for x, y in data_iter:
53 | token_array = x['token']
54 | batch_size, word_num = token_array.shape
55 | change.fill(0.)
56 | np.add.at(root_param, (token_array, ), 1. / word_num)
57 | if word_num > 1:
58 | for child_i in range(word_num):
59 | child_sum = get_harmonic_sum(child_i - 0) + get_harmonic_sum(word_num - child_i - 1)
60 | scale = (word_num - 1) / word_num / child_sum
61 | for head_i in range(word_num):
62 | if child_i == head_i:
63 | continue
64 | direction = 1 if head_i <= child_i else 0
65 | head_pos = token_array[:, head_i]
66 | child_pos = token_array[:, child_i]
67 | diff = scale / abs(head_i - child_i)
68 | np.add.at(trans_param, (head_pos, child_pos, direction), diff)
69 | change[head_i, direction] += diff
70 | update_decision(change, norm_counter, token_array, dec_param)
71 |
72 | trans_param += smooth
73 | dec_param += smooth
74 | root_param += smooth
75 |
76 | es = first_child_update(norm_counter, dec_param)
77 | norm_counter *= 0.9 * es
78 | dec_param += norm_counter
79 |
80 | root_param_sum = root_param.sum()
81 | trans_param_sum = trans_param.sum(1, keepdims=True)
82 | decision_param_sum = dec_param.sum(3, keepdims=True)
83 |
84 | root_param /= root_param_sum
85 | trans_param /= trans_param_sum
86 | dec_param /= decision_param_sum
87 |
88 | return np.log(dec_param), np.log(trans_param), np.log(root_param)
89 |
--------------------------------------------------------------------------------
/src/model/embedding/__init__.py:
--------------------------------------------------------------------------------
1 | from .embedding import EmbeddingAdaptor, Embedding
2 | from .fastnlp_embedding import FastNLPEmbeddingAdaptor, FastNLPCharEmbeddingAdaptor, FastNLPEmbeddingVariationalAdaptor
3 | from .transformers_embedding import TransformersAdaptor, TransformersEmbedding
4 |
--------------------------------------------------------------------------------
/src/model/embedding/embedding.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import TYPE_CHECKING, Any, Dict, List
5 |
6 | import torch
7 | import torch.nn as nn
8 | from hydra.utils import instantiate
9 | from torch import Tensor
10 |
11 | from src.model.nn import IndependentDropout
12 | from src.utility.config import Config
13 | from src.utility.logger import get_logger_func
14 |
15 | if TYPE_CHECKING:
16 | from src.model import ModelBase
17 | from src.datamodule import DataModule
18 | from src.utility.var_pool import VarPool
19 |
20 | AnyDict = Dict[str, Any]
21 |
22 | _warn, _info, _debug = get_logger_func('embedding')
23 |
24 |
25 | @dataclass
26 | class EmbeddingItem:
27 | name: str
28 | field: str
29 | emb: EmbeddingAdaptor
30 |
31 |
32 | @dataclass
33 | class EmbeddingConfig(Config):
34 | use_word: bool
35 | use_tag: bool
36 | use_subword: bool # I believe we need only one subwords field.'
37 | dropout: 0. # when multi embedding, for each position, drop some entirely.
38 | # all other items are treated as EmbeddingItemConfig
39 |
40 |
41 | @dataclass
42 | class EmbeddingItemConfig(Config):
43 | args: AnyDict
44 | adaptor_args: AnyDict
45 | field: str
46 | requires_vocab: bool = True # pass vocab to embedding
47 | normalize_word: bool = False # pass the normalize_func(used by datamodule) to Embedding
48 | normalize_method: str = 'mean+std' # mean+std, mean, std, none
49 | normalize_time: str = 'nowhere' # when to normalize embedding, none, begin, epoch, batch
50 |
51 |
52 | class Embedding(torch.nn.Module):
53 | """Embedding, plus apply to different fields."""
54 | bounded_model: ModelBase
55 |
56 | def __init__(self, dm: DataModule, **cfg):
57 | super().__init__()
58 | flags, emb_cfg = EmbeddingConfig.build(cfg, ignore_unknown=True)
59 | flags: EmbeddingConfig
60 |
61 | vocabs = dm.vocabs
62 | datasets = dm.datasets
63 |
64 | self.disabled_fields = set()
65 | if not flags.use_word:
66 | self.disabled_fields.add('word')
67 | if not flags.use_subword:
68 | self.disabled_fields.add('subword')
69 | if not flags.use_tag:
70 | self.disabled_fields.add('pos')
71 |
72 | # instantiate embeddings
73 | self.embeds: List[EmbeddingItem] = []
74 | self.normalize_dict = {'nowhere': [], 'begin': [], 'epoch': [], 'batch': []}
75 | for name, cfg in emb_cfg.items():
76 | if name.startswith('_') or cfg is None:
77 | continue
78 | cfg: EmbeddingItemConfig = EmbeddingItemConfig.build(cfg)
79 | if cfg.field in self.disabled_fields:
80 | continue
81 | instantiate_args = {}
82 | if cfg.requires_vocab:
83 | instantiate_args['vocab'] = vocabs[cfg.field]
84 | if cfg.normalize_word:
85 | instantiate_args['word_transform'] = dm.normalize_one_word_func
86 | emb = instantiate(cfg.args, **instantiate_args)
87 | emb = instantiate(cfg.adaptor_args, emb=emb)
88 | emb.process(vocabs, datasets)
89 | self.add_module(name, emb)
90 | self.embeds.append(EmbeddingItem(name, cfg.field, emb))
91 | self.normalize_dict[cfg.normalize_time].append((name, cfg.normalize_method))
92 |
93 | _info(f'Emb: {", ".join(e.name for e in self.embeds)}')
94 | _info(f'Normalize plan: {self.normalize_dict}')
95 | self.embed_size = sum(e.embed_size for e in self)
96 |
97 | if flags.dropout > 0:
98 | self.dropout_func = IndependentDropout(flags.dropout)
99 | else:
100 | self.dropout_func = lambda *x: x
101 |
102 | def forward(self, x, vp: VarPool):
103 | emb = list(self.dropout_func(*[item.emb(x[item.field], vp) for item in self.embeds]))
104 | seq_len = max(e.shape[1] for e in emb)
105 | assert all(e.shape[1] in (1, seq_len) for e in emb)
106 | for item, h in zip(self.embeds, emb):
107 | vp[item.name] = h
108 | for i in range(len(emb)):
109 | if emb[i].shape[1] == 1:
110 | emb[i] = emb[i].expand(-1, seq_len, -1)
111 | # from src.utility.fn import draw_att
112 | # draw_att(torch.cat(emb, dim=-1)[0])
113 | return torch.cat(emb, dim=-1)
114 |
115 | def normalize(self, now):
116 | for name, method in self.normalize_dict[now]:
117 | getattr(self, name).normalize(method)
118 |
119 | def __getitem__(self, key):
120 | return self.embeds[key].emb
121 |
122 | def __iter__(self):
123 | return map(lambda e: e.emb, self.embeds)
124 |
125 | def __len__(self):
126 | return len(self.embeds)
127 |
128 |
129 | class EmbeddingAdaptor(nn.Module):
130 | device_indicator: Tensor
131 | singleton_emb = {}
132 |
133 | def __init__(self, emb):
134 | super().__init__()
135 | self.emb = emb
136 | self.register_buffer('device_indicator', torch.zeros(1))
137 |
138 | self._normalize_warned = False
139 |
140 | @property
141 | def embed_size(self):
142 | raise NotImplementedError
143 |
144 | @property
145 | def device(self):
146 | return self.device_indicator.device
147 |
148 | def process(self, vocabs, datasets):
149 | return
150 |
151 | def forward(self, inputs: List[Any], vp: VarPool):
152 | raise NotImplementedError
153 |
154 | def normalize(self, method: str):
155 | if not self._normalize_warned:
156 | _warn(f"{type(self)} didn't implement normalize.")
157 | self._normalize_warned = True
158 |
159 | @staticmethod
160 | def _normalize(data: Tensor, method: str):
161 | with torch.no_grad():
162 | if method == 'mean+std':
163 | std, mean = torch.std_mean(data, dim=0, keepdim=True)
164 | data.sub_(mean).divide_(std)
165 | elif method == 'mean':
166 | mean = torch.mean(data, dim=0, keepdim=True)
167 | data.sub_(mean)
168 | elif method == 'std':
169 | std = torch.std(data, dim=0, keepdim=True)
170 | data.divide_(std)
171 | else:
172 | raise ValueError(f'Unrecognized normalize method: {method}')
173 |
174 | @classmethod
175 | def get_singleton(cls, name, emb):
176 | if name in EmbeddingAdaptor.singleton_emb:
177 | return EmbeddingAdaptor.singleton_emb[name]
178 | EmbeddingAdaptor.singleton_emb[name] = emb = cls(emb)
179 | return emb
180 |
--------------------------------------------------------------------------------
/src/model/embedding/fastnlp_embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from fastNLP.embeddings import StaticEmbedding, TokenEmbedding, CNNCharEmbedding, LSTMCharEmbedding
6 | from torch import Tensor
7 | from torch.nn import Parameter
8 |
9 | from src.model.embedding.embedding import EmbeddingAdaptor
10 | from src.model.nn.multivariate_kl import MultVariateKLD
11 | from src.utility.var_pool import VarPool
12 |
13 |
14 | class FastNLPEmbeddingAdaptor(EmbeddingAdaptor):
15 |
16 | def __init__(self, emb: TokenEmbedding):
17 | super().__init__(emb)
18 | self._embed_size = self.emb.embed_size
19 | self._word_dropout = emb.word_dropout
20 | self._dropout = emb.dropout_layer.p
21 | self._normalize_weight = None
22 |
23 | @property
24 | def embed_size(self):
25 | return self._embed_size
26 |
27 | def forward(self, field: Tensor, vp: VarPool):
28 | return self.emb(field)
29 |
30 | def normalize(self, method):
31 | emb: torch.nn.Embedding = self.emb.embedding
32 | if hasattr(self.emb, 'mapped_counts'):
33 | self.emb: StaticEmbedding
34 | if self._normalize_weight is None:
35 | self._normalize_weight = (self.emb.mapped_counts / self.emb.mapped_counts.sum()).unsqueeze(-1)
36 | mean = (emb.weight.data * self._normalize_weight).sum()
37 | if method == 'mean':
38 | emb.weight.data.sub_(mean)
39 | else:
40 | std = (((emb.weight.data - mean).pow(2.) * self._normalize_weight).sum() + 1e-6).sqrt()
41 | if method == 'mean+std':
42 | emb.weight.data.sub_(mean)
43 | emb.weight.data.div_(std)
44 | else:
45 | padding_idx = self.emb.get_word_vocab().padding_idx
46 | start_idx = 1 if padding_idx == 0 else 0
47 | self._normalize(emb.weight.data[start_idx:], method)
48 |
49 | class FastNLPEmbeddingVariationalAdaptor(FastNLPEmbeddingAdaptor):
50 | def __init__(self, emb: TokenEmbedding, mode: str, out_dim: int):
51 | # mode: vae or ib
52 | super(FastNLPEmbeddingVariationalAdaptor, self).__init__(emb)
53 | self.mode = mode
54 | if self.mode != 'basic':
55 | self._embed_size = out_dim
56 | self.enc = nn.Linear(emb.embed_size, 2 * out_dim)
57 | if self.mode == 'ib':
58 | self.gaussian_kl = MultVariateKLD('sum')
59 | self.target_mean = Parameter(torch.zeros(1, out_dim))
60 | self.target_lvar = Parameter(torch.zeros(1, out_dim))
61 |
62 | def forward(self, field: Tensor, vp: VarPool):
63 | if self.mode == 'basic':
64 | return super().forward(field, vp)
65 |
66 | mean, lvar = torch.chunk(self.enc(self.emb(field)), 2, dim=-1)
67 | if self.training:
68 | z = torch.empty_like(mean).normal_()
69 | z = (0.5 * lvar).exp() * z + mean
70 | else:
71 | z = mean
72 | vp.kl = self.kl(mean, lvar)
73 | return z
74 |
75 | def kl(self, mean, lvar):
76 | if self.mode == 'ib':
77 | _mean, _lvar = mean.view(-1, self.embed_size), lvar.view(-1, self.embed_size)
78 | _b = len(_mean)
79 | return self.gaussian_kl(_mean, self.target_mean.expand(_b, -1), _lvar, self.target_lvar.expand(_b, -1))
80 | else:
81 | return -0.5 * (lvar - torch.pow(mean, 2) - torch.exp(lvar) + 1).sum()
82 |
83 |
84 | class FastNLPCharEmbeddingAdaptor(FastNLPEmbeddingAdaptor):
85 | def normalize(self, method):
86 | self.emb: Union[CNNCharEmbedding, LSTMCharEmbedding]
87 | emb = self.emb.char_embedding
88 | start_idx = 1 if self.emb.char_pad_index == 0 else 0
89 | self._normalize(emb.weight.data[start_idx:], method)
90 |
--------------------------------------------------------------------------------
/src/model/embedding/transformers_embedding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from fastNLP import Padder
5 | from torch import Tensor
6 |
7 | from src.model.embedding.embedding import EmbeddingAdaptor
8 | from src.model.nn.scalar_mix import ScalarMix
9 | from src.utility.fn import pad
10 | from src.utility.var_pool import VarPool
11 |
12 |
13 | class TransformersAdaptor(EmbeddingAdaptor):
14 | def __init__(self, emb):
15 | super().__init__(emb)
16 | self.emb: TransformersEmbedding
17 | self._embed_size = self.emb.n_out
18 | self._dropout = self.emb.dropout
19 |
20 | @property
21 | def embed_size(self):
22 | return self._embed_size
23 |
24 | def process(self, vocabs, datasets):
25 | enable_transformers_embedding(datasets, self.emb.tokenizer)
26 |
27 | def forward(self, field: Tensor, vp: VarPool):
28 | return self.emb(field)[:, 1: -1]
29 |
30 |
31 | def enable_transformers_embedding(datasets, tokenizer, fix_len=20):
32 | def get_subwords(_words):
33 | sws = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)[:fix_len]) for w in _words]
34 | sws = [[tokenizer.cls_token_id]] + sws + [[tokenizer.sep_token_id]]
35 | sws = list(map(lambda x: torch.tensor(x, dtype=torch.long), sws))
36 | return pad(sws, tokenizer.pad_token_id).numpy()
37 |
38 | for ds in datasets.values():
39 | ds.apply_field(get_subwords,
40 | 'raw_word',
41 | 'subword',
42 | is_input=True,
43 | padder=SubWordsPadder(tokenizer.pad_token_id))
44 |
45 |
46 | class SubWordsPadder(Padder):
47 | def __call__(self, contents, field_name, field_ele_dtype, dim: int):
48 | batch_size, dtype = len(contents), type(contents[0][0][0])
49 | max_len0, max_len1 = max(c.shape[0] for c in contents), max(c.shape[1] for c in contents)
50 | padded_array = np.full((batch_size, max_len0, max_len1), fill_value=self.pad_val, dtype=dtype)
51 | for b_idx, matrix in enumerate(contents):
52 | padded_array[b_idx, :matrix.shape[0], :matrix.shape[1]] = matrix
53 | return padded_array
54 |
55 |
56 | class TransformersEmbedding(nn.Module):
57 | r""" By Zhang Yu
58 | A nn that directly utilizes the pretrained models in `transformers`_ to produce BERT representations.
59 | While mainly tailored to provide input preparation and post-processing for the BERT model,
60 | it is also compatiable with other pretrained language models like XLNet, RoBERTa and ELECTRA, etc.
61 |
62 | Args:
63 | model (str):
64 | Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``.
65 | n_layers (int):
66 | The number of layers from the model to use.
67 | If 0, uses all layers.
68 | n_out (int):
69 | The requested size of the embeddings. Default: 0.
70 | If 0, uses the size of the pretrained embedding model.
71 | stride (int):
72 | A sequence longer than max length will be splitted into several small pieces
73 | with a window size of ``stride``. Default: 10.
74 | pooling (str):
75 | Pooling way to get from token piece embeddings to token embedding.
76 | Either take the first subtoken ('first'), the last subtoken ('last'), or a mean over all ('mean').
77 | Default: 'mean'.
78 | dropout (float):
79 | The dropout ratio of BERT layers. Default: 0.
80 | This value will be passed into the :class:`ScalarMix` layer.
81 | requires_grad (bool):
82 | If ``True``, the model parameters will be updated together with the downstream task.
83 | Default: ``False``.
84 |
85 | .. _transformers:
86 | https://github.com/huggingface/transformers
87 | """
88 |
89 | def __init__(self,
90 | model,
91 | n_layers,
92 | n_out=0,
93 | stride=256,
94 | pooling='mean',
95 | dropout=0,
96 | requires_grad=False):
97 | super().__init__()
98 |
99 | from transformers import AutoConfig, AutoModel, AutoTokenizer
100 | self.bert = AutoModel.from_pretrained(model,
101 | config=AutoConfig.from_pretrained(model, output_hidden_states=True))
102 | self.bert = self.bert.requires_grad_(requires_grad)
103 |
104 | self.model = model
105 | self.n_layers = n_layers or self.bert.config.num_hidden_layers
106 | self.hidden_size = self.bert.config.hidden_size
107 | self.n_out = n_out or self.hidden_size
108 | self.stride = stride
109 | self.pooling = pooling
110 | self.dropout = dropout
111 | self.requires_grad = requires_grad
112 | self.max_len = int(max(0, self.bert.config.max_position_embeddings) or 1e12) - 2
113 |
114 | self.tokenizer = AutoTokenizer.from_pretrained(model)
115 | self.pad_index = self.tokenizer.pad_token_id
116 | # assert self.pad_index == pad_index
117 |
118 | self.scalar_mix = ScalarMix(self.n_layers, dropout)
119 | self.projection = nn.Linear(self.hidden_size, self.n_out, False) \
120 | if self.hidden_size != self.n_out else nn.Identity()
121 |
122 | def forward(self, subwords):
123 | r"""
124 | Args:
125 | subwords (~torch.Tensor): ``[batch_size, seq_len, fix_len]``.
126 | Returns:
127 | ~torch.Tensor:
128 | BERT embeddings of shape ``[batch_size, seq_len, n_out]``.
129 | """
130 | mask = subwords.ne(self.pad_index)
131 | lens = mask.sum((1, 2))
132 | # [batch_size, n_subwords]
133 | subwords = pad(subwords[mask].split(lens.tolist()), self.pad_index, padding_side=self.tokenizer.padding_side)
134 | bert_mask = pad(mask[mask].split(lens.tolist()), 0, padding_side=self.tokenizer.padding_side)
135 |
136 | # return the hidden states of all layers
137 | bert = self.bert(subwords[:, :self.max_len], attention_mask=bert_mask[:, :self.max_len].float())[-1]
138 | # [n_layers, batch_size, max_len, hidden_size]
139 | bert = bert[-self.n_layers:]
140 | # [batch_size, max_len, hidden_size]
141 | bert = self.scalar_mix(bert)
142 | # [batch_size, n_subwords, hidden_size]
143 | for i in range(self.stride,
144 | (subwords.shape[1] - self.max_len + self.stride - 1) // self.stride * self.stride + 1,
145 | self.stride):
146 | part = self.bert(
147 | subwords[:, i:i + self.max_len],
148 | attention_mask=bert_mask[:, i:i + self.max_len].float(),
149 | )[-1]
150 | bert = torch.cat((bert, self.scalar_mix(part[-self.n_layers:])[:, self.max_len - self.stride:]), 1)
151 |
152 | # [batch_size, n_subwords]
153 | bert_lens = mask.sum(-1)
154 | bert_lens = bert_lens.masked_fill_(bert_lens.eq(0), 1)
155 | # [batch_size, seq_len, fix_len, hidden_size]
156 | embed = bert.new_zeros(*mask.shape, self.hidden_size).masked_scatter_(mask.unsqueeze(-1), bert[bert_mask])
157 | # [batch_size, seq_len, hidden_size]
158 | if self.pooling == 'first':
159 | embed = embed[:, :, 0]
160 | elif self.pooling == 'last':
161 | embed = embed \
162 | .gather(2, (bert_lens - 1).unsqueeze(-1).repeat(1, 1, self.hidden_size).unsqueeze(2)) \
163 | .squeeze(2)
164 | else:
165 | embed = embed.sum(2) / bert_lens.unsqueeze(-1)
166 | embed = self.projection(embed)
167 |
168 | return embed
169 |
--------------------------------------------------------------------------------
/src/model/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .affine import Biaffine
2 | from .common import ResLayer, MLP
3 | from .dmv_spec import DMVSkipConnectEncoder, DMVFactorizedBilinear
4 | from .dropout import SharedDropout, IndependentDropout
5 | from .scalar_mix import ScalarMix
6 | from .variational_lstm import VariationalLSTM
7 | from .affine_scorer import BiaffineScorer
8 |
--------------------------------------------------------------------------------
/src/model/nn/affine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | from torch.nn.parameter import Parameter
5 |
6 |
7 | class Biaffine(nn.Module):
8 | r"""
9 | Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine`.
10 |
11 | This function has a tensor of weights :math:`W` and bias terms if needed.
12 | The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`.
13 | :math:`x` and :math:`y` can be concatenated with bias terms.
14 |
15 | Args:
16 | n_in (int):
17 | The size of the input feature.
18 | n_out (int):
19 | The number of output channels.
20 | bias_x (bool):
21 | If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``.
22 | bias_y (bool):
23 | If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``.
24 | """
25 |
26 | def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True):
27 | super().__init__()
28 |
29 | self.n_in = n_in
30 | self.n_out = n_out
31 | self.bias_x = bias_x
32 | self.bias_y = bias_y
33 | self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y))
34 |
35 | self.reset_parameters()
36 |
37 | def __repr__(self):
38 | s = f'n_in={self.n_in}'
39 | if self.n_out > 1:
40 | s += f', n_out={self.n_out}'
41 | if self.bias_x:
42 | s += f', bias_x={self.bias_x}'
43 | if self.bias_y:
44 | s += f', bias_y={self.bias_y}'
45 |
46 | return f'{self.__class__.__name__}({s})'
47 |
48 | def reset_parameters(self):
49 | nn.init.zeros_(self.weight)
50 |
51 | def forward(self, x, y):
52 | r"""
53 | Args:
54 | x (torch.Tensor): ``[batch_size, seq_len, n_in]``.
55 | y (torch.Tensor): ``[batch_size, seq_len, n_in]``.
56 |
57 | Returns:
58 | ~torch.Tensor:
59 | A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``.
60 | If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically.
61 | """
62 |
63 | if self.bias_x:
64 | x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
65 | if self.bias_y:
66 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
67 | # [batch_size, n_out, seq_len, seq_len]
68 | s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y)
69 | # remove dim 1 if n_out == 1
70 | s = s.squeeze(1)
71 |
72 | return s
73 |
74 |
75 |
--------------------------------------------------------------------------------
/src/model/nn/affine_scorer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from .affine import Biaffine
8 | from .common import MLP
9 |
10 |
11 | class BiaffineScorer(nn.Module):
12 | def __init__(self,
13 | n_in,
14 | hidden_dim,
15 | out_dim,
16 | mlp_dropout,
17 | mlp_activate,
18 | scale):
19 | super().__init__()
20 | self.mlp_dropout = mlp_dropout
21 | self.mlp1 = MLP(n_in // 2, hidden_dim, mlp_dropout, mlp_activate)
22 | self.mlp2 = MLP(n_in // 2, hidden_dim, mlp_dropout, mlp_activate)
23 | self.affine = Biaffine(hidden_dim, out_dim, bias_x=True, bias_y=out_dim > 1)
24 | self.register_buffer('scale', 1 / torch.tensor(hidden_dim if scale else 1).pow(0.25))
25 | self.n_out = out_dim
26 |
27 | def reset_parameters(self):
28 | nn.init.zeros_(self.affine.weight)
29 | self.affine.weight.diagonal().one_()
30 |
31 | def forward(self, x, x2):
32 | h1 = self.mlp1(x) * self.scale
33 | h2 = self.mlp2(x2) * self.scale
34 | out = self.affine(h1, h2).permute(0, 2, 3, 1)
35 | return out
36 |
--------------------------------------------------------------------------------
/src/model/nn/common.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch import Tensor
3 |
4 | from src.model.nn.dropout import SharedDropout
5 |
6 |
7 | class ResLayer(nn.Module):
8 | def __init__(self, n_in, n_hidden, activate=True):
9 | super(ResLayer, self).__init__()
10 | self.linear = nn.Sequential(
11 | nn.Linear(n_in, n_hidden),
12 | nn.ReLU(),
13 | nn.Linear(n_hidden, n_hidden),
14 | nn.ReLU(),
15 | )
16 | self.n_out = n_hidden
17 | self.activation = nn.LeakyReLU() if activate else nn.Identity()
18 |
19 | def forward(self, x):
20 | return self.activation(self.linear(x)) + x
21 |
22 |
23 | class MLP(nn.Module):
24 | def __init__(self, n_in, n_hidden, dropout=0, activate=True):
25 | super(MLP, self).__init__()
26 |
27 | self.n_in = n_in
28 | self.n_hidden = n_hidden
29 |
30 | self.linear = nn.Linear(n_in, n_hidden)
31 | self.activation = nn.LeakyReLU() if activate else nn.Identity()
32 | self.dropout = SharedDropout(p=dropout) if dropout > 0 else nn.Identity()
33 | self.n_out = n_hidden
34 | self.reset_parameters()
35 |
36 | def __repr__(self):
37 | s = f"n_in={self.n_in}, n_out={self.n_hidden}"
38 | if isinstance(self.dropout, SharedDropout):
39 | s += f", dropout={self.dropout.p}"
40 |
41 | return f"{self.__class__.__name__}({s})"
42 |
43 | def reset_parameters(self):
44 | nn.init.orthogonal_(self.linear.weight)
45 | nn.init.zeros_(self.linear.bias)
46 |
47 | def forward(self, x: Tensor) -> Tensor:
48 | x = self.linear(x)
49 | x = self.activation(x)
50 | x = self.dropout(x)
51 | return x
52 |
--------------------------------------------------------------------------------
/src/model/nn/dmv_spec.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 |
6 | class DMVSkipConnectEncoder(nn.Module):
7 | def __init__(self, hidden_size, n_bottleneck=0, n_mid=0, dropout=0.):
8 | super().__init__()
9 | self.hidden_size = hidden_size
10 | self.activate = nn.LeakyReLU()
11 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
12 | self.n_out = hidden_size
13 |
14 | # To encode valence information
15 | if n_bottleneck == 0:
16 | self.HASCHILD_linear = nn.Linear(self.hidden_size, self.hidden_size)
17 | self.NOCHILD_linear = nn.Linear(self.hidden_size, self.hidden_size)
18 | else:
19 | self.HASCHILD_linear = self.create_bottleneck(self.hidden_size, n_bottleneck)
20 | self.NOCHILD_linear = self.create_bottleneck(self.hidden_size, n_bottleneck)
21 | self.valence_linear = nn.Linear(self.hidden_size, self.hidden_size)
22 |
23 | # To encode direction information
24 | if n_bottleneck == 0:
25 | self.LEFT_linear = nn.Linear(self.hidden_size, self.hidden_size)
26 | self.RIGHT_linear = nn.Linear(self.hidden_size, self.hidden_size)
27 | else:
28 | self.LEFT_linear = self.create_bottleneck(self.hidden_size, n_bottleneck)
29 | self.RIGHT_linear = self.create_bottleneck(self.hidden_size, n_bottleneck)
30 | self.direction_linear = nn.Linear(self.hidden_size, self.hidden_size)
31 |
32 | # To produce final hidden representation
33 | n_mid = n_mid if n_mid else hidden_size
34 | self.linear1 = nn.Linear(self.hidden_size, n_mid)
35 | self.linear2 = nn.Linear(n_mid, self.hidden_size)
36 |
37 | def forward(self, x: Tensor):
38 | # input: ... x len x hidden1
39 | # output: ... x len x dir x val x hidden2
40 | has_child = self.HASCHILD_linear(x) + x
41 | no_child = self.NOCHILD_linear(x) + x
42 | h = torch.cat([no_child.unsqueeze(-2), has_child.unsqueeze(-2)], dim=-2)
43 | h = self.activate(self.valence_linear(self.activate(h)))
44 |
45 | x = x.unsqueeze(-2)
46 | left_h = self.LEFT_linear(h) + x
47 | right_h = self.RIGHT_linear(h) + x
48 | h = torch.cat([left_h.unsqueeze(-3), right_h.unsqueeze(-3)], dim=-3)
49 | h = self.activate(self.direction_linear(self.activate(h)))
50 |
51 | h = self.dropout(h)
52 | return self.linear2(self.activate(self.linear1(h)))
53 |
54 | @staticmethod
55 | def create_bottleneck(n_in_out, n_bottleneck):
56 | return nn.Sequential(nn.Linear(n_in_out, n_bottleneck), nn.Linear(n_bottleneck, n_in_out))
57 |
58 |
59 | class DMVFactorizedBilinear(nn.Module):
60 | def __init__(self, n_in, n_in2=None, r=64):
61 | super(DMVFactorizedBilinear, self).__init__()
62 | self.n_in = n_in
63 | self.n_in2 = n_in2 if n_in2 else n_in
64 | self.r = r
65 | self.project1 = nn.Linear(self.n_in, self.r)
66 | self.project2 = nn.Linear(self.n_in2, self.r)
67 |
68 | def forward(self, x1, x2):
69 | x1 = self.project1(x1)
70 | x2 = self.project2(x2)
71 | if len(x1.shape) == 5:
72 | return torch.einsum("bhdve, bcdve -> bhcdv", x1, x2)
73 | elif len(x1.shape) == 4:
74 | return torch.einsum("hdve, cdve -> hcdv", x1, x2)
75 | else:
76 | raise NotImplementedError
77 |
--------------------------------------------------------------------------------
/src/model/nn/dropout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class SharedDropout(nn.Module):
6 | r"""
7 | SharedDropout differs from the vanilla dropout strategy in that
8 | the dropout mask is shared across one dimension.
9 |
10 | Args:
11 | p (float):
12 | The probability of an element to be zeroed. Default: 0.5.
13 | batch_first (bool):
14 | If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``.
15 | Default: ``True``.
16 |
17 | Examples:
18 | >>> x = torch.ones(1, 3, 5)
19 | >>> nn.Dropout()(x)
20 | tensor([[[0., 2., 2., 0., 0.],
21 | [2., 2., 0., 2., 2.],
22 | [2., 2., 2., 2., 0.]]])
23 | >>> SharedDropout()(x)
24 | tensor([[[2., 0., 2., 0., 2.],
25 | [2., 0., 2., 0., 2.],
26 | [2., 0., 2., 0., 2.]]])
27 | """
28 |
29 | def __init__(self, p=0.5, batch_first=True):
30 | super().__init__()
31 |
32 | self.p = p
33 | self.batch_first = batch_first
34 |
35 | def __repr__(self):
36 | s = f'p={self.p}'
37 | if self.batch_first:
38 | s += f', batch_first={self.batch_first}'
39 |
40 | return f'{self.__class__.__name__}({s})'
41 |
42 | def forward(self, x):
43 | r"""
44 | Args:
45 | x (~torch.Tensor):
46 | A tensor of any shape.
47 | Returns:
48 | The returned tensor is of the same shape as `x`.
49 | """
50 |
51 | if self.training:
52 | if self.batch_first:
53 | mask = self.get_mask(x[:, 0], self.p).unsqueeze(1)
54 | else:
55 | mask = self.get_mask(x[0], self.p)
56 | x = x * mask
57 |
58 | return x
59 |
60 | @staticmethod
61 | def get_mask(x, p):
62 | return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p)
63 |
64 |
65 | class IndependentDropout(nn.Module):
66 | r"""
67 | For :math:`N` tensors, they use different dropout masks respectively.
68 | When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M`
69 | to compensate, and when all of them are dropped together, zeros are returned.
70 |
71 | Args:
72 | p (float):
73 | The probability of an element to be zeroed. Default: 0.5.
74 |
75 | Examples:
76 | >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5)
77 | >>> x, y = IndependentDropout()(x, y)
78 | >>> x
79 | tensor([[[1., 1., 1., 1., 1.],
80 | [0., 0., 0., 0., 0.],
81 | [2., 2., 2., 2., 2.]]])
82 | >>> y
83 | tensor([[[1., 1., 1., 1., 1.],
84 | [2., 2., 2., 2., 2.],
85 | [0., 0., 0., 0., 0.]]])
86 | """
87 |
88 | def __init__(self, p=0.5):
89 | super().__init__()
90 |
91 | self.p = p
92 |
93 | def __repr__(self):
94 | return f'{self.__class__.__name__}(p={self.p})'
95 |
96 | def forward(self, *items):
97 | r"""
98 | Args:
99 | items (list[~torch.Tensor]):
100 | A list of tensors that have the same shape except the last dimension.
101 | Returns:
102 | The returned tensors are of the same shape as `items`.
103 | """
104 |
105 | if self.training:
106 | masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items]
107 | total = sum(masks)
108 | scale = len(items) / total.max(torch.ones_like(total))
109 | masks = [mask * scale for mask in masks]
110 | items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)]
111 |
112 | return items
113 |
--------------------------------------------------------------------------------
/src/model/nn/multivariate_kl.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class MultVariateKLD(torch.nn.Module):
5 | def __init__(self, reduction):
6 | super(MultVariateKLD, self).__init__()
7 | self.reduction = reduction
8 |
9 | def forward(self, mu1, mu2, logvar_1, logvar_2):
10 | mu1, mu2 = mu1.type(dtype=torch.float64), mu2.type(dtype=torch.float64)
11 | sigma_1 = logvar_1.exp().type(dtype=torch.float64)
12 | sigma_2 = logvar_2.exp().type(dtype=torch.float64)
13 |
14 | sigma_diag_1 = torch.diag_embed(sigma_1, offset=0, dim1=-2, dim2=-1)
15 | sigma_diag_2 = torch.diag_embed(sigma_2, offset=0, dim1=-2, dim2=-1)
16 |
17 | sigma_diag_2_inv = sigma_diag_2.inverse()
18 |
19 | # log(det(sigma2^T)/det(sigma1))
20 | term_1 = (sigma_diag_2.det() / sigma_diag_1.det()).log()
21 | # term_1[term_1.ne(term_1)] = 0
22 |
23 | # trace(inv(sigma2)*sigma1)
24 | term_2 = torch.diagonal((torch.matmul(sigma_diag_2_inv, sigma_diag_1)), dim1=-2, dim2=-1).sum(-1)
25 |
26 | # (mu2-m1)^T*inv(sigma2)*(mu2-mu1)
27 | term_3 = torch.matmul(torch.matmul((mu2 - mu1).unsqueeze(-1).transpose(2, 1), sigma_diag_2_inv),
28 | (mu2 - mu1).unsqueeze(-1)).flatten()
29 |
30 | # dimension of embedded space (number of mus and sigmas)
31 | n = mu1.shape[1]
32 |
33 | # Calc kl divergence on entire batch
34 | kl = 0.5 * (term_1 - n + term_2 + term_3)
35 |
36 | # Calculate mean kl_d loss
37 | if self.reduction == 'mean':
38 | kl_agg = torch.mean(kl)
39 | elif self.reduction == 'sum':
40 | kl_agg = torch.sum(kl)
41 | else:
42 | raise NotImplementedError(f'Reduction type not implemented: {self.reduction}')
43 |
44 | return kl_agg
45 |
--------------------------------------------------------------------------------
/src/model/nn/scalar_mix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ScalarMix(nn.Module):
6 | r"""
7 | Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)`
8 | where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters.
9 |
10 | Args:
11 | n_layers (int):
12 | The number of layers to be mixed, i.e., :math:`N`.
13 | dropout (float):
14 | The dropout ratio of the layer weights.
15 | If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0
16 | with the dropout probability (i.e., setting the unnormalized weight to -inf).
17 | This effectively redistributes the dropped probability mass to all other weights.
18 | Default: 0.
19 | """
20 |
21 | def __init__(self, n_layers, dropout=0):
22 | super().__init__()
23 |
24 | self.n_layers = n_layers
25 |
26 | self.weights = nn.Parameter(torch.zeros(n_layers))
27 | self.gamma = nn.Parameter(torch.tensor([1.0]))
28 | self.dropout_func = nn.Dropout(dropout)
29 |
30 | def __repr__(self):
31 | s = f'n_layers={self.n_layers}'
32 | if self.dropout_func.p > 0:
33 | s += f', dropout={self.dropout_func.p}'
34 |
35 | return f'{self.__class__.__name__}({s})'
36 |
37 | def forward(self, tensors):
38 | r"""
39 | Args:
40 | tensors (list[~torch.Tensor]):
41 | :math:`N` tensors to be mixed.
42 |
43 | Returns:
44 | The mixture of :math:`N` tensors.
45 | """
46 |
47 | normed_weights = self.dropout_func(self.weights.softmax(-1))
48 | weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors))
49 |
50 | return self.gamma * weighted_sum
51 |
--------------------------------------------------------------------------------
/src/model/nn/variational_lstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import PackedSequence
4 | from src.model.nn.dropout import SharedDropout
5 |
6 |
7 | class VariationalLSTM(nn.Module):
8 | r"""
9 | VariationalLSTM :cite:`yarin-etal-2016-dropout` is an variant of the vanilla bidirectional LSTM
10 | adopted by Biaffine Parser with the only difference of the dropout strategy.
11 | It drops nodes in the LSTM layers (input and recurrent connections)
12 | and applies the same dropout mask at every recurrent timesteps.
13 | APIs are roughly the same as :class:`~torch.nn.LSTM` except that we only allows
14 | :class:`~torch.nn.utils.rnn.PackedSequence` as input.
15 | Args:
16 | input_size (int):
17 | The number of expected features in the input.
18 | hidden_size (int):
19 | The number of features in the hidden state `h`.
20 | num_layers (int):
21 | The number of recurrent layers. Default: 1.
22 | dropout (float):
23 | If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer (except last).
24 | Default: 0.
25 | """
26 |
27 | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0, cell=nn.LSTMCell, init='zy'):
28 | super().__init__()
29 |
30 | self.input_size = input_size
31 | self.hidden_size = hidden_size
32 | self.num_layers = num_layers
33 | self.dropout = dropout
34 | self.init = init
35 |
36 | self.f_cells = nn.ModuleList()
37 | self.b_cells = nn.ModuleList()
38 | for _ in range(self.num_layers):
39 | self.f_cells.append(cell(input_size=input_size, hidden_size=hidden_size))
40 | self.b_cells.append(cell(input_size=input_size, hidden_size=hidden_size))
41 | input_size = hidden_size * 2
42 |
43 | self.reset_parameters()
44 |
45 | def __repr__(self):
46 | s = f'{self.input_size}, {2 * self.hidden_size}'
47 | if self.num_layers > 1:
48 | s += f', num_layers={self.num_layers}'
49 | if self.dropout > 0:
50 | s += f', dropout={self.dropout}'
51 |
52 | return f'{self.__class__.__name__}({s})'
53 |
54 | def reset_parameters(self):
55 | if self.init == 'zy':
56 | for name, param in self.named_parameters():
57 | if name.startswith('lstm'):
58 | # apply orthogonal_ to weight
59 | if len(param.shape) > 1:
60 | nn.init.orthogonal_(param)
61 | # apply zeros_ to bias
62 | else:
63 | nn.init.zeros_(param)
64 | elif self.init == 'biased':
65 | for name, param in self.named_parameters():
66 | if name.startswith('lstm'):
67 | # apply orthogonal_ to weight
68 | if len(param.shape) > 1:
69 | nn.init.xavier_uniform_(param)
70 | else:
71 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
72 | param.data.fill_(0.)
73 | n = param.shape[0]
74 | start, end = n // 4, n // 2
75 | param.data[start:end].fill_(1.)
76 | else:
77 | raise ValueError(f'Bad init_version, {self.cfg.init_version=}')
78 |
79 | def layer_forward(self, x, hx, cell, batch_sizes, reverse=False):
80 | hx_0 = hx_i = hx
81 | hx_n, output = [], []
82 | steps = reversed(range(len(x))) if reverse else range(len(x))
83 | if self.training:
84 | hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout)
85 |
86 | for t in steps:
87 | last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t]
88 | if last_batch_size < batch_size:
89 | hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)]
90 | else:
91 | hx_n.append([h[batch_size:] for h in hx_i])
92 | hx_i = [h[:batch_size] for h in hx_i]
93 | hx_i = [h for h in cell(x[t], hx_i)]
94 | output.append(hx_i[0])
95 | if self.training:
96 | hx_i[0] = hx_i[0] * hid_mask[:batch_size]
97 | if reverse:
98 | hx_n = hx_i
99 | output.reverse()
100 | else:
101 | hx_n.append(hx_i)
102 | hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))]
103 | output = torch.cat(output)
104 |
105 | return output, hx_n
106 |
107 | def forward(self, sequence: PackedSequence, hx=None):
108 | r"""
109 | Args:
110 | sequence (~torch.nn.utils.rnn.PackedSequence):
111 | A packed variable length sequence.
112 | hx (~torch.Tensor, ~torch.Tensor):
113 | A tuple composed of two tensors `h` and `c`.
114 | `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial hidden state
115 | for each element in the batch.
116 | `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial cell state
117 | for each element in the batch.
118 | If `hx` is not provided, both `h` and `c` default to zero.
119 | Default: ``None``.
120 | Returns:
121 | ~torch.nn.utils.rnn.PackedSequence, (~List[torch.Tensor], ~torch.Tensor):
122 | The first is a list of packed variable length sequence for each layer.
123 | The second is a tuple of tensors `h` and `c`.
124 | `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]``
125 | holds the hidden state for `t=seq_len`.
126 | Like output, the layers can be separated using
127 | ``h.view(num_layers, num_directions, batch_size, hidden_size)``
128 | and similarly for c.
129 | `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]``
130 | holds the cell state for `t=seq_len`.
131 | """
132 | x, batch_sizes = sequence.data, sequence.batch_sizes.tolist()
133 | batch_size = batch_sizes[0]
134 | h_n, c_n, hiddens = [], [], []
135 |
136 | if hx is None:
137 | ih = x.new_zeros(self.num_layers * 2, batch_size, self.hidden_size)
138 | h, c = ih, ih
139 | else:
140 | h, c = hx
141 | h = h.view(self.num_layers, 2, batch_size, self.hidden_size)
142 | c = c.view(self.num_layers, 2, batch_size, self.hidden_size)
143 |
144 | for i in range(self.num_layers):
145 | x = torch.split(x, batch_sizes)
146 | if self.training and i > 0:
147 | mask = SharedDropout.get_mask(x[0], self.dropout)
148 | x = [i * mask[:len(i)] for i in x]
149 | x_i, (h_i, c_i) = self.layer_forward(x, (h[i, 0], c[i, 0]), self.f_cells[i], batch_sizes)
150 | x_b, (h_b, c_b) = self.layer_forward(x, (h[i, 1], c[i, 1]), self.b_cells[i], batch_sizes, True)
151 | x_i = torch.cat((x_i, x_b), -1)
152 | h_i = torch.stack((h_i, h_b))
153 | c_i = torch.stack((c_i, c_b))
154 | x = x_i
155 | h_n.append(h_i)
156 | c_n.append(c_i)
157 | hiddens.append(
158 | PackedSequence(x_i, sequence.batch_sizes, sequence.sorted_indices, sequence.unsorted_indices))
159 |
160 | hx = torch.cat(h_n, 0), torch.cat(c_n, 0)
161 | return hiddens, hx
162 |
--------------------------------------------------------------------------------
/src/model/text_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.text_encoder.base import EncoderBase
2 | from src.model.text_encoder.rnn_encoder import RNNEncoder
3 | from src.model.text_encoder.mlp_encoder import MLPEncoder
4 | from src.model.text_encoder.blank_encoder import BlankEncoder
--------------------------------------------------------------------------------
/src/model/text_encoder/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import torch.nn as nn
6 |
7 | if TYPE_CHECKING:
8 | from src.model.embedding import Embedding
9 | from src.model import ModelBase
10 |
11 |
12 | class EncoderBase(nn.Module):
13 | bounded_embedding: Embedding
14 | bounded_model: ModelBase
15 |
16 | def __init__(self, embedding: Embedding):
17 | super().__init__()
18 | self.__dict__['bounded_embedding'] = embedding
19 |
20 | def forward(self, x, ctx):
21 | raise NotImplementedError
22 |
23 | def get_dim(self, field):
24 | raise NotImplementedError(f'Unrecognized {field=}')
25 |
26 |
--------------------------------------------------------------------------------
/src/model/text_encoder/blank_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 |
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | from src.model.embedding import Embedding
9 | from src.model.text_encoder.base import EncoderBase
10 | from src.model.nn import SharedDropout
11 | from src.utility.config import Config
12 | from src.utility.logger import get_logger_func
13 | from src.utility.var_pool import VarPool
14 |
15 | _warn, _info, _debug = get_logger_func('encoder')
16 |
17 |
18 | @dataclass
19 | class BlankEncoderConfig(Config):
20 | dropout: float
21 | shared_dropout: float
22 |
23 |
24 | class BlankEncoder(EncoderBase):
25 |
26 | def __init__(self, embedding: Embedding, **cfg):
27 | super().__init__(embedding)
28 | self.cfg = cfg = BlankEncoderConfig.build(cfg)
29 | self.output_size = embedding.embed_size
30 | self.dropout = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity()
31 | self.shared_dropout = SharedDropout(cfg.dropout) if cfg.shared_dropout > 0 else nn.Identity()
32 |
33 | def forward(self, x: Tensor, vp: VarPool, hiddens=None):
34 | x = self.dropout(x)
35 | x = self.shared_dropout(x)
36 | return {'x': x}
37 |
38 | def get_dim(self, field):
39 | return self.output_size
40 |
--------------------------------------------------------------------------------
/src/model/text_encoder/mlp_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 |
5 | import torch.nn as nn
6 | from torch import Tensor
7 |
8 | from src.model.embedding import Embedding
9 | from src.model.text_encoder.base import EncoderBase
10 | from src.model.nn import SharedDropout
11 | from src.utility.config import Config
12 | from src.utility.logger import get_logger_func
13 | from src.utility.var_pool import VarPool
14 |
15 | _warn, _info, _debug = get_logger_func('encoder')
16 |
17 |
18 | @dataclass
19 | class MLPEncoderConfig(Config):
20 | dropout: float
21 | n_hidden: int
22 | shared_dropout: float
23 |
24 |
25 | class MLPEncoder(EncoderBase):
26 |
27 | def __init__(self, embedding: Embedding, **cfg):
28 | super().__init__(embedding)
29 | self.cfg = cfg = MLPEncoderConfig.build(cfg)
30 | self.output_size = cfg.n_hidden
31 | self.linear = nn.Linear(embedding.embed_size, self.output_size, bias=False)
32 | self.dropout = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity()
33 | self.shared_dropout = SharedDropout(cfg.dropout) if cfg.shared_dropout > 0 else nn.Identity()
34 |
35 | def forward(self, x: Tensor, vp: VarPool, hiddens=None):
36 | x = self.dropout(x)
37 | x = self.shared_dropout(x)
38 | x = self.linear(x)
39 | return {'x': x}
40 |
41 | def get_dim(self, field):
42 | return self.output_size
43 |
--------------------------------------------------------------------------------
/src/model/text_encoder/multi_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 |
5 | from .base import EncoderBase
6 |
7 |
8 | class MultiEncoder(EncoderBase):
9 | """Compose encoders to different output."""
10 |
11 | def __init__(self, embedding, mapping, ff=None, **encoders):
12 | """
13 | :param mapping: a dict indicate show to construct x.
14 | for example: mapping = {
15 | 'arc': ['shared_lstm.x', 'arc_lstm.x'],
16 | 'rel': ['shared_lstm.x', 'rel_lstm.x']
17 | }
18 | :param ff: a dict indicate passthrough variables
19 | e.g. ff = {
20 | 'hiddens': 'shared_lstm.hiddens'
21 | }
22 | :type mapping: dict
23 | """
24 | super().__init__(embedding)
25 |
26 | self.all_encoders = []
27 | for key, value in encoders.items():
28 | if key.startswith('_'):
29 | continue
30 | self.add_module(key, value)
31 | self.all_encoders.append(key)
32 |
33 | self.mapping = {} # {'shared_lstm': {'x': ['arc', 'rel']}, ...}
34 | self.output_fields = list(mapping.keys())
35 | self.dims = {o: 0 for o in self.output_fields}
36 | self.detailed_dims = {o: [] for o in self.output_fields}
37 | for target, sources in mapping.items():
38 | for source in sources:
39 | source_name, source_field = source.split('.')
40 | self.dims[target] += encoders[source_name].get_dim(source_field)
41 | self.detailed_dims[target].append(encoders[source_name].get_dim(source_field))
42 | if source_name not in self.mapping:
43 | self.mapping[source_name] = {}
44 | if source_field not in self.mapping[source_name]:
45 | self.mapping[source_name][source_field] = []
46 | self.mapping[source_name][source_field].append(target)
47 | self.ff = {}
48 | if ff is not None:
49 | for target, source in ff.items():
50 | source_name, source_field = source.split('.')
51 | assert target not in mapping, 'Conflict'
52 | if source_name not in self.ff:
53 | self.ff[source_name] = {}
54 | if source_field not in self.ff[source_name]:
55 | self.ff[source_name][source_field] = []
56 | self.ff[source_name][source_field].append(target)
57 |
58 | def forward(self, x, ctx):
59 | outputs = {key: [] for key in self.output_fields}
60 | for source_name in self.all_encoders:
61 | encoder_out = getattr(self, source_name)(x, ctx)
62 | if source_name in self.mapping:
63 | for encoder_field, targets in self.mapping[source_name].items():
64 | for target in targets:
65 | outputs[target].append(encoder_out[encoder_field])
66 | if source_name in self.ff:
67 | for encoder_field, targets in self.ff[source_name].items():
68 | for target in targets:
69 | outputs[target] = encoder_out[encoder_field]
70 | outputs = {
71 | key: torch.cat(value, dim=-1) if key in self.output_fields else value
72 | for key, value in outputs.items()
73 | }
74 |
75 | return outputs
76 |
77 | def get_dim(self, field):
78 | return self.dims[field]
79 |
80 |
--------------------------------------------------------------------------------
/src/model/text_encoder/rnn_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Union
5 |
6 | import torch
7 | import torch.nn as nn
8 | from omegaconf import MISSING
9 | from torch import Tensor
10 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
11 |
12 | from src.model.embedding import Embedding
13 | from src.model.text_encoder.base import EncoderBase
14 | from src.model.nn import ScalarMix, SharedDropout, VariationalLSTM
15 | from src.utility.config import Config
16 | from src.utility.logger import get_logger_func
17 | from src.utility.var_pool import VarPool
18 |
19 | _warn, _info, _debug = get_logger_func('encoder')
20 | RNN_TYPE_DICT = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN}
21 | RNNCELL_TYPE_DICT = {'lstm': nn.LSTMCell, 'gru': nn.GRUCell, 'rnn': nn.RNNCell}
22 |
23 |
24 | @dataclass
25 | class LSTMEncoderConfig(Config):
26 | reproject_emb: int = 0 # reproject layer before lstm
27 | # ============================= dropout ==============================
28 | pre_shared_dropout: float = 0.
29 | pre_dropout: float = 0.
30 | post_shared_dropout: float = 0.
31 | post_dropout: float = 0.
32 |
33 | # =============================== lstm ===============================
34 | rnn_type: str = 'lstm' # lstm, gru or rnn
35 | hidden_size: Union[int, List[int]] = MISSING # hidden size for each layer
36 | proj_size: int = 0 # projective size
37 | num_layers: int = MISSING # total layers
38 | output_layers: Union[int, List[int]] = -1 # which layers are return, start from 0
39 | init_version: str = 'biased'
40 | shared_dropout: bool = True
41 | lstm_dropout: float = 0.33 # only between layers, unlike zhangyu.
42 | no_eos: bool = False # simulate no
43 | sorted: bool = True
44 |
45 | # ============================== output ==============================
46 | mix: bool = False # whether to use a ScaleMix when multiple outputs
47 | reproject_out: int = 0
48 | cat_emb: bool = False
49 |
50 |
51 | class RNNEncoder(EncoderBase):
52 |
53 | def __init__(self, embedding: Embedding, **cfg):
54 | super().__init__(embedding)
55 | self.cfg = cfg = LSTMEncoderConfig.build(cfg)
56 |
57 | # check output_layers
58 | output_layers: List[int] = [cfg.output_layers] if isinstance(cfg.output_layers, int) else cfg.output_layers
59 | output_layers = sorted(cfg.num_layers + o if o < 0 else o for o in output_layers)
60 | assert output_layers[0] >= 0 and output_layers[-1] < cfg.num_layers
61 | if output_layers[-1] < cfg.num_layers - 1:
62 | cfg.num_layers = output_layers[-1] + 1
63 | _warn(f'max index of output_layers is smaller to n_layers, n_layers is set to {cfg.num_layers}')
64 | self.output_layers = output_layers
65 |
66 | self.embedding2nn = nn.Linear(embedding.embed_size, cfg.reproject_emb) if cfg.reproject_emb else nn.Identity()
67 |
68 | # ============================= dropout ==============================
69 |
70 | self.pre_shared_dropout = SharedDropout(cfg.pre_shared_dropout) if cfg.pre_shared_dropout else nn.Identity()
71 | self.pre_dropout = nn.Dropout(cfg.pre_dropout) if cfg.pre_dropout else nn.Identity()
72 | self.post_shared_dropout = SharedDropout(cfg.post_shared_dropout) if cfg.post_shared_dropout else nn.Identity()
73 | self.post_dropout = nn.Dropout(cfg.post_dropout) if cfg.post_dropout else nn.Identity()
74 |
75 | # =============================== lstm ===============================
76 |
77 | input_size = cfg.reproject_emb if cfg.reproject_emb > 0 else embedding.embed_size
78 | if cfg.shared_dropout:
79 | assert isinstance(cfg.hidden_size, int), 'Not supported'
80 | assert cfg.proj_size == 0, 'Not supported'
81 | self.lstm = VariationalLSTM(input_size, cfg.hidden_size, cfg.num_layers, cfg.lstm_dropout,
82 | RNNCELL_TYPE_DICT[cfg.rnn_type])
83 | self.output_size = 2 * cfg.hidden_size
84 | else:
85 | # figure out how many layers in each sub modules
86 | layer_for_each_rnn = [x - y for x, y in zip(output_layers, [-1] + output_layers[:-1])]
87 |
88 | # check hiddens
89 | if isinstance(cfg.hidden_size, int):
90 | hiddens = [cfg.hidden_size for _ in layer_for_each_rnn]
91 | else:
92 | hiddens = cfg.hidden_size
93 | assert len(hiddens) == len(layer_for_each_rnn)
94 |
95 | # construct nn
96 | self.lstm_dropout = nn.Dropout(cfg.lstm_dropout)
97 | self.lstm = nn.ModuleList()
98 | rnn_type = RNN_TYPE_DICT[cfg.rnn_type]
99 | for n_layer, hidden in zip(layer_for_each_rnn, hiddens):
100 | sub_lstm = rnn_type(input_size,
101 | hidden,
102 | n_layer,
103 | dropout=cfg.lstm_dropout if n_layer > 1 else 0,
104 | bidirectional=True,
105 | proj_size=cfg.proj_size if hidden > cfg.proj_size > 0 else 0)
106 | self.lstm.append(sub_lstm)
107 | input_size = 2 * cfg.proj_size if cfg.proj_size else 2 * hidden
108 | self.output_size = 2 * cfg.proj_size if cfg.proj_size else 2 * hiddens[-1]
109 |
110 | if cfg.mix:
111 | assert isinstance(cfg.hidden_size, int) or all(h == cfg.hidden_size[0] for h in cfg.hidden_size), \
112 | 'Only if has same dim for all layers, mix can be used.'
113 | self.mix = ScalarMix(len(output_layers))
114 | else:
115 | self.output_size *= len(output_layers)
116 |
117 | if cfg.reproject_out:
118 | self.nn2out = nn.Linear(self.output_size, cfg.reproject_out)
119 | self.output_size = cfg.reproject_out
120 | else:
121 | self.nn2out = nn.Identity()
122 |
123 | if cfg.cat_emb:
124 | self.output_size += embedding.embed_size
125 |
126 | self.reset_parameters(cfg.init_version)
127 |
128 | def reset_parameters(self, init_method):
129 | if init_method == 'zy':
130 | for name, param in self.named_parameters():
131 | if name.startswith('lstm'):
132 | # apply orthogonal_ to weight
133 | if len(param.shape) > 1:
134 | nn.init.orthogonal_(param)
135 | # apply zeros_ to bias
136 | else:
137 | nn.init.zeros_(param)
138 | elif init_method == 'biased':
139 | for name, param in self.named_parameters():
140 | if name.startswith('lstm'):
141 | # apply orthogonal_ to weight
142 | if len(param.shape) > 1:
143 | nn.init.xavier_uniform_(param)
144 | else:
145 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871
146 | param.data.fill_(0.)
147 | n = param.shape[0]
148 | start, end = n // 4, n // 2
149 | param.data[start:end].fill_(1.)
150 | # else:
151 | # raise ValueError(f'Bad init_version, {self.cfg.init_version=}')
152 |
153 | def forward(self, x: Tensor, vp: VarPool, hiddens=None):
154 | """
155 | :param x: output of embedding
156 | :param vp: the varpool
157 | :param hiddens: ttbp
158 | :return: a dict contains
159 | x, Tensor: the concated or mixed representation.
160 | all: List[Tensor], a list contains all outputs specified in self.output_layers.
161 | hx: the output state for all layers.
162 | """
163 | if isinstance(x, list): x = torch.cat(x, dim=-1)
164 |
165 | emb = x
166 | x = self.embedding2nn(x)
167 | x = self.pre_shared_dropout(x)
168 | x = self.pre_dropout(x)
169 | xs, hx = self.lstm_forward(x, vp, hiddens)
170 | if self.cfg.mix:
171 | x = self.mix(xs)
172 | else:
173 | x = torch.cat(xs, dim=-1)
174 | x = self.post_dropout(x)
175 | x = self.post_shared_dropout(x)
176 | if self.cfg.no_eos:
177 | x = torch.cat([x, torch.zeros(x.shape[0], 1, x.shape[2], device=x.device)], dim=1)
178 | x = self.nn2out(x)
179 |
180 | if self.cfg.cat_emb:
181 | x = torch.cat([x, emb], dim=-1)
182 |
183 | # from src.utility.fn import draw_att
184 | # draw_att(x[0])
185 |
186 | return {'x': x, 'all': xs, 'hiddens': hx}
187 |
188 | def lstm_forward(self, x: Tensor, vp: VarPool, hiddens=None):
189 | if self.cfg.no_eos:
190 | x = x[:, :-1]
191 | x = pack_padded_sequence(x, vp.seq_len_cpu - 1, True, enforce_sorted=self.cfg.sorted)
192 | else:
193 | x = pack_padded_sequence(x, vp.seq_len_cpu, True, enforce_sorted=self.cfg.sorted)
194 |
195 | if self.cfg.shared_dropout:
196 | outputs, (hx, _) = self.lstm(x, hiddens)
197 | outputs = [outputs[i] for i in self.output_layers]
198 | outputs = [pad_packed_sequence(o, True)[0] for o in outputs]
199 | else:
200 | layer_count = -1
201 | outputs = []
202 | output_layers = self.output_layers.copy()
203 | hx = []
204 | hiddens = hiddens if hiddens is not None else [None] * len(self.lstm)
205 |
206 | for layer, hidden in zip(self.lstm, hiddens):
207 | output: PackedSequence
208 | output, (hx_, _) = layer(x, hidden)
209 | hx.append(hx_)
210 |
211 | layer_count += layer.num_layers
212 | if layer_count == output_layers[0]:
213 | output_layers.pop(0)
214 | outputs.append(pad_packed_sequence(output, True)[0])
215 |
216 | data = self.lstm_dropout(output.data)
217 | x = PackedSequence(data, output.batch_sizes, output.sorted_indices, output.unsorted_indices)
218 | hx = torch.cat(hx, 0)
219 | return outputs, hx
220 |
221 | def get_dim(self, field):
222 | if field == 'x' or field == 'all':
223 | return self.output_size
224 | return super().get_dim(field)
225 |
--------------------------------------------------------------------------------
/src/model/torch_struct/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributions import DMV1o, DependencyCRF, StructDistribution
2 | from .semirings import (
3 | CheckpointSemiring,
4 | CheckpointShardSemiring,
5 | EntropySemiring,
6 | FastLogSemiring,
7 | FastMaxSemiring,
8 | FastSampleSemiring,
9 | GumbelCRFSemiring,
10 | KMaxSemiring,
11 | LogSemiring,
12 | MaxSemiring,
13 | MultiSampledSemiring,
14 | SampledSemiring,
15 | SparseMaxSemiring,
16 | StdSemiring,
17 | TempMax,
18 | )
19 |
20 | version = "0.4"
21 |
22 | # For flake8 compatibility.
23 | __all__ = [
24 | LogSemiring,
25 | StdSemiring,
26 | SampledSemiring,
27 | MaxSemiring,
28 | SparseMaxSemiring,
29 | KMaxSemiring,
30 | FastLogSemiring,
31 | FastMaxSemiring,
32 | FastSampleSemiring,
33 | EntropySemiring,
34 | MultiSampledSemiring,
35 | GumbelCRFSemiring,
36 | StructDistribution,
37 | DMV1o,
38 | DependencyCRF,
39 | CheckpointSemiring,
40 | CheckpointShardSemiring,
41 | TempMax,
42 | ]
43 |
--------------------------------------------------------------------------------
/src/model/torch_struct/dmv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 |
4 | from .helpers import _Struct
5 | from .semirings import Semiring
6 |
7 | NOCHILD = 1
8 | HASCHILD = 0
9 | LEFT = 0
10 | RIGHT = 1
11 | GO = 0
12 | STOP = 1
13 | DIR_NUM = 2
14 | VAL_NUM = 2
15 | DEC_NUM = 2
16 |
17 |
18 | class DMV1oStruct(_Struct):
19 | def _dp(self, scores, lengths=None, force_grad=False, cache=False):
20 | # dec, attach
21 | s: Semiring = self.semiring
22 |
23 | if isinstance(scores[0], torch.Tensor):
24 | # attach_score: batch, N, N, valence
25 | # dec_score: batch, N, direction, valence, decision
26 | attach: Tensor = s.convert(scores[1])
27 | dec: Tensor = s.convert(scores[0])
28 | else:
29 | attach: Tensor = s.convert([scores[0][1], scores[1][1]])
30 | dec: Tensor = s.convert([scores[0][0], scores[1][0]])
31 |
32 | _, batch, N, *_ = dec.shape
33 | # diagonal for left, diagonal(1) for right.
34 | I = s.zero_(attach.new_empty((s.size(), batch, N + 1, N + 1, VAL_NUM)))
35 | C = s.zero_(attach.new_empty((s.size(), batch, N + 1, N + 1, VAL_NUM)))
36 | attach_left = s.mul(attach, dec[:, :, :, None, LEFT, :, GO])
37 | attach_right = s.mul(attach, dec[:, :, :, None, RIGHT, :, GO])
38 |
39 | diag_minus1(C, 0, 2, 3).copy_(dec[:, :, :, LEFT, :, STOP].transpose(-2, -1))
40 | C.diagonal(1, 2, 3).copy_(dec[:, :, :, RIGHT, :, STOP].transpose(-2, -1))
41 | _zero = C.new_tensor(s.zero)
42 | if _zero.ndim == 0:
43 | _zero = _zero.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
44 | else:
45 | _zero = _zero.unsqueeze(-1).unsqueeze(-1)
46 |
47 | for w in range(1, N):
48 | n = N - w
49 |
50 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 1, NOCHILD)), stripe_val(C, n, w, (w, 1, HASCHILD))))
51 | x = s.times(x.unsqueeze(-2), attach_left.diagonal(-w, -3, -2))
52 | diag_minus1(I, -w, -3, -2).copy_(x)
53 |
54 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 1, HASCHILD)), stripe_val(C, n, w, (w, 1, NOCHILD))))
55 | x = s.times(x.unsqueeze(-2), attach_right.diagonal(w, -3, -2))
56 | I.diagonal(w + 1, -3, -2).copy_(x)
57 |
58 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 0, NOCHILD), 0, True), stripe_noval(I, n, w, (w, 0))), -2)
59 | diag_minus1(C, -w, -3, -2).copy_(x.transpose(-2, -1))
60 |
61 | x = s.sum(s.mul(stripe_noval(I, n, w, (0, 2)), stripe_val(C, n, w, (1, w + 1, NOCHILD), 0, True)), -2)
62 | C.diagonal(w + 1, -3, -2).copy_(x.transpose(-2, -1))
63 | C[:, lengths.ne(w), 0, w + 1] = _zero
64 |
65 | v = torch.gather(C[:, :, 0, :, NOCHILD], -1, (lengths[None, ..., None] + 1).expand(s.size(), -1, -1))
66 | return v, [dec, attach], [C, I]
67 |
68 | def _arrange_marginals(self, marg):
69 | return marg[1] # return attach
70 |
71 |
72 | def stripe_val(x: Tensor, n, w, offset=(0, 0, 0), dim=1, keep_val=False):
73 | # x: s x b x N x N x valence
74 | # on the last three dim, N x N x valence
75 | # n and w are for N x N
76 | assert x.shape[-1] == 2
77 | assert x.is_contiguous(), 'x must be contiguous, or write on new view will lost.'
78 | seq_len = x.shape[-2]
79 | if keep_val:
80 | size = (*x.shape[:-3], n, w, 1)
81 | stride = list(x.stride())
82 | stride[-3] = (seq_len + 1) * 2
83 | stride[-2] = (1 if dim == 1 else seq_len) * 2
84 | else:
85 | stride = list(x.stride())[:-1]
86 | stride[-2] = (seq_len + 1) * 2
87 | stride[-1] = (1 if dim == 1 else seq_len) * 2
88 | size = (*x.shape[:-3], n, w)
89 | return x.as_strided(size=size,
90 | stride=stride,
91 | storage_offset=x.storage_offset() + (offset[0] * seq_len * 2 + offset[1] * 2 + offset[2]))
92 |
93 |
94 | def stripe_noval(x: Tensor, n, w, offset=(0, 0), dim=1):
95 | # x: s x b x N x N x valence
96 | # on the last three dim, N x N x valence
97 | # n and w are for N x N
98 | assert x.shape[-1] == 2
99 | assert x.is_contiguous(), 'x must be contiguous, or write on new view will lost.'
100 | seq_len = x.shape[-2]
101 | stride = list(x.stride())
102 | stride[-3] = (seq_len + 1) * 2
103 | stride[-2] = (1 if dim == 1 else seq_len) * 2
104 | return x.as_strided(size=(*x.shape[:-3], n, w, 2),
105 | stride=stride,
106 | storage_offset=x.storage_offset() + (offset[0] * seq_len * 2 + offset[1] * 2))
107 |
108 |
109 | def diag_minus1(x: Tensor, offset, dim1, dim2) -> Tensor:
110 | # assume a[..., dim1, ..., dim2, ...]
111 | stride = list(x.stride())
112 | if offset > 0:
113 | storage_offset = stride[dim2] * offset
114 | else:
115 | storage_offset = stride[dim1] * abs(offset)
116 | to_append = stride[dim1] + stride[dim2]
117 | if dim2 < 0:
118 | stride.pop(dim1)
119 | stride.pop(dim2)
120 | else:
121 | stride.pop(dim2)
122 | stride.pop(dim1) # todo handle +/- or -/+ (now only support +/+ ans -/-)
123 | stride.append(to_append)
124 | size = list(x.size())
125 | to_append = size[dim1] - 1 - abs(offset)
126 | if dim2 < 0:
127 | size.pop(dim1)
128 | size.pop(dim2)
129 | else:
130 | size.pop(dim2)
131 | size.pop(dim1)
132 | size.append(to_append)
133 | return x.as_strided(size, stride, storage_offset=storage_offset)
134 |
--------------------------------------------------------------------------------
/src/model/torch_struct/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Tuple, Union
3 |
4 | import torch
5 | from torch import Tensor
6 | from torch.autograd import Function
7 |
8 | from .semirings import LogSemiring, Semiring
9 |
10 |
11 | class Get(Function):
12 | @staticmethod
13 | def forward(ctx, chart, grad_chart, indices):
14 | ctx.save_for_backward(grad_chart)
15 | out = chart[indices]
16 | ctx.indices = indices
17 | return out
18 |
19 | @staticmethod
20 | def backward(ctx, grad_output):
21 | (grad_chart, ) = ctx.saved_tensors
22 | grad_chart[ctx.indices] += grad_output
23 | return grad_chart, None, None
24 |
25 |
26 | class Set(torch.autograd.Function):
27 | @staticmethod
28 | def forward(ctx, chart, indices, vals):
29 | chart[indices] = vals
30 | ctx.indices = indices
31 | return chart
32 |
33 | @staticmethod
34 | def backward(ctx, grad_output):
35 | z = grad_output[ctx.indices]
36 | return None, None, z
37 |
38 |
39 | class Chart:
40 | def __init__(self, size, potentials, semiring: Semiring, cache=True):
41 | self.data = semiring.zero_(
42 | torch.empty(*((semiring.size(), ) + size), dtype=potentials.dtype, device=potentials.device))
43 | self.grad = self.data.detach().clone().fill_(0.0)
44 | self.cache = cache
45 | self.semiring = semiring
46 |
47 | def __getitem__(self, ind):
48 | I = slice(None)
49 | if self.cache:
50 | return Get.apply(self.data, self.grad, (I, I) + ind)
51 | else:
52 | return self.data[(I, I) + ind]
53 |
54 | def __setitem__(self, ind, new):
55 | I = slice(None)
56 | if self.cache:
57 | self.data = Set.apply(self.data, (I, I) + ind, new)
58 | else:
59 | self.data[(I, I) + ind] = new
60 |
61 | def get(self, ind):
62 | return Get.apply(self.data, self.grad, ind)
63 |
64 | def set(self, ind, new):
65 | self.data = Set.apply(self.data, ind, new)
66 |
67 |
68 | class _Struct:
69 | def __init__(self, semiring: Semiring = LogSemiring):
70 | self.semiring = semiring
71 |
72 | def score(self, potentials: Tensor, parts: Tensor, batch_dims=(0, )) -> Tensor:
73 | """gather all score in parts"""
74 | score = torch.mul(potentials, parts)
75 | batch = tuple((score.shape[b] for b in batch_dims))
76 | return self.semiring.prod(score.view(batch + (-1, )))
77 |
78 | def _bin_length(self, length: int) -> Tuple[int, int]:
79 | log_N = int(math.ceil(math.log(length, 2)))
80 | bin_N = int(math.pow(2, log_N))
81 | return log_N, bin_N
82 |
83 | def _get_dimension_and_requires_grad(self, edge: Union[List[Tensor], Tensor]) -> Tuple[int, ...]:
84 | if isinstance(edge, (list, tuple)):
85 | for t in edge:
86 | t.requires_grad_(True)
87 | return edge[0].shape
88 | else:
89 | edge.requires_grad_(True)
90 | return edge.shape
91 |
92 | def _chart(self, size, potentials, force_grad):
93 | return self._make_chart(1, size, potentials, force_grad)[0]
94 |
95 | def _make_chart(self, N, size, potentials, force_grad=False):
96 | return [(self.semiring.zero_(
97 | torch.zeros(*((self.semiring.size(), ) + size), dtype=potentials.dtype,
98 | device=potentials.device)).requires_grad_(force_grad and not potentials.requires_grad))
99 | for _ in range(N)]
100 |
101 | def sum(self, edge, lengths=None, _raw=False):
102 | """
103 | Compute the (semiring) sum over all structures model.
104 |
105 | Parameters:
106 | edge : generic params (see class)
107 | lengths: None or b long tensor mask
108 |
109 | Returns:
110 | v: b tensor of total sum
111 | """
112 |
113 | v = self._dp(edge, lengths)[0]
114 | if _raw:
115 | return v
116 | return self.semiring.unconvert(v)
117 |
118 | def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False):
119 | """
120 | Compute the marginals of a structured model.
121 |
122 | Parameters:
123 | params : generic params (see class)
124 | lengths: None or b long tensor mask
125 | Returns:
126 | marginals: b x (N-1) x C x C table
127 |
128 | """
129 | if (_autograd or self.semiring is not LogSemiring or not hasattr(self, '_dp_backward')):
130 | v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True, cache=not _raw)
131 | if _raw:
132 | all_m = []
133 | for k in range(v.shape[0]):
134 | obj = v[k].sum(dim=0)
135 |
136 | marg = torch.autograd.grad(
137 | obj,
138 | edges,
139 | create_graph=True,
140 | only_inputs=True,
141 | allow_unused=False,
142 | )
143 | all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
144 | return torch.stack(all_m, dim=0)
145 | elif _combine:
146 | obj = v.sum(dim=0).sum(dim=0)
147 | marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False)
148 | a_m = self._arrange_marginals(marg)
149 | return a_m
150 | else:
151 | obj = self.semiring.unconvert(v).sum(dim=0)
152 | marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False)
153 | a_m = self._arrange_marginals(marg)
154 | return self.semiring.unconvert(a_m)
155 | else:
156 | v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
157 | return self._dp_backward(edge, lengths, alpha)
158 |
159 | @staticmethod
160 | def to_parts(spans, extra, lengths=None):
161 | return spans
162 |
163 | @staticmethod
164 | def from_parts(spans):
165 | return spans, None
166 |
167 | def _arrange_marginals(self, marg):
168 | return marg[0]
169 |
170 | def _dp(self, scores, lengths=None, force_grad=False, cache=True):
171 | raise NotImplementedError
172 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/__init__.py:
--------------------------------------------------------------------------------
1 | from .checkpoint import CheckpointSemiring, CheckpointShardSemiring
2 | from .fast_semirings import FastLogSemiring, FastMaxSemiring, FastSampleSemiring
3 | from .sample import GumbelCRFSemiring, MultiSampledSemiring, SampledSemiring
4 | from .semirings import (CrossEntropySemiring, EntropySemiring, KLDivergenceSemiring, KMaxSemiring, LogSemiring,
5 | MaxSemiring, RiskSemiring, Semiring, StdSemiring, TempMax)
6 | from .sparse_max import SparseMaxSemiring
7 |
8 | # For flake8 compatibility.
9 | __all__ = [
10 | Semiring,
11 | FastLogSemiring,
12 | FastMaxSemiring,
13 | FastSampleSemiring,
14 | LogSemiring,
15 | StdSemiring,
16 | SampledSemiring,
17 | MaxSemiring,
18 | SparseMaxSemiring,
19 | KMaxSemiring,
20 | EntropySemiring,
21 | CrossEntropySemiring,
22 | KLDivergenceSemiring,
23 | MultiSampledSemiring,
24 | CheckpointSemiring,
25 | CheckpointShardSemiring,
26 | TempMax,
27 | ]
28 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/checkpoint.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | try:
4 | import genbmm
5 | from genbmm import BandedMatrix
6 | except ImportError:
7 | pass
8 |
9 |
10 | def broadcast_size(a, b):
11 | return torch.tensor([max(i, j) for i, j in zip(a.shape, b.shape)]).prod()
12 |
13 |
14 | def matmul_size(a, b):
15 | size = [max(i, j) for i, j in zip(a.shape[:-2], b.shape[:-2])]
16 | size.append(a.shape[-2])
17 | size.append(b.shape[-1])
18 | return size
19 |
20 |
21 | def CheckpointSemiring(cls, min_size=0):
22 | class _Check(torch.autograd.Function):
23 | @staticmethod
24 | def forward(ctx, a, b):
25 | ctx.save_for_backward(a, b)
26 | return cls.matmul(a, b)
27 |
28 | @staticmethod
29 | def backward(ctx, grad_output):
30 | a, b = ctx.saved_tensors
31 | with torch.enable_grad():
32 | q = cls.matmul(a, b)
33 | return torch.autograd.grad(q, (a, b), grad_output)
34 |
35 | class _CheckBand(torch.autograd.Function):
36 | @staticmethod
37 | def forward(ctx, a, a_lu, a_ld, b, b_lu, b_ld):
38 | ctx.save_for_backward(a, b, torch.LongTensor([a_lu, a_ld, b_lu, b_ld]))
39 | a = BandedMatrix(a, a_lu, a_ld)
40 | b = BandedMatrix(b, b_lu, b_ld)
41 | return cls.matmul(a, b).data
42 |
43 | @staticmethod
44 | def backward(ctx, grad_output):
45 | a, b, bands = ctx.saved_tensors
46 | a_lu, a_ld, b_lu, b_ld = bands.tolist()
47 | with torch.enable_grad():
48 | q = cls.matmul(BandedMatrix(a, a_lu, a_ld), BandedMatrix(b, b_lu, b_ld))
49 | grad_a, grad_b = torch.autograd.grad(q.data, (a, b), grad_output)
50 | return grad_a, None, None, grad_b, None, None
51 |
52 | class _CheckpointSemiring(cls):
53 | @staticmethod
54 | def matmul(a, b):
55 | if isinstance(a, genbmm.BandedMatrix):
56 | lu = a.lu + b.lu
57 | ld = a.ld + b.ld
58 | c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)
59 | return BandedMatrix(c, lu, ld, cls.zero)
60 |
61 | if broadcast_size(a, b) > min_size:
62 | return _Check.apply(a, b)
63 | else:
64 | return cls.matmul(a, b)
65 |
66 | return _CheckpointSemiring
67 |
68 |
69 | def CheckpointShardSemiring(cls, max_size, min_size=0):
70 | class _Check(torch.autograd.Function):
71 | @staticmethod
72 | def forward(ctx, a, b):
73 | ctx.save_for_backward(a, b)
74 | size = matmul_size(a, b)
75 | return accumulate_(
76 | a,
77 | b,
78 | size,
79 | lambda a, b: cls.matmul(a, b),
80 | preserve=len(size),
81 | step=max_size // (b.shape[-2] * a.shape[-1]) + 2,
82 | )
83 |
84 | @staticmethod
85 | def backward(ctx, grad_output):
86 | a, b = ctx.saved_tensors
87 | grad_a, grad_b = unaccumulate_(
88 | a,
89 | b,
90 | grad_output,
91 | len(grad_output.shape),
92 | lambda a, b: cls.matmul(a, b),
93 | step=max_size // (b.shape[-2] * a.shape[-1]) + 2,
94 | )
95 | return grad_a, grad_b
96 |
97 | class _CheckpointSemiring(cls):
98 | @staticmethod
99 | def matmul(a, b):
100 | size = torch.tensor([max(i, j) for i, j in zip(a.shape, b.shape)]).prod()
101 | if size < min_size:
102 | return cls.matmul(a, b)
103 | else:
104 | return _Check.apply(a, b)
105 |
106 | return _CheckpointSemiring
107 |
108 |
109 | def ones(x):
110 | one = []
111 | for i, v in enumerate(x.shape[:-1]):
112 | if v == 1:
113 | one.append(i)
114 | return one
115 |
116 |
117 | def mind(one, inds):
118 | inds = list(inds)
119 | for v in one:
120 | inds[v] = inds[v].clone().fill_(0)
121 | return inds
122 |
123 |
124 | def accumulate_(a, b, size, fn, preserve, step=10000):
125 | slices = []
126 | total = 1
127 | for s in size[:preserve]:
128 | slices.append(slice(s))
129 | total *= s
130 | if step > total:
131 | return fn(a, b)
132 |
133 | ret = torch.zeros(*size, dtype=a.dtype, device=a.device)
134 |
135 | a = a.expand(*size[:-2], a.shape[-2], a.shape[-1])
136 | b = b.expand(*size[:-2], b.shape[-2], b.shape[-1])
137 |
138 | a2 = a.contiguous().view(-1, a.shape[-2], a.shape[-1])
139 | b2 = b.contiguous().view(-1, b.shape[-2], b.shape[-1])
140 | ret = ret.view(-1, a.shape[-2], b.shape[-1])
141 | for p in range(0, ret.shape[0], step):
142 | ret[p:p + step, :] = fn(a2[p:p + step], b2[p:p + step])
143 | ret = ret.view(*size)
144 | return ret
145 |
146 |
147 | def unaccumulate_(a, b, grad_output, preserve, fn, step=10000):
148 | slices = []
149 | total = 1
150 | size = grad_output.shape[:preserve]
151 | for s in grad_output.shape[:preserve]:
152 | slices.append(slice(s))
153 | total *= s
154 |
155 | if step > total:
156 | with torch.enable_grad():
157 | a_in = a.clone().requires_grad_(True)
158 | b_in = b.clone().requires_grad_(True)
159 | q = fn(a, b)
160 | ag, bg = torch.autograd.grad(q, (a, b), grad_output)
161 | return ag, bg
162 |
163 | a2 = a.expand(*size[:-2], a.shape[-2], a.shape[-1])
164 | b2 = b.expand(*size[:-2], b.shape[-2], b.shape[-1])
165 | a2 = a2.contiguous().view(-1, a.shape[-2], a.shape[-1])
166 | b2 = b2.contiguous().view(-1, b.shape[-2], b.shape[-1])
167 |
168 | a_grad = a2.clone().fill_(0)
169 | b_grad = b2.clone().fill_(0)
170 |
171 | grad_output = grad_output.view(-1, a.shape[-2], b.shape[-1])
172 | for p in range(0, grad_output.shape[0], step):
173 | with torch.enable_grad():
174 | a_in = a2[p:p + step].clone().requires_grad_(True)
175 | b_in = b2[p:p + step].clone().requires_grad_(True)
176 | q = fn(a_in, b_in)
177 | ag, bg = torch.autograd.grad(q, (a_in, b_in), grad_output[p:p + step])
178 | a_grad[p:p + step] += ag
179 | b_grad[p:p + step] += bg
180 |
181 | a_grad = a_grad.view(*size[:-2], a.shape[-2], a.shape[-1])
182 | b_grad = b_grad.view(*size[:-2], b.shape[-2], b.shape[-1])
183 | a_ones = ones(a)
184 | b_ones = ones(b)
185 | f1, f2 = a_grad.sum(a_ones, keepdim=True), b_grad.sum(b_ones, keepdim=True)
186 | return f1, f2
187 |
188 |
189 | # def unaccumulate_(a, b, grad_output, fn, step=10000):
190 | # slices = []
191 | # a_grad = a.clone().fill_(0)
192 | # b_grad = b.clone().fill_(0)
193 |
194 | # total = 1
195 | # for s in grad_output.shape:
196 | # slices.append(slice(s))
197 | # total *= s
198 | # a_one, b_one = ones(a), ones(b)
199 |
200 | # indices = torch.tensor(np.mgrid[slices]).view(len(grad_output.shape), -1)
201 |
202 | # for p in range(0, total, step):
203 | # ind = indices[:, p : p + step].unbind()
204 | # a_ind = mind(a_one, ind)
205 | # b_ind = mind(b_one, ind)
206 |
207 | # q = fn(a[tuple(a_ind)], b[tuple(b_ind)], grad_output[tuple(ind)])
208 | # a_grad.index_put_(tuple(a_ind), q, accumulate=True)
209 | # b_grad.index_put_(tuple(b_ind), q, accumulate=True)
210 | # return a_grad, b_grad
211 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/fast_semirings.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions
3 |
4 | from .sample import _SampledLogSumExp
5 | from .semirings import _BaseLog
6 |
7 | try:
8 | import genbmm
9 | except ImportError:
10 | pass
11 |
12 |
13 | def matmul_size(a, b):
14 | size = [max(i, j) for i, j in zip(a.shape[:-2], b.shape[:-2])]
15 | size.append(a.shape[-2])
16 | size.append(b.shape[-1])
17 | return size
18 |
19 |
20 | def broadcast(a, b):
21 | size = matmul_size(a, b)
22 | a = a.expand(*size[:-2], a.shape[-2], a.shape[-1])
23 | b = b.expand(*size[:-2], b.shape[-2], b.shape[-1])
24 | a2 = a.contiguous().view(-1, a.shape[-2], a.shape[-1])
25 | b2 = b.contiguous().view(-1, b.shape[-2], b.shape[-1])
26 | return a2, b2, size
27 |
28 |
29 | class FastLogSemiring(_BaseLog):
30 | """
31 | Implements the log-space semiring (logsumexp, +, -inf, 0).
32 |
33 | Gradients give marginals.
34 | """
35 | @staticmethod
36 | def sum(xs, dim=-1):
37 | return torch.logsumexp(xs, dim=dim)
38 |
39 | @staticmethod
40 | def matmul(a, b, dims=1):
41 | if isinstance(a, genbmm.BandedMatrix):
42 | return b.multiply_log(a.transpose())
43 | else:
44 | a2, b2, size = broadcast(a, b)
45 | return genbmm.logbmm(a2, b2).view(size)
46 |
47 |
48 | class FastMaxSemiring(_BaseLog):
49 | @staticmethod
50 | def sum(xs, dim=-1):
51 | return torch.max(xs, dim=dim)[0]
52 |
53 | @staticmethod
54 | def matmul(a, b, dims=1):
55 | a2, b2, size = broadcast(a, b)
56 | return genbmm.maxbmm(a2, b2).view(size)
57 |
58 |
59 | class FastSampleSemiring(_BaseLog):
60 | @staticmethod
61 | def sum(xs, dim=-1):
62 | return _SampledLogSumExp.apply(xs, dim)
63 |
64 | @staticmethod
65 | def matmul(a, b, dims=1):
66 | a2, b2, size = broadcast(a, b)
67 | return genbmm.samplebmm(a2, b2).view(size)
68 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/keops.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions
3 |
4 | from .semirings import _BaseLog
5 |
6 | try:
7 | from pykeops.torch import LazyTensor
8 | except ImportError:
9 | pass
10 |
11 |
12 | class LogSemiringKO(_BaseLog):
13 | """
14 | Implements the log-space semiring (logsumexp, +, -inf, 0).
15 |
16 | Gradients give marginals.
17 | """
18 | @staticmethod
19 | def sum(a, dim=-1):
20 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous())
21 | c = a_lazy.sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1)
22 | return c
23 |
24 | @classmethod
25 | def dot(cls, a, b):
26 | """
27 | Dot product along last dim. (Faster than calling sum and times.)
28 | """
29 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous())
30 | b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous())
31 | c = (a_lazy + b_lazy).sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1)
32 | return c
33 |
34 |
35 | class _Max(torch.autograd.Function):
36 | @staticmethod
37 | def forward(ctx, a, b):
38 | one_hot = b.shape[-1]
39 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous())
40 | b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous())
41 | c = (a_lazy + b_lazy).sum(-1).max(a.dim() - 1).squeeze(-1).squeeze(-1)
42 | ac = (a_lazy + b_lazy).sum(-1).argmax(a.dim() - 1).squeeze(-1).squeeze(-1)
43 | ctx.save_for_backward(ac, torch.tensor(one_hot))
44 | return c
45 |
46 | @staticmethod
47 | def backward(ctx, grad_output):
48 | ac, size = ctx.saved_tensors
49 | back = torch.nn.functional.one_hot(ac, size).type_as(grad_output)
50 | ret = grad_output.unsqueeze(-1).mul(back)
51 | return ret, ret
52 |
53 |
54 | class MaxSemiringKO(_BaseLog):
55 | @classmethod
56 | def sum(cls, xs, dim=-1):
57 | assert dim == -1
58 | return cls.dot(xs, xs.clone().fill_(0))
59 |
60 | @classmethod
61 | def dot(cls, a, b):
62 | """
63 | Dot product along last dim. (Faster than calling sum and times.)
64 | """
65 | return _Max.apply(a, b)
66 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/sample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributions
3 |
4 | from .semirings import _BaseLog
5 |
6 |
7 | class _SampledLogSumExp(torch.autograd.Function):
8 | @staticmethod
9 | def forward(ctx, input, dim):
10 | ctx.save_for_backward(input, torch.tensor(dim))
11 | return torch.logsumexp(input, dim=dim)
12 |
13 | @staticmethod
14 | def backward(ctx, grad_output):
15 | logits, dim = ctx.saved_tensors
16 | grad_input = None
17 | if ctx.needs_input_grad[0]:
18 |
19 | def sample(ls):
20 | pre_shape = ls.shape
21 | draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 1, True)
22 | draws.squeeze(1)
23 | return (torch.nn.functional.one_hot(draws, pre_shape[-1]).view(*pre_shape).type_as(ls))
24 |
25 | if dim == -1:
26 | s = sample(logits)
27 | else:
28 | dim = dim if dim >= 0 else logits.dim() + dim
29 | perm = [i for i in range(logits.dim()) if i != dim] + [dim]
30 | rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
31 | s = sample(logits.permute(perm)).permute(rev_perm)
32 |
33 | grad_input = grad_output.unsqueeze(dim).mul(s)
34 | return grad_input, None
35 |
36 |
37 | class SampledSemiring(_BaseLog):
38 | """
39 | Implements a sampling semiring (logsumexp, +, -inf, 0).
40 |
41 | "Gradients" give sample.
42 |
43 | This is an exact forward-filtering, backward-sampling approach.
44 | """
45 | @staticmethod
46 | def sum(xs, dim=-1):
47 | return _SampledLogSumExp.apply(xs, dim)
48 |
49 |
50 | def GumbelCRFSemiring(temp):
51 | class ST(torch.autograd.Function):
52 | @staticmethod
53 | def forward(ctx, logits, dim):
54 | out = torch.nn.functional.one_hot(logits.max(-1)[1], dim)
55 | out = out.type_as(logits)
56 | ctx.save_for_backward(logits, out)
57 | return out
58 |
59 | @staticmethod
60 | def backward(ctx, grad_output):
61 | logits, out = ctx.saved_tensors
62 | with torch.enable_grad():
63 | ret = torch.autograd.grad(logits.softmax(-1), logits, out * grad_output)[0]
64 | return ret, None
65 |
66 | class _GumbelCRFLogSumExp(torch.autograd.Function):
67 | @staticmethod
68 | def forward(ctx, input, dim):
69 | ctx.save_for_backward(input, torch.tensor(dim))
70 | return torch.logsumexp(input, dim=dim)
71 |
72 | @staticmethod
73 | def backward(ctx, grad_output):
74 | logits, dim = ctx.saved_tensors
75 | grad_input = None
76 | if ctx.needs_input_grad[0]:
77 |
78 | def sample(ls):
79 | update = (ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1], ))) / temp
80 | out = ST.apply(update, ls.shape[-1])
81 | return out
82 |
83 | if dim == -1:
84 | s = sample(logits)
85 | else:
86 | dim = dim if dim >= 0 else logits.dim() + dim
87 | perm = [i for i in range(logits.dim()) if i != dim] + [dim]
88 | rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
89 | s = sample(logits.permute(perm)).permute(rev_perm)
90 |
91 | grad_input = grad_output.unsqueeze(dim).mul(s)
92 | return grad_input, None
93 |
94 | class _GumbelCRFSemiring(_BaseLog):
95 | @staticmethod
96 | def sum(xs, dim=-1):
97 | return _GumbelCRFLogSumExp.apply(xs, dim)
98 |
99 | return _GumbelCRFSemiring
100 |
101 |
102 | bits = torch.tensor([pow(2, i) for i in range(1, 18)])
103 |
104 |
105 | class _MultiSampledLogSumExp(torch.autograd.Function):
106 | @staticmethod
107 | def forward(ctx, input, dim):
108 | part = torch.logsumexp(input, dim=dim)
109 | ctx.save_for_backward(input, part, torch.tensor(dim))
110 | return part
111 |
112 | @staticmethod
113 | def backward(ctx, grad_output):
114 |
115 | logits, part, dim = ctx.saved_tensors
116 | grad_input = None
117 | if ctx.needs_input_grad[0]:
118 |
119 | def sample(ls):
120 | pre_shape = ls.shape
121 | draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 16, True)
122 | draws = draws.transpose(0, 1)
123 | return (torch.nn.functional.one_hot(draws, pre_shape[-1]).view(16, *pre_shape).type_as(ls))
124 |
125 | if dim == -1:
126 | s = sample(logits)
127 | else:
128 | dim = dim if dim >= 0 else logits.dim() + dim
129 | perm = [i for i in range(logits.dim()) if i != dim] + [dim]
130 | rev_perm = [0] + [a + 1 for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
131 | s = sample(logits.permute(perm)).permute(rev_perm)
132 |
133 | dim = dim if dim >= 0 else logits.dim() + dim
134 | final = (grad_output % 2).unsqueeze(0)
135 | mbits = bits[:].type_as(grad_output)
136 | on = grad_output.unsqueeze(0) % mbits.view(17, *[1] * grad_output.dim())
137 | on = on[1:] - on[:-1]
138 | old_bits = (on + final == 0).unsqueeze(dim + 1)
139 |
140 | grad_input = (mbits[:-1].view(16, *[1] * (s.dim() - 1)).mul(s.masked_fill_(old_bits, 0)))
141 |
142 | return torch.sum(grad_input, dim=0), None
143 |
144 |
145 | class MultiSampledSemiring(_BaseLog):
146 | """
147 | Implements a multi-sampling semiring (logsumexp, +, -inf, 0).
148 |
149 | "Gradients" give up to 16 samples with replacement.
150 | """
151 | @staticmethod
152 | def sum(xs, dim=-1):
153 | return _MultiSampledLogSumExp.apply(xs, dim)
154 |
155 | @staticmethod
156 | def to_discrete(xs, j):
157 | i = j
158 | final = xs % 2
159 | mbits = bits.type_as(xs)
160 | return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final) != 0).type_as(xs)
161 |
--------------------------------------------------------------------------------
/src/model/torch_struct/semirings/sparse_max.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .semirings import _BaseLog
4 |
5 |
6 | class SparseMaxSemiring(_BaseLog):
7 | """
8 |
9 | Implements differentiable dynamic programming with a sparsemax semiring (sparsemax, +, -inf, 0).
10 |
11 | Sparse-max gradients give a more sparse set of marginal like terms.
12 |
13 | * From softmax to sparsemax- A sparse model of attention and multi-label classification :cite:`martins2016softmax`
14 | * Differentiable dynamic programming for structured prediction and attention :cite:`mensch2018differentiable`
15 | """
16 | @staticmethod
17 | def sum(xs, dim=-1):
18 | return _SimplexProject.apply(xs, dim)
19 |
20 |
21 | class _SimplexProject(torch.autograd.Function):
22 | @staticmethod
23 | def forward(ctx, input, dim, z=1):
24 | w_star = project_simplex(input, dim)
25 | ctx.save_for_backward(input, w_star.clone(), torch.tensor(dim))
26 | x = input.mul(w_star).sum(dim) - w_star.norm(p=2, dim=dim)
27 | return x
28 |
29 | @staticmethod
30 | def backward(ctx, grad_output):
31 | input, w_star, dim = ctx.saved_tensors
32 | w_star.requires_grad_(True)
33 |
34 | grad_input = None
35 | if ctx.needs_input_grad[0]:
36 | wstar = _SparseMaxGrad.apply(w_star, dim)
37 | grad_input = grad_output.unsqueeze(dim).mul(wstar)
38 | return grad_input, None, None
39 |
40 |
41 | class _SparseMaxGrad(torch.autograd.Function):
42 | @staticmethod
43 | def forward(ctx, w_star, dim):
44 | ctx.save_for_backward(w_star, dim)
45 | return w_star
46 |
47 | @staticmethod
48 | def backward(ctx, grad_output):
49 | w_star, dim = ctx.saved_tensors
50 | return sparsemax_grad(grad_output, w_star, dim.item()), None
51 |
52 |
53 | def project_simplex(v, dim, z=1):
54 | v_sorted, _ = torch.sort(v, dim=dim, descending=True)
55 | cssv = torch.cumsum(v_sorted, dim=dim) - z
56 | ind = torch.arange(1, 1 + v.shape[dim]).to(dtype=v.dtype)
57 | cond = v_sorted - cssv / ind >= 0
58 | k = cond.sum(dim=dim, keepdim=True)
59 | tau = cssv.gather(dim, k - 1) / k.to(dtype=v.dtype)
60 | w = torch.clamp(v - tau, min=0)
61 | return w
62 |
63 |
64 | def sparsemax_grad(dout, w_star, dim):
65 | out = dout.clone()
66 | supp = w_star > 0
67 | out[w_star <= 0] = 0
68 | nnz = supp.to(dtype=dout.dtype).sum(dim=dim, keepdim=True)
69 | out = out - (out.sum(dim=dim, keepdim=True) / nnz)
70 | out[w_star <= 0] = 0
71 | return out
72 |
--------------------------------------------------------------------------------
/src/model/vis_encoder/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.vis_encoder.base import VisEncoderBase
2 | from src.model.vis_encoder.box_rel import VisBoxRelSimpleEncoder
--------------------------------------------------------------------------------
/src/model/vis_encoder/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | import torch.nn as nn
6 |
7 | if TYPE_CHECKING:
8 | from src.model import ModelBase
9 |
10 |
11 | class VisEncoderBase(nn.Module):
12 | bounded_model: ModelBase
13 |
14 | def __init__(self):
15 | super(VisEncoderBase, self).__init__()
16 |
17 | def forward(self, x, ctx):
18 | raise NotImplementedError
19 |
20 | def get_dim(self, field):
21 | raise NotImplementedError(f'Unrecognized {field=}')
22 |
23 |
--------------------------------------------------------------------------------
/src/model/vis_encoder/box_rel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as tnn
3 |
4 | from src.model.vis_encoder import VisEncoderBase
5 | from src.model.nn import MLP, BiaffineScorer
6 |
7 |
8 | class VisBoxRelSimpleEncoder(VisEncoderBase):
9 | def __init__(self, n_in, n_hidden, dropout, activate, use_attr, use_img, img_feat):
10 | super().__init__()
11 |
12 | self.use_img = use_img
13 | if use_img:
14 | self.img_fc = MLP(n_in, n_hidden, dropout, activate)
15 |
16 | self.img_feat = img_feat
17 | if img_feat:
18 | n_in *= 2
19 | self.box_fc = MLP(n_in, n_hidden, dropout, activate)
20 | self.rel_fc = MLP(n_in, n_hidden, dropout, activate)
21 | # self.rel_fc = BiaffineScorer(n_in * 2, n_hidden, n_hidden, dropout, activate, 1)
22 |
23 | self.use_attr = use_attr
24 | if use_attr:
25 | self.attr_fc = MLP(n_in, n_hidden, dropout, activate)
26 | self.n_hidden = n_hidden
27 | self.dropout = dropout
28 |
29 | def forward(self, x, ctx):
30 |
31 | if self.img_feat:
32 | feat: torch.Tensor = x["vis_box_feat"]
33 | B, N, H = feat.shape
34 | box = feat
35 | inputs = torch.cat(
36 | [box, feat.mean(1, keepdim=True).expand(-1, feat.shape[1], -1)], dim=-1
37 | )
38 | else:
39 | inputs = x["vis_box_feat"]
40 | B, N, H = inputs.shape
41 | inputs = inputs.view(B, N, H)
42 | _rel_inp = (inputs.unsqueeze(1) + inputs.unsqueeze(2)) / 2
43 | x_rel = self.rel_fc(_rel_inp)
44 | # x_rel = self.rel_fc(inputs, inputs)
45 | rel = x_rel.view(len(x_rel), -1, self.n_hidden)
46 |
47 | out = {"box": self.box_fc(inputs), "rel": rel}
48 | if self.use_attr:
49 | out["attr"] = self.attr_fc(inputs)
50 | if self.use_img:
51 | out["img"] = self.img_fc(x["vis_box_feat"].mean(1, keepdim=True))
52 | return out
53 |
54 | def get_dim(self, field):
55 | return self.n_hidden
56 |
57 |
--------------------------------------------------------------------------------
/src/utility/config.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from dataclasses import dataclass
3 |
4 | from omegaconf import MISSING, DictConfig
5 |
6 | from src.utility.logger import get_logger_func
7 |
8 | _warn, _info, _debug = get_logger_func('config')
9 |
10 |
11 | @dataclass
12 | class Config:
13 | @classmethod
14 | def build(cls, env, ignore_unknown=False, allow_missing=None):
15 | if isinstance(env, (dict, DictConfig)):
16 | if 'cfg' in env and isinstance(env['cfg'], cls):
17 | breakpoint()
18 | return env['cfg']
19 |
20 | matched = {k: v for k, v in env.items() if k in inspect.signature(cls).parameters}
21 | unmatched = {k: env[k]
22 | for k in env.keys() - matched.keys()
23 | if not k.startswith('n_')} # n_* will be set automatically
24 | if unmatched and not ignore_unknown:
25 | raise ValueError(f'Unrecognized cfg: {unmatched}')
26 | # noinspection PyArgumentList
27 | cfg = cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters})
28 |
29 | allow_missing = allow_missing or set()
30 | for key, value in cfg.__dict__.items():
31 | if not key.startswith('_') and key not in allow_missing:
32 | assert value is not MISSING, f'{key} is MISSING.'
33 |
34 | if ignore_unknown:
35 | return cfg, unmatched
36 | return cfg
37 | elif isinstance(env, cls):
38 | return env
39 | raise TypeError
40 |
41 | def __setitem__(self, key, value):
42 | if not hasattr(self, key):
43 | _warn(f"Adding new key: {key}")
44 | setattr(self, key, value)
45 |
46 | def __getitem__(self, item):
47 | return getattr(self, item)
--------------------------------------------------------------------------------
/src/utility/defaultlist.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 |
4 | class defaultlist(list):
5 | """
6 | __version__ = "1.0.0"
7 | __author__ = 'c0fec0de'
8 | __author_email__ = 'c0fec0de@gmail.com'
9 | __description__ = " collections.defaultdict equivalent implementation of list."
10 | __url__ = "https://github.com/c0fec0de/defaultlist"
11 | """
12 |
13 | # noinspection PyMissingConstructor
14 | def __init__(self, factory=None):
15 | """
16 | List extending automatically to the maximum requested length.
17 | Keyword Args:
18 | factory: Function called for every missing index.
19 | """
20 | self.__factory = factory or defaultlist.__nonefactory
21 |
22 | @staticmethod
23 | def __nonefactory():
24 | return None
25 |
26 | def __fill(self, index):
27 | missing = index - len(self) + 1
28 | if missing > 0:
29 | # noinspection PyMethodFirstArgAssignment
30 | self += [self.__factory() for _ in range(missing)]
31 |
32 | def __setitem__(self, index, value):
33 | self.__fill(index)
34 | list.__setitem__(self, index, value)
35 |
36 | def __getitem__(self, index):
37 | if isinstance(index, slice):
38 | return self.__getslice(index.start, index.stop, index.step)
39 | else:
40 | self.__fill(index)
41 | return list.__getitem__(self, index)
42 |
43 | def __getslice__(self, start, stop, step=None): # pragma: no cover
44 | # python 2.x legacy
45 | if stop == sys.maxint:
46 | stop = None
47 | return self.__getslice(start, stop, step)
48 |
49 | def __normidx(self, idx, default):
50 | if idx is None:
51 | idx = default
52 | elif idx < 0:
53 | idx += len(self)
54 | return idx
55 |
56 | def __getslice(self, start, stop, step):
57 | end = max((start or 0, stop or 0, 0))
58 | if end:
59 | self.__fill(end)
60 | start = self.__normidx(start, 0)
61 | stop = self.__normidx(stop, len(self))
62 | step = step or 1
63 | r = defaultlist(factory=self.__factory)
64 | for idx in range(start, stop, step):
65 | r.append(list.__getitem__(self, idx))
66 | return r
67 |
68 | def __add__(self, other):
69 | if isinstance(other, list):
70 | r = self.copy()
71 | r += other
72 | return r
73 | else:
74 | return list.__add__(self, other)
75 |
76 | def copy(self):
77 | """Return a shallow copy of the list. Equivalent to a[:]."""
78 | r = defaultlist(factory=self.__factory)
79 | r += self
80 | return r
81 |
--------------------------------------------------------------------------------
/src/utility/fn.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import logging
3 | import os
4 | from functools import wraps
5 | from typing import Any, Dict, Callable, Optional, Iterator
6 |
7 | from hydra.utils import instantiate
8 | from omegaconf import ListConfig, DictConfig
9 | from pytorch_lightning import Trainer
10 | from torch import Tensor
11 |
12 |
13 | def not_distributed_guard():
14 | import torch.distributed as dist
15 | assert not dist.is_initialized()
16 |
17 |
18 | def endless_iter(i: Iterator, shuffle: Optional[Callable] = None, inplace_shuffle: Optional[Callable] = None):
19 | while True:
20 | if shuffle is not None:
21 | i = shuffle(i)
22 | if inplace_shuffle is not None:
23 | inplace_shuffle(i)
24 | for x in i:
25 | yield x
26 |
27 |
28 | def dict_apply(d: Dict[Any, Any], func=None, key_func=None):
29 | assert func or key_func
30 | if func is None:
31 | return {key_func(key): value for key, value in d.items()}
32 | elif key_func is None:
33 | return {key: func(value) for key, value in d.items()}
34 | return {key_func(key): func(value) for key, value in d.items()}
35 |
36 |
37 | def hydra_instantiate_func_helper(func):
38 | """convert func() to func()()"""
39 |
40 | @wraps(func)
41 | def wrapper(*args, **kwargs):
42 | def mid():
43 | return func(*args, **kwargs)
44 |
45 | return mid
46 |
47 | return wrapper
48 |
49 |
50 | def reduce_loss(mode, loss, num_token, num_sentence) -> Tensor:
51 | if not isinstance(loss, list):
52 | loss, num_token, num_sentence = [loss], [num_token], [num_sentence]
53 | assert len(loss) >= 1, 'Nothing to reduce. You should handle this error outside this function.'
54 | if mode == 'token':
55 | # average over tokens in a batch
56 | return sum(loss) / (sum(num_token) + 1e-12)
57 | elif mode == 'sentence':
58 | # first average over tokens in a sentence.
59 | # then average sentences over a batch
60 | # return sum((l / s).sum() for l, s in zip(loss, seq_len)) / (sum(len(s) for s in seq_len))
61 | raise NotImplementedError('Deprecated')
62 | elif mode == 'batch':
63 | # average over sentences in a batch
64 | return sum(loss) / (sum(num_sentence) + 1e-12)
65 | elif mode == 'sum':
66 | return sum(loss)
67 | raise ValueError
68 |
69 |
70 | def split_list(raw, size):
71 | out = []
72 | offset = 0
73 | for s in size:
74 | out.append(raw[offset: offset + s])
75 | offset += s
76 | assert offset == len(raw)
77 | return out
78 |
79 |
80 | def instantiate_no_recursive(*args, **kwargs):
81 | return instantiate(*args, **kwargs, _recursive_=False)
82 |
83 |
84 | def get_coeff_iter(command, idx_getter=None, validator=None):
85 | # 1. not (list, tuple, ListConfig): constant alpha
86 | # 2. List[str]: str should be [value]@[epoch]. eg "[0@0, 0.5@100]". Linearly to value at epoch.
87 | # the first term must be @0 (from the beginning)
88 | if not isinstance(command, (list, tuple, ListConfig)):
89 | # -123456789 is never reached, so it is endless
90 | assert command != -123456789
91 | return iter(lambda: command, -123456789)
92 |
93 | if idx_getter is None:
94 | _i = 0
95 |
96 | def auto_inc():
97 | nonlocal _i
98 | i, _i = _i, _i + 1
99 | return i
100 |
101 | idx_getter = auto_inc
102 |
103 | def calculate_alpha(value_and_step):
104 | prev_v, prev_s = value_and_step[0].split('@')
105 | prev_v, prev_s = float(prev_v), int(prev_s)
106 | assert prev_s == 0, 'the first step must be 0'
107 | idx = idx_getter()
108 | for i in range(1, len(value_and_step)):
109 | next_v, next_s = value_and_step[i].split('@')
110 | next_v, next_s = float(next_v), int(next_s)
111 | rate = (next_v - prev_v) / (next_s - prev_s)
112 | while idx <= next_s:
113 | value = prev_v + rate * (idx - prev_s)
114 | if validator is not None:
115 | assert validator(value), f'Bad value in coeff_iter. Get {value}.'
116 | yield value
117 | idx = idx_getter()
118 | prev_v, prev_s = next_v, next_s
119 | while True:
120 | yield prev_v
121 |
122 | return iter(calculate_alpha(command))
123 |
124 |
125 | def instantiate_trainer(callbacks=None, **kwargs):
126 | if callbacks is not None:
127 | NoneType = type(None)
128 | callbacks = list(filter(lambda x: not isinstance(x, (dict, DictConfig, NoneType)), callbacks.values()))
129 | return Trainer(callbacks=callbacks, **kwargs)
130 |
131 |
132 | def pad(tensors, padding_value=0, total_length=None, padding_side='right'):
133 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) for i in range(len(tensors[0].size()))]
134 | if total_length is not None:
135 | assert total_length >= size[1]
136 | size[1] = total_length
137 | out_tensor = tensors[0].data.new(*size).fill_(padding_value)
138 | for i, tensor in enumerate(tensors):
139 | out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor
140 | return out_tensor
141 |
142 |
143 | def filter_list(data, mask):
144 | if isinstance(mask[0], list):
145 | out = []
146 | for subdata, submask in zip(data, mask):
147 | out.append(filter_list(subdata, submask))
148 | return out
149 | elif isinstance(mask[0], int):
150 | return [subdata for subdata, submask in zip(data, mask) if submask]
151 | raise ValueError(f'Bad mask value: {mask}')
152 |
153 |
154 | def draw_att(data: Tensor, path=None):
155 | assert data.ndim == 2
156 | import seaborn as sns
157 | import matplotlib.pyplot as plt
158 | data = data.detach().cpu().numpy()
159 | sns.heatmap(data=data, center=0, mask=data < -100)
160 | if path:
161 | plt.savefig(path)
162 | else:
163 | plt.show()
164 |
165 |
166 | def merge_outputs(a, b):
167 | assert a.keys() == b.keys()
168 | for key in a:
169 | adata, bdata = a[key], b[key]
170 | if len(adata) > len(bdata):
171 | bdata.extend([None] * (len(adata) - len(bdata)))
172 | else:
173 | adata.extend([None] * (len(bdata) - len(adata)))
174 | a[key] = [ai if ai is not None else bi for ai, bi in zip(a[key], b[key])]
175 | return a
176 |
177 |
178 | def symlink_force(target, link_name):
179 | try:
180 | os.symlink(target, link_name)
181 | except OSError as e:
182 | if e.errno == errno.EEXIST:
183 | os.remove(link_name)
184 | os.symlink(target, link_name)
185 | else:
186 | raise e
187 |
188 |
189 | def listloggers():
190 | rootlogger = logging.getLogger()
191 | print(rootlogger)
192 | for h in rootlogger.handlers:
193 | print(' %s' % h)
194 |
195 | for nm, lgr in logging.Logger.manager.loggerDict.items():
196 | print('+ [%-20s] %s ' % (nm, lgr))
197 | if not isinstance(lgr, logging.PlaceHolder):
198 | for h in lgr.handlers:
199 | print(' %s' % h)
--------------------------------------------------------------------------------
/src/utility/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | from colorama import Fore
5 | from pytorch_lightning.utilities import rank_zero_only
6 | from tqdm.auto import tqdm
7 |
8 |
9 | class TqdmLoggingHandler(logging.Handler):
10 | def __init__(self, level=logging.INFO):
11 | super().__init__(level)
12 |
13 | def emit(self, record):
14 | try:
15 | msg = self.format(record)
16 | tqdm.write(msg, file=sys.stdout)
17 | except (KeyboardInterrupt, SystemExit):
18 | raise
19 | except:
20 | self.handleError(record)
21 |
22 |
23 | class ColorFormatter(logging.Formatter):
24 | def format(self, record):
25 |
26 | # Save the original format configured by the user
27 | # when the logger formatter was instantiated
28 | format_orig = self._style._fmt
29 |
30 | # Replace the original format with one customized by logging level
31 | if record.levelno == logging.DEBUG:
32 | self._style._fmt = Fore.YELLOW + format_orig + Fore.RESET
33 |
34 | elif record.levelno >= logging.WARNING:
35 | self._style._fmt = Fore.RED + format_orig + Fore.RESET
36 |
37 | # Call the original formatter class to do the grunt work
38 | result = logging.Formatter.format(self, record)
39 |
40 | # Restore the original format configured by the user
41 | self._style._fmt = format_orig
42 |
43 | return result
44 |
45 |
46 | def get_logger_func(name):
47 | log = logging.getLogger(name)
48 |
49 | def _warn(*args, stacklevel: int = 2, **kwargs):
50 | kwargs["stacklevel"] = stacklevel
51 | log.warning(*args, **kwargs)
52 |
53 | def _info(*args, stacklevel: int = 2, **kwargs):
54 | kwargs["stacklevel"] = stacklevel
55 | log.info(*args, **kwargs)
56 |
57 | def _debug(*args, stacklevel: int = 2, **kwargs):
58 | kwargs["stacklevel"] = stacklevel
59 | log.debug(*args, **kwargs)
60 |
61 | return _warn, _info, _debug
62 | # return rank_zero_only(_warn), rank_zero_only(_info), rank_zero_only(_debug)
63 |
--------------------------------------------------------------------------------
/src/utility/meta.py:
--------------------------------------------------------------------------------
1 | class Singleton(type):
2 | _instances = {}
3 |
4 | def __call__(cls, *args, **kwargs):
5 | if cls not in cls._instances:
6 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
7 | return cls._instances[cls]
8 |
--------------------------------------------------------------------------------
/src/utility/scheduler.py:
--------------------------------------------------------------------------------
1 | # noinspection PyUnresolvedReferences
2 | import logging
3 | # noinspection PyUnresolvedReferences
4 | import math
5 |
6 | # noinspection PyUnresolvedReferences
7 | import numpy as np
8 | from torch.optim import lr_scheduler
9 | # noinspection PyUnresolvedReferences
10 | from transformers import (get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup,
11 | get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup,
12 | get_polynomial_decay_schedule_with_warmup)
13 |
14 | from src.utility.logger import get_logger_func
15 |
16 | _warn, _info, _debug = get_logger_func('scheduler')
17 |
18 |
19 | def get_exponential_lr_scheduler(optimizer, gamma, **kwargs):
20 | if isinstance(gamma, str):
21 | gamma = eval(gamma)
22 | _debug(f'gamma is converted to {gamma} {type(gamma)}')
23 | kwargs['gamma'] = gamma
24 | return lr_scheduler.ExponentialLR(optimizer, **kwargs)
25 |
26 |
27 | def get_reduce_lr_on_plateau_scheduler(optimizer, **kwargs):
28 | return lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs)
29 |
30 |
31 | def get_lr_lambda_scheduler(optimizer, lr_lambda, **kwargs):
32 | if isinstance(lr_lambda, str):
33 | lr_lambda = eval(lr_lambda)
34 | _debug(f'lr_lambda is converted to {lr_lambda} {type(lr_lambda)}')
35 | kwargs['lr_lambda'] = lr_lambda
36 | return lr_scheduler.LambdaLR(optimizer, **kwargs)
37 |
--------------------------------------------------------------------------------
/src/utility/spacy_helper.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List
2 |
3 | from spacy.tokens import Doc
4 |
5 |
6 | class PretokenizedTokenizer:
7 | """Custom tokenizer to be used in spaCy when the text is already pretokenized."""
8 |
9 | def __init__(self, vocab):
10 | """Initialize tokenizer with a given vocab
11 | :param vocab: an existing vocabulary (see https://spacy.io/api/vocab)
12 | """
13 | self.vocab = vocab
14 |
15 | def __call__(self, inp: Union[List[str], str]):
16 | """Call the tokenizer on input `inp`.
17 | :param inp: either a string to be split on whitespace, or a list of tokens
18 | :return: the created Doc object
19 | """
20 | if isinstance(inp, str):
21 | words = inp.split()
22 | spaces = [True] * (len(words) - 1) + ([True] if inp[-1].isspace() else [False])
23 | return Doc(self.vocab, words=words, spaces=spaces)
24 | elif isinstance(inp, list):
25 | return Doc(self.vocab, words=inp)
26 | else:
27 | raise ValueError("Unexpected input format. Expected string to be split on whitespace, or list of tokens.")
28 |
--------------------------------------------------------------------------------
/src/utility/var_pool.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Union
2 |
3 | from fastNLP.core.utils import seq_len_to_mask
4 | from torch import Tensor
5 |
6 |
7 | class VarPool:
8 | def __init__(self, **kwargs):
9 | self._pool = {}
10 | self._lazy_func = {}
11 | self._circle_trace = []
12 |
13 | for key, value in kwargs.items():
14 | self._pool[key] = value
15 |
16 | self.add_lazy('seq_len', 'batch_size', lambda x: len(x))
17 | self.add_lazy('seq_len', 'max_len', lambda x: max(x))
18 | self.add_lazy('seq_len', 'num_token', lambda x: sum(x))
19 | self.add_lazy(['seq_len', 'max_len'], 'mask', lambda x, y: seq_len_to_mask(x, y))
20 |
21 | def add_lazy(self, source: Union[str, List[str]], target: str, func: Callable, overwrite=False):
22 | assert overwrite or target not in self._lazy_func, f'{target=}'
23 | if isinstance(source, str):
24 | source = [source]
25 | self._lazy_func[target] = (source, func)
26 |
27 | def select(self, mask):
28 | new_vp = VarPool()
29 | for key, value in self._pool.items():
30 | if key in ('batch_size', 'max_len'):
31 | continue
32 | if key.endswith('_cpu') or key.endswith('_cuda'):
33 | continue
34 | if not isinstance(value, Tensor):
35 | continue
36 | new_vp.add_lazy([], key, lambda v=value: v[mask], overwrite=True)
37 | for key, value in self._lazy_func.items():
38 | if key not in new_vp._lazy_func and not key.endswith('cuda') and not key.endswith('cpu'):
39 | new_vp.add_lazy(value[0], key, value[1], overwrite=True)
40 | return new_vp
41 |
42 | def __getitem__(self, item):
43 | if item in self._pool:
44 | return self._pool[item]
45 | if item in self._lazy_func:
46 | source, func = self._lazy_func[item]
47 | self._circle_trace.append(item)
48 | assert not any(map(lambda s: s in self._circle_trace, source))
49 | source = [self[s] for s in source]
50 | self._circle_trace.pop()
51 | target = func(*source)
52 | self[item] = target
53 | return target
54 | name, device = item.rsplit('_', 1)
55 | if device in ('cuda', 'cpu'):
56 | value = self[name].to(device)
57 | self._pool[item] = value
58 | return value
59 | raise KeyError(f'No {item}.')
60 |
61 | def __setitem__(self, key, value):
62 | self._pool[key] = value
63 | if isinstance(value, Tensor):
64 | self.add_lazy(key, key + '_cuda', lambda x: x if x.device.type == 'cuda' else x.cuda())
65 | self.add_lazy(key, key + '_cpu', lambda x: x if x.device.type == 'cpu' else x.cpu())
66 |
67 | def __getattr__(self, item):
68 | return self[item]
69 |
70 | def __setattr__(self, key, value):
71 | if key.startswith('_'):
72 | super().__setattr__(key, value)
73 | else:
74 | self._pool[key] = value
75 |
76 | def __contains__(self, key):
77 | return key in self._pool or key in self._lazy_func
78 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os.path
4 | from pathlib import Path
5 |
6 | import hydra
7 | import pytorch_lightning as pl
8 | from hydra import compose
9 | from hydra.utils import HydraConfig, instantiate
10 | from omegaconf import DictConfig
11 | from omegaconf import OmegaConf
12 |
13 | import src
14 | from src import datamodule
15 | from src.datamodule import DataModule
16 | from src.pipeline import Pipeline
17 | from src.utility.fn import instantiate_no_recursive
18 | from src.utility.pl_callback import NNICallback
19 |
20 | log = logging.getLogger(__name__)
21 |
22 |
23 | @hydra.main('config', 'config_test')
24 | def test(cfg: DictConfig):
25 | if (seed := cfg.seed) is not None:
26 | pl.seed_everything(seed)
27 |
28 | if cfg.pipeline.load_from_checkpoint is None:
29 | log.warning('Testing a random-initialized model.')
30 |
31 | if (p := cfg.pipeline.load_from_checkpoint) is not None:
32 | p = Path(p)
33 | if len(p.parts) >= 2 and p.parts[-2] == 'checkpoint':
34 | config_folder = p.parents[1] / 'config'
35 | else:
36 | config_folder = p.parent / 'config'
37 | if config_folder.exists():
38 | # Load saved config.
39 | # Note that this only load overrides. Inconsistency happens if you change sub-config's file.
40 | # From Hydra's author:
41 | # https://stackoverflow.com/questions/67170653/how-to-load-hydra-parameters-from-previous-jobs-without-having-to-use-argparse/67172466?noredirect=1
42 | log.info('Loading saved overrides')
43 | original_overrides = OmegaConf.load(config_folder / 'overrides.yaml')
44 | current_overrides = HydraConfig.get().overrides.task
45 | # hydra_config = OmegaConf.load(config_folder / 'hydra.yaml')
46 | config_name = 'config_test' # hydra_config.hydra.job.config_name
47 | overrides = original_overrides + current_overrides
48 | # noinspection PyTypeChecker
49 | cfg = compose(config_name, overrides=overrides)
50 | if os.path.exists(config_folder / 'nni.json'):
51 | with open(config_folder / 'nni.json') as f:
52 | nni_overrides = json.load(f)
53 | NNICallback.setup_cfg(nni_overrides, cfg)
54 | log.info(OmegaConf.to_yaml(cfg))
55 |
56 | src.g_cfg = cfg
57 |
58 | trainer: pl.Trainer = instantiate(cfg.trainer)
59 | src.trainer = trainer
60 |
61 | datamodule: DataModule = instantiate_no_recursive(cfg.datamodule)
62 | pipeline: Pipeline = instantiate_no_recursive(cfg.pipeline, dm=datamodule)
63 | output_name = cfg.get('output_name', 'predict')
64 | datamodule.setup('test')
65 |
66 | trainer.test(pipeline, dataloaders=datamodule.dataloader('train'))
67 | pipeline.write_prediction(output_name + '_train.conll', 'train', pipeline._test_outputs[0])
68 | trainer.test(pipeline, dataloaders=datamodule.dataloader('dev'))
69 | pipeline.write_prediction(output_name + '_dev.conll', 'dev', pipeline._test_outputs[0])
70 | trainer.test(pipeline, dataloaders=datamodule.dataloader('test'))
71 | pipeline.write_prediction(output_name + '_test.conll', 'test', pipeline._test_outputs[0])
72 |
73 |
74 | if __name__ == '__main__':
75 | test()
76 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | import os.path
6 | import random
7 | import string
8 | from pathlib import Path
9 |
10 | import hydra
11 | import pytorch_lightning as pl
12 | from hydra import compose
13 | from hydra.utils import HydraConfig, instantiate
14 | from omegaconf import DictConfig
15 | from omegaconf import OmegaConf
16 |
17 | import src
18 | from src.datamodule import DataModule
19 | from src.pipeline import Pipeline
20 | from src.utility.fn import instantiate_no_recursive
21 | from src.utility.fn import symlink_force
22 | from src.utility.logger import get_logger_func
23 | from src.utility.pl_callback import BestWatcherCallback
24 | from src.utility.pl_callback import NNICallback
25 |
26 | _warn, _info, _debug = get_logger_func('main')
27 |
28 |
29 | @hydra.main('config', 'config_train')
30 | def train(cfg: DictConfig):
31 | src.g_cfg = cfg
32 | _info(f'Working directory: {os.getcwd()}')
33 |
34 | outputs_root = os.path.join(cfg.root, 'outputs')
35 | if os.path.exists(outputs_root):
36 | symlink_force(os.getcwd(), os.path.join(outputs_root, '0_latest_run'))
37 |
38 | if cfg.name == '@@@AUTO@@@':
39 | # In the case we can not set name={hydra:job.override_dirname} in config.yaml, e.g., multirun
40 | cfg.name = HydraConfig.get().job.override_dirname
41 |
42 | # init multirun
43 | if (num := HydraConfig.get().job.get('num')) is not None and num > 1:
44 | # set group in wandb, if use joblib, this will be set from joblib.
45 | if 'MULTIRUN_ID' not in os.environ:
46 | os.environ['MULTIRUN_ID'] = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(4))
47 | if 'logger' in cfg.trainer and 'tags' in cfg.trainer.logger:
48 | cfg.trainer.logger.tags.append(os.environ['MULTIRUN_ID'])
49 |
50 | if (config_folder := cfg.load_cfg_from_checkpoint) is not None:
51 | # Load saved config.
52 | # Note that this only load overrides. Inconsistency happens if you change sub-config's file.
53 | # From Hydra's author:
54 | # https://stackoverflow.com/questions/67170653/how-to-load-hydra-parameters-from-previous-jobs-without-having-to-use-argparse/67172466?noredirect=1
55 | _info('Loading saved overrides')
56 | config_folder = Path(config_folder)
57 | original_overrides = OmegaConf.load(config_folder / 'overrides.yaml')
58 | current_overrides = HydraConfig.get().overrides.task
59 | # hydra_config = OmegaConf.load(config_folder / 'hydra.yaml')
60 | config_name = 'conf' # hydra_config.hydra.job.config_name
61 | overrides = original_overrides + current_overrides
62 | # noinspection PyTypeChecker
63 | cfg = compose(config_name, overrides=overrides)
64 | if os.path.exists(config_folder / 'nni.json'):
65 | with open(config_folder / 'nni.json') as f:
66 | nni_overrides = json.load(f)
67 | NNICallback.setup_cfg(nni_overrides, cfg)
68 | _info(OmegaConf.to_yaml(cfg))
69 | src.g_cfg = cfg
70 |
71 | if (seed := cfg.seed) is not None:
72 | pl.seed_everything(seed)
73 | # torch.use_deterministic_algorithms(True)
74 | # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
75 |
76 | assert not (cfg.pipeline.load_from_checkpoint is not None and cfg.trainer.resume_from_checkpoint is not None), \
77 | 'You should not use load_from_checkpoint and resume_from_checkpoint at the same time.'
78 | assert not cfg.watch_field.startswith('test/'), 'You should not use test set to tune hparams.'
79 |
80 | trainer: pl.Trainer = instantiate(cfg.trainer)
81 | src.trainer = trainer
82 | if 'optimized_metric' in cfg:
83 | assert any(isinstance(c, BestWatcherCallback) for c in trainer.callbacks)
84 |
85 | datamodule: DataModule = instantiate_no_recursive(cfg.datamodule)
86 | pipeline: Pipeline = instantiate_no_recursive(cfg.pipeline, dm=datamodule)
87 | trainer.fit(pipeline, datamodule)
88 |
89 | ckpt_path = "best"
90 | trainer.test(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path)
91 |
92 | _info(f'Working directory: {os.getcwd()}')
93 |
94 | # Return metric score for hyperparameter optimization
95 | callbacks = trainer.callbacks
96 | for c in callbacks:
97 | if isinstance(c, BestWatcherCallback):
98 | if c.best_model_path:
99 | _info(f'Best ckpt: {c.best_model_path}')
100 | if 'optimized_metric' in cfg:
101 | return c.best_model_metric[cfg.optimized_metric]
102 | break
103 |
104 |
105 | if __name__ == '__main__':
106 | train()
107 |
--------------------------------------------------------------------------------