├── .gitignore ├── LICENSE ├── README.md ├── assets └── task.png ├── config ├── config_test.yaml ├── config_train.yaml ├── data │ ├── vlparse.yaml │ └── vlparse_lang_only.yaml ├── exp │ ├── lang_only.yaml │ └── vlgae.yaml ├── hydra │ └── job_logging │ │ ├── custom.yaml │ │ └── nofile.yaml ├── model │ ├── embedding │ │ └── en.yaml │ ├── lang_only.yaml │ ├── metric │ │ ├── attachment.yaml │ │ └── attachment_box_rel.yaml │ ├── optimize │ │ ├── constant.yaml │ │ └── linear.yaml │ └── vlgae.yaml └── trainer │ ├── callbacks │ ├── best_watcher.yaml │ ├── early_stopping.yaml │ ├── lr_monitor.yaml │ ├── progressbar.yaml │ ├── wandb.yaml │ └── weights_summary.yaml │ ├── debug.yaml │ ├── logger │ └── wandb.yaml │ ├── test.yaml │ └── train.yaml ├── data ├── data_format.json └── vlparse.json ├── eval.py ├── requirements.txt ├── src ├── __init__.py ├── datamodule │ ├── __init__.py │ ├── datamodule.py │ ├── sampler.py │ ├── task │ │ ├── __init__.py │ │ ├── dep.py │ │ └── vlparse.py │ └── vocabulary.py ├── model │ ├── __init__.py │ ├── base.py │ ├── dmv.py │ ├── dmv_helper │ │ ├── __init__.py │ │ ├── good_init.py │ │ ├── good_init_nn.py │ │ └── km_init.py │ ├── embedding │ │ ├── __init__.py │ │ ├── embedding.py │ │ ├── fastnlp_embedding.py │ │ └── transformers_embedding.py │ ├── joint.py │ ├── ldndmv.py │ ├── nn │ │ ├── __init__.py │ │ ├── affine.py │ │ ├── affine_scorer.py │ │ ├── common.py │ │ ├── dmv_spec.py │ │ ├── dropout.py │ │ ├── multivariate_kl.py │ │ ├── scalar_mix.py │ │ └── variational_lstm.py │ ├── text_encoder │ │ ├── __init__.py │ │ ├── base.py │ │ ├── blank_encoder.py │ │ ├── mlp_encoder.py │ │ ├── multi_encoder.py │ │ └── rnn_encoder.py │ ├── torch_struct │ │ ├── __init__.py │ │ ├── deptree.py │ │ ├── distributions.py │ │ ├── dmv.py │ │ ├── helpers.py │ │ └── semirings │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── fast_semirings.py │ │ │ ├── keops.py │ │ │ ├── sample.py │ │ │ ├── semirings.py │ │ │ └── sparse_max.py │ └── vis_encoder │ │ ├── __init__.py │ │ ├── base.py │ │ └── box_rel.py ├── pipeline.py └── utility │ ├── _metric_legacy.py │ ├── alg.py │ ├── config.py │ ├── defaultlist.py │ ├── fn.py │ ├── logger.py │ ├── meta.py │ ├── metric.py │ ├── pl_callback.py │ ├── scheduler.py │ ├── spacy_helper.py │ └── var_pool.py ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /outputs 2 | /.vscode 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | ### VisualStudioCode 134 | .vscode/* 135 | !.vscode/settings.json 136 | !.vscode/tasks.json 137 | !.vscode/launch.json 138 | !.vscode/extensions.json 139 | *.code-workspace 140 | **/.vscode 141 | 142 | # JetBrains 143 | .idea/ 144 | 145 | # Lightning-Hydra-Template 146 | /configs/local/default.yaml 147 | # /data/ 148 | /logs/ 149 | /wandb/ 150 | .env 151 | .autoenv 152 | 153 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Beijing Institute for General Artificial Intelligence 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VLGAE 2 | Official Implementation for CVPR 2022 paper "Unsupervised Vision-Language Parsing: Seamlessly Bridging Visual Scene Graphs with Language Structures via Dependency Relationships" 3 | 4 |
5 | task 6 |
7 | -------------------------------------------------------------------------------- /assets/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LouChao98/VLGAE/d71d07a31c3e4c04616070a053729956108fb83d/assets/task.png -------------------------------------------------------------------------------- /config/config_test.yaml: -------------------------------------------------------------------------------- 1 | # some cfg should not belong to any submodule 2 | seed: 1 3 | project: untitled 4 | name: ${name_guard:@@@AUTO@@@} 5 | watch_field: val/loss 6 | watch_mode: min 7 | root: ${hydra:runtime.cwd} 8 | output_name: ~ 9 | 10 | pipeline: 11 | _target_: src.pipeline.Pipeline 12 | 13 | load_from_checkpoint: ~ 14 | loss_reduction_mode: token 15 | 16 | hydra: 17 | run: 18 | dir: . 19 | output_subdir: null 20 | job: 21 | env_set: 22 | TOKENIZERS_PARALLELISM: 'false' 23 | HF_DATASETS_OFFLINE: '1' 24 | TRANSFORMERS_OFFLINE: '1' 25 | TORCH_WARN_ONCE: '1' 26 | NUMEXPR_MAX_THREADS: '8' 27 | DEBUG_MODE: '' 28 | 29 | defaults: 30 | - _self_ 31 | - trainer: train 32 | - data: vlparse 33 | - model: vlgae -------------------------------------------------------------------------------- /config/config_train.yaml: -------------------------------------------------------------------------------- 1 | # some cfg should not belong to any submodule 2 | seed: ~ 3 | project: untitled 4 | name: ${name_guard:@@@AUTO@@@} 5 | watch_field: val/loss 6 | watch_mode: min 7 | root: ${hydra:runtime.cwd} 8 | load_cfg_from_checkpoint: ~ 9 | 10 | pipeline: 11 | _target_: src.pipeline.Pipeline 12 | 13 | load_from_checkpoint: ~ 14 | loss_reduction_mode: token 15 | 16 | hydra: 17 | sweep: 18 | dir: outputs/multirun/${now:%Y-%m-%d_%H-%M-%S} 19 | subdir: ${path_guard:${hydra.job.override_dirname}} 20 | run: 21 | dir: outputs/${path_guard:${name}}/${now:%Y-%m-%d_%H-%M-%S} 22 | output_subdir: config 23 | job: 24 | env_set: 25 | TOKENIZERS_PARALLELISM: 'false' 26 | HF_DATASETS_OFFLINE: '1' 27 | TRANSFORMERS_OFFLINE: '1' 28 | TORCH_WARN_ONCE: '1' 29 | NUMEXPR_MAX_THREADS: '8' 30 | DEBUG_MODE: '0' 31 | 32 | defaults: 33 | - _self_ 34 | - trainer: train 35 | - data: vlparse 36 | - model: vlgae -------------------------------------------------------------------------------- /config/data/vlparse.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | _target_: src.datamodule.task.VLParseDataModule 5 | train_path: ${root}/data/vlparse/train 6 | train_init_path: ${root}/data/vlparse/init 7 | dev_path: ${root}/data/vlparse/val 8 | test_path: ${root}/data/vlparse/test 9 | 10 | use_img: false 11 | use_gold_scene_graph: false 12 | sg_path: ${root}/data/vlparse/vlparse.json 13 | 14 | use_tag: true 15 | num_lex: 200 16 | num_token: 99999 17 | ignore_stop_word: false 18 | 19 | normalize_word: true 20 | build_no_create_entry: true 21 | max_len: 22 | train: 10 23 | 24 | train_dataloader: 25 | token_size: 5000 26 | num_bucket: 10 27 | batch_size: 64 28 | dev_dataloader: 29 | token_size: 5000 30 | num_bucket: 8 31 | batch_size: 64 32 | test_dataloader: 33 | token_size: 5000 34 | num_bucket: 8 35 | batch_size: 64 36 | 37 | trainer: 38 | val_check_interval: 0.5 -------------------------------------------------------------------------------- /config/data/vlparse_lang_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | datamodule: 4 | _target_: src.datamodule.task.DepDataModule 5 | train_path: ${root}/data/vlparse/train.conll 6 | train_init_path: ${root}/vlparse/init.conll 7 | dev_path: ${root}/data/vlparse/val.conll 8 | test_path: ${root}/data/vlparse/test.conll 9 | 10 | use_tag: true 11 | num_lex: 200 12 | num_token: 99999 13 | ignore_stop_word: false 14 | 15 | normalize_word: true 16 | build_no_create_entry: true 17 | 18 | train_dataloader: 19 | token_size: 5000 20 | num_bucket: 10 21 | batch_size: 64 22 | dev_dataloader: 23 | num_bucket: 8 24 | token_size: 10000 25 | test_dataloader: 26 | num_bucket: 8 27 | token_size: 10000 28 | max_len: 29 | train: 15 30 | 31 | trainer: 32 | val_check_interval: 0.5 -------------------------------------------------------------------------------- /config/exp/lang_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: vlparse_lang_only 5 | - override /model: lang_only 6 | 7 | datamodule: 8 | num_lex: 0 9 | ignore_stop_word: true 10 | 11 | dataloader: 12 | default: 13 | batch_size: 16 14 | 15 | encoder: 16 | hidden_size: 400 17 | num_layers: 3 18 | lstm_dropout: 0.2 19 | 20 | model: 21 | init_method: 'y' 22 | context_mode: 'hx' 23 | init_epoch: 3 24 | 25 | mid_ff: 26 | n_bottleneck: 0 27 | n_mid: 100 28 | dropout: 0.2 29 | root_emb_dim: 10 30 | dec_emb_dim: 10 31 | 32 | variational_mode: 'none' 33 | z_dim: 64 34 | 35 | optimizer: 36 | args: 37 | lr: 0.0005 38 | 39 | _rank: 32 40 | _dropout: 0.5 41 | _hidden_size: 384 42 | project: unnamed -------------------------------------------------------------------------------- /config/exp/vlgae.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - override /data: vlparse 5 | - override /model: vlgae 6 | - override /model/optimize@optimize: linear 7 | 8 | datamodule: 9 | num_lex: 0 10 | max_len: 11 | train: 50 12 | 13 | trainer: 14 | val_check_interval: 0.5 15 | max_epochs: 50 16 | 17 | optimizer: 18 | args: 19 | lr: 1.0e-3 20 | 21 | project: unnamed -------------------------------------------------------------------------------- /config/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | 3 | version: 1 4 | formatters: 5 | console: 6 | (): src.utility.logger.ColorFormatter 7 | format: '%(message)s' 8 | detail: 9 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 10 | datefmt: '%y-%m-%d %H:%M:%S' 11 | handlers: 12 | console: 13 | class: src.utility.logger.TqdmLoggingHandler 14 | formatter: console 15 | level: DEBUG 16 | file: 17 | class: logging.FileHandler 18 | formatter: detail 19 | filename: ${hydra.job.name}.log 20 | root: 21 | handlers: [console, file] 22 | loggers: 23 | fastNLP: 24 | handlers: [console, file] 25 | lightning: 26 | handlers: [console, file] 27 | nni: 28 | handlers: [console, file] 29 | disable_existing_loggers: false 30 | -------------------------------------------------------------------------------- /config/hydra/job_logging/nofile.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | 3 | version: 1 4 | formatters: 5 | console: 6 | (): src.utility.logger.ColorFormatter 7 | format: '[%(name)s] %(message)s' 8 | handlers: 9 | console: 10 | class: src.utility.logger.TqdmLoggingHandler 11 | formatter: console 12 | level: DEBUG 13 | root: 14 | handlers: [console] 15 | loggers: 16 | fastNLP: 17 | handlers: [console] 18 | lightning: 19 | handlers: [console] 20 | nni: 21 | handlers: [console] 22 | disable_existing_loggers: false 23 | -------------------------------------------------------------------------------- /config/model/embedding/en.yaml: -------------------------------------------------------------------------------- 1 | # @package embedding 2 | 3 | # embedding args 4 | use_word: true 5 | use_tag: true 6 | use_subword: false 7 | dropout: 0. 8 | 9 | # embedding item args 10 | word_embedding: 11 | args: 12 | _target_: fastNLP.embeddings.StaticEmbedding 13 | model_dir_or_name: ${..._emb_mapping.glove100} 14 | min_freq: 2 15 | lower: true 16 | adaptor_args: 17 | _target_: src.model.embedding.FastNLPEmbeddingVariationalAdaptor 18 | mode: basic 19 | out_dim: 0 20 | field: word 21 | normalize_method: mean+std 22 | normalize_time: begin 23 | tag_embedding: 24 | args: 25 | _target_: fastNLP.embeddings.StaticEmbedding 26 | embedding_dim: 100 27 | init_embed: normal 28 | adaptor_args: 29 | _target_: src.model.embedding.FastNLPEmbeddingAdaptor 30 | field: tag 31 | normalize_method: mean+std 32 | normalize_time: begin 33 | transformer: 34 | args: 35 | _target_: src.model.embedding.TransformersEmbedding 36 | model: bert-base-cased 37 | n_layers: 1 38 | n_out: 0 39 | requires_grad: false 40 | adaptor_args: 41 | _target_: src.model.embedding.TransformersAdaptor 42 | field: subword 43 | requires_vocab: false 44 | 45 | 46 | # others 47 | _emb_mapping: 48 | glove100: ${root}/data/glove/glove.6B.100d.txt 49 | glove300: ${root}/data/glove/glove.840B.300d.txt 50 | glove6b_300: ${root}/data/glove/glove.6B.300d.txt 51 | bio: ${root}/data/bio_nlp_vec/PubMed-shuffle-win-30.txt 52 | jose100: ${root}/data/jose/jose_100d.txt -------------------------------------------------------------------------------- /config/model/lang_only.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - embedding: en 5 | - metric: attachment 6 | - optimize: linear 7 | 8 | encoder: 9 | _target_: src.model.text_encoder.RNNEncoder 10 | reproject_emb: 0 11 | reproject_out: 0 12 | mix: false 13 | pre_shared_dropout: 0.1 14 | pre_dropout: 0.1 15 | post_shared_dropout: 0.1 16 | post_dropout: 0.1 17 | hidden_size: 200 18 | proj_size: 0 19 | num_layers: 2 20 | output_layers: -1 21 | init_version: zy 22 | shared_dropout: true 23 | lstm_dropout: 0.33 24 | 25 | _hidden_size: 500 26 | _dropout: 0.33 27 | _rank: 32 28 | 29 | model: 30 | _target_: src.model.DiscriminativeNDMV 31 | _recursive_: false 32 | context_mode: hx 33 | init_method: 'y' 34 | init_epoch: 3 35 | viterbi_training: true 36 | mbr_decoding: false 37 | extended_valence: true 38 | function_mask: false 39 | 40 | variational_mode: 'none' 41 | z_dim: 0 42 | 43 | mid_ff: 44 | _target_: src.model.nn.DMVSkipConnectEncoder 45 | n_bottleneck: 0 46 | n_mid: 0 47 | dropout: 0. 48 | 49 | head_ff: 50 | _target_: src.model.nn.MLP 51 | n_hidden: ${_hidden_size} 52 | dropout: ${_dropout} 53 | child_ff: 54 | _target_: src.model.nn.MLP 55 | n_hidden: ${_hidden_size} 56 | dropout: ${_dropout} 57 | root_ff: 58 | _target_: src.model.nn.MLP 59 | n_hidden: ${_hidden_size} 60 | dropout: ${_dropout} 61 | dec_ff: 62 | _target_: src.model.nn.MLP 63 | n_hidden: ${_hidden_size} 64 | dropout: ${_dropout} 65 | 66 | attach_rank: ${_rank} 67 | dec_rank: ${_rank} 68 | root_rank: ${_rank} 69 | 70 | root_emb_dim: 50 71 | dec_emb_dim: 50 -------------------------------------------------------------------------------- /config/model/metric/attachment.yaml: -------------------------------------------------------------------------------- 1 | # @package metric 2 | 3 | _target_: src.utility.metric.DependencyParsingMetric 4 | -------------------------------------------------------------------------------- /config/model/metric/attachment_box_rel.yaml: -------------------------------------------------------------------------------- 1 | # @package metric 2 | _target_: src.utility.metric.MultiMetric 3 | 4 | dep: 5 | _target_: src.utility.metric.DependencyParsingMetric 6 | extra_vocab: ${..extra_vocab} 7 | 8 | img: 9 | _target_: src.utility.metric.FactorImageMatchingMetric 10 | extra_vocab: ${..extra_vocab} 11 | 12 | match: 13 | _target_: src.utility.metric.BoxRelMatchingMetric 14 | extra_vocab: ${..extra_vocab} 15 | -------------------------------------------------------------------------------- /config/model/optimize/constant.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optimizer: 4 | groups: 5 | - pattern: dependency.embedding.transformer 6 | lr: 1.0e-5 7 | args: 8 | _target_: torch.optim.Adam 9 | lr: 1.0e-3 10 | betas: [ 0.9, 0.999 ] 11 | weight_decay: 0. 12 | eps: 1.0e-12 13 | 14 | scheduler: ~ -------------------------------------------------------------------------------- /config/model/optimize/linear.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | optimizer: 4 | groups: [] 5 | args: 6 | _target_: torch.optim.Adam 7 | lr: 1.0e-3 8 | betas: [ 0.9, 0.999 ] 9 | weight_decay: 0. 10 | eps: 1.0e-12 11 | 12 | scheduler: 13 | interval: step 14 | frequency: 1 15 | args: 16 | _target_: src.utility.scheduler.get_exponential_lr_scheduler 17 | gamma: 0.75**(1/2000) 18 | -------------------------------------------------------------------------------- /config/model/vlgae.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - metric: attachment_box_rel 5 | - embedding: en 6 | - optimize: constant 7 | 8 | _match_hidden_size: 128 9 | _hidden_size: 256 10 | _dropout: 0.33 11 | _rank: 16 12 | 13 | embedding: 14 | use_word: false 15 | use_subword: true 16 | use_tag: true 17 | tag_embedding: 18 | args: 19 | embedding_dim: 32 20 | 21 | encoder: 22 | _target_: src.model.text_encoder.MLPEncoder 23 | dropout: 0.33 24 | shared_dropout: 0 25 | n_hidden: ${_hidden_size} 26 | 27 | vis_encoder: 28 | _target_: src.model.vis_encoder.VisBoxRelSimpleEncoder 29 | n_in: 2048 30 | n_hidden: ${_hidden_size} 31 | dropout: 0. 32 | activate: true 33 | use_attr: true 34 | use_img: false 35 | img_feat: true 36 | 37 | model: 38 | _target_: src.model.DependencyBoxRel 39 | _recursive_: false 40 | 41 | add_rel: true 42 | add_attr: true 43 | add_image: true 44 | add_marginal: true 45 | 46 | margin: 1 47 | language_factor_mode: word+maxdep 48 | visual_factor_mode: unprune 49 | visual_factor_cfg: 50 | n_hidden: ${_match_hidden_size} 51 | feat_fuse_mode: attention 52 | feat_fuse_args: 53 | num_heads: 4 54 | dropout: 0.33 55 | replace: false 56 | aug_with_matching: true 57 | gather_logit_mode: simple 58 | gather_logit_args: ~ 59 | loss_grounding_mode: factor|ce 60 | loss_grounding_args: 61 | use_pos_prior: true 62 | vis2txt: 1 63 | decode_grounding_mode: on_factor 64 | decode_grounding_args: 65 | use_pos_prior: true 66 | use_heuristic: true 67 | grounding_interpolation: 0.5 68 | 69 | word_encoder: 70 | _target_: src.model.nn.MLP 71 | n_hidden: ${_match_hidden_size} 72 | dropout: 0.33 73 | activate: false 74 | 75 | init_method: 'y' 76 | init_epoch: 5 77 | 78 | dep_model_cfg: 79 | _target_: src.model.DiscriminativeNDMV 80 | _recursive_: false 81 | context_mode: 'mean' 82 | init_method: ${..init_method} 83 | init_epoch: ${..init_epoch} 84 | viterbi_training: true 85 | mbr_decoding: false 86 | extended_valence: true 87 | function_mask: false 88 | 89 | variational_mode: 'none' 90 | z_dim: 0 91 | 92 | mid_ff: 93 | _target_: src.model.nn.DMVSkipConnectEncoder 94 | n_bottleneck: 150 95 | n_mid: 0 96 | dropout: 0.3 97 | 98 | head_ff: 99 | _target_: src.model.nn.MLP 100 | n_hidden: ${_hidden_size} 101 | dropout: ${_dropout} 102 | child_ff: 103 | _target_: src.model.nn.MLP 104 | n_hidden: ${_hidden_size} 105 | dropout: ${_dropout} 106 | root_ff: 107 | _target_: src.model.nn.MLP 108 | n_hidden: ${_hidden_size} 109 | dropout: ${_dropout} 110 | dec_ff: 111 | _target_: src.model.nn.MLP 112 | n_hidden: ${_hidden_size} 113 | dropout: ${_dropout} 114 | 115 | attach_rank: ${_rank} 116 | dec_rank: ${_rank} 117 | root_rank: ${_rank} 118 | 119 | root_emb_dim: 10 120 | dec_emb_dim: 10 -------------------------------------------------------------------------------- /config/trainer/callbacks/best_watcher.yaml: -------------------------------------------------------------------------------- 1 | best_watcher: 2 | _target_: src.utility.pl_callback.BestWatcherCallback 3 | monitor: ${watch_field} 4 | mode: ${watch_mode} 5 | hint: true 6 | save: 7 | dirpath: checkpoint 8 | filename: "{epoch}-{step}-{${watch_field}:.2f}" 9 | start_patience: 2 10 | write: 'new' 11 | report: true 12 | -------------------------------------------------------------------------------- /config/trainer/callbacks/early_stopping.yaml: -------------------------------------------------------------------------------- 1 | early_stopping: 2 | _target_: pytorch_lightning.callbacks.EarlyStopping 3 | monitor: ${watch_field} 4 | mode: ${watch_mode} 5 | patience: 100 6 | -------------------------------------------------------------------------------- /config/trainer/callbacks/lr_monitor.yaml: -------------------------------------------------------------------------------- 1 | lr_monitor: 2 | _target_: src.utility.pl_callback.LearningRateMonitorWithEarlyStop 3 | logging_interval: 'epoch' # None, step, epoch. None=following scheduler 4 | minimum_lr: 1e-8 5 | -------------------------------------------------------------------------------- /config/trainer/callbacks/progressbar.yaml: -------------------------------------------------------------------------------- 1 | progress_bar: 2 | _target_: src.utility.pl_callback.MyProgressBar 3 | refresh_rate: 1 4 | process_position: 0 5 | 6 | #progress_bar: 7 | # _target_: pytorch_lightning.callbacks.RichProgressBar -------------------------------------------------------------------------------- /config/trainer/callbacks/wandb.yaml: -------------------------------------------------------------------------------- 1 | wandb: 2 | _target_: src.utility.pl_callback.WatchModelWithWandb 3 | log: ${in_debugger:gradients,null} # all, gradients, parameters, None 4 | log_freq: 100 5 | -------------------------------------------------------------------------------- /config/trainer/callbacks/weights_summary.yaml: -------------------------------------------------------------------------------- 1 | weights_summary: 2 | _target_: pytorch_lightning.callbacks.ModelSummary 3 | max_depth: ${in_debugger:5,2} -------------------------------------------------------------------------------- /config/trainer/debug.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - logger: ~ 3 | - callbacks: 4 | - progressbar 5 | - early_stopping 6 | - weights_summary 7 | # - swa 8 | - override /hydra/job_logging@_global_.hydra.job_logging: nofile 9 | 10 | hydra: 11 | job: 12 | env_set: 13 | DEBUG_MODE: '1' 14 | 15 | _target_: src.utility.fn.instantiate_trainer 16 | 17 | fast_dev_run: 3 18 | checkpoint_callback: false 19 | 20 | gpus: 1 21 | gradient_clip_val: 5. 22 | track_grad_norm: -1 23 | # max_epochs: 1000 # due to fast_dev_run 24 | max_steps: -1 25 | val_check_interval: 1.0 # int for n epoch, float for in epoch 26 | accumulate_grad_batches: 1 27 | precision: 32 28 | # num_sanity_val_steps: 2 # due to fast_dev_run 29 | resume_from_checkpoint: ~ 30 | detect_anomaly: true 31 | deterministic: false 32 | 33 | # following are settings you should not touch in most cases 34 | accelerator: ${accelerator:${.gpus}} 35 | replace_sampler_ddp: false 36 | multiple_trainloader_mode: min_size 37 | enable_model_summary: false -------------------------------------------------------------------------------- /config/trainer/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | 2 | _target_: pytorch_lightning.loggers.WandbLogger 3 | name: ${name} 4 | project: ${project} 5 | tags: [] 6 | save_code: false 7 | save_dir: ${root}/outputs 8 | -------------------------------------------------------------------------------- /config/trainer/test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - logger: ~ 3 | - callbacks: 4 | - progressbar 5 | - override /hydra/job_logging@_global_.hydra.job_logging: nofile 6 | 7 | _target_: src.utility.fn.instantiate_trainer 8 | 9 | enable_checkpointing: false 10 | logger: ~ 11 | 12 | gpus: 1 13 | precision: 32 14 | resume_from_checkpoint: ~ 15 | 16 | # following are settings you should not touch in most cases 17 | accelerator: ${accelerator:${.gpus}} 18 | detect_anomaly: false 19 | replace_sampler_ddp: false 20 | enable_model_summary: false 21 | -------------------------------------------------------------------------------- /config/trainer/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - logger: ~ 3 | - callbacks: 4 | - progressbar 5 | # - early_stopping 6 | - lr_monitor # may early-stopping 7 | - best_watcher 8 | - weights_summary 9 | - override /hydra/job_logging@_global_.hydra.job_logging: custom 10 | 11 | _target_: src.utility.fn.instantiate_trainer 12 | 13 | gpus: 1 14 | gradient_clip_val: 5. 15 | track_grad_norm: -1 16 | max_epochs: 50 17 | max_steps: -1 18 | val_check_interval: 1.0 # int for n step, float for in epoch 19 | accumulate_grad_batches: 1 20 | precision: 32 21 | num_sanity_val_steps: ${in_debugger:1,5} 22 | resume_from_checkpoint: ~ 23 | detect_anomaly: false 24 | deterministic: false 25 | 26 | # following are settings you should not touch in most cases 27 | accelerator: gpu 28 | strategy: ${accelerator:${.gpus}} 29 | replace_sampler_ddp: false 30 | multiple_trainloader_mode: min_size 31 | enable_model_summary: false 32 | -------------------------------------------------------------------------------- /data/data_format.json: -------------------------------------------------------------------------------- 1 | { 2 | "": { 3 | "image": { 4 | "coco_id": 0, // MSCOCO id 5 | "vg_id": 0, // VisualGenome id 6 | "height": 0, 7 | "width": 0 8 | }, 9 | "box": { 10 | "": { 11 | "width": 0.0, // percentage of image width 12 | "height": 0.0, // percentage of image height 13 | "x": 0.0, // percentage of image width 14 | "y": 0.0, // percentage of image height 15 | "label": "region label from VisualGenome", 16 | "attribute": "list of attributes separated by semicolon" 17 | }, 18 | ... 19 | }, 20 | "relationship":{ 21 | "", 23 | "to": "", 24 | "label": "relationship label from VisualGenome" 25 | }, 26 | ... 27 | }, 28 | "sentence": { 29 | "": { 30 | "text": "the sentence", 31 | "pos": "part-of-speech tags", 32 | "dephead": "dependency heads", 33 | "span": { 34 | "": { // object 35 | "label": "object", 36 | "start": 0, // inclusive character offset 37 | "end": 0, // exclusive character offset 38 | "attribute_start": 0, // inclusive character offset 39 | "attribute_end": 0, // exclusive character offset, (0,0)=no attribute 40 | "text": "text", 41 | "attribute_text": "attribute_text", 42 | "alignment": [""] 43 | }, 44 | ""], 50 | "alignment": [""] 51 | }, 52 | ... 53 | } 54 | }, 55 | ... 56 | } 57 | }, 58 | ... 59 | } -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import conllu 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--gold", action="store_true", help="whether to use gold boxes instead of proposals" 9 | ) 10 | parser.add_argument( 11 | "--file", 12 | help="path to the prediction", 13 | default="outputs/0_latest_run/dev.predict.txt", 14 | ) 15 | parser.add_argument( 16 | "--dataroot", 17 | help="path to VLParse", 18 | default="data/vlparse", 19 | ) 20 | args = parser.parse_args() 21 | 22 | id_list_path = f"{args.dataroot}/id_list/val.txt" 23 | predict_path = args.file 24 | 25 | if args.gold: 26 | with open(f"{args.dataroot}/dev_gold_boxes.json") as f: 27 | img2boxes = json.load(f) 28 | else: 29 | with open(f"{args.dataroot}/dev_roi_boxes.json") as f: 30 | img2boxes = json.load(f) 31 | img2boxes = {int(key): value for key, value in img2boxes.items()} 32 | 33 | with open(f"{args.dataroot}/vlparse.json") as f: 34 | gold = json.load(f) 35 | gold = {item["coco_id"]: item for item in gold if isinstance(item, dict)} 36 | 37 | 38 | id_list = [line for line in Path(id_list_path).read_text().splitlines()] 39 | img_ids = [int(item) for item in id_list for _ in range(5)] 40 | sent_ids = [item for _ in id_list for item in range(5)] 41 | predict = list( 42 | conllu.parse_incr(open(predict_path), fields=["ID", "FORM", "POS", "HEAD", "ALIGN"]) 43 | ) 44 | has_vg = [item in gold for item in img_ids] 45 | img_ids = [item for item, flag in zip(img_ids, has_vg) if flag] 46 | sent_ids = [item for item, flag in zip(sent_ids, has_vg) if flag] 47 | # predict = [item for item, flag in zip(predict, has_vg) if flag] 48 | print(len(sent_ids), len(predict)) 49 | 50 | 51 | def get_position(item): 52 | return item["x"], item["y"], item["x"] + item["width"], item["y"] + item["height"] 53 | 54 | 55 | def bb_intersection_over_union(boxA, boxB): 56 | # boxA = [int(x) for x in boxA] 57 | # boxB = [int(x) for x in boxB] 58 | 59 | xA = max(boxA[0], boxB[0]) 60 | yA = max(boxA[1], boxB[1]) 61 | xB = min(boxA[2], boxB[2]) 62 | yB = min(boxA[3], boxB[3]) 63 | 64 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 65 | 66 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 67 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 68 | 69 | iou = interArea / float(boxAArea + boxBArea - interArea) 70 | 71 | return iou 72 | 73 | 74 | obj_correct = 0 75 | obj_total = 0 76 | attr_correct = 0 77 | attr_total = 0 78 | rel_correct = 0 79 | rel_total = 0 80 | 81 | 82 | def test(boxA, boxB): 83 | if bb_intersection_over_union(boxA, boxB) >= 0.5: 84 | return True 85 | return False 86 | 87 | 88 | for idx in range(len(predict)): 89 | img_id, sent_id = img_ids[idx], sent_ids[idx] 90 | 91 | # obj 92 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items(): 93 | if data["type"] != "OBJ": 94 | continue 95 | correct_flag = False 96 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"): 97 | pred_type, pred_id = item.split() 98 | if pred_type == "obj": 99 | word_predict = img2boxes[img_id][int(pred_id)] 100 | correct_flag = False 101 | for obj_id, _ in data["candidates"]: 102 | position = get_position(gold[img_id]["obj"][obj_id]) 103 | if test(word_predict, position): 104 | correct_flag = True 105 | break 106 | if correct_flag: 107 | obj_correct += 1 108 | break 109 | obj_total += 1 110 | 111 | # attr 112 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items(): 113 | if data["type"] != "ATTR": 114 | continue 115 | correct_flag = False 116 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"): 117 | pred_type, pred_id = item.split() 118 | if pred_type == "attr": 119 | try: 120 | word_predict = img2boxes[img_id][int(pred_id)] 121 | except IndexError: 122 | print(img_id, sent_id) 123 | correct_flag = False 124 | for obj_id, _ in data["candidates"]: 125 | position = get_position(gold[img_id]["obj"][obj_id]) 126 | if test(word_predict, position): 127 | correct_flag = True 128 | break 129 | if correct_flag: 130 | attr_correct += 1 131 | break 132 | attr_total += 1 133 | 134 | # rel 135 | for word_idx, data in gold[img_id]["txt2sg"][sent_id].items(): 136 | if data["type"] != "REL": 137 | continue 138 | correct_flag = False 139 | for item in predict[idx][int(word_idx)]["ALIGN"].split("|"): 140 | pred_type, pred_id = item.split() 141 | if pred_type == "rel": 142 | obj1, obj2 = pred_id.split("-") 143 | obj1 = img2boxes[img_id][int(obj1)] 144 | obj2 = img2boxes[img_id][int(obj2)] 145 | 146 | correct_flag = False 147 | for rel_id, _ in data["candidates"]: 148 | rel_item = gold[img_id]["rel"][rel_id - len(gold[img_id]["obj"])] 149 | assert rel_item["id"] == rel_id 150 | gold_obj1 = get_position(gold[img_id]["obj"][rel_item["subj"]]) 151 | gold_obj2 = get_position(gold[img_id]["obj"][rel_item["obj"]]) 152 | 153 | if test(obj1, gold_obj1) and test(obj2, gold_obj2): 154 | correct_flag = True 155 | break 156 | if test(obj2, gold_obj1) and test(obj1, gold_obj2): 157 | correct_flag = True 158 | break 159 | if correct_flag: 160 | rel_correct += 1 161 | break 162 | rel_total += 1 163 | 164 | 165 | print("obj", obj_correct / obj_total, obj_total) 166 | print("attr", attr_correct / attr_total, attr_total) 167 | print("rel", rel_correct / rel_total, rel_total) 168 | print( 169 | "0-order", 170 | (obj_correct + attr_correct + rel_correct) / (obj_total + attr_total + rel_total), 171 | ) 172 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | pytorch-lightning 3 | transformers 4 | easydict 5 | colorama 6 | fastnlp 7 | nltk 8 | wandb 9 | matplotlib 10 | seaborn 11 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Optional, Mapping 4 | 5 | import numpy as np 6 | import pytorch_lightning 7 | import torch 8 | from easydict import EasyDict 9 | from hydra._internal.utils import is_under_debugger as _is_under_debugger 10 | from hydra.utils import HydraConfig 11 | from omegaconf import ListConfig, OmegaConf 12 | 13 | from src.utility.logger import get_logger_func 14 | 15 | _warn, _info, _debug = get_logger_func('src') 16 | 17 | g_cfg = EasyDict({ 18 | 'MANUAL': 1, 19 | }) # globel configuration obj 20 | trainer: Optional[pytorch_lightning.Trainer] = None 21 | debugging = False 22 | 23 | # >>> setup logger 24 | 25 | pl_logger = logging.getLogger('lightning') 26 | pl_logger.propagate = False 27 | 28 | fastnlp_logger = logging.getLogger('fastNLP') 29 | fastnlp_logger.propagate = False 30 | 31 | wandb_logger = logging.getLogger('wandb') 32 | # wandb_logger.propagate = False 33 | 34 | # >>> setup OmegaConf 35 | 36 | # OmegaConf.register_new_resolver('in', lambda x, y: x in y) 37 | OmegaConf.register_new_resolver('lang', lambda x: x.split('_')[0]) 38 | OmegaConf.register_new_resolver('last', lambda x: x.split('/')[-1]) 39 | OmegaConf.register_new_resolver('div2', lambda x: x // 2) 40 | # OmegaConf.register_new_resolver('cat', lambda x, y: x + y) 41 | 42 | _hit_debug = True 43 | 44 | 45 | def is_under_debugger(): 46 | if os.environ.get('DEBUG_MODE', '').lower() in ('true', 't', '1', 'yes', 'y'): 47 | result = True 48 | else: 49 | result = _is_under_debugger() 50 | global _hit_debug, debugging 51 | if result and _hit_debug: 52 | _warn("Debug mode.") 53 | _hit_debug = False 54 | debugging = True 55 | return result 56 | 57 | 58 | OmegaConf.register_new_resolver('in_debugger', lambda x, default=None: x if is_under_debugger() else default) 59 | 60 | 61 | def path_guard(x: str): 62 | x = x.split(',') 63 | x.sort() 64 | x = '_'.join(x) 65 | x = x.replace('/', '-') 66 | x = x.replace('=', '-') 67 | return x[:240] 68 | 69 | 70 | OmegaConf.register_new_resolver('path_guard', path_guard) 71 | 72 | 73 | def half_int(x): 74 | assert x % 2 == 0 75 | return x // 2 76 | 77 | 78 | OmegaConf.register_new_resolver('half_int', half_int) 79 | 80 | 81 | def name_guard(fallback): 82 | try: 83 | return HydraConfig.get().job.override_dirname 84 | except ValueError as v: 85 | if 'HydraConfig was not set' in str(v): 86 | return fallback 87 | raise v 88 | 89 | 90 | OmegaConf.register_new_resolver('name_guard', name_guard) 91 | 92 | 93 | def choose_accelerator(gpus): 94 | if isinstance(gpus, int): 95 | return 'ddp' if gpus > 1 else None 96 | elif isinstance(gpus, str): 97 | return 'ddp' if len(gpus.split(',')) > 1 else None 98 | elif isinstance(gpus, (list, ListConfig)): 99 | return 'ddp' if len(gpus) > 1 else None 100 | elif gpus is None: 101 | return None 102 | raise ValueError(f'Unrecognized {gpus=} ({type(gpus)})') 103 | 104 | 105 | OmegaConf.register_new_resolver('accelerator', choose_accelerator) 106 | 107 | 108 | # >>> setup inf 109 | 110 | INF = 1e20 111 | 112 | 113 | def setup_inf(v): 114 | global INF 115 | import src.model.torch_struct as stt 116 | INF = v 117 | stt.semirings.semirings.NEGINF = -INF 118 | 119 | 120 | setup_inf(1e20) 121 | 122 | 123 | # pl patch 124 | 125 | def _extract_batch_size(batch): 126 | if isinstance(batch, torch.Tensor): 127 | yield batch.shape[0] 128 | elif isinstance(batch, np.ndarray): 129 | yield batch.shape[0] 130 | elif isinstance(batch, str): 131 | yield len(batch) 132 | elif isinstance(batch, Mapping): 133 | for sample in batch: 134 | yield from _extract_batch_size(sample) 135 | else: 136 | x, y = batch 137 | yield len(x['id']) 138 | 139 | 140 | from pytorch_lightning.utilities import data as pludata 141 | 142 | pludata._extract_batch_size = _extract_batch_size 143 | -------------------------------------------------------------------------------- /src/datamodule/__init__.py: -------------------------------------------------------------------------------- 1 | from src.datamodule.datamodule import DataModule -------------------------------------------------------------------------------- /src/datamodule/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | from math import ceil 4 | from typing import List 5 | 6 | import torch 7 | from fastNLP import RandomSampler, SequentialSampler 8 | 9 | 10 | from src.utility.logger import get_logger_func 11 | 12 | _warn, _info, _debug = get_logger_func("sampler") 13 | 14 | 15 | class ConstantTokenNumSampler: 16 | def __init__( 17 | self, 18 | seq_len: List[int], 19 | max_token: int = 4096, 20 | max_sentence: int = -1, 21 | num_bucket: int = 16, 22 | single_sent_threshold: int = -1, 23 | sort_in_batch: bool = True, 24 | shuffle: bool = True, 25 | force_same_len: bool = False, 26 | ): 27 | """ 28 | :param List[int] seq_len: sample 的长度的列表。 29 | :param int max_token: 每个 batch 的最大的 token 数量 30 | :param int max_sentence: 每个 batch 最大的句子数量,与 max_token 同时生效, <=0 不生效 31 | :param int num_bucket: 将数据按长度拆分为 num_bucket 个 bucket 32 | :param int single_sent_threshold: 长度大于阈值的句子强制 batch_size=1, -1 不生效 33 | :param bool sort_in_batch: 使得一个 batch 内句子长度降序 34 | :param bool shuffle: shuffle 35 | :param bool force_same_len: 忽略 num_buckt, 每个长度为一个桶, 每个 batch 中所有的句子长度相同 36 | """ 37 | 38 | assert ( 39 | len(seq_len) >= num_bucket 40 | ), "The number of samples should be larger than buckets." 41 | assert ( 42 | num_bucket > 1 or force_same_len 43 | ), "Use RandomSampler if you do not need bucket." 44 | 45 | self.seq_len = seq_len 46 | self.max_token = max_token 47 | self.max_sentence = max_sentence if max_sentence > 0 else 10000000000000000 48 | self.single_sent_threshold = single_sent_threshold 49 | self.sort_in_batch = sort_in_batch and not force_same_len 50 | self.shuffle = shuffle 51 | self.epoch = 0 # +=1 everytime __iter__ is called. 52 | 53 | # sizes: List[int], pseudo size of each buckets. 54 | # buckets: List[List[int]], each one is a bucket, containing idx. 55 | if force_same_len: 56 | self.sizes = list(set(seq_len)) 57 | len2idx = dict((l, i) for i, l in enumerate(self.sizes)) 58 | self.buckets = [[] for _ in range(len(self.sizes))] 59 | for i, l in enumerate(seq_len): 60 | self.buckets[len2idx[l]].append(i) 61 | else: 62 | self.sizes, self.buckets = self.kmeans(seq_len, num_bucket) 63 | 64 | # chunks: List[int], n chunk for each bucket 65 | self.chunks = [ 66 | min( 67 | len(bucket), 68 | max( 69 | ceil(size * len(bucket) / max_token), 70 | ceil(len(bucket) / max_sentence), 71 | ), 72 | ) 73 | for size, bucket in zip(self.sizes, self.buckets) 74 | ] 75 | 76 | self._batches = [] 77 | self._all_batches = [] # including other workers 78 | self._exhausted = True 79 | self._init_iter_with_retry() # init here for valid __len__ at any time. 80 | 81 | def __iter__(self): 82 | self._init_iter_with_retry() 83 | yield from self._batches 84 | self._exhausted = True 85 | 86 | def __len__(self): 87 | return len(self._batches) 88 | 89 | def _init_iter(self): 90 | if self.shuffle: 91 | self.epoch += 1 92 | g = torch.Generator() 93 | g.manual_seed(self.epoch) 94 | range_fn = partial(torch.randperm, generator=g) 95 | else: 96 | range_fn = torch.arange 97 | 98 | batches = [] 99 | for i in range(len(self.buckets)): 100 | split_sizes = [ 101 | (len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 102 | for j in range(self.chunks[i]) 103 | ] 104 | for batch in range_fn(len(self.buckets[i])).split(split_sizes): 105 | batches.append([self.buckets[i][j] for j in batch]) 106 | batches = [ 107 | batch 108 | for i in range_fn(len(batches)) 109 | for batch in self._process_batch(batches[i]) 110 | ] 111 | 112 | self._batches = batches 113 | self._all_batches = batches 114 | self._exhausted = False 115 | 116 | def _init_iter_with_retry(self, max_try=5): 117 | _count = 0 118 | while self._exhausted: 119 | _count += 1 120 | if _count == max_try: 121 | raise ValueError("Failed to init iteration.") 122 | self._init_iter() 123 | 124 | def _process_batch(self, batch): 125 | # apply sort_in_batch and single_sent_threshold 126 | singles = [] 127 | if self.single_sent_threshold != -1: 128 | new_batch = [] 129 | for inst_idx in batch: 130 | if self.seq_len[inst_idx] >= self.single_sent_threshold: 131 | singles.append([inst_idx]) 132 | else: 133 | new_batch.append(inst_idx) 134 | batch = new_batch 135 | if self.sort_in_batch: 136 | batch.sort(key=lambda i: -self.seq_len[i]) 137 | if len(batch): 138 | return [batch] + singles 139 | else: 140 | return singles 141 | 142 | def set_epoch(self, epoch: int): 143 | # This is not a subclass of DistributedSampler, so will never be called by pytorch-lightning. 144 | breakpoint() # any case call this? 145 | self.epoch = epoch 146 | 147 | @staticmethod 148 | def kmeans(x, k, max_it=32): 149 | """From https://github.com/yzhangcs/parser/blob/main/supar/utils/alg.py#L7""" 150 | 151 | # the number of clusters must not be greater than the number of datapoints 152 | x, k = torch.tensor(x, dtype=torch.float), min(len(x), k) 153 | # collect unique datapoints 154 | d = x.unique() 155 | # initialize k centroids randomly 156 | c = d[torch.randperm(len(d))[:k]] 157 | # assign each datapoint to the cluster with the closest centroid 158 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) 159 | 160 | for _ in range(max_it): 161 | # if an empty cluster is encountered, 162 | # choose the farthest datapoint from the biggest cluster and move that the empty one 163 | mask = torch.arange(k).unsqueeze(-1).eq(y) 164 | none = torch.where(~mask.any(-1))[0].tolist() 165 | while len(none) > 0: 166 | for i in none: 167 | # the biggest cluster 168 | b = torch.where(mask[mask.sum(-1).argmax()])[0] 169 | # the datapoint farthest from the centroid of cluster b 170 | f = dists[b].argmax() 171 | # update the assigned cluster of f 172 | y[b[f]] = i 173 | # re-calculate the mask 174 | mask = torch.arange(k).unsqueeze(-1).eq(y) 175 | none = torch.where(~mask.any(-1))[0].tolist() 176 | # update the centroids 177 | c, old = (x * mask).sum(-1) / mask.sum(-1), c 178 | # re-assign all datapoints to clusters 179 | dists, y = torch.abs_(x.unsqueeze(-1) - c).min(-1) 180 | # stop iteration early if the centroids converge 181 | if c.equal(old): 182 | break 183 | # assign all datapoints to the new-generated clusters 184 | # the empty ones are discarded 185 | assigned = y.unique().tolist() 186 | # get the centroids of the assigned clusters 187 | centroids = c[assigned].tolist() 188 | # map all values of datapoints to buckets 189 | clusters = [torch.where(y.eq(i))[0].tolist() for i in assigned] 190 | 191 | return centroids, clusters 192 | 193 | 194 | class BasicSampler: 195 | """RandomSampler and SequentialSampler""" 196 | 197 | def __init__( 198 | self, 199 | seq_len, 200 | batch_size, 201 | single_sent_threshold=-1, 202 | sort_in_batch=True, 203 | shuffle=True, 204 | ): 205 | self.seq_len = seq_len 206 | self.batch_size = batch_size 207 | self.single_sent_threshold = single_sent_threshold 208 | self.sort_in_batch = sort_in_batch 209 | self.shuffle = shuffle 210 | self.epoch = 0 211 | 212 | self._sampler = RandomSampler() if shuffle else SequentialSampler() 213 | 214 | def __iter__(self): 215 | batch = [] 216 | for i in self._sampler(self.seq_len): 217 | batch.append(i) 218 | if len(batch) == self.batch_size: 219 | yield from self._process_batch(batch) 220 | batch.clear() 221 | if batch: 222 | yield from self._process_batch(batch) 223 | 224 | def __len__(self): 225 | return math.ceil(len(self.seq_len) / self.batch_size) 226 | 227 | def _process_batch(self, batch): 228 | # apply sort_in_batch and single_sent_threshold 229 | singles = [] 230 | if self.single_sent_threshold != -1: 231 | new_batch = [] 232 | for inst_idx in batch: 233 | if self.seq_len[inst_idx] >= self.single_sent_threshold: 234 | singles.append([inst_idx]) 235 | else: 236 | new_batch.append(inst_idx) 237 | batch = new_batch 238 | if self.sort_in_batch: 239 | batch.sort(key=lambda i: -self.seq_len[i]) 240 | if len(batch): 241 | return [batch] + singles 242 | else: 243 | return singles 244 | 245 | def set_epoch(self, epoch: int): 246 | # This is not a subclass of DistributedSampler 247 | # this function will never be called by pytorch-lightning. 248 | self.epoch = epoch 249 | -------------------------------------------------------------------------------- /src/datamodule/task/__init__.py: -------------------------------------------------------------------------------- 1 | from .dep import DepDataModule 2 | from .vlparse import VLParseDataModule -------------------------------------------------------------------------------- /src/datamodule/task/dep.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from fastNLP import DataSet 4 | from fastNLP.io import ConllLoader 5 | from nltk.corpus import stopwords 6 | 7 | import src 8 | 9 | from src.datamodule.datamodule import DataModule 10 | from src.datamodule.vocabulary import Vocabulary 11 | from src.utility.alg import isprojective 12 | from src.utility.logger import get_logger_func 13 | import omegaconf 14 | 15 | _warn, _info, _debug = get_logger_func('runner') 16 | 17 | 18 | class DepDataModule(DataModule): 19 | INPUTS = ('id', 'word', 'token', 'seq_len') # word for encoder, token for dmv 20 | TARGETS = ('arc', ) 21 | LOADER = ConllLoader 22 | 23 | def __init__( 24 | self, 25 | use_tag=True, 26 | num_lex=0, # limit word in token. not consider tag. 27 | num_token=99999, # limit total token. consider (lex, tag) pair. 28 | ignore_stop_word=False, 29 | headers=None, 30 | indexes=None, 31 | **kwargs): 32 | assert num_lex > 0 or use_tag, 'Nothing to build token' 33 | 34 | headers = headers or ['raw_word', 'tag', 'arc'] 35 | indexes = indexes or [1, 2, 3] 36 | loader = self.LOADER(headers, indexes=indexes, dropna=False, sep='\t') 37 | 38 | self.use_tag = use_tag 39 | if use_tag: 40 | assert 'tag' in headers 41 | self.INPUTS = self.INPUTS + ('tag', ) 42 | self.EXTRA_VOCAB = self.EXTRA_VOCAB + ('tag', ) 43 | 44 | self.num_lex = num_lex 45 | self.num_token = num_token 46 | self.ignore_stop_word = ignore_stop_word 47 | super().__init__(loader=loader, **kwargs) 48 | self.vocabs['token'] = None # set to manual init 49 | 50 | self.token2word = None 51 | self.token2tag = None 52 | if self.use_tag and self.num_lex > 0: 53 | self.token_mode = 'joint' 54 | elif self.use_tag: 55 | self.token_mode = 'tag' 56 | else: 57 | self.token_mode = 'word' 58 | 59 | def _load(self, path, name): 60 | ds: DataSet = self.loader._load(path) 61 | 62 | if self.token_mode == 'joint': 63 | ds.apply(lambda x: [f'{w.lower()}:{p}' for w, p in zip(x['raw_word'], x['tag'])], new_field_name='token') 64 | elif self.token_mode == 'tag': 65 | ds.apply(lambda x: x['tag'], new_field_name='token') 66 | else: 67 | ds.apply(lambda x: list(map(str.lower, x['raw_word'])), new_field_name='token') 68 | 69 | if name in ('train', 'train_init', 'dev', 'val', 'test'): 70 | ds['arc'].int() 71 | orig_len = len(ds) 72 | ds.drop(lambda i: not isprojective(i['arc']), inplace=False) 73 | cleaned_len = len(ds) 74 | if cleaned_len < orig_len: 75 | _warn(f'Data contains nonprojective trees. {path}') 76 | else: 77 | raise NotImplementedError 78 | 79 | return ds 80 | 81 | def post_init_vocab(self, datasets): 82 | count = Counter() 83 | word_count = Counter() 84 | 85 | if self.token_mode == 'tag': 86 | self.vocabs['token'] = self.vocabs['tag'] 87 | self.token2tag = list(range(len(self.vocabs['token']))) 88 | return 89 | 90 | for ds in self.get_create_entry_ds(): 91 | for inst in ds: 92 | word_count.update(map(str.lower, inst['word'])) 93 | if self.token_mode == 'joint': 94 | count.update(zip(map(str.lower, inst['word']), inst['tag'])) 95 | 96 | if self.ignore_stop_word: 97 | sw = set(stopwords.words('english')) 98 | used_word = [w for w, i in word_count.most_common(self.num_lex + len(sw)) if w not in sw] 99 | used_word = set(used_word[:self.num_lex]) 100 | else: 101 | used_word = set(w for w, i in word_count.most_common(self.num_lex)) 102 | 103 | processed_count = {} 104 | if self.token_mode == 'joint': 105 | for (w, p), c in count.most_common(): 106 | if w in used_word: 107 | processed_count[f'{w}:{p}'] = c 108 | if len(processed_count) == self.num_token: 109 | break 110 | for p in self.vocabs['tag'].word2idx: 111 | if p in ('', ''): continue 112 | processed_count[f':{p}'] = 100000 113 | else: 114 | for w, c in word_count.most_common(): 115 | if w in used_word: 116 | processed_count[w] = c 117 | if len(processed_count) == self.num_token: 118 | break 119 | 120 | token_vocab = Vocabulary() 121 | token_vocab.word_count = Counter(processed_count) 122 | token_vocab.build_vocab() 123 | self.vocabs['token'] = token_vocab 124 | 125 | if self.token_mode == 'joint': 126 | w, t = zip(*[token_vocab.idx2word[i].rsplit(':', 1) for i in range(2, len(token_vocab))]) 127 | w = ['', ''] + list(w) 128 | t = ['', ''] + list(t) 129 | self.token2word = [self.vocabs['word'][i] for i in w] 130 | self.token2tag = [self.vocabs['tag'][i] for i in t] 131 | else: 132 | self.token2word = [self.vocabs['word'][token_vocab.idx2word[i]] for i in range(len(token_vocab))] 133 | 134 | def train_dataloader(self): 135 | loaders = {'train': self.dataloader('train')} 136 | for key in self.datasets: 137 | if key in ('train', 'dev', 'test'): 138 | continue 139 | if key == 'train_init': 140 | try: 141 | n_init = src.g_cfg.model.init_epoch 142 | do_init = src.g_cfg.model.init_method == 'y' and n_init > 0 143 | except (KeyError, omegaconf.errors.ConfigAttributeError): 144 | _warn('ignoring train_init due to missing cfg.') 145 | continue 146 | if do_init: 147 | loaders['train'] = _TrainInitLoader(self.dataloader('train_init'), loaders['train'], n_init) 148 | loaders[key] = self.dataloader(key) 149 | _info(f'Returning {len(loaders)} loader(s) as train_dataloader.') 150 | return loaders 151 | 152 | 153 | class _TrainInitLoader: 154 | def __init__(self, init_loader, normal_loader, n_init) -> None: 155 | self.init_loader = init_loader 156 | self.normal_loader = normal_loader 157 | self.n_init = n_init 158 | self.current = 1 159 | 160 | def __iter__(self): 161 | if self.current <= self.n_init: 162 | self.current += 1 163 | _warn('Initializing') 164 | yield from self.init_loader 165 | else: 166 | yield from self.normal_loader 167 | -------------------------------------------------------------------------------- /src/datamodule/task/vlparse.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import json 3 | import os 4 | from pathlib import Path 5 | from typing import Any, Dict, List, Tuple 6 | 7 | import numpy as np 8 | import torch 9 | from fastNLP.core import DataSet 10 | from omegaconf import DictConfig, ListConfig 11 | from torch import Tensor 12 | 13 | from src.datamodule.task.dep import DepDataModule 14 | from src.utility.logger import get_logger_func 15 | 16 | InputDict = Dict[str, Tensor] 17 | TensorDict = Dict[str, Tensor] 18 | AnyDict = Dict[str, Any] 19 | GenDict = (dict, DictConfig) 20 | GenList = (list, ListConfig) 21 | 22 | _warn, _info, _debug = get_logger_func("datamodule") 23 | 24 | 25 | def get_box(obj): 26 | return [obj["x"], obj["y"], obj["x"] + obj["width"], obj["y"] + obj["height"]] 27 | 28 | 29 | class _COCODetFeatLazyLoader: 30 | def __init__(self, root, sg_data, sample, gold): 31 | self.root = root 32 | self.sg_data = sg_data 33 | self.sample = sample 34 | self.gold = gold 35 | 36 | def __call__(self, batch: List[Tuple[int, Any]]): 37 | box_feats, boxes, masks, rel_masks = [], [], [], [] 38 | max_len = 0 39 | for _, inst in batch: 40 | if (self.root / f"{inst['img_id']}.npy").exists(): 41 | feat = np.load(str(self.root / f"{inst['img_id']}.npy")) 42 | if self.sample > 0 and self.sample < len(feat): 43 | sample_id = np.random.choice( 44 | np.arange(len(feat)), self.sample, False 45 | ) 46 | feat = feat[sample_id] 47 | else: 48 | feat = feat[:35] 49 | sample_id = np.arange(len(feat)) 50 | box_feat, box = feat[:, :-4], feat[:, -4:] 51 | box_feat = torch.tensor(box_feat, dtype=torch.float) 52 | box = torch.tensor(box) 53 | 54 | box_feats.append(box_feat) 55 | boxes.append(box) 56 | 57 | if self.gold: 58 | inst_mask, inst_rel_mask = self.build_gold_mask(inst, sample_id) 59 | masks.append(inst_mask) 60 | rel_masks.append(inst_rel_mask) 61 | else: 62 | masks.append(torch.ones(len(box_feat), dtype=torch.bool)) 63 | rel_masks.append(None) 64 | max_len = max(len(box_feat), max_len) 65 | else: 66 | assert False 67 | 68 | box_feats_output = torch.zeros(len(box_feats), max_len, 2048) 69 | boxes_output = torch.zeros(len(boxes), max_len, 4) 70 | masks_output = torch.zeros(len(masks), max_len, dtype=torch.bool) 71 | rel_masks_output = ( 72 | None 73 | if len(rel_masks) == 0 74 | else torch.zeros(len(rel_masks), max_len, max_len, dtype=torch.bool) 75 | ) 76 | for i, (bf, b, m, rm) in enumerate(zip(box_feats, boxes, masks, rel_masks)): 77 | if bf is not None: 78 | box_feats_output[i, : len(bf)] = bf 79 | boxes_output[i, : len(b)] = b 80 | masks_output[i, : len(m)] = m 81 | if rm is not None: 82 | rel_masks_output[i, : rm.shape[0], : rm.shape[1]] = rm 83 | 84 | return ( 85 | { 86 | "vis_box_feat": box_feats_output, 87 | "vis_box_mask": masks_output, 88 | "vis_rel_mask": rel_masks_output, 89 | "vis_available": masks_output[:, 0], 90 | }, 91 | {"vis_box": boxes_output}, 92 | ) 93 | 94 | def build_gold_mask(self, inst, sample_id): 95 | sg_inst = self.sg_data[inst["img_id"]] 96 | if len(sg_inst["obj"]) == 0: 97 | return torch.zeros(0, dtype=torch.bool), torch.zeros(0, 0, dtype=torch.bool) 98 | mask = torch.ones(min(len(sample_id), len(sg_inst["obj"])), dtype=torch.bool) 99 | rel_mask = torch.zeros( 100 | len(sg_inst["obj"]), len(sg_inst["obj"]), dtype=torch.bool 101 | ) 102 | for item in sg_inst["rel"]: 103 | rel_mask[item["subj"], item["obj"]] = 1 104 | sample_id = torch.from_numpy(sample_id) 105 | rel_mask = rel_mask.gather( 106 | 1, sample_id.unsqueeze(0).expand(rel_mask.shape[1], -1) 107 | ).gather(0, sample_id.unsqueeze(-1).expand(-1, len(sample_id))) 108 | return mask, rel_mask 109 | 110 | 111 | class VLParseDataModule(DepDataModule): 112 | TARGETS = ("arc", "sg_type", "sg_box", "sg_mask") 113 | # train: text(.conll), proposed box(det_feats/.npy), img(.npy) 114 | # dev: text(.conll), proposed box(det_feats/.npy), img(.npy), scene graph(../.json) 115 | # test: text(.conll), proposed box(det_feats/.npy), scene graph(../.json) 116 | 117 | def __init__(self, use_img, use_gold_scene_graph, sg_path, **kwargs): 118 | 119 | self.use_img = use_img # use native image feature 120 | if self.use_img: 121 | self.INPUTS = self.INPUTS + ("vis_img",) 122 | self.use_gold_scene_graph = use_gold_scene_graph # return gold box and rels 123 | 124 | with open(sg_path) as f: # load scene graph 125 | sg_data = json.load(f) 126 | self.sg_data = {inst["coco_id"]: inst for inst in sg_data} 127 | 128 | if use_gold_scene_graph: 129 | with open(os.path.split(sg_path)[0] + "/vlparse_train_sg_raw.json") as f: 130 | sg_data = json.load(f) 131 | self.sg_data |= {inst["coco_id"]: inst for inst in sg_data} 132 | 133 | super().__init__(**kwargs) 134 | 135 | def _load(self, path, name) -> DataSet: 136 | # text: xxx.conll, a conllu format file 137 | # img: xxx.npy, each item is prefeteched feat. [n_img x hidden_size] 138 | # det_feats/.npy, box feat for each img shape: 100 x (1024+4) 139 | # id_list/xxx.txt, each line is a img_id and sent_id pair. assume sent with same img_id are put together. 140 | ds: DataSet = super()._load(path + ".conll", name) 141 | 142 | # load ids 143 | folder, filename = os.path.split(path) 144 | with open(Path(folder) / "id_list" / (filename + ".txt")) as f: 145 | img_id = [int(line.strip()) for line in f] 146 | if len(img_id) != len(ds): 147 | img_id = [id_ for id_ in img_id for _ in range(5)] 148 | ds.add_field("img_id", img_id) 149 | ds.add_field("img_sent_id", [i % 5 for i, _ in enumerate(img_id)]) 150 | 151 | # native image feature 152 | with self.tolerant_exception(["test"], name): 153 | if self.use_img: 154 | img_feat = np.load(path + ".npy").repeat(5, 0) 155 | ds.add_field("vis_img", img_feat, is_input=True) 156 | 157 | # prepare target, (and input if gold_sg) from sg data 158 | ds.apply_more(self.process_sg) 159 | 160 | ds.add_collate_fn( 161 | _COCODetFeatLazyLoader( 162 | Path(folder) 163 | / ("gold_feats" if self.use_gold_scene_graph else "det_feats"), 164 | self.sg_data, 165 | 35 if name in ("train", "train_init") else 0, 166 | self.use_gold_scene_graph, 167 | ), 168 | "det_feat_loader", 169 | ) 170 | if name in ("dev", "test") or self.use_gold_scene_graph: 171 | ds.drop(lambda x: not x["has_sg"]) 172 | return ds 173 | 174 | def process_sg(self, inst): 175 | if inst["img_id"] not in self.sg_data: 176 | txt2sg = {} 177 | rels = [] 178 | else: 179 | sg = self.sg_data[inst["img_id"]] 180 | rels = sg["rel"] 181 | txt2sg = sg["txt2sg"][inst["img_sent_id"]] 182 | id2node = {node["id"]: node for node in chain(sg["obj"], sg["rel"])} 183 | typestr2id = {"OBJ": 1, "ATTR": 2, "REL": 3} 184 | gold_box, tok_type = [], [] 185 | 186 | # here only collect grounded box per words 187 | for i in range(len(inst["raw_word"])): 188 | if (i := str(i)) in txt2sg: 189 | alignment = txt2sg[i] 190 | tok_type.append(typestr2id[alignment["type"]]) 191 | if tok_type[-1] == 3: 192 | node = id2node[alignment["preferred"]] 193 | subj, obj = id2node[node["subj"]], id2node[node["obj"]] 194 | gold_box.append(get_box(subj) + get_box(obj)) 195 | else: 196 | gold_box.append( 197 | get_box(id2node[alignment["preferred"]]) + [0.0] * 4 198 | ) 199 | else: 200 | tok_type.append(0) 201 | gold_box.append([0.0] * 8) 202 | 203 | sg_rel = [[item["subj"], item["obj"]] for item in rels] 204 | return { 205 | "sg_type": tok_type, 206 | "sg_box": gold_box, 207 | "vis_rel": sg_rel, # this is for inputs. When eval we just need sg_box. 208 | "sg_mask": [t != 0 for t in tok_type], 209 | "has_sg": inst["img_id"] in self.sg_data, 210 | } 211 | -------------------------------------------------------------------------------- /src/datamodule/vocabulary.py: -------------------------------------------------------------------------------- 1 | from fastNLP import Vocabulary as _fastNLP_Vocabulary 2 | from fastNLP.core.vocabulary import _check_build_vocab 3 | 4 | 5 | class Vocabulary(_fastNLP_Vocabulary): 6 | @_check_build_vocab 7 | def __getitem__(self, w: str): 8 | if w.endswith("::"): 9 | w = [w[:-2], ":"] 10 | else: 11 | w = w.rsplit(":", 1) 12 | w[0] = w[0].lower() 13 | if (_w := ":".join(w)) in self._word2idx: 14 | return self._word2idx[_w] 15 | if (_w := ":" + w[1]) in self._word2idx: 16 | return self._word2idx[_w] 17 | # no need to check 18 | raise ValueError("word `{}` not in vocabulary".format(w)) 19 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ModelBase, JointModelBase 2 | from .dmv import DMV 3 | from .ldndmv import DiscriminativeNDMV 4 | from .joint import DependencyBoxRel -------------------------------------------------------------------------------- /src/model/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import functools 4 | from collections import defaultdict 5 | from io import IOBase 6 | from typing import Any, Dict, List, Tuple 7 | 8 | import torch.nn as nn 9 | from fastNLP import DataSet, Vocabulary 10 | from hydra.utils import instantiate 11 | from torch import Tensor 12 | 13 | import src 14 | from src.datamodule import DataModule 15 | from src.model.embedding import Embedding 16 | from src.model.text_encoder import EncoderBase 17 | from src.utility.defaultlist import defaultlist 18 | from src.utility.fn import get_coeff_iter 19 | from src.utility.logger import get_logger_func 20 | from src.utility.var_pool import VarPool 21 | from abc import ABC 22 | from typing import Dict, Any, Type, Tuple 23 | 24 | from src.utility.config import Config 25 | from hydra.utils import instantiate 26 | from omegaconf import open_dict, OmegaConf 27 | from torch import Tensor 28 | 29 | 30 | from src.model.vis_encoder.base import VisEncoderBase 31 | 32 | InputDict = Dict[str, Tensor] 33 | TensorDict = Dict[str, Tensor] 34 | AnyDict = Dict[str, Any] 35 | 36 | _warn, _info, _debug = get_logger_func("model") 37 | 38 | 39 | class ModelBase(nn.Module): 40 | datamodule: DataModule 41 | embedding: Embedding 42 | encoder: EncoderBase 43 | _function_group = {} 44 | 45 | def __init__(self): 46 | super(ModelBase, self).__init__() 47 | self._dynamic_cfg = {} 48 | 49 | def setup(self, dm: DataModule): 50 | self.datamodule = dm 51 | self.embedding = Embedding(**src.g_cfg.embedding, dm=dm) 52 | self.encoder = instantiate(src.g_cfg.encoder, embedding=self.embedding) 53 | self.embedding.__dict__["bounded_model"] = self 54 | self.encoder.__dict__["bounded_model"] = self 55 | 56 | def forward( 57 | self, inputs: InputDict, vp: VarPool, embed=None, encoded=None, return_all=False 58 | ): 59 | dyn_cfg = self.apply_dynamic_cfg() 60 | src.trainer.lightning_module.log_dict(dyn_cfg) 61 | if embed is None: 62 | embed = self.embedding(inputs, vp) 63 | if encoded is None or encoded["__need_encode"]: 64 | if encoded is None: 65 | encoded = {} 66 | else: 67 | del encoded["__need_encode"] 68 | encoded |= self.encoder(embed, vp) 69 | encoded["emb"] = embed 70 | score = self._forward(inputs, encoded, vp) 71 | if return_all: 72 | return embed, encoded, score 73 | return score 74 | 75 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool): 76 | raise NotImplementedError 77 | 78 | def loss( 79 | self, x: TensorDict, gold: InputDict, vp: VarPool 80 | ) -> Tuple[Tensor, TensorDict]: 81 | raise NotImplementedError 82 | 83 | def decode(self, x: TensorDict, vp: VarPool) -> AnyDict: 84 | raise NotImplementedError 85 | 86 | def normalize_embedding(self, now): 87 | self.embedding.normalize(now) 88 | 89 | def preprocess_write(self, output: List[Dict[str, Any]]): 90 | batch_size = len(output[0]["id"]) # check one batch 91 | safe_to_sort = all( 92 | (len(p) == batch_size) for p in output[0]["predict"].values() 93 | ) 94 | 95 | if safe_to_sort: 96 | # I will put all predicts in the order of idx, but you have to remove padding by yourself. 97 | sorted_predicts = defaultdict(defaultlist) 98 | for batch in output: 99 | id_, predict = batch["id"], batch["predict"] 100 | for key, value in predict.items(): 101 | if isinstance(value, Tensor): 102 | value = value.detach().cpu().numpy() 103 | for one_id, one_value in zip(id_, value): 104 | sorted_predicts[key][one_id] = one_value 105 | return sorted_predicts 106 | else: 107 | raise NotImplementedError("Can not preprocess automatically.") 108 | 109 | def write_prediction( 110 | self, s: IOBase, predicts, dataset: DataSet, vocabs: Dict[str, Vocabulary] 111 | ) -> IOBase: 112 | raise NotImplementedError 113 | 114 | # noinspection PyMethodMayBeStatic 115 | def set_varpool(self, vp: VarPool) -> VarPool: 116 | return vp 117 | 118 | @classmethod 119 | def add_impl_to_group(cls, group, spec, pre_hook=None): 120 | def decorator(func): 121 | if group not in cls._function_group: 122 | cls._function_group[group] = {} 123 | assert spec not in cls._function_group[group], spec 124 | cls._function_group[group][spec] = (func, pre_hook) 125 | 126 | @functools.wraps(func) 127 | def wrapper(*args, **kwargs): 128 | return func(*args, **kwargs) 129 | 130 | return wrapper 131 | 132 | return decorator 133 | 134 | def set_impl_in_group(self, group, spec): 135 | try: 136 | impl, pre_hook = self._function_group[group][spec] 137 | except Exception as e: 138 | _warn(f"Failed to load {group}: {spec}") 139 | raise e 140 | if pre_hook is not None: 141 | getattr(self, pre_hook)() 142 | setattr(self, group, functools.partial(impl, self)) 143 | 144 | def add_dynamic_cfg(self, name, command): 145 | """name: |""" 146 | if name in self._dynamic_cfg: 147 | _warn(f"Overwriting {name} with {command}") 148 | self._dynamic_cfg[name] = get_coeff_iter( 149 | command, idx_getter=lambda: src.trainer.current_epoch 150 | ) 151 | 152 | def apply_dynamic_cfg(self): 153 | params = {key: next(value) for key, value in self._dynamic_cfg.items()} 154 | for key, value in params.items(): 155 | obj_nev, cfg_nev = key.split("|") 156 | o = self 157 | for attr_name in obj_nev.split("."): 158 | o = getattr(o, attr_name) 159 | s = o 160 | cfg_nev = cfg_nev.split(".") 161 | for k in cfg_nev[:-1]: 162 | s = s[k] 163 | s[cfg_nev[-1]] = value 164 | return params 165 | 166 | def process_checkpoint(self, ckpt): 167 | return ckpt 168 | 169 | 170 | class JointModelBase(ModelBase, ABC): 171 | # assume only one datamodule 172 | # assume image does not require embedding 173 | # assume all visual-side module/parameter are named with 'vis_' prefix. 174 | 175 | # I prefer not seperate the joint model into a language-side model and a visual-side model 176 | # because it is hard to foresee possible interaction between two sides and 177 | # for now the visual-side model is very simple. 178 | 179 | # language part, inherit from ModelBase 180 | # datamodule: DataModule 181 | # embedding: Embedding 182 | # encoder: EncoderBase 183 | 184 | # visual part 185 | vis_encoder: VisEncoderBase 186 | 187 | def setup(self, dm: DataModule): 188 | if getattr(self, "__setup_handled") is not True: 189 | _warn("You call setup() directly. Consider to use _setup()") 190 | self.datamodule = dm 191 | # self.embedding = Embedding(**src.g_cfg.embedding, dm=dm) 192 | # self.embedding.__dict__['bounded_model'] = self 193 | self.encoder = instantiate(src.g_cfg.encoder, embedding=self.embedding) 194 | self.encoder.__dict__["bounded_model"] = self 195 | self.vis_encoder = instantiate(src.g_cfg.vis_encoder) 196 | if self.vis_encoder is None: 197 | _warn("vis_encoder is disabled.") 198 | else: 199 | self.vis_encoder.__dict__["bounded_model"] = self 200 | 201 | def _setup(self, dm: DataModule, cfg_class: Type[Config], allow_missing=None): 202 | setattr(self, "__setup_handled", True) 203 | self.cfg = cfg = cfg_class.build(self.cfg, allow_missing=allow_missing) 204 | with open_dict(cfg.dep_model_cfg): 205 | cfg.dep_model_cfg = OmegaConf.merge(cfg.dep_model_cfg, dm.get_vocab_count()) 206 | self.dependency = instantiate(cfg.dep_model_cfg) 207 | self.dependency.setup(dm) 208 | JointModelBase.setup(self, dm) 209 | return cfg 210 | 211 | @property 212 | def embedding(self): 213 | return self.dependency.embedding 214 | 215 | def forward( 216 | self, 217 | inputs: InputDict, 218 | vp: VarPool, 219 | embed=None, 220 | encoded=None, 221 | vis_encoded=None, 222 | return_all=False, 223 | ): 224 | if vis_encoded is None: 225 | vis_input = { 226 | key: value for key, value in inputs.items() if key.startswith("vis_") 227 | } 228 | if len(vis_input) > 0: 229 | vis_encoded = self.vis_encoder(vis_input, vp) 230 | else: 231 | vis_encoded = {} 232 | encoded = encoded if encoded is not None else {"__need_encode": True} 233 | for key, value in vis_encoded.items(): 234 | encoded[f"vis_{key}"] = value 235 | embed, encoded, score = super().forward(inputs, vp, embed, encoded, True) 236 | vis_score = self._vis_forward(inputs, vis_encoded, encoded, score, vp) 237 | score = {**score, **vis_score} 238 | if return_all: 239 | return embed, encoded, score 240 | else: 241 | return score 242 | 243 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool): 244 | return self.dependency._forward(inputs, encoded, vp) 245 | 246 | def _vis_forward( 247 | self, 248 | inputs: InputDict, 249 | encoded: TensorDict, 250 | language_encoded: TensorDict, 251 | lang_score: TensorDict, 252 | vp: VarPool, 253 | ): 254 | raise NotImplementedError 255 | 256 | -------------------------------------------------------------------------------- /src/model/dmv.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from io import IOBase 3 | from typing import Any, Dict, Tuple, Optional 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from fastNLP import DataSet, Vocabulary 9 | from hydra.conf import MISSING 10 | from torch import Tensor 11 | from torch.optim import Optimizer 12 | 13 | from src.datamodule.task import DepDataModule 14 | from src.model import ModelBase 15 | from src.model.dmv_helper import km_init, good_init 16 | from src.model.torch_struct import DMV1o, DependencyCRF 17 | from src.utility.config import Config 18 | from src.utility.logger import get_logger_func 19 | from src.utility.var_pool import VarPool 20 | 21 | InputDict = Dict[str, Tensor] 22 | TensorDict = Dict[str, Tensor] 23 | AnyDict = Dict[str, Any] 24 | 25 | _warn, _info, _debug = get_logger_func('model') 26 | 27 | 28 | @dataclass 29 | class DMVConfig(Config): 30 | viterbi_training: bool 31 | mbr_decoding: bool 32 | init_method: str # km, good, random 33 | smooth: float 34 | 35 | # ============================= AUTO FIELDS ============================= 36 | n_word: int = MISSING 37 | n_tag: int = MISSING 38 | n_token: int = MISSING 39 | 40 | 41 | class DMV(ModelBase): 42 | _instance = None # work around for DMVMStepOptimizer 43 | 44 | def __init__(self, **cfg): 45 | super().__init__() 46 | # noinspection PyTypeChecker 47 | self.cfg: DMVConfig = cfg 48 | self.root_param: Optional[nn.Parameter] = None 49 | self.trans_param: Optional[nn.Parameter] = None 50 | self.dec_param: Optional[nn.Parameter] = None 51 | self.optimizer: Optional[DMVMStepOptimizer] = None 52 | 53 | if DMV._instance is not None: 54 | _warn('overwriting DMV._instance') 55 | DMV._instance = self 56 | 57 | def setup(self, dm: DepDataModule): 58 | self.datamodule = dm 59 | self.cfg = cfg = DMVConfig.build(self.cfg, allow_missing={'n_word', 'n_tag'}) 60 | 61 | if cfg.init_method == 'km': 62 | d, t, r = km_init(dm.datasets['train'], cfg.n_token, cfg.smooth) 63 | elif cfg.init_method == 'good': 64 | d, t, r = good_init(dm.datasets['train'], cfg.n_token, cfg.smooth) 65 | else: 66 | d = np.random.randn(cfg.n_token, 2, 2, 2) 67 | r = np.random.randn(cfg.n_token) 68 | t = np.random.randn(cfg.n_token, cfg.n_token, 2, 2) 69 | 70 | self.root_param = nn.Parameter(torch.from_numpy(r)) 71 | # head, child, dir, valence 72 | self.trans_param = nn.Parameter(torch.from_numpy(t)) 73 | # head, dir, valence, decision 74 | self.dec_param = nn.Parameter(torch.from_numpy(d)) 75 | 76 | def forward(self, inputs: InputDict, vp: VarPool, embed=None, encoded=None, return_all=False): 77 | assert embed is None 78 | assert encoded is None 79 | assert not return_all 80 | return self._forward(inputs, {}, vp) 81 | 82 | def _forward(self, inputs: InputDict, encoded: TensorDict, vp: VarPool): 83 | b, l, n = vp.batch_size, vp.max_len, self.cfg.n_token 84 | token_array = inputs['token'] 85 | 86 | t = self.trans_param.unsqueeze(0).expand(b, n, n, 2, 2) 87 | head_token_index = token_array.view(b, l, 1, 1, 1).expand(b, l, n, 2, 2) 88 | child_token_index = token_array.view(b, 1, l, 1, 1).expand(b, l, l, 2, 2) 89 | t = torch.gather(torch.gather(t, 1, head_token_index), 2, child_token_index) 90 | index = torch.triu(torch.ones(l, l, dtype=torch.long, device=t.device)) \ 91 | .view(1, l, l, 1, 1).expand(b, l, l, 1, 2) 92 | t = torch.gather(t, 3, index).squeeze(3) 93 | 94 | d = self.dec_param.unsqueeze(0).expand(b, n, 2, 2, 2) 95 | head_pos_index = token_array.view(b, l, 1, 1, 1).expand(b, l, 2, 2, 2) 96 | d = torch.gather(d, 1, head_pos_index) 97 | 98 | r = self.root_param.unsqueeze(0).expand(b, n) 99 | r = torch.gather(r, 1, token_array) 100 | 101 | merged_d, merged_t = DMV1o.merge(d, t, r) 102 | return {'merged_dec': merged_d, 'merged_attach': merged_t} 103 | 104 | def loss(self, x: TensorDict, gold: InputDict, vp: VarPool) -> Tuple[Tensor, TensorDict]: 105 | dist = DMV1o([x['merged_dec'], x['merged_attach']], vp.seq_len) 106 | if self.cfg.viterbi_training: 107 | ll = dist.max.sum() 108 | else: 109 | ll = dist.partition.sum() 110 | return -ll, {'ll': ll} 111 | 112 | # noinspection DuplicatedCode 113 | @torch.enable_grad() 114 | def decode(self, x: TensorDict, vp: VarPool) -> AnyDict: 115 | if self.optimizer: 116 | self.optimizer.apply() 117 | mdec = x['merged_dec'].detach().requires_grad_() 118 | mattach = x['merged_attach'].detach().requires_grad_() 119 | dist = DMV1o([mdec, mattach], vp.seq_len) 120 | if self.cfg.mbr_decoding: 121 | arc = torch.autograd.grad(dist.partition.sum(), mattach)[0].sum(-1) 122 | dist = DependencyCRF(arc, vp.seq_len) 123 | arc = dist.argmax.nonzero() 124 | predicted = vp.seq_len.new_zeros(vp.batch_size, vp.max_len) 125 | predicted[arc[:, 0], arc[:, 2] - 1] = arc[:, 1] 126 | else: 127 | arc = dist.argmax.sum(-1).nonzero() 128 | predicted = vp.seq_len.new_zeros(vp.batch_size, vp.max_len) 129 | predicted[arc[:, 0], arc[:, 2] - 1] = arc[:, 1] 130 | return {'arc': predicted} 131 | 132 | def normalize_embedding(self, now): 133 | pass 134 | 135 | # noinspection DuplicatedCode 136 | def write_prediction(self, s: IOBase, predicts, dataset: DataSet, vocabs: Dict[str, Vocabulary]) -> IOBase: 137 | for i, length in enumerate(dataset['seq_len'].content): 138 | word, arc = dataset[i]['raw_word'], predicts['arc'][i] 139 | for line_id, (word, arc) in enumerate(zip(word, arc), start=1): 140 | line = '\t'.join([str(line_id), word, '-', str(arc)]) 141 | s.write(f'{line}\n') 142 | s.write('\n') 143 | return s 144 | 145 | 146 | class DMVMStepOptimizer(Optimizer): 147 | def __init__(self, params, smooth: float): 148 | self.dmv = DMV._instance 149 | self.dmv.optimizer = self 150 | 151 | self._root, self._dec, self._trans = None, None, None 152 | self.smooth = smooth 153 | self.can_apply = False 154 | super().__init__(self.dmv.parameters(), {}) 155 | 156 | def step(self, closure=None): 157 | loss = None 158 | if closure is not None: 159 | with torch.enable_grad(): 160 | loss = closure() 161 | 162 | if self._root is None: 163 | self._root = torch.zeros_like(self.dmv.root_param) 164 | self._dec = torch.zeros_like(self.dmv.dec_param) 165 | self._trans = torch.zeros_like(self.dmv.trans_param) 166 | 167 | self._root -= self.dmv.root_param.grad 168 | self._dec -= self.dmv.dec_param.grad 169 | self._trans -= self.dmv.trans_param.grad 170 | self.can_apply = True 171 | 172 | def apply(self): 173 | if self.can_apply: 174 | self.dmv.root_param.data, self._root = \ 175 | torch.log(self._root + self.smooth).log_softmax(0), self.dmv.root_param.data 176 | self.dmv.dec_param.data, self._dec = \ 177 | torch.log(self._dec + self.smooth).log_softmax(3), self.dmv.dec_param.data 178 | self.dmv.trans_param.data, self._trans = \ 179 | torch.log(self._trans + self.smooth).log_softmax(1), self.dmv.trans_param.data 180 | self.reset() 181 | 182 | def reset(self): 183 | self._root.zero_() 184 | self._dec.zero_() 185 | self._trans.zero_() 186 | self.can_apply = False 187 | -------------------------------------------------------------------------------- /src/model/dmv_helper/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.dmv_helper.good_init import good_init 2 | from src.model.dmv_helper.good_init_nn import generate_rule_1o, LinearPadder, SquarePadder 3 | from src.model.dmv_helper.km_init import km_init 4 | -------------------------------------------------------------------------------- /src/model/dmv_helper/good_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastNLP import DataSet, AutoPadder 3 | 4 | 5 | from src.model.torch_struct.dmv import HASCHILD, NOCHILD, STOP, GO 6 | 7 | 8 | def recovery_one(heads): 9 | left_most = np.arange(len(heads)) 10 | right_most = np.arange(len(heads)) 11 | for idx, each_head in enumerate(heads): 12 | if each_head in (0, len(heads) + 1): # skip head is ROOT 13 | continue 14 | each_head -= 1 15 | if idx < left_most[each_head]: 16 | left_most[each_head] = idx 17 | if idx > right_most[each_head]: 18 | right_most[each_head] = idx 19 | 20 | valences = np.empty((len(heads), 2), dtype=np.int) 21 | head_valences = np.empty(len(heads), dtype=np.int) 22 | 23 | for idx, each_head in enumerate(heads): 24 | each_head -= 1 25 | valences[idx, 0] = NOCHILD if left_most[idx] == idx else HASCHILD 26 | valences[idx, 1] = NOCHILD if right_most[idx] == idx else HASCHILD 27 | if each_head > idx: # each_head = -1 `s head_valence is never used 28 | head_valences[idx] = NOCHILD if left_most[each_head] == idx else HASCHILD 29 | else: 30 | head_valences[idx] = NOCHILD if right_most[each_head] == idx else HASCHILD 31 | return valences, head_valences 32 | 33 | 34 | def good_init(dataset: DataSet, n_token: int, smooth: float): 35 | """process all sentences in one batch.""" 36 | max_len = max(dataset['seq_len'].content) 37 | heads = np.zeros((len(dataset), max_len + 1), dtype=np.int) 38 | valences = np.zeros((len(dataset), max_len + 1, 2), dtype=np.int) 39 | head_valences = np.zeros((len(dataset), max_len + 1), dtype=np.int) 40 | root_counter = np.zeros((n_token,)) 41 | 42 | for idx, instance in enumerate(dataset): 43 | one_heads = np.asarray(instance['arc']) 44 | one_valences, one_head_valences = recovery_one(one_heads) 45 | heads[idx, 1:instance['seq_len'] + 1] = one_heads 46 | valences[idx, 1:instance['seq_len'] + 1] = one_valences 47 | head_valences[idx, 1:instance['seq_len'] + 1] = one_head_valences 48 | 49 | batch_size, sentence_len = heads.shape 50 | len_array = np.asarray(dataset['seq_len'].content) 51 | token_array = AutoPadder()(dataset['token'].content, 'token', np.int, 1) 52 | batch_arange = np.arange(batch_size) 53 | 54 | batch_trans_trace = np.zeros((batch_size, max_len, max_len, 2, 2)) 55 | batch_dec_trace = np.zeros((batch_size, max_len, max_len, 2, 2, 2)) 56 | 57 | for m in range(1, sentence_len): 58 | h = heads[:, m] 59 | direction = (h <= m).astype(np.long) 60 | h_valence = head_valences[:, m] 61 | m_valence = valences[:, m] 62 | m_child_valence = h_valence 63 | 64 | len_mask = ((h <= len_array) & (m <= len_array)) 65 | 66 | batch_dec_trace[batch_arange, m - 1, m - 1, 0, m_valence[:, 0], STOP] = len_mask 67 | batch_dec_trace[batch_arange, m - 1, m - 1, 1, m_valence[:, 1], STOP] = len_mask 68 | 69 | head_mask = h == 0 70 | mask = head_mask * len_mask 71 | if mask.any(): 72 | np.add.at(root_counter, token_array[:, m - 1], mask) 73 | 74 | head_mask = ~head_mask 75 | mask = head_mask * len_mask 76 | if mask.any(): 77 | batch_trans_trace[batch_arange, h - 1, m - 1, direction, m_child_valence] = mask 78 | batch_dec_trace[batch_arange, h - 1, m - 1, direction, h_valence, GO] = mask 79 | 80 | dec_post_dim = (2, 2, 2) 81 | dec_counter = np.zeros((n_token, *dec_post_dim)) 82 | index = (token_array.flatten(),) 83 | np.add.at(dec_counter, index, np.sum(batch_dec_trace, 2).reshape(-1, *dec_post_dim)) 84 | 85 | trans_post_dim = (2, 2) 86 | head_ids = np.tile(np.expand_dims(token_array, 2), (1, 1, max_len)) 87 | child_ids = np.tile(np.expand_dims(token_array, 1), (1, max_len, 1)) 88 | trans_counter = np.zeros((n_token, n_token, *trans_post_dim)) 89 | index = (head_ids.flatten(), child_ids.flatten()) 90 | np.add.at(trans_counter, index, batch_trans_trace.reshape(-1, *trans_post_dim)) 91 | 92 | root_counter += smooth 93 | root_sum = root_counter.sum() 94 | root_param = np.log(root_counter / root_sum) 95 | 96 | trans_counter += smooth 97 | trans_sum = trans_counter.sum(axis=1, keepdims=True) 98 | trans_param = np.log(trans_counter / trans_sum) 99 | 100 | dec_counter += smooth 101 | dec_sum = dec_counter.sum(axis=3, keepdims=True) 102 | dec_param = np.log(dec_counter / dec_sum) 103 | return dec_param, trans_param, root_param 104 | -------------------------------------------------------------------------------- /src/model/dmv_helper/good_init_nn.py: -------------------------------------------------------------------------------- 1 | # unlike good_init.py, this file contains helpers to initialize nn without dmv. 2 | 3 | from typing import List 4 | 5 | import numpy as np 6 | from fastNLP.core.field import Padder 7 | 8 | from src.model.torch_struct.dmv import LEFT, RIGHT, HASCHILD, NOCHILD, GO, STOP 9 | 10 | 11 | class LinearPadder(Padder): 12 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 13 | max_sent_length = max(r.shape[0] for r in contents) 14 | batch_size = len(contents) 15 | out = np.full((batch_size, max_sent_length, *contents[0].shape[1:]), fill_value=self.pad_val, dtype=np.float) 16 | for b_idx, rule in enumerate(contents): 17 | sent_len = rule.shape[0] 18 | out[b_idx, :sent_len] = rule 19 | return out 20 | 21 | 22 | class SquarePadder(Padder): 23 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 24 | max_sent_length = max(r.shape[0] for r in contents) 25 | batch_size = len(contents) 26 | out = np.full((batch_size, max_sent_length, max_sent_length, *contents[0].shape[2:]), fill_value=self.pad_val, 27 | dtype=np.float) 28 | for b_idx, rule in enumerate(contents): 29 | sent_len = rule.shape[0] 30 | out[b_idx, :sent_len, :sent_len] = rule 31 | return out 32 | 33 | 34 | def generate_rule_1o(heads: List[int]): 35 | """ 36 | First-order DMV, generate the grammar rules used in the "predicted" parse tree from other parser. 37 | :param heads: the head of each position 38 | :return: decision rule 39 | """ 40 | seq_len = len(heads) 41 | decision = np.zeros(shape=(seq_len, 2, 2, 2)) 42 | attach = np.zeros(shape=(seq_len, seq_len, 2)) 43 | root = np.zeros(shape=(seq_len,)) 44 | root[heads.index(0)] = 1 45 | 46 | left_most_child = list(range(seq_len)) 47 | right_most_child = list(range(seq_len)) 48 | for child, head in enumerate(heads): 49 | head = head - 1 50 | if head == -1: 51 | continue 52 | elif child < head: 53 | if child < left_most_child[head]: 54 | left_most_child[head] = child 55 | else: 56 | if child > right_most_child[head]: 57 | right_most_child[head] = child 58 | 59 | for child, head in enumerate(heads): 60 | head = head - 1 61 | 62 | if child < head: 63 | most_child, d = left_most_child, LEFT 64 | else: 65 | most_child, d = right_most_child, RIGHT 66 | 67 | valence = NOCHILD if most_child[head] == child else HASCHILD 68 | decision[head][d][valence][GO] += 1 69 | if head != -1: 70 | attach[head][child][valence] += 1 71 | 72 | valence = NOCHILD if left_most_child[child] == child else HASCHILD 73 | decision[child][LEFT][valence][STOP] += 1 74 | 75 | valence = NOCHILD if right_most_child[child] == child else HASCHILD 76 | decision[child][RIGHT][valence][STOP] += 1 77 | 78 | return {'dec_rule': decision, 'attach_rule': attach, 'root_rule': root} 79 | -------------------------------------------------------------------------------- /src/model/dmv_helper/km_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastNLP import DataSet, DataSetIter 3 | from numpy import ndarray 4 | 5 | from src.datamodule.sampler import ConstantTokenNumSampler 6 | from src.model.torch_struct.dmv import HASCHILD, NOCHILD, STOP, GO 7 | 8 | harmonic_sum = [0., 1.] 9 | 10 | 11 | def get_harmonic_sum(n: int): 12 | global harmonic_sum 13 | while n >= len(harmonic_sum): 14 | harmonic_sum.append(harmonic_sum[-1] + 1 / len(harmonic_sum)) 15 | return harmonic_sum[n] 16 | 17 | 18 | def update_decision(change: ndarray, norm_counter: ndarray, token_array: ndarray, dec_param: ndarray): 19 | for i in range(token_array.shape[1]): 20 | pos = token_array[:, i] 21 | for _direction in (0, 1): 22 | if change[i, _direction] > 0: 23 | np.add.at(norm_counter, (pos, _direction, NOCHILD, GO), 1.) 24 | np.add.at(norm_counter, (pos, _direction, HASCHILD, GO), -1.) 25 | np.add.at(dec_param, (pos, _direction, HASCHILD, GO), change[i, _direction]) 26 | np.add.at(norm_counter, (pos, _direction, NOCHILD, STOP), -1.) 27 | np.add.at(norm_counter, (pos, _direction, HASCHILD, STOP), 1.) 28 | np.add.at(dec_param, (pos, _direction, NOCHILD, STOP), 1.) 29 | else: 30 | np.add.at(dec_param, (pos, _direction, NOCHILD, STOP), 1.) 31 | 32 | 33 | def first_child_update(norm_counter: ndarray, dec_param: ndarray): 34 | all_param = dec_param.flatten() 35 | all_norm = norm_counter.flatten() 36 | mask = (all_param <= 0) | (0 <= all_norm) 37 | ratio = -all_param / all_norm 38 | ratio[mask] = 1. 39 | return np.min(ratio) 40 | 41 | 42 | def km_init(dataset: DataSet, n_token: int, smooth: float): 43 | # do not ask why? I do not know more than you. 44 | dec_param = np.zeros((n_token, 2, 2, 2)) 45 | root_param = np.zeros((n_token,)) 46 | trans_param = np.zeros((n_token, n_token, 2, 2)) 47 | 48 | norm_counter = np.full(dec_param.shape, smooth) 49 | change = np.zeros((max(dataset['seq_len'].content), 2)) 50 | sampler = ConstantTokenNumSampler(dataset['seq_len'].content, 1000000, -1, 0, force_same_len=True) 51 | data_iter = DataSetIter(dataset, batch_sampler=sampler, as_numpy=True) 52 | for x, y in data_iter: 53 | token_array = x['token'] 54 | batch_size, word_num = token_array.shape 55 | change.fill(0.) 56 | np.add.at(root_param, (token_array, ), 1. / word_num) 57 | if word_num > 1: 58 | for child_i in range(word_num): 59 | child_sum = get_harmonic_sum(child_i - 0) + get_harmonic_sum(word_num - child_i - 1) 60 | scale = (word_num - 1) / word_num / child_sum 61 | for head_i in range(word_num): 62 | if child_i == head_i: 63 | continue 64 | direction = 1 if head_i <= child_i else 0 65 | head_pos = token_array[:, head_i] 66 | child_pos = token_array[:, child_i] 67 | diff = scale / abs(head_i - child_i) 68 | np.add.at(trans_param, (head_pos, child_pos, direction), diff) 69 | change[head_i, direction] += diff 70 | update_decision(change, norm_counter, token_array, dec_param) 71 | 72 | trans_param += smooth 73 | dec_param += smooth 74 | root_param += smooth 75 | 76 | es = first_child_update(norm_counter, dec_param) 77 | norm_counter *= 0.9 * es 78 | dec_param += norm_counter 79 | 80 | root_param_sum = root_param.sum() 81 | trans_param_sum = trans_param.sum(1, keepdims=True) 82 | decision_param_sum = dec_param.sum(3, keepdims=True) 83 | 84 | root_param /= root_param_sum 85 | trans_param /= trans_param_sum 86 | dec_param /= decision_param_sum 87 | 88 | return np.log(dec_param), np.log(trans_param), np.log(root_param) 89 | -------------------------------------------------------------------------------- /src/model/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedding import EmbeddingAdaptor, Embedding 2 | from .fastnlp_embedding import FastNLPEmbeddingAdaptor, FastNLPCharEmbeddingAdaptor, FastNLPEmbeddingVariationalAdaptor 3 | from .transformers_embedding import TransformersAdaptor, TransformersEmbedding 4 | -------------------------------------------------------------------------------- /src/model/embedding/embedding.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import TYPE_CHECKING, Any, Dict, List 5 | 6 | import torch 7 | import torch.nn as nn 8 | from hydra.utils import instantiate 9 | from torch import Tensor 10 | 11 | from src.model.nn import IndependentDropout 12 | from src.utility.config import Config 13 | from src.utility.logger import get_logger_func 14 | 15 | if TYPE_CHECKING: 16 | from src.model import ModelBase 17 | from src.datamodule import DataModule 18 | from src.utility.var_pool import VarPool 19 | 20 | AnyDict = Dict[str, Any] 21 | 22 | _warn, _info, _debug = get_logger_func('embedding') 23 | 24 | 25 | @dataclass 26 | class EmbeddingItem: 27 | name: str 28 | field: str 29 | emb: EmbeddingAdaptor 30 | 31 | 32 | @dataclass 33 | class EmbeddingConfig(Config): 34 | use_word: bool 35 | use_tag: bool 36 | use_subword: bool # I believe we need only one subwords field.' 37 | dropout: 0. # when multi embedding, for each position, drop some entirely. 38 | # all other items are treated as EmbeddingItemConfig 39 | 40 | 41 | @dataclass 42 | class EmbeddingItemConfig(Config): 43 | args: AnyDict 44 | adaptor_args: AnyDict 45 | field: str 46 | requires_vocab: bool = True # pass vocab to embedding 47 | normalize_word: bool = False # pass the normalize_func(used by datamodule) to Embedding 48 | normalize_method: str = 'mean+std' # mean+std, mean, std, none 49 | normalize_time: str = 'nowhere' # when to normalize embedding, none, begin, epoch, batch 50 | 51 | 52 | class Embedding(torch.nn.Module): 53 | """Embedding, plus apply to different fields.""" 54 | bounded_model: ModelBase 55 | 56 | def __init__(self, dm: DataModule, **cfg): 57 | super().__init__() 58 | flags, emb_cfg = EmbeddingConfig.build(cfg, ignore_unknown=True) 59 | flags: EmbeddingConfig 60 | 61 | vocabs = dm.vocabs 62 | datasets = dm.datasets 63 | 64 | self.disabled_fields = set() 65 | if not flags.use_word: 66 | self.disabled_fields.add('word') 67 | if not flags.use_subword: 68 | self.disabled_fields.add('subword') 69 | if not flags.use_tag: 70 | self.disabled_fields.add('pos') 71 | 72 | # instantiate embeddings 73 | self.embeds: List[EmbeddingItem] = [] 74 | self.normalize_dict = {'nowhere': [], 'begin': [], 'epoch': [], 'batch': []} 75 | for name, cfg in emb_cfg.items(): 76 | if name.startswith('_') or cfg is None: 77 | continue 78 | cfg: EmbeddingItemConfig = EmbeddingItemConfig.build(cfg) 79 | if cfg.field in self.disabled_fields: 80 | continue 81 | instantiate_args = {} 82 | if cfg.requires_vocab: 83 | instantiate_args['vocab'] = vocabs[cfg.field] 84 | if cfg.normalize_word: 85 | instantiate_args['word_transform'] = dm.normalize_one_word_func 86 | emb = instantiate(cfg.args, **instantiate_args) 87 | emb = instantiate(cfg.adaptor_args, emb=emb) 88 | emb.process(vocabs, datasets) 89 | self.add_module(name, emb) 90 | self.embeds.append(EmbeddingItem(name, cfg.field, emb)) 91 | self.normalize_dict[cfg.normalize_time].append((name, cfg.normalize_method)) 92 | 93 | _info(f'Emb: {", ".join(e.name for e in self.embeds)}') 94 | _info(f'Normalize plan: {self.normalize_dict}') 95 | self.embed_size = sum(e.embed_size for e in self) 96 | 97 | if flags.dropout > 0: 98 | self.dropout_func = IndependentDropout(flags.dropout) 99 | else: 100 | self.dropout_func = lambda *x: x 101 | 102 | def forward(self, x, vp: VarPool): 103 | emb = list(self.dropout_func(*[item.emb(x[item.field], vp) for item in self.embeds])) 104 | seq_len = max(e.shape[1] for e in emb) 105 | assert all(e.shape[1] in (1, seq_len) for e in emb) 106 | for item, h in zip(self.embeds, emb): 107 | vp[item.name] = h 108 | for i in range(len(emb)): 109 | if emb[i].shape[1] == 1: 110 | emb[i] = emb[i].expand(-1, seq_len, -1) 111 | # from src.utility.fn import draw_att 112 | # draw_att(torch.cat(emb, dim=-1)[0]) 113 | return torch.cat(emb, dim=-1) 114 | 115 | def normalize(self, now): 116 | for name, method in self.normalize_dict[now]: 117 | getattr(self, name).normalize(method) 118 | 119 | def __getitem__(self, key): 120 | return self.embeds[key].emb 121 | 122 | def __iter__(self): 123 | return map(lambda e: e.emb, self.embeds) 124 | 125 | def __len__(self): 126 | return len(self.embeds) 127 | 128 | 129 | class EmbeddingAdaptor(nn.Module): 130 | device_indicator: Tensor 131 | singleton_emb = {} 132 | 133 | def __init__(self, emb): 134 | super().__init__() 135 | self.emb = emb 136 | self.register_buffer('device_indicator', torch.zeros(1)) 137 | 138 | self._normalize_warned = False 139 | 140 | @property 141 | def embed_size(self): 142 | raise NotImplementedError 143 | 144 | @property 145 | def device(self): 146 | return self.device_indicator.device 147 | 148 | def process(self, vocabs, datasets): 149 | return 150 | 151 | def forward(self, inputs: List[Any], vp: VarPool): 152 | raise NotImplementedError 153 | 154 | def normalize(self, method: str): 155 | if not self._normalize_warned: 156 | _warn(f"{type(self)} didn't implement normalize.") 157 | self._normalize_warned = True 158 | 159 | @staticmethod 160 | def _normalize(data: Tensor, method: str): 161 | with torch.no_grad(): 162 | if method == 'mean+std': 163 | std, mean = torch.std_mean(data, dim=0, keepdim=True) 164 | data.sub_(mean).divide_(std) 165 | elif method == 'mean': 166 | mean = torch.mean(data, dim=0, keepdim=True) 167 | data.sub_(mean) 168 | elif method == 'std': 169 | std = torch.std(data, dim=0, keepdim=True) 170 | data.divide_(std) 171 | else: 172 | raise ValueError(f'Unrecognized normalize method: {method}') 173 | 174 | @classmethod 175 | def get_singleton(cls, name, emb): 176 | if name in EmbeddingAdaptor.singleton_emb: 177 | return EmbeddingAdaptor.singleton_emb[name] 178 | EmbeddingAdaptor.singleton_emb[name] = emb = cls(emb) 179 | return emb 180 | -------------------------------------------------------------------------------- /src/model/embedding/fastnlp_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from fastNLP.embeddings import StaticEmbedding, TokenEmbedding, CNNCharEmbedding, LSTMCharEmbedding 6 | from torch import Tensor 7 | from torch.nn import Parameter 8 | 9 | from src.model.embedding.embedding import EmbeddingAdaptor 10 | from src.model.nn.multivariate_kl import MultVariateKLD 11 | from src.utility.var_pool import VarPool 12 | 13 | 14 | class FastNLPEmbeddingAdaptor(EmbeddingAdaptor): 15 | 16 | def __init__(self, emb: TokenEmbedding): 17 | super().__init__(emb) 18 | self._embed_size = self.emb.embed_size 19 | self._word_dropout = emb.word_dropout 20 | self._dropout = emb.dropout_layer.p 21 | self._normalize_weight = None 22 | 23 | @property 24 | def embed_size(self): 25 | return self._embed_size 26 | 27 | def forward(self, field: Tensor, vp: VarPool): 28 | return self.emb(field) 29 | 30 | def normalize(self, method): 31 | emb: torch.nn.Embedding = self.emb.embedding 32 | if hasattr(self.emb, 'mapped_counts'): 33 | self.emb: StaticEmbedding 34 | if self._normalize_weight is None: 35 | self._normalize_weight = (self.emb.mapped_counts / self.emb.mapped_counts.sum()).unsqueeze(-1) 36 | mean = (emb.weight.data * self._normalize_weight).sum() 37 | if method == 'mean': 38 | emb.weight.data.sub_(mean) 39 | else: 40 | std = (((emb.weight.data - mean).pow(2.) * self._normalize_weight).sum() + 1e-6).sqrt() 41 | if method == 'mean+std': 42 | emb.weight.data.sub_(mean) 43 | emb.weight.data.div_(std) 44 | else: 45 | padding_idx = self.emb.get_word_vocab().padding_idx 46 | start_idx = 1 if padding_idx == 0 else 0 47 | self._normalize(emb.weight.data[start_idx:], method) 48 | 49 | class FastNLPEmbeddingVariationalAdaptor(FastNLPEmbeddingAdaptor): 50 | def __init__(self, emb: TokenEmbedding, mode: str, out_dim: int): 51 | # mode: vae or ib 52 | super(FastNLPEmbeddingVariationalAdaptor, self).__init__(emb) 53 | self.mode = mode 54 | if self.mode != 'basic': 55 | self._embed_size = out_dim 56 | self.enc = nn.Linear(emb.embed_size, 2 * out_dim) 57 | if self.mode == 'ib': 58 | self.gaussian_kl = MultVariateKLD('sum') 59 | self.target_mean = Parameter(torch.zeros(1, out_dim)) 60 | self.target_lvar = Parameter(torch.zeros(1, out_dim)) 61 | 62 | def forward(self, field: Tensor, vp: VarPool): 63 | if self.mode == 'basic': 64 | return super().forward(field, vp) 65 | 66 | mean, lvar = torch.chunk(self.enc(self.emb(field)), 2, dim=-1) 67 | if self.training: 68 | z = torch.empty_like(mean).normal_() 69 | z = (0.5 * lvar).exp() * z + mean 70 | else: 71 | z = mean 72 | vp.kl = self.kl(mean, lvar) 73 | return z 74 | 75 | def kl(self, mean, lvar): 76 | if self.mode == 'ib': 77 | _mean, _lvar = mean.view(-1, self.embed_size), lvar.view(-1, self.embed_size) 78 | _b = len(_mean) 79 | return self.gaussian_kl(_mean, self.target_mean.expand(_b, -1), _lvar, self.target_lvar.expand(_b, -1)) 80 | else: 81 | return -0.5 * (lvar - torch.pow(mean, 2) - torch.exp(lvar) + 1).sum() 82 | 83 | 84 | class FastNLPCharEmbeddingAdaptor(FastNLPEmbeddingAdaptor): 85 | def normalize(self, method): 86 | self.emb: Union[CNNCharEmbedding, LSTMCharEmbedding] 87 | emb = self.emb.char_embedding 88 | start_idx = 1 if self.emb.char_pad_index == 0 else 0 89 | self._normalize(emb.weight.data[start_idx:], method) 90 | -------------------------------------------------------------------------------- /src/model/embedding/transformers_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from fastNLP import Padder 5 | from torch import Tensor 6 | 7 | from src.model.embedding.embedding import EmbeddingAdaptor 8 | from src.model.nn.scalar_mix import ScalarMix 9 | from src.utility.fn import pad 10 | from src.utility.var_pool import VarPool 11 | 12 | 13 | class TransformersAdaptor(EmbeddingAdaptor): 14 | def __init__(self, emb): 15 | super().__init__(emb) 16 | self.emb: TransformersEmbedding 17 | self._embed_size = self.emb.n_out 18 | self._dropout = self.emb.dropout 19 | 20 | @property 21 | def embed_size(self): 22 | return self._embed_size 23 | 24 | def process(self, vocabs, datasets): 25 | enable_transformers_embedding(datasets, self.emb.tokenizer) 26 | 27 | def forward(self, field: Tensor, vp: VarPool): 28 | return self.emb(field)[:, 1: -1] 29 | 30 | 31 | def enable_transformers_embedding(datasets, tokenizer, fix_len=20): 32 | def get_subwords(_words): 33 | sws = [tokenizer.convert_tokens_to_ids(tokenizer.tokenize(w)[:fix_len]) for w in _words] 34 | sws = [[tokenizer.cls_token_id]] + sws + [[tokenizer.sep_token_id]] 35 | sws = list(map(lambda x: torch.tensor(x, dtype=torch.long), sws)) 36 | return pad(sws, tokenizer.pad_token_id).numpy() 37 | 38 | for ds in datasets.values(): 39 | ds.apply_field(get_subwords, 40 | 'raw_word', 41 | 'subword', 42 | is_input=True, 43 | padder=SubWordsPadder(tokenizer.pad_token_id)) 44 | 45 | 46 | class SubWordsPadder(Padder): 47 | def __call__(self, contents, field_name, field_ele_dtype, dim: int): 48 | batch_size, dtype = len(contents), type(contents[0][0][0]) 49 | max_len0, max_len1 = max(c.shape[0] for c in contents), max(c.shape[1] for c in contents) 50 | padded_array = np.full((batch_size, max_len0, max_len1), fill_value=self.pad_val, dtype=dtype) 51 | for b_idx, matrix in enumerate(contents): 52 | padded_array[b_idx, :matrix.shape[0], :matrix.shape[1]] = matrix 53 | return padded_array 54 | 55 | 56 | class TransformersEmbedding(nn.Module): 57 | r""" By Zhang Yu 58 | A nn that directly utilizes the pretrained models in `transformers`_ to produce BERT representations. 59 | While mainly tailored to provide input preparation and post-processing for the BERT model, 60 | it is also compatiable with other pretrained language models like XLNet, RoBERTa and ELECTRA, etc. 61 | 62 | Args: 63 | model (str): 64 | Path or name of the pretrained models registered in `transformers`_, e.g., ``'bert-base-cased'``. 65 | n_layers (int): 66 | The number of layers from the model to use. 67 | If 0, uses all layers. 68 | n_out (int): 69 | The requested size of the embeddings. Default: 0. 70 | If 0, uses the size of the pretrained embedding model. 71 | stride (int): 72 | A sequence longer than max length will be splitted into several small pieces 73 | with a window size of ``stride``. Default: 10. 74 | pooling (str): 75 | Pooling way to get from token piece embeddings to token embedding. 76 | Either take the first subtoken ('first'), the last subtoken ('last'), or a mean over all ('mean'). 77 | Default: 'mean'. 78 | dropout (float): 79 | The dropout ratio of BERT layers. Default: 0. 80 | This value will be passed into the :class:`ScalarMix` layer. 81 | requires_grad (bool): 82 | If ``True``, the model parameters will be updated together with the downstream task. 83 | Default: ``False``. 84 | 85 | .. _transformers: 86 | https://github.com/huggingface/transformers 87 | """ 88 | 89 | def __init__(self, 90 | model, 91 | n_layers, 92 | n_out=0, 93 | stride=256, 94 | pooling='mean', 95 | dropout=0, 96 | requires_grad=False): 97 | super().__init__() 98 | 99 | from transformers import AutoConfig, AutoModel, AutoTokenizer 100 | self.bert = AutoModel.from_pretrained(model, 101 | config=AutoConfig.from_pretrained(model, output_hidden_states=True)) 102 | self.bert = self.bert.requires_grad_(requires_grad) 103 | 104 | self.model = model 105 | self.n_layers = n_layers or self.bert.config.num_hidden_layers 106 | self.hidden_size = self.bert.config.hidden_size 107 | self.n_out = n_out or self.hidden_size 108 | self.stride = stride 109 | self.pooling = pooling 110 | self.dropout = dropout 111 | self.requires_grad = requires_grad 112 | self.max_len = int(max(0, self.bert.config.max_position_embeddings) or 1e12) - 2 113 | 114 | self.tokenizer = AutoTokenizer.from_pretrained(model) 115 | self.pad_index = self.tokenizer.pad_token_id 116 | # assert self.pad_index == pad_index 117 | 118 | self.scalar_mix = ScalarMix(self.n_layers, dropout) 119 | self.projection = nn.Linear(self.hidden_size, self.n_out, False) \ 120 | if self.hidden_size != self.n_out else nn.Identity() 121 | 122 | def forward(self, subwords): 123 | r""" 124 | Args: 125 | subwords (~torch.Tensor): ``[batch_size, seq_len, fix_len]``. 126 | Returns: 127 | ~torch.Tensor: 128 | BERT embeddings of shape ``[batch_size, seq_len, n_out]``. 129 | """ 130 | mask = subwords.ne(self.pad_index) 131 | lens = mask.sum((1, 2)) 132 | # [batch_size, n_subwords] 133 | subwords = pad(subwords[mask].split(lens.tolist()), self.pad_index, padding_side=self.tokenizer.padding_side) 134 | bert_mask = pad(mask[mask].split(lens.tolist()), 0, padding_side=self.tokenizer.padding_side) 135 | 136 | # return the hidden states of all layers 137 | bert = self.bert(subwords[:, :self.max_len], attention_mask=bert_mask[:, :self.max_len].float())[-1] 138 | # [n_layers, batch_size, max_len, hidden_size] 139 | bert = bert[-self.n_layers:] 140 | # [batch_size, max_len, hidden_size] 141 | bert = self.scalar_mix(bert) 142 | # [batch_size, n_subwords, hidden_size] 143 | for i in range(self.stride, 144 | (subwords.shape[1] - self.max_len + self.stride - 1) // self.stride * self.stride + 1, 145 | self.stride): 146 | part = self.bert( 147 | subwords[:, i:i + self.max_len], 148 | attention_mask=bert_mask[:, i:i + self.max_len].float(), 149 | )[-1] 150 | bert = torch.cat((bert, self.scalar_mix(part[-self.n_layers:])[:, self.max_len - self.stride:]), 1) 151 | 152 | # [batch_size, n_subwords] 153 | bert_lens = mask.sum(-1) 154 | bert_lens = bert_lens.masked_fill_(bert_lens.eq(0), 1) 155 | # [batch_size, seq_len, fix_len, hidden_size] 156 | embed = bert.new_zeros(*mask.shape, self.hidden_size).masked_scatter_(mask.unsqueeze(-1), bert[bert_mask]) 157 | # [batch_size, seq_len, hidden_size] 158 | if self.pooling == 'first': 159 | embed = embed[:, :, 0] 160 | elif self.pooling == 'last': 161 | embed = embed \ 162 | .gather(2, (bert_lens - 1).unsqueeze(-1).repeat(1, 1, self.hidden_size).unsqueeze(2)) \ 163 | .squeeze(2) 164 | else: 165 | embed = embed.sum(2) / bert_lens.unsqueeze(-1) 166 | embed = self.projection(embed) 167 | 168 | return embed 169 | -------------------------------------------------------------------------------- /src/model/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .affine import Biaffine 2 | from .common import ResLayer, MLP 3 | from .dmv_spec import DMVSkipConnectEncoder, DMVFactorizedBilinear 4 | from .dropout import SharedDropout, IndependentDropout 5 | from .scalar_mix import ScalarMix 6 | from .variational_lstm import VariationalLSTM 7 | from .affine_scorer import BiaffineScorer 8 | -------------------------------------------------------------------------------- /src/model/nn/affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class Biaffine(nn.Module): 8 | r""" 9 | Biaffine layer for first-order scoring :cite:`dozat-etal-2017-biaffine`. 10 | 11 | This function has a tensor of weights :math:`W` and bias terms if needed. 12 | The score :math:`s(x, y)` of the vector pair :math:`(x, y)` is computed as :math:`x^T W y`. 13 | :math:`x` and :math:`y` can be concatenated with bias terms. 14 | 15 | Args: 16 | n_in (int): 17 | The size of the input feature. 18 | n_out (int): 19 | The number of output channels. 20 | bias_x (bool): 21 | If ``True``, adds a bias term for tensor :math:`x`. Default: ``True``. 22 | bias_y (bool): 23 | If ``True``, adds a bias term for tensor :math:`y`. Default: ``True``. 24 | """ 25 | 26 | def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True): 27 | super().__init__() 28 | 29 | self.n_in = n_in 30 | self.n_out = n_out 31 | self.bias_x = bias_x 32 | self.bias_y = bias_y 33 | self.weight = nn.Parameter(torch.Tensor(n_out, n_in + bias_x, n_in + bias_y)) 34 | 35 | self.reset_parameters() 36 | 37 | def __repr__(self): 38 | s = f'n_in={self.n_in}' 39 | if self.n_out > 1: 40 | s += f', n_out={self.n_out}' 41 | if self.bias_x: 42 | s += f', bias_x={self.bias_x}' 43 | if self.bias_y: 44 | s += f', bias_y={self.bias_y}' 45 | 46 | return f'{self.__class__.__name__}({s})' 47 | 48 | def reset_parameters(self): 49 | nn.init.zeros_(self.weight) 50 | 51 | def forward(self, x, y): 52 | r""" 53 | Args: 54 | x (torch.Tensor): ``[batch_size, seq_len, n_in]``. 55 | y (torch.Tensor): ``[batch_size, seq_len, n_in]``. 56 | 57 | Returns: 58 | ~torch.Tensor: 59 | A scoring tensor of shape ``[batch_size, n_out, seq_len, seq_len]``. 60 | If ``n_out=1``, the dimension for ``n_out`` will be squeezed automatically. 61 | """ 62 | 63 | if self.bias_x: 64 | x = torch.cat((x, torch.ones_like(x[..., :1])), -1) 65 | if self.bias_y: 66 | y = torch.cat((y, torch.ones_like(y[..., :1])), -1) 67 | # [batch_size, n_out, seq_len, seq_len] 68 | s = torch.einsum('bxi,oij,byj->boxy', x, self.weight, y) 69 | # remove dim 1 if n_out == 1 70 | s = s.squeeze(1) 71 | 72 | return s 73 | 74 | 75 | -------------------------------------------------------------------------------- /src/model/nn/affine_scorer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .affine import Biaffine 8 | from .common import MLP 9 | 10 | 11 | class BiaffineScorer(nn.Module): 12 | def __init__(self, 13 | n_in, 14 | hidden_dim, 15 | out_dim, 16 | mlp_dropout, 17 | mlp_activate, 18 | scale): 19 | super().__init__() 20 | self.mlp_dropout = mlp_dropout 21 | self.mlp1 = MLP(n_in // 2, hidden_dim, mlp_dropout, mlp_activate) 22 | self.mlp2 = MLP(n_in // 2, hidden_dim, mlp_dropout, mlp_activate) 23 | self.affine = Biaffine(hidden_dim, out_dim, bias_x=True, bias_y=out_dim > 1) 24 | self.register_buffer('scale', 1 / torch.tensor(hidden_dim if scale else 1).pow(0.25)) 25 | self.n_out = out_dim 26 | 27 | def reset_parameters(self): 28 | nn.init.zeros_(self.affine.weight) 29 | self.affine.weight.diagonal().one_() 30 | 31 | def forward(self, x, x2): 32 | h1 = self.mlp1(x) * self.scale 33 | h2 = self.mlp2(x2) * self.scale 34 | out = self.affine(h1, h2).permute(0, 2, 3, 1) 35 | return out 36 | -------------------------------------------------------------------------------- /src/model/nn/common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | 4 | from src.model.nn.dropout import SharedDropout 5 | 6 | 7 | class ResLayer(nn.Module): 8 | def __init__(self, n_in, n_hidden, activate=True): 9 | super(ResLayer, self).__init__() 10 | self.linear = nn.Sequential( 11 | nn.Linear(n_in, n_hidden), 12 | nn.ReLU(), 13 | nn.Linear(n_hidden, n_hidden), 14 | nn.ReLU(), 15 | ) 16 | self.n_out = n_hidden 17 | self.activation = nn.LeakyReLU() if activate else nn.Identity() 18 | 19 | def forward(self, x): 20 | return self.activation(self.linear(x)) + x 21 | 22 | 23 | class MLP(nn.Module): 24 | def __init__(self, n_in, n_hidden, dropout=0, activate=True): 25 | super(MLP, self).__init__() 26 | 27 | self.n_in = n_in 28 | self.n_hidden = n_hidden 29 | 30 | self.linear = nn.Linear(n_in, n_hidden) 31 | self.activation = nn.LeakyReLU() if activate else nn.Identity() 32 | self.dropout = SharedDropout(p=dropout) if dropout > 0 else nn.Identity() 33 | self.n_out = n_hidden 34 | self.reset_parameters() 35 | 36 | def __repr__(self): 37 | s = f"n_in={self.n_in}, n_out={self.n_hidden}" 38 | if isinstance(self.dropout, SharedDropout): 39 | s += f", dropout={self.dropout.p}" 40 | 41 | return f"{self.__class__.__name__}({s})" 42 | 43 | def reset_parameters(self): 44 | nn.init.orthogonal_(self.linear.weight) 45 | nn.init.zeros_(self.linear.bias) 46 | 47 | def forward(self, x: Tensor) -> Tensor: 48 | x = self.linear(x) 49 | x = self.activation(x) 50 | x = self.dropout(x) 51 | return x 52 | -------------------------------------------------------------------------------- /src/model/nn/dmv_spec.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | class DMVSkipConnectEncoder(nn.Module): 7 | def __init__(self, hidden_size, n_bottleneck=0, n_mid=0, dropout=0.): 8 | super().__init__() 9 | self.hidden_size = hidden_size 10 | self.activate = nn.LeakyReLU() 11 | self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() 12 | self.n_out = hidden_size 13 | 14 | # To encode valence information 15 | if n_bottleneck == 0: 16 | self.HASCHILD_linear = nn.Linear(self.hidden_size, self.hidden_size) 17 | self.NOCHILD_linear = nn.Linear(self.hidden_size, self.hidden_size) 18 | else: 19 | self.HASCHILD_linear = self.create_bottleneck(self.hidden_size, n_bottleneck) 20 | self.NOCHILD_linear = self.create_bottleneck(self.hidden_size, n_bottleneck) 21 | self.valence_linear = nn.Linear(self.hidden_size, self.hidden_size) 22 | 23 | # To encode direction information 24 | if n_bottleneck == 0: 25 | self.LEFT_linear = nn.Linear(self.hidden_size, self.hidden_size) 26 | self.RIGHT_linear = nn.Linear(self.hidden_size, self.hidden_size) 27 | else: 28 | self.LEFT_linear = self.create_bottleneck(self.hidden_size, n_bottleneck) 29 | self.RIGHT_linear = self.create_bottleneck(self.hidden_size, n_bottleneck) 30 | self.direction_linear = nn.Linear(self.hidden_size, self.hidden_size) 31 | 32 | # To produce final hidden representation 33 | n_mid = n_mid if n_mid else hidden_size 34 | self.linear1 = nn.Linear(self.hidden_size, n_mid) 35 | self.linear2 = nn.Linear(n_mid, self.hidden_size) 36 | 37 | def forward(self, x: Tensor): 38 | # input: ... x len x hidden1 39 | # output: ... x len x dir x val x hidden2 40 | has_child = self.HASCHILD_linear(x) + x 41 | no_child = self.NOCHILD_linear(x) + x 42 | h = torch.cat([no_child.unsqueeze(-2), has_child.unsqueeze(-2)], dim=-2) 43 | h = self.activate(self.valence_linear(self.activate(h))) 44 | 45 | x = x.unsqueeze(-2) 46 | left_h = self.LEFT_linear(h) + x 47 | right_h = self.RIGHT_linear(h) + x 48 | h = torch.cat([left_h.unsqueeze(-3), right_h.unsqueeze(-3)], dim=-3) 49 | h = self.activate(self.direction_linear(self.activate(h))) 50 | 51 | h = self.dropout(h) 52 | return self.linear2(self.activate(self.linear1(h))) 53 | 54 | @staticmethod 55 | def create_bottleneck(n_in_out, n_bottleneck): 56 | return nn.Sequential(nn.Linear(n_in_out, n_bottleneck), nn.Linear(n_bottleneck, n_in_out)) 57 | 58 | 59 | class DMVFactorizedBilinear(nn.Module): 60 | def __init__(self, n_in, n_in2=None, r=64): 61 | super(DMVFactorizedBilinear, self).__init__() 62 | self.n_in = n_in 63 | self.n_in2 = n_in2 if n_in2 else n_in 64 | self.r = r 65 | self.project1 = nn.Linear(self.n_in, self.r) 66 | self.project2 = nn.Linear(self.n_in2, self.r) 67 | 68 | def forward(self, x1, x2): 69 | x1 = self.project1(x1) 70 | x2 = self.project2(x2) 71 | if len(x1.shape) == 5: 72 | return torch.einsum("bhdve, bcdve -> bhcdv", x1, x2) 73 | elif len(x1.shape) == 4: 74 | return torch.einsum("hdve, cdve -> hcdv", x1, x2) 75 | else: 76 | raise NotImplementedError 77 | -------------------------------------------------------------------------------- /src/model/nn/dropout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SharedDropout(nn.Module): 6 | r""" 7 | SharedDropout differs from the vanilla dropout strategy in that 8 | the dropout mask is shared across one dimension. 9 | 10 | Args: 11 | p (float): 12 | The probability of an element to be zeroed. Default: 0.5. 13 | batch_first (bool): 14 | If ``True``, the input and output tensors are provided as ``[batch_size, seq_len, *]``. 15 | Default: ``True``. 16 | 17 | Examples: 18 | >>> x = torch.ones(1, 3, 5) 19 | >>> nn.Dropout()(x) 20 | tensor([[[0., 2., 2., 0., 0.], 21 | [2., 2., 0., 2., 2.], 22 | [2., 2., 2., 2., 0.]]]) 23 | >>> SharedDropout()(x) 24 | tensor([[[2., 0., 2., 0., 2.], 25 | [2., 0., 2., 0., 2.], 26 | [2., 0., 2., 0., 2.]]]) 27 | """ 28 | 29 | def __init__(self, p=0.5, batch_first=True): 30 | super().__init__() 31 | 32 | self.p = p 33 | self.batch_first = batch_first 34 | 35 | def __repr__(self): 36 | s = f'p={self.p}' 37 | if self.batch_first: 38 | s += f', batch_first={self.batch_first}' 39 | 40 | return f'{self.__class__.__name__}({s})' 41 | 42 | def forward(self, x): 43 | r""" 44 | Args: 45 | x (~torch.Tensor): 46 | A tensor of any shape. 47 | Returns: 48 | The returned tensor is of the same shape as `x`. 49 | """ 50 | 51 | if self.training: 52 | if self.batch_first: 53 | mask = self.get_mask(x[:, 0], self.p).unsqueeze(1) 54 | else: 55 | mask = self.get_mask(x[0], self.p) 56 | x = x * mask 57 | 58 | return x 59 | 60 | @staticmethod 61 | def get_mask(x, p): 62 | return x.new_empty(x.shape).bernoulli_(1 - p) / (1 - p) 63 | 64 | 65 | class IndependentDropout(nn.Module): 66 | r""" 67 | For :math:`N` tensors, they use different dropout masks respectively. 68 | When :math:`N-M` of them are dropped, the remaining :math:`M` ones are scaled by a factor of :math:`N/M` 69 | to compensate, and when all of them are dropped together, zeros are returned. 70 | 71 | Args: 72 | p (float): 73 | The probability of an element to be zeroed. Default: 0.5. 74 | 75 | Examples: 76 | >>> x, y = torch.ones(1, 3, 5), torch.ones(1, 3, 5) 77 | >>> x, y = IndependentDropout()(x, y) 78 | >>> x 79 | tensor([[[1., 1., 1., 1., 1.], 80 | [0., 0., 0., 0., 0.], 81 | [2., 2., 2., 2., 2.]]]) 82 | >>> y 83 | tensor([[[1., 1., 1., 1., 1.], 84 | [2., 2., 2., 2., 2.], 85 | [0., 0., 0., 0., 0.]]]) 86 | """ 87 | 88 | def __init__(self, p=0.5): 89 | super().__init__() 90 | 91 | self.p = p 92 | 93 | def __repr__(self): 94 | return f'{self.__class__.__name__}(p={self.p})' 95 | 96 | def forward(self, *items): 97 | r""" 98 | Args: 99 | items (list[~torch.Tensor]): 100 | A list of tensors that have the same shape except the last dimension. 101 | Returns: 102 | The returned tensors are of the same shape as `items`. 103 | """ 104 | 105 | if self.training: 106 | masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p) for x in items] 107 | total = sum(masks) 108 | scale = len(items) / total.max(torch.ones_like(total)) 109 | masks = [mask * scale for mask in masks] 110 | items = [item * mask.unsqueeze(-1) for item, mask in zip(items, masks)] 111 | 112 | return items 113 | -------------------------------------------------------------------------------- /src/model/nn/multivariate_kl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultVariateKLD(torch.nn.Module): 5 | def __init__(self, reduction): 6 | super(MultVariateKLD, self).__init__() 7 | self.reduction = reduction 8 | 9 | def forward(self, mu1, mu2, logvar_1, logvar_2): 10 | mu1, mu2 = mu1.type(dtype=torch.float64), mu2.type(dtype=torch.float64) 11 | sigma_1 = logvar_1.exp().type(dtype=torch.float64) 12 | sigma_2 = logvar_2.exp().type(dtype=torch.float64) 13 | 14 | sigma_diag_1 = torch.diag_embed(sigma_1, offset=0, dim1=-2, dim2=-1) 15 | sigma_diag_2 = torch.diag_embed(sigma_2, offset=0, dim1=-2, dim2=-1) 16 | 17 | sigma_diag_2_inv = sigma_diag_2.inverse() 18 | 19 | # log(det(sigma2^T)/det(sigma1)) 20 | term_1 = (sigma_diag_2.det() / sigma_diag_1.det()).log() 21 | # term_1[term_1.ne(term_1)] = 0 22 | 23 | # trace(inv(sigma2)*sigma1) 24 | term_2 = torch.diagonal((torch.matmul(sigma_diag_2_inv, sigma_diag_1)), dim1=-2, dim2=-1).sum(-1) 25 | 26 | # (mu2-m1)^T*inv(sigma2)*(mu2-mu1) 27 | term_3 = torch.matmul(torch.matmul((mu2 - mu1).unsqueeze(-1).transpose(2, 1), sigma_diag_2_inv), 28 | (mu2 - mu1).unsqueeze(-1)).flatten() 29 | 30 | # dimension of embedded space (number of mus and sigmas) 31 | n = mu1.shape[1] 32 | 33 | # Calc kl divergence on entire batch 34 | kl = 0.5 * (term_1 - n + term_2 + term_3) 35 | 36 | # Calculate mean kl_d loss 37 | if self.reduction == 'mean': 38 | kl_agg = torch.mean(kl) 39 | elif self.reduction == 'sum': 40 | kl_agg = torch.sum(kl) 41 | else: 42 | raise NotImplementedError(f'Reduction type not implemented: {self.reduction}') 43 | 44 | return kl_agg 45 | -------------------------------------------------------------------------------- /src/model/nn/scalar_mix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ScalarMix(nn.Module): 6 | r""" 7 | Computes a parameterised scalar mixture of :math:`N` tensors, :math:`mixture = \gamma * \sum_{k}(s_k * tensor_k)` 8 | where :math:`s = \mathrm{softmax}(w)`, with :math:`w` and :math:`\gamma` scalar parameters. 9 | 10 | Args: 11 | n_layers (int): 12 | The number of layers to be mixed, i.e., :math:`N`. 13 | dropout (float): 14 | The dropout ratio of the layer weights. 15 | If dropout > 0, then for each scalar weight, adjust its softmax weight mass to 0 16 | with the dropout probability (i.e., setting the unnormalized weight to -inf). 17 | This effectively redistributes the dropped probability mass to all other weights. 18 | Default: 0. 19 | """ 20 | 21 | def __init__(self, n_layers, dropout=0): 22 | super().__init__() 23 | 24 | self.n_layers = n_layers 25 | 26 | self.weights = nn.Parameter(torch.zeros(n_layers)) 27 | self.gamma = nn.Parameter(torch.tensor([1.0])) 28 | self.dropout_func = nn.Dropout(dropout) 29 | 30 | def __repr__(self): 31 | s = f'n_layers={self.n_layers}' 32 | if self.dropout_func.p > 0: 33 | s += f', dropout={self.dropout_func.p}' 34 | 35 | return f'{self.__class__.__name__}({s})' 36 | 37 | def forward(self, tensors): 38 | r""" 39 | Args: 40 | tensors (list[~torch.Tensor]): 41 | :math:`N` tensors to be mixed. 42 | 43 | Returns: 44 | The mixture of :math:`N` tensors. 45 | """ 46 | 47 | normed_weights = self.dropout_func(self.weights.softmax(-1)) 48 | weighted_sum = sum(w * h for w, h in zip(normed_weights, tensors)) 49 | 50 | return self.gamma * weighted_sum 51 | -------------------------------------------------------------------------------- /src/model/nn/variational_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import PackedSequence 4 | from src.model.nn.dropout import SharedDropout 5 | 6 | 7 | class VariationalLSTM(nn.Module): 8 | r""" 9 | VariationalLSTM :cite:`yarin-etal-2016-dropout` is an variant of the vanilla bidirectional LSTM 10 | adopted by Biaffine Parser with the only difference of the dropout strategy. 11 | It drops nodes in the LSTM layers (input and recurrent connections) 12 | and applies the same dropout mask at every recurrent timesteps. 13 | APIs are roughly the same as :class:`~torch.nn.LSTM` except that we only allows 14 | :class:`~torch.nn.utils.rnn.PackedSequence` as input. 15 | Args: 16 | input_size (int): 17 | The number of expected features in the input. 18 | hidden_size (int): 19 | The number of features in the hidden state `h`. 20 | num_layers (int): 21 | The number of recurrent layers. Default: 1. 22 | dropout (float): 23 | If non-zero, introduces a :class:`SharedDropout` layer on the outputs of each LSTM layer (except last). 24 | Default: 0. 25 | """ 26 | 27 | def __init__(self, input_size, hidden_size, num_layers=1, dropout=0, cell=nn.LSTMCell, init='zy'): 28 | super().__init__() 29 | 30 | self.input_size = input_size 31 | self.hidden_size = hidden_size 32 | self.num_layers = num_layers 33 | self.dropout = dropout 34 | self.init = init 35 | 36 | self.f_cells = nn.ModuleList() 37 | self.b_cells = nn.ModuleList() 38 | for _ in range(self.num_layers): 39 | self.f_cells.append(cell(input_size=input_size, hidden_size=hidden_size)) 40 | self.b_cells.append(cell(input_size=input_size, hidden_size=hidden_size)) 41 | input_size = hidden_size * 2 42 | 43 | self.reset_parameters() 44 | 45 | def __repr__(self): 46 | s = f'{self.input_size}, {2 * self.hidden_size}' 47 | if self.num_layers > 1: 48 | s += f', num_layers={self.num_layers}' 49 | if self.dropout > 0: 50 | s += f', dropout={self.dropout}' 51 | 52 | return f'{self.__class__.__name__}({s})' 53 | 54 | def reset_parameters(self): 55 | if self.init == 'zy': 56 | for name, param in self.named_parameters(): 57 | if name.startswith('lstm'): 58 | # apply orthogonal_ to weight 59 | if len(param.shape) > 1: 60 | nn.init.orthogonal_(param) 61 | # apply zeros_ to bias 62 | else: 63 | nn.init.zeros_(param) 64 | elif self.init == 'biased': 65 | for name, param in self.named_parameters(): 66 | if name.startswith('lstm'): 67 | # apply orthogonal_ to weight 68 | if len(param.shape) > 1: 69 | nn.init.xavier_uniform_(param) 70 | else: 71 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 72 | param.data.fill_(0.) 73 | n = param.shape[0] 74 | start, end = n // 4, n // 2 75 | param.data[start:end].fill_(1.) 76 | else: 77 | raise ValueError(f'Bad init_version, {self.cfg.init_version=}') 78 | 79 | def layer_forward(self, x, hx, cell, batch_sizes, reverse=False): 80 | hx_0 = hx_i = hx 81 | hx_n, output = [], [] 82 | steps = reversed(range(len(x))) if reverse else range(len(x)) 83 | if self.training: 84 | hid_mask = SharedDropout.get_mask(hx_0[0], self.dropout) 85 | 86 | for t in steps: 87 | last_batch_size, batch_size = len(hx_i[0]), batch_sizes[t] 88 | if last_batch_size < batch_size: 89 | hx_i = [torch.cat((h, ih[last_batch_size:batch_size])) for h, ih in zip(hx_i, hx_0)] 90 | else: 91 | hx_n.append([h[batch_size:] for h in hx_i]) 92 | hx_i = [h[:batch_size] for h in hx_i] 93 | hx_i = [h for h in cell(x[t], hx_i)] 94 | output.append(hx_i[0]) 95 | if self.training: 96 | hx_i[0] = hx_i[0] * hid_mask[:batch_size] 97 | if reverse: 98 | hx_n = hx_i 99 | output.reverse() 100 | else: 101 | hx_n.append(hx_i) 102 | hx_n = [torch.cat(h) for h in zip(*reversed(hx_n))] 103 | output = torch.cat(output) 104 | 105 | return output, hx_n 106 | 107 | def forward(self, sequence: PackedSequence, hx=None): 108 | r""" 109 | Args: 110 | sequence (~torch.nn.utils.rnn.PackedSequence): 111 | A packed variable length sequence. 112 | hx (~torch.Tensor, ~torch.Tensor): 113 | A tuple composed of two tensors `h` and `c`. 114 | `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial hidden state 115 | for each element in the batch. 116 | `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` holds the initial cell state 117 | for each element in the batch. 118 | If `hx` is not provided, both `h` and `c` default to zero. 119 | Default: ``None``. 120 | Returns: 121 | ~torch.nn.utils.rnn.PackedSequence, (~List[torch.Tensor], ~torch.Tensor): 122 | The first is a list of packed variable length sequence for each layer. 123 | The second is a tuple of tensors `h` and `c`. 124 | `h` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` 125 | holds the hidden state for `t=seq_len`. 126 | Like output, the layers can be separated using 127 | ``h.view(num_layers, num_directions, batch_size, hidden_size)`` 128 | and similarly for c. 129 | `c` of shape ``[num_layers*num_directions, batch_size, hidden_size]`` 130 | holds the cell state for `t=seq_len`. 131 | """ 132 | x, batch_sizes = sequence.data, sequence.batch_sizes.tolist() 133 | batch_size = batch_sizes[0] 134 | h_n, c_n, hiddens = [], [], [] 135 | 136 | if hx is None: 137 | ih = x.new_zeros(self.num_layers * 2, batch_size, self.hidden_size) 138 | h, c = ih, ih 139 | else: 140 | h, c = hx 141 | h = h.view(self.num_layers, 2, batch_size, self.hidden_size) 142 | c = c.view(self.num_layers, 2, batch_size, self.hidden_size) 143 | 144 | for i in range(self.num_layers): 145 | x = torch.split(x, batch_sizes) 146 | if self.training and i > 0: 147 | mask = SharedDropout.get_mask(x[0], self.dropout) 148 | x = [i * mask[:len(i)] for i in x] 149 | x_i, (h_i, c_i) = self.layer_forward(x, (h[i, 0], c[i, 0]), self.f_cells[i], batch_sizes) 150 | x_b, (h_b, c_b) = self.layer_forward(x, (h[i, 1], c[i, 1]), self.b_cells[i], batch_sizes, True) 151 | x_i = torch.cat((x_i, x_b), -1) 152 | h_i = torch.stack((h_i, h_b)) 153 | c_i = torch.stack((c_i, c_b)) 154 | x = x_i 155 | h_n.append(h_i) 156 | c_n.append(c_i) 157 | hiddens.append( 158 | PackedSequence(x_i, sequence.batch_sizes, sequence.sorted_indices, sequence.unsorted_indices)) 159 | 160 | hx = torch.cat(h_n, 0), torch.cat(c_n, 0) 161 | return hiddens, hx 162 | -------------------------------------------------------------------------------- /src/model/text_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.text_encoder.base import EncoderBase 2 | from src.model.text_encoder.rnn_encoder import RNNEncoder 3 | from src.model.text_encoder.mlp_encoder import MLPEncoder 4 | from src.model.text_encoder.blank_encoder import BlankEncoder -------------------------------------------------------------------------------- /src/model/text_encoder/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import torch.nn as nn 6 | 7 | if TYPE_CHECKING: 8 | from src.model.embedding import Embedding 9 | from src.model import ModelBase 10 | 11 | 12 | class EncoderBase(nn.Module): 13 | bounded_embedding: Embedding 14 | bounded_model: ModelBase 15 | 16 | def __init__(self, embedding: Embedding): 17 | super().__init__() 18 | self.__dict__['bounded_embedding'] = embedding 19 | 20 | def forward(self, x, ctx): 21 | raise NotImplementedError 22 | 23 | def get_dim(self, field): 24 | raise NotImplementedError(f'Unrecognized {field=}') 25 | 26 | -------------------------------------------------------------------------------- /src/model/text_encoder/blank_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from src.model.embedding import Embedding 9 | from src.model.text_encoder.base import EncoderBase 10 | from src.model.nn import SharedDropout 11 | from src.utility.config import Config 12 | from src.utility.logger import get_logger_func 13 | from src.utility.var_pool import VarPool 14 | 15 | _warn, _info, _debug = get_logger_func('encoder') 16 | 17 | 18 | @dataclass 19 | class BlankEncoderConfig(Config): 20 | dropout: float 21 | shared_dropout: float 22 | 23 | 24 | class BlankEncoder(EncoderBase): 25 | 26 | def __init__(self, embedding: Embedding, **cfg): 27 | super().__init__(embedding) 28 | self.cfg = cfg = BlankEncoderConfig.build(cfg) 29 | self.output_size = embedding.embed_size 30 | self.dropout = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity() 31 | self.shared_dropout = SharedDropout(cfg.dropout) if cfg.shared_dropout > 0 else nn.Identity() 32 | 33 | def forward(self, x: Tensor, vp: VarPool, hiddens=None): 34 | x = self.dropout(x) 35 | x = self.shared_dropout(x) 36 | return {'x': x} 37 | 38 | def get_dim(self, field): 39 | return self.output_size 40 | -------------------------------------------------------------------------------- /src/model/text_encoder/mlp_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | 5 | import torch.nn as nn 6 | from torch import Tensor 7 | 8 | from src.model.embedding import Embedding 9 | from src.model.text_encoder.base import EncoderBase 10 | from src.model.nn import SharedDropout 11 | from src.utility.config import Config 12 | from src.utility.logger import get_logger_func 13 | from src.utility.var_pool import VarPool 14 | 15 | _warn, _info, _debug = get_logger_func('encoder') 16 | 17 | 18 | @dataclass 19 | class MLPEncoderConfig(Config): 20 | dropout: float 21 | n_hidden: int 22 | shared_dropout: float 23 | 24 | 25 | class MLPEncoder(EncoderBase): 26 | 27 | def __init__(self, embedding: Embedding, **cfg): 28 | super().__init__(embedding) 29 | self.cfg = cfg = MLPEncoderConfig.build(cfg) 30 | self.output_size = cfg.n_hidden 31 | self.linear = nn.Linear(embedding.embed_size, self.output_size, bias=False) 32 | self.dropout = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity() 33 | self.shared_dropout = SharedDropout(cfg.dropout) if cfg.shared_dropout > 0 else nn.Identity() 34 | 35 | def forward(self, x: Tensor, vp: VarPool, hiddens=None): 36 | x = self.dropout(x) 37 | x = self.shared_dropout(x) 38 | x = self.linear(x) 39 | return {'x': x} 40 | 41 | def get_dim(self, field): 42 | return self.output_size 43 | -------------------------------------------------------------------------------- /src/model/text_encoder/multi_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | 5 | from .base import EncoderBase 6 | 7 | 8 | class MultiEncoder(EncoderBase): 9 | """Compose encoders to different output.""" 10 | 11 | def __init__(self, embedding, mapping, ff=None, **encoders): 12 | """ 13 | :param mapping: a dict indicate show to construct x. 14 | for example: mapping = { 15 | 'arc': ['shared_lstm.x', 'arc_lstm.x'], 16 | 'rel': ['shared_lstm.x', 'rel_lstm.x'] 17 | } 18 | :param ff: a dict indicate passthrough variables 19 | e.g. ff = { 20 | 'hiddens': 'shared_lstm.hiddens' 21 | } 22 | :type mapping: dict 23 | """ 24 | super().__init__(embedding) 25 | 26 | self.all_encoders = [] 27 | for key, value in encoders.items(): 28 | if key.startswith('_'): 29 | continue 30 | self.add_module(key, value) 31 | self.all_encoders.append(key) 32 | 33 | self.mapping = {} # {'shared_lstm': {'x': ['arc', 'rel']}, ...} 34 | self.output_fields = list(mapping.keys()) 35 | self.dims = {o: 0 for o in self.output_fields} 36 | self.detailed_dims = {o: [] for o in self.output_fields} 37 | for target, sources in mapping.items(): 38 | for source in sources: 39 | source_name, source_field = source.split('.') 40 | self.dims[target] += encoders[source_name].get_dim(source_field) 41 | self.detailed_dims[target].append(encoders[source_name].get_dim(source_field)) 42 | if source_name not in self.mapping: 43 | self.mapping[source_name] = {} 44 | if source_field not in self.mapping[source_name]: 45 | self.mapping[source_name][source_field] = [] 46 | self.mapping[source_name][source_field].append(target) 47 | self.ff = {} 48 | if ff is not None: 49 | for target, source in ff.items(): 50 | source_name, source_field = source.split('.') 51 | assert target not in mapping, 'Conflict' 52 | if source_name not in self.ff: 53 | self.ff[source_name] = {} 54 | if source_field not in self.ff[source_name]: 55 | self.ff[source_name][source_field] = [] 56 | self.ff[source_name][source_field].append(target) 57 | 58 | def forward(self, x, ctx): 59 | outputs = {key: [] for key in self.output_fields} 60 | for source_name in self.all_encoders: 61 | encoder_out = getattr(self, source_name)(x, ctx) 62 | if source_name in self.mapping: 63 | for encoder_field, targets in self.mapping[source_name].items(): 64 | for target in targets: 65 | outputs[target].append(encoder_out[encoder_field]) 66 | if source_name in self.ff: 67 | for encoder_field, targets in self.ff[source_name].items(): 68 | for target in targets: 69 | outputs[target] = encoder_out[encoder_field] 70 | outputs = { 71 | key: torch.cat(value, dim=-1) if key in self.output_fields else value 72 | for key, value in outputs.items() 73 | } 74 | 75 | return outputs 76 | 77 | def get_dim(self, field): 78 | return self.dims[field] 79 | 80 | -------------------------------------------------------------------------------- /src/model/text_encoder/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from omegaconf import MISSING 9 | from torch import Tensor 10 | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence 11 | 12 | from src.model.embedding import Embedding 13 | from src.model.text_encoder.base import EncoderBase 14 | from src.model.nn import ScalarMix, SharedDropout, VariationalLSTM 15 | from src.utility.config import Config 16 | from src.utility.logger import get_logger_func 17 | from src.utility.var_pool import VarPool 18 | 19 | _warn, _info, _debug = get_logger_func('encoder') 20 | RNN_TYPE_DICT = {'lstm': nn.LSTM, 'gru': nn.GRU, 'rnn': nn.RNN} 21 | RNNCELL_TYPE_DICT = {'lstm': nn.LSTMCell, 'gru': nn.GRUCell, 'rnn': nn.RNNCell} 22 | 23 | 24 | @dataclass 25 | class LSTMEncoderConfig(Config): 26 | reproject_emb: int = 0 # reproject layer before lstm 27 | # ============================= dropout ============================== 28 | pre_shared_dropout: float = 0. 29 | pre_dropout: float = 0. 30 | post_shared_dropout: float = 0. 31 | post_dropout: float = 0. 32 | 33 | # =============================== lstm =============================== 34 | rnn_type: str = 'lstm' # lstm, gru or rnn 35 | hidden_size: Union[int, List[int]] = MISSING # hidden size for each layer 36 | proj_size: int = 0 # projective size 37 | num_layers: int = MISSING # total layers 38 | output_layers: Union[int, List[int]] = -1 # which layers are return, start from 0 39 | init_version: str = 'biased' 40 | shared_dropout: bool = True 41 | lstm_dropout: float = 0.33 # only between layers, unlike zhangyu. 42 | no_eos: bool = False # simulate no 43 | sorted: bool = True 44 | 45 | # ============================== output ============================== 46 | mix: bool = False # whether to use a ScaleMix when multiple outputs 47 | reproject_out: int = 0 48 | cat_emb: bool = False 49 | 50 | 51 | class RNNEncoder(EncoderBase): 52 | 53 | def __init__(self, embedding: Embedding, **cfg): 54 | super().__init__(embedding) 55 | self.cfg = cfg = LSTMEncoderConfig.build(cfg) 56 | 57 | # check output_layers 58 | output_layers: List[int] = [cfg.output_layers] if isinstance(cfg.output_layers, int) else cfg.output_layers 59 | output_layers = sorted(cfg.num_layers + o if o < 0 else o for o in output_layers) 60 | assert output_layers[0] >= 0 and output_layers[-1] < cfg.num_layers 61 | if output_layers[-1] < cfg.num_layers - 1: 62 | cfg.num_layers = output_layers[-1] + 1 63 | _warn(f'max index of output_layers is smaller to n_layers, n_layers is set to {cfg.num_layers}') 64 | self.output_layers = output_layers 65 | 66 | self.embedding2nn = nn.Linear(embedding.embed_size, cfg.reproject_emb) if cfg.reproject_emb else nn.Identity() 67 | 68 | # ============================= dropout ============================== 69 | 70 | self.pre_shared_dropout = SharedDropout(cfg.pre_shared_dropout) if cfg.pre_shared_dropout else nn.Identity() 71 | self.pre_dropout = nn.Dropout(cfg.pre_dropout) if cfg.pre_dropout else nn.Identity() 72 | self.post_shared_dropout = SharedDropout(cfg.post_shared_dropout) if cfg.post_shared_dropout else nn.Identity() 73 | self.post_dropout = nn.Dropout(cfg.post_dropout) if cfg.post_dropout else nn.Identity() 74 | 75 | # =============================== lstm =============================== 76 | 77 | input_size = cfg.reproject_emb if cfg.reproject_emb > 0 else embedding.embed_size 78 | if cfg.shared_dropout: 79 | assert isinstance(cfg.hidden_size, int), 'Not supported' 80 | assert cfg.proj_size == 0, 'Not supported' 81 | self.lstm = VariationalLSTM(input_size, cfg.hidden_size, cfg.num_layers, cfg.lstm_dropout, 82 | RNNCELL_TYPE_DICT[cfg.rnn_type]) 83 | self.output_size = 2 * cfg.hidden_size 84 | else: 85 | # figure out how many layers in each sub modules 86 | layer_for_each_rnn = [x - y for x, y in zip(output_layers, [-1] + output_layers[:-1])] 87 | 88 | # check hiddens 89 | if isinstance(cfg.hidden_size, int): 90 | hiddens = [cfg.hidden_size for _ in layer_for_each_rnn] 91 | else: 92 | hiddens = cfg.hidden_size 93 | assert len(hiddens) == len(layer_for_each_rnn) 94 | 95 | # construct nn 96 | self.lstm_dropout = nn.Dropout(cfg.lstm_dropout) 97 | self.lstm = nn.ModuleList() 98 | rnn_type = RNN_TYPE_DICT[cfg.rnn_type] 99 | for n_layer, hidden in zip(layer_for_each_rnn, hiddens): 100 | sub_lstm = rnn_type(input_size, 101 | hidden, 102 | n_layer, 103 | dropout=cfg.lstm_dropout if n_layer > 1 else 0, 104 | bidirectional=True, 105 | proj_size=cfg.proj_size if hidden > cfg.proj_size > 0 else 0) 106 | self.lstm.append(sub_lstm) 107 | input_size = 2 * cfg.proj_size if cfg.proj_size else 2 * hidden 108 | self.output_size = 2 * cfg.proj_size if cfg.proj_size else 2 * hiddens[-1] 109 | 110 | if cfg.mix: 111 | assert isinstance(cfg.hidden_size, int) or all(h == cfg.hidden_size[0] for h in cfg.hidden_size), \ 112 | 'Only if has same dim for all layers, mix can be used.' 113 | self.mix = ScalarMix(len(output_layers)) 114 | else: 115 | self.output_size *= len(output_layers) 116 | 117 | if cfg.reproject_out: 118 | self.nn2out = nn.Linear(self.output_size, cfg.reproject_out) 119 | self.output_size = cfg.reproject_out 120 | else: 121 | self.nn2out = nn.Identity() 122 | 123 | if cfg.cat_emb: 124 | self.output_size += embedding.embed_size 125 | 126 | self.reset_parameters(cfg.init_version) 127 | 128 | def reset_parameters(self, init_method): 129 | if init_method == 'zy': 130 | for name, param in self.named_parameters(): 131 | if name.startswith('lstm'): 132 | # apply orthogonal_ to weight 133 | if len(param.shape) > 1: 134 | nn.init.orthogonal_(param) 135 | # apply zeros_ to bias 136 | else: 137 | nn.init.zeros_(param) 138 | elif init_method == 'biased': 139 | for name, param in self.named_parameters(): 140 | if name.startswith('lstm'): 141 | # apply orthogonal_ to weight 142 | if len(param.shape) > 1: 143 | nn.init.xavier_uniform_(param) 144 | else: 145 | # based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 146 | param.data.fill_(0.) 147 | n = param.shape[0] 148 | start, end = n // 4, n // 2 149 | param.data[start:end].fill_(1.) 150 | # else: 151 | # raise ValueError(f'Bad init_version, {self.cfg.init_version=}') 152 | 153 | def forward(self, x: Tensor, vp: VarPool, hiddens=None): 154 | """ 155 | :param x: output of embedding 156 | :param vp: the varpool 157 | :param hiddens: ttbp 158 | :return: a dict contains 159 | x, Tensor: the concated or mixed representation. 160 | all: List[Tensor], a list contains all outputs specified in self.output_layers. 161 | hx: the output state for all layers. 162 | """ 163 | if isinstance(x, list): x = torch.cat(x, dim=-1) 164 | 165 | emb = x 166 | x = self.embedding2nn(x) 167 | x = self.pre_shared_dropout(x) 168 | x = self.pre_dropout(x) 169 | xs, hx = self.lstm_forward(x, vp, hiddens) 170 | if self.cfg.mix: 171 | x = self.mix(xs) 172 | else: 173 | x = torch.cat(xs, dim=-1) 174 | x = self.post_dropout(x) 175 | x = self.post_shared_dropout(x) 176 | if self.cfg.no_eos: 177 | x = torch.cat([x, torch.zeros(x.shape[0], 1, x.shape[2], device=x.device)], dim=1) 178 | x = self.nn2out(x) 179 | 180 | if self.cfg.cat_emb: 181 | x = torch.cat([x, emb], dim=-1) 182 | 183 | # from src.utility.fn import draw_att 184 | # draw_att(x[0]) 185 | 186 | return {'x': x, 'all': xs, 'hiddens': hx} 187 | 188 | def lstm_forward(self, x: Tensor, vp: VarPool, hiddens=None): 189 | if self.cfg.no_eos: 190 | x = x[:, :-1] 191 | x = pack_padded_sequence(x, vp.seq_len_cpu - 1, True, enforce_sorted=self.cfg.sorted) 192 | else: 193 | x = pack_padded_sequence(x, vp.seq_len_cpu, True, enforce_sorted=self.cfg.sorted) 194 | 195 | if self.cfg.shared_dropout: 196 | outputs, (hx, _) = self.lstm(x, hiddens) 197 | outputs = [outputs[i] for i in self.output_layers] 198 | outputs = [pad_packed_sequence(o, True)[0] for o in outputs] 199 | else: 200 | layer_count = -1 201 | outputs = [] 202 | output_layers = self.output_layers.copy() 203 | hx = [] 204 | hiddens = hiddens if hiddens is not None else [None] * len(self.lstm) 205 | 206 | for layer, hidden in zip(self.lstm, hiddens): 207 | output: PackedSequence 208 | output, (hx_, _) = layer(x, hidden) 209 | hx.append(hx_) 210 | 211 | layer_count += layer.num_layers 212 | if layer_count == output_layers[0]: 213 | output_layers.pop(0) 214 | outputs.append(pad_packed_sequence(output, True)[0]) 215 | 216 | data = self.lstm_dropout(output.data) 217 | x = PackedSequence(data, output.batch_sizes, output.sorted_indices, output.unsorted_indices) 218 | hx = torch.cat(hx, 0) 219 | return outputs, hx 220 | 221 | def get_dim(self, field): 222 | if field == 'x' or field == 'all': 223 | return self.output_size 224 | return super().get_dim(field) 225 | -------------------------------------------------------------------------------- /src/model/torch_struct/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributions import DMV1o, DependencyCRF, StructDistribution 2 | from .semirings import ( 3 | CheckpointSemiring, 4 | CheckpointShardSemiring, 5 | EntropySemiring, 6 | FastLogSemiring, 7 | FastMaxSemiring, 8 | FastSampleSemiring, 9 | GumbelCRFSemiring, 10 | KMaxSemiring, 11 | LogSemiring, 12 | MaxSemiring, 13 | MultiSampledSemiring, 14 | SampledSemiring, 15 | SparseMaxSemiring, 16 | StdSemiring, 17 | TempMax, 18 | ) 19 | 20 | version = "0.4" 21 | 22 | # For flake8 compatibility. 23 | __all__ = [ 24 | LogSemiring, 25 | StdSemiring, 26 | SampledSemiring, 27 | MaxSemiring, 28 | SparseMaxSemiring, 29 | KMaxSemiring, 30 | FastLogSemiring, 31 | FastMaxSemiring, 32 | FastSampleSemiring, 33 | EntropySemiring, 34 | MultiSampledSemiring, 35 | GumbelCRFSemiring, 36 | StructDistribution, 37 | DMV1o, 38 | DependencyCRF, 39 | CheckpointSemiring, 40 | CheckpointShardSemiring, 41 | TempMax, 42 | ] 43 | -------------------------------------------------------------------------------- /src/model/torch_struct/dmv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | from .helpers import _Struct 5 | from .semirings import Semiring 6 | 7 | NOCHILD = 1 8 | HASCHILD = 0 9 | LEFT = 0 10 | RIGHT = 1 11 | GO = 0 12 | STOP = 1 13 | DIR_NUM = 2 14 | VAL_NUM = 2 15 | DEC_NUM = 2 16 | 17 | 18 | class DMV1oStruct(_Struct): 19 | def _dp(self, scores, lengths=None, force_grad=False, cache=False): 20 | # dec, attach 21 | s: Semiring = self.semiring 22 | 23 | if isinstance(scores[0], torch.Tensor): 24 | # attach_score: batch, N, N, valence 25 | # dec_score: batch, N, direction, valence, decision 26 | attach: Tensor = s.convert(scores[1]) 27 | dec: Tensor = s.convert(scores[0]) 28 | else: 29 | attach: Tensor = s.convert([scores[0][1], scores[1][1]]) 30 | dec: Tensor = s.convert([scores[0][0], scores[1][0]]) 31 | 32 | _, batch, N, *_ = dec.shape 33 | # diagonal for left, diagonal(1) for right. 34 | I = s.zero_(attach.new_empty((s.size(), batch, N + 1, N + 1, VAL_NUM))) 35 | C = s.zero_(attach.new_empty((s.size(), batch, N + 1, N + 1, VAL_NUM))) 36 | attach_left = s.mul(attach, dec[:, :, :, None, LEFT, :, GO]) 37 | attach_right = s.mul(attach, dec[:, :, :, None, RIGHT, :, GO]) 38 | 39 | diag_minus1(C, 0, 2, 3).copy_(dec[:, :, :, LEFT, :, STOP].transpose(-2, -1)) 40 | C.diagonal(1, 2, 3).copy_(dec[:, :, :, RIGHT, :, STOP].transpose(-2, -1)) 41 | _zero = C.new_tensor(s.zero) 42 | if _zero.ndim == 0: 43 | _zero = _zero.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) 44 | else: 45 | _zero = _zero.unsqueeze(-1).unsqueeze(-1) 46 | 47 | for w in range(1, N): 48 | n = N - w 49 | 50 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 1, NOCHILD)), stripe_val(C, n, w, (w, 1, HASCHILD)))) 51 | x = s.times(x.unsqueeze(-2), attach_left.diagonal(-w, -3, -2)) 52 | diag_minus1(I, -w, -3, -2).copy_(x) 53 | 54 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 1, HASCHILD)), stripe_val(C, n, w, (w, 1, NOCHILD)))) 55 | x = s.times(x.unsqueeze(-2), attach_right.diagonal(w, -3, -2)) 56 | I.diagonal(w + 1, -3, -2).copy_(x) 57 | 58 | x = s.sum(s.mul(stripe_val(C, n, w, (0, 0, NOCHILD), 0, True), stripe_noval(I, n, w, (w, 0))), -2) 59 | diag_minus1(C, -w, -3, -2).copy_(x.transpose(-2, -1)) 60 | 61 | x = s.sum(s.mul(stripe_noval(I, n, w, (0, 2)), stripe_val(C, n, w, (1, w + 1, NOCHILD), 0, True)), -2) 62 | C.diagonal(w + 1, -3, -2).copy_(x.transpose(-2, -1)) 63 | C[:, lengths.ne(w), 0, w + 1] = _zero 64 | 65 | v = torch.gather(C[:, :, 0, :, NOCHILD], -1, (lengths[None, ..., None] + 1).expand(s.size(), -1, -1)) 66 | return v, [dec, attach], [C, I] 67 | 68 | def _arrange_marginals(self, marg): 69 | return marg[1] # return attach 70 | 71 | 72 | def stripe_val(x: Tensor, n, w, offset=(0, 0, 0), dim=1, keep_val=False): 73 | # x: s x b x N x N x valence 74 | # on the last three dim, N x N x valence 75 | # n and w are for N x N 76 | assert x.shape[-1] == 2 77 | assert x.is_contiguous(), 'x must be contiguous, or write on new view will lost.' 78 | seq_len = x.shape[-2] 79 | if keep_val: 80 | size = (*x.shape[:-3], n, w, 1) 81 | stride = list(x.stride()) 82 | stride[-3] = (seq_len + 1) * 2 83 | stride[-2] = (1 if dim == 1 else seq_len) * 2 84 | else: 85 | stride = list(x.stride())[:-1] 86 | stride[-2] = (seq_len + 1) * 2 87 | stride[-1] = (1 if dim == 1 else seq_len) * 2 88 | size = (*x.shape[:-3], n, w) 89 | return x.as_strided(size=size, 90 | stride=stride, 91 | storage_offset=x.storage_offset() + (offset[0] * seq_len * 2 + offset[1] * 2 + offset[2])) 92 | 93 | 94 | def stripe_noval(x: Tensor, n, w, offset=(0, 0), dim=1): 95 | # x: s x b x N x N x valence 96 | # on the last three dim, N x N x valence 97 | # n and w are for N x N 98 | assert x.shape[-1] == 2 99 | assert x.is_contiguous(), 'x must be contiguous, or write on new view will lost.' 100 | seq_len = x.shape[-2] 101 | stride = list(x.stride()) 102 | stride[-3] = (seq_len + 1) * 2 103 | stride[-2] = (1 if dim == 1 else seq_len) * 2 104 | return x.as_strided(size=(*x.shape[:-3], n, w, 2), 105 | stride=stride, 106 | storage_offset=x.storage_offset() + (offset[0] * seq_len * 2 + offset[1] * 2)) 107 | 108 | 109 | def diag_minus1(x: Tensor, offset, dim1, dim2) -> Tensor: 110 | # assume a[..., dim1, ..., dim2, ...] 111 | stride = list(x.stride()) 112 | if offset > 0: 113 | storage_offset = stride[dim2] * offset 114 | else: 115 | storage_offset = stride[dim1] * abs(offset) 116 | to_append = stride[dim1] + stride[dim2] 117 | if dim2 < 0: 118 | stride.pop(dim1) 119 | stride.pop(dim2) 120 | else: 121 | stride.pop(dim2) 122 | stride.pop(dim1) # todo handle +/- or -/+ (now only support +/+ ans -/-) 123 | stride.append(to_append) 124 | size = list(x.size()) 125 | to_append = size[dim1] - 1 - abs(offset) 126 | if dim2 < 0: 127 | size.pop(dim1) 128 | size.pop(dim2) 129 | else: 130 | size.pop(dim2) 131 | size.pop(dim1) 132 | size.append(to_append) 133 | return x.as_strided(size, stride, storage_offset=storage_offset) 134 | -------------------------------------------------------------------------------- /src/model/torch_struct/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.autograd import Function 7 | 8 | from .semirings import LogSemiring, Semiring 9 | 10 | 11 | class Get(Function): 12 | @staticmethod 13 | def forward(ctx, chart, grad_chart, indices): 14 | ctx.save_for_backward(grad_chart) 15 | out = chart[indices] 16 | ctx.indices = indices 17 | return out 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | (grad_chart, ) = ctx.saved_tensors 22 | grad_chart[ctx.indices] += grad_output 23 | return grad_chart, None, None 24 | 25 | 26 | class Set(torch.autograd.Function): 27 | @staticmethod 28 | def forward(ctx, chart, indices, vals): 29 | chart[indices] = vals 30 | ctx.indices = indices 31 | return chart 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | z = grad_output[ctx.indices] 36 | return None, None, z 37 | 38 | 39 | class Chart: 40 | def __init__(self, size, potentials, semiring: Semiring, cache=True): 41 | self.data = semiring.zero_( 42 | torch.empty(*((semiring.size(), ) + size), dtype=potentials.dtype, device=potentials.device)) 43 | self.grad = self.data.detach().clone().fill_(0.0) 44 | self.cache = cache 45 | self.semiring = semiring 46 | 47 | def __getitem__(self, ind): 48 | I = slice(None) 49 | if self.cache: 50 | return Get.apply(self.data, self.grad, (I, I) + ind) 51 | else: 52 | return self.data[(I, I) + ind] 53 | 54 | def __setitem__(self, ind, new): 55 | I = slice(None) 56 | if self.cache: 57 | self.data = Set.apply(self.data, (I, I) + ind, new) 58 | else: 59 | self.data[(I, I) + ind] = new 60 | 61 | def get(self, ind): 62 | return Get.apply(self.data, self.grad, ind) 63 | 64 | def set(self, ind, new): 65 | self.data = Set.apply(self.data, ind, new) 66 | 67 | 68 | class _Struct: 69 | def __init__(self, semiring: Semiring = LogSemiring): 70 | self.semiring = semiring 71 | 72 | def score(self, potentials: Tensor, parts: Tensor, batch_dims=(0, )) -> Tensor: 73 | """gather all score in parts""" 74 | score = torch.mul(potentials, parts) 75 | batch = tuple((score.shape[b] for b in batch_dims)) 76 | return self.semiring.prod(score.view(batch + (-1, ))) 77 | 78 | def _bin_length(self, length: int) -> Tuple[int, int]: 79 | log_N = int(math.ceil(math.log(length, 2))) 80 | bin_N = int(math.pow(2, log_N)) 81 | return log_N, bin_N 82 | 83 | def _get_dimension_and_requires_grad(self, edge: Union[List[Tensor], Tensor]) -> Tuple[int, ...]: 84 | if isinstance(edge, (list, tuple)): 85 | for t in edge: 86 | t.requires_grad_(True) 87 | return edge[0].shape 88 | else: 89 | edge.requires_grad_(True) 90 | return edge.shape 91 | 92 | def _chart(self, size, potentials, force_grad): 93 | return self._make_chart(1, size, potentials, force_grad)[0] 94 | 95 | def _make_chart(self, N, size, potentials, force_grad=False): 96 | return [(self.semiring.zero_( 97 | torch.zeros(*((self.semiring.size(), ) + size), dtype=potentials.dtype, 98 | device=potentials.device)).requires_grad_(force_grad and not potentials.requires_grad)) 99 | for _ in range(N)] 100 | 101 | def sum(self, edge, lengths=None, _raw=False): 102 | """ 103 | Compute the (semiring) sum over all structures model. 104 | 105 | Parameters: 106 | edge : generic params (see class) 107 | lengths: None or b long tensor mask 108 | 109 | Returns: 110 | v: b tensor of total sum 111 | """ 112 | 113 | v = self._dp(edge, lengths)[0] 114 | if _raw: 115 | return v 116 | return self.semiring.unconvert(v) 117 | 118 | def marginals(self, edge, lengths=None, _autograd=True, _raw=False, _combine=False): 119 | """ 120 | Compute the marginals of a structured model. 121 | 122 | Parameters: 123 | params : generic params (see class) 124 | lengths: None or b long tensor mask 125 | Returns: 126 | marginals: b x (N-1) x C x C table 127 | 128 | """ 129 | if (_autograd or self.semiring is not LogSemiring or not hasattr(self, '_dp_backward')): 130 | v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True, cache=not _raw) 131 | if _raw: 132 | all_m = [] 133 | for k in range(v.shape[0]): 134 | obj = v[k].sum(dim=0) 135 | 136 | marg = torch.autograd.grad( 137 | obj, 138 | edges, 139 | create_graph=True, 140 | only_inputs=True, 141 | allow_unused=False, 142 | ) 143 | all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) 144 | return torch.stack(all_m, dim=0) 145 | elif _combine: 146 | obj = v.sum(dim=0).sum(dim=0) 147 | marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) 148 | a_m = self._arrange_marginals(marg) 149 | return a_m 150 | else: 151 | obj = self.semiring.unconvert(v).sum(dim=0) 152 | marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) 153 | a_m = self._arrange_marginals(marg) 154 | return self.semiring.unconvert(a_m) 155 | else: 156 | v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True) 157 | return self._dp_backward(edge, lengths, alpha) 158 | 159 | @staticmethod 160 | def to_parts(spans, extra, lengths=None): 161 | return spans 162 | 163 | @staticmethod 164 | def from_parts(spans): 165 | return spans, None 166 | 167 | def _arrange_marginals(self, marg): 168 | return marg[0] 169 | 170 | def _dp(self, scores, lengths=None, force_grad=False, cache=True): 171 | raise NotImplementedError 172 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/__init__.py: -------------------------------------------------------------------------------- 1 | from .checkpoint import CheckpointSemiring, CheckpointShardSemiring 2 | from .fast_semirings import FastLogSemiring, FastMaxSemiring, FastSampleSemiring 3 | from .sample import GumbelCRFSemiring, MultiSampledSemiring, SampledSemiring 4 | from .semirings import (CrossEntropySemiring, EntropySemiring, KLDivergenceSemiring, KMaxSemiring, LogSemiring, 5 | MaxSemiring, RiskSemiring, Semiring, StdSemiring, TempMax) 6 | from .sparse_max import SparseMaxSemiring 7 | 8 | # For flake8 compatibility. 9 | __all__ = [ 10 | Semiring, 11 | FastLogSemiring, 12 | FastMaxSemiring, 13 | FastSampleSemiring, 14 | LogSemiring, 15 | StdSemiring, 16 | SampledSemiring, 17 | MaxSemiring, 18 | SparseMaxSemiring, 19 | KMaxSemiring, 20 | EntropySemiring, 21 | CrossEntropySemiring, 22 | KLDivergenceSemiring, 23 | MultiSampledSemiring, 24 | CheckpointSemiring, 25 | CheckpointShardSemiring, 26 | TempMax, 27 | ] 28 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | try: 4 | import genbmm 5 | from genbmm import BandedMatrix 6 | except ImportError: 7 | pass 8 | 9 | 10 | def broadcast_size(a, b): 11 | return torch.tensor([max(i, j) for i, j in zip(a.shape, b.shape)]).prod() 12 | 13 | 14 | def matmul_size(a, b): 15 | size = [max(i, j) for i, j in zip(a.shape[:-2], b.shape[:-2])] 16 | size.append(a.shape[-2]) 17 | size.append(b.shape[-1]) 18 | return size 19 | 20 | 21 | def CheckpointSemiring(cls, min_size=0): 22 | class _Check(torch.autograd.Function): 23 | @staticmethod 24 | def forward(ctx, a, b): 25 | ctx.save_for_backward(a, b) 26 | return cls.matmul(a, b) 27 | 28 | @staticmethod 29 | def backward(ctx, grad_output): 30 | a, b = ctx.saved_tensors 31 | with torch.enable_grad(): 32 | q = cls.matmul(a, b) 33 | return torch.autograd.grad(q, (a, b), grad_output) 34 | 35 | class _CheckBand(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, a, a_lu, a_ld, b, b_lu, b_ld): 38 | ctx.save_for_backward(a, b, torch.LongTensor([a_lu, a_ld, b_lu, b_ld])) 39 | a = BandedMatrix(a, a_lu, a_ld) 40 | b = BandedMatrix(b, b_lu, b_ld) 41 | return cls.matmul(a, b).data 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | a, b, bands = ctx.saved_tensors 46 | a_lu, a_ld, b_lu, b_ld = bands.tolist() 47 | with torch.enable_grad(): 48 | q = cls.matmul(BandedMatrix(a, a_lu, a_ld), BandedMatrix(b, b_lu, b_ld)) 49 | grad_a, grad_b = torch.autograd.grad(q.data, (a, b), grad_output) 50 | return grad_a, None, None, grad_b, None, None 51 | 52 | class _CheckpointSemiring(cls): 53 | @staticmethod 54 | def matmul(a, b): 55 | if isinstance(a, genbmm.BandedMatrix): 56 | lu = a.lu + b.lu 57 | ld = a.ld + b.ld 58 | c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld) 59 | return BandedMatrix(c, lu, ld, cls.zero) 60 | 61 | if broadcast_size(a, b) > min_size: 62 | return _Check.apply(a, b) 63 | else: 64 | return cls.matmul(a, b) 65 | 66 | return _CheckpointSemiring 67 | 68 | 69 | def CheckpointShardSemiring(cls, max_size, min_size=0): 70 | class _Check(torch.autograd.Function): 71 | @staticmethod 72 | def forward(ctx, a, b): 73 | ctx.save_for_backward(a, b) 74 | size = matmul_size(a, b) 75 | return accumulate_( 76 | a, 77 | b, 78 | size, 79 | lambda a, b: cls.matmul(a, b), 80 | preserve=len(size), 81 | step=max_size // (b.shape[-2] * a.shape[-1]) + 2, 82 | ) 83 | 84 | @staticmethod 85 | def backward(ctx, grad_output): 86 | a, b = ctx.saved_tensors 87 | grad_a, grad_b = unaccumulate_( 88 | a, 89 | b, 90 | grad_output, 91 | len(grad_output.shape), 92 | lambda a, b: cls.matmul(a, b), 93 | step=max_size // (b.shape[-2] * a.shape[-1]) + 2, 94 | ) 95 | return grad_a, grad_b 96 | 97 | class _CheckpointSemiring(cls): 98 | @staticmethod 99 | def matmul(a, b): 100 | size = torch.tensor([max(i, j) for i, j in zip(a.shape, b.shape)]).prod() 101 | if size < min_size: 102 | return cls.matmul(a, b) 103 | else: 104 | return _Check.apply(a, b) 105 | 106 | return _CheckpointSemiring 107 | 108 | 109 | def ones(x): 110 | one = [] 111 | for i, v in enumerate(x.shape[:-1]): 112 | if v == 1: 113 | one.append(i) 114 | return one 115 | 116 | 117 | def mind(one, inds): 118 | inds = list(inds) 119 | for v in one: 120 | inds[v] = inds[v].clone().fill_(0) 121 | return inds 122 | 123 | 124 | def accumulate_(a, b, size, fn, preserve, step=10000): 125 | slices = [] 126 | total = 1 127 | for s in size[:preserve]: 128 | slices.append(slice(s)) 129 | total *= s 130 | if step > total: 131 | return fn(a, b) 132 | 133 | ret = torch.zeros(*size, dtype=a.dtype, device=a.device) 134 | 135 | a = a.expand(*size[:-2], a.shape[-2], a.shape[-1]) 136 | b = b.expand(*size[:-2], b.shape[-2], b.shape[-1]) 137 | 138 | a2 = a.contiguous().view(-1, a.shape[-2], a.shape[-1]) 139 | b2 = b.contiguous().view(-1, b.shape[-2], b.shape[-1]) 140 | ret = ret.view(-1, a.shape[-2], b.shape[-1]) 141 | for p in range(0, ret.shape[0], step): 142 | ret[p:p + step, :] = fn(a2[p:p + step], b2[p:p + step]) 143 | ret = ret.view(*size) 144 | return ret 145 | 146 | 147 | def unaccumulate_(a, b, grad_output, preserve, fn, step=10000): 148 | slices = [] 149 | total = 1 150 | size = grad_output.shape[:preserve] 151 | for s in grad_output.shape[:preserve]: 152 | slices.append(slice(s)) 153 | total *= s 154 | 155 | if step > total: 156 | with torch.enable_grad(): 157 | a_in = a.clone().requires_grad_(True) 158 | b_in = b.clone().requires_grad_(True) 159 | q = fn(a, b) 160 | ag, bg = torch.autograd.grad(q, (a, b), grad_output) 161 | return ag, bg 162 | 163 | a2 = a.expand(*size[:-2], a.shape[-2], a.shape[-1]) 164 | b2 = b.expand(*size[:-2], b.shape[-2], b.shape[-1]) 165 | a2 = a2.contiguous().view(-1, a.shape[-2], a.shape[-1]) 166 | b2 = b2.contiguous().view(-1, b.shape[-2], b.shape[-1]) 167 | 168 | a_grad = a2.clone().fill_(0) 169 | b_grad = b2.clone().fill_(0) 170 | 171 | grad_output = grad_output.view(-1, a.shape[-2], b.shape[-1]) 172 | for p in range(0, grad_output.shape[0], step): 173 | with torch.enable_grad(): 174 | a_in = a2[p:p + step].clone().requires_grad_(True) 175 | b_in = b2[p:p + step].clone().requires_grad_(True) 176 | q = fn(a_in, b_in) 177 | ag, bg = torch.autograd.grad(q, (a_in, b_in), grad_output[p:p + step]) 178 | a_grad[p:p + step] += ag 179 | b_grad[p:p + step] += bg 180 | 181 | a_grad = a_grad.view(*size[:-2], a.shape[-2], a.shape[-1]) 182 | b_grad = b_grad.view(*size[:-2], b.shape[-2], b.shape[-1]) 183 | a_ones = ones(a) 184 | b_ones = ones(b) 185 | f1, f2 = a_grad.sum(a_ones, keepdim=True), b_grad.sum(b_ones, keepdim=True) 186 | return f1, f2 187 | 188 | 189 | # def unaccumulate_(a, b, grad_output, fn, step=10000): 190 | # slices = [] 191 | # a_grad = a.clone().fill_(0) 192 | # b_grad = b.clone().fill_(0) 193 | 194 | # total = 1 195 | # for s in grad_output.shape: 196 | # slices.append(slice(s)) 197 | # total *= s 198 | # a_one, b_one = ones(a), ones(b) 199 | 200 | # indices = torch.tensor(np.mgrid[slices]).view(len(grad_output.shape), -1) 201 | 202 | # for p in range(0, total, step): 203 | # ind = indices[:, p : p + step].unbind() 204 | # a_ind = mind(a_one, ind) 205 | # b_ind = mind(b_one, ind) 206 | 207 | # q = fn(a[tuple(a_ind)], b[tuple(b_ind)], grad_output[tuple(ind)]) 208 | # a_grad.index_put_(tuple(a_ind), q, accumulate=True) 209 | # b_grad.index_put_(tuple(b_ind), q, accumulate=True) 210 | # return a_grad, b_grad 211 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/fast_semirings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | 4 | from .sample import _SampledLogSumExp 5 | from .semirings import _BaseLog 6 | 7 | try: 8 | import genbmm 9 | except ImportError: 10 | pass 11 | 12 | 13 | def matmul_size(a, b): 14 | size = [max(i, j) for i, j in zip(a.shape[:-2], b.shape[:-2])] 15 | size.append(a.shape[-2]) 16 | size.append(b.shape[-1]) 17 | return size 18 | 19 | 20 | def broadcast(a, b): 21 | size = matmul_size(a, b) 22 | a = a.expand(*size[:-2], a.shape[-2], a.shape[-1]) 23 | b = b.expand(*size[:-2], b.shape[-2], b.shape[-1]) 24 | a2 = a.contiguous().view(-1, a.shape[-2], a.shape[-1]) 25 | b2 = b.contiguous().view(-1, b.shape[-2], b.shape[-1]) 26 | return a2, b2, size 27 | 28 | 29 | class FastLogSemiring(_BaseLog): 30 | """ 31 | Implements the log-space semiring (logsumexp, +, -inf, 0). 32 | 33 | Gradients give marginals. 34 | """ 35 | @staticmethod 36 | def sum(xs, dim=-1): 37 | return torch.logsumexp(xs, dim=dim) 38 | 39 | @staticmethod 40 | def matmul(a, b, dims=1): 41 | if isinstance(a, genbmm.BandedMatrix): 42 | return b.multiply_log(a.transpose()) 43 | else: 44 | a2, b2, size = broadcast(a, b) 45 | return genbmm.logbmm(a2, b2).view(size) 46 | 47 | 48 | class FastMaxSemiring(_BaseLog): 49 | @staticmethod 50 | def sum(xs, dim=-1): 51 | return torch.max(xs, dim=dim)[0] 52 | 53 | @staticmethod 54 | def matmul(a, b, dims=1): 55 | a2, b2, size = broadcast(a, b) 56 | return genbmm.maxbmm(a2, b2).view(size) 57 | 58 | 59 | class FastSampleSemiring(_BaseLog): 60 | @staticmethod 61 | def sum(xs, dim=-1): 62 | return _SampledLogSumExp.apply(xs, dim) 63 | 64 | @staticmethod 65 | def matmul(a, b, dims=1): 66 | a2, b2, size = broadcast(a, b) 67 | return genbmm.samplebmm(a2, b2).view(size) 68 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/keops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | 4 | from .semirings import _BaseLog 5 | 6 | try: 7 | from pykeops.torch import LazyTensor 8 | except ImportError: 9 | pass 10 | 11 | 12 | class LogSemiringKO(_BaseLog): 13 | """ 14 | Implements the log-space semiring (logsumexp, +, -inf, 0). 15 | 16 | Gradients give marginals. 17 | """ 18 | @staticmethod 19 | def sum(a, dim=-1): 20 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) 21 | c = a_lazy.sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1) 22 | return c 23 | 24 | @classmethod 25 | def dot(cls, a, b): 26 | """ 27 | Dot product along last dim. (Faster than calling sum and times.) 28 | """ 29 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) 30 | b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous()) 31 | c = (a_lazy + b_lazy).sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1) 32 | return c 33 | 34 | 35 | class _Max(torch.autograd.Function): 36 | @staticmethod 37 | def forward(ctx, a, b): 38 | one_hot = b.shape[-1] 39 | a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) 40 | b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous()) 41 | c = (a_lazy + b_lazy).sum(-1).max(a.dim() - 1).squeeze(-1).squeeze(-1) 42 | ac = (a_lazy + b_lazy).sum(-1).argmax(a.dim() - 1).squeeze(-1).squeeze(-1) 43 | ctx.save_for_backward(ac, torch.tensor(one_hot)) 44 | return c 45 | 46 | @staticmethod 47 | def backward(ctx, grad_output): 48 | ac, size = ctx.saved_tensors 49 | back = torch.nn.functional.one_hot(ac, size).type_as(grad_output) 50 | ret = grad_output.unsqueeze(-1).mul(back) 51 | return ret, ret 52 | 53 | 54 | class MaxSemiringKO(_BaseLog): 55 | @classmethod 56 | def sum(cls, xs, dim=-1): 57 | assert dim == -1 58 | return cls.dot(xs, xs.clone().fill_(0)) 59 | 60 | @classmethod 61 | def dot(cls, a, b): 62 | """ 63 | Dot product along last dim. (Faster than calling sum and times.) 64 | """ 65 | return _Max.apply(a, b) 66 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/sample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions 3 | 4 | from .semirings import _BaseLog 5 | 6 | 7 | class _SampledLogSumExp(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, input, dim): 10 | ctx.save_for_backward(input, torch.tensor(dim)) 11 | return torch.logsumexp(input, dim=dim) 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | logits, dim = ctx.saved_tensors 16 | grad_input = None 17 | if ctx.needs_input_grad[0]: 18 | 19 | def sample(ls): 20 | pre_shape = ls.shape 21 | draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 1, True) 22 | draws.squeeze(1) 23 | return (torch.nn.functional.one_hot(draws, pre_shape[-1]).view(*pre_shape).type_as(ls)) 24 | 25 | if dim == -1: 26 | s = sample(logits) 27 | else: 28 | dim = dim if dim >= 0 else logits.dim() + dim 29 | perm = [i for i in range(logits.dim()) if i != dim] + [dim] 30 | rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])] 31 | s = sample(logits.permute(perm)).permute(rev_perm) 32 | 33 | grad_input = grad_output.unsqueeze(dim).mul(s) 34 | return grad_input, None 35 | 36 | 37 | class SampledSemiring(_BaseLog): 38 | """ 39 | Implements a sampling semiring (logsumexp, +, -inf, 0). 40 | 41 | "Gradients" give sample. 42 | 43 | This is an exact forward-filtering, backward-sampling approach. 44 | """ 45 | @staticmethod 46 | def sum(xs, dim=-1): 47 | return _SampledLogSumExp.apply(xs, dim) 48 | 49 | 50 | def GumbelCRFSemiring(temp): 51 | class ST(torch.autograd.Function): 52 | @staticmethod 53 | def forward(ctx, logits, dim): 54 | out = torch.nn.functional.one_hot(logits.max(-1)[1], dim) 55 | out = out.type_as(logits) 56 | ctx.save_for_backward(logits, out) 57 | return out 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output): 61 | logits, out = ctx.saved_tensors 62 | with torch.enable_grad(): 63 | ret = torch.autograd.grad(logits.softmax(-1), logits, out * grad_output)[0] 64 | return ret, None 65 | 66 | class _GumbelCRFLogSumExp(torch.autograd.Function): 67 | @staticmethod 68 | def forward(ctx, input, dim): 69 | ctx.save_for_backward(input, torch.tensor(dim)) 70 | return torch.logsumexp(input, dim=dim) 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | logits, dim = ctx.saved_tensors 75 | grad_input = None 76 | if ctx.needs_input_grad[0]: 77 | 78 | def sample(ls): 79 | update = (ls + torch.distributions.Gumbel(0, 1).sample((ls.shape[-1], ))) / temp 80 | out = ST.apply(update, ls.shape[-1]) 81 | return out 82 | 83 | if dim == -1: 84 | s = sample(logits) 85 | else: 86 | dim = dim if dim >= 0 else logits.dim() + dim 87 | perm = [i for i in range(logits.dim()) if i != dim] + [dim] 88 | rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])] 89 | s = sample(logits.permute(perm)).permute(rev_perm) 90 | 91 | grad_input = grad_output.unsqueeze(dim).mul(s) 92 | return grad_input, None 93 | 94 | class _GumbelCRFSemiring(_BaseLog): 95 | @staticmethod 96 | def sum(xs, dim=-1): 97 | return _GumbelCRFLogSumExp.apply(xs, dim) 98 | 99 | return _GumbelCRFSemiring 100 | 101 | 102 | bits = torch.tensor([pow(2, i) for i in range(1, 18)]) 103 | 104 | 105 | class _MultiSampledLogSumExp(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, input, dim): 108 | part = torch.logsumexp(input, dim=dim) 109 | ctx.save_for_backward(input, part, torch.tensor(dim)) 110 | return part 111 | 112 | @staticmethod 113 | def backward(ctx, grad_output): 114 | 115 | logits, part, dim = ctx.saved_tensors 116 | grad_input = None 117 | if ctx.needs_input_grad[0]: 118 | 119 | def sample(ls): 120 | pre_shape = ls.shape 121 | draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 16, True) 122 | draws = draws.transpose(0, 1) 123 | return (torch.nn.functional.one_hot(draws, pre_shape[-1]).view(16, *pre_shape).type_as(ls)) 124 | 125 | if dim == -1: 126 | s = sample(logits) 127 | else: 128 | dim = dim if dim >= 0 else logits.dim() + dim 129 | perm = [i for i in range(logits.dim()) if i != dim] + [dim] 130 | rev_perm = [0] + [a + 1 for a, b in sorted(enumerate(perm), key=lambda a: a[1])] 131 | s = sample(logits.permute(perm)).permute(rev_perm) 132 | 133 | dim = dim if dim >= 0 else logits.dim() + dim 134 | final = (grad_output % 2).unsqueeze(0) 135 | mbits = bits[:].type_as(grad_output) 136 | on = grad_output.unsqueeze(0) % mbits.view(17, *[1] * grad_output.dim()) 137 | on = on[1:] - on[:-1] 138 | old_bits = (on + final == 0).unsqueeze(dim + 1) 139 | 140 | grad_input = (mbits[:-1].view(16, *[1] * (s.dim() - 1)).mul(s.masked_fill_(old_bits, 0))) 141 | 142 | return torch.sum(grad_input, dim=0), None 143 | 144 | 145 | class MultiSampledSemiring(_BaseLog): 146 | """ 147 | Implements a multi-sampling semiring (logsumexp, +, -inf, 0). 148 | 149 | "Gradients" give up to 16 samples with replacement. 150 | """ 151 | @staticmethod 152 | def sum(xs, dim=-1): 153 | return _MultiSampledLogSumExp.apply(xs, dim) 154 | 155 | @staticmethod 156 | def to_discrete(xs, j): 157 | i = j 158 | final = xs % 2 159 | mbits = bits.type_as(xs) 160 | return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final) != 0).type_as(xs) 161 | -------------------------------------------------------------------------------- /src/model/torch_struct/semirings/sparse_max.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .semirings import _BaseLog 4 | 5 | 6 | class SparseMaxSemiring(_BaseLog): 7 | """ 8 | 9 | Implements differentiable dynamic programming with a sparsemax semiring (sparsemax, +, -inf, 0). 10 | 11 | Sparse-max gradients give a more sparse set of marginal like terms. 12 | 13 | * From softmax to sparsemax- A sparse model of attention and multi-label classification :cite:`martins2016softmax` 14 | * Differentiable dynamic programming for structured prediction and attention :cite:`mensch2018differentiable` 15 | """ 16 | @staticmethod 17 | def sum(xs, dim=-1): 18 | return _SimplexProject.apply(xs, dim) 19 | 20 | 21 | class _SimplexProject(torch.autograd.Function): 22 | @staticmethod 23 | def forward(ctx, input, dim, z=1): 24 | w_star = project_simplex(input, dim) 25 | ctx.save_for_backward(input, w_star.clone(), torch.tensor(dim)) 26 | x = input.mul(w_star).sum(dim) - w_star.norm(p=2, dim=dim) 27 | return x 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | input, w_star, dim = ctx.saved_tensors 32 | w_star.requires_grad_(True) 33 | 34 | grad_input = None 35 | if ctx.needs_input_grad[0]: 36 | wstar = _SparseMaxGrad.apply(w_star, dim) 37 | grad_input = grad_output.unsqueeze(dim).mul(wstar) 38 | return grad_input, None, None 39 | 40 | 41 | class _SparseMaxGrad(torch.autograd.Function): 42 | @staticmethod 43 | def forward(ctx, w_star, dim): 44 | ctx.save_for_backward(w_star, dim) 45 | return w_star 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | w_star, dim = ctx.saved_tensors 50 | return sparsemax_grad(grad_output, w_star, dim.item()), None 51 | 52 | 53 | def project_simplex(v, dim, z=1): 54 | v_sorted, _ = torch.sort(v, dim=dim, descending=True) 55 | cssv = torch.cumsum(v_sorted, dim=dim) - z 56 | ind = torch.arange(1, 1 + v.shape[dim]).to(dtype=v.dtype) 57 | cond = v_sorted - cssv / ind >= 0 58 | k = cond.sum(dim=dim, keepdim=True) 59 | tau = cssv.gather(dim, k - 1) / k.to(dtype=v.dtype) 60 | w = torch.clamp(v - tau, min=0) 61 | return w 62 | 63 | 64 | def sparsemax_grad(dout, w_star, dim): 65 | out = dout.clone() 66 | supp = w_star > 0 67 | out[w_star <= 0] = 0 68 | nnz = supp.to(dtype=dout.dtype).sum(dim=dim, keepdim=True) 69 | out = out - (out.sum(dim=dim, keepdim=True) / nnz) 70 | out[w_star <= 0] = 0 71 | return out 72 | -------------------------------------------------------------------------------- /src/model/vis_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.vis_encoder.base import VisEncoderBase 2 | from src.model.vis_encoder.box_rel import VisBoxRelSimpleEncoder -------------------------------------------------------------------------------- /src/model/vis_encoder/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | import torch.nn as nn 6 | 7 | if TYPE_CHECKING: 8 | from src.model import ModelBase 9 | 10 | 11 | class VisEncoderBase(nn.Module): 12 | bounded_model: ModelBase 13 | 14 | def __init__(self): 15 | super(VisEncoderBase, self).__init__() 16 | 17 | def forward(self, x, ctx): 18 | raise NotImplementedError 19 | 20 | def get_dim(self, field): 21 | raise NotImplementedError(f'Unrecognized {field=}') 22 | 23 | -------------------------------------------------------------------------------- /src/model/vis_encoder/box_rel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as tnn 3 | 4 | from src.model.vis_encoder import VisEncoderBase 5 | from src.model.nn import MLP, BiaffineScorer 6 | 7 | 8 | class VisBoxRelSimpleEncoder(VisEncoderBase): 9 | def __init__(self, n_in, n_hidden, dropout, activate, use_attr, use_img, img_feat): 10 | super().__init__() 11 | 12 | self.use_img = use_img 13 | if use_img: 14 | self.img_fc = MLP(n_in, n_hidden, dropout, activate) 15 | 16 | self.img_feat = img_feat 17 | if img_feat: 18 | n_in *= 2 19 | self.box_fc = MLP(n_in, n_hidden, dropout, activate) 20 | self.rel_fc = MLP(n_in, n_hidden, dropout, activate) 21 | # self.rel_fc = BiaffineScorer(n_in * 2, n_hidden, n_hidden, dropout, activate, 1) 22 | 23 | self.use_attr = use_attr 24 | if use_attr: 25 | self.attr_fc = MLP(n_in, n_hidden, dropout, activate) 26 | self.n_hidden = n_hidden 27 | self.dropout = dropout 28 | 29 | def forward(self, x, ctx): 30 | 31 | if self.img_feat: 32 | feat: torch.Tensor = x["vis_box_feat"] 33 | B, N, H = feat.shape 34 | box = feat 35 | inputs = torch.cat( 36 | [box, feat.mean(1, keepdim=True).expand(-1, feat.shape[1], -1)], dim=-1 37 | ) 38 | else: 39 | inputs = x["vis_box_feat"] 40 | B, N, H = inputs.shape 41 | inputs = inputs.view(B, N, H) 42 | _rel_inp = (inputs.unsqueeze(1) + inputs.unsqueeze(2)) / 2 43 | x_rel = self.rel_fc(_rel_inp) 44 | # x_rel = self.rel_fc(inputs, inputs) 45 | rel = x_rel.view(len(x_rel), -1, self.n_hidden) 46 | 47 | out = {"box": self.box_fc(inputs), "rel": rel} 48 | if self.use_attr: 49 | out["attr"] = self.attr_fc(inputs) 50 | if self.use_img: 51 | out["img"] = self.img_fc(x["vis_box_feat"].mean(1, keepdim=True)) 52 | return out 53 | 54 | def get_dim(self, field): 55 | return self.n_hidden 56 | 57 | -------------------------------------------------------------------------------- /src/utility/config.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from dataclasses import dataclass 3 | 4 | from omegaconf import MISSING, DictConfig 5 | 6 | from src.utility.logger import get_logger_func 7 | 8 | _warn, _info, _debug = get_logger_func('config') 9 | 10 | 11 | @dataclass 12 | class Config: 13 | @classmethod 14 | def build(cls, env, ignore_unknown=False, allow_missing=None): 15 | if isinstance(env, (dict, DictConfig)): 16 | if 'cfg' in env and isinstance(env['cfg'], cls): 17 | breakpoint() 18 | return env['cfg'] 19 | 20 | matched = {k: v for k, v in env.items() if k in inspect.signature(cls).parameters} 21 | unmatched = {k: env[k] 22 | for k in env.keys() - matched.keys() 23 | if not k.startswith('n_')} # n_* will be set automatically 24 | if unmatched and not ignore_unknown: 25 | raise ValueError(f'Unrecognized cfg: {unmatched}') 26 | # noinspection PyArgumentList 27 | cfg = cls(**{k: v for k, v in env.items() if k in inspect.signature(cls).parameters}) 28 | 29 | allow_missing = allow_missing or set() 30 | for key, value in cfg.__dict__.items(): 31 | if not key.startswith('_') and key not in allow_missing: 32 | assert value is not MISSING, f'{key} is MISSING.' 33 | 34 | if ignore_unknown: 35 | return cfg, unmatched 36 | return cfg 37 | elif isinstance(env, cls): 38 | return env 39 | raise TypeError 40 | 41 | def __setitem__(self, key, value): 42 | if not hasattr(self, key): 43 | _warn(f"Adding new key: {key}") 44 | setattr(self, key, value) 45 | 46 | def __getitem__(self, item): 47 | return getattr(self, item) -------------------------------------------------------------------------------- /src/utility/defaultlist.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class defaultlist(list): 5 | """ 6 | __version__ = "1.0.0" 7 | __author__ = 'c0fec0de' 8 | __author_email__ = 'c0fec0de@gmail.com' 9 | __description__ = " collections.defaultdict equivalent implementation of list." 10 | __url__ = "https://github.com/c0fec0de/defaultlist" 11 | """ 12 | 13 | # noinspection PyMissingConstructor 14 | def __init__(self, factory=None): 15 | """ 16 | List extending automatically to the maximum requested length. 17 | Keyword Args: 18 | factory: Function called for every missing index. 19 | """ 20 | self.__factory = factory or defaultlist.__nonefactory 21 | 22 | @staticmethod 23 | def __nonefactory(): 24 | return None 25 | 26 | def __fill(self, index): 27 | missing = index - len(self) + 1 28 | if missing > 0: 29 | # noinspection PyMethodFirstArgAssignment 30 | self += [self.__factory() for _ in range(missing)] 31 | 32 | def __setitem__(self, index, value): 33 | self.__fill(index) 34 | list.__setitem__(self, index, value) 35 | 36 | def __getitem__(self, index): 37 | if isinstance(index, slice): 38 | return self.__getslice(index.start, index.stop, index.step) 39 | else: 40 | self.__fill(index) 41 | return list.__getitem__(self, index) 42 | 43 | def __getslice__(self, start, stop, step=None): # pragma: no cover 44 | # python 2.x legacy 45 | if stop == sys.maxint: 46 | stop = None 47 | return self.__getslice(start, stop, step) 48 | 49 | def __normidx(self, idx, default): 50 | if idx is None: 51 | idx = default 52 | elif idx < 0: 53 | idx += len(self) 54 | return idx 55 | 56 | def __getslice(self, start, stop, step): 57 | end = max((start or 0, stop or 0, 0)) 58 | if end: 59 | self.__fill(end) 60 | start = self.__normidx(start, 0) 61 | stop = self.__normidx(stop, len(self)) 62 | step = step or 1 63 | r = defaultlist(factory=self.__factory) 64 | for idx in range(start, stop, step): 65 | r.append(list.__getitem__(self, idx)) 66 | return r 67 | 68 | def __add__(self, other): 69 | if isinstance(other, list): 70 | r = self.copy() 71 | r += other 72 | return r 73 | else: 74 | return list.__add__(self, other) 75 | 76 | def copy(self): 77 | """Return a shallow copy of the list. Equivalent to a[:].""" 78 | r = defaultlist(factory=self.__factory) 79 | r += self 80 | return r 81 | -------------------------------------------------------------------------------- /src/utility/fn.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import logging 3 | import os 4 | from functools import wraps 5 | from typing import Any, Dict, Callable, Optional, Iterator 6 | 7 | from hydra.utils import instantiate 8 | from omegaconf import ListConfig, DictConfig 9 | from pytorch_lightning import Trainer 10 | from torch import Tensor 11 | 12 | 13 | def not_distributed_guard(): 14 | import torch.distributed as dist 15 | assert not dist.is_initialized() 16 | 17 | 18 | def endless_iter(i: Iterator, shuffle: Optional[Callable] = None, inplace_shuffle: Optional[Callable] = None): 19 | while True: 20 | if shuffle is not None: 21 | i = shuffle(i) 22 | if inplace_shuffle is not None: 23 | inplace_shuffle(i) 24 | for x in i: 25 | yield x 26 | 27 | 28 | def dict_apply(d: Dict[Any, Any], func=None, key_func=None): 29 | assert func or key_func 30 | if func is None: 31 | return {key_func(key): value for key, value in d.items()} 32 | elif key_func is None: 33 | return {key: func(value) for key, value in d.items()} 34 | return {key_func(key): func(value) for key, value in d.items()} 35 | 36 | 37 | def hydra_instantiate_func_helper(func): 38 | """convert func() to func()()""" 39 | 40 | @wraps(func) 41 | def wrapper(*args, **kwargs): 42 | def mid(): 43 | return func(*args, **kwargs) 44 | 45 | return mid 46 | 47 | return wrapper 48 | 49 | 50 | def reduce_loss(mode, loss, num_token, num_sentence) -> Tensor: 51 | if not isinstance(loss, list): 52 | loss, num_token, num_sentence = [loss], [num_token], [num_sentence] 53 | assert len(loss) >= 1, 'Nothing to reduce. You should handle this error outside this function.' 54 | if mode == 'token': 55 | # average over tokens in a batch 56 | return sum(loss) / (sum(num_token) + 1e-12) 57 | elif mode == 'sentence': 58 | # first average over tokens in a sentence. 59 | # then average sentences over a batch 60 | # return sum((l / s).sum() for l, s in zip(loss, seq_len)) / (sum(len(s) for s in seq_len)) 61 | raise NotImplementedError('Deprecated') 62 | elif mode == 'batch': 63 | # average over sentences in a batch 64 | return sum(loss) / (sum(num_sentence) + 1e-12) 65 | elif mode == 'sum': 66 | return sum(loss) 67 | raise ValueError 68 | 69 | 70 | def split_list(raw, size): 71 | out = [] 72 | offset = 0 73 | for s in size: 74 | out.append(raw[offset: offset + s]) 75 | offset += s 76 | assert offset == len(raw) 77 | return out 78 | 79 | 80 | def instantiate_no_recursive(*args, **kwargs): 81 | return instantiate(*args, **kwargs, _recursive_=False) 82 | 83 | 84 | def get_coeff_iter(command, idx_getter=None, validator=None): 85 | # 1. not (list, tuple, ListConfig): constant alpha 86 | # 2. List[str]: str should be [value]@[epoch]. eg "[0@0, 0.5@100]". Linearly to value at epoch. 87 | # the first term must be @0 (from the beginning) 88 | if not isinstance(command, (list, tuple, ListConfig)): 89 | # -123456789 is never reached, so it is endless 90 | assert command != -123456789 91 | return iter(lambda: command, -123456789) 92 | 93 | if idx_getter is None: 94 | _i = 0 95 | 96 | def auto_inc(): 97 | nonlocal _i 98 | i, _i = _i, _i + 1 99 | return i 100 | 101 | idx_getter = auto_inc 102 | 103 | def calculate_alpha(value_and_step): 104 | prev_v, prev_s = value_and_step[0].split('@') 105 | prev_v, prev_s = float(prev_v), int(prev_s) 106 | assert prev_s == 0, 'the first step must be 0' 107 | idx = idx_getter() 108 | for i in range(1, len(value_and_step)): 109 | next_v, next_s = value_and_step[i].split('@') 110 | next_v, next_s = float(next_v), int(next_s) 111 | rate = (next_v - prev_v) / (next_s - prev_s) 112 | while idx <= next_s: 113 | value = prev_v + rate * (idx - prev_s) 114 | if validator is not None: 115 | assert validator(value), f'Bad value in coeff_iter. Get {value}.' 116 | yield value 117 | idx = idx_getter() 118 | prev_v, prev_s = next_v, next_s 119 | while True: 120 | yield prev_v 121 | 122 | return iter(calculate_alpha(command)) 123 | 124 | 125 | def instantiate_trainer(callbacks=None, **kwargs): 126 | if callbacks is not None: 127 | NoneType = type(None) 128 | callbacks = list(filter(lambda x: not isinstance(x, (dict, DictConfig, NoneType)), callbacks.values())) 129 | return Trainer(callbacks=callbacks, **kwargs) 130 | 131 | 132 | def pad(tensors, padding_value=0, total_length=None, padding_side='right'): 133 | size = [len(tensors)] + [max(tensor.size(i) for tensor in tensors) for i in range(len(tensors[0].size()))] 134 | if total_length is not None: 135 | assert total_length >= size[1] 136 | size[1] = total_length 137 | out_tensor = tensors[0].data.new(*size).fill_(padding_value) 138 | for i, tensor in enumerate(tensors): 139 | out_tensor[i][[slice(-i, None) if padding_side == 'left' else slice(0, i) for i in tensor.size()]] = tensor 140 | return out_tensor 141 | 142 | 143 | def filter_list(data, mask): 144 | if isinstance(mask[0], list): 145 | out = [] 146 | for subdata, submask in zip(data, mask): 147 | out.append(filter_list(subdata, submask)) 148 | return out 149 | elif isinstance(mask[0], int): 150 | return [subdata for subdata, submask in zip(data, mask) if submask] 151 | raise ValueError(f'Bad mask value: {mask}') 152 | 153 | 154 | def draw_att(data: Tensor, path=None): 155 | assert data.ndim == 2 156 | import seaborn as sns 157 | import matplotlib.pyplot as plt 158 | data = data.detach().cpu().numpy() 159 | sns.heatmap(data=data, center=0, mask=data < -100) 160 | if path: 161 | plt.savefig(path) 162 | else: 163 | plt.show() 164 | 165 | 166 | def merge_outputs(a, b): 167 | assert a.keys() == b.keys() 168 | for key in a: 169 | adata, bdata = a[key], b[key] 170 | if len(adata) > len(bdata): 171 | bdata.extend([None] * (len(adata) - len(bdata))) 172 | else: 173 | adata.extend([None] * (len(bdata) - len(adata))) 174 | a[key] = [ai if ai is not None else bi for ai, bi in zip(a[key], b[key])] 175 | return a 176 | 177 | 178 | def symlink_force(target, link_name): 179 | try: 180 | os.symlink(target, link_name) 181 | except OSError as e: 182 | if e.errno == errno.EEXIST: 183 | os.remove(link_name) 184 | os.symlink(target, link_name) 185 | else: 186 | raise e 187 | 188 | 189 | def listloggers(): 190 | rootlogger = logging.getLogger() 191 | print(rootlogger) 192 | for h in rootlogger.handlers: 193 | print(' %s' % h) 194 | 195 | for nm, lgr in logging.Logger.manager.loggerDict.items(): 196 | print('+ [%-20s] %s ' % (nm, lgr)) 197 | if not isinstance(lgr, logging.PlaceHolder): 198 | for h in lgr.handlers: 199 | print(' %s' % h) -------------------------------------------------------------------------------- /src/utility/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | from colorama import Fore 5 | from pytorch_lightning.utilities import rank_zero_only 6 | from tqdm.auto import tqdm 7 | 8 | 9 | class TqdmLoggingHandler(logging.Handler): 10 | def __init__(self, level=logging.INFO): 11 | super().__init__(level) 12 | 13 | def emit(self, record): 14 | try: 15 | msg = self.format(record) 16 | tqdm.write(msg, file=sys.stdout) 17 | except (KeyboardInterrupt, SystemExit): 18 | raise 19 | except: 20 | self.handleError(record) 21 | 22 | 23 | class ColorFormatter(logging.Formatter): 24 | def format(self, record): 25 | 26 | # Save the original format configured by the user 27 | # when the logger formatter was instantiated 28 | format_orig = self._style._fmt 29 | 30 | # Replace the original format with one customized by logging level 31 | if record.levelno == logging.DEBUG: 32 | self._style._fmt = Fore.YELLOW + format_orig + Fore.RESET 33 | 34 | elif record.levelno >= logging.WARNING: 35 | self._style._fmt = Fore.RED + format_orig + Fore.RESET 36 | 37 | # Call the original formatter class to do the grunt work 38 | result = logging.Formatter.format(self, record) 39 | 40 | # Restore the original format configured by the user 41 | self._style._fmt = format_orig 42 | 43 | return result 44 | 45 | 46 | def get_logger_func(name): 47 | log = logging.getLogger(name) 48 | 49 | def _warn(*args, stacklevel: int = 2, **kwargs): 50 | kwargs["stacklevel"] = stacklevel 51 | log.warning(*args, **kwargs) 52 | 53 | def _info(*args, stacklevel: int = 2, **kwargs): 54 | kwargs["stacklevel"] = stacklevel 55 | log.info(*args, **kwargs) 56 | 57 | def _debug(*args, stacklevel: int = 2, **kwargs): 58 | kwargs["stacklevel"] = stacklevel 59 | log.debug(*args, **kwargs) 60 | 61 | return _warn, _info, _debug 62 | # return rank_zero_only(_warn), rank_zero_only(_info), rank_zero_only(_debug) 63 | -------------------------------------------------------------------------------- /src/utility/meta.py: -------------------------------------------------------------------------------- 1 | class Singleton(type): 2 | _instances = {} 3 | 4 | def __call__(cls, *args, **kwargs): 5 | if cls not in cls._instances: 6 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 7 | return cls._instances[cls] 8 | -------------------------------------------------------------------------------- /src/utility/scheduler.py: -------------------------------------------------------------------------------- 1 | # noinspection PyUnresolvedReferences 2 | import logging 3 | # noinspection PyUnresolvedReferences 4 | import math 5 | 6 | # noinspection PyUnresolvedReferences 7 | import numpy as np 8 | from torch.optim import lr_scheduler 9 | # noinspection PyUnresolvedReferences 10 | from transformers import (get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, 11 | get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup, 12 | get_polynomial_decay_schedule_with_warmup) 13 | 14 | from src.utility.logger import get_logger_func 15 | 16 | _warn, _info, _debug = get_logger_func('scheduler') 17 | 18 | 19 | def get_exponential_lr_scheduler(optimizer, gamma, **kwargs): 20 | if isinstance(gamma, str): 21 | gamma = eval(gamma) 22 | _debug(f'gamma is converted to {gamma} {type(gamma)}') 23 | kwargs['gamma'] = gamma 24 | return lr_scheduler.ExponentialLR(optimizer, **kwargs) 25 | 26 | 27 | def get_reduce_lr_on_plateau_scheduler(optimizer, **kwargs): 28 | return lr_scheduler.ReduceLROnPlateau(optimizer, **kwargs) 29 | 30 | 31 | def get_lr_lambda_scheduler(optimizer, lr_lambda, **kwargs): 32 | if isinstance(lr_lambda, str): 33 | lr_lambda = eval(lr_lambda) 34 | _debug(f'lr_lambda is converted to {lr_lambda} {type(lr_lambda)}') 35 | kwargs['lr_lambda'] = lr_lambda 36 | return lr_scheduler.LambdaLR(optimizer, **kwargs) 37 | -------------------------------------------------------------------------------- /src/utility/spacy_helper.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | from spacy.tokens import Doc 4 | 5 | 6 | class PretokenizedTokenizer: 7 | """Custom tokenizer to be used in spaCy when the text is already pretokenized.""" 8 | 9 | def __init__(self, vocab): 10 | """Initialize tokenizer with a given vocab 11 | :param vocab: an existing vocabulary (see https://spacy.io/api/vocab) 12 | """ 13 | self.vocab = vocab 14 | 15 | def __call__(self, inp: Union[List[str], str]): 16 | """Call the tokenizer on input `inp`. 17 | :param inp: either a string to be split on whitespace, or a list of tokens 18 | :return: the created Doc object 19 | """ 20 | if isinstance(inp, str): 21 | words = inp.split() 22 | spaces = [True] * (len(words) - 1) + ([True] if inp[-1].isspace() else [False]) 23 | return Doc(self.vocab, words=words, spaces=spaces) 24 | elif isinstance(inp, list): 25 | return Doc(self.vocab, words=inp) 26 | else: 27 | raise ValueError("Unexpected input format. Expected string to be split on whitespace, or list of tokens.") 28 | -------------------------------------------------------------------------------- /src/utility/var_pool.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Union 2 | 3 | from fastNLP.core.utils import seq_len_to_mask 4 | from torch import Tensor 5 | 6 | 7 | class VarPool: 8 | def __init__(self, **kwargs): 9 | self._pool = {} 10 | self._lazy_func = {} 11 | self._circle_trace = [] 12 | 13 | for key, value in kwargs.items(): 14 | self._pool[key] = value 15 | 16 | self.add_lazy('seq_len', 'batch_size', lambda x: len(x)) 17 | self.add_lazy('seq_len', 'max_len', lambda x: max(x)) 18 | self.add_lazy('seq_len', 'num_token', lambda x: sum(x)) 19 | self.add_lazy(['seq_len', 'max_len'], 'mask', lambda x, y: seq_len_to_mask(x, y)) 20 | 21 | def add_lazy(self, source: Union[str, List[str]], target: str, func: Callable, overwrite=False): 22 | assert overwrite or target not in self._lazy_func, f'{target=}' 23 | if isinstance(source, str): 24 | source = [source] 25 | self._lazy_func[target] = (source, func) 26 | 27 | def select(self, mask): 28 | new_vp = VarPool() 29 | for key, value in self._pool.items(): 30 | if key in ('batch_size', 'max_len'): 31 | continue 32 | if key.endswith('_cpu') or key.endswith('_cuda'): 33 | continue 34 | if not isinstance(value, Tensor): 35 | continue 36 | new_vp.add_lazy([], key, lambda v=value: v[mask], overwrite=True) 37 | for key, value in self._lazy_func.items(): 38 | if key not in new_vp._lazy_func and not key.endswith('cuda') and not key.endswith('cpu'): 39 | new_vp.add_lazy(value[0], key, value[1], overwrite=True) 40 | return new_vp 41 | 42 | def __getitem__(self, item): 43 | if item in self._pool: 44 | return self._pool[item] 45 | if item in self._lazy_func: 46 | source, func = self._lazy_func[item] 47 | self._circle_trace.append(item) 48 | assert not any(map(lambda s: s in self._circle_trace, source)) 49 | source = [self[s] for s in source] 50 | self._circle_trace.pop() 51 | target = func(*source) 52 | self[item] = target 53 | return target 54 | name, device = item.rsplit('_', 1) 55 | if device in ('cuda', 'cpu'): 56 | value = self[name].to(device) 57 | self._pool[item] = value 58 | return value 59 | raise KeyError(f'No {item}.') 60 | 61 | def __setitem__(self, key, value): 62 | self._pool[key] = value 63 | if isinstance(value, Tensor): 64 | self.add_lazy(key, key + '_cuda', lambda x: x if x.device.type == 'cuda' else x.cuda()) 65 | self.add_lazy(key, key + '_cpu', lambda x: x if x.device.type == 'cpu' else x.cpu()) 66 | 67 | def __getattr__(self, item): 68 | return self[item] 69 | 70 | def __setattr__(self, key, value): 71 | if key.startswith('_'): 72 | super().__setattr__(key, value) 73 | else: 74 | self._pool[key] = value 75 | 76 | def __contains__(self, key): 77 | return key in self._pool or key in self._lazy_func 78 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os.path 4 | from pathlib import Path 5 | 6 | import hydra 7 | import pytorch_lightning as pl 8 | from hydra import compose 9 | from hydra.utils import HydraConfig, instantiate 10 | from omegaconf import DictConfig 11 | from omegaconf import OmegaConf 12 | 13 | import src 14 | from src import datamodule 15 | from src.datamodule import DataModule 16 | from src.pipeline import Pipeline 17 | from src.utility.fn import instantiate_no_recursive 18 | from src.utility.pl_callback import NNICallback 19 | 20 | log = logging.getLogger(__name__) 21 | 22 | 23 | @hydra.main('config', 'config_test') 24 | def test(cfg: DictConfig): 25 | if (seed := cfg.seed) is not None: 26 | pl.seed_everything(seed) 27 | 28 | if cfg.pipeline.load_from_checkpoint is None: 29 | log.warning('Testing a random-initialized model.') 30 | 31 | if (p := cfg.pipeline.load_from_checkpoint) is not None: 32 | p = Path(p) 33 | if len(p.parts) >= 2 and p.parts[-2] == 'checkpoint': 34 | config_folder = p.parents[1] / 'config' 35 | else: 36 | config_folder = p.parent / 'config' 37 | if config_folder.exists(): 38 | # Load saved config. 39 | # Note that this only load overrides. Inconsistency happens if you change sub-config's file. 40 | # From Hydra's author: 41 | # https://stackoverflow.com/questions/67170653/how-to-load-hydra-parameters-from-previous-jobs-without-having-to-use-argparse/67172466?noredirect=1 42 | log.info('Loading saved overrides') 43 | original_overrides = OmegaConf.load(config_folder / 'overrides.yaml') 44 | current_overrides = HydraConfig.get().overrides.task 45 | # hydra_config = OmegaConf.load(config_folder / 'hydra.yaml') 46 | config_name = 'config_test' # hydra_config.hydra.job.config_name 47 | overrides = original_overrides + current_overrides 48 | # noinspection PyTypeChecker 49 | cfg = compose(config_name, overrides=overrides) 50 | if os.path.exists(config_folder / 'nni.json'): 51 | with open(config_folder / 'nni.json') as f: 52 | nni_overrides = json.load(f) 53 | NNICallback.setup_cfg(nni_overrides, cfg) 54 | log.info(OmegaConf.to_yaml(cfg)) 55 | 56 | src.g_cfg = cfg 57 | 58 | trainer: pl.Trainer = instantiate(cfg.trainer) 59 | src.trainer = trainer 60 | 61 | datamodule: DataModule = instantiate_no_recursive(cfg.datamodule) 62 | pipeline: Pipeline = instantiate_no_recursive(cfg.pipeline, dm=datamodule) 63 | output_name = cfg.get('output_name', 'predict') 64 | datamodule.setup('test') 65 | 66 | trainer.test(pipeline, dataloaders=datamodule.dataloader('train')) 67 | pipeline.write_prediction(output_name + '_train.conll', 'train', pipeline._test_outputs[0]) 68 | trainer.test(pipeline, dataloaders=datamodule.dataloader('dev')) 69 | pipeline.write_prediction(output_name + '_dev.conll', 'dev', pipeline._test_outputs[0]) 70 | trainer.test(pipeline, dataloaders=datamodule.dataloader('test')) 71 | pipeline.write_prediction(output_name + '_test.conll', 'test', pipeline._test_outputs[0]) 72 | 73 | 74 | if __name__ == '__main__': 75 | test() 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | import os.path 6 | import random 7 | import string 8 | from pathlib import Path 9 | 10 | import hydra 11 | import pytorch_lightning as pl 12 | from hydra import compose 13 | from hydra.utils import HydraConfig, instantiate 14 | from omegaconf import DictConfig 15 | from omegaconf import OmegaConf 16 | 17 | import src 18 | from src.datamodule import DataModule 19 | from src.pipeline import Pipeline 20 | from src.utility.fn import instantiate_no_recursive 21 | from src.utility.fn import symlink_force 22 | from src.utility.logger import get_logger_func 23 | from src.utility.pl_callback import BestWatcherCallback 24 | from src.utility.pl_callback import NNICallback 25 | 26 | _warn, _info, _debug = get_logger_func('main') 27 | 28 | 29 | @hydra.main('config', 'config_train') 30 | def train(cfg: DictConfig): 31 | src.g_cfg = cfg 32 | _info(f'Working directory: {os.getcwd()}') 33 | 34 | outputs_root = os.path.join(cfg.root, 'outputs') 35 | if os.path.exists(outputs_root): 36 | symlink_force(os.getcwd(), os.path.join(outputs_root, '0_latest_run')) 37 | 38 | if cfg.name == '@@@AUTO@@@': 39 | # In the case we can not set name={hydra:job.override_dirname} in config.yaml, e.g., multirun 40 | cfg.name = HydraConfig.get().job.override_dirname 41 | 42 | # init multirun 43 | if (num := HydraConfig.get().job.get('num')) is not None and num > 1: 44 | # set group in wandb, if use joblib, this will be set from joblib. 45 | if 'MULTIRUN_ID' not in os.environ: 46 | os.environ['MULTIRUN_ID'] = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(4)) 47 | if 'logger' in cfg.trainer and 'tags' in cfg.trainer.logger: 48 | cfg.trainer.logger.tags.append(os.environ['MULTIRUN_ID']) 49 | 50 | if (config_folder := cfg.load_cfg_from_checkpoint) is not None: 51 | # Load saved config. 52 | # Note that this only load overrides. Inconsistency happens if you change sub-config's file. 53 | # From Hydra's author: 54 | # https://stackoverflow.com/questions/67170653/how-to-load-hydra-parameters-from-previous-jobs-without-having-to-use-argparse/67172466?noredirect=1 55 | _info('Loading saved overrides') 56 | config_folder = Path(config_folder) 57 | original_overrides = OmegaConf.load(config_folder / 'overrides.yaml') 58 | current_overrides = HydraConfig.get().overrides.task 59 | # hydra_config = OmegaConf.load(config_folder / 'hydra.yaml') 60 | config_name = 'conf' # hydra_config.hydra.job.config_name 61 | overrides = original_overrides + current_overrides 62 | # noinspection PyTypeChecker 63 | cfg = compose(config_name, overrides=overrides) 64 | if os.path.exists(config_folder / 'nni.json'): 65 | with open(config_folder / 'nni.json') as f: 66 | nni_overrides = json.load(f) 67 | NNICallback.setup_cfg(nni_overrides, cfg) 68 | _info(OmegaConf.to_yaml(cfg)) 69 | src.g_cfg = cfg 70 | 71 | if (seed := cfg.seed) is not None: 72 | pl.seed_everything(seed) 73 | # torch.use_deterministic_algorithms(True) 74 | # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 75 | 76 | assert not (cfg.pipeline.load_from_checkpoint is not None and cfg.trainer.resume_from_checkpoint is not None), \ 77 | 'You should not use load_from_checkpoint and resume_from_checkpoint at the same time.' 78 | assert not cfg.watch_field.startswith('test/'), 'You should not use test set to tune hparams.' 79 | 80 | trainer: pl.Trainer = instantiate(cfg.trainer) 81 | src.trainer = trainer 82 | if 'optimized_metric' in cfg: 83 | assert any(isinstance(c, BestWatcherCallback) for c in trainer.callbacks) 84 | 85 | datamodule: DataModule = instantiate_no_recursive(cfg.datamodule) 86 | pipeline: Pipeline = instantiate_no_recursive(cfg.pipeline, dm=datamodule) 87 | trainer.fit(pipeline, datamodule) 88 | 89 | ckpt_path = "best" 90 | trainer.test(model=pipeline, datamodule=datamodule, ckpt_path=ckpt_path) 91 | 92 | _info(f'Working directory: {os.getcwd()}') 93 | 94 | # Return metric score for hyperparameter optimization 95 | callbacks = trainer.callbacks 96 | for c in callbacks: 97 | if isinstance(c, BestWatcherCallback): 98 | if c.best_model_path: 99 | _info(f'Best ckpt: {c.best_model_path}') 100 | if 'optimized_metric' in cfg: 101 | return c.best_model_metric[cfg.optimized_metric] 102 | break 103 | 104 | 105 | if __name__ == '__main__': 106 | train() 107 | --------------------------------------------------------------------------------