├── README.md
├── bashutil.sh
├── configs
├── generating.yaml
└── training.yaml
├── generate.sh
├── mybycha
├── .editorconfig
├── .flake8
├── .gitignore
├── bycha
│ ├── __init__.py
│ ├── criteria
│ │ ├── __init__.py
│ │ ├── abstract_criterion.py
│ │ ├── auto_encoding_loss.py
│ │ ├── base_criterion.py
│ │ ├── cross_entropy.py
│ │ ├── focal_cross_entropy.py
│ │ ├── label_smoothed_cross_entropy.py
│ │ ├── label_smoothed_ctc.py
│ │ ├── mse.py
│ │ ├── multitask_criterion.py
│ │ └── self_contained_loss.py
│ ├── dataloaders
│ │ ├── __init__.py
│ │ ├── abstract_dataLoader.py
│ │ ├── binarized_dataloader.py
│ │ ├── in_memory_dataloader.py
│ │ └── streaming_dataloader.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── abstract_dataset.py
│ │ ├── data_map_dataset.py
│ │ ├── in_memory_dataset.py
│ │ ├── json_dataset.py
│ │ ├── parallel_text_dataset.py
│ │ ├── streaming_dataset.py
│ │ ├── streaming_json_dataset.py
│ │ ├── streaming_parallel_text_dataset.py
│ │ ├── streaming_text_dataset.py
│ │ ├── text_dataset.py
│ │ └── tfrecord_dataset.py
│ ├── entries
│ │ ├── __init__.py
│ │ ├── binarize_data.py
│ │ ├── build_tokenizer.py
│ │ ├── export.py
│ │ ├── preprocess.py
│ │ ├── run.py
│ │ ├── serve.py
│ │ ├── serve_model.py
│ │ └── util.py
│ ├── evaluators
│ │ ├── __init__.py
│ │ ├── abstract_evaluator.py
│ │ ├── evaluator.py
│ │ └── multi_evaluator.py
│ ├── generators
│ │ ├── __init__.py
│ │ ├── abstract_generator.py
│ │ ├── extraction_generator.py
│ │ ├── generator.py
│ │ ├── self_contained_generator.py
│ │ └── sequence_generator.py
│ ├── metrics
│ │ ├── __init__.py
│ │ ├── abstract_metric.py
│ │ ├── accuracy.py
│ │ ├── bleu.py
│ │ ├── f1.py
│ │ ├── matthews_corr.py
│ │ ├── pairwise_metric.py
│ │ ├── pearson_corr.py
│ │ └── spearman_corr.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── abstract_encoder_decoder_model.py
│ │ ├── abstract_model.py
│ │ ├── bert_model.py
│ │ ├── encoder_decoder_model.py
│ │ ├── huggingface
│ │ │ ├── __init__.py
│ │ │ ├── huggingface_extractive_question_answering_model.py
│ │ │ ├── huggingface_pretrain_bart_model.py
│ │ │ ├── huggingface_pretrain_mbart_model.py
│ │ │ └── huggingface_sequence_classification_model.py
│ │ ├── seq2seq.py
│ │ ├── sequence_classification_model.py
│ │ ├── torch_transformer.py
│ │ └── variational_auto_encoder.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── decoders
│ │ │ ├── __init__.py
│ │ │ ├── abstract_decoder.py
│ │ │ ├── autopruning_decoder.py
│ │ │ ├── dlcl_decoder.py
│ │ │ ├── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── abstract_decoder_layer.py
│ │ │ │ ├── autopruning_decoder_layer.py
│ │ │ │ ├── lstm_decoder_layer.py
│ │ │ │ ├── moe_decoder_layer.py
│ │ │ │ ├── nonauto_transformer_decoder_layer.py
│ │ │ │ ├── strucdrop_decoder_layer.py
│ │ │ │ └── transformer_decoder_layer.py
│ │ │ ├── lstm_decoder.py
│ │ │ ├── moe_decoder.py
│ │ │ ├── nonauto_transformer_decoder.py
│ │ │ ├── strucdrop_decoder.py
│ │ │ └── transformer_decoder.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ ├── abstract_encoder.py
│ │ │ ├── autopruning_encoder.py
│ │ │ ├── dlcl_encoder.py
│ │ │ ├── huggingface_encoder.py
│ │ │ ├── key_value_transformer_encoder.py
│ │ │ ├── layers
│ │ │ │ ├── __init__.py
│ │ │ │ ├── abstract_encoder_layer.py
│ │ │ │ ├── autopruning_encoder_layer.py
│ │ │ │ ├── moe_encoder_layer.py
│ │ │ │ ├── strucdrop_encoder_layer.py
│ │ │ │ └── transformer_encoder_layer.py
│ │ │ ├── lstm_encoder.py
│ │ │ ├── moe_encoder.py
│ │ │ ├── multi_encoder_wrapper.py
│ │ │ ├── strucdrop_encoder.py
│ │ │ ├── transformer_encoder.py
│ │ │ └── vae_encoder.py
│ │ ├── layers
│ │ │ ├── __init__.py
│ │ │ ├── autopruning_ffn.py
│ │ │ ├── bert_layer_norm.py
│ │ │ ├── classifier.py
│ │ │ ├── dlcl.py
│ │ │ ├── embedding.py
│ │ │ ├── feed_forward.py
│ │ │ ├── gaussian.py
│ │ │ ├── gumbel.py
│ │ │ ├── layerdrop.py
│ │ │ ├── learned_positional_embedding.py
│ │ │ ├── moe.py
│ │ │ └── sinusoidal_positional_embedding.py
│ │ ├── search
│ │ │ ├── __init__.py
│ │ │ ├── abstract_search.py
│ │ │ ├── beam_search.py
│ │ │ ├── forward_sampling.py
│ │ │ ├── greedy_search.py
│ │ │ └── sequence_search.py
│ │ └── utils.py
│ ├── optim
│ │ ├── __init__.py
│ │ └── optimizer.py
│ ├── samplers
│ │ ├── __init__.py
│ │ ├── abstract_sampler.py
│ │ ├── batch_shuffle_sampler.py
│ │ ├── bucket_sampler.py
│ │ ├── distributed_sampler.py
│ │ ├── sequential_sampler.py
│ │ └── shuffled_sampler.py
│ ├── services
│ │ ├── __init__.py
│ │ ├── idls
│ │ │ ├── __init__.py
│ │ │ ├── bycha.thrift
│ │ │ └── model_infer.thrift
│ │ ├── model_server.py
│ │ └── server.py
│ ├── tasks
│ │ ├── __init__.py
│ │ ├── abstract_task.py
│ │ ├── auto_encoding_task.py
│ │ ├── base_task.py
│ │ ├── extractive_question_answering_task.py
│ │ ├── masked_lm_task.py
│ │ ├── seq2seq_task.py
│ │ ├── sequence_classification_task.py
│ │ ├── sequence_regression_task.py
│ │ └── translation_task.py
│ ├── tokenizers
│ │ ├── __init__.py
│ │ ├── abstract_tokenizer.py
│ │ ├── fastbpe.py
│ │ ├── huggingface_tokenizer.py
│ │ ├── sentencepiece.py
│ │ ├── utils.py
│ │ └── vocabulary.py
│ ├── trainers
│ │ ├── __init__.py
│ │ ├── abstract_trainer.py
│ │ ├── moe_trainer.py
│ │ ├── trainer.py
│ │ └── trainer_autopruning.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── io.py
│ │ ├── ops.py
│ │ ├── profiling.py
│ │ ├── rate_schedulers
│ │ ├── __init__.py
│ │ ├── abstract_rate_scheduler.py
│ │ ├── constant_rate_scheduler.py
│ │ ├── inverse_square_root_rate_scheduler.py
│ │ ├── logistic_scheduler.py
│ │ ├── noam_scheduler.py
│ │ └── polynomial_decay_scheduler.py
│ │ ├── registry.py
│ │ ├── runtime.py
│ │ ├── tensor.py
│ │ └── txc_utils.py
├── requirements.txt
└── setup.py
├── pics
├── overview.png
└── sketch_and_generate.png
├── preparation
├── fragmenizer
│ ├── __init__.py
│ ├── atom_fragmenizer.py
│ ├── brics_fragmenizer.py
│ ├── brics_ring_r_fragmenizer.py
│ └── ring_r_fragmenizer.py
├── get_fragment_vocab.py
├── get_training_data.py
└── utils
│ ├── __init__.py
│ ├── common.py
│ ├── molecule_preparation.py
│ ├── shape_utils.py
│ ├── tfbio_data.py
│ ├── utils.py
│ └── utils_full.py
├── shape_pretraining
├── __init__.py
├── common.py
├── io.py
├── shape_pretraining_criterion_no_regression.py
├── shape_pretraining_dataloader_shard.py
├── shape_pretraining_dataset.py
├── shape_pretraining_dataset_pocket.py
├── shape_pretraining_dataset_shard.py
├── shape_pretraining_decoder_iterative_no_regression.py
├── shape_pretraining_encoder.py
├── shape_pretraining_iterator_no_regression.py
├── shape_pretraining_model.py
├── shape_pretraining_search_forward_sampling_dock_dedup_iterative_no_regression.py
├── shape_pretraining_search_iterative_no_regression.py
├── shape_pretraining_task_no_regression.py
├── shape_pretraining_task_no_regression_pocket.py
├── tfbio_data.py
└── utils.py
├── sketch
├── shape_utils.py
└── sketching.py
└── train.sh
/README.md:
--------------------------------------------------------------------------------
1 | # DESERT
2 | Zero-Shot 3D Drug Design by Sketching and Generating (NeurIPS 2022)
3 |
4 |
6 |
7 |

8 |
9 |
10 |

11 |
12 |
13 | P.s. Because the project is too tied to ByteDance infrastructure, we can not sure that it can run on your device painlessly.
14 |
15 | ## Requirement
16 | Our method is powered by an old version of [ParaGen](https://github.com/bytedance/ParaGen) (previous name ByCha).
17 |
18 | Install it with
19 | ```bash
20 | cd mybycha
21 | pip install -e .
22 | pip install horovod
23 | pip install lightseq
24 | ```
25 | You also need to install
26 | ```bash
27 | conda install -c "conda-forge/label/cf202003" openbabel # recommend using anaconda for this project
28 | pip install rdkit-pypi
29 | pip install pybel scikit-image pebble meeko==0.1.dev1 vina pytransform3d
30 | ```
31 |
32 | ## Pre-training
33 |
34 | ### Data Preparation
35 | Our training data was extracted from the open molecule database [ZINC](https://zinc.docking.org/). You need to download it first.
36 |
37 | To get the fragment vocabulary
38 | ```bash
39 | cd preparation
40 | python get_fragment_vocab.py # fill blank paths in the file first
41 | ```
42 |
43 | To get the training data
44 | ```bash
45 | python get_training_data.py # fill blank paths in the file first
46 | ```
47 |
48 | We also provide partial training data and vocabulary [Here](https://drive.google.com/drive/folders/1T2tKgILJAIMK6uTuhh3-qV-Ib0JVgaBs?usp=sharing).
49 |
50 | ### Training Shape2Mol Model
51 |
52 | You need to fill blank paths in configs/training.yaml and train.sh.
53 |
54 | ```bash
55 | bash train.sh
56 | ```
57 |
58 | We also provide a trained checkpoint [Here](https://drive.google.com/file/d/1YCRORU5aMJEMO8hDT_o9uKCXmXTL5_5N/view?usp=sharing).
59 |
60 | ## Design Molecules
61 |
62 | ### Sketching
63 |
64 | For a given protein, you need to get its pocket by using [CAVITY](http://www.pkumdl.cn:8000/cavityplus/computation.php).
65 |
66 | Sampling molecular shapes with
67 | ```bash
68 | cd sketch
69 | python sketching.py # fill blank paths in the file first
70 | ```
71 |
72 | ### Generating
73 |
74 | ```bash
75 | bash generate.sh # fill blank paths in the file first
76 | ```
77 |
78 | ## Citation
79 | ```
80 | @inproceedings{long2022DESERT,
81 | title={Zero-Shot 3D Drug Design by Sketching and Generating},
82 | author={Long, Siyu and Zhou, Yi and Dai, Xinyu and Zhou, Hao},
83 | booktitle={NeurIPS},
84 | year={2022}
85 | }
86 | ```
87 |
--------------------------------------------------------------------------------
/bashutil.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | function parse_args(){
5 | while [[ "$#" -gt 0 ]]; do
6 | found=0
7 | for key in "${!BASH_ARGS[@]}"; do
8 | if [[ "--$key" == "$1" ]] ; then
9 | BASH_ARGS[$key]=$2
10 | found=1
11 | fi
12 | done
13 | if [[ $found == 0 ]]; then
14 | echo "arg $1 not defined!" >&2
15 | exit 1
16 | fi
17 | shift; shift
18 | done
19 |
20 | echo "======== PARSED BASH ARGS ========" >&2
21 | for key in "${!BASH_ARGS[@]}"; do
22 | echo " $key = ${BASH_ARGS[$key]}" >&2
23 | eval "$key=${BASH_ARGS[$key]}" >&2
24 | done
25 | echo "==================================" >&2
26 | }
27 |
--------------------------------------------------------------------------------
/configs/generating.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | class: ShapePretrainingTaskNoRegressionPocket
3 | mode: evaluate
4 |
5 | # for training efficiency
6 | max_seq_len: 20
7 |
8 | # for get molecule shape
9 | grid_resolution: 0.5
10 | max_dist_stamp: 4.0
11 | max_dist: 6.75
12 | patch_size: 4
13 |
14 | # for molecule augmentation
15 | rotation_bin: 24
16 | max_translation: 1.0
17 |
18 | delta_input: False
19 |
20 | data:
21 | train:
22 | class: ShapePretrainingDatasetShard
23 | path: ---TRAINING DATA PATH---
24 | vocab_path: ---VOCAB PATH---
25 | sample_each_shard: 500000
26 | shuffle: True
27 | valid:
28 | class: ShapePretrainingDatasetPocket
29 | path:
30 | samples: ---VALID DATA PATH---
31 | vocab: ---VOCAB PATH---
32 | test:
33 | class: ShapePretrainingDatasetPocket
34 | path:
35 | samples: ---TEST DATA PATH---
36 | vocab: ---VOCAB PATH---
37 |
38 | dataloader:
39 | train:
40 | class: ShapePretrainingDataLoaderShard
41 | max_samples: 64
42 | valid:
43 | class: InMemoryDataLoader
44 | sampler:
45 | class: SequentialSampler
46 | max_samples: 128
47 | test:
48 | class: InMemoryDataLoader
49 | sampler:
50 | class: SequentialSampler
51 | max_samples: 128
52 |
53 | trainer:
54 | class: Trainer
55 | optimizer:
56 | class: AdamW
57 | lr:
58 | class: InverseSquareRootRateScheduler
59 | rate: 5e-4
60 | warmup_steps: 4000
61 | clip_norm: 0.
62 | betas: (0.9, 0.98)
63 | eps: 1e-8
64 | weight_decay: 1e-2
65 | update_frequency: 4
66 | max_steps: 300000
67 | log_interval: 100
68 | validate_interval_step: 500
69 | assess_reverse: True
70 |
71 | model:
72 | class: ShapePretrainingModel
73 | encoder:
74 | class: ShapePretrainingEncoder
75 | patch_size: 4
76 | num_layers: 12
77 | d_model: 1024
78 | n_head: 8
79 | dim_feedforward: 4096
80 | dropout: 0.1
81 | activation: 'relu'
82 | learn_pos: True
83 | decoder:
84 | class: ShapePretrainingDecoderIterativeNoRegression
85 | num_layers: 12
86 | d_model: 1024
87 | n_head: 8
88 | dim_feedforward: 4096
89 | dropout: 0.1
90 | activation: 'relu'
91 | learn_pos: True
92 | iterative_num: 1
93 | max_dist: 6.75
94 | grid_resolution: 0.5
95 | iterative_block:
96 | class: ShapePretrainingIteratorNoRegression
97 | num_layers: 3
98 | d_model: 1024
99 | n_head: 8
100 | dim_feedforward: 4096
101 | dropout: 0.1
102 | activation: 'relu'
103 | learn_pos: True
104 | d_model: 1024
105 | share_embedding: decoder-input-output
106 |
107 | criterion:
108 | class: ShapePretrainingCriterionNoRegression
109 |
110 | generator:
111 | class: SequenceGenerator
112 | search:
113 | class: ShapePretrainingSearchForwardSamplingDockDedupIterativeNoRegression
114 | maxlen_coef: (0.0, 20.0)
115 | num_return_sequence: 1000
116 | fnum_return_sequence: 100
117 |
118 | evaluator:
119 | class: Evaluator
120 | metric:
121 | acc:
122 | class: Accuracy
123 |
124 | env:
125 | device: cuda
126 | fp16: True
127 |
--------------------------------------------------------------------------------
/configs/training.yaml:
--------------------------------------------------------------------------------
1 | task:
2 | class: ShapePretrainingTaskNoRegression
3 | mode: train
4 |
5 | # for training efficiency
6 | max_seq_len: 20
7 |
8 | # for get molecule shape
9 | grid_resolution: 0.5
10 | max_dist_stamp: 4.0
11 | max_dist: 6.75
12 | patch_size: 4
13 |
14 | # for molecule augmentation
15 | rotation_bin: 24
16 | max_translation: 1.0
17 |
18 | delta_input: False
19 |
20 | data:
21 | train:
22 | class: ShapePretrainingDatasetShard
23 | path: ---TRAINING DATA PATH---
24 | vocab_path: ---VOCAB PATH---
25 | sample_each_shard: 500000
26 | shuffle: True
27 | valid:
28 | class: ShapePretrainingDataset
29 | path:
30 | samples: ---VALID DATA PATH---
31 | vocab: ---VOCAB PATH---
32 | test:
33 | class: ShapePretrainingDataset
34 | path:
35 | samples: ---TEST DATA PATH---
36 | vocab: ---VOCAB PATH---
37 |
38 | dataloader:
39 | train:
40 | class: ShapePretrainingDataLoaderShard
41 | max_samples: 64
42 | valid:
43 | class: InMemoryDataLoader
44 | sampler:
45 | class: SequentialSampler
46 | max_samples: 128
47 | test:
48 | class: InMemoryDataLoader
49 | sampler:
50 | class: SequentialSampler
51 | max_samples: 128
52 |
53 | trainer:
54 | class: Trainer
55 | optimizer:
56 | class: AdamW
57 | lr:
58 | class: InverseSquareRootRateScheduler
59 | rate: 5e-4
60 | warmup_steps: 4000
61 | clip_norm: 0.
62 | betas: (0.9, 0.98)
63 | eps: 1e-8
64 | weight_decay: 1e-2
65 | update_frequency: 4
66 | max_steps: 300000
67 | log_interval: 100
68 | validate_interval_step: 500
69 | assess_reverse: True
70 |
71 | model:
72 | class: ShapePretrainingModel
73 | encoder:
74 | class: ShapePretrainingEncoder
75 | patch_size: 4
76 | num_layers: 12
77 | d_model: 1024
78 | n_head: 8
79 | dim_feedforward: 4096
80 | dropout: 0.1
81 | activation: 'relu'
82 | learn_pos: True
83 | decoder:
84 | class: ShapePretrainingDecoderIterativeNoRegression
85 | num_layers: 12
86 | d_model: 1024
87 | n_head: 8
88 | dim_feedforward: 4096
89 | dropout: 0.1
90 | activation: 'relu'
91 | learn_pos: True
92 | iterative_num: 1
93 | max_dist: 6.75
94 | grid_resolution: 0.5
95 | iterative_block:
96 | class: ShapePretrainingIteratorNoRegression
97 | num_layers: 3
98 | d_model: 1024
99 | n_head: 8
100 | dim_feedforward: 4096
101 | dropout: 0.1
102 | activation: 'relu'
103 | learn_pos: True
104 | d_model: 1024
105 | share_embedding: decoder-input-output
106 |
107 | criterion:
108 | class: ShapePretrainingCriterionNoRegression
109 |
110 | generator:
111 | class: SequenceGenerator
112 | search:
113 | class: ShapePretrainingSearchIterativeNoRegression
114 | maxlen_coef: (0.0, 20.0)
115 |
116 | evaluator:
117 | class: Evaluator
118 | metric:
119 | acc:
120 | class: Accuracy
121 |
122 | env:
123 | device: cuda
124 | fp16: True
125 |
--------------------------------------------------------------------------------
/generate.sh:
--------------------------------------------------------------------------------
1 | bycha-run \
2 | --config configs/generating.yaml \
3 | --lib shape_pretraining \
4 | --task.mode evaluate \
5 | --task.data.train.path data \
6 | --task.data.valid.path.samples ❗❗❗FILL_THIS(MOLECULE SHAPES SAMPLED FROM CAVITY)❗❗❗ \
7 | --task.data.test.path.samples ❗❗❗FILL_THIS❗❗❗ \
8 | --task.dataloader.train.max_samples 1 \
9 | --task.dataloader.valid.sampler.max_samples 1 \
10 | --task.dataloader.test.sampler.max_samples 1 \
11 | --task.model.path ❗❗❗FILL_THIS❗❗❗ \
12 | --task.evaluator.save_hypo_dir ❗❗❗FILL_THIS❗❗❗
--------------------------------------------------------------------------------
/mybycha/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | end_of_line = lf
5 | insert_final_newline = true
6 | max_line_length = 120
7 |
8 | [Makefile]
9 | indent_style = tab
10 |
11 | [*.py]
12 | indent_style = space
13 | indent_size = 4
14 |
15 | [*.{js,ts,html}]
16 | indent_style = space
17 | indent_size = 2
--------------------------------------------------------------------------------
/mybycha/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length=120
3 | exclude =
4 | venv/,
5 | venv_py/,
6 | .eggs,
7 | .tox
8 | ignore = D400,D300,D205,D200,D105,D100,D101,D103,D107
9 |
--------------------------------------------------------------------------------
/mybycha/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | **/*.pyc
3 | **/*.DS_Store
4 | *.egg-info/
5 | *.swp
6 | .eggs/
7 | .idea/
8 | .tox/
9 | .pytest_cache/*
10 | venv/
11 | venv_py/
12 | .mypy_cache/
13 | __pycache__/
14 | /build
15 | /dist
16 | checkpoints/
17 | *.pt
--------------------------------------------------------------------------------
/mybycha/bycha/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_criterion import AbstractCriterion
7 |
8 | register_criterion, create_criterion, registry = setup_registry('criterion', AbstractCriterion)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.criteria.' + module_name)
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/abstract_criterion.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger(__name__)
3 |
4 | from torch.nn import Module
5 |
6 | from bycha.utils.runtime import Environment
7 |
8 |
9 | class AbstractCriterion(Module):
10 | """
11 | Criterion is the base class for all the criterion within ByCha.
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 | self._model = None
17 |
18 | def build(self, *args, **kwargs):
19 | """
20 | Construct a criterion for model training.
21 | Typically, `model` should be provided.
22 | """
23 | self._build(*args, **kwargs)
24 |
25 | e = Environment()
26 | if e.device.startswith('cuda'):
27 | logger.info('move criterion to {}'.format(e.device))
28 | self.cuda(e.device)
29 |
30 | def _build(self, *args, **kwargs):
31 | pass
32 |
33 | def forward(self, *args, **kwargs):
34 | """
35 | Compute the loss from neural model input, and produce a loss.
36 | """
37 | raise NotImplementedError
38 |
39 | def step_update(self, *args, **kwargs):
40 | """
41 | Perform step-level update
42 | """
43 | pass
44 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/auto_encoding_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | from bycha.criteria import register_criterion
5 | from bycha.criteria.base_criterion import BaseCriterion
6 | from bycha.utils.rate_schedulers import create_rate_scheduler
7 |
8 |
9 | @register_criterion
10 | class AutoEncodingLoss(BaseCriterion):
11 | """
12 | Label smoothed cross entropy
13 |
14 | Args:
15 | epsilon: label smoothing rate
16 | """
17 |
18 | def __init__(self, epsilon=0.1, beta=1.):
19 | super().__init__()
20 | self._epsilon = epsilon
21 | self._beta_configs = beta
22 |
23 | self._padding_idx = None
24 | self._beta = None
25 |
26 | def _build(self, model, padding_idx=-1):
27 | self._padding_idx = padding_idx
28 | self._model = model
29 |
30 | self._beta = create_rate_scheduler(self._beta_configs)
31 | self._beta.build()
32 |
33 | def _reconstruct_loss(self, lprobs, target, reduce=False):
34 | assert target.dim() == lprobs.dim() - 1
35 |
36 | lprobs, target = lprobs.view(-1, lprobs.size(-1)), target.view(-1)
37 | padding_mask = target.eq(self._padding_idx)
38 | ntokens = (~padding_mask).sum()
39 | # calculate nll loss
40 | nll_loss = -lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1)
41 | nll_loss.masked_fill_(padding_mask, 0.)
42 | if reduce:
43 | nll_loss = nll_loss.sum() / ntokens
44 |
45 | # calculate smoothed loss
46 | smooth_loss = -lprobs.mean(dim=-1)
47 | smooth_loss.masked_fill_(padding_mask, 0.)
48 | smooth_loss = smooth_loss.sum() / ntokens
49 |
50 | return nll_loss, ntokens, smooth_loss
51 |
52 | def step_update(self, step):
53 | """
54 | Perform step-level update
55 |
56 | Args:
57 | step: running step
58 | """
59 | self._beta.step_update(step)
60 |
61 | def compute_loss(self, lprobs, net_output):
62 | """
63 | Compute loss from a batch of samples
64 |
65 | Args:
66 | lprobs: neural network output logits
67 | net_output: neural net output
68 | Returns:
69 | - loss for network backward and optimization
70 | - logging information
71 | """
72 | lprobs = F.log_softmax(lprobs, dim=-1)
73 | # fetch target with default index 0
74 | target = net_output[0]
75 |
76 | bsz, sql = target.size()
77 | rec_loss, n_tokens, smooth_loss = self._reconstruct_loss(lprobs, target, reduce=False)
78 | rec_loss = torch.sum(rec_loss.view(bsz, -1), dim=-1)
79 | reg_loss = self._model.reg_loss()
80 |
81 | loss = torch.sum(rec_loss + self._beta.rate * reg_loss) / n_tokens
82 | loss = (1. - self._epsilon) * loss + self._epsilon * smooth_loss if self.training else loss
83 |
84 | nll_loss = torch.sum(self._model.nll(rec_loss, reg_loss)) / n_tokens # real nll loss
85 |
86 | logging_states = {
87 | 'reg_weight': self._beta.rate,
88 | 'loss': loss.data.item(),
89 | 'nll_loss': nll_loss.data.item(),
90 | 'ppl': 2 ** (nll_loss.data.item()),
91 | 'reg_loss': torch.mean(reg_loss).item(),
92 | 'rec_loss': (torch.sum(rec_loss)/n_tokens).item()
93 | }
94 |
95 | return loss, logging_states
96 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/base_criterion.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Tuple
2 |
3 | from bycha.criteria import AbstractCriterion
4 |
5 |
6 | class BaseCriterion(AbstractCriterion):
7 | """
8 | BaseCriterion is the base class for all the criterion within ByCha.
9 | """
10 |
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, net_input, net_output):
15 | """
16 | Compute loss from a batch of samples
17 |
18 | Args:
19 | net_input: neural network input and is used for compute the logits
20 | net_output (dict): oracle target for a network input
21 | Returns:
22 | tuple:
23 | - **loss**: loss for network backward and optimization
24 | - **logging_states**: logging information
25 | """
26 | if isinstance(net_input, Dict):
27 | lprobs = self._model(**net_input)
28 | elif isinstance(net_input, List) or isinstance(net_input, Tuple):
29 | lprobs = self._model(*net_input)
30 | else:
31 | lprobs = self._model(net_input)
32 | # fetch target with default index 0
33 | loss, logging_states = self.compute_loss(lprobs, **net_output)
34 | return loss, logging_states
35 |
36 | def compute_loss(self, *args, **kwargs):
37 | """
38 | Compute loss from model results
39 | """
40 | raise NotImplementedError
41 |
42 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from bycha.criteria import register_criterion
6 | from bycha.criteria.base_criterion import BaseCriterion
7 |
8 |
9 | @register_criterion
10 | class CrossEntropy(BaseCriterion):
11 | """
12 | Cross Entropy Loss.
13 |
14 | """
15 |
16 | def __init__(self, weight=None, logging_metric='acc'):
17 | super().__init__()
18 | self._weight = torch.FloatTensor(weight) if weight is not None else weight
19 | self._logging_metric = logging_metric
20 | self._padding_idx = None
21 | self._nll_loss = None
22 |
23 | def _build(self, model, padding_idx=-1):
24 | """
25 | Build a cross entropy loss over model.
26 |
27 | Args:
28 | model: a neural model for compute cross entropy.
29 | padding_idx: labels of padding_idx are all ignored to computed nll_loss
30 | """
31 | self._model = model
32 | self._padding_idx = padding_idx
33 | self._nll_loss = nn.NLLLoss(weight=self._weight, ignore_index=padding_idx)
34 |
35 | def compute_loss(self, lprobs, target):
36 | """
37 | Compute loss from a batch of samples
38 |
39 | Args:
40 | lprobs: neural network output logits
41 | target: oracle target for a network input
42 |
43 | Returns:
44 | - loss for network backward and optimization
45 | - logging information
46 | """
47 | lprobs = F.log_softmax(lprobs, dim=-1)
48 |
49 | # compute nll loss
50 | lprobs = lprobs.view(-1, lprobs.size(-1))
51 | target = target.view(-1)
52 | nll_loss = self._nll_loss(lprobs, target)
53 |
54 | # record logging
55 | logging_states = {
56 | 'loss': nll_loss.data.item(),
57 | }
58 | if self._logging_metric == 'acc':
59 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item()
60 | tot = target.size(0)
61 | logging_states['acc'] = correct / tot
62 | elif self._logging_metric == 'ppl':
63 | logging_states['ppl'] = 2 ** (nll_loss.data.item())
64 | return nll_loss, logging_states
65 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/focal_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 | from bycha.criteria import register_criterion
4 | from bycha.criteria.base_criterion import BaseCriterion
5 |
6 |
7 | @register_criterion
8 | class FocalCrossEntropy(BaseCriterion):
9 | """
10 | Label smoothed cross entropy
11 |
12 | Args:
13 | gamma: focal loss rate
14 | """
15 |
16 | def __init__(self, gamma: float = 2.0):
17 | super().__init__()
18 | self._gamma = gamma
19 |
20 | self._padding_idx = None
21 |
22 | def _build(self, model, padding_idx=-1):
23 | """
24 | Build a label smoothed cross entropy loss over model.
25 |
26 | Args:
27 | model: a neural model for compute cross entropy.
28 | padding_idx: labels of padding_idx are all ignored to computed nll_loss
29 | """
30 | self._model = model
31 | self._padding_idx = padding_idx
32 |
33 | def compute_loss(self, lprobs, target):
34 | """
35 | Compute loss from a batch of samples
36 |
37 | Args:
38 | lprobs: neural network output logits
39 | target: oracle target for a network input
40 | Returns:
41 | - loss for network backward and optimization
42 | - logging information
43 | """
44 | lprobs = F.log_softmax(lprobs, dim=-1)
45 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item()
46 | tot = target.size(0)
47 |
48 | target_padding_mask = target.eq(self._padding_idx)
49 | assert target.dim() == lprobs.dim() - 1
50 |
51 | lprobs = lprobs.view(-1, lprobs.size(-1))
52 | target = target.view(-1)
53 | target_padding_mask = target_padding_mask.view(-1)
54 | ntokens = (~target_padding_mask).sum()
55 |
56 | # calculate nll loss
57 | lprobs = lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1)
58 | weight = (1 - lprobs.exp()) ** self._gamma
59 | loss = - weight * lprobs
60 | loss.masked_fill_(target_padding_mask, 0.)
61 | loss = loss.sum() / ntokens
62 |
63 | # record logging
64 | logging_states = {
65 | 'loss': loss.data.item(),
66 | 'ntokens': ntokens.data.item(),
67 | 'acc': correct / tot
68 | }
69 |
70 | return loss, logging_states
71 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/label_smoothed_cross_entropy.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 | from bycha.criteria import register_criterion
4 | from bycha.criteria.base_criterion import BaseCriterion
5 |
6 |
7 | @register_criterion
8 | class LabelSmoothedCrossEntropy(BaseCriterion):
9 | """
10 | Label smoothed cross entropy
11 |
12 | Args:
13 | epsilon: label smoothing rate
14 | """
15 |
16 | def __init__(self, epsilon: float = 0.1):
17 | super().__init__()
18 | self._epsilon = epsilon
19 |
20 | self._padding_idx = None
21 |
22 | def _build(self, model, padding_idx=-1):
23 | """
24 | Build a label smoothed cross entropy loss over model.
25 |
26 | Args:
27 | model: a neural model for compute cross entropy.
28 | padding_idx: labels of padding_idx are all ignored to computed nll_loss
29 | """
30 | self._model = model
31 | self._padding_idx = padding_idx
32 |
33 | def compute_loss(self, lprobs, target):
34 | """
35 | Compute loss from a batch of samples
36 |
37 | Args:
38 | lprobs: neural network output logits
39 | target: oracle target for a network input
40 | Returns:
41 | - loss for network backward and optimization
42 | - logging information
43 | """
44 | lprobs = F.log_softmax(lprobs, dim=-1)
45 | target_padding_mask = target.eq(self._padding_idx)
46 | assert target.dim() == lprobs.dim() - 1
47 | # infer task type
48 | is_classification_task = len(target.size()) == 1
49 |
50 | lprobs = lprobs.view(-1, lprobs.size(-1))
51 | target = target.view(-1)
52 | target_padding_mask = target_padding_mask.view(-1)
53 | ntokens = (~target_padding_mask).sum()
54 |
55 | # calculate nll loss
56 | nll_loss = -lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1)
57 | nll_loss.masked_fill_(target_padding_mask, 0.)
58 | nll_loss = nll_loss.sum() / ntokens
59 |
60 | # calculate smoothed loss
61 | if self._epsilon > 0.:
62 | smooth_loss = -lprobs.mean(dim=-1)
63 | smooth_loss.masked_fill_(target_padding_mask, 0.)
64 | smooth_loss = smooth_loss.sum() / ntokens
65 |
66 | # average nll loss and smoothed loss, weighted by epsilon
67 | loss = (1. - self._epsilon) * nll_loss + self._epsilon * smooth_loss if self.training else nll_loss
68 | else:
69 | loss = nll_loss
70 |
71 | # record logging
72 | logging_states = {
73 | 'loss': loss.data.item(),
74 | 'nll_loss': nll_loss.data.item(),
75 | 'ntokens': ntokens.data.item(),
76 | }
77 | if is_classification_task:
78 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item()
79 | tot = target.size(0)
80 | logging_states['acc'] = correct / tot
81 | else:
82 | logging_states['ppl'] = 2 ** (nll_loss.data.item())
83 |
84 | return loss, logging_states
85 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/mse.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from bycha.criteria import register_criterion
4 | from bycha.criteria.base_criterion import BaseCriterion
5 |
6 |
7 | @register_criterion
8 | class MSE(BaseCriterion):
9 | """
10 | Mean square error
11 |
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 | self._mse_loss = None
17 |
18 | def _build(self, model):
19 | """
20 | Build a cross entropy loss over model.
21 |
22 | Args:
23 | model: a neural model for compute cross entropy.
24 | """
25 | self._model = model
26 | self._mse_loss = nn.MSELoss()
27 |
28 | def compute_loss(self, pred, target):
29 | """
30 | Compute loss from a batch of samples
31 |
32 | Args:
33 | pred: neural network output
34 | target: oracle target for a network input
35 | Returns:
36 | - loss for network backward and optimization
37 | - logging information
38 | """
39 | # compute nll loss
40 | pred = pred.view(-1)
41 | target = target.view(-1)
42 | mse_loss = self._mse_loss(pred, target)
43 |
44 | # record logging
45 | logging_states = {
46 | 'loss': mse_loss.data.item(),
47 | }
48 | return mse_loss, logging_states
49 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/multitask_criterion.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from bycha.criteria import AbstractCriterion, create_criterion, register_criterion
4 |
5 |
6 | @register_criterion
7 | class MultiTaskCriterion(AbstractCriterion):
8 | """
9 | Criterion is the base class for all the criterion within ByCha.
10 | """
11 |
12 | def __init__(self, criterions):
13 | super().__init__()
14 | self._criterion_configs = criterions
15 |
16 | self._names = [name for name in self._criterion_configs]
17 | self._criterions, self._weights = None, None
18 |
19 | def _build(self, model, padding_idx=-1):
20 | """
21 | Build multi-task criterion by dispatch args to each criterion
22 |
23 | Args:
24 | model: neural model
25 | padding_idx: pad idx to ignore
26 | """
27 | self._model = model
28 | self._criterions, self._weights = {}, {}
29 | for name in self._names:
30 | criterion_config = self._criterion_configs[name]
31 | self._weights[name] = criterion_config.pop('weight') if 'weight' in criterion_config else 1
32 | self._criterions[name] = create_criterion(self._criterion_configs[name])
33 | self._criterions[name].build(model, padding_idx)
34 |
35 | def forward(self, net_input, net_output):
36 | """
37 | Compute loss from a batch of samples
38 |
39 | Args:
40 | net_input: neural network input and is used for compute the logits
41 | net_output (dict): oracle target for a network input
42 | Returns:
43 | - loss for network backward and optimization
44 | - logging information
45 | """
46 | lprobs_dict = self._model(**net_input)
47 | assert isinstance(lprobs_dict, Dict), 'A multitask learning model must return a dict of log-probability'
48 | # fetch target with default index 0
49 | tot_loss, complete_logging_states = 0, {}
50 | for name in self._names:
51 | lprobs, net_out, criterion = lprobs_dict[name], net_output[name], self._criterions[name]
52 | loss, logging_states = criterion.compute_loss(lprobs, **net_out)
53 | tot_loss += self._weights[name] * loss
54 | logging_states = {f'{name}.{key}': val for key, val in logging_states.items()}
55 | complete_logging_states.update(logging_states)
56 | complete_logging_states['loss'] = tot_loss.data.item()
57 | return tot_loss, complete_logging_states
58 |
59 |
--------------------------------------------------------------------------------
/mybycha/bycha/criteria/self_contained_loss.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | from bycha.criteria import AbstractCriterion, register_criterion
4 |
5 |
6 | @register_criterion
7 | class SelfContainedLoss(AbstractCriterion):
8 | """
9 | SelfContainedLoss.
10 |
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def build(self, model, *args, **kwargs):
17 | """
18 | Build a cross entropy loss over model.
19 |
20 | Args:
21 | model: a neural model for compute cross entropy.
22 | """
23 | self._model = model
24 |
25 | def forward(self, net_input):
26 | """
27 | Compute loss via model itself
28 |
29 | Args:
30 | net_input (dict): neural network input and is used for compute the logits
31 | Returns:
32 | - loss for network backward and optimization
33 | - logging information
34 | """
35 | output = self._model.loss(**net_input)
36 | if isinstance(output, Tuple):
37 | assert len(output) == 2, 'if a tuple returned, it must be (loss, logging_states)'
38 | loss, logging_states = output
39 | else:
40 | loss = output
41 | logging_states = {'loss': loss.data.item()}
42 | return loss, logging_states
43 |
--------------------------------------------------------------------------------
/mybycha/bycha/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.ops import deepcopy_on_ref
5 | from bycha.utils.registry import setup_registry
6 |
7 | from .abstract_dataLoader import AbstractDataLoader
8 |
9 | register_dataloader, _build_dataloader, registry = setup_registry('dataloader', AbstractDataLoader)
10 |
11 |
12 | def build_dataloader(configs, dataset, sampler=None, collate_fn=None, post_collate=False):
13 | """
14 | Build a dataloader
15 |
16 | Args:
17 | configs: dataloader configs
18 | dataset: dataset storing samples
19 | sampler: sample strategy
20 | collate_fn: collate function during data fetching with torch.utils.data.DataLoader
21 | post_collate: whether to perform collate_fn after data fetching
22 |
23 | Returns:
24 | AbstractDataLoader
25 | """
26 | configs = deepcopy_on_ref(configs)
27 | configs.update({
28 | 'dataset': dataset,
29 | 'collate_fn': collate_fn if not post_collate else None,
30 | 'post_collate_fn': collate_fn if post_collate else None
31 | })
32 | if sampler is not None:
33 | configs['sampler'] = sampler
34 | dataloader = _build_dataloader(configs)
35 | return dataloader
36 |
37 |
38 | modules_dir = os.path.dirname(__file__)
39 | for file in os.listdir(modules_dir):
40 | path = os.path.join(modules_dir, file)
41 | if (
42 | not file.startswith('_')
43 | and not file.startswith('.')
44 | and (file.endswith('.py') or os.path.isdir(path))
45 | ):
46 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
47 | module = importlib.import_module('bycha.dataloaders.' + module_name)
48 |
--------------------------------------------------------------------------------
/mybycha/bycha/dataloaders/binarized_dataloader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import random
4 |
5 | from bycha.dataloaders import register_dataloader, AbstractDataLoader
6 | from bycha.utils.io import UniIO
7 | from bycha.utils.runtime import Environment
8 | from bycha.utils.tensor import list2tensor
9 |
10 |
11 | @register_dataloader
12 | class BinarizedDataLoader(AbstractDataLoader):
13 | """
14 | AbstractDataLoader to sample and process data from dataset
15 |
16 | Args:
17 | path: path to load binarized data
18 | """
19 |
20 | def __init__(self,
21 | path,
22 | preload=False,
23 | length_interval=8,
24 | max_shuffle_size=1,
25 | **kwargs):
26 | super().__init__(None)
27 | self._path = path
28 | self._preload = preload
29 | self._batches = None
30 | self._length_interval = length_interval
31 | self._max_shuffle_size = max_shuffle_size
32 |
33 | env = Environment()
34 | self._rank = env.rank
35 | self._distributed_wolrds = env.distributed_world
36 | self._max_buffered_batch_num = self._max_shuffle_size * self._distributed_wolrds
37 | self._buffered_batches = []
38 |
39 | if preload:
40 | self._batches = []
41 | with UniIO(self._path) as fin:
42 | for batch in fin:
43 | batch = json.loads(batch)
44 | batch = list2tensor(batch)
45 | self._batches.append(batch)
46 | total_size = int(math.ceil(len(self._batches) * 1.0 / self._distributed_wolrds)) * self._distributed_wolrds
47 | self._batches += self._batches[:(total_size - len(self._batches))]
48 |
49 | def reset(self, *args, **kwargs):
50 | if not self._preload:
51 | self._batches = UniIO(self._path)
52 | else:
53 | if self._max_shuffle_size > 0:
54 | random.shuffle(self._batches)
55 | self._buffered_batches.clear()
56 | return self
57 |
58 | def __iter__(self):
59 | for batch in self._batches:
60 | if not self._preload:
61 | batch = json.loads(batch)
62 | batch = list2tensor(batch)
63 | self._buffered_batches.append(batch)
64 | if len(self._buffered_batches) == self._max_buffered_batch_num:
65 | for s in self._dispatch():
66 | yield s
67 | if len(self._buffered_batches) >= self._distributed_wolrds:
68 | for s in self._dispatch():
69 | yield s
70 |
71 | def _dispatch(self):
72 | random.shuffle(self._buffered_batches)
73 | batch_num = len(self._buffered_batches) // self._distributed_wolrds * self._distributed_wolrds
74 | self._buffered_batches = self._buffered_batches[self._rank:batch_num:self._distributed_wolrds]
75 | for s in self._buffered_batches:
76 | yield s
77 | self._buffered_batches.clear()
78 |
79 | def __len__(self):
80 | return len(self._batches) // self._distributed_wolrds if self._preload else 0
81 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_dataset import AbstractDataset
7 |
8 | register_dataset, create_dataset, registry = setup_registry('dataset', AbstractDataset)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.datasets.' + module_name)
20 |
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/in_memory_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from bycha.datasets import AbstractDataset
4 |
5 |
6 | class InMemoryDataset(AbstractDataset):
7 | """
8 | An in-memory dataset which load data in memory before running task.
9 | InMemoryDataset is suitable for dataset of relatively low capacity.
10 |
11 | Args:
12 | path: data path to read
13 | sort_samples (bool): sort samples before running a task.
14 | It would be useful in inference without degrading performance.
15 | max_size: maximum size of loaded data
16 | """
17 |
18 | def __init__(self,
19 | path,
20 | sort_samples=False,
21 | max_size=0,):
22 | super().__init__(path, max_size=max_size)
23 |
24 | self._data = None
25 | self._sort_samples = sort_samples
26 |
27 | def build(self, collate_fn=None, preprocessed=False, **kwargs):
28 | """
29 | Build input stream and load data into memory
30 |
31 | Args:
32 | collate_fn: callback defined by a specific task
33 | preprocessed: data has been preprocessed
34 | """
35 | self._collate_fn = collate_fn
36 | self._preprocessed = preprocessed
37 |
38 | if self._path:
39 | self._load()
40 | self._pos = 0
41 |
42 | def _load(self):
43 | """
44 | Load data into memory
45 | """
46 | raise NotImplementedError
47 |
48 | def shuffle(self):
49 | """
50 | shuffle preload data
51 | """
52 | random.shuffle(self._data)
53 |
54 | def __getitem__(self, index):
55 | """
56 | fetch an item at index
57 |
58 | Args:
59 | index: index of item to fetch
60 |
61 | Returns:
62 | sample: data of index in preload data list
63 | """
64 | return self._data[index]
65 |
66 | def __iter__(self):
67 | for sample in self._data:
68 | yield sample
69 |
70 | def __next__(self):
71 | """
72 | fetch next sample
73 |
74 | Returns:
75 | sample: next sample
76 | """
77 | if self._pos < len(self._data):
78 | sample = self._data[self._pos]
79 | self._pos += 1
80 | else:
81 | raise StopIteration
82 | return sample
83 |
84 | def reset(self):
85 | """
86 | Reset io for a new round of iteration
87 | """
88 | self._pos = 0
89 |
90 | def finalize(self):
91 | super().finalize()
92 | try:
93 | del self._data
94 | except:
95 | pass
96 |
97 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/json_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from bycha.datasets import register_dataset
6 | from bycha.datasets.in_memory_dataset import InMemoryDataset
7 | from bycha.utils.data import count_sample_token
8 | from bycha.utils.io import UniIO
9 | from bycha.utils.runtime import progress_bar
10 |
11 |
12 | @register_dataset
13 | class JsonDataset(InMemoryDataset):
14 | """
15 | JsonDataset is an in-memory dataset for reading data saved with json.dumps.
16 |
17 | Args:
18 | path: data path to read
19 | sort_samples (bool): sort samples before running a task.
20 | It would be useful in inference without degrading performance.
21 | max_size: maximum size of loaded data
22 | """
23 |
24 | def __init__(self,
25 | path,
26 | sort_samples=False,
27 | max_size=0):
28 | super().__init__(path, sort_samples=sort_samples, max_size=max_size)
29 |
30 | def _load(self):
31 | """
32 | Preload all the data into memory. In the loading process, data are preprocess and sorted.
33 | """
34 | fin = UniIO(path=self._path)
35 | self._data = []
36 | accecpted, discarded = 0, 0
37 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')):
38 | if 0 < self._max_size <= i:
39 | break
40 | try:
41 | sample = sample.strip('\n')
42 | self._data.append(self._full_callback(sample))
43 | accecpted += 1
44 | except Exception:
45 | logger.warning('sample {} is discarded'.format(sample))
46 | discarded += 1
47 | if self._sort_samples:
48 | self._data.sort(key=lambda x: count_sample_token(x))
49 | self._length = len(self._data)
50 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples')
51 | fin.close()
52 |
53 | def _callback(self, sample):
54 | """
55 | Callback for json data
56 |
57 | Args:
58 | sample: data in raw format
59 |
60 | Returns:
61 | sample (dict): a dict of samples consisting of parallel data of different sources
62 | """
63 | sample = json.loads(sample)
64 | return sample
65 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/parallel_text_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from bycha.datasets import register_dataset
6 | from bycha.datasets.in_memory_dataset import InMemoryDataset
7 | from bycha.utils.data import count_sample_token
8 | from bycha.utils.io import UniIO
9 | from bycha.utils.runtime import progress_bar, Environment
10 |
11 |
12 | @register_dataset
13 | class ParallelTextDataset(InMemoryDataset):
14 | """
15 | ParallelTextDataset is an in-memory dataset for reading data saved in parallel files.
16 |
17 | Args:
18 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only.
19 | sort_samples (bool): sort samples before running a task.
20 | It would be useful in inference without degrading performance.
21 | max_size: maximum size of loaded data
22 | """
23 |
24 | def __init__(self,
25 | path: Dict[str, str] = None,
26 | sort_samples: bool = False,
27 | max_size: int = 0,):
28 | super().__init__(path, sort_samples=sort_samples, max_size=max_size)
29 | self._sources = list(path.keys())
30 |
31 | def _callback(self, sample):
32 | """
33 | Callback for parallel data
34 |
35 | Args:
36 | sample: data in raw format
37 |
38 | Returns:
39 | sample (dict): a dict of samples consisting of parallel data of different sources
40 | """
41 | if self._preprocessed:
42 | sample = {key: [eval(v) for v in val] for key, val in sample.items()}
43 | return sample
44 |
45 | def _load(self):
46 | """
47 | Preload all the data into memory. In the loading process, data are collate_fnd and sorted.
48 | """
49 | ori_fin = [UniIO(self._path[src]) for src in self._sources]
50 | fin = zip(*ori_fin)
51 | self._data = []
52 | accepted, discarded = 0, 0
53 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')):
54 | if 0 < self._max_size <= i:
55 | break
56 | try:
57 | sample = self._full_callback({
58 | src: s.strip('\n')
59 | for src, s in zip(self._sources, sample)
60 | })
61 | self._data.append(sample)
62 | accepted += 1
63 | except Exception as e:
64 | env = Environment()
65 | if env.debug:
66 | raise e
67 | logger.warning('sample {} is discarded'.format(sample))
68 | discarded += 1
69 | if self._sort_samples:
70 | self._data.sort(key=lambda x: count_sample_token(x))
71 | self._length = len(self._data)
72 | logger.info(f'Totally accept {accepted} samples, discard {discarded} samples')
73 | for fin in ori_fin:
74 | fin.close()
75 |
76 | self._collate_fn = None
77 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/streaming_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from torch.utils.data import IterableDataset
4 |
5 | from bycha.datasets import AbstractDataset
6 |
7 |
8 | class StreamingDataset(AbstractDataset, IterableDataset):
9 | """
10 | Tackle with io and create parallel data
11 |
12 | Args:
13 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only.
14 |
15 | """
16 |
17 | def __init__(self,
18 | path: Dict[str, str],):
19 | super().__init__(path)
20 | self._fin = None
21 | self._length = None
22 |
23 | def shuffle(self):
24 | """
25 | shuffle preload data
26 | """
27 | pass
28 |
29 | def __getitem__(self, index):
30 | """
31 | fetch an item with index
32 |
33 | Args:
34 | index: index of item to fetch
35 |
36 | Returns:
37 | sample: data of index in preload data list
38 | """
39 | return next(self)
40 |
41 | def __next__(self):
42 | """
43 | fetch next sample
44 |
45 | Returns:
46 | sample: next sample
47 | """
48 | raise NotImplementedError
49 |
50 | def reset(self):
51 | """
52 | Reset io for a new round of iteration
53 | """
54 | pass
55 |
56 | def __len__(self):
57 | """
58 | Compute dataset length
59 |
60 | Returns:
61 | dataset length
62 | """
63 | return 0
64 |
65 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/streaming_json_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from bycha.datasets import register_dataset
6 | from bycha.datasets.streaming_dataset import StreamingDataset
7 | from bycha.utils.io import UniIO
8 |
9 |
10 | @register_dataset
11 | class StreamingJsonDataset(StreamingDataset):
12 | """
13 | StreamingJsonDataset is a streaming dataset for reading data saved with json.dumps.
14 |
15 | Args:
16 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only.
17 | """
18 |
19 | def __init__(self, path):
20 | super().__init__(path)
21 |
22 | def build(self, collate_fn=None, preprocessed=False):
23 | """
24 | Build input stream
25 |
26 | Args:
27 | collate_fn: callback defined by a specific task
28 | preprocessed: data has been processed
29 | """
30 | self._collate_fn = collate_fn
31 | self._preprocessed = preprocessed
32 |
33 | if self._path:
34 | self._fin = UniIO(self._path)
35 |
36 | def __iter__(self):
37 | """
38 | fetch next sample
39 |
40 | Returns:
41 | sample: next sample
42 | """
43 | for sample in self._fin:
44 | try:
45 | sample = self._full_callback(sample)
46 | yield sample
47 | except Exception as e:
48 | logger.warning(e)
49 |
50 | def _callback(self, sample):
51 | """
52 | Callback for json data
53 |
54 | Args:
55 | sample: data in raw format
56 |
57 | Returns:
58 | sample (dict): a dict of samples consisting of parallel data of different sources
59 | """
60 | sample = json.loads(sample)
61 | return sample
62 |
63 | def reset(self):
64 | """
65 | reset the dataset
66 | """
67 | self._pos = 0
68 | self._fin = UniIO(self._path)
69 |
70 | def finalize(self):
71 | """
72 | Finalize dataset after finish reading
73 | """
74 | self._fin.close()
75 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/streaming_parallel_text_dataset.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from bycha.datasets import register_dataset
4 | from bycha.datasets.streaming_dataset import StreamingDataset
5 | from bycha.utils.io import UniIO
6 |
7 |
8 | @register_dataset
9 | class StreamingParallelTextDataset(StreamingDataset):
10 | """
11 | StreamingParallelTextDataset is a streaming dataset for reading data saved in parallel files.
12 |
13 | Args:
14 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only.
15 | """
16 |
17 | def __init__(self,
18 | path: Dict[str, str] = None,):
19 | super().__init__(path)
20 | self._sources = path.keys()
21 | self._ori_fin = None
22 |
23 | def build(self, collate_fn=None, preprocessed=False):
24 | """
25 | Build input stream
26 |
27 | Args:
28 | collate_fn: callback defined by a specific task
29 | preprocessed: whether the data has been preprocessed
30 | """
31 | self._collate_fn = collate_fn
32 | self._preprocessed = preprocessed
33 |
34 | if self._path:
35 | self._ori_fin = [UniIO(self._path[src]) for src in self._sources]
36 | self._fin = zip(*self._ori_fin)
37 |
38 | def __iter__(self):
39 | """
40 | fetch next sample
41 |
42 | Returns:
43 | sample: next sample
44 | """
45 | for sample in self._fin:
46 | sample = self._full_callback({src: s.strip('\n') for src, s in zip(self._sources, sample)})
47 | yield sample
48 |
49 | def _callback(self, sample):
50 | """
51 | Callback for parallel data
52 |
53 | Args:
54 | sample: data in raw format
55 |
56 | Returns:
57 | sample (dict): a dict of samples consisting of parallel data of different sources
58 | """
59 | if self._preprocessed:
60 | sample = {key: [eval(v) for v in val] for key, val in sample.items()}
61 | return sample
62 |
63 | def finalize(self):
64 | """
65 | Finalize dataset after finish reading
66 | """
67 | for fin in self._ori_fin:
68 | fin.close()
69 |
70 | def reset(self):
71 | """
72 | reset the dataset
73 | """
74 | self._pos = 0
75 | self._ori_fin = [UniIO(self._path[src]) for src in self._sources]
76 | self._fin = zip(*self._ori_fin)
77 |
78 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/streaming_text_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger(__name__)
3 |
4 | from bycha.datasets import register_dataset
5 | from bycha.datasets.streaming_dataset import StreamingDataset
6 | from bycha.utils.io import UniIO
7 |
8 |
9 | @register_dataset
10 | class StreamingTextDataset(StreamingDataset):
11 | """
12 | StreamingTextDataset is a streaming dataset for reading data in textual format.
13 |
14 | Args:
15 | path: path to load the data
16 | """
17 |
18 | def __init__(self,
19 | path,):
20 | super().__init__(path)
21 |
22 | def build(self, collate_fn=None, preprocessed=False):
23 | """
24 | Build input stream
25 |
26 | Args:
27 | collate_fn: callback defined by a specific task
28 | """
29 | self._collate_fn = collate_fn
30 |
31 | if self._path:
32 | self._fin = UniIO(self._path)
33 |
34 | def __iter__(self):
35 | """
36 | fetch next sample
37 |
38 | Returns:
39 | sample: next sample
40 | """
41 | for sample in self._fin:
42 | try:
43 | sample = self._full_callback(sample)
44 | yield sample
45 | except StopIteration:
46 | raise StopIteration
47 | except Exception as e:
48 | logger.warning(e)
49 |
50 | def reset(self):
51 | """
52 | reset the dataset
53 | """
54 | self._pos = 0
55 | self._fin = UniIO(self._path)
56 |
57 | def _callback(self, sample):
58 | """
59 | Callback for json data
60 |
61 | Args:
62 | sample: data in raw format
63 |
64 | Returns:
65 | sample (dict): a dict of samples consisting of parallel data of different sources
66 | """
67 | sample = sample.strip('\n').strip()
68 | return sample
69 |
70 | def finalize(self):
71 | """
72 | Finalize dataset after finish reading
73 | """
74 | self._fin.close()
75 |
76 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/text_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger(__name__)
3 |
4 | from bycha.datasets import register_dataset
5 | from bycha.datasets.in_memory_dataset import InMemoryDataset
6 | from bycha.utils.data import count_sample_token
7 | from bycha.utils.io import UniIO
8 | from bycha.utils.runtime import progress_bar
9 |
10 |
11 | @register_dataset
12 | class TextDataset(InMemoryDataset):
13 | """
14 | TextDataset is an in-memory dataset for reading data in textual format.
15 |
16 | Args:
17 | path: path to load the data
18 | sort_samples (bool): sort samples before running a task.
19 | It would be useful in inference without degrading performance.
20 | max_size: maximum size of loaded data
21 | """
22 |
23 | def __init__(self,
24 | path: str = None,
25 | sort_samples: bool = False,
26 | max_size: int = 0,):
27 | super().__init__(path, sort_samples=sort_samples, max_size=max_size)
28 |
29 | def _callback(self, sample):
30 | """
31 | Callback for textual data
32 |
33 | Args:
34 | sample: data in raw format
35 |
36 | Returns:
37 | sample (dict): a dict of samples consisting of parallel data of different sources
38 | """
39 | sample = sample.strip('\n')
40 | if self._preprocessed:
41 | sample = [eval(v) for v in sample.split()]
42 | return sample
43 |
44 | def _load(self):
45 | """
46 | Preload all the data into memory. In the loading process, data are collate_fnd and sorted.
47 | """
48 | fin = UniIO(path=self._path)
49 | self._data = []
50 | accecpted, discarded = 0, 0
51 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')):
52 | if 0 < self._max_size <= i:
53 | break
54 | try:
55 | self._data.append(self._full_callback(sample))
56 | accecpted += 1
57 | except Exception:
58 | logger.warning('sample {} is discarded'.format(sample))
59 | discarded += 1
60 | if self._sort_samples:
61 | self._data.sort(key=lambda x: count_sample_token(x))
62 | self._length = len(self._data)
63 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples')
64 | fin.close()
65 |
--------------------------------------------------------------------------------
/mybycha/bycha/datasets/tfrecord_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from bycha.datasets import register_dataset
6 | from bycha.datasets.in_memory_dataset import InMemoryDataset
7 | from bycha.utils.data import count_sample_token
8 | from bycha.utils.runtime import progress_bar
9 |
10 |
11 | @register_dataset
12 | class TFRecordDataset(InMemoryDataset):
13 | """
14 | A tfrecord dataset is an in-memory dataset for reading data in tfrecord
15 |
16 | Args:
17 | path: data path to read
18 | index_path: path to load data map
19 | description: decription for tfrecord
20 | sort_samples (bool): sort samples before running a task.
21 | It would be useful in inference without degrading performance.
22 | max_size: maximum size of loaded data
23 | """
24 |
25 | def __init__(self,
26 | path,
27 | index_path=None,
28 | description=None,
29 | sort_samples=False,
30 | max_size=0):
31 | super().__init__(path, sort_samples=sort_samples, max_size=max_size)
32 | self._index_path = index_path
33 | self._description = description
34 |
35 | def _load(self):
36 | """
37 | Preload all the data into memory. In the loading process, data are preprocessed and sorted.
38 | """
39 | import tfrecord
40 | fin = tfrecord.tfrecord_loader(self._path,
41 | self._index_path,
42 | self._description)
43 | self._data = []
44 | accecpted, discarded = 0, 0
45 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')):
46 | if 0 < self._max_size <= i:
47 | break
48 | try:
49 | self._data.append(self._full_callback(sample))
50 | accecpted += 1
51 | except Exception:
52 | logger.warning('sample {} is discarded'.format(sample))
53 | discarded += 1
54 | if self._sort_samples:
55 | self._data.sort(key=lambda x: count_sample_token(x))
56 | self._length = len(self._data)
57 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples')
58 |
59 | def _callback(self, sample):
60 | """
61 | Callback for json data
62 |
63 | Args:
64 | sample: data in raw format
65 |
66 | Returns:
67 | sample (dict): a dict of samples consisting of parallel data of different sources
68 | """
69 | sample = json.loads(sample)
70 | return sample
71 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/entries/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/entries/binarize_data.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from bycha.entries.util import parse_config
4 | from bycha.tasks import create_task, AbstractTask
5 | from bycha.utils.ops import recursive
6 |
7 | def main():
8 | confs = parse_config()
9 | task = create_task(confs.pop('task'))
10 | assert isinstance(task, AbstractTask)
11 | task.build()
12 | dataloader = task._build_dataloader('train', mode='train')
13 | output_path = confs['output_path']
14 | to_list = recursive(lambda x: x.tolist())
15 | with open(output_path, 'w') as fout:
16 | for batch in dataloader:
17 | batch = to_list(batch)
18 | batch = json.dumps(batch)
19 | fout.write(f'{batch}\n')
20 |
21 |
22 | if __name__ == '__main__':
23 | main()
24 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/build_tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from bycha.tokenizers import registry, AbstractTokenizer
3 | from bycha.utils.runtime import build_env
4 | from bycha.entries.util import parse_config
5 |
6 |
7 | def main():
8 | configs = parse_config()
9 | if 'env' in configs:
10 | env_conf = configs.pop('env')
11 | build_env(configs, **env_conf)
12 | cls = registry[configs.pop('class').lower()]
13 | assert issubclass(cls, AbstractTokenizer)
14 | os.makedirs('/'.join(configs['output_path'].split('/')[:-1]), exist_ok=True)
15 | data = configs.pop('data')
16 | cls.learn(data, **configs)
17 |
18 |
19 | if __name__ == '__main__':
20 | main()
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/export.py:
--------------------------------------------------------------------------------
1 | from bycha.tasks import create_task
2 | from bycha.utils.runtime import build_env
3 | from bycha.entries.util import parse_config
4 |
5 |
6 | def main():
7 | confs = parse_config()
8 | if 'env' in confs:
9 | build_env(confs['task'], **confs['env'])
10 | export_conf = confs.pop('export')
11 | task = create_task(confs.pop('task'))
12 | task.build()
13 | path = export_conf.pop("path")
14 | task.export(path, **export_conf)
15 |
16 |
17 | if __name__ == '__main__':
18 | main()
19 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/preprocess.py:
--------------------------------------------------------------------------------
1 | from bycha.entries.util import parse_config
2 | from bycha.tasks import create_task, AbstractTask
3 | from bycha.datasets import create_dataset, AbstractDataset
4 |
5 |
6 | def main():
7 | confs = parse_config()
8 | task = create_task(confs.pop('task'))
9 | assert isinstance(task, AbstractTask)
10 | task.build()
11 | dataset_conf = confs['dataset']
12 | for _, conf in confs['data'].items():
13 | output_path = conf['output_path']
14 | data_map_path = conf['data_map_path'] if 'data_map_path' in conf else None
15 | dataset_conf['path'] = conf['path']
16 | dataset = create_dataset(dataset_conf)
17 | assert isinstance(dataset, AbstractDataset)
18 | dataset.build(collate_fn=task._data_collate_fn, preprocessed=False)
19 | dataset.write(path=output_path, data_map_path=data_map_path)
20 |
21 |
22 | if __name__ == '__main__':
23 | main()
24 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/run.py:
--------------------------------------------------------------------------------
1 | from bycha.tasks import create_task
2 | from bycha.utils.runtime import build_env
3 | from bycha.entries.util import parse_config
4 |
5 |
6 | def main():
7 | confs = parse_config()
8 | if 'env' in confs:
9 | build_env(confs['task'], **confs['env'])
10 | task = create_task(confs.pop('task'))
11 | task.build()
12 | task.run()
13 |
14 |
15 | if __name__ == '__main__':
16 | main()
17 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/serve.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | import euler
6 |
7 | from bycha.entries.util import parse_config
8 | from bycha.services import Server, Service
9 | from bycha.utils.runtime import build_env
10 |
11 |
12 | def main():
13 | configs = parse_config()
14 |
15 | env = configs.pop('env')
16 | env['device'] = 'cpu'
17 | build_env(configs['task'], **env)
18 |
19 | server = Server(configs)
20 | app = euler.Server(Service)
21 |
22 | @app.register('serve')
23 | def serve(ctx, req):
24 | return server.serve(req)
25 |
26 | server_port = int(os.environ.get('SERVER_PORT', 18001))
27 | logger.info('Starting thrift server in python on PORT {}...'.format(server_port))
28 | app.run("tcp://0.0.0.0:{}".format(server_port),
29 | transport="buffered",
30 | workers_count=getattr(configs, 'worker', 8))
31 | logger.info('exit!')
32 |
33 |
34 | if __name__ == '__main__':
35 | main()
36 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/serve_model.py:
--------------------------------------------------------------------------------
1 | from bycha.entries.util import parse_config
2 | import os
3 | import logging
4 | logger = logging.getLogger(__name__)
5 |
6 | from thriftpy.rpc import make_server
7 | import thriftpy
8 |
9 | from bycha.services.model_server import ModelServer
10 | from bycha.tasks import create_task
11 | from bycha.utils.runtime import build_env
12 |
13 |
14 | def main():
15 | configs = parse_config()
16 | if 'env' in configs:
17 | build_env(configs['task'], **configs['env'])
18 | task = create_task(configs.pop('task'))
19 | task.build()
20 | generator = task._generator
21 | model = ModelServer(generator)
22 | grpc_port = int(os.environ.get('GRPC_PORT', 6000))
23 | model_infer_thrift = thriftpy.load("/opt/tiger/ByCha/bycha/services/idls/model_infer.thrift", module_name="model_infer_thrift")
24 | server = make_server(model_infer_thrift.ModelInfer,
25 | model,
26 | 'localhost',
27 | grpc_port,
28 | client_timeout=None)
29 | logger.info('Starting Serving Model')
30 | server.serve()
31 |
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/mybycha/bycha/entries/util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 |
4 | from bycha.utils.data import possible_eval
5 | from bycha.utils.io import UniIO
6 |
7 |
8 | def parse_config():
9 | """
10 | Parse configurations from config file and override arguments.
11 | Returns:
12 |
13 | """
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--config', metavar='N', type=str, help='config path')
16 | parser.add_argument('--lib', metavar='N', default=None, type=str, help='customization package')
17 | args, unknown = parser.parse_known_args()
18 | with UniIO(args.config) as fin:
19 | confs = yaml.load(fin, Loader=yaml.FullLoader)
20 | stringizing(confs)
21 | kv_pairs = []
22 | current_key = None
23 | for ele in unknown:
24 | if ele.startswith("--"):
25 | current_key = ele[2:]
26 | else:
27 | kv_pairs.append((current_key, ele))
28 | for pair in kv_pairs:
29 | ks = pair[0].split(".")
30 | v = possible_eval(pair[1])
31 | tmp = confs
32 | last_key = ks[-1]
33 | for k in ks[:-1]:
34 | if k not in tmp:
35 | tmp[k] = {}
36 | tmp = tmp[k]
37 | tmp[last_key] = v
38 | if args.lib:
39 | if 'env' not in confs:
40 | confs['env'] = {}
41 | custom_libs = [args.lib]
42 | if 'custom_libs' in confs['env']:
43 | custom_libs.append(confs['env']['custom_libs'])
44 | confs['env']['custom_libs'] = ','.join(custom_libs)
45 | return confs
46 |
47 | def stringizing(conf: dict):
48 | def _stringizing(def_dct: dict, conf_dct: dict):
49 | for k, v in conf_dct.items():
50 | if isinstance(v, str):
51 | for def_k, def_v in def_dct.items():
52 | if def_k in v:
53 | v = v.replace(def_k, def_v)
54 | conf_dct[k] = v
55 | if isinstance(v, dict):
56 | _stringizing(def_dct, v)
57 |
58 | if "define" in conf:
59 | definition = {"${" + k + "}": v for k,v in conf['define'].items()}
60 | conf.pop("define")
61 | _stringizing(definition, conf)
62 |
--------------------------------------------------------------------------------
/mybycha/bycha/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_evaluator import AbstractEvaluator
7 |
8 | register_evaluator, create_evaluator, registry = setup_registry('evaluator', AbstractEvaluator)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.evaluators.' + module_name)
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/evaluators/abstract_evaluator.py:
--------------------------------------------------------------------------------
1 | class AbstractEvaluator:
2 | """
3 | Evaluation scheduler
4 | """
5 |
6 | def __init__(self, ):
7 | pass
8 |
9 | def build(self, *args, **kwargs):
10 | """
11 | Build evaluator from the given configs and components
12 | """
13 | raise NotImplementedError
14 |
15 | def finalize(self):
16 | """
17 | Finalize evaluator after finishing evaluation
18 | """
19 | raise NotImplementedError
20 |
21 | def _step_reset(self, *args, **kwargs):
22 | """
23 | Reset states by step
24 | """
25 | pass
26 |
27 | def _step(self, samples):
28 | """
29 | Evaluate one batch of samples
30 |
31 | Args:
32 | samples: a batch of samples
33 | """
34 | raise NotImplementedError
35 |
36 | def _step_update(self, *args, **kwargs):
37 | """
38 | Update states by step
39 | """
40 | pass
41 |
42 | def _eval_reset(self, *args, **kwargs):
43 | """
44 | Reset states before the overall evaluation process
45 | """
46 | pass
47 |
48 | def _eval_update(self, *args, **kwargs):
49 | """
50 | Update states after the overall evaluation process
51 | """
52 | pass
53 |
54 | def eval(self):
55 | """
56 | Evaluation process
57 | """
58 | raise NotImplementedError
59 |
--------------------------------------------------------------------------------
/mybycha/bycha/evaluators/multi_evaluator.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | from bycha.evaluators import AbstractEvaluator, register_evaluator, create_evaluator
4 |
5 |
6 | @register_evaluator
7 | class MultiTaskEvaluator(AbstractEvaluator):
8 | """
9 | MultiTaskEvaluator for evaluation, which wrapped from Evaluator with different situation.
10 |
11 | Args:
12 | evaluators (dict): evaluator configurations for building multiple evaluators
13 | """
14 |
15 | def __init__(self,
16 | evaluators: Dict,
17 | ):
18 | super().__init__()
19 | self._evaluator_configs = evaluators
20 |
21 | self._evaluators = None
22 | self._task_callback = None
23 |
24 | def build(self, generator, dataloaders, tokenizer, task_callback=None, postprocess=None):
25 | """
26 | Build evaluators with given arguments.
27 | Arguments are dispatched to all the evaluators respectively.
28 |
29 | Args:
30 | generator (bycha.generators.AbstractGenerator): the inference model to generate hypothesis
31 | dataloaders (dict[bycha.dataloaders.AbstractDataLoader]): a set of dataloaders to evaluate
32 | tokenizer (bycha.tokenizers.AbstractTokenizer): a tokenizer
33 | task_callback: building context in task during for evaluation via a callback function
34 | postprocess: postprocess pipeline to obtain final hypothesis from predicted results (torch.Tensor)
35 | """
36 | self._evaluators = {}
37 | for name, config in self._evaluator_configs.items():
38 | self._evaluators[name.upper()] = create_evaluator(config)
39 | self._evaluators[name.upper()].build(generator=generator,
40 | dataloaders=dataloaders,
41 | tokenizer=tokenizer,
42 | task_callback=task_callback,
43 | postprocess=postprocess)
44 | self._task_callback = task_callback
45 |
46 | def eval(self):
47 | """
48 | Perform evaluation for each task;
49 | """
50 | scores = {}
51 | for name, evaluator in self._evaluators.items():
52 | self._task_callback(training=False, infering=True)
53 | states = evaluator.eval()
54 | scores.update({'{}.{}'.format(name, key): val for key, val in states.items()})
55 | return scores
56 |
--------------------------------------------------------------------------------
/mybycha/bycha/generators/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_generator import AbstractGenerator
7 |
8 | register_generator, create_generator, registry = setup_registry('generator', AbstractGenerator)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.generators.' + module_name)
20 |
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/generators/abstract_generator.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger(__name__)
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from bycha.utils.ops import inspect_fn
8 | from bycha.utils.runtime import Environment
9 | from bycha.utils.io import UniIO, mkdir
10 |
11 |
12 | class AbstractGenerator(nn.Module):
13 | """
14 | AbstractGenerator wrap a model with inference algorithms.
15 | It can be directly exported and used for inference or serving.
16 |
17 | Args:
18 | path: path to restore traced model
19 | """
20 |
21 | def __init__(self, path):
22 | super().__init__()
23 |
24 | self._path = path
25 | self._traced_model = None
26 | self._model = None
27 | self._mode = 'infer'
28 |
29 | def build(self, *args, **kwargs):
30 | """
31 | Build or load a generator
32 | """
33 | if self._path is not None:
34 | self.load()
35 | else:
36 | self.build_from_model(*args, **kwargs)
37 |
38 | self._env = Environment()
39 | if self._env.device.startswith('cuda'):
40 | logger.info('move model to {}'.format(self._env.device))
41 | self.cuda(self._env.device)
42 |
43 | def build_from_model(self, *args, **kwargs):
44 | """
45 | Build generator from model
46 | """
47 | raise NotImplementedError
48 |
49 | def forward(self, *args, **kwargs):
50 | """
51 | Infer a sample in evaluation mode.
52 | We auto detect whether the inference model is traced, and use appropriate model to perform inference.
53 | """
54 | if self._traced_model is not None:
55 | return self._traced_model(*args, **kwargs)
56 | else:
57 | return self._forward(*args, **kwargs)
58 |
59 | def _forward(self, *args, **kwargs):
60 | """
61 | Infer a sample in evaluation mode with torch model.
62 | """
63 | raise NotImplementedError
64 |
65 | def export(self, path, net_input, **kwargs):
66 | """
67 | Export self to `path` by export model directly
68 |
69 | Args:
70 | path: path to store serialized model
71 | net_input: fake net_input for tracing the model
72 | """
73 | self.eval()
74 | with torch.no_grad():
75 | logger.info('trace model {}'.format(self._model.__class__.__name__))
76 | model = torch.jit.trace_module(self._model, {'forward': net_input})
77 | mkdir(path)
78 | logger.info('save model to {}/model'.format(path))
79 | with UniIO('{}/model'.format(path), 'wb') as fout:
80 | torch.jit.save(model, fout)
81 |
82 | def load(self):
83 | """
84 | Load a serialized model from path
85 | """
86 | logger.info('load model from {}'.format(self._path))
87 | with UniIO(self._path, 'rb') as fin:
88 | self._traced_model = torch.jit.load(fin)
89 |
90 | def reset(self, *args, **kwargs):
91 | """
92 | Reset generator states.
93 | """
94 | pass
95 |
96 | @property
97 | def input_slots(self):
98 | """
99 | Generator input slots that is auto-detected
100 | """
101 | return inspect_fn(self._forward)
102 |
--------------------------------------------------------------------------------
/mybycha/bycha/generators/generator.py:
--------------------------------------------------------------------------------
1 | from bycha.generators import AbstractGenerator, register_generator
2 | from bycha.utils.ops import inspect_fn
3 |
4 |
5 | @register_generator
6 | class Generator(AbstractGenerator):
7 | """
8 | Generator wrap a model with inference algorithms.
9 | Generator has the same function and interface as model.
10 | It can be directly exported and used for inference or serving.
11 |
12 | Args:
13 | path: path to export or load generator
14 | is_regression: whether the task is a regression task
15 | """
16 |
17 | def __init__(self,
18 | path=None,
19 | is_regression=False,
20 | is_binary_classification=False):
21 | super().__init__(path)
22 | self._is_regression = is_regression
23 | self._is_binary_classification = is_binary_classification
24 |
25 | self._model = None
26 |
27 | def build_from_model(self, model):
28 | """
29 | Build generator from model
30 |
31 | Args:
32 | model (bycha.models.AbstractModel): a neural model
33 | """
34 | self._model = model
35 |
36 | def _forward(self, *args):
37 | """
38 | Infer a sample as model in evaluation mode, and predict results from logits predicted by model
39 |
40 | Args:
41 | inputs: inference inputs
42 | """
43 | output = self._model(*args)
44 | if not self._is_regression:
45 | if self._is_binary_classification:
46 | output = (output > 0.5).long()
47 | else:
48 | _, output = output.max(dim=-1)
49 | return output
50 |
51 | def reset(self, mode):
52 | """
53 | Reset generator states.
54 |
55 | Args:
56 | mode: running mode
57 | """
58 | if mode != 'train':
59 | self.eval()
60 | self._mode = mode
61 | self._model.reset(mode)
62 |
63 | @property
64 | def model(self):
65 | return self._model
66 |
67 | @property
68 | def input_slots(self):
69 | """
70 | Generator input slots that is auto-detected
71 | """
72 | return inspect_fn(self._model.forward)
73 |
--------------------------------------------------------------------------------
/mybycha/bycha/generators/self_contained_generator.py:
--------------------------------------------------------------------------------
1 | from bycha.generators import AbstractGenerator, register_generator
2 |
3 |
4 | @register_generator
5 | class SelfContainedGenerator(AbstractGenerator):
6 | """
7 | SelfContainedGenerator use self-implemented generate function within model.
8 | Generator has the same function and interface as model.
9 | It can be directly exported and used for inference or serving.
10 |
11 | Args:
12 | path: path to export or load generator
13 | """
14 |
15 | def __init__(self,
16 | path=None, **kwargs):
17 | super().__init__(path)
18 | self._kwargs = kwargs
19 | self._model = None
20 |
21 | def build_from_model(self, model, *args, **kwargs):
22 | """
23 | Build generator from model
24 |
25 | Args:
26 | model (bycha.models.AbstractModel): a neural model
27 | """
28 | self._model = model
29 |
30 | def _forward(self, *args, **kwargs):
31 | """
32 | Infer a sample as model in evaluation mode, and predict results from logits predicted by model
33 | """
34 | kwargs.update(self._kwargs)
35 | output = self._model.generate(*args, **kwargs)
36 | return output
37 |
38 | @property
39 | def model(self):
40 | return self._model
41 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_metric import AbstractMetric
7 | from .pairwise_metric import PairwiseMetric
8 |
9 | register_metric, create_metric, registry = setup_registry('metric', AbstractMetric)
10 |
11 | modules_dir = os.path.dirname(__file__)
12 | for file in os.listdir(modules_dir):
13 | path = os.path.join(modules_dir, file)
14 | if (
15 | not file.startswith('_')
16 | and not file.startswith('.')
17 | and (file.endswith('.py') or os.path.isdir(path))
18 | ):
19 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
20 | module = importlib.import_module('bycha.metrics.' + module_name)
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/abstract_metric.py:
--------------------------------------------------------------------------------
1 | class AbstractMetric:
2 | """
3 | Metric evaluates the performance with produced hypotheses and references.
4 | """
5 |
6 | def __init__(self):
7 | self._score = None
8 |
9 | def build(self, *args, **kwargs):
10 | """
11 | Build metric
12 | """
13 | self.reset()
14 |
15 | def reset(self):
16 | """
17 | Reset metric for a new round of evaluation
18 | """
19 | pass
20 |
21 | def add_all(self, *args, **kwargs):
22 | raise NotImplementedError
23 |
24 | def add(self, *args, **kwargs):
25 | """
26 | Add parallel hypotheses and references to metric buffer
27 | """
28 | raise NotImplementedError
29 |
30 | def eval(self):
31 | """
32 | Evaluate the performance with buffered hypotheses and references.
33 | """
34 | raise NotImplementedError
35 |
36 | def __len__(self):
37 | raise NotImplementedError
38 |
39 | def __getitem__(self, idx):
40 | raise NotImplementedError
41 |
42 | def get_item(self, idx, to_str=False):
43 | """
44 | fetch a item at given index
45 |
46 | Args:
47 | idx: index of a pair of hypothesis and reference
48 | to_str: transform the pair to str format before return it
49 |
50 | Returns:
51 | item: a pair of item in tuple or string format
52 | """
53 | raise NotImplementedError
54 |
55 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/accuracy.py:
--------------------------------------------------------------------------------
1 | from bycha.metrics import PairwiseMetric, register_metric
2 |
3 |
4 | @register_metric
5 | class Accuracy(PairwiseMetric):
6 | """
7 | Accuracy evaluates accuracy of produced hypotheses labels by comparing with references.
8 | """
9 |
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def eval(self):
14 | """
15 | Calculate the accuracy of produced hypotheses comparing with references
16 | Returns:
17 | score (float): evaluation score
18 | """
19 | if self._score is not None:
20 | return self._score
21 | else:
22 | correct = 0
23 | for hypo, ref in zip(self.hypos, self.refs):
24 | correct += 1 if hypo == ref else 0
25 | self._score = correct / len(self.hypos)
26 | return self._score
27 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/bleu.py:
--------------------------------------------------------------------------------
1 | import sacrebleu
2 |
3 | from bycha.metrics import PairwiseMetric, register_metric
4 |
5 |
6 | @register_metric
7 | class BLEU(PairwiseMetric):
8 | """
9 | BLEU evaluates BLEU scores of produced hypotheses by comparing with references.
10 | """
11 |
12 | def __init__(self, no_tok=False, lang='en'):
13 | super().__init__()
14 | self._no_tok = no_tok
15 | self._lang = lang
16 |
17 | self._sacrebleu_kwargs = {}
18 | if self._no_tok:
19 | self._sacrebleu_kwargs['tokenize'] = 'none'
20 | else:
21 | self._sacrebleu_kwargs['tokenize'] = get_tokenize_by_lang(self._lang)
22 |
23 | def build(self, *args, **kwargs):
24 | """
25 | Build metric
26 | """
27 | self.reset()
28 |
29 | def add(self, hypo, ref):
30 | """
31 | Add parallel hypotheses and references to metric buffer
32 | """
33 | if isinstance(ref, str):
34 | ref = [ref]
35 | self._hypos.append(hypo)
36 | self._refs.append(ref)
37 |
38 | def eval(self):
39 | """
40 | Evaluate the performance with buffered hypotheses and references.
41 | """
42 | if self._score is not None:
43 | return self._score
44 | else:
45 | refs = list(zip(*self._refs))
46 | bleu = sacrebleu.corpus_bleu(self._hypos, refs, **self._sacrebleu_kwargs)
47 | self._score = bleu.score
48 | return self._score
49 |
50 |
51 | def get_tokenize_by_lang(lang):
52 | if lang in ['zh']:
53 | return 'zh'
54 | elif lang in ['ko']:
55 | return 'char'
56 | else:
57 | return '13a'
58 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/f1.py:
--------------------------------------------------------------------------------
1 | from bycha.metrics import PairwiseMetric, register_metric
2 |
3 |
4 | @register_metric
5 | class F1(PairwiseMetric):
6 | """
7 | F1 evaluates F1 of produced hypotheses labels by comparing with references.
8 | """
9 |
10 | def __init__(self, target_label):
11 | super().__init__()
12 | self._target_label = target_label
13 |
14 | self._precision, self._recall = 0, 0
15 |
16 | def eval(self):
17 | """
18 | Calculate the f1-score of produced hypotheses comparing with references
19 | Returns:
20 | score (float): evaluation score
21 | """
22 | if self._score is not None:
23 | return self._score
24 | else:
25 | if isinstance(self._target_label, int):
26 | self._precision, self._recall = self._fast_precision_recall()
27 | else:
28 | self._precision, self._recall = self._precision_recall()
29 | self._score = self._precision * self._recall * 2 / (self._precision + self.recall)
30 | return self._score
31 |
32 | def _precision_recall(self):
33 | true_positive, false_positive, true_negative, false_negative = 1e-8, 0, 0, 0
34 | for hypo, ref in zip(self.hypos, self.refs):
35 | if ref == self._target_label:
36 | if hypo == ref:
37 | true_positive += 1
38 | else:
39 | false_negative += 1
40 | else:
41 | if hypo == ref:
42 | true_negative += 1
43 | else:
44 | false_positive += 1
45 | precision = true_positive / (true_positive + false_positive)
46 | recall = true_positive / (true_positive + false_negative)
47 | return precision, recall
48 |
49 | def _fast_precision_recall(self):
50 | import torch
51 | hypos = torch.LongTensor(self.hypos)
52 | refs = torch.LongTensor(self.refs)
53 |
54 | from bycha.utils.runtime import Environment
55 | env = Environment()
56 | if env.device.startswith('cuda'):
57 | hypos, refs = hypos.cuda(), refs.cuda()
58 | with torch.no_grad():
59 | true_mask = refs.eq(self._target_label)
60 | pos_mask = hypos.eq(self._target_label)
61 | true_positive = true_mask.masked_fill(~pos_mask, False).long().sum().data.item() + 1e-8
62 | false_positive = (~true_mask).masked_fill(~pos_mask, False).long().sum().data.item()
63 | false_negative = true_mask.masked_fill(pos_mask, False).long().sum().data.item()
64 | precision = true_positive / (true_positive + false_positive)
65 | recall = true_positive / (true_positive + false_negative)
66 | return precision, recall
67 |
68 | @property
69 | def precision(self):
70 | return self._precision
71 |
72 | @property
73 | def recall(self):
74 | return self._recall
75 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/matthews_corr.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import matthews_corrcoef
2 |
3 | from bycha.metrics import PairwiseMetric, register_metric
4 |
5 |
6 | @register_metric
7 | class MatthewsCorr(PairwiseMetric):
8 | """
9 | MatthewsCorr evaluates matthews correlation of produced hypotheses labels by comparing with references.
10 | """
11 |
12 | def __init__(self):
13 | super().__init__()
14 |
15 | def eval(self):
16 | """
17 | Calculate the spearman correlation of produced hypotheses comparing with references
18 | Returns:
19 | score (float): evaluation score
20 | """
21 | if self._score is not None:
22 | return self._score
23 | else:
24 | self._score = matthews_corrcoef(self._refs, self._hypos)
25 | return self._score
26 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/pairwise_metric.py:
--------------------------------------------------------------------------------
1 | from bycha.metrics import AbstractMetric
2 |
3 |
4 | class PairwiseMetric(AbstractMetric):
5 | """
6 | PairwiseMtric evaluates pairwise comparison between refs and hypos.
7 | """
8 |
9 | def __init__(self):
10 | super().__init__()
11 | self._hypos, self._refs = [], []
12 |
13 | def reset(self):
14 | """
15 | Reset metric for a new round of evaluation
16 | """
17 | self._hypos.clear()
18 | self._refs.clear()
19 | self._score = None
20 |
21 | def add_all(self, hypos, refs):
22 | """
23 | Add all hypos and refs
24 | """
25 | for hypo, ref in zip(hypos, refs):
26 | self.add(hypo, ref)
27 |
28 | def add(self, hypo, ref):
29 | """
30 | Add parallel hypotheses and references to metric buffer
31 | """
32 | self._hypos.append(hypo)
33 | self._refs.append(ref)
34 |
35 | def eval(self):
36 | """
37 | Evaluate the performance with buffered hypotheses and references.
38 | """
39 | raise NotImplementedError
40 |
41 | def __len__(self):
42 | return len(self._hypos)
43 |
44 | def __getitem__(self, idx):
45 | return self._hypos[idx], self._refs[idx]
46 |
47 | def get_item(self, idx, to_str=False):
48 | """
49 | fetch a item at given index
50 |
51 | Args:
52 | idx: index of a pair of hypothesis and reference
53 | to_str: transform the pair to str format before return it
54 |
55 | Returns:
56 | item: a pair of item in tuple or string format
57 | """
58 | ret = self[idx]
59 | if to_str:
60 | ret = '\n\tHypothesis: {}\n\tGround Truth: {}\n'.format(ret[0], ret[1])
61 | return ret
62 |
63 | @property
64 | def hypos(self):
65 | return self._hypos
66 |
67 | @property
68 | def refs(self):
69 | return self._refs
70 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/pearson_corr.py:
--------------------------------------------------------------------------------
1 | from scipy.stats import pearsonr
2 | import numpy as np
3 |
4 | from bycha.metrics import PairwiseMetric, register_metric
5 |
6 |
7 | @register_metric
8 | class PearsonCorr(PairwiseMetric):
9 | """
10 | PearsonCorr evaluates pearson's correlation of produced hypotheses labels by comparing with references.
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def eval(self):
17 | """
18 | Calculate the spearman correlation of produced hypotheses comparing with references
19 | Returns:
20 | score (float): evaluation score
21 | """
22 | if self._score is not None:
23 | return self._score
24 | else:
25 | self._score = pearsonr(np.array(self._hypos), np.array(self._refs))[0]
26 | return self._score
27 |
28 |
--------------------------------------------------------------------------------
/mybycha/bycha/metrics/spearman_corr.py:
--------------------------------------------------------------------------------
1 | from scipy.stats import spearmanr
2 | import numpy as np
3 |
4 | from bycha.metrics import PairwiseMetric, register_metric
5 |
6 |
7 | @register_metric
8 | class SpearmanCorr(PairwiseMetric):
9 | """
10 | SpearmanCorr evaluates spearman's correlation of produced hypotheses labels by comparing with references.
11 | """
12 |
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def eval(self):
17 | """
18 | Calculate the spearman correlation of produced hypotheses comparing with references
19 | Returns:
20 | score (float): evaluation score
21 | """
22 | if self._score is not None:
23 | return self._score
24 | else:
25 | self._score = spearmanr(np.array(self._hypos), np.array(self._refs))[0]
26 | return self._score
27 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_model import AbstractModel
7 |
8 | register_model, create_model, registry = setup_registry('model', AbstractModel)
9 |
10 | models_dir = os.path.dirname(__file__)
11 | for file in os.listdir(models_dir):
12 | path = os.path.join(models_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | model_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.models.' + model_name)
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/abstract_encoder_decoder_model.py:
--------------------------------------------------------------------------------
1 | from bycha.models import AbstractModel
2 |
3 |
4 | class AbstractEncoderDecoderModel(AbstractModel):
5 | """
6 | AbstractEncoderDecoderModel defines interface for encoder-decoder model.
7 | It must contains two attributes: encoder and decoder.
8 | """
9 |
10 | def __init__(self, path, *args, **kwargs):
11 | super().__init__(path)
12 | self._args = args
13 | self._kwargs = kwargs
14 |
15 | self._encoder, self._decoder = None, None
16 |
17 | @property
18 | def encoder(self):
19 | return self._encoder
20 |
21 | @property
22 | def decoder(self):
23 | return self._decoder
24 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/huggingface/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 |
5 | models_dir = os.path.dirname(__file__)
6 | for file in os.listdir(models_dir):
7 | path = os.path.join(models_dir, file)
8 | if (
9 | not file.startswith('_')
10 | and not file.startswith('.')
11 | and (file.endswith('.py') or os.path.isdir(path))
12 | ):
13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file
14 | module = importlib.import_module('bycha.models.huggingface.' + model_name)
15 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/huggingface/huggingface_extractive_question_answering_model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModelForQuestionAnswering
2 |
3 | from bycha.models import register_model
4 | from bycha.models.abstract_encoder_decoder_model import AbstractEncoderDecoderModel
5 | from bycha.modules.layers.classifier import HuggingfaceClassifier
6 |
7 |
8 | @register_model
9 | class HuggingfaceExtractiveQuestionAnsweringModel(AbstractEncoderDecoderModel):
10 | """
11 | HuggingfaceExtractiveQuestionAnsweringModel is a extractive question answering model built on
12 | huggingface extractive question answering models.
13 |
14 | Args:
15 | pretrained_model: pretrained_model in huggingface
16 | has_answerable: has answerable problem
17 | path: path to restore model
18 | """
19 |
20 | def __init__(self, pretrained_model, has_answerable=False, path=None):
21 | super().__init__(path=path)
22 | self._pretrained_model = pretrained_model
23 | self._has_answerable = has_answerable
24 |
25 | self._config = None
26 | self._model = None
27 | self._special_tokens = None
28 | self._encoder, self._decoder = None, None
29 | if self._has_answerable:
30 | self._classification_head = None
31 |
32 | def _build(self, vocab_size, special_tokens):
33 | """
34 | Build model with vocabulary size and special tokens
35 |
36 | Args:
37 | vocab_size: vocabulary size of input sequence
38 | special_tokens: special tokens of input sequence
39 | """
40 | self._config = AutoConfig.from_pretrained(self._pretrained_model)
41 | self._model = AutoModelForQuestionAnswering.from_pretrained(self._pretrained_model, config=self._config,)
42 | self._special_tokens = special_tokens
43 |
44 | if self._has_answerable:
45 | self._classification_head = HuggingfaceClassifier(self._model.d_model, 2)
46 |
47 | def forward(self, input, answerable=None, start_positions=None, end_positions=None):
48 | """
49 | Compute output with neural input
50 |
51 | Args:
52 | input: input sequence
53 | answerable: gold answerable
54 | start_positions: gold start position
55 | end_positions: gold end position
56 |
57 | Returns:
58 | - log probability of start and end position
59 | """
60 | output = self._model(input,
61 | attention_mask=input.ne(self._special_tokens['pad']),
62 | start_positions=start_positions,
63 | end_positions=end_positions)
64 | return output
65 |
66 | def loss(self, input, answerable=None, start_positions=None, end_positions=None):
67 | """
68 | Compute loss from network inputs
69 |
70 | Args:
71 | input: input sequence
72 | answerable: gold answerable
73 | start_positions: gold start position
74 | end_positions: gold end position
75 |
76 | Returns:
77 | - loss
78 | """
79 | output = self(input, answerable, start_positions, end_positions)
80 | return output[0]
81 |
82 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/huggingface/huggingface_pretrain_bart_model.py:
--------------------------------------------------------------------------------
1 | from transformers import BartConfig, BartForConditionalGeneration
2 |
3 | from bycha.models import register_model
4 | from bycha.models.abstract_model import AbstractModel
5 |
6 |
7 | @register_model
8 | class HuggingfacePretrainBartModel(AbstractModel):
9 | """
10 | HuggingfacePretrainBartModel is a pretrained bart model built on
11 | huggingface pretrained bart models.
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 |
17 | self._config = None
18 | self._model = None
19 | self._special_tokens = None
20 |
21 | def _build(self, vocab_size, special_tokens):
22 | """
23 | Build model with vocabulary size and special tokens
24 |
25 | Args:
26 | vocab_size: vocabulary size of input sequence
27 | special_tokens: special tokens of input sequence
28 | """
29 | self._config = BartConfig(vocab_size=vocab_size, pad_token_id=special_tokens['pad'])
30 | self._model = BartForConditionalGeneration(self._config)
31 | self._special_tokens = special_tokens
32 |
33 | def forward(self, enc_input, dec_input):
34 | """
35 | Compute output with neural input
36 |
37 | Args:
38 | enc_input: encoder input sequence
39 | dec_input: decoder input sequence
40 |
41 | Returns:
42 | - log probability of next tokens in sequences
43 | """
44 | output = self._model(enc_input,
45 | attention_mask=enc_input.ne(self._special_tokens['pad']),
46 | decoder_input_ids=dec_input,
47 | use_cache=self._mode == 'infer')
48 | output = output[0]
49 | return output
50 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/huggingface/huggingface_pretrain_mbart_model.py:
--------------------------------------------------------------------------------
1 | from transformers import MBartConfig, MBartForConditionalGeneration
2 |
3 | from bycha.models import register_model
4 | from bycha.models.abstract_model import AbstractModel
5 |
6 |
7 | @register_model
8 | class HuggingfacePretrainMBartModel(AbstractModel):
9 | """
10 | HuggingfacePretrainBartModel is a pretrained bart model built on
11 | huggingface pretrained bart models.
12 | """
13 |
14 | def __init__(self, path=None, pretrained_path=None):
15 | super().__init__(path)
16 |
17 | self._config = None
18 | self._model = None
19 | self._special_tokens = None
20 | self._pretrained_path = pretrained_path
21 |
22 | def _build(self, vocab_size, special_tokens):
23 | """
24 | Build model with vocabulary size and special tokens
25 |
26 | Args:
27 | vocab_size: vocabulary size of input sequence
28 | special_tokens: special tokens of input sequence
29 | """
30 | self._config = MBartConfig(vocab_size=vocab_size, pad_token_id=special_tokens['pad'])
31 | if self._pretrained_path:
32 | self._model = MBartForConditionalGeneration(self._config).from_pretrained(self._pretrained_path)
33 | else:
34 | self._model = MBartForConditionalGeneration(self._config)
35 | self._special_tokens = special_tokens
36 |
37 | def forward(self, src, tgt):
38 | """
39 | Compute output with neural input
40 |
41 | Args:
42 | src: encoder input sequence
43 | tgt: decoder input sequence
44 |
45 | Returns:
46 | - log probability of next tokens in sequences
47 | """
48 | output = self._model(src,
49 | attention_mask=src.ne(self._special_tokens['pad']),
50 | decoder_input_ids=tgt,
51 | use_cache=self._mode == 'infer')
52 | output = output[0]
53 | return output
54 |
55 | def generate(self, src, tgt_langtok_id, max_length, beam):
56 | return self._model.generate(input_ids=src, decoder_start_token_id=tgt_langtok_id, max_length=max_length, num_beams=beam)
57 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/huggingface/huggingface_sequence_classification_model.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModelForSequenceClassification
2 |
3 | from bycha.models import register_model
4 | from bycha.models.abstract_model import AbstractModel
5 |
6 |
7 | @register_model
8 | class HuggingfaceSequenceClassificationModel(AbstractModel):
9 | """
10 | HuggingfaceSequenceClassificationModel is a sequence classification architecture built on
11 | huggingface sequence classification models.
12 |
13 | Args:
14 | pretrained_model: pretrained_model in huggingface
15 | num_labels: number of labels
16 | """
17 |
18 | def __init__(self, pretrained_model, num_labels=2):
19 | super().__init__()
20 | self._pretrained_model = pretrained_model
21 | self._num_labels = num_labels
22 |
23 | self._config = None
24 | self._model = None
25 | self._special_tokens = None
26 |
27 | def _build(self, vocab_size, special_tokens):
28 | """
29 | Build model with vocabulary size and special tokens
30 |
31 | Args:
32 | vocab_size: vocabulary size of input sequence
33 | special_tokens: special tokens of input sequence
34 | """
35 | self._config = AutoConfig.from_pretrained(
36 | self._pretrained_model,
37 | num_labels=self._num_labels
38 | )
39 | self._model = AutoModelForSequenceClassification.from_pretrained(
40 | self._pretrained_model,
41 | config=self._config,
42 | )
43 | self._special_tokens = special_tokens
44 |
45 | def forward(self, input):
46 | """
47 | Compute output with neural input
48 |
49 | Args:
50 | input: input source sequences
51 |
52 | Returns:
53 | - log probability of labels
54 | """
55 | output = self._model(input, attention_mask=input.ne(self._special_tokens['pad']))
56 | output = output.logits if self._num_labels > 1 else output.logits.squeeze(dim=-1)
57 | return output
58 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/seq2seq.py:
--------------------------------------------------------------------------------
1 | from bycha.models import register_model
2 | from bycha.models.encoder_decoder_model import EncoderDecoderModel
3 |
4 |
5 | @register_model
6 | class Seq2Seq(EncoderDecoderModel):
7 | """
8 | EncoderDecoderModel defines overall encoder-decoder architecture.
9 |
10 | Args:
11 | encoder: encoder configurations to build an encoder
12 | decoder: decoder configurations to build an decoder
13 | d_model: feature embedding
14 | share_embedding: how the embedding is share [all, decoder-input-output, None].
15 | `all` indicates that source embedding, target embedding and target
16 | output projection are the same.
17 | `decoder-input-output` indicates that only target embedding and target
18 | output projection are the same.
19 | `None` indicates that none of them are the same.
20 | path: path to restore model
21 | """
22 |
23 | def __init__(self,
24 | encoder,
25 | decoder,
26 | d_model,
27 | share_embedding=None,
28 | path=None):
29 | super().__init__(encoder=encoder,
30 | decoder=decoder,
31 | d_model=d_model,
32 | share_embedding=share_embedding,
33 | path=path)
34 |
35 | def forward(self, src, tgt):
36 | """
37 | Compute output with neural input
38 |
39 | Args:
40 | src: source sequence
41 | tgt: previous tokens at target side, which is a time-shifted target sequence in training
42 |
43 | Returns:
44 | - log probability of next token at target side
45 | """
46 | memory, memory_padding_mask = self._encoder(src=src)
47 | logits = self._decoder(tgt=tgt,
48 | memory=memory,
49 | memory_padding_mask=memory_padding_mask)
50 | return logits
51 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/sequence_classification_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from bycha.models import AbstractModel, register_model
4 | from bycha.modules.encoders import create_encoder
5 | from bycha.modules.layers.embedding import Embedding
6 | from bycha.modules.layers.classifier import HuggingfaceClassifier
7 |
8 |
9 | @register_model
10 | class SequenceClassificationModel(AbstractModel):
11 | """
12 | SequenceClassificationModel is a general sequence classification architecture consisting of
13 | one encoder and one classifier.
14 |
15 | Args:
16 | encoder: encoder configuration
17 | labels: number of labels
18 | dropout: dropout
19 | source_num: the number of input source sequence
20 | path: path to restore model
21 | """
22 |
23 | def __init__(self,
24 | encoder,
25 | labels,
26 | dropout=0.,
27 | source_num=1,
28 | path=None):
29 | super().__init__(path)
30 | self._encoder_config = encoder
31 | self._labels = labels
32 | self._source_num = source_num
33 | self._dropout = dropout
34 |
35 | self._encoder, self._classifier = None, None
36 | self._path = path
37 |
38 | def _build(self, vocab_size, special_tokens):
39 | """
40 | Build model with vocabulary size and special tokens
41 |
42 | Args:
43 | vocab_size: vocabulary size of input sequence
44 | special_tokens: special tokens of input sequence
45 | """
46 | self._build_encoder(vocab_size=vocab_size, special_tokens=special_tokens)
47 | self._build_classifier()
48 |
49 | def _build_encoder(self, vocab_size, special_tokens):
50 | """
51 | Build encoder with vocabulary size and special tokens
52 |
53 | Args:
54 | vocab_size: vocabulary size of input sequence
55 | special_tokens: special tokens of input sequence
56 | """
57 | self._encoder = create_encoder(self._encoder_config)
58 | embed = Embedding(vocab_size=vocab_size,
59 | d_model=self.encoder.d_model,
60 | padding_idx=special_tokens['pad'])
61 | self._encoder.build(embed=embed, special_tokens=special_tokens)
62 |
63 | def _build_classifier(self):
64 | """
65 | Build classifer on label space
66 | """
67 | self._classifier = HuggingfaceClassifier(self.encoder.out_dim * self._source_num, self._labels, dropout=self._dropout)
68 |
69 | @property
70 | def encoder(self):
71 | return self._encoder
72 |
73 | @property
74 | def classifier(self):
75 | return self._classifier
76 |
77 | def forward(self, *inputs):
78 | """
79 | Compute output with neural input
80 |
81 | Args:
82 | *inputs: input source sequences
83 |
84 | Returns:
85 | - log probability of labels
86 | """
87 | x = [self.encoder(t)[-1] for t in inputs]
88 | x = torch.cat(x, dim=-1)
89 | logits = self.classifier(x)
90 | return logits
91 |
92 |
--------------------------------------------------------------------------------
/mybycha/bycha/models/variational_auto_encoder.py:
--------------------------------------------------------------------------------
1 | from bycha.models import register_model
2 | from bycha.models.seq2seq import Seq2Seq
3 |
4 |
5 | @register_model
6 | class VariationalAutoEncoders(Seq2Seq):
7 | """
8 | VariationalAutoEncoders is an extension to Seq2Seq model with latent space .
9 |
10 | Args:
11 | encoder: encoder configurations to build an encoder
12 | decoder: decoder configurations to build an decoder
13 | d_model: feature embedding
14 | share_embedding: how the embedding is share [all, decoder-input-output, None].
15 | `all` indicates that source embedding, target embedding and target
16 | output projection are the same.
17 | `decoder-input-output` indicates that only target embedding and target
18 | output projection are the same.
19 | `None` indicates that none of them are the same.
20 | path: path to restore model
21 | """
22 |
23 | def __init__(self,
24 | encoder,
25 | decoder,
26 | d_model,
27 | share_embedding=None,
28 | path=None,
29 | ):
30 | super().__init__(encoder=encoder,
31 | decoder=decoder,
32 | d_model=d_model,
33 | share_embedding=share_embedding,
34 | path=path)
35 |
36 | def reg_loss(self):
37 | """
38 | Auto-Encoding regularization loss
39 |
40 | Returns:
41 | - KL loss between prior and posterior
42 | """
43 | return self.encoder.reg_loss()
44 |
45 | def nll(self, rec_loss, reg_losses, method="elbo"):
46 | """
47 | NLL loss
48 |
49 | Args:
50 | rec_loss: reconstruction loss
51 | reg_losses: regularization loss
52 | method: generation method
53 |
54 | Returns:
55 | - NLL loss
56 | """
57 | return self.encoder.nll(rec_loss, reg_losses, method)
58 |
59 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/modules/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/modules/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_decoder import AbstractDecoder
7 |
8 | register_decoder, create_decoder, registry = setup_registry('decoder', AbstractDecoder)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.modules.decoders.' + module_name)
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/decoders/abstract_decoder.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 |
3 |
4 | class AbstractDecoder(Module):
5 | """
6 | AbstractEncoder is the abstract for encoders, and defines general interface for encoders.
7 |
8 | Args:
9 | name: encoder name
10 | """
11 |
12 | def __init__(self, name=None):
13 | super().__init__()
14 | self._name = name
15 | self._cache = {}
16 | self._mode = 'train'
17 |
18 | def build(self, *args, **kwargs):
19 | """
20 | Build decoder with task instance
21 | """
22 | raise NotImplementedError
23 |
24 | def forward(self, *args, **kwargs):
25 | """
26 | Process forward of decoder.
27 | """
28 | raise NotImplementedError
29 |
30 | def reset(self, mode):
31 | """
32 | Reset encoder and switch running mode
33 |
34 | Args:
35 | mode: running mode in [train, valid, infer]
36 | """
37 | self._cache.clear()
38 | self._mode = mode
39 |
40 | def get_cache(self):
41 | """
42 | Retrieve inner cache
43 |
44 | Returns:
45 | - cached states as a Dict
46 | """
47 | return self._cache
48 |
49 | def set_cache(self, cache):
50 | """
51 | Set cache from outside
52 |
53 | Args:
54 | cache: cache dict from outside
55 | """
56 | self._cache = cache
57 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/decoders/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .abstract_decoder_layer import AbstractDecoderLayer
2 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/decoders/layers/abstract_decoder_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class AbstractDecoderLayer(nn.Module):
8 | """
9 | AbstractDecoderLayer is an abstract class for decoder layers.
10 | """
11 |
12 | def __init__(self):
13 | super().__init__()
14 | self._cache = dict()
15 | self._mode = 'train'
16 | self._dummy_param = nn.Parameter(torch.empty(0))
17 |
18 | def reset(self, mode: str):
19 | """
20 | Reset encoder layer and switch running mode
21 |
22 | Args:
23 | mode: running mode in [train, valid, infer]
24 | """
25 | self._cache: Dict[str, torch.Tensor] = {"prev": self._dummy_param}
26 | self._mode = mode
27 |
28 | def _update_cache(self, *args, **kwargs):
29 | """
30 | Update cache with current states
31 | """
32 | pass
33 |
34 | def get_cache(self):
35 | """
36 | Retrieve inner cache
37 |
38 | Returns:
39 | - cached states as a Dict
40 | """
41 | return self._cache
42 |
43 | def set_cache(self, cache: Dict[str, torch.Tensor]):
44 | """
45 | Set cache from outside
46 |
47 | Args:
48 | cache: cache dict from outside
49 | """
50 | self._cache = cache
51 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_encoder import AbstractEncoder
7 |
8 | register_encoder, create_encoder, registry = setup_registry('encoder', AbstractEncoder)
9 |
10 |
11 | modules_dir = os.path.dirname(__file__)
12 | for file in os.listdir(modules_dir):
13 | path = os.path.join(modules_dir, file)
14 | if (
15 | not file.startswith('_')
16 | and not file.startswith('.')
17 | and (file.endswith('.py') or os.path.isdir(path))
18 | ):
19 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
20 | module = importlib.import_module('bycha.modules.encoders.' + module_name)
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/abstract_encoder.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 |
3 |
4 | class AbstractEncoder(Module):
5 | """
6 | AbstractEncoder is the abstract for encoders, and defines general interface for encoders.
7 |
8 | Args:
9 | name: encoder name
10 | """
11 |
12 | def __init__(self, name=None):
13 | super().__init__()
14 | self._name = name
15 | self._cache = {}
16 | self._mode = 'train'
17 |
18 | def build(self, *args, **kwargs):
19 | """
20 | Build encoder with task instance
21 | """
22 | raise NotImplementedError
23 |
24 | def forward(self, *args, **kwargs):
25 | """
26 | Process forward of encoder. Outputs are cached until the encoder is reset.
27 | """
28 | if self._mode == 'train':
29 | if 'out' not in self._cache:
30 | out = self._forward(*args, **kwargs)
31 | self._cache['out'] = out
32 | return self._cache['out']
33 | else:
34 | return self._forward(*args, **kwargs)
35 |
36 | def _forward(self, *args, **kwargs):
37 | """
38 | Forward function to override. Its results can be auto cached in forward.
39 | """
40 | raise NotImplementedError
41 |
42 | @property
43 | def name(self):
44 | return self._name
45 |
46 | @property
47 | def d_model(self):
48 | raise NotImplementedError
49 |
50 | @property
51 | def out_dim(self):
52 | raise NotImplementedError
53 |
54 | def _cache_states(self, name, state):
55 | """
56 | Cache a state into encoder cache
57 |
58 | Args:
59 | name: state key
60 | state: state value
61 | """
62 | self._cache[name] = state
63 |
64 | def reset(self, mode):
65 | """
66 | Reset encoder and switch running mode
67 |
68 | Args:
69 | mode: running mode in [train, valid, infer]
70 | """
71 | self._cache.clear()
72 | self._mode = mode
73 |
74 | def set_cache(self, cache):
75 | """
76 | Set cache from outside
77 |
78 | Args:
79 | cache: cache dict from outside
80 | """
81 | self._cache = cache
82 |
83 | def get_cache(self):
84 | """
85 | Retrieve inner cache
86 |
87 | Returns:
88 | - cached states as a Dict
89 | """
90 | return self._cache
91 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/huggingface_encoder.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from transformers import AutoConfig, AutoModel
3 |
4 | from bycha.modules.encoders import AbstractEncoder, register_encoder
5 |
6 |
7 | @register_encoder
8 | class HuggingFaceEncoder(AbstractEncoder):
9 | """
10 | HuggingFaceEncoder is a wrapped encoder from huggingface pretrained models
11 |
12 | Args:
13 | pretrained_model: name of pretrained model, see huggingface for supported models
14 | freeze: freeze pretrained model in training
15 | return_seed: return with sequence representation
16 | name: encoder name
17 | """
18 |
19 | def __init__(self,
20 | pretrained_model,
21 | freeze=False,
22 | return_seed=False,
23 | name=None):
24 | super().__init__(name=name)
25 | self._pretrained_model_name = pretrained_model
26 | self._freeze = freeze
27 | self._return_seed = return_seed
28 |
29 | self._special_tokens = None
30 | self._configs = None
31 | self._huggingface_model = None
32 |
33 | def build(self, special_tokens=None, vocab_size=None):
34 | """
35 | Build computational graph
36 |
37 | Args:
38 | special_tokens: special_tokens: special tokens defined in vocabulary
39 | vocab_size: vocabulary size of embedding
40 | """
41 | self._special_tokens = special_tokens
42 | self._configs = AutoConfig.from_pretrained(self._pretrained_model_name)
43 | self._huggingface_model = AutoModel.from_config(self._configs)
44 | if self._freeze:
45 | self.freeze_params()
46 |
47 | assert self._configs.vocab_size == vocab_size
48 |
49 | def freeze_params(self):
50 | """
51 | Freeze parameters of pretrained model
52 | """
53 | for param in self._huggingface_model.base_model.parameters():
54 | param.requires_grad = False
55 |
56 | def _forward(self, text: Tensor):
57 | r"""
58 | Args:
59 | text: tokens in src side.
60 | :math:`(N, S)` where N is the batch size, S is the source sequence length.
61 |
62 | Returns:
63 | - source token hidden representation.
64 | :math:`(S, N, E)` where S is the source sequence length, N is the batch size,
65 | E is the embedding size.
66 | """
67 | padding_mask = text.eq(self._padding_idx)
68 | model_out = self._huggingface_model(text, ~padding_mask)
69 | try:
70 | x, seed = model_out['last_hidden_state'], model_out['pooler_output']
71 | except:
72 | x, seed = model_out
73 | finally:
74 | x = x.transpose(0, 1)
75 | if self._return_seed:
76 | return x, padding_mask, seed
77 | else:
78 | return x, padding_mask
79 |
80 | @property
81 | def out_dim(self):
82 | return self._configs.hidden_size
83 |
84 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .abstract_encoder_layer import AbstractEncoderLayer
2 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/layers/abstract_encoder_layer.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class AbstractEncoderLayer(nn.Module):
5 | """
6 | AbstractEncoderLayer is an abstract class for encoder layers.
7 | """
8 |
9 | def __init__(self):
10 | super().__init__()
11 | self._cache = {}
12 | self._mode = 'train'
13 |
14 | def reset(self, mode):
15 | """
16 | Reset encoder layer and switch running mode
17 |
18 | Args:
19 | mode: running mode in [train, valid, infer]
20 | """
21 | self._mode = mode
22 | self._cache.clear()
23 |
24 | def _update_cache(self, *args, **kwargs):
25 | """
26 | Update internal cache from outside states
27 | """
28 | pass
29 |
30 | def get_cache(self):
31 | """
32 | Retrieve inner cache
33 |
34 | Returns:
35 | - cached states as a Dict
36 | """
37 | return self._cache
38 |
39 | def set_cache(self, cache):
40 | """
41 | Set cache from outside
42 |
43 | Args:
44 | cache: cache dict from outside
45 | """
46 | self._cache = cache
47 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/layers/moe_encoder_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | from torch import Tensor
4 | from torch import nn
5 |
6 | from bycha.modules.encoders.layers import AbstractEncoderLayer
7 | from bycha.modules.layers.moe import MoE
8 |
9 |
10 | class MoEEncoderLayer(AbstractEncoderLayer):
11 | """
12 | TransformerEncoderLayer performs one layer of transformer operation, namely self-attention and feed-forward network.
13 |
14 | Args:
15 | d_model: feature dimension
16 | nhead: head numbers of multihead attention
17 | dim_feedforward: dimensionality of inner vector space
18 | dropout: dropout rate
19 | activation: activation function used in feed-forward network
20 | normalize_before: use pre-norm fashion, default as post-norm.
21 | Pre-norm suit deep nets while post-norm achieve better results when nets are shallow.
22 | """
23 |
24 | def __init__(self,
25 | d_model,
26 | nhead,
27 | dim_feedforward=2048,
28 | dropout=0.1,
29 | attention_dropout=0,
30 | activation="relu",
31 | normalize_before=False,
32 | num_experts=1,
33 | sparse=True,):
34 | super(MoEEncoderLayer, self).__init__()
35 | self.normalize_before = normalize_before
36 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=attention_dropout)
37 |
38 | self.moe = MoE(d_model, dim_feedforward=dim_feedforward, activation=activation,
39 | num_experts=num_experts, sparse=sparse)
40 |
41 | self.self_attn_norm = nn.LayerNorm(d_model)
42 | self.ffn_norm = nn.LayerNorm(d_model)
43 | self.dropout1 = nn.Dropout(dropout)
44 | self.dropout2 = nn.Dropout(dropout)
45 |
46 | def forward(self,
47 | src: Tensor,
48 | src_mask: Optional[Tensor] = None,
49 | src_key_padding_mask: Optional[Tensor] = None):
50 | r"""Pass the input through the encoder layer.
51 |
52 | Args:
53 | src: the sequence to the encoder layer (required).
54 | :math:`(S, B, D)`, where S is sequence length, B is batch size and D is feature dimension
55 | src_mask: the attention mask for the src sequence (optional).
56 | :math:`(S, S)`, where S is sequence length.
57 | src_key_padding_mask: the mask for the src keys per batch (optional).
58 | :math: `(B, S)`, where B is batch size and S is sequence length
59 | """
60 | residual = src
61 | if self.normalize_before:
62 | src = self.self_attn_norm(src)
63 | src = self.self_attn(src, src, src, attn_mask=src_mask,
64 | key_padding_mask=src_key_padding_mask)[0]
65 | src = self.dropout1(src)
66 | src = residual + src
67 | if not self.normalize_before:
68 | src = self.self_attn_norm(src)
69 |
70 | residual = src
71 | if self.normalize_before:
72 | src = self.ffn_norm(src)
73 | src, loss = self.moe(src)
74 | src = self.dropout2(src)
75 | src = residual + src
76 | if not self.normalize_before:
77 | src = self.ffn_norm(src)
78 | return src, loss
79 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/layers/transformer_encoder_layer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor
4 | from torch import nn
5 |
6 | from bycha.modules.encoders.layers import AbstractEncoderLayer
7 | from bycha.modules.layers.feed_forward import FFN
8 |
9 |
10 | class TransformerEncoderLayer(AbstractEncoderLayer):
11 | """
12 | TransformerEncoderLayer performs one layer of transformer operation, namely self-attention and feed-forward network.
13 |
14 | Args:
15 | d_model: feature dimension
16 | nhead: head numbers of multihead attention
17 | dim_feedforward: dimensionality of inner vector space
18 | dropout: dropout rate
19 | activation: activation function used in feed-forward network
20 | normalize_before: use pre-norm fashion, default as post-norm.
21 | Pre-norm suit deep nets while post-norm achieve better results when nets are shallow.
22 | """
23 |
24 | def __init__(self,
25 | d_model,
26 | nhead,
27 | dim_feedforward=2048,
28 | dropout=0.1,
29 | attention_dropout=0,
30 | activation="relu",
31 | normalize_before=False,):
32 | super(TransformerEncoderLayer, self).__init__()
33 | self.normalize_before = normalize_before
34 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=attention_dropout)
35 | # Implementation of Feedforward model
36 | self.ffn = FFN(d_model, dim_feedforward=dim_feedforward, activation=activation)
37 |
38 | self.self_attn_norm = nn.LayerNorm(d_model)
39 | self.ffn_norm = nn.LayerNorm(d_model)
40 | self.dropout1 = nn.Dropout(dropout)
41 | self.dropout2 = nn.Dropout(dropout)
42 |
43 | def forward(self,
44 | src: Tensor,
45 | src_mask: Optional[Tensor] = None,
46 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
47 | r"""Pass the input through the encoder layer.
48 |
49 | Args:
50 | src: the sequence to the encoder layer (required).
51 | :math:`(S, B, D)`, where S is sequence length, B is batch size and D is feature dimension
52 | src_mask: the attention mask for the src sequence (optional).
53 | :math:`(S, S)`, where S is sequence length.
54 | src_key_padding_mask: the mask for the src keys per batch (optional).
55 | :math: `(B, S)`, where B is batch size and S is sequence length
56 | """
57 | residual = src
58 | if self.normalize_before:
59 | src = self.self_attn_norm(src)
60 | src = self.self_attn(src, src, src, attn_mask=src_mask,
61 | key_padding_mask=src_key_padding_mask)[0]
62 | src = self.dropout1(src)
63 | src = residual + src
64 | if not self.normalize_before:
65 | src = self.self_attn_norm(src)
66 |
67 | residual = src
68 | if self.normalize_before:
69 | src = self.ffn_norm(src)
70 | src = self.ffn(src)
71 | src = self.dropout2(src)
72 | src = residual + src
73 | if not self.normalize_before:
74 | src = self.ffn_norm(src)
75 | return src
76 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/encoders/lstm_encoder.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from torch.nn import LSTM
3 | import torch.nn as nn
4 |
5 | from bycha.modules.encoders import AbstractEncoder, register_encoder
6 |
7 |
8 | @register_encoder
9 | class LSTMEncoder(AbstractEncoder):
10 | """
11 | LSTMEncoder is a LSTM encoder.
12 |
13 | Args:
14 | num_layers: number of encoder layers
15 | d_model: feature dimension
16 | hidden_size: hidden size within LSTM
17 | dropout: dropout rate
18 | bidirectional: encode sequence with bidirectional LSTM
19 | return_seed: return with sequence representation
20 | name: module name
21 | """
22 |
23 | def __init__(self,
24 | num_layers,
25 | d_model=512,
26 | hidden_size=1024,
27 | dropout=0.1,
28 | bidirectional=True,
29 | return_seed=None,
30 | name=None):
31 | super().__init__()
32 | self._num_layers = num_layers
33 | self._d_model = d_model
34 | self._hidden_size = hidden_size
35 | self._dropout = dropout
36 | self._bidirectional = bidirectional
37 | self._return_seed = return_seed
38 | self._name = name
39 |
40 | self._special_tokens = None
41 | self._embed, self._embed_dropout = None, None
42 | self._layer = None
43 | self._pool_seed = None
44 |
45 | def build(self, embed, special_tokens):
46 | """
47 | Build computational modules.
48 |
49 | Args:
50 | embed: token embedding
51 | special_tokens: special tokens defined in vocabulary
52 | """
53 | self._embed = embed
54 | self._special_tokens = special_tokens
55 | self._embed_dropout = nn.Dropout(self._dropout)
56 | self._layer = LSTM(input_size=self._d_model,
57 | hidden_size=self._hidden_size,
58 | num_layers=self._num_layers,
59 | dropout=self._dropout,
60 | bidirectional=self._bidirectional)
61 |
62 | def _forward(self,
63 | src: Tensor):
64 | r"""
65 | Args:
66 | src: tokens in src side.
67 | :math:`(N, S)` where N is the batch size, S is the source sequence length.
68 |
69 | Returns:
70 | - source token hidden representation.
71 | :math:`(S, N, E)` where S is the source sequence length, N is the batch size,
72 | E is the embedding size.
73 | """
74 | x = self._embed(src)
75 | x = self._embed_dropout(x)
76 |
77 | src_padding_mask = src.eq(self._special_tokens['pad'])
78 | x = x.transpose(0, 1)
79 | x = self._layer(x)[0]
80 |
81 | if self._pool_seed:
82 | return x, src_padding_mask, x.mean(dim=0)
83 | else:
84 | return x, src_padding_mask
85 |
86 | @property
87 | def d_model(self):
88 | return self._d_model
89 |
90 | @property
91 | def out_dim(self):
92 | return self._hidden_size * 2 if self._bidirectional else self._hidden_size
93 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/modules/layers/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/autopruning_ffn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from bycha.modules.utils import get_activation_fn
6 | from bycha.modules.layers.gumbel import gumbel_softmax_topk
7 |
8 | class AutoPruningFFN(nn.Module):
9 | """
10 | Feed-forward neural network
11 |
12 | Args:
13 | d_model: input feature dimension
14 | dim_feedforward: dimensionality of inner vector space
15 | dim_out: output feature dimensionality
16 | activation: activation function
17 | bias: requires bias in output linear function
18 | """
19 |
20 | def __init__(self,
21 | d_model,
22 | dim_feedforward=None,
23 | dim_out=None,
24 | activation="relu",
25 | bias=True):
26 | super().__init__()
27 | self._dim_feedforward = dim_feedforward or d_model
28 | self._dim_out = dim_out or d_model
29 | self._bias = bias
30 |
31 | self._fc1 = nn.Linear(d_model, self._dim_feedforward)
32 | self._fc2 = nn.Linear(self._dim_feedforward, self._dim_out, bias=self._bias)
33 | self._activation = get_activation_fn(activation)
34 |
35 | def forward(self, x, weights=None, sorted_indeces=None, tau=0.):
36 | """
37 | Args:
38 | x: feature to perform feed-forward net
39 | :math:`(*, D)`, where D is feature dimension
40 |
41 | Returns:
42 | - feed forward output
43 | :math:`(*, D)`, where D is feature dimension
44 | """
45 | if weights is None or sorted_indeces is None or not self.training:
46 | x = self._fc1(x)
47 | x = self._activation(x)
48 | x = self._fc2(x)
49 | return x
50 |
51 | gumbel_onehot, _ = gumbel_softmax_topk(weights, tau=tau)
52 | one_index = gumbel_onehot.max(-1, keepdim=True)[1].item()
53 | prune_ratio = gumbel_onehot[one_index] * one_index * 0.1
54 | prune_num = torch.floor(prune_ratio * self._dim_feedforward).long()
55 |
56 | x = F.linear(x, self._fc1.weight[sorted_indeces[prune_num:]],
57 | self._fc1.bias[sorted_indeces[prune_num:]]) if self._bias else \
58 | F.linear(x, self._fc1.weight[sorted_indeces[prune_num:]])
59 | x *= gumbel_onehot[one_index]
60 | x = self._activation(x)
61 | x = F.linear(x, self._fc2.weight[:, sorted_indeces[prune_num:]], self._fc2.bias)
62 | return x
63 |
64 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/bert_layer_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BertLayerNorm(nn.Module):
6 | """
7 | BertLayerNorm is layer norm used in BERT.
8 | It is a layernorm module in the TF style (epsilon inside the square root).
9 |
10 | Args:
11 | hidden_size: dimensionality of hidden space
12 | """
13 |
14 | def __init__(self, hidden_size, eps=1e-12):
15 |
16 | super(BertLayerNorm, self).__init__()
17 | self.gamma = nn.Parameter(torch.ones(hidden_size))
18 | self.beta = nn.Parameter(torch.zeros(hidden_size))
19 | self.variance_epsilon = eps
20 |
21 | def forward(self, x):
22 | r"""
23 |
24 |
25 | Args:
26 | x: feature to perform layer norm
27 | :math:`(*, D)`, where D is the feature dimension
28 |
29 | Returns:
30 | - normalized feature
31 | :math:`(*, D)`, where D is the feature dimension
32 | """
33 | u = x.mean(-1, keepdim=True)
34 | s = (x - u).pow(2).mean(-1, keepdim=True)
35 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
36 | return self.gamma * x + self.beta
37 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/dlcl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 |
5 |
6 | class DynamicLinearCombinationLayer(nn.Module):
7 | """
8 | DLCL: make the input to be a linear combination of previous outputs
9 | For pre-norm, x_{l+1} = \sum_{k=0}^{l} W_{k} * LN(y_{k})
10 | For post-norm, x_{l+1} = LN(\sum_{k=0}^{l} W_{k} * y_{k})
11 | where x_{l}, y_{l} are the input and output of l-th layer
12 |
13 | For pre-norm, LN should be performed in previous layer
14 | For post-norm, LN is performed in this layer
15 |
16 | Args:
17 | idx: this is the `idx`-th layer
18 | post_ln: post layernorm
19 | """
20 | def __init__(self, idx, post_ln=None):
21 | super(DynamicLinearCombinationLayer, self).__init__()
22 | assert (idx > 0)
23 | self.linear = nn.Linear(idx, 1, bias=False)
24 | nn.init._no_grad_fill_(self.linear.weight, 1.0 / idx)
25 | self.post_ln = post_ln
26 |
27 | def forward(self, y):
28 | """
29 | Args:
30 | y: SequenceLength x BatchSize x Dim x idx
31 | """
32 | x = self.linear(y)
33 | x = x.squeeze(dim=-1)
34 | if self.post_ln is not None:
35 | x = self.post_ln(x)
36 | return x
37 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/embedding.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Embedding(nn.Embedding):
5 | """
6 | Embedding is a wrapped class of torch.nn.Embedding with normal initialization on weight
7 | and zero initialization on pad.
8 |
9 | Args:
10 | vocab_size: vocabulary size
11 | d_model: feature dimensionality
12 | padding_idx: index of pad, which is a special token to ignore
13 | """
14 |
15 | def __init__(self, vocab_size, d_model, padding_idx=None):
16 | super().__init__(vocab_size, d_model, padding_idx=padding_idx)
17 | nn.init.normal_(self.weight, mean=0, std=d_model ** -0.5)
18 | if padding_idx:
19 | nn.init.constant_(self.weight[padding_idx], 0)
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from bycha.modules.utils import get_activation_fn
4 |
5 |
6 | class FFN(nn.Module):
7 | """
8 | Feed-forward neural network
9 |
10 | Args:
11 | d_model: input feature dimension
12 | dim_feedforward: dimensionality of inner vector space
13 | dim_out: output feature dimensionality
14 | activation: activation function
15 | bias: requires bias in output linear function
16 | """
17 |
18 | def __init__(self,
19 | d_model,
20 | dim_feedforward=None,
21 | dim_out=None,
22 | activation="relu",
23 | bias=True):
24 | super().__init__()
25 | dim_feedforward = dim_feedforward or d_model
26 | dim_out = dim_out or d_model
27 |
28 | self._fc1 = nn.Linear(d_model, dim_feedforward)
29 | self._fc2 = nn.Linear(dim_feedforward, dim_out, bias=bias)
30 | self._activation = get_activation_fn(activation)
31 |
32 | def forward(self, x):
33 | """
34 | Args:
35 | x: feature to perform feed-forward net
36 | :math:`(*, D)`, where D is feature dimension
37 |
38 | Returns:
39 | - feed forward output
40 | :math:`(*, D)`, where D is feature dimension
41 | """
42 | x = self._fc1(x)
43 | x = self._activation(x)
44 | x = self._fc2(x)
45 | return x
46 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/gaussian.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Gaussian(nn.Module):
5 | """
6 | Gaussian predict gaussian characteristics, namely mean and logvar
7 |
8 | Args:
9 | d_model: feature dimension
10 | latent_size: dimensionality of gaussian distribution
11 | """
12 |
13 | def __init__(self, d_model, latent_size):
14 | super().__init__()
15 | self._dmodel = d_model
16 | self._latent_size = latent_size
17 |
18 | self.post_mean = nn.Linear(d_model, latent_size)
19 | self.post_logvar = nn.Linear(d_model, latent_size)
20 |
21 | def forward(self, x):
22 | """
23 | Args:
24 | x: feature to perform gaussian
25 | :math:`(*, D)`, where D is feature dimension
26 |
27 | Returns:
28 | - gaussian mean
29 | :math:`(*, D)`, where D is feature dimension
30 | - gaussian logvar
31 | :math:`(*, D)`, where D is feature dimension
32 | """
33 | post_mean = self.post_mean(x)
34 | post_logvar = self.post_logvar(x)
35 | return post_mean, post_logvar
36 |
37 |
38 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/gumbel.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def gumbel_softmax_topk(logits, k=1, tau=1, hard=True, dim=-1):
4 | while True:
5 | gumbels = -torch.empty_like(logits).exponential_().log()
6 | #print(gumbels, gumbels.device)
7 | gumbels = (logits + gumbels) / tau
8 | y_soft = gumbels.softmax(dim)
9 | if (torch.isinf(gumbels).any()) or (torch.isinf(y_soft).any()) or (torch.isnan(y_soft).any()):
10 | continue
11 | else:
12 | break
13 |
14 | index = None
15 | if hard:
16 | _, index = torch.topk(y_soft, k)
17 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
18 | ret = y_hard - y_soft.detach() + y_soft
19 | else:
20 | ret = y_soft
21 | return ret, index
22 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/layerdrop.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 |
6 | class LayerDropModuleList(nn.ModuleList):
7 | """
8 | Implementation of LayerDrop based on torch.nn.ModuleList
9 |
10 | Usage:
11 | Replace torch.nn.ModuleList with LayerDropModuleList
12 | For example:
13 | layers = nn.ModuleList([TransformerEncoderLayer
14 | for _ in range(num_layers)])
15 |
16 | -> layers = LayerDropModuleList(p=0.2, gamma=1e-4)
17 | layers.extend([TransformerEncoderLayer
18 | for _ in range(num_layers)])
19 |
20 | Args:
21 | p: initial drop probability
22 | gamma: attenuation speed of drop probability
23 | mode_depth: probability distribution across layers
24 | modules: an iterable of modules to add
25 | """
26 |
27 | def __init__(self, p, gamma=0., mode_depth=None, modules=None):
28 | super().__init__(modules)
29 | self.p = p
30 | self._gamma = gamma
31 | self._mode_depth = mode_depth
32 | self._step = -1
33 |
34 | def __iter__(self):
35 | self._step += 1
36 | layer_num = len(self)
37 | dropout_probs = torch.empty(layer_num).uniform_()
38 | if self.training and self._gamma > 0:
39 | p_now = self.p - self.p * math.exp(-self._gamma * self._step)
40 | else:
41 | p_now = self.p
42 | p_now = max(0., p_now)
43 |
44 | p_layers = [p_now] * layer_num
45 | if self._mode_depth == 'transformer':
46 | p_layers = [2*min(i+1, layer_num-i)/layer_num*p_now for i in range(layer_num)]
47 | elif self._mode_depth == 'bert':
48 | p_layers = [p_now*i/layer_num for i in range(1, layer_num+1)]
49 |
50 | for i, m in enumerate(super().__iter__()):
51 | m.layerdrop = p_layers[i] if self.training else 0.
52 | if not self.training or (dropout_probs[i] > m.layerdrop):
53 | yield m
54 |
55 | def config_to_params(config):
56 | layerdrop, gamma, mode_depth = 0., 0., None
57 | if config is None:
58 | return layerdrop, gamma, mode_depth
59 | if 'prob' in config:
60 | layerdrop = config['prob']
61 | if 'gamma' in config:
62 | gamma = config['gamma']
63 | if 'mode_depth' in config:
64 | mode_depth = config['mode_depth']
65 | return layerdrop, gamma, mode_depth
66 |
67 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/learned_positional_embedding.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch import Tensor
7 |
8 |
9 | class LearnedPositionalEmbedding(nn.Embedding):
10 | """
11 | This module learns positional embeddings up to a fixed maximum size.
12 |
13 | Args:
14 | num_embeddings: number of embeddings
15 | embedding_dim: embedding dimension
16 | """
17 |
18 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, post_mask=False):
19 | super().__init__(num_embeddings, embedding_dim, padding_idx)
20 | nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
21 | # if post_mask = True, then padding_idx = id of pad token in token embedding, we first mark padding
22 | # tokens using padding_idx, then generate embedding matrix using positional embedding, finally set
23 | # marked positions with zero
24 | self._post_mask = post_mask
25 |
26 | def forward(
27 | self,
28 | input: Tensor,
29 | positions: Optional[Tensor] = None
30 | ):
31 | """
32 | Args:
33 | input: an input LongTensor
34 | :math:`(*, L)`, where L is sequence length
35 | positions: pre-defined positions
36 | :math:`(*, L)`, where L is sequence length
37 |
38 | Returns:
39 | - positional embedding indexed from input
40 | :math:`(*, L, D)`, where L is sequence length and D is dimensionality
41 | """
42 | if self._post_mask:
43 | mask = input.ne(self.padding_idx).long()
44 | if positions is None:
45 | positions = (torch.cumsum(mask, dim=1) - 1).long()
46 | emb = F.embedding(
47 | positions,
48 | self.weight,
49 | None,
50 | self.max_norm,
51 | self.norm_type,
52 | self.scale_grad_by_freq,
53 | self.sparse,
54 | )#[B,L,H]
55 | emb = emb * mask.unsqueeze(-1)
56 | return emb
57 | else:
58 | if positions is None:
59 | mask = torch.ones_like(input)
60 | positions = (torch.cumsum(mask, dim=1) - 1).long()
61 | return F.embedding(
62 | positions,
63 | self.weight,
64 | self.padding_idx,
65 | self.max_norm,
66 | self.norm_type,
67 | self.scale_grad_by_freq,
68 | self.sparse,
69 | )
70 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/layers/sinusoidal_positional_embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.onnx.operators
5 | from torch import nn
6 |
7 |
8 | class SinusoidalPositionalEmbedding(nn.Module):
9 | """This module produces sinusoidal positional embeddings of any length.
10 |
11 | Padding symbols are ignored.
12 | """
13 |
14 | class __SinusoidalPositionalEmbedding(nn.Module):
15 |
16 | def __init__(self, embedding_dim, num_embeddings=1024):
17 | super().__init__()
18 | self._embedding_dim = embedding_dim
19 | self._num_embeddings = num_embeddings
20 |
21 | num_timescales = self._embedding_dim // 2
22 | log_timescale_increment = torch.FloatTensor([math.log(10000.) / (num_timescales - 1)])
23 | inv_timescales = nn.Parameter((torch.arange(num_timescales) * -log_timescale_increment).exp(), requires_grad=False)
24 | self.register_buffer('_inv_timescales', inv_timescales)
25 |
26 | def forward(
27 | self,
28 | input,
29 | ):
30 | """Input is expected to be of size [bsz x seqlen]."""
31 | mask = torch.ones_like(input).type_as(self._inv_timescales)
32 | positions = torch.cumsum(mask, dim=1) - 1
33 |
34 | scaled_time = positions[:, :, None] * self._inv_timescales[None, None, :]
35 | signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=-1)
36 | return signal.detach()
37 |
38 | __embed__ = None
39 |
40 | def __init__(self, embedding_dim, num_embeddings=1024):
41 | super().__init__()
42 | if not SinusoidalPositionalEmbedding.__embed__:
43 | SinusoidalPositionalEmbedding.__embed__ = SinusoidalPositionalEmbedding.__SinusoidalPositionalEmbedding(
44 | embedding_dim=embedding_dim,
45 | num_embeddings=num_embeddings
46 | )
47 | self.embedding = SinusoidalPositionalEmbedding.__embed__
48 |
49 | def forward(self, input):
50 | return self.embedding(input)
51 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/search/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_search import AbstractSearch
7 |
8 | register_search, create_search, registry = setup_registry('search', AbstractSearch)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.modules.search.' + module_name)
20 |
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/search/abstract_search.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Module
2 |
3 |
4 | class AbstractSearch(Module):
5 | """
6 | AbstractSearch is search algorithm on original neural model to perform special inference.
7 | """
8 |
9 | def __init__(self):
10 | super().__init__()
11 | self._mode = 'infer'
12 |
13 | def build(self, *args, **kwargs):
14 | """
15 | Build search algorithm with task instance
16 | """
17 | raise NotImplementedError
18 |
19 | def forward(self, *args, **kwargs):
20 | """
21 | Process forward of search algorithm.
22 | """
23 | raise NotImplementedError
24 |
25 | def reset(self, mode):
26 | """
27 | Reset encoder and switch running mode
28 |
29 | Args:
30 | mode: running mode in [train, valid, infer]
31 | """
32 | self._mode = mode
33 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/search/greedy_search.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 |
4 | from bycha.modules.search import register_search
5 | from bycha.modules.search.sequence_search import SequenceSearch
6 | from bycha.modules.utils import create_init_scores
7 |
8 |
9 | @register_search
10 | class GreedySearch(SequenceSearch):
11 | """
12 | GreedySearch is greedy search on sequence generation.
13 |
14 | Args:
15 | maxlen_coef (a, b): maxlen computation coefficient.
16 | The max length is computed as `(S * a + b)`, where S is source sequence length.
17 | """
18 |
19 | def __init__(self, maxlen_coef=(1.2, 10)):
20 | super().__init__()
21 |
22 | self._maxlen_a, self._maxlen_b = maxlen_coef
23 |
24 | def forward(self,
25 | prev_tokens,
26 | memory,
27 | memory_padding_mask,
28 | target_mask: Optional[torch.Tensor] = None,
29 | prev_scores: Optional[torch.Tensor] = None):
30 | """
31 | Decoding full-step sequence with greedy search
32 |
33 | Args:
34 | prev_tokens: previous tokens or prefix of sequence
35 | memory: memory for attention.
36 | :math:`(M, N, E)`, where M is the memory sequence length, N is the batch size,
37 | memory_padding_mask: memory sequence padding mask.
38 | :math:`(N, M)` where M is the memory sequence length, N is the batch size.
39 | target_mask: target mask indicating blacklist tokens
40 | :math:`(B, V)` where B is batch size and V is vocab size
41 | prev_scores: scores of previous tokens
42 | :math:`(B)` where B is batch size
43 |
44 | Returns:
45 | - log probability of generated sequence
46 | - generated sequence
47 | """
48 | batch_size = prev_tokens.size(0)
49 | scores = create_init_scores(prev_tokens, memory) if prev_scores is None else prev_scores
50 | for _ in range(int(memory.size(0) * self._maxlen_a + self._maxlen_b)):
51 | logits = self._decoder(prev_tokens, memory, memory_padding_mask)
52 | logits = logits[:, -1, :]
53 | if target_mask is not None:
54 | logits = logits.masked_fill(target_mask, float('-inf'))
55 | next_word_scores, words = logits.max(dim=-1)
56 | eos_mask = words.eq(self._eos)
57 | if eos_mask.long().sum() == batch_size:
58 | break
59 | scores = scores + next_word_scores.masked_fill_(eos_mask, 0.).view(-1)
60 | prev_tokens = torch.cat([prev_tokens, words.unsqueeze(dim=-1)], dim=-1)
61 | return scores, prev_tokens
62 |
--------------------------------------------------------------------------------
/mybycha/bycha/modules/search/sequence_search.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from torch import Tensor
4 |
5 | from bycha.modules.search import AbstractSearch
6 |
7 |
8 | class SequenceSearch(AbstractSearch):
9 | """
10 | SequenceSearch algorithms are used to generate a complete sequence with strategies.
11 | It usually built from a one-step neural model and fledges the model to a full-step generation.
12 | """
13 |
14 | def __init__(self):
15 | super().__init__()
16 |
17 | self._decoder = None
18 | self._bos, self._eos, self._pad = None, None, None
19 |
20 | def build(self, decoder, bos, eos, pad, *args, **kwargs):
21 | """
22 | Build the search algorithm with task instances.
23 |
24 | Args:
25 | decoder: decoder of neural model.
26 | bos: begin-of-sentence index
27 | eos: end-of-sentence index
28 | pad: pad index
29 | """
30 | self._decoder = decoder
31 | self._bos, self._eos, self._pad = bos, eos, pad
32 |
33 | def forward(self,
34 | prev_tokens: Tensor,
35 | memory: Tensor,
36 | memory_padding_mask: Tensor,
37 | target_mask: Optional[Tensor] = None,
38 | prev_scores: Optional[Tensor] = None):
39 | """
40 | Decoding full-step sequence
41 |
42 | Args:
43 | prev_tokens: previous tokens or prefix of sequence
44 | memory: memory for attention.
45 | :math:`(M, N, E)`, where M is the memory sequence length, N is the batch size,
46 | memory_padding_mask: memory sequence padding mask.
47 | :math:`(N, M)` where M is the memory sequence length, N is the batch size.
48 | target_mask: target mask indicating blacklist tokens
49 | :math:`(B, V)` where B is batch size and V is vocab size
50 | prev_scores: scores of previous tokens
51 | :math:`(B)` where B is batch size
52 |
53 | Returns:
54 | - log probability of generated sequence
55 | - generated sequence
56 | """
57 | raise NotImplementedError
58 |
59 | def reset(self, mode):
60 | """
61 | Reset encoder and switch running mode
62 |
63 | Args:
64 | mode: running mode in [train, valid, infer]
65 | """
66 | self._mode = mode
67 | self._decoder.reset(mode)
68 |
69 |
--------------------------------------------------------------------------------
/mybycha/bycha/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 | from bycha.utils.runtime import Environment
6 |
7 | from .abstract_sampler import AbstractSampler
8 | from .distributed_sampler import DistributedSampler
9 |
10 | register_sampler, _create_sampler, registry = setup_registry('sampler', AbstractSampler)
11 |
12 |
13 | def create_sampler(configs, is_training=False):
14 | """
15 | Create a sampler.
16 | Note in distributed training, sampler should be further wrapped with a DistributedSampler.
17 |
18 | Args:
19 | configs: sampler configuration
20 | is_training: whether the sampler is used for training.
21 |
22 | Returns:
23 | a data sampler
24 | """
25 | sampler = _create_sampler(configs)
26 | env = Environment()
27 | if env.distributed_world > 1 and is_training:
28 | sampler = DistributedSampler(sampler)
29 | return sampler
30 |
31 |
32 | modules_dir = os.path.dirname(__file__)
33 | for file in os.listdir(modules_dir):
34 | path = os.path.join(modules_dir, file)
35 | if (
36 | not file.startswith('_')
37 | and not file.startswith('.')
38 | and (file.endswith('.py') or os.path.isdir(path))
39 | ):
40 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
41 | module = importlib.import_module('bycha.samplers.' + module_name)
42 |
--------------------------------------------------------------------------------
/mybycha/bycha/samplers/batch_shuffle_sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from bycha.samplers import register_sampler
4 | from bycha.samplers.sequential_sampler import SequentialSampler
5 |
6 |
7 | @register_sampler
8 | class BatchShuffleSampler(SequentialSampler):
9 | """
10 | BatchShuffleSampler pre-compute all the batch sequentially,
11 | and shuffle the reading order of batches before an new round of iteration.
12 | """
13 |
14 | def __init__(self, **kwargs):
15 | super().__init__(**kwargs)
16 |
17 | @property
18 | def batch_sampler(self):
19 | """
20 | Pre-calculate batches within sampler with strategy
21 |
22 | Returns:
23 | batches: a list of batches of index
24 | """
25 | batches = super().batch_sampler
26 | random.shuffle(batches)
27 | return batches
28 |
--------------------------------------------------------------------------------
/mybycha/bycha/samplers/bucket_sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from bycha.samplers import AbstractSampler, register_sampler
4 | from bycha.utils.runtime import Environment
5 |
6 |
7 | @register_sampler
8 | class BucketSampler(AbstractSampler):
9 | """
10 | BucketSampler put samples of similar size into a bucket to lift computational efficiency and to accelerate training.
11 |
12 | Args:
13 | noise: inject noise when create buckets for each iteration.
14 | """
15 |
16 | def __init__(self, noise=0., **kwargs):
17 | super().__init__(**kwargs)
18 | self._noise = noise
19 |
20 | def build(self, data_source):
21 | """
22 | Build sampler over data_source
23 |
24 | Args:
25 | data_source: a list of data
26 | """
27 | self._data_source = data_source
28 | self._length = len(self._data_source)
29 | self.reset(0)
30 |
31 | def reset(self, epoch, *args, **kwargs):
32 | """
33 | Resetting sampler states / shuffle reading order for next round of iteration
34 |
35 | Args:
36 | epoch: iteration epoch
37 | """
38 | env = Environment()
39 | random.seed(env.seed + epoch)
40 | token_nums = [(i, self._inject_noise(sample['token_num'])) for i, sample in enumerate(self._data_source)]
41 | token_nums.sort(key=lambda x: x[1], reverse=True)
42 | self._permutation = [idx for idx, _ in token_nums]
43 |
44 | @property
45 | def batch_sampler(self):
46 | """
47 | Pre-calculate batches within sampler with strategy
48 |
49 | Returns:
50 | batches: a list of batches of index
51 | """
52 | batches = super().batch_sampler
53 | random.shuffle(batches)
54 | return batches
55 |
56 | def _inject_noise(self, x):
57 | """
58 | Disturb size
59 |
60 | Args:
61 | x: size
62 |
63 | Returns:
64 | disturbed size
65 | """
66 | if self._noise > 0:
67 | variance = int(x * self._noise)
68 | r = random.randint(-variance, variance)
69 | return x + r
70 | else:
71 | return x
72 |
--------------------------------------------------------------------------------
/mybycha/bycha/samplers/sequential_sampler.py:
--------------------------------------------------------------------------------
1 | from bycha.samplers import AbstractSampler, register_sampler
2 |
3 |
4 | @register_sampler
5 | class SequentialSampler(AbstractSampler):
6 | """
7 | SequentialSampler iterates on samples sequentially.
8 | """
9 |
10 | def __init__(self, **kwargs):
11 | super().__init__(**kwargs)
12 |
13 | def build(self, data_source):
14 | """
15 | Build sampler over data_source
16 |
17 | Args:
18 | data_source: a list of data
19 | """
20 | self._data_source = data_source
21 | self._permutation = [_ for _ in range(len(self._data_source))]
22 | self._length = len(self._permutation)
23 |
--------------------------------------------------------------------------------
/mybycha/bycha/samplers/shuffled_sampler.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from bycha.samplers import AbstractSampler, register_sampler
4 | from bycha.utils.runtime import Environment
5 |
6 |
7 | @register_sampler
8 | class ShuffleSampler(AbstractSampler):
9 | """
10 | ShuffleSampler shuffle the order before fetching samples.
11 | """
12 |
13 | def __init__(self, **kwargs):
14 | super().__init__(**kwargs)
15 | self._env = Environment()
16 |
17 | def build(self, data_source):
18 | """
19 | Build sampler over data_source
20 |
21 | Args:
22 | data_source: a list of data
23 | """
24 | self._data_source = data_source
25 | self._permutation = [_ for _ in range(len(self._data_source))]
26 | self._length = len(self._permutation)
27 | self.reset(0)
28 |
29 | def reset(self, epoch, *args, **kwargs):
30 | """
31 | Resetting sampler states / shuffle reading order for next round of iteration
32 | """
33 | random.seed(self._env.seed + epoch)
34 | random.shuffle(self._permutation)
35 |
--------------------------------------------------------------------------------
/mybycha/bycha/services/__init__.py:
--------------------------------------------------------------------------------
1 | import euler
2 |
3 | from .idls.bycha_thrift import Request, Response, Service
4 | from .server import Server
5 |
--------------------------------------------------------------------------------
/mybycha/bycha/services/idls/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/services/idls/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/services/idls/bycha.thrift:
--------------------------------------------------------------------------------
1 | namespace py Thrift
2 |
3 | struct Request{
4 | // a json-dumped string for a batch of samples
5 | 1: required string samples;
6 | }
7 |
8 | struct Response{
9 | // a json-dumped string for a batch of results
10 | 1: string results;
11 | 2: string debug_info;
12 | // message for this request, "Success" or others
13 | 4: i32 code;
14 | }
15 |
16 | service Service{
17 | // infer the result score of title
18 | Response serve(1:Request req)
19 | }
20 |
--------------------------------------------------------------------------------
/mybycha/bycha/services/idls/model_infer.thrift:
--------------------------------------------------------------------------------
1 | namespace py background_model_infer
2 |
3 | service ModelInfer {
4 | binary infer(1:binary samples)
5 | }
--------------------------------------------------------------------------------
/mybycha/bycha/services/model_server.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import pickle
3 | import logging
4 | logger = logging.getLogger(__name__)
5 |
6 | import torch
7 |
8 | from bycha.utils.ops import auto_map_args
9 | from bycha.utils.runtime import Environment
10 | from bycha.utils.tensor import to_device
11 |
12 |
13 | class ModelServer:
14 | """
15 | ModelServer is a thrift server running neural model at backend.
16 |
17 | Args:
18 | generator: neural inference model
19 | """
20 |
21 | def __init__(self, generator):
22 | self._generator = generator
23 | self._env = Environment()
24 |
25 | self._generator.eval()
26 |
27 | def infer(self, net_input):
28 | """
29 | Inference with neural model.
30 |
31 | Args:
32 | net_input: neural model
33 |
34 | Returns:
35 | - neural output
36 | """
37 | try:
38 | net_input = pickle.loads(net_input)
39 | if isinstance(net_input, Dict):
40 | net_input = auto_map_args(net_input, self._generator.input_slots)
41 | net_input = to_device(net_input, self._env.device, fp16=self._env.fp16)
42 | with torch.no_grad():
43 | self._generator.reset('infer')
44 | net_output = self._generator(*net_input)
45 | net_output = to_device(net_output, 'cpu')
46 | net_output = pickle.dumps(net_output)
47 | return net_output
48 | except Exception as e:
49 | logger.warning(str(e))
50 | return None
51 |
--------------------------------------------------------------------------------
/mybycha/bycha/services/server.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 | import logging
5 | logger = logging.getLogger(__name__)
6 |
7 | from thriftpy.rpc import make_client
8 | import thriftpy
9 |
10 | from bycha.services import Request, Response
11 | from bycha.tasks import create_task
12 | from bycha.utils.runtime import build_env, Environment
13 |
14 |
15 | class Server:
16 | """
17 | Server make the task a interactive service, and can be deployed online to serve requests.
18 |
19 | Args:
20 | configs: configurations to build a task
21 | """
22 |
23 | def __init__(self, configs):
24 | self._configs = configs
25 |
26 | if 'env' in self._configs:
27 | build_env(**self._configs['env'])
28 | task_config = self._configs.pop('task')
29 | task_config.pop('generator')
30 | task_config.pop('model')
31 |
32 | self._task = create_task(task_config)
33 | self._task.build()
34 | self._task.reset(training=False, infering=True)
35 |
36 | self._env = Environment()
37 |
38 | def serve(self, request):
39 | """
40 | Serve a request
41 |
42 | Args:
43 | request (Request): a request for serving.
44 | It must contain a jsonable attribute named `samples` indicating a batch of unprocessed samples.
45 |
46 | Returns:
47 | response (Response): a response to the given request.
48 | """
49 | response = Response(results='')
50 | try:
51 | logger.info('receive request {}'.format(request))
52 | generator = _build_backend_generator_service()
53 | samples = request.samples
54 | samples = json.loads(samples)
55 | samples = self._task.preprocess(samples)
56 | samples = pickle.dumps(samples['net_input'])
57 | results = generator.infer(samples)
58 | results = pickle.loads(results)
59 | debug_info = {'net_output': results.tolist()}
60 | results = self._task.postprocess(results)
61 | response = Response(results=json.dumps(results),
62 | debug_info=json.dumps(debug_info) if self._env.debug else None)
63 | logger.info('return response {}'.format(response))
64 | except Exception as e:
65 | logger.warning(str(e))
66 | finally:
67 | return response
68 |
69 |
70 | def _build_backend_generator_service():
71 | """
72 | Create a service to connect backend neural model
73 |
74 | Returns:
75 | - a thrift client connecting neural model
76 | """
77 | backgronud_model_infer_thrift = thriftpy.load("/opt/tiger/ByCha/bycha/services/idls/model_infer.thrift",
78 | module_name="model_infer_thrift")
79 | grpc_port = int(os.environ.get('GRPC_PORT', 6000))
80 | generator = make_client(backgronud_model_infer_thrift.ModelInfer,
81 | 'localhost',
82 | grpc_port,
83 | timeout=100000000)
84 | return generator
85 |
86 |
--------------------------------------------------------------------------------
/mybycha/bycha/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | TRAIN, VALID, EVALUATE, SERVE = 'TRAIN', 'VALID', 'EVALUATE', 'SERVE'
7 |
8 | from .abstract_task import AbstractTask
9 |
10 | register_task, create_task, registry = setup_registry('task', AbstractTask)
11 |
12 | modules_dir = os.path.dirname(__file__)
13 | for file in os.listdir(modules_dir):
14 | path = os.path.join(modules_dir, file)
15 | if (
16 | not file.startswith('_')
17 | and not file.startswith('.')
18 | and (file.endswith('.py') or os.path.isdir(path))
19 | ):
20 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
21 | module = importlib.import_module('bycha.tasks.' + module_name)
22 |
23 |
--------------------------------------------------------------------------------
/mybycha/bycha/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_tokenizer import AbstractTokenizer
7 |
8 | register_tokenizer, create_tokenizer, registry = setup_registry('tokenizer', AbstractTokenizer)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.tokenizers.' + module_name)
20 |
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/tokenizers/abstract_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class AbstractTokenizer:
5 | """
6 | Tokenizer provides a tokenization pipeline.
7 | """
8 |
9 | def __init__(self):
10 | pass
11 |
12 | def build(self, *args, **kwargs):
13 | """
14 | Build tokenizer.
15 | """
16 | pass
17 |
18 | def __len__(self):
19 | raise NotImplementedError
20 |
21 | def encode(self, *args) -> List[int]:
22 | """
23 | Encode a textual sentence into a list of index.
24 | """
25 | raise NotImplementedError
26 |
27 | def decode(self, x: List[int]) -> str:
28 | """
29 | Decode a list of index back into a textual sentence
30 |
31 | Args:
32 | x: a list of index
33 | """
34 | raise NotImplementedError
35 |
36 | def tok(self, *args) -> str:
37 | """
38 | Tokenize a textual sentence without index mapping.
39 | """
40 | out = []
41 | for ext in args:
42 | out += ext
43 | return out
44 |
45 | def detok(self, x: str) -> str:
46 | """
47 | Detokenize a textual sentence without index mapping.
48 |
49 | Args:
50 | x: a textual sentence
51 | """
52 | return x
53 |
54 | def token2index(self, *args) -> List[int]:
55 | """
56 | Only map a textual sentence to index
57 | """
58 | raise NotImplementedError
59 |
60 | def index2token(self, x: List[int]) -> str:
61 | """
62 | Only map a list of index back into a textual sentence
63 |
64 | Args:
65 | x: a list of index
66 | """
67 | raise NotImplementedError
68 |
69 | @staticmethod
70 | def learn(*args, **kwargs):
71 | """
72 | Learn a tokenizer from data set.
73 | """
74 | raise NotImplementedError
75 |
76 | @property
77 | def special_tokens(self):
78 | return {
79 | 'bos': self.bos,
80 | 'eos': self.eos,
81 | 'pad': self.pad,
82 | 'unk': self.unk
83 | }
84 |
85 | @property
86 | def bos(self):
87 | raise NotImplementedError
88 |
89 | @property
90 | def eos(self):
91 | raise NotImplementedError
92 |
93 | @property
94 | def unk(self):
95 | raise NotImplementedError
96 |
97 | @property
98 | def pad(self):
99 | raise NotImplementedError
100 |
101 | @property
102 | def bos_token(self):
103 | raise NotImplementedError
104 |
105 | @property
106 | def eos_token(self):
107 | raise NotImplementedError
108 |
109 | @property
110 | def unk_token(self):
111 | raise NotImplementedError
112 |
113 | @property
114 | def pad_token(self):
115 | raise NotImplementedError
116 |
--------------------------------------------------------------------------------
/mybycha/bycha/tokenizers/huggingface_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from transformers import AutoTokenizer
6 |
7 | from bycha.tokenizers import AbstractTokenizer, register_tokenizer
8 |
9 |
10 | @register_tokenizer
11 | class HuggingfaceTokenizer(AbstractTokenizer):
12 | """
13 | HuggingfaceTokenizer use `huggingface/transformers` lib to do tokenization
14 | see huggingface/transformers(https://github.com/huggingface/transformers)
15 |
16 | Args:
17 | tokenizer_name: tokenizer names
18 | """
19 |
20 | def __init__(self, tokenizer_name, *args, **kwargs):
21 | super().__init__()
22 | self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs)
23 |
24 | @staticmethod
25 | def learn(*args, **kwargs):
26 | """
27 | HuggingfaceTokenizer are used for pretrained model, and is usually directly load from huggingface.
28 | """
29 | logger.info('learn vocab not supported for huggingface tokenizer')
30 | raise NotImplementedError
31 |
32 | def __len__(self):
33 | return len(self._tokenizer)
34 |
35 | def encode(self, input, *args, **kwargs) -> List[int]:
36 | """
37 | Encode a textual sentence into a list of index.
38 | """
39 | if len(input) == 1:
40 | input = input[0]
41 | return self._tokenizer.encode(input, *args, **kwargs)
42 |
43 | def decode(self, *args, **kwargs) -> str:
44 | """
45 | Decode a list of index back into a textual sentence
46 | """
47 | return self._tokenizer.decode(*args, **kwargs)
48 |
49 | def __call__(self, *args, **kwargs):
50 | return self._tokenizer(*args, **kwargs)
51 |
52 | def token2index(self, *args) -> List[int]:
53 | """
54 | Only map a textual sentence to index
55 | """
56 | return self.encode(*args)[1:-1]
57 |
58 | @property
59 | def max_length(self):
60 | return self._tokenizer.model_max_length
61 |
62 | @property
63 | def bos(self):
64 | return self._tokenizer.bos_token_id
65 |
66 | @property
67 | def eos(self):
68 | return self._tokenizer.eos_token_id
69 |
70 | @property
71 | def unk(self):
72 | return self._tokenizer.unk_token_id
73 |
74 | @property
75 | def pad(self):
76 | return self._tokenizer.pad_token_id
77 |
78 | @property
79 | def bos_token(self):
80 | return self._tokenizer.bos_token
81 |
82 | @property
83 | def eos_token(self):
84 | return self._tokenizer.eos_token
85 |
86 | @property
87 | def unk_token(self):
88 | return self._tokenizer.unk_token
89 |
90 | @property
91 | def pad_token(self):
92 | return self._tokenizer.pad_token
93 |
--------------------------------------------------------------------------------
/mybycha/bycha/tokenizers/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | SPACE_NORMALIZER = re.compile("\s+")
4 |
5 | SPECIAL_SYMBOLS = ['', '', '', '']
6 |
7 |
--------------------------------------------------------------------------------
/mybycha/bycha/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_trainer import AbstractTrainer
7 |
8 | register_trainer, create_trainer, registry = setup_registry('trainer', AbstractTrainer)
9 |
10 | modules_dir = os.path.dirname(__file__)
11 | for file in os.listdir(modules_dir):
12 | path = os.path.join(modules_dir, file)
13 | if (
14 | not file.startswith('_')
15 | and not file.startswith('.')
16 | and (file.endswith('.py') or os.path.isdir(path))
17 | ):
18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
19 | module = importlib.import_module('bycha.trainers.' + module_name)
20 |
21 |
--------------------------------------------------------------------------------
/mybycha/bycha/trainers/moe_trainer.py:
--------------------------------------------------------------------------------
1 | from bycha.trainers.trainer import Trainer
2 | from bycha.trainers import register_trainer
3 |
4 | @register_trainer
5 | class MoETrainer(Trainer):
6 |
7 | def __init__(self,
8 | load_balance_alpha=0.,
9 | *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 | self._load_balance_alpha = load_balance_alpha
12 |
13 | def _forward_loss(self, samples):
14 | """
15 | Forward neural model and compute the loss of given samples
16 |
17 | Args:
18 | samples: a batch of samples
19 |
20 | Returns:
21 | - derived loss as torch.Tensor
22 | - states for updating log
23 | """
24 | loss, logging_states = self._criterion(**samples)
25 | loss, logging_states = self._load_balance_loss(loss, logging_states)
26 | return loss, logging_states
27 |
28 | def _load_balance_loss(self, loss, logging_states):
29 | moe_loss = self._model._encoder.moe_loss + self._model._decoder.moe_loss
30 | moe_loss /= 2
31 | loss += moe_loss*self._load_balance_alpha
32 | logging_states['moe_loss'] = moe_loss.data.item()
33 | return loss, logging_states
34 |
35 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/utils/__init__.py
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 |
4 | from bycha.utils.registry import setup_registry
5 |
6 | from .abstract_rate_scheduler import AbstractRateScheduler
7 |
8 | register_rate_scheduler, _create_rate_scheduler, registry = setup_registry('rate_scheduler', AbstractRateScheduler)
9 |
10 |
11 | def create_rate_scheduler(configs):
12 | if isinstance(configs, float):
13 | configs = {'class': 'ConstantRateScheduler', 'rate': configs}
14 | rate_schduler = _create_rate_scheduler(configs)
15 | return rate_schduler
16 |
17 |
18 | modules_dir = os.path.dirname(__file__)
19 | for file in os.listdir(modules_dir):
20 | path = os.path.join(modules_dir, file)
21 | if (
22 | not file.startswith('_')
23 | and not file.startswith('.')
24 | and (file.endswith('.py') or os.path.isdir(path))
25 | ):
26 | module_name = file[:file.find('.py')] if file.endswith('.py') else file
27 | module = importlib.import_module('bycha.utils.rate_schedulers.' + module_name)
28 |
29 |
30 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/abstract_rate_scheduler.py:
--------------------------------------------------------------------------------
1 | class AbstractRateScheduler:
2 | """
3 | AbstractRateScheduler is an auxiliary tools for adjust rate.
4 |
5 | Args:
6 | rate: initial rate
7 | """
8 |
9 | def __init__(self, rate: float = 0., *args, **kwargs):
10 | self._rate: float = rate
11 |
12 | def build(self, *args, **kwargs):
13 | """
14 | Build rate scheduler
15 | """
16 | pass
17 |
18 | def step_update(self, step, *args, **kwargs):
19 | """
20 | Update inner rate with outside states at each step
21 |
22 | Args:
23 | step: training step
24 | """
25 | pass
26 |
27 | def step_reset(self, step, *args, **kwargs):
28 | """
29 | Reset inner rate with outside states at each step
30 |
31 | Args:
32 | step: training step
33 | """
34 | pass
35 |
36 | def epoch_update(self, epoch, *args, **kwargs):
37 | """
38 | Update inner rate with outside states at each epoch
39 |
40 | Args:
41 | epoch: training epoch
42 | """
43 | pass
44 |
45 | def epoch_reset(self, epoch, *args, **kwargs):
46 | """
47 | Update inner rate with outside states at each epoch
48 |
49 | Args:
50 | epoch: training epoch
51 | """
52 | pass
53 |
54 | @property
55 | def rate(self):
56 | return self._rate
57 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/constant_rate_scheduler.py:
--------------------------------------------------------------------------------
1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler
2 |
3 |
4 | @register_rate_scheduler
5 | class ConstantRateScheduler(AbstractRateScheduler):
6 | """
7 | ConstantRateScheduler do no schedule rate.
8 |
9 | Args:
10 | rate: constant rate
11 | """
12 |
13 | def __init__(self, rate):
14 | super().__init__(rate)
15 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/inverse_square_root_rate_scheduler.py:
--------------------------------------------------------------------------------
1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler
2 |
3 |
4 | @register_rate_scheduler
5 | class InverseSquareRootRateScheduler(AbstractRateScheduler):
6 | """
7 | InverseSquareRootRateScheduler first linearly warm up rate and decay the rate in square root.
8 |
9 | Args:
10 | rate: maximum rate
11 | warmup_steps: number of updates in warming up
12 | """
13 |
14 | def __init__(self, rate, warmup_steps=1000):
15 | super().__init__(rate)
16 | self._warmup_steps = warmup_steps
17 |
18 | self._lr_step, self._decay_factor = None, None
19 |
20 | def build(self):
21 | """
22 | Build rate scheduler
23 | """
24 | self._lr_step = self._rate / self._warmup_steps
25 | self._decay_factor = self._rate * self._warmup_steps ** 0.5
26 | self._rate = 0.
27 |
28 | def step_update(self, step, *args, **kwargs):
29 | """
30 | Update inner rate with outside states at each step
31 |
32 | Args:
33 | step: training step
34 | """
35 | if step < self._warmup_steps:
36 | self._rate = step * self._lr_step
37 | else:
38 | self._rate = self._decay_factor * step ** -0.5
39 |
40 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/logistic_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler
4 |
5 |
6 | @register_rate_scheduler
7 | class LogisticScheduler(AbstractRateScheduler):
8 | """
9 | LogisticScheduler scheduler the rate with logistic decay.
10 |
11 | Args:
12 | k: decaying weight
13 | x0: bias
14 | """
15 |
16 | def __init__(self, k=0.0025, x0=4000):
17 | super().__init__(0.)
18 | self._k = k
19 | self._x0 = x0
20 |
21 | def step_update(self, step, *args, **kwargs):
22 | """
23 | Update inner rate with outside states at each step
24 |
25 | Args:
26 | step: training step
27 | """
28 | self._rate = 1 / (1 + math.exp(-self._k * (step - self._x0)))
29 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/noam_scheduler.py:
--------------------------------------------------------------------------------
1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler
2 |
3 |
4 | @register_rate_scheduler
5 | class NoamScheduler(AbstractRateScheduler):
6 | """
7 | NoamScheduler is a scheduling methods proposed by Noam
8 |
9 | Args:
10 | d_model: neural model feature dimension
11 | warmup_steps: training steps in warming up
12 | """
13 |
14 | def __init__(self, d_model, warmup_steps=4000):
15 | super().__init__(0.)
16 | self._warmup_steps = warmup_steps
17 | self._d_model = d_model
18 |
19 | def step_update(self, step, *args, **kwargs):
20 | """
21 | Update inner rate with outside states at each step
22 |
23 | Args:
24 | step: training step
25 | """
26 | self._rate = (self._d_model ** -0.5) * min([step ** -0.5, step * (self._warmup_steps ** -1.5)])
27 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/rate_schedulers/polynomial_decay_scheduler.py:
--------------------------------------------------------------------------------
1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler
2 |
3 |
4 | @register_rate_scheduler
5 | class PolynomialDecayScheduler(AbstractRateScheduler):
6 | """
7 | PolynomialDecaySchedulaer first linearly warm up rate, then decay the rate polynomailly and
8 | finally keep at an minimum rate.
9 |
10 | Args:
11 | max_rate: maximum rate
12 | total_steps: total training steps
13 | warmup_steps: number of updates in warming up
14 | end_rate: minimum rate at end
15 | power: polynomial decaying power
16 | """
17 |
18 | def __init__(self, max_rate, total_steps, warmup_steps=4000, end_rate=0.0, power=1.0,):
19 | super().__init__(0.)
20 | self._max_rate = max_rate
21 | self._total_steps = total_steps
22 | self._warmup_steps = warmup_steps
23 | self._end_rate = end_rate
24 | self._power = power
25 |
26 | def step_update(self, step, *args, **kwargs):
27 | """
28 | Update inner rate with outside states at each step
29 |
30 | Args:
31 | step: training step
32 | """
33 | if self._warmup_steps > 0 and step <= self._warmup_steps:
34 | warmup_factor = step / float(self._warmup_steps)
35 | self._rate = warmup_factor * self._max_rate
36 | elif step >= self._total_steps:
37 | self._rate = self._end_rate
38 | else:
39 | rate_range = self._max_rate - self._end_rate
40 | pct_remaining = 1 - (step - self._warmup_steps) / (self._total_steps - self._warmup_steps)
41 | self._rate = rate_range * pct_remaining ** self._power + self._end_rate
42 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/registry.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | logger = logging.getLogger(__name__)
4 |
5 | from bycha.utils.data import possible_eval
6 | from bycha.utils.io import jsonable
7 | from bycha.utils.ops import deepcopy_on_ref
8 |
9 | MODULE_REGISTRY = {}
10 |
11 |
12 | def setup_registry(registry, base_cls, force_extend=True):
13 | """
14 | Set up registry for a certain class
15 |
16 | Args:
17 | registry: registry name
18 | base_cls: base class of a certain class
19 | force_extend: force a new class extend the base class
20 |
21 | Returns:
22 | - decorator to register a subclass
23 | - function to create a subclass with configurations
24 | - registry dictionary
25 | """
26 |
27 | if registry not in MODULE_REGISTRY:
28 | MODULE_REGISTRY[registry] = {}
29 |
30 | def register_cls(cls):
31 | """
32 | Register a class with its name
33 |
34 | Args:
35 | cls: a new class fro registration
36 | """
37 | name = cls.__name__.lower()
38 | if name in MODULE_REGISTRY[registry]:
39 | raise ValueError('Cannot register duplicate {} class ({})'.format(registry, name))
40 | if force_extend and not issubclass(cls, base_cls):
41 | raise ValueError('Class {} must extend {}'.format(name, base_cls.__name__))
42 | if name in MODULE_REGISTRY[registry]:
43 | raise ValueError('Cannot register class with duplicate class name ({})'.format(name))
44 | MODULE_REGISTRY[registry][name] = cls
45 | return cls
46 |
47 | def create_cls(configs=None):
48 | """
49 | Create a class with configuration
50 |
51 | Args:
52 | configs: configuration dictionary for building class
53 |
54 | Returns:
55 | - an instance of class
56 | """
57 | configs = deepcopy_on_ref(configs)
58 | name = configs.pop('class')
59 | json_configs = {k: v for k, v in configs.items() if jsonable(k) and jsonable(v)}
60 | logger.info('Creating {} class with configs \n{}\n'.format(name, json.dumps(json_configs, indent=4, sort_keys=True)))
61 | assert name.lower() in MODULE_REGISTRY[registry], f"{name} is not implemented in ByCha"
62 | cls = MODULE_REGISTRY[registry][name.lower()]
63 | kwargs = {}
64 | for k, v in configs.items():
65 | kwargs[k] = possible_eval(v)
66 | return cls(**kwargs)
67 |
68 | return register_cls, create_cls, MODULE_REGISTRY[registry]
69 |
--------------------------------------------------------------------------------
/mybycha/bycha/utils/txc_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from bycha.utils.io import UniIO
4 |
5 |
6 | def build_txc(config):
7 | """
8 | Build a txc structure from config
9 |
10 | Args:
11 | config: txc configuration path
12 |
13 | Returns:
14 | - process function
15 | """
16 | from txc.modules.structures import create_structure
17 | from txc.runtime import build_env
18 | build_env(mode='infer')
19 | if not config:
20 | return None
21 | with UniIO(config) as fin:
22 | config = json.load(fin)
23 | txc = create_structure(config)
24 |
25 | def _process(*args):
26 | for x, unit_in_arg in zip(args, txc.input_names):
27 | txc.buffer_in[unit_in_arg] = x
28 | txc.run()
29 | out = txc.fetch_buffer_out(name=txc.output_names[0])
30 | return out
31 |
32 | return _process
33 |
--------------------------------------------------------------------------------
/mybycha/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | tensorflow==2.4.0
4 | tqdm
5 | sacremoses
6 | sacrebleu
7 | sentencepiece
8 | fastBPE
9 | scipy
10 | scikit-learn
11 | mosestokenizer
12 | nltk
13 | pyyaml
14 | mpi4py
15 | numpy==1.19.2
16 | more-itertools
17 | tabulate
18 | lightseq
19 |
--------------------------------------------------------------------------------
/mybycha/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="bycha",
5 | version="0.1.0",
6 | keywords=["Natural Language Processing", "Machine Learning"],
7 | license="MIT Licence",
8 | packages=find_packages(),
9 | include_package_data=True,
10 | platforms="any",
11 | install_requires=open("requirements.txt").readlines(),
12 | zip_safe=False,
13 |
14 | scripts=[],
15 | entry_points={
16 | 'console_scripts': [
17 | 'bycha-run = bycha.entries.run:main',
18 | 'bycha-export = bycha.entries.export:main',
19 | 'bycha-preprocess = bycha.entries.preprocess:main',
20 | 'bycha-serve = bycha.entries.serve:main',
21 | 'bycha-serve-model = bycha.entries.serve_model:main',
22 | 'bycha-build-tokenizer = bycha.entries.build_tokenizer:main',
23 | 'bycha-binarize-data = bycha.entries.binarize_data:main'
24 | ]
25 | }
26 | )
27 |
--------------------------------------------------------------------------------
/pics/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/pics/overview.png
--------------------------------------------------------------------------------
/pics/sketch_and_generate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/pics/sketch_and_generate.png
--------------------------------------------------------------------------------
/preparation/fragmenizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .brics_fragmenizer import BRICS_Fragmenizer
2 | from .ring_r_fragmenizer import RING_R_Fragmenizer
3 | from .brics_ring_r_fragmenizer import BRICS_RING_R_Fragmenizer
4 | from .atom_fragmenizer import ATOM_Fragmenizer
5 |
--------------------------------------------------------------------------------
/preparation/fragmenizer/atom_fragmenizer.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 |
3 | class ATOM_Fragmenizer():
4 | def __init__(self):
5 | self.type = 'Atom_Fragmenizers'
6 |
7 | def get_bonds(self, mol):
8 | bonds = mol.GetBonds()
9 | return list(bonds)
10 |
11 | def fragmenize(self, mol, dummyStart=1):
12 | bonds = self.get_bonds(mol)
13 | if len(bonds) != 0:
14 | bond_ids = [bond.GetIdx() for bond in bonds]
15 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))]
16 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels)
17 | dummyEnd = dummyStart + len(dummyLabels) - 1
18 | else:
19 | break_mol = mol
20 | dummyEnd = dummyStart - 1
21 | return break_mol, dummyEnd
22 |
--------------------------------------------------------------------------------
/preparation/fragmenizer/brics_fragmenizer.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | from rdkit.Chem.BRICS import FindBRICSBonds
3 |
4 | class BRICS_Fragmenizer():
5 | def __inti__(self):
6 | self.type = 'BRICS_Fragmenizers'
7 |
8 | def get_bonds(self, mol):
9 | bonds = [bond[0] for bond in list(FindBRICSBonds(mol))]
10 | return bonds
11 |
12 | def fragmenize(self, mol, dummyStart=1):
13 | # get bonds need to be break
14 | bonds = [bond[0] for bond in list(FindBRICSBonds(mol))]
15 |
16 | # whether the molecule can really be break
17 | if len(bonds) != 0:
18 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds]
19 |
20 | # break the bonds & set the dummy labels for the bonds
21 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))]
22 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels)
23 | dummyEnd = dummyStart + len(dummyLabels) - 1
24 | else:
25 | break_mol = mol
26 | dummyEnd = dummyStart - 1
27 |
28 | return break_mol, dummyEnd
29 |
--------------------------------------------------------------------------------
/preparation/fragmenizer/brics_ring_r_fragmenizer.py:
--------------------------------------------------------------------------------
1 | from fragmenizer import BRICS_Fragmenizer, RING_R_Fragmenizer
2 | from rdkit import Chem
3 |
4 | # from brics_fragmenizer import BRICS_Fragmenizer
5 | # from ring_r_fragmenizer import RING_R_Fragmenizer
6 |
7 | class BRICS_RING_R_Fragmenizer():
8 | def __init__(self):
9 | self.type = 'BRICS_RING_R_Fragmenizer'
10 | self.brics_fragmenizer = BRICS_Fragmenizer()
11 | self.ring_r_fragmenizer = RING_R_Fragmenizer()
12 |
13 | def fragmenize(self, mol, dummyStart=1):
14 | brics_bonds = self.brics_fragmenizer.get_bonds(mol)
15 | ring_r_bonds = self.ring_r_fragmenizer.get_bonds(mol)
16 | bonds = brics_bonds + ring_r_bonds
17 |
18 | if len(bonds) != 0:
19 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds]
20 | bond_ids = list(set(bond_ids))
21 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))]
22 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels)
23 | dummyEnd = dummyStart + len(dummyLabels) - 1
24 | else:
25 | break_mol = mol
26 | dummyEnd = dummyStart - 1
27 |
28 | return break_mol, dummyEnd
29 |
--------------------------------------------------------------------------------
/preparation/fragmenizer/ring_r_fragmenizer.py:
--------------------------------------------------------------------------------
1 | from rdkit import Chem
2 | from utils import get_rings, get_other_atom_idx, find_parts_bonds
3 | from rdkit.Chem.rdchem import BondType
4 |
5 | class RING_R_Fragmenizer():
6 | def __init__(self):
7 | self.type = 'RING_R_Fragmenizer'
8 |
9 | def bonds_filter(self, mol, bonds):
10 | filted_bonds = []
11 | for bond in bonds:
12 | bond_type = mol.GetBondBetweenAtoms(bond[0], bond[1]).GetBondType()
13 | if not bond_type is BondType.SINGLE:
14 | continue
15 | f_atom = mol.GetAtomWithIdx(bond[0])
16 | s_atom = mol.GetAtomWithIdx(bond[1])
17 | if f_atom.GetSymbol() == '*' or s_atom.GetSymbol() == '*':
18 | continue
19 | if mol.GetBondBetweenAtoms(bond[0], bond[1]).IsInRing():
20 | continue
21 | filted_bonds.append(bond)
22 | return filted_bonds
23 |
24 | def get_bonds(self, mol):
25 | bonds = []
26 | rings = get_rings(mol)
27 | if len(rings) > 0:
28 | for ring in rings:
29 | rest_atom_idx = get_other_atom_idx(mol, ring)
30 | bonds += find_parts_bonds(mol, [rest_atom_idx, ring])
31 | bonds = self.bonds_filter(mol, bonds)
32 | return bonds
33 |
34 | def fragmenize(self, mol, dummyStart=1):
35 | rings = get_rings(mol)
36 | if len(rings) > 0:
37 | bonds = []
38 | for ring in rings:
39 | rest_atom_idx = get_other_atom_idx(mol, ring)
40 | bonds += find_parts_bonds(mol, [rest_atom_idx, ring])
41 | bonds = self.bonds_filter(mol, bonds)
42 | if len(bonds) > 0:
43 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds]
44 | bond_ids = list(set(bond_ids))
45 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))]
46 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels)
47 | dummyEnd = dummyStart + len(dummyLabels) - 1
48 | else:
49 | break_mol = mol
50 | dummyEnd = dummyStart - 1
51 | else:
52 | break_mol = mol
53 | dummyEnd = dummyStart - 1
54 | return break_mol, dummyEnd
55 |
--------------------------------------------------------------------------------
/preparation/get_fragment_vocab.py:
--------------------------------------------------------------------------------
1 | import os
2 | from rdkit import Chem
3 | import gzip
4 | import pickle
5 | from datetime import datetime
6 | from fragmenizer import BRICS_RING_R_Fragmenizer
7 | from utils import centralize, canonical_frag_smi
8 |
9 | data_path = 'ZINC SDF PATH'
10 | save_pkl_path = 'YOUR PATH'
11 | save_pkl_pattern = 'BRICS_RING_R.{}.pkl'
12 | save_pkl_file = os.path.join(save_pkl_path, save_pkl_pattern)
13 | file_list = os.listdir(data_path)[:]
14 |
15 | vocab = dict()
16 |
17 | save_interval = 1000 * 10000
18 | print_interval = 1000
19 | mo_cnt = 0
20 |
21 | fragmenizer = BRICS_RING_R_Fragmenizer()
22 |
23 | start = datetime.now()
24 | for f_idx, file_name in enumerate(file_list):
25 | file_path = os.path.join(data_path, file_name)
26 | gzip_data = gzip.open(file_path)
27 | with Chem.ForwardSDMolSupplier(gzip_data) as mos:
28 | for mo in mos:
29 | if mo is None:
30 | continue
31 | mo_cnt += 1
32 |
33 | frags, _ = fragmenizer.fragmenize(mo)
34 | frags = Chem.GetMolFrags(frags, asMols=True)
35 | for frag in frags:
36 | frag = centralize(frag)
37 | frag_smi = canonical_frag_smi(Chem.MolToSmiles(frag))
38 |
39 | if frag_smi not in vocab:
40 | vocab[frag_smi] = frag
41 |
42 | if mo_cnt % save_interval == 0:
43 | with open(save_pkl_file.format(mo_cnt), 'wb') as fw:
44 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL)
45 |
46 | if mo_cnt % print_interval == 0:
47 | now = datetime.now()
48 | time_interval = (now - start).total_seconds()
49 | print('current {} file {} molecule {:.3f} ms/mol'.format(f_idx, mo_cnt, time_interval * 1000 / print_interval))
50 | start = datetime.now()
51 |
52 | with open(save_pkl_file.format(mo_cnt), 'wb') as fw:
53 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL)
54 |
55 | def mapping_star(mol):
56 | star_mapping = dict()
57 | star_cnt = 0
58 | for atom in mol.GetAtoms():
59 | if atom.GetSymbol() == '*':
60 | star_cnt += 1
61 | star_mapping[star_cnt] = atom.GetSmarts()
62 | star_mapping[atom.GetSmarts()] = star_cnt
63 | return star_mapping
64 |
65 | vocab_w_mapping = {'PAD': [None, None, 0],'UNK': [None, None, 1], 'BOS': [None, None, 2], 'EOS': [None, None, 3], 'BOB': [None, None, 4], 'EOB': [None, None, 5]}
66 | for key in vocab.keys():
67 | star_mapping = mapping_star(vocab[key])
68 | vocab_w_mapping[key] = [vocab[key], star_mapping, len(vocab_w_mapping)]
69 |
70 | with open(save_pkl_file.format('vocab'), 'wb') as fw:
71 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL)
72 |
--------------------------------------------------------------------------------
/preparation/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import find_parts_bonds
2 | from .utils import get_other_atom_idx
3 | from .utils import get_rings
4 | from .utils import get_bonds
5 | from .utils import centralize
6 | from .utils import canonical_frag_smi
7 | from .utils import get_center
8 | from .utils import get_align_points
9 | from .utils import get_tree
10 | from .utils import tree_linearize
11 | from .utils import get_surrogate_frag
12 | from .utils import get_atom_mapping_between_frag_and_surrogate
13 | from .utils_full import get_dock_fast_with_smiles_with_mol
14 |
--------------------------------------------------------------------------------
/preparation/utils/common.py:
--------------------------------------------------------------------------------
1 | # van der Waals radius
2 | ATOM_RADIUS = {
3 | 'C': 1.908,
4 | 'F': 1.75,
5 | 'Cl': 1.948,
6 | 'Br': 2.22,
7 | 'I': 2.35,
8 | 'N': 1.824,
9 | 'O': 1.6612,
10 | 'P': 2.1,
11 | 'S': 2.0,
12 | 'Si': 2.2, # not accurate
13 | 'H': 1.0
14 | }
15 |
16 | # atomic number
17 | ATOMIC_NUMBER = {
18 | 'C': 6,
19 | 'F': 9,
20 | 'Cl': 17,
21 | 'Br': 35,
22 | 'I': 53,
23 | 'N': 7,
24 | 'O': 8,
25 | 'P': 15,
26 | 'S': 16,
27 | 'Si': 14,
28 | 'H': 1
29 | }
30 |
31 | ATOMIC_NUMBER_REVERSE = {v: k for k, v in ATOMIC_NUMBER.items()}
32 |
--------------------------------------------------------------------------------
/shape_pretraining/__init__.py:
--------------------------------------------------------------------------------
1 | from .shape_pretraining_dataset import ShapePretrainingDataset
2 | from .shape_pretraining_model import ShapePretrainingModel
3 | from .shape_pretraining_encoder import ShapePretrainingEncoder
4 | from .shape_pretraining_dataset_shard import ShapePretrainingDatasetShard
5 | from .shape_pretraining_dataloader_shard import ShapePretrainingDataLoaderShard
6 | from .shape_pretraining_task_no_regression import ShapePretrainingTaskNoRegression
7 | from .shape_pretraining_task_no_regression_pocket import ShapePretrainingTaskNoRegressionPocket
8 | from .shape_pretraining_criterion_no_regression import ShapePretrainingCriterionNoRegression
9 | from .shape_pretraining_decoder_iterative_no_regression import ShapePretrainingDecoderIterativeNoRegression
10 | from .shape_pretraining_iterator_no_regression import ShapePretrainingIteratorNoRegression
11 | from .shape_pretraining_search_iterative_no_regression import ShapePretrainingSearchIterativeNoRegression
12 | from .shape_pretraining_dataset_pocket import ShapePretrainingDatasetPocket
--------------------------------------------------------------------------------
/shape_pretraining/common.py:
--------------------------------------------------------------------------------
1 | # van der Waals radius
2 | ATOM_RADIUS = {
3 | 'C': 1.908,
4 | 'F': 1.75,
5 | 'Cl': 1.948,
6 | 'Br': 2.22,
7 | 'I': 2.35,
8 | 'N': 1.824,
9 | 'O': 1.6612,
10 | 'P': 2.1,
11 | 'S': 2.0,
12 | 'Si': 2.2 # not accurate
13 | }
14 |
15 | # atomic number
16 | ATOMIC_NUMBER = {
17 | 'C': 6,
18 | 'F': 9,
19 | 'Cl': 17,
20 | 'Br': 35,
21 | 'I': 53,
22 | 'N': 7,
23 | 'O': 8,
24 | 'P': 15,
25 | 'S': 16,
26 | 'Si': 14
27 | }
28 |
29 | ATOMIC_NUMBER_REVERSE = {v: k for k, v in ATOMIC_NUMBER.items()}
30 |
--------------------------------------------------------------------------------
/shape_pretraining/io.py:
--------------------------------------------------------------------------------
1 | from bycha.utils.io import _InputStream, _OutputStream, _InputBytes, _OutputBytes
2 | import pickle
3 | from bycha.utils.runtime import logger
4 | import random
5 | from bycha.utils.ops import local_seed
6 |
7 | class _MyInputBytes(_InputBytes):
8 | def __init__(self, fake_epoch, *args, shuffle=False, **kwargs):
9 | super().__init__(*args, **kwargs)
10 | self._data_iter = None
11 | self._shuffle = shuffle
12 | self._fake_epoch = fake_epoch
13 |
14 | def __next__(self):
15 | try:
16 | if self._idx >= len(self._fins):
17 | raise IndexError
18 | if not self._data_iter:
19 | data = pickle.load(self._fins[self._idx])
20 | if self._shuffle:
21 | with local_seed((self._fake_epoch * len(self._fins)) + self._idx):
22 | old_state = random.getstate()
23 | random.seed((self._fake_epoch * len(self._fins)) + self._idx)
24 | random.shuffle(data)
25 | random.setstate(old_state)
26 | self._data_iter = iter(data)
27 | sample = next(self._data_iter)
28 | return sample
29 | except StopIteration:
30 | self._idx += 1
31 | self._data_iter = None
32 | sample = self.__next__()
33 | return sample
34 | except IndexError:
35 | raise StopIteration
36 |
37 | def reset(self):
38 | self._idx = 0
39 | for fin in self._fins:
40 | fin.seek(0)
41 | self._data_iter = None
42 |
43 | class MyUniIO(_InputStream, _OutputStream, _MyInputBytes, _OutputBytes):
44 | def __init__(self, path, fake_epoch, mode='r', encoding='utf8', shuffle=False):
45 | pass
46 |
47 | def __new__(cls, path, fake_epoch, mode='r', encoding='utf8', shuffle=False):
48 | if 'r' in mode.lower():
49 | if 'b' in mode.lower():
50 | return _MyInputBytes(fake_epoch, path, mode=mode, shuffle=shuffle)
51 | return _InputStream(path, encoding=encoding)
52 | elif 'w' in mode.lower():
53 | if 'b' in mode.lower():
54 | return _OutputBytes(path, mode=mode)
55 | return _OutputStream(path, encoding=encoding)
56 | logger.warning(f'Not support file mode: {mode}')
57 | raise ValueError
58 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_dataset.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from bycha.datasets import register_dataset
3 | from bycha.datasets.in_memory_dataset import InMemoryDataset
4 | from bycha.utils.runtime import logger, progress_bar
5 | from .utils import get_mol_centroid, centralize
6 |
7 | @register_dataset
8 | class ShapePretrainingDataset(InMemoryDataset):
9 | def __init__(self,
10 | path):
11 | super().__init__(path)
12 |
13 | vocab_path = self._path['vocab']
14 | with open(vocab_path, 'rb') as fr:
15 | self._vocab = pickle.load(fr)
16 |
17 | def _load(self):
18 | self._data = []
19 |
20 | samples_path = self._path['samples']
21 |
22 | with open(samples_path, 'rb') as fr:
23 | samples = pickle.load(fr)
24 |
25 | accecpted, discarded = 0, 0
26 | for i, sample in enumerate(progress_bar(samples, desc='Loading Samples...')):
27 | try:
28 | self._data.append(self._full_callback(sample))
29 | accecpted += 1
30 | except Exception:
31 | logger.warning('sample {} is discarded'.format(i))
32 | discarded += 1
33 |
34 | self._length = len(self._data)
35 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples')
36 |
37 | def _callback(self, sample):
38 | # centralize a molecule and translate its fragments
39 | mol = sample[0]
40 | centroid = get_mol_centroid(mol)
41 | mol = centralize(mol)
42 |
43 | fragment_list = []
44 | for fragment in sample[1]:
45 | if not fragment[3] is None:
46 | trans_vec = fragment[3] - centroid
47 | else:
48 | trans_vec = fragment[3]
49 | fragment_list.append({
50 | 'vocab_id': fragment[0],
51 | 'vocab_key': fragment[1],
52 | 'frag_smi': fragment[2],
53 | 'trans_vec': trans_vec,
54 | 'rotate_mat': fragment[4]
55 | })
56 |
57 | tree_list = sample[2]
58 |
59 | return {
60 | 'mol': mol,
61 | 'frag_list': fragment_list,
62 | 'tree_list': tree_list
63 | }
64 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_dataset_pocket.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from bycha.datasets import register_dataset
3 | from bycha.datasets.in_memory_dataset import InMemoryDataset
4 | from bycha.utils.runtime import logger, progress_bar
5 | from .utils import get_mol_centroid, centralize
6 |
7 | @register_dataset
8 | class ShapePretrainingDatasetPocket(InMemoryDataset):
9 | def __init__(self,
10 | path):
11 | super().__init__(path)
12 |
13 | vocab_path = self._path['vocab']
14 | with open(vocab_path, 'rb') as fr:
15 | self._vocab = pickle.load(fr)
16 |
17 | def _load(self):
18 | self._data = []
19 |
20 | samples_path = self._path['samples']
21 |
22 | with open(samples_path, 'rb') as fr:
23 | samples = pickle.load(fr)
24 |
25 | accecpted, discarded = 0, 0
26 | for i, sample in enumerate(progress_bar(samples, desc='Loading Samples...')):
27 | try:
28 | self._data.append(self._full_callback(sample))
29 | accecpted += 1
30 | except Exception:
31 | logger.warning('sample {} is discarded'.format(i))
32 | discarded += 1
33 |
34 | self._length = len(self._data)
35 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples')
36 |
37 | def _callback(self, sample):
38 | return {
39 | 'shape': sample
40 | }
41 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_dataset_shard.py:
--------------------------------------------------------------------------------
1 | from bycha.datasets import register_dataset
2 | from bycha.datasets.streaming_dataset import StreamingDataset
3 | from .io import MyUniIO
4 | from bycha.utils.runtime import logger
5 | import pickle
6 | from .utils import get_mol_centroid, centralize, set_atom_prop
7 |
8 | @register_dataset
9 | class ShapePretrainingDatasetShard(StreamingDataset):
10 | def __init__(self,
11 | path,
12 | vocab_path,
13 | sample_each_shard,
14 | shuffle=False):
15 | super().__init__(path)
16 | self._sample_each_shard = sample_each_shard
17 | self._shuffle = shuffle
18 | self._fake_epoch = 0
19 |
20 | with open(vocab_path, 'rb') as fr:
21 | self._vocab = pickle.load(fr)
22 |
23 | def build(self, collate_fn=None, preprocessed=False):
24 | self._collate_fn = collate_fn
25 | if self._path:
26 | self._fin = MyUniIO(self._path, self._fake_epoch, mode='rb', shuffle=self._shuffle)
27 |
28 | def __iter__(self):
29 | for sample in self._fin:
30 | try:
31 | sample = self._full_callback(sample)
32 | yield sample
33 | except StopIteration:
34 | raise StopIteration
35 | except Exception as e:
36 | logger.warning(e)
37 |
38 | def reset(self):
39 | self._pos = 0
40 | self._fin = MyUniIO(self._path, self._fake_epoch, mode='rb', shuffle=self._shuffle)
41 | self._fake_epoch += 1
42 |
43 | def _callback(self, sample):
44 | # centralize a molecule and translate its fragments
45 | mol = sample[0]
46 | centroid = get_mol_centroid(mol)
47 | mol = centralize(mol)
48 |
49 | for atom in mol.GetAtoms():
50 | set_atom_prop(atom, 'origin_atom_idx', str(atom.GetIdx()))
51 |
52 | fragment_list = []
53 | for fragment in sample[1]:
54 | if not fragment[3] is None:
55 | trans_vec = fragment[3] - centroid
56 | else:
57 | trans_vec = fragment[3]
58 | fragment_list.append({
59 | 'vocab_id': fragment[0],
60 | 'vocab_key': fragment[1],
61 | 'frag_smi': fragment[2],
62 | 'trans_vec': trans_vec,
63 | 'rotate_mat': fragment[4]
64 | })
65 |
66 | tree_list = sample[2]
67 |
68 | return {
69 | 'mol': mol,
70 | 'frag_list': fragment_list,
71 | 'tree_list': tree_list
72 | }
73 |
74 | def finalize(self):
75 | self._fin.close()
76 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_encoder.py:
--------------------------------------------------------------------------------
1 | from bycha.modules.encoders import register_encoder
2 | from bycha.modules.encoders.transformer_encoder import TransformerEncoder
3 | from bycha.modules.layers.feed_forward import FFN
4 | import torch
5 |
6 | @register_encoder
7 | class ShapePretrainingEncoder(TransformerEncoder):
8 | def __init__(self, patch_size, *args, **kwargs):
9 | super().__init__(*args, **kwargs)
10 | self._patch_size = patch_size
11 |
12 | def build(self,
13 | embed,
14 | special_tokens):
15 | super().build(embed, special_tokens)
16 | self._patch_ffn = FFN(self._patch_size**3, self._d_model, self._d_model)
17 |
18 | def _forward(self, src):
19 | bz, sl = src.size(0), src.size(1)
20 |
21 | x = self._patch_ffn(src)
22 | if self._embed_scale is not None:
23 | x = x * self._embed_scale
24 | if self._pos_embed is not None:
25 | pos = torch.arange(sl).unsqueeze(0).repeat(bz, 1).to(x.device)
26 | x = x + self._pos_embed(pos)
27 | if self._embed_norm is not None:
28 | x = self._embed_norm(x)
29 | x = self._embed_dropout(x)
30 |
31 | src_padding_mask = torch.zeros((bz, sl), dtype=torch.bool).to(x.device)
32 | x = x.transpose(0, 1)
33 | for layer in self._layers:
34 | x = layer(x, src_key_padding_mask=src_padding_mask)
35 |
36 | if self._norm is not None:
37 | x = self._norm(x)
38 |
39 | if self._return_seed:
40 | encoder_out = x[1:], src_padding_mask[:, 1:], x[0]
41 | else:
42 | encoder_out = x, src_padding_mask
43 |
44 | return encoder_out
45 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_iterator_no_regression.py:
--------------------------------------------------------------------------------
1 | from bycha.modules.encoders import register_encoder
2 | from bycha.modules.encoders.transformer_encoder import TransformerEncoder
3 | from bycha.modules.layers.feed_forward import FFN
4 | import torch
5 | import torch.nn as nn
6 |
7 | @register_encoder
8 | class ShapePretrainingIteratorNoRegression(TransformerEncoder):
9 | def __init__(self, *args, **kwargs):
10 | super().__init__(*args, **kwargs)
11 |
12 | def build(self,
13 | embed,
14 | special_tokens,
15 | trans_size,
16 | rotat_size):
17 | super().build(embed, special_tokens)
18 | from bycha.modules.layers.embedding import Embedding
19 |
20 | self._trans_emb = Embedding(vocab_size=trans_size,
21 | d_model=embed.weight.shape[1])
22 | self._rotat_emb = Embedding(vocab_size=rotat_size,
23 | d_model=embed.weight.shape[1])
24 |
25 | self._logits_output_proj = nn.Linear(embed.weight.shape[1],
26 | embed.weight.shape[0],
27 | bias=False)
28 | self._logits_output_proj.weight = embed.weight
29 | self._trans_output_proj = nn.Linear(self._trans_emb.weight.shape[1],
30 | self._trans_emb.weight.shape[0],
31 | bias=False)
32 | self._trans_output_proj.weight = self._trans_emb.weight
33 | self._rotat_output_proj = nn.Linear(self._rotat_emb.weight.shape[1],
34 | self._rotat_emb.weight.shape[0],
35 | bias=False)
36 | self._rotat_output_proj.weight = self._rotat_emb.weight
37 |
38 | def _forward(self, logits, trans, r_mat, padding_mask):
39 | bz, sl = logits.size(0), logits.size(1)
40 | logits_pred = logits.argmax(-1)
41 | trans_pred = trans.argmax(-1)
42 | r_mat_pred = r_mat.argmax(-1)
43 |
44 | x = self._embed(logits_pred)
45 | x = x + self._trans_emb(trans_pred)
46 | x = x + self._rotat_emb(r_mat_pred)
47 |
48 | if self._embed_scale is not None:
49 | x = x * self._embed_scale
50 | if self._pos_embed is not None:
51 | pos = torch.arange(sl).unsqueeze(0).repeat(bz, 1).to(x.device)
52 | x = x + self._pos_embed(pos)
53 | if self._embed_norm is not None:
54 | x = self._embed_norm(x)
55 | x = self._embed_dropout(x)
56 |
57 | src_padding_mask = padding_mask
58 | x = x.transpose(0, 1)
59 | for layer in self._layers:
60 | x = layer(x, src_key_padding_mask=src_padding_mask)
61 |
62 | if self._norm is not None:
63 | x = self._norm(x)
64 |
65 | x = x.transpose(0, 1)
66 | logits = self._logits_output_proj(x)
67 | trans = self._trans_output_proj(x)
68 | r_mat = self._rotat_output_proj(x)
69 |
70 | return logits, trans, r_mat
71 |
--------------------------------------------------------------------------------
/shape_pretraining/shape_pretraining_model.py:
--------------------------------------------------------------------------------
1 | from bycha.models import register_model
2 | from bycha.models.encoder_decoder_model import EncoderDecoderModel
3 | import torch
4 |
5 | @register_model
6 | class ShapePretrainingModel(EncoderDecoderModel):
7 | def __init__(self,
8 | encoder,
9 | decoder,
10 | d_model,
11 | share_embedding=None,
12 | path=None,
13 | no_shape=False,
14 | no_trans=False,
15 | no_rotat=False):
16 | super().__init__(encoder=encoder,
17 | decoder=decoder,
18 | d_model=d_model,
19 | share_embedding=share_embedding,
20 | path=path)
21 | self._no_shape = no_shape
22 | self._no_trans = no_trans
23 | self._no_rotat = no_rotat
24 |
25 | def forward(self,
26 | shape,
27 | shape_patches,
28 | input_frag_idx,
29 | input_frag_idx_mask,
30 | input_frag_trans,
31 | input_frag_trans_mask,
32 | input_frag_r_mat,
33 | input_frag_r_mat_mask):
34 | memory, memory_padding_mask = self._encoder(src=shape_patches)
35 | if self._no_shape:
36 | memory = torch.zeros_like(memory)
37 | if self._no_trans:
38 | input_frag_trans = torch.zeros_like(input_frag_trans)
39 | if self._no_rotat:
40 | input_frag_r_mat = torch.zeros_like(input_frag_r_mat)
41 | logits, trans, r_mat = self._decoder(input_frag_idx=input_frag_idx,
42 | input_frag_trans=input_frag_trans,
43 | input_frag_r_mat=input_frag_r_mat,
44 | memory=memory,
45 | memory_padding_mask=memory_padding_mask)
46 | return (logits, trans, r_mat)
47 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | DIR=$(dirname $(readlink -f "$0"))
3 | cd $DIR
4 | source bashutil.sh
5 |
6 | declare -A BASH_ARGS=(
7 | # bycha config
8 | [config]=./configs/training.yaml
9 | [lib]=shape_pretraining
10 | [args]=
11 | )
12 | parse_args "$@"
13 |
14 | MPIRUN bycha-run \
15 | --config $config \
16 | --lib $lib \
17 | --task.trainer.tensorboard_dir ❗❗❗FILL_THIS❗❗❗ \
18 | --task.trainer.save_model_dir ❗❗❗FILL_THIS❗❗❗ \
19 | --task.trainer.restore_path ❗❗❗FILL_THIS❗❗❗ \
20 | ${BASH_ARGS[args]}
21 |
22 |
--------------------------------------------------------------------------------