├── README.md ├── bashutil.sh ├── configs ├── generating.yaml └── training.yaml ├── generate.sh ├── mybycha ├── .editorconfig ├── .flake8 ├── .gitignore ├── bycha │ ├── __init__.py │ ├── criteria │ │ ├── __init__.py │ │ ├── abstract_criterion.py │ │ ├── auto_encoding_loss.py │ │ ├── base_criterion.py │ │ ├── cross_entropy.py │ │ ├── focal_cross_entropy.py │ │ ├── label_smoothed_cross_entropy.py │ │ ├── label_smoothed_ctc.py │ │ ├── mse.py │ │ ├── multitask_criterion.py │ │ └── self_contained_loss.py │ ├── dataloaders │ │ ├── __init__.py │ │ ├── abstract_dataLoader.py │ │ ├── binarized_dataloader.py │ │ ├── in_memory_dataloader.py │ │ └── streaming_dataloader.py │ ├── datasets │ │ ├── __init__.py │ │ ├── abstract_dataset.py │ │ ├── data_map_dataset.py │ │ ├── in_memory_dataset.py │ │ ├── json_dataset.py │ │ ├── parallel_text_dataset.py │ │ ├── streaming_dataset.py │ │ ├── streaming_json_dataset.py │ │ ├── streaming_parallel_text_dataset.py │ │ ├── streaming_text_dataset.py │ │ ├── text_dataset.py │ │ └── tfrecord_dataset.py │ ├── entries │ │ ├── __init__.py │ │ ├── binarize_data.py │ │ ├── build_tokenizer.py │ │ ├── export.py │ │ ├── preprocess.py │ │ ├── run.py │ │ ├── serve.py │ │ ├── serve_model.py │ │ └── util.py │ ├── evaluators │ │ ├── __init__.py │ │ ├── abstract_evaluator.py │ │ ├── evaluator.py │ │ └── multi_evaluator.py │ ├── generators │ │ ├── __init__.py │ │ ├── abstract_generator.py │ │ ├── extraction_generator.py │ │ ├── generator.py │ │ ├── self_contained_generator.py │ │ └── sequence_generator.py │ ├── metrics │ │ ├── __init__.py │ │ ├── abstract_metric.py │ │ ├── accuracy.py │ │ ├── bleu.py │ │ ├── f1.py │ │ ├── matthews_corr.py │ │ ├── pairwise_metric.py │ │ ├── pearson_corr.py │ │ └── spearman_corr.py │ ├── models │ │ ├── __init__.py │ │ ├── abstract_encoder_decoder_model.py │ │ ├── abstract_model.py │ │ ├── bert_model.py │ │ ├── encoder_decoder_model.py │ │ ├── huggingface │ │ │ ├── __init__.py │ │ │ ├── huggingface_extractive_question_answering_model.py │ │ │ ├── huggingface_pretrain_bart_model.py │ │ │ ├── huggingface_pretrain_mbart_model.py │ │ │ └── huggingface_sequence_classification_model.py │ │ ├── seq2seq.py │ │ ├── sequence_classification_model.py │ │ ├── torch_transformer.py │ │ └── variational_auto_encoder.py │ ├── modules │ │ ├── __init__.py │ │ ├── decoders │ │ │ ├── __init__.py │ │ │ ├── abstract_decoder.py │ │ │ ├── autopruning_decoder.py │ │ │ ├── dlcl_decoder.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── abstract_decoder_layer.py │ │ │ │ ├── autopruning_decoder_layer.py │ │ │ │ ├── lstm_decoder_layer.py │ │ │ │ ├── moe_decoder_layer.py │ │ │ │ ├── nonauto_transformer_decoder_layer.py │ │ │ │ ├── strucdrop_decoder_layer.py │ │ │ │ └── transformer_decoder_layer.py │ │ │ ├── lstm_decoder.py │ │ │ ├── moe_decoder.py │ │ │ ├── nonauto_transformer_decoder.py │ │ │ ├── strucdrop_decoder.py │ │ │ └── transformer_decoder.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ ├── abstract_encoder.py │ │ │ ├── autopruning_encoder.py │ │ │ ├── dlcl_encoder.py │ │ │ ├── huggingface_encoder.py │ │ │ ├── key_value_transformer_encoder.py │ │ │ ├── layers │ │ │ │ ├── __init__.py │ │ │ │ ├── abstract_encoder_layer.py │ │ │ │ ├── autopruning_encoder_layer.py │ │ │ │ ├── moe_encoder_layer.py │ │ │ │ ├── strucdrop_encoder_layer.py │ │ │ │ └── transformer_encoder_layer.py │ │ │ ├── lstm_encoder.py │ │ │ ├── moe_encoder.py │ │ │ ├── multi_encoder_wrapper.py │ │ │ ├── strucdrop_encoder.py │ │ │ ├── transformer_encoder.py │ │ │ └── vae_encoder.py │ │ ├── layers │ │ │ ├── __init__.py │ │ │ ├── autopruning_ffn.py │ │ │ ├── bert_layer_norm.py │ │ │ ├── classifier.py │ │ │ ├── dlcl.py │ │ │ ├── embedding.py │ │ │ ├── feed_forward.py │ │ │ ├── gaussian.py │ │ │ ├── gumbel.py │ │ │ ├── layerdrop.py │ │ │ ├── learned_positional_embedding.py │ │ │ ├── moe.py │ │ │ └── sinusoidal_positional_embedding.py │ │ ├── search │ │ │ ├── __init__.py │ │ │ ├── abstract_search.py │ │ │ ├── beam_search.py │ │ │ ├── forward_sampling.py │ │ │ ├── greedy_search.py │ │ │ └── sequence_search.py │ │ └── utils.py │ ├── optim │ │ ├── __init__.py │ │ └── optimizer.py │ ├── samplers │ │ ├── __init__.py │ │ ├── abstract_sampler.py │ │ ├── batch_shuffle_sampler.py │ │ ├── bucket_sampler.py │ │ ├── distributed_sampler.py │ │ ├── sequential_sampler.py │ │ └── shuffled_sampler.py │ ├── services │ │ ├── __init__.py │ │ ├── idls │ │ │ ├── __init__.py │ │ │ ├── bycha.thrift │ │ │ └── model_infer.thrift │ │ ├── model_server.py │ │ └── server.py │ ├── tasks │ │ ├── __init__.py │ │ ├── abstract_task.py │ │ ├── auto_encoding_task.py │ │ ├── base_task.py │ │ ├── extractive_question_answering_task.py │ │ ├── masked_lm_task.py │ │ ├── seq2seq_task.py │ │ ├── sequence_classification_task.py │ │ ├── sequence_regression_task.py │ │ └── translation_task.py │ ├── tokenizers │ │ ├── __init__.py │ │ ├── abstract_tokenizer.py │ │ ├── fastbpe.py │ │ ├── huggingface_tokenizer.py │ │ ├── sentencepiece.py │ │ ├── utils.py │ │ └── vocabulary.py │ ├── trainers │ │ ├── __init__.py │ │ ├── abstract_trainer.py │ │ ├── moe_trainer.py │ │ ├── trainer.py │ │ └── trainer_autopruning.py │ └── utils │ │ ├── __init__.py │ │ ├── data.py │ │ ├── io.py │ │ ├── ops.py │ │ ├── profiling.py │ │ ├── rate_schedulers │ │ ├── __init__.py │ │ ├── abstract_rate_scheduler.py │ │ ├── constant_rate_scheduler.py │ │ ├── inverse_square_root_rate_scheduler.py │ │ ├── logistic_scheduler.py │ │ ├── noam_scheduler.py │ │ └── polynomial_decay_scheduler.py │ │ ├── registry.py │ │ ├── runtime.py │ │ ├── tensor.py │ │ └── txc_utils.py ├── requirements.txt └── setup.py ├── pics ├── overview.png └── sketch_and_generate.png ├── preparation ├── fragmenizer │ ├── __init__.py │ ├── atom_fragmenizer.py │ ├── brics_fragmenizer.py │ ├── brics_ring_r_fragmenizer.py │ └── ring_r_fragmenizer.py ├── get_fragment_vocab.py ├── get_training_data.py └── utils │ ├── __init__.py │ ├── common.py │ ├── molecule_preparation.py │ ├── shape_utils.py │ ├── tfbio_data.py │ ├── utils.py │ └── utils_full.py ├── shape_pretraining ├── __init__.py ├── common.py ├── io.py ├── shape_pretraining_criterion_no_regression.py ├── shape_pretraining_dataloader_shard.py ├── shape_pretraining_dataset.py ├── shape_pretraining_dataset_pocket.py ├── shape_pretraining_dataset_shard.py ├── shape_pretraining_decoder_iterative_no_regression.py ├── shape_pretraining_encoder.py ├── shape_pretraining_iterator_no_regression.py ├── shape_pretraining_model.py ├── shape_pretraining_search_forward_sampling_dock_dedup_iterative_no_regression.py ├── shape_pretraining_search_iterative_no_regression.py ├── shape_pretraining_task_no_regression.py ├── shape_pretraining_task_no_regression_pocket.py ├── tfbio_data.py └── utils.py ├── sketch ├── shape_utils.py └── sketching.py └── train.sh /README.md: -------------------------------------------------------------------------------- 1 | # DESERT 2 | Zero-Shot 3D Drug Design by Sketching and Generating (NeurIPS 2022) 3 | 4 | 6 |
7 | 8 |
9 |
10 | 11 |
12 | 13 | P.s. Because the project is too tied to ByteDance infrastructure, we can not sure that it can run on your device painlessly. 14 | 15 | ## Requirement 16 | Our method is powered by an old version of [ParaGen](https://github.com/bytedance/ParaGen) (previous name ByCha). 17 | 18 | Install it with 19 | ```bash 20 | cd mybycha 21 | pip install -e . 22 | pip install horovod 23 | pip install lightseq 24 | ``` 25 | You also need to install 26 | ```bash 27 | conda install -c "conda-forge/label/cf202003" openbabel # recommend using anaconda for this project 28 | pip install rdkit-pypi 29 | pip install pybel scikit-image pebble meeko==0.1.dev1 vina pytransform3d 30 | ``` 31 | 32 | ## Pre-training 33 | 34 | ### Data Preparation 35 | Our training data was extracted from the open molecule database [ZINC](https://zinc.docking.org/). You need to download it first. 36 | 37 | To get the fragment vocabulary 38 | ```bash 39 | cd preparation 40 | python get_fragment_vocab.py # fill blank paths in the file first 41 | ``` 42 | 43 | To get the training data 44 | ```bash 45 | python get_training_data.py # fill blank paths in the file first 46 | ``` 47 | 48 | We also provide partial training data and vocabulary [Here](https://drive.google.com/drive/folders/1T2tKgILJAIMK6uTuhh3-qV-Ib0JVgaBs?usp=sharing). 49 | 50 | ### Training Shape2Mol Model 51 | 52 | You need to fill blank paths in configs/training.yaml and train.sh. 53 | 54 | ```bash 55 | bash train.sh 56 | ``` 57 | 58 | We also provide a trained checkpoint [Here](https://drive.google.com/file/d/1YCRORU5aMJEMO8hDT_o9uKCXmXTL5_5N/view?usp=sharing). 59 | 60 | ## Design Molecules 61 | 62 | ### Sketching 63 | 64 | For a given protein, you need to get its pocket by using [CAVITY](http://www.pkumdl.cn:8000/cavityplus/computation.php). 65 | 66 | Sampling molecular shapes with 67 | ```bash 68 | cd sketch 69 | python sketching.py # fill blank paths in the file first 70 | ``` 71 | 72 | ### Generating 73 | 74 | ```bash 75 | bash generate.sh # fill blank paths in the file first 76 | ``` 77 | 78 | ## Citation 79 | ``` 80 | @inproceedings{long2022DESERT, 81 | title={Zero-Shot 3D Drug Design by Sketching and Generating}, 82 | author={Long, Siyu and Zhou, Yi and Dai, Xinyu and Zhou, Hao}, 83 | booktitle={NeurIPS}, 84 | year={2022} 85 | } 86 | ``` 87 | -------------------------------------------------------------------------------- /bashutil.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | function parse_args(){ 5 | while [[ "$#" -gt 0 ]]; do 6 | found=0 7 | for key in "${!BASH_ARGS[@]}"; do 8 | if [[ "--$key" == "$1" ]] ; then 9 | BASH_ARGS[$key]=$2 10 | found=1 11 | fi 12 | done 13 | if [[ $found == 0 ]]; then 14 | echo "arg $1 not defined!" >&2 15 | exit 1 16 | fi 17 | shift; shift 18 | done 19 | 20 | echo "======== PARSED BASH ARGS ========" >&2 21 | for key in "${!BASH_ARGS[@]}"; do 22 | echo " $key = ${BASH_ARGS[$key]}" >&2 23 | eval "$key=${BASH_ARGS[$key]}" >&2 24 | done 25 | echo "==================================" >&2 26 | } 27 | -------------------------------------------------------------------------------- /configs/generating.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | class: ShapePretrainingTaskNoRegressionPocket 3 | mode: evaluate 4 | 5 | # for training efficiency 6 | max_seq_len: 20 7 | 8 | # for get molecule shape 9 | grid_resolution: 0.5 10 | max_dist_stamp: 4.0 11 | max_dist: 6.75 12 | patch_size: 4 13 | 14 | # for molecule augmentation 15 | rotation_bin: 24 16 | max_translation: 1.0 17 | 18 | delta_input: False 19 | 20 | data: 21 | train: 22 | class: ShapePretrainingDatasetShard 23 | path: ---TRAINING DATA PATH--- 24 | vocab_path: ---VOCAB PATH--- 25 | sample_each_shard: 500000 26 | shuffle: True 27 | valid: 28 | class: ShapePretrainingDatasetPocket 29 | path: 30 | samples: ---VALID DATA PATH--- 31 | vocab: ---VOCAB PATH--- 32 | test: 33 | class: ShapePretrainingDatasetPocket 34 | path: 35 | samples: ---TEST DATA PATH--- 36 | vocab: ---VOCAB PATH--- 37 | 38 | dataloader: 39 | train: 40 | class: ShapePretrainingDataLoaderShard 41 | max_samples: 64 42 | valid: 43 | class: InMemoryDataLoader 44 | sampler: 45 | class: SequentialSampler 46 | max_samples: 128 47 | test: 48 | class: InMemoryDataLoader 49 | sampler: 50 | class: SequentialSampler 51 | max_samples: 128 52 | 53 | trainer: 54 | class: Trainer 55 | optimizer: 56 | class: AdamW 57 | lr: 58 | class: InverseSquareRootRateScheduler 59 | rate: 5e-4 60 | warmup_steps: 4000 61 | clip_norm: 0. 62 | betas: (0.9, 0.98) 63 | eps: 1e-8 64 | weight_decay: 1e-2 65 | update_frequency: 4 66 | max_steps: 300000 67 | log_interval: 100 68 | validate_interval_step: 500 69 | assess_reverse: True 70 | 71 | model: 72 | class: ShapePretrainingModel 73 | encoder: 74 | class: ShapePretrainingEncoder 75 | patch_size: 4 76 | num_layers: 12 77 | d_model: 1024 78 | n_head: 8 79 | dim_feedforward: 4096 80 | dropout: 0.1 81 | activation: 'relu' 82 | learn_pos: True 83 | decoder: 84 | class: ShapePretrainingDecoderIterativeNoRegression 85 | num_layers: 12 86 | d_model: 1024 87 | n_head: 8 88 | dim_feedforward: 4096 89 | dropout: 0.1 90 | activation: 'relu' 91 | learn_pos: True 92 | iterative_num: 1 93 | max_dist: 6.75 94 | grid_resolution: 0.5 95 | iterative_block: 96 | class: ShapePretrainingIteratorNoRegression 97 | num_layers: 3 98 | d_model: 1024 99 | n_head: 8 100 | dim_feedforward: 4096 101 | dropout: 0.1 102 | activation: 'relu' 103 | learn_pos: True 104 | d_model: 1024 105 | share_embedding: decoder-input-output 106 | 107 | criterion: 108 | class: ShapePretrainingCriterionNoRegression 109 | 110 | generator: 111 | class: SequenceGenerator 112 | search: 113 | class: ShapePretrainingSearchForwardSamplingDockDedupIterativeNoRegression 114 | maxlen_coef: (0.0, 20.0) 115 | num_return_sequence: 1000 116 | fnum_return_sequence: 100 117 | 118 | evaluator: 119 | class: Evaluator 120 | metric: 121 | acc: 122 | class: Accuracy 123 | 124 | env: 125 | device: cuda 126 | fp16: True 127 | -------------------------------------------------------------------------------- /configs/training.yaml: -------------------------------------------------------------------------------- 1 | task: 2 | class: ShapePretrainingTaskNoRegression 3 | mode: train 4 | 5 | # for training efficiency 6 | max_seq_len: 20 7 | 8 | # for get molecule shape 9 | grid_resolution: 0.5 10 | max_dist_stamp: 4.0 11 | max_dist: 6.75 12 | patch_size: 4 13 | 14 | # for molecule augmentation 15 | rotation_bin: 24 16 | max_translation: 1.0 17 | 18 | delta_input: False 19 | 20 | data: 21 | train: 22 | class: ShapePretrainingDatasetShard 23 | path: ---TRAINING DATA PATH--- 24 | vocab_path: ---VOCAB PATH--- 25 | sample_each_shard: 500000 26 | shuffle: True 27 | valid: 28 | class: ShapePretrainingDataset 29 | path: 30 | samples: ---VALID DATA PATH--- 31 | vocab: ---VOCAB PATH--- 32 | test: 33 | class: ShapePretrainingDataset 34 | path: 35 | samples: ---TEST DATA PATH--- 36 | vocab: ---VOCAB PATH--- 37 | 38 | dataloader: 39 | train: 40 | class: ShapePretrainingDataLoaderShard 41 | max_samples: 64 42 | valid: 43 | class: InMemoryDataLoader 44 | sampler: 45 | class: SequentialSampler 46 | max_samples: 128 47 | test: 48 | class: InMemoryDataLoader 49 | sampler: 50 | class: SequentialSampler 51 | max_samples: 128 52 | 53 | trainer: 54 | class: Trainer 55 | optimizer: 56 | class: AdamW 57 | lr: 58 | class: InverseSquareRootRateScheduler 59 | rate: 5e-4 60 | warmup_steps: 4000 61 | clip_norm: 0. 62 | betas: (0.9, 0.98) 63 | eps: 1e-8 64 | weight_decay: 1e-2 65 | update_frequency: 4 66 | max_steps: 300000 67 | log_interval: 100 68 | validate_interval_step: 500 69 | assess_reverse: True 70 | 71 | model: 72 | class: ShapePretrainingModel 73 | encoder: 74 | class: ShapePretrainingEncoder 75 | patch_size: 4 76 | num_layers: 12 77 | d_model: 1024 78 | n_head: 8 79 | dim_feedforward: 4096 80 | dropout: 0.1 81 | activation: 'relu' 82 | learn_pos: True 83 | decoder: 84 | class: ShapePretrainingDecoderIterativeNoRegression 85 | num_layers: 12 86 | d_model: 1024 87 | n_head: 8 88 | dim_feedforward: 4096 89 | dropout: 0.1 90 | activation: 'relu' 91 | learn_pos: True 92 | iterative_num: 1 93 | max_dist: 6.75 94 | grid_resolution: 0.5 95 | iterative_block: 96 | class: ShapePretrainingIteratorNoRegression 97 | num_layers: 3 98 | d_model: 1024 99 | n_head: 8 100 | dim_feedforward: 4096 101 | dropout: 0.1 102 | activation: 'relu' 103 | learn_pos: True 104 | d_model: 1024 105 | share_embedding: decoder-input-output 106 | 107 | criterion: 108 | class: ShapePretrainingCriterionNoRegression 109 | 110 | generator: 111 | class: SequenceGenerator 112 | search: 113 | class: ShapePretrainingSearchIterativeNoRegression 114 | maxlen_coef: (0.0, 20.0) 115 | 116 | evaluator: 117 | class: Evaluator 118 | metric: 119 | acc: 120 | class: Accuracy 121 | 122 | env: 123 | device: cuda 124 | fp16: True 125 | -------------------------------------------------------------------------------- /generate.sh: -------------------------------------------------------------------------------- 1 | bycha-run \ 2 | --config configs/generating.yaml \ 3 | --lib shape_pretraining \ 4 | --task.mode evaluate \ 5 | --task.data.train.path data \ 6 | --task.data.valid.path.samples ❗❗❗FILL_THIS(MOLECULE SHAPES SAMPLED FROM CAVITY)❗❗❗ \ 7 | --task.data.test.path.samples ❗❗❗FILL_THIS❗❗❗ \ 8 | --task.dataloader.train.max_samples 1 \ 9 | --task.dataloader.valid.sampler.max_samples 1 \ 10 | --task.dataloader.test.sampler.max_samples 1 \ 11 | --task.model.path ❗❗❗FILL_THIS❗❗❗ \ 12 | --task.evaluator.save_hypo_dir ❗❗❗FILL_THIS❗❗❗ -------------------------------------------------------------------------------- /mybycha/.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | max_line_length = 120 7 | 8 | [Makefile] 9 | indent_style = tab 10 | 11 | [*.py] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | [*.{js,ts,html}] 16 | indent_style = space 17 | indent_size = 2 -------------------------------------------------------------------------------- /mybycha/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length=120 3 | exclude = 4 | venv/, 5 | venv_py/, 6 | .eggs, 7 | .tox 8 | ignore = D400,D300,D205,D200,D105,D100,D101,D103,D107 9 | -------------------------------------------------------------------------------- /mybycha/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | **/*.pyc 3 | **/*.DS_Store 4 | *.egg-info/ 5 | *.swp 6 | .eggs/ 7 | .idea/ 8 | .tox/ 9 | .pytest_cache/* 10 | venv/ 11 | venv_py/ 12 | .mypy_cache/ 13 | __pycache__/ 14 | /build 15 | /dist 16 | checkpoints/ 17 | *.pt -------------------------------------------------------------------------------- /mybycha/bycha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/criteria/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_criterion import AbstractCriterion 7 | 8 | register_criterion, create_criterion, registry = setup_registry('criterion', AbstractCriterion) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.criteria.' + module_name) 20 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/abstract_criterion.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | 4 | from torch.nn import Module 5 | 6 | from bycha.utils.runtime import Environment 7 | 8 | 9 | class AbstractCriterion(Module): 10 | """ 11 | Criterion is the base class for all the criterion within ByCha. 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | self._model = None 17 | 18 | def build(self, *args, **kwargs): 19 | """ 20 | Construct a criterion for model training. 21 | Typically, `model` should be provided. 22 | """ 23 | self._build(*args, **kwargs) 24 | 25 | e = Environment() 26 | if e.device.startswith('cuda'): 27 | logger.info('move criterion to {}'.format(e.device)) 28 | self.cuda(e.device) 29 | 30 | def _build(self, *args, **kwargs): 31 | pass 32 | 33 | def forward(self, *args, **kwargs): 34 | """ 35 | Compute the loss from neural model input, and produce a loss. 36 | """ 37 | raise NotImplementedError 38 | 39 | def step_update(self, *args, **kwargs): 40 | """ 41 | Perform step-level update 42 | """ 43 | pass 44 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/auto_encoding_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from bycha.criteria import register_criterion 5 | from bycha.criteria.base_criterion import BaseCriterion 6 | from bycha.utils.rate_schedulers import create_rate_scheduler 7 | 8 | 9 | @register_criterion 10 | class AutoEncodingLoss(BaseCriterion): 11 | """ 12 | Label smoothed cross entropy 13 | 14 | Args: 15 | epsilon: label smoothing rate 16 | """ 17 | 18 | def __init__(self, epsilon=0.1, beta=1.): 19 | super().__init__() 20 | self._epsilon = epsilon 21 | self._beta_configs = beta 22 | 23 | self._padding_idx = None 24 | self._beta = None 25 | 26 | def _build(self, model, padding_idx=-1): 27 | self._padding_idx = padding_idx 28 | self._model = model 29 | 30 | self._beta = create_rate_scheduler(self._beta_configs) 31 | self._beta.build() 32 | 33 | def _reconstruct_loss(self, lprobs, target, reduce=False): 34 | assert target.dim() == lprobs.dim() - 1 35 | 36 | lprobs, target = lprobs.view(-1, lprobs.size(-1)), target.view(-1) 37 | padding_mask = target.eq(self._padding_idx) 38 | ntokens = (~padding_mask).sum() 39 | # calculate nll loss 40 | nll_loss = -lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1) 41 | nll_loss.masked_fill_(padding_mask, 0.) 42 | if reduce: 43 | nll_loss = nll_loss.sum() / ntokens 44 | 45 | # calculate smoothed loss 46 | smooth_loss = -lprobs.mean(dim=-1) 47 | smooth_loss.masked_fill_(padding_mask, 0.) 48 | smooth_loss = smooth_loss.sum() / ntokens 49 | 50 | return nll_loss, ntokens, smooth_loss 51 | 52 | def step_update(self, step): 53 | """ 54 | Perform step-level update 55 | 56 | Args: 57 | step: running step 58 | """ 59 | self._beta.step_update(step) 60 | 61 | def compute_loss(self, lprobs, net_output): 62 | """ 63 | Compute loss from a batch of samples 64 | 65 | Args: 66 | lprobs: neural network output logits 67 | net_output: neural net output 68 | Returns: 69 | - loss for network backward and optimization 70 | - logging information 71 | """ 72 | lprobs = F.log_softmax(lprobs, dim=-1) 73 | # fetch target with default index 0 74 | target = net_output[0] 75 | 76 | bsz, sql = target.size() 77 | rec_loss, n_tokens, smooth_loss = self._reconstruct_loss(lprobs, target, reduce=False) 78 | rec_loss = torch.sum(rec_loss.view(bsz, -1), dim=-1) 79 | reg_loss = self._model.reg_loss() 80 | 81 | loss = torch.sum(rec_loss + self._beta.rate * reg_loss) / n_tokens 82 | loss = (1. - self._epsilon) * loss + self._epsilon * smooth_loss if self.training else loss 83 | 84 | nll_loss = torch.sum(self._model.nll(rec_loss, reg_loss)) / n_tokens # real nll loss 85 | 86 | logging_states = { 87 | 'reg_weight': self._beta.rate, 88 | 'loss': loss.data.item(), 89 | 'nll_loss': nll_loss.data.item(), 90 | 'ppl': 2 ** (nll_loss.data.item()), 91 | 'reg_loss': torch.mean(reg_loss).item(), 92 | 'rec_loss': (torch.sum(rec_loss)/n_tokens).item() 93 | } 94 | 95 | return loss, logging_states 96 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/base_criterion.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from bycha.criteria import AbstractCriterion 4 | 5 | 6 | class BaseCriterion(AbstractCriterion): 7 | """ 8 | BaseCriterion is the base class for all the criterion within ByCha. 9 | """ 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, net_input, net_output): 15 | """ 16 | Compute loss from a batch of samples 17 | 18 | Args: 19 | net_input: neural network input and is used for compute the logits 20 | net_output (dict): oracle target for a network input 21 | Returns: 22 | tuple: 23 | - **loss**: loss for network backward and optimization 24 | - **logging_states**: logging information 25 | """ 26 | if isinstance(net_input, Dict): 27 | lprobs = self._model(**net_input) 28 | elif isinstance(net_input, List) or isinstance(net_input, Tuple): 29 | lprobs = self._model(*net_input) 30 | else: 31 | lprobs = self._model(net_input) 32 | # fetch target with default index 0 33 | loss, logging_states = self.compute_loss(lprobs, **net_output) 34 | return loss, logging_states 35 | 36 | def compute_loss(self, *args, **kwargs): 37 | """ 38 | Compute loss from model results 39 | """ 40 | raise NotImplementedError 41 | 42 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from bycha.criteria import register_criterion 6 | from bycha.criteria.base_criterion import BaseCriterion 7 | 8 | 9 | @register_criterion 10 | class CrossEntropy(BaseCriterion): 11 | """ 12 | Cross Entropy Loss. 13 | 14 | """ 15 | 16 | def __init__(self, weight=None, logging_metric='acc'): 17 | super().__init__() 18 | self._weight = torch.FloatTensor(weight) if weight is not None else weight 19 | self._logging_metric = logging_metric 20 | self._padding_idx = None 21 | self._nll_loss = None 22 | 23 | def _build(self, model, padding_idx=-1): 24 | """ 25 | Build a cross entropy loss over model. 26 | 27 | Args: 28 | model: a neural model for compute cross entropy. 29 | padding_idx: labels of padding_idx are all ignored to computed nll_loss 30 | """ 31 | self._model = model 32 | self._padding_idx = padding_idx 33 | self._nll_loss = nn.NLLLoss(weight=self._weight, ignore_index=padding_idx) 34 | 35 | def compute_loss(self, lprobs, target): 36 | """ 37 | Compute loss from a batch of samples 38 | 39 | Args: 40 | lprobs: neural network output logits 41 | target: oracle target for a network input 42 | 43 | Returns: 44 | - loss for network backward and optimization 45 | - logging information 46 | """ 47 | lprobs = F.log_softmax(lprobs, dim=-1) 48 | 49 | # compute nll loss 50 | lprobs = lprobs.view(-1, lprobs.size(-1)) 51 | target = target.view(-1) 52 | nll_loss = self._nll_loss(lprobs, target) 53 | 54 | # record logging 55 | logging_states = { 56 | 'loss': nll_loss.data.item(), 57 | } 58 | if self._logging_metric == 'acc': 59 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item() 60 | tot = target.size(0) 61 | logging_states['acc'] = correct / tot 62 | elif self._logging_metric == 'ppl': 63 | logging_states['ppl'] = 2 ** (nll_loss.data.item()) 64 | return nll_loss, logging_states 65 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/focal_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from bycha.criteria import register_criterion 4 | from bycha.criteria.base_criterion import BaseCriterion 5 | 6 | 7 | @register_criterion 8 | class FocalCrossEntropy(BaseCriterion): 9 | """ 10 | Label smoothed cross entropy 11 | 12 | Args: 13 | gamma: focal loss rate 14 | """ 15 | 16 | def __init__(self, gamma: float = 2.0): 17 | super().__init__() 18 | self._gamma = gamma 19 | 20 | self._padding_idx = None 21 | 22 | def _build(self, model, padding_idx=-1): 23 | """ 24 | Build a label smoothed cross entropy loss over model. 25 | 26 | Args: 27 | model: a neural model for compute cross entropy. 28 | padding_idx: labels of padding_idx are all ignored to computed nll_loss 29 | """ 30 | self._model = model 31 | self._padding_idx = padding_idx 32 | 33 | def compute_loss(self, lprobs, target): 34 | """ 35 | Compute loss from a batch of samples 36 | 37 | Args: 38 | lprobs: neural network output logits 39 | target: oracle target for a network input 40 | Returns: 41 | - loss for network backward and optimization 42 | - logging information 43 | """ 44 | lprobs = F.log_softmax(lprobs, dim=-1) 45 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item() 46 | tot = target.size(0) 47 | 48 | target_padding_mask = target.eq(self._padding_idx) 49 | assert target.dim() == lprobs.dim() - 1 50 | 51 | lprobs = lprobs.view(-1, lprobs.size(-1)) 52 | target = target.view(-1) 53 | target_padding_mask = target_padding_mask.view(-1) 54 | ntokens = (~target_padding_mask).sum() 55 | 56 | # calculate nll loss 57 | lprobs = lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1) 58 | weight = (1 - lprobs.exp()) ** self._gamma 59 | loss = - weight * lprobs 60 | loss.masked_fill_(target_padding_mask, 0.) 61 | loss = loss.sum() / ntokens 62 | 63 | # record logging 64 | logging_states = { 65 | 'loss': loss.data.item(), 66 | 'ntokens': ntokens.data.item(), 67 | 'acc': correct / tot 68 | } 69 | 70 | return loss, logging_states 71 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/label_smoothed_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from bycha.criteria import register_criterion 4 | from bycha.criteria.base_criterion import BaseCriterion 5 | 6 | 7 | @register_criterion 8 | class LabelSmoothedCrossEntropy(BaseCriterion): 9 | """ 10 | Label smoothed cross entropy 11 | 12 | Args: 13 | epsilon: label smoothing rate 14 | """ 15 | 16 | def __init__(self, epsilon: float = 0.1): 17 | super().__init__() 18 | self._epsilon = epsilon 19 | 20 | self._padding_idx = None 21 | 22 | def _build(self, model, padding_idx=-1): 23 | """ 24 | Build a label smoothed cross entropy loss over model. 25 | 26 | Args: 27 | model: a neural model for compute cross entropy. 28 | padding_idx: labels of padding_idx are all ignored to computed nll_loss 29 | """ 30 | self._model = model 31 | self._padding_idx = padding_idx 32 | 33 | def compute_loss(self, lprobs, target): 34 | """ 35 | Compute loss from a batch of samples 36 | 37 | Args: 38 | lprobs: neural network output logits 39 | target: oracle target for a network input 40 | Returns: 41 | - loss for network backward and optimization 42 | - logging information 43 | """ 44 | lprobs = F.log_softmax(lprobs, dim=-1) 45 | target_padding_mask = target.eq(self._padding_idx) 46 | assert target.dim() == lprobs.dim() - 1 47 | # infer task type 48 | is_classification_task = len(target.size()) == 1 49 | 50 | lprobs = lprobs.view(-1, lprobs.size(-1)) 51 | target = target.view(-1) 52 | target_padding_mask = target_padding_mask.view(-1) 53 | ntokens = (~target_padding_mask).sum() 54 | 55 | # calculate nll loss 56 | nll_loss = -lprobs.gather(dim=-1, index=target.unsqueeze(dim=-1)).squeeze(dim=-1) 57 | nll_loss.masked_fill_(target_padding_mask, 0.) 58 | nll_loss = nll_loss.sum() / ntokens 59 | 60 | # calculate smoothed loss 61 | if self._epsilon > 0.: 62 | smooth_loss = -lprobs.mean(dim=-1) 63 | smooth_loss.masked_fill_(target_padding_mask, 0.) 64 | smooth_loss = smooth_loss.sum() / ntokens 65 | 66 | # average nll loss and smoothed loss, weighted by epsilon 67 | loss = (1. - self._epsilon) * nll_loss + self._epsilon * smooth_loss if self.training else nll_loss 68 | else: 69 | loss = nll_loss 70 | 71 | # record logging 72 | logging_states = { 73 | 'loss': loss.data.item(), 74 | 'nll_loss': nll_loss.data.item(), 75 | 'ntokens': ntokens.data.item(), 76 | } 77 | if is_classification_task: 78 | correct = (lprobs.max(dim=-1)[1] == target).sum().data.item() 79 | tot = target.size(0) 80 | logging_states['acc'] = correct / tot 81 | else: 82 | logging_states['ppl'] = 2 ** (nll_loss.data.item()) 83 | 84 | return loss, logging_states 85 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/mse.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from bycha.criteria import register_criterion 4 | from bycha.criteria.base_criterion import BaseCriterion 5 | 6 | 7 | @register_criterion 8 | class MSE(BaseCriterion): 9 | """ 10 | Mean square error 11 | 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | self._mse_loss = None 17 | 18 | def _build(self, model): 19 | """ 20 | Build a cross entropy loss over model. 21 | 22 | Args: 23 | model: a neural model for compute cross entropy. 24 | """ 25 | self._model = model 26 | self._mse_loss = nn.MSELoss() 27 | 28 | def compute_loss(self, pred, target): 29 | """ 30 | Compute loss from a batch of samples 31 | 32 | Args: 33 | pred: neural network output 34 | target: oracle target for a network input 35 | Returns: 36 | - loss for network backward and optimization 37 | - logging information 38 | """ 39 | # compute nll loss 40 | pred = pred.view(-1) 41 | target = target.view(-1) 42 | mse_loss = self._mse_loss(pred, target) 43 | 44 | # record logging 45 | logging_states = { 46 | 'loss': mse_loss.data.item(), 47 | } 48 | return mse_loss, logging_states 49 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/multitask_criterion.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from bycha.criteria import AbstractCriterion, create_criterion, register_criterion 4 | 5 | 6 | @register_criterion 7 | class MultiTaskCriterion(AbstractCriterion): 8 | """ 9 | Criterion is the base class for all the criterion within ByCha. 10 | """ 11 | 12 | def __init__(self, criterions): 13 | super().__init__() 14 | self._criterion_configs = criterions 15 | 16 | self._names = [name for name in self._criterion_configs] 17 | self._criterions, self._weights = None, None 18 | 19 | def _build(self, model, padding_idx=-1): 20 | """ 21 | Build multi-task criterion by dispatch args to each criterion 22 | 23 | Args: 24 | model: neural model 25 | padding_idx: pad idx to ignore 26 | """ 27 | self._model = model 28 | self._criterions, self._weights = {}, {} 29 | for name in self._names: 30 | criterion_config = self._criterion_configs[name] 31 | self._weights[name] = criterion_config.pop('weight') if 'weight' in criterion_config else 1 32 | self._criterions[name] = create_criterion(self._criterion_configs[name]) 33 | self._criterions[name].build(model, padding_idx) 34 | 35 | def forward(self, net_input, net_output): 36 | """ 37 | Compute loss from a batch of samples 38 | 39 | Args: 40 | net_input: neural network input and is used for compute the logits 41 | net_output (dict): oracle target for a network input 42 | Returns: 43 | - loss for network backward and optimization 44 | - logging information 45 | """ 46 | lprobs_dict = self._model(**net_input) 47 | assert isinstance(lprobs_dict, Dict), 'A multitask learning model must return a dict of log-probability' 48 | # fetch target with default index 0 49 | tot_loss, complete_logging_states = 0, {} 50 | for name in self._names: 51 | lprobs, net_out, criterion = lprobs_dict[name], net_output[name], self._criterions[name] 52 | loss, logging_states = criterion.compute_loss(lprobs, **net_out) 53 | tot_loss += self._weights[name] * loss 54 | logging_states = {f'{name}.{key}': val for key, val in logging_states.items()} 55 | complete_logging_states.update(logging_states) 56 | complete_logging_states['loss'] = tot_loss.data.item() 57 | return tot_loss, complete_logging_states 58 | 59 | -------------------------------------------------------------------------------- /mybycha/bycha/criteria/self_contained_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from bycha.criteria import AbstractCriterion, register_criterion 4 | 5 | 6 | @register_criterion 7 | class SelfContainedLoss(AbstractCriterion): 8 | """ 9 | SelfContainedLoss. 10 | 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def build(self, model, *args, **kwargs): 17 | """ 18 | Build a cross entropy loss over model. 19 | 20 | Args: 21 | model: a neural model for compute cross entropy. 22 | """ 23 | self._model = model 24 | 25 | def forward(self, net_input): 26 | """ 27 | Compute loss via model itself 28 | 29 | Args: 30 | net_input (dict): neural network input and is used for compute the logits 31 | Returns: 32 | - loss for network backward and optimization 33 | - logging information 34 | """ 35 | output = self._model.loss(**net_input) 36 | if isinstance(output, Tuple): 37 | assert len(output) == 2, 'if a tuple returned, it must be (loss, logging_states)' 38 | loss, logging_states = output 39 | else: 40 | loss = output 41 | logging_states = {'loss': loss.data.item()} 42 | return loss, logging_states 43 | -------------------------------------------------------------------------------- /mybycha/bycha/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.ops import deepcopy_on_ref 5 | from bycha.utils.registry import setup_registry 6 | 7 | from .abstract_dataLoader import AbstractDataLoader 8 | 9 | register_dataloader, _build_dataloader, registry = setup_registry('dataloader', AbstractDataLoader) 10 | 11 | 12 | def build_dataloader(configs, dataset, sampler=None, collate_fn=None, post_collate=False): 13 | """ 14 | Build a dataloader 15 | 16 | Args: 17 | configs: dataloader configs 18 | dataset: dataset storing samples 19 | sampler: sample strategy 20 | collate_fn: collate function during data fetching with torch.utils.data.DataLoader 21 | post_collate: whether to perform collate_fn after data fetching 22 | 23 | Returns: 24 | AbstractDataLoader 25 | """ 26 | configs = deepcopy_on_ref(configs) 27 | configs.update({ 28 | 'dataset': dataset, 29 | 'collate_fn': collate_fn if not post_collate else None, 30 | 'post_collate_fn': collate_fn if post_collate else None 31 | }) 32 | if sampler is not None: 33 | configs['sampler'] = sampler 34 | dataloader = _build_dataloader(configs) 35 | return dataloader 36 | 37 | 38 | modules_dir = os.path.dirname(__file__) 39 | for file in os.listdir(modules_dir): 40 | path = os.path.join(modules_dir, file) 41 | if ( 42 | not file.startswith('_') 43 | and not file.startswith('.') 44 | and (file.endswith('.py') or os.path.isdir(path)) 45 | ): 46 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 47 | module = importlib.import_module('bycha.dataloaders.' + module_name) 48 | -------------------------------------------------------------------------------- /mybycha/bycha/dataloaders/binarized_dataloader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import random 4 | 5 | from bycha.dataloaders import register_dataloader, AbstractDataLoader 6 | from bycha.utils.io import UniIO 7 | from bycha.utils.runtime import Environment 8 | from bycha.utils.tensor import list2tensor 9 | 10 | 11 | @register_dataloader 12 | class BinarizedDataLoader(AbstractDataLoader): 13 | """ 14 | AbstractDataLoader to sample and process data from dataset 15 | 16 | Args: 17 | path: path to load binarized data 18 | """ 19 | 20 | def __init__(self, 21 | path, 22 | preload=False, 23 | length_interval=8, 24 | max_shuffle_size=1, 25 | **kwargs): 26 | super().__init__(None) 27 | self._path = path 28 | self._preload = preload 29 | self._batches = None 30 | self._length_interval = length_interval 31 | self._max_shuffle_size = max_shuffle_size 32 | 33 | env = Environment() 34 | self._rank = env.rank 35 | self._distributed_wolrds = env.distributed_world 36 | self._max_buffered_batch_num = self._max_shuffle_size * self._distributed_wolrds 37 | self._buffered_batches = [] 38 | 39 | if preload: 40 | self._batches = [] 41 | with UniIO(self._path) as fin: 42 | for batch in fin: 43 | batch = json.loads(batch) 44 | batch = list2tensor(batch) 45 | self._batches.append(batch) 46 | total_size = int(math.ceil(len(self._batches) * 1.0 / self._distributed_wolrds)) * self._distributed_wolrds 47 | self._batches += self._batches[:(total_size - len(self._batches))] 48 | 49 | def reset(self, *args, **kwargs): 50 | if not self._preload: 51 | self._batches = UniIO(self._path) 52 | else: 53 | if self._max_shuffle_size > 0: 54 | random.shuffle(self._batches) 55 | self._buffered_batches.clear() 56 | return self 57 | 58 | def __iter__(self): 59 | for batch in self._batches: 60 | if not self._preload: 61 | batch = json.loads(batch) 62 | batch = list2tensor(batch) 63 | self._buffered_batches.append(batch) 64 | if len(self._buffered_batches) == self._max_buffered_batch_num: 65 | for s in self._dispatch(): 66 | yield s 67 | if len(self._buffered_batches) >= self._distributed_wolrds: 68 | for s in self._dispatch(): 69 | yield s 70 | 71 | def _dispatch(self): 72 | random.shuffle(self._buffered_batches) 73 | batch_num = len(self._buffered_batches) // self._distributed_wolrds * self._distributed_wolrds 74 | self._buffered_batches = self._buffered_batches[self._rank:batch_num:self._distributed_wolrds] 75 | for s in self._buffered_batches: 76 | yield s 77 | self._buffered_batches.clear() 78 | 79 | def __len__(self): 80 | return len(self._batches) // self._distributed_wolrds if self._preload else 0 81 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_dataset import AbstractDataset 7 | 8 | register_dataset, create_dataset, registry = setup_registry('dataset', AbstractDataset) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.datasets.' + module_name) 20 | 21 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/in_memory_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from bycha.datasets import AbstractDataset 4 | 5 | 6 | class InMemoryDataset(AbstractDataset): 7 | """ 8 | An in-memory dataset which load data in memory before running task. 9 | InMemoryDataset is suitable for dataset of relatively low capacity. 10 | 11 | Args: 12 | path: data path to read 13 | sort_samples (bool): sort samples before running a task. 14 | It would be useful in inference without degrading performance. 15 | max_size: maximum size of loaded data 16 | """ 17 | 18 | def __init__(self, 19 | path, 20 | sort_samples=False, 21 | max_size=0,): 22 | super().__init__(path, max_size=max_size) 23 | 24 | self._data = None 25 | self._sort_samples = sort_samples 26 | 27 | def build(self, collate_fn=None, preprocessed=False, **kwargs): 28 | """ 29 | Build input stream and load data into memory 30 | 31 | Args: 32 | collate_fn: callback defined by a specific task 33 | preprocessed: data has been preprocessed 34 | """ 35 | self._collate_fn = collate_fn 36 | self._preprocessed = preprocessed 37 | 38 | if self._path: 39 | self._load() 40 | self._pos = 0 41 | 42 | def _load(self): 43 | """ 44 | Load data into memory 45 | """ 46 | raise NotImplementedError 47 | 48 | def shuffle(self): 49 | """ 50 | shuffle preload data 51 | """ 52 | random.shuffle(self._data) 53 | 54 | def __getitem__(self, index): 55 | """ 56 | fetch an item at index 57 | 58 | Args: 59 | index: index of item to fetch 60 | 61 | Returns: 62 | sample: data of index in preload data list 63 | """ 64 | return self._data[index] 65 | 66 | def __iter__(self): 67 | for sample in self._data: 68 | yield sample 69 | 70 | def __next__(self): 71 | """ 72 | fetch next sample 73 | 74 | Returns: 75 | sample: next sample 76 | """ 77 | if self._pos < len(self._data): 78 | sample = self._data[self._pos] 79 | self._pos += 1 80 | else: 81 | raise StopIteration 82 | return sample 83 | 84 | def reset(self): 85 | """ 86 | Reset io for a new round of iteration 87 | """ 88 | self._pos = 0 89 | 90 | def finalize(self): 91 | super().finalize() 92 | try: 93 | del self._data 94 | except: 95 | pass 96 | 97 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/json_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from bycha.datasets import register_dataset 6 | from bycha.datasets.in_memory_dataset import InMemoryDataset 7 | from bycha.utils.data import count_sample_token 8 | from bycha.utils.io import UniIO 9 | from bycha.utils.runtime import progress_bar 10 | 11 | 12 | @register_dataset 13 | class JsonDataset(InMemoryDataset): 14 | """ 15 | JsonDataset is an in-memory dataset for reading data saved with json.dumps. 16 | 17 | Args: 18 | path: data path to read 19 | sort_samples (bool): sort samples before running a task. 20 | It would be useful in inference without degrading performance. 21 | max_size: maximum size of loaded data 22 | """ 23 | 24 | def __init__(self, 25 | path, 26 | sort_samples=False, 27 | max_size=0): 28 | super().__init__(path, sort_samples=sort_samples, max_size=max_size) 29 | 30 | def _load(self): 31 | """ 32 | Preload all the data into memory. In the loading process, data are preprocess and sorted. 33 | """ 34 | fin = UniIO(path=self._path) 35 | self._data = [] 36 | accecpted, discarded = 0, 0 37 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')): 38 | if 0 < self._max_size <= i: 39 | break 40 | try: 41 | sample = sample.strip('\n') 42 | self._data.append(self._full_callback(sample)) 43 | accecpted += 1 44 | except Exception: 45 | logger.warning('sample {} is discarded'.format(sample)) 46 | discarded += 1 47 | if self._sort_samples: 48 | self._data.sort(key=lambda x: count_sample_token(x)) 49 | self._length = len(self._data) 50 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples') 51 | fin.close() 52 | 53 | def _callback(self, sample): 54 | """ 55 | Callback for json data 56 | 57 | Args: 58 | sample: data in raw format 59 | 60 | Returns: 61 | sample (dict): a dict of samples consisting of parallel data of different sources 62 | """ 63 | sample = json.loads(sample) 64 | return sample 65 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/parallel_text_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from bycha.datasets import register_dataset 6 | from bycha.datasets.in_memory_dataset import InMemoryDataset 7 | from bycha.utils.data import count_sample_token 8 | from bycha.utils.io import UniIO 9 | from bycha.utils.runtime import progress_bar, Environment 10 | 11 | 12 | @register_dataset 13 | class ParallelTextDataset(InMemoryDataset): 14 | """ 15 | ParallelTextDataset is an in-memory dataset for reading data saved in parallel files. 16 | 17 | Args: 18 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only. 19 | sort_samples (bool): sort samples before running a task. 20 | It would be useful in inference without degrading performance. 21 | max_size: maximum size of loaded data 22 | """ 23 | 24 | def __init__(self, 25 | path: Dict[str, str] = None, 26 | sort_samples: bool = False, 27 | max_size: int = 0,): 28 | super().__init__(path, sort_samples=sort_samples, max_size=max_size) 29 | self._sources = list(path.keys()) 30 | 31 | def _callback(self, sample): 32 | """ 33 | Callback for parallel data 34 | 35 | Args: 36 | sample: data in raw format 37 | 38 | Returns: 39 | sample (dict): a dict of samples consisting of parallel data of different sources 40 | """ 41 | if self._preprocessed: 42 | sample = {key: [eval(v) for v in val] for key, val in sample.items()} 43 | return sample 44 | 45 | def _load(self): 46 | """ 47 | Preload all the data into memory. In the loading process, data are collate_fnd and sorted. 48 | """ 49 | ori_fin = [UniIO(self._path[src]) for src in self._sources] 50 | fin = zip(*ori_fin) 51 | self._data = [] 52 | accepted, discarded = 0, 0 53 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')): 54 | if 0 < self._max_size <= i: 55 | break 56 | try: 57 | sample = self._full_callback({ 58 | src: s.strip('\n') 59 | for src, s in zip(self._sources, sample) 60 | }) 61 | self._data.append(sample) 62 | accepted += 1 63 | except Exception as e: 64 | env = Environment() 65 | if env.debug: 66 | raise e 67 | logger.warning('sample {} is discarded'.format(sample)) 68 | discarded += 1 69 | if self._sort_samples: 70 | self._data.sort(key=lambda x: count_sample_token(x)) 71 | self._length = len(self._data) 72 | logger.info(f'Totally accept {accepted} samples, discard {discarded} samples') 73 | for fin in ori_fin: 74 | fin.close() 75 | 76 | self._collate_fn = None 77 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/streaming_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from torch.utils.data import IterableDataset 4 | 5 | from bycha.datasets import AbstractDataset 6 | 7 | 8 | class StreamingDataset(AbstractDataset, IterableDataset): 9 | """ 10 | Tackle with io and create parallel data 11 | 12 | Args: 13 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only. 14 | 15 | """ 16 | 17 | def __init__(self, 18 | path: Dict[str, str],): 19 | super().__init__(path) 20 | self._fin = None 21 | self._length = None 22 | 23 | def shuffle(self): 24 | """ 25 | shuffle preload data 26 | """ 27 | pass 28 | 29 | def __getitem__(self, index): 30 | """ 31 | fetch an item with index 32 | 33 | Args: 34 | index: index of item to fetch 35 | 36 | Returns: 37 | sample: data of index in preload data list 38 | """ 39 | return next(self) 40 | 41 | def __next__(self): 42 | """ 43 | fetch next sample 44 | 45 | Returns: 46 | sample: next sample 47 | """ 48 | raise NotImplementedError 49 | 50 | def reset(self): 51 | """ 52 | Reset io for a new round of iteration 53 | """ 54 | pass 55 | 56 | def __len__(self): 57 | """ 58 | Compute dataset length 59 | 60 | Returns: 61 | dataset length 62 | """ 63 | return 0 64 | 65 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/streaming_json_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from bycha.datasets import register_dataset 6 | from bycha.datasets.streaming_dataset import StreamingDataset 7 | from bycha.utils.io import UniIO 8 | 9 | 10 | @register_dataset 11 | class StreamingJsonDataset(StreamingDataset): 12 | """ 13 | StreamingJsonDataset is a streaming dataset for reading data saved with json.dumps. 14 | 15 | Args: 16 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only. 17 | """ 18 | 19 | def __init__(self, path): 20 | super().__init__(path) 21 | 22 | def build(self, collate_fn=None, preprocessed=False): 23 | """ 24 | Build input stream 25 | 26 | Args: 27 | collate_fn: callback defined by a specific task 28 | preprocessed: data has been processed 29 | """ 30 | self._collate_fn = collate_fn 31 | self._preprocessed = preprocessed 32 | 33 | if self._path: 34 | self._fin = UniIO(self._path) 35 | 36 | def __iter__(self): 37 | """ 38 | fetch next sample 39 | 40 | Returns: 41 | sample: next sample 42 | """ 43 | for sample in self._fin: 44 | try: 45 | sample = self._full_callback(sample) 46 | yield sample 47 | except Exception as e: 48 | logger.warning(e) 49 | 50 | def _callback(self, sample): 51 | """ 52 | Callback for json data 53 | 54 | Args: 55 | sample: data in raw format 56 | 57 | Returns: 58 | sample (dict): a dict of samples consisting of parallel data of different sources 59 | """ 60 | sample = json.loads(sample) 61 | return sample 62 | 63 | def reset(self): 64 | """ 65 | reset the dataset 66 | """ 67 | self._pos = 0 68 | self._fin = UniIO(self._path) 69 | 70 | def finalize(self): 71 | """ 72 | Finalize dataset after finish reading 73 | """ 74 | self._fin.close() 75 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/streaming_parallel_text_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from bycha.datasets import register_dataset 4 | from bycha.datasets.streaming_dataset import StreamingDataset 5 | from bycha.utils.io import UniIO 6 | 7 | 8 | @register_dataset 9 | class StreamingParallelTextDataset(StreamingDataset): 10 | """ 11 | StreamingParallelTextDataset is a streaming dataset for reading data saved in parallel files. 12 | 13 | Args: 14 | path: a dict of data with their path. `path` can be `None` to build the process pipeline only. 15 | """ 16 | 17 | def __init__(self, 18 | path: Dict[str, str] = None,): 19 | super().__init__(path) 20 | self._sources = path.keys() 21 | self._ori_fin = None 22 | 23 | def build(self, collate_fn=None, preprocessed=False): 24 | """ 25 | Build input stream 26 | 27 | Args: 28 | collate_fn: callback defined by a specific task 29 | preprocessed: whether the data has been preprocessed 30 | """ 31 | self._collate_fn = collate_fn 32 | self._preprocessed = preprocessed 33 | 34 | if self._path: 35 | self._ori_fin = [UniIO(self._path[src]) for src in self._sources] 36 | self._fin = zip(*self._ori_fin) 37 | 38 | def __iter__(self): 39 | """ 40 | fetch next sample 41 | 42 | Returns: 43 | sample: next sample 44 | """ 45 | for sample in self._fin: 46 | sample = self._full_callback({src: s.strip('\n') for src, s in zip(self._sources, sample)}) 47 | yield sample 48 | 49 | def _callback(self, sample): 50 | """ 51 | Callback for parallel data 52 | 53 | Args: 54 | sample: data in raw format 55 | 56 | Returns: 57 | sample (dict): a dict of samples consisting of parallel data of different sources 58 | """ 59 | if self._preprocessed: 60 | sample = {key: [eval(v) for v in val] for key, val in sample.items()} 61 | return sample 62 | 63 | def finalize(self): 64 | """ 65 | Finalize dataset after finish reading 66 | """ 67 | for fin in self._ori_fin: 68 | fin.close() 69 | 70 | def reset(self): 71 | """ 72 | reset the dataset 73 | """ 74 | self._pos = 0 75 | self._ori_fin = [UniIO(self._path[src]) for src in self._sources] 76 | self._fin = zip(*self._ori_fin) 77 | 78 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/streaming_text_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | 4 | from bycha.datasets import register_dataset 5 | from bycha.datasets.streaming_dataset import StreamingDataset 6 | from bycha.utils.io import UniIO 7 | 8 | 9 | @register_dataset 10 | class StreamingTextDataset(StreamingDataset): 11 | """ 12 | StreamingTextDataset is a streaming dataset for reading data in textual format. 13 | 14 | Args: 15 | path: path to load the data 16 | """ 17 | 18 | def __init__(self, 19 | path,): 20 | super().__init__(path) 21 | 22 | def build(self, collate_fn=None, preprocessed=False): 23 | """ 24 | Build input stream 25 | 26 | Args: 27 | collate_fn: callback defined by a specific task 28 | """ 29 | self._collate_fn = collate_fn 30 | 31 | if self._path: 32 | self._fin = UniIO(self._path) 33 | 34 | def __iter__(self): 35 | """ 36 | fetch next sample 37 | 38 | Returns: 39 | sample: next sample 40 | """ 41 | for sample in self._fin: 42 | try: 43 | sample = self._full_callback(sample) 44 | yield sample 45 | except StopIteration: 46 | raise StopIteration 47 | except Exception as e: 48 | logger.warning(e) 49 | 50 | def reset(self): 51 | """ 52 | reset the dataset 53 | """ 54 | self._pos = 0 55 | self._fin = UniIO(self._path) 56 | 57 | def _callback(self, sample): 58 | """ 59 | Callback for json data 60 | 61 | Args: 62 | sample: data in raw format 63 | 64 | Returns: 65 | sample (dict): a dict of samples consisting of parallel data of different sources 66 | """ 67 | sample = sample.strip('\n').strip() 68 | return sample 69 | 70 | def finalize(self): 71 | """ 72 | Finalize dataset after finish reading 73 | """ 74 | self._fin.close() 75 | 76 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/text_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | 4 | from bycha.datasets import register_dataset 5 | from bycha.datasets.in_memory_dataset import InMemoryDataset 6 | from bycha.utils.data import count_sample_token 7 | from bycha.utils.io import UniIO 8 | from bycha.utils.runtime import progress_bar 9 | 10 | 11 | @register_dataset 12 | class TextDataset(InMemoryDataset): 13 | """ 14 | TextDataset is an in-memory dataset for reading data in textual format. 15 | 16 | Args: 17 | path: path to load the data 18 | sort_samples (bool): sort samples before running a task. 19 | It would be useful in inference without degrading performance. 20 | max_size: maximum size of loaded data 21 | """ 22 | 23 | def __init__(self, 24 | path: str = None, 25 | sort_samples: bool = False, 26 | max_size: int = 0,): 27 | super().__init__(path, sort_samples=sort_samples, max_size=max_size) 28 | 29 | def _callback(self, sample): 30 | """ 31 | Callback for textual data 32 | 33 | Args: 34 | sample: data in raw format 35 | 36 | Returns: 37 | sample (dict): a dict of samples consisting of parallel data of different sources 38 | """ 39 | sample = sample.strip('\n') 40 | if self._preprocessed: 41 | sample = [eval(v) for v in sample.split()] 42 | return sample 43 | 44 | def _load(self): 45 | """ 46 | Preload all the data into memory. In the loading process, data are collate_fnd and sorted. 47 | """ 48 | fin = UniIO(path=self._path) 49 | self._data = [] 50 | accecpted, discarded = 0, 0 51 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')): 52 | if 0 < self._max_size <= i: 53 | break 54 | try: 55 | self._data.append(self._full_callback(sample)) 56 | accecpted += 1 57 | except Exception: 58 | logger.warning('sample {} is discarded'.format(sample)) 59 | discarded += 1 60 | if self._sort_samples: 61 | self._data.sort(key=lambda x: count_sample_token(x)) 62 | self._length = len(self._data) 63 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples') 64 | fin.close() 65 | -------------------------------------------------------------------------------- /mybycha/bycha/datasets/tfrecord_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from bycha.datasets import register_dataset 6 | from bycha.datasets.in_memory_dataset import InMemoryDataset 7 | from bycha.utils.data import count_sample_token 8 | from bycha.utils.runtime import progress_bar 9 | 10 | 11 | @register_dataset 12 | class TFRecordDataset(InMemoryDataset): 13 | """ 14 | A tfrecord dataset is an in-memory dataset for reading data in tfrecord 15 | 16 | Args: 17 | path: data path to read 18 | index_path: path to load data map 19 | description: decription for tfrecord 20 | sort_samples (bool): sort samples before running a task. 21 | It would be useful in inference without degrading performance. 22 | max_size: maximum size of loaded data 23 | """ 24 | 25 | def __init__(self, 26 | path, 27 | index_path=None, 28 | description=None, 29 | sort_samples=False, 30 | max_size=0): 31 | super().__init__(path, sort_samples=sort_samples, max_size=max_size) 32 | self._index_path = index_path 33 | self._description = description 34 | 35 | def _load(self): 36 | """ 37 | Preload all the data into memory. In the loading process, data are preprocessed and sorted. 38 | """ 39 | import tfrecord 40 | fin = tfrecord.tfrecord_loader(self._path, 41 | self._index_path, 42 | self._description) 43 | self._data = [] 44 | accecpted, discarded = 0, 0 45 | for i, sample in enumerate(progress_bar(fin, streaming=True, desc='Loading Samples')): 46 | if 0 < self._max_size <= i: 47 | break 48 | try: 49 | self._data.append(self._full_callback(sample)) 50 | accecpted += 1 51 | except Exception: 52 | logger.warning('sample {} is discarded'.format(sample)) 53 | discarded += 1 54 | if self._sort_samples: 55 | self._data.sort(key=lambda x: count_sample_token(x)) 56 | self._length = len(self._data) 57 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples') 58 | 59 | def _callback(self, sample): 60 | """ 61 | Callback for json data 62 | 63 | Args: 64 | sample: data in raw format 65 | 66 | Returns: 67 | sample (dict): a dict of samples consisting of parallel data of different sources 68 | """ 69 | sample = json.loads(sample) 70 | return sample 71 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/entries/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/entries/binarize_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from bycha.entries.util import parse_config 4 | from bycha.tasks import create_task, AbstractTask 5 | from bycha.utils.ops import recursive 6 | 7 | def main(): 8 | confs = parse_config() 9 | task = create_task(confs.pop('task')) 10 | assert isinstance(task, AbstractTask) 11 | task.build() 12 | dataloader = task._build_dataloader('train', mode='train') 13 | output_path = confs['output_path'] 14 | to_list = recursive(lambda x: x.tolist()) 15 | with open(output_path, 'w') as fout: 16 | for batch in dataloader: 17 | batch = to_list(batch) 18 | batch = json.dumps(batch) 19 | fout.write(f'{batch}\n') 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/build_tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from bycha.tokenizers import registry, AbstractTokenizer 3 | from bycha.utils.runtime import build_env 4 | from bycha.entries.util import parse_config 5 | 6 | 7 | def main(): 8 | configs = parse_config() 9 | if 'env' in configs: 10 | env_conf = configs.pop('env') 11 | build_env(configs, **env_conf) 12 | cls = registry[configs.pop('class').lower()] 13 | assert issubclass(cls, AbstractTokenizer) 14 | os.makedirs('/'.join(configs['output_path'].split('/')[:-1]), exist_ok=True) 15 | data = configs.pop('data') 16 | cls.learn(data, **configs) 17 | 18 | 19 | if __name__ == '__main__': 20 | main() 21 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/export.py: -------------------------------------------------------------------------------- 1 | from bycha.tasks import create_task 2 | from bycha.utils.runtime import build_env 3 | from bycha.entries.util import parse_config 4 | 5 | 6 | def main(): 7 | confs = parse_config() 8 | if 'env' in confs: 9 | build_env(confs['task'], **confs['env']) 10 | export_conf = confs.pop('export') 11 | task = create_task(confs.pop('task')) 12 | task.build() 13 | path = export_conf.pop("path") 14 | task.export(path, **export_conf) 15 | 16 | 17 | if __name__ == '__main__': 18 | main() 19 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/preprocess.py: -------------------------------------------------------------------------------- 1 | from bycha.entries.util import parse_config 2 | from bycha.tasks import create_task, AbstractTask 3 | from bycha.datasets import create_dataset, AbstractDataset 4 | 5 | 6 | def main(): 7 | confs = parse_config() 8 | task = create_task(confs.pop('task')) 9 | assert isinstance(task, AbstractTask) 10 | task.build() 11 | dataset_conf = confs['dataset'] 12 | for _, conf in confs['data'].items(): 13 | output_path = conf['output_path'] 14 | data_map_path = conf['data_map_path'] if 'data_map_path' in conf else None 15 | dataset_conf['path'] = conf['path'] 16 | dataset = create_dataset(dataset_conf) 17 | assert isinstance(dataset, AbstractDataset) 18 | dataset.build(collate_fn=task._data_collate_fn, preprocessed=False) 19 | dataset.write(path=output_path, data_map_path=data_map_path) 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/run.py: -------------------------------------------------------------------------------- 1 | from bycha.tasks import create_task 2 | from bycha.utils.runtime import build_env 3 | from bycha.entries.util import parse_config 4 | 5 | 6 | def main(): 7 | confs = parse_config() 8 | if 'env' in confs: 9 | build_env(confs['task'], **confs['env']) 10 | task = create_task(confs.pop('task')) 11 | task.build() 12 | task.run() 13 | 14 | 15 | if __name__ == '__main__': 16 | main() 17 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/serve.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | import euler 6 | 7 | from bycha.entries.util import parse_config 8 | from bycha.services import Server, Service 9 | from bycha.utils.runtime import build_env 10 | 11 | 12 | def main(): 13 | configs = parse_config() 14 | 15 | env = configs.pop('env') 16 | env['device'] = 'cpu' 17 | build_env(configs['task'], **env) 18 | 19 | server = Server(configs) 20 | app = euler.Server(Service) 21 | 22 | @app.register('serve') 23 | def serve(ctx, req): 24 | return server.serve(req) 25 | 26 | server_port = int(os.environ.get('SERVER_PORT', 18001)) 27 | logger.info('Starting thrift server in python on PORT {}...'.format(server_port)) 28 | app.run("tcp://0.0.0.0:{}".format(server_port), 29 | transport="buffered", 30 | workers_count=getattr(configs, 'worker', 8)) 31 | logger.info('exit!') 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/serve_model.py: -------------------------------------------------------------------------------- 1 | from bycha.entries.util import parse_config 2 | import os 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | from thriftpy.rpc import make_server 7 | import thriftpy 8 | 9 | from bycha.services.model_server import ModelServer 10 | from bycha.tasks import create_task 11 | from bycha.utils.runtime import build_env 12 | 13 | 14 | def main(): 15 | configs = parse_config() 16 | if 'env' in configs: 17 | build_env(configs['task'], **configs['env']) 18 | task = create_task(configs.pop('task')) 19 | task.build() 20 | generator = task._generator 21 | model = ModelServer(generator) 22 | grpc_port = int(os.environ.get('GRPC_PORT', 6000)) 23 | model_infer_thrift = thriftpy.load("/opt/tiger/ByCha/bycha/services/idls/model_infer.thrift", module_name="model_infer_thrift") 24 | server = make_server(model_infer_thrift.ModelInfer, 25 | model, 26 | 'localhost', 27 | grpc_port, 28 | client_timeout=None) 29 | logger.info('Starting Serving Model') 30 | server.serve() 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /mybycha/bycha/entries/util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | 4 | from bycha.utils.data import possible_eval 5 | from bycha.utils.io import UniIO 6 | 7 | 8 | def parse_config(): 9 | """ 10 | Parse configurations from config file and override arguments. 11 | Returns: 12 | 13 | """ 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--config', metavar='N', type=str, help='config path') 16 | parser.add_argument('--lib', metavar='N', default=None, type=str, help='customization package') 17 | args, unknown = parser.parse_known_args() 18 | with UniIO(args.config) as fin: 19 | confs = yaml.load(fin, Loader=yaml.FullLoader) 20 | stringizing(confs) 21 | kv_pairs = [] 22 | current_key = None 23 | for ele in unknown: 24 | if ele.startswith("--"): 25 | current_key = ele[2:] 26 | else: 27 | kv_pairs.append((current_key, ele)) 28 | for pair in kv_pairs: 29 | ks = pair[0].split(".") 30 | v = possible_eval(pair[1]) 31 | tmp = confs 32 | last_key = ks[-1] 33 | for k in ks[:-1]: 34 | if k not in tmp: 35 | tmp[k] = {} 36 | tmp = tmp[k] 37 | tmp[last_key] = v 38 | if args.lib: 39 | if 'env' not in confs: 40 | confs['env'] = {} 41 | custom_libs = [args.lib] 42 | if 'custom_libs' in confs['env']: 43 | custom_libs.append(confs['env']['custom_libs']) 44 | confs['env']['custom_libs'] = ','.join(custom_libs) 45 | return confs 46 | 47 | def stringizing(conf: dict): 48 | def _stringizing(def_dct: dict, conf_dct: dict): 49 | for k, v in conf_dct.items(): 50 | if isinstance(v, str): 51 | for def_k, def_v in def_dct.items(): 52 | if def_k in v: 53 | v = v.replace(def_k, def_v) 54 | conf_dct[k] = v 55 | if isinstance(v, dict): 56 | _stringizing(def_dct, v) 57 | 58 | if "define" in conf: 59 | definition = {"${" + k + "}": v for k,v in conf['define'].items()} 60 | conf.pop("define") 61 | _stringizing(definition, conf) 62 | -------------------------------------------------------------------------------- /mybycha/bycha/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_evaluator import AbstractEvaluator 7 | 8 | register_evaluator, create_evaluator, registry = setup_registry('evaluator', AbstractEvaluator) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.evaluators.' + module_name) 20 | -------------------------------------------------------------------------------- /mybycha/bycha/evaluators/abstract_evaluator.py: -------------------------------------------------------------------------------- 1 | class AbstractEvaluator: 2 | """ 3 | Evaluation scheduler 4 | """ 5 | 6 | def __init__(self, ): 7 | pass 8 | 9 | def build(self, *args, **kwargs): 10 | """ 11 | Build evaluator from the given configs and components 12 | """ 13 | raise NotImplementedError 14 | 15 | def finalize(self): 16 | """ 17 | Finalize evaluator after finishing evaluation 18 | """ 19 | raise NotImplementedError 20 | 21 | def _step_reset(self, *args, **kwargs): 22 | """ 23 | Reset states by step 24 | """ 25 | pass 26 | 27 | def _step(self, samples): 28 | """ 29 | Evaluate one batch of samples 30 | 31 | Args: 32 | samples: a batch of samples 33 | """ 34 | raise NotImplementedError 35 | 36 | def _step_update(self, *args, **kwargs): 37 | """ 38 | Update states by step 39 | """ 40 | pass 41 | 42 | def _eval_reset(self, *args, **kwargs): 43 | """ 44 | Reset states before the overall evaluation process 45 | """ 46 | pass 47 | 48 | def _eval_update(self, *args, **kwargs): 49 | """ 50 | Update states after the overall evaluation process 51 | """ 52 | pass 53 | 54 | def eval(self): 55 | """ 56 | Evaluation process 57 | """ 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /mybycha/bycha/evaluators/multi_evaluator.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from bycha.evaluators import AbstractEvaluator, register_evaluator, create_evaluator 4 | 5 | 6 | @register_evaluator 7 | class MultiTaskEvaluator(AbstractEvaluator): 8 | """ 9 | MultiTaskEvaluator for evaluation, which wrapped from Evaluator with different situation. 10 | 11 | Args: 12 | evaluators (dict): evaluator configurations for building multiple evaluators 13 | """ 14 | 15 | def __init__(self, 16 | evaluators: Dict, 17 | ): 18 | super().__init__() 19 | self._evaluator_configs = evaluators 20 | 21 | self._evaluators = None 22 | self._task_callback = None 23 | 24 | def build(self, generator, dataloaders, tokenizer, task_callback=None, postprocess=None): 25 | """ 26 | Build evaluators with given arguments. 27 | Arguments are dispatched to all the evaluators respectively. 28 | 29 | Args: 30 | generator (bycha.generators.AbstractGenerator): the inference model to generate hypothesis 31 | dataloaders (dict[bycha.dataloaders.AbstractDataLoader]): a set of dataloaders to evaluate 32 | tokenizer (bycha.tokenizers.AbstractTokenizer): a tokenizer 33 | task_callback: building context in task during for evaluation via a callback function 34 | postprocess: postprocess pipeline to obtain final hypothesis from predicted results (torch.Tensor) 35 | """ 36 | self._evaluators = {} 37 | for name, config in self._evaluator_configs.items(): 38 | self._evaluators[name.upper()] = create_evaluator(config) 39 | self._evaluators[name.upper()].build(generator=generator, 40 | dataloaders=dataloaders, 41 | tokenizer=tokenizer, 42 | task_callback=task_callback, 43 | postprocess=postprocess) 44 | self._task_callback = task_callback 45 | 46 | def eval(self): 47 | """ 48 | Perform evaluation for each task; 49 | """ 50 | scores = {} 51 | for name, evaluator in self._evaluators.items(): 52 | self._task_callback(training=False, infering=True) 53 | states = evaluator.eval() 54 | scores.update({'{}.{}'.format(name, key): val for key, val in states.items()}) 55 | return scores 56 | -------------------------------------------------------------------------------- /mybycha/bycha/generators/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_generator import AbstractGenerator 7 | 8 | register_generator, create_generator, registry = setup_registry('generator', AbstractGenerator) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.generators.' + module_name) 20 | 21 | -------------------------------------------------------------------------------- /mybycha/bycha/generators/abstract_generator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger(__name__) 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from bycha.utils.ops import inspect_fn 8 | from bycha.utils.runtime import Environment 9 | from bycha.utils.io import UniIO, mkdir 10 | 11 | 12 | class AbstractGenerator(nn.Module): 13 | """ 14 | AbstractGenerator wrap a model with inference algorithms. 15 | It can be directly exported and used for inference or serving. 16 | 17 | Args: 18 | path: path to restore traced model 19 | """ 20 | 21 | def __init__(self, path): 22 | super().__init__() 23 | 24 | self._path = path 25 | self._traced_model = None 26 | self._model = None 27 | self._mode = 'infer' 28 | 29 | def build(self, *args, **kwargs): 30 | """ 31 | Build or load a generator 32 | """ 33 | if self._path is not None: 34 | self.load() 35 | else: 36 | self.build_from_model(*args, **kwargs) 37 | 38 | self._env = Environment() 39 | if self._env.device.startswith('cuda'): 40 | logger.info('move model to {}'.format(self._env.device)) 41 | self.cuda(self._env.device) 42 | 43 | def build_from_model(self, *args, **kwargs): 44 | """ 45 | Build generator from model 46 | """ 47 | raise NotImplementedError 48 | 49 | def forward(self, *args, **kwargs): 50 | """ 51 | Infer a sample in evaluation mode. 52 | We auto detect whether the inference model is traced, and use appropriate model to perform inference. 53 | """ 54 | if self._traced_model is not None: 55 | return self._traced_model(*args, **kwargs) 56 | else: 57 | return self._forward(*args, **kwargs) 58 | 59 | def _forward(self, *args, **kwargs): 60 | """ 61 | Infer a sample in evaluation mode with torch model. 62 | """ 63 | raise NotImplementedError 64 | 65 | def export(self, path, net_input, **kwargs): 66 | """ 67 | Export self to `path` by export model directly 68 | 69 | Args: 70 | path: path to store serialized model 71 | net_input: fake net_input for tracing the model 72 | """ 73 | self.eval() 74 | with torch.no_grad(): 75 | logger.info('trace model {}'.format(self._model.__class__.__name__)) 76 | model = torch.jit.trace_module(self._model, {'forward': net_input}) 77 | mkdir(path) 78 | logger.info('save model to {}/model'.format(path)) 79 | with UniIO('{}/model'.format(path), 'wb') as fout: 80 | torch.jit.save(model, fout) 81 | 82 | def load(self): 83 | """ 84 | Load a serialized model from path 85 | """ 86 | logger.info('load model from {}'.format(self._path)) 87 | with UniIO(self._path, 'rb') as fin: 88 | self._traced_model = torch.jit.load(fin) 89 | 90 | def reset(self, *args, **kwargs): 91 | """ 92 | Reset generator states. 93 | """ 94 | pass 95 | 96 | @property 97 | def input_slots(self): 98 | """ 99 | Generator input slots that is auto-detected 100 | """ 101 | return inspect_fn(self._forward) 102 | -------------------------------------------------------------------------------- /mybycha/bycha/generators/generator.py: -------------------------------------------------------------------------------- 1 | from bycha.generators import AbstractGenerator, register_generator 2 | from bycha.utils.ops import inspect_fn 3 | 4 | 5 | @register_generator 6 | class Generator(AbstractGenerator): 7 | """ 8 | Generator wrap a model with inference algorithms. 9 | Generator has the same function and interface as model. 10 | It can be directly exported and used for inference or serving. 11 | 12 | Args: 13 | path: path to export or load generator 14 | is_regression: whether the task is a regression task 15 | """ 16 | 17 | def __init__(self, 18 | path=None, 19 | is_regression=False, 20 | is_binary_classification=False): 21 | super().__init__(path) 22 | self._is_regression = is_regression 23 | self._is_binary_classification = is_binary_classification 24 | 25 | self._model = None 26 | 27 | def build_from_model(self, model): 28 | """ 29 | Build generator from model 30 | 31 | Args: 32 | model (bycha.models.AbstractModel): a neural model 33 | """ 34 | self._model = model 35 | 36 | def _forward(self, *args): 37 | """ 38 | Infer a sample as model in evaluation mode, and predict results from logits predicted by model 39 | 40 | Args: 41 | inputs: inference inputs 42 | """ 43 | output = self._model(*args) 44 | if not self._is_regression: 45 | if self._is_binary_classification: 46 | output = (output > 0.5).long() 47 | else: 48 | _, output = output.max(dim=-1) 49 | return output 50 | 51 | def reset(self, mode): 52 | """ 53 | Reset generator states. 54 | 55 | Args: 56 | mode: running mode 57 | """ 58 | if mode != 'train': 59 | self.eval() 60 | self._mode = mode 61 | self._model.reset(mode) 62 | 63 | @property 64 | def model(self): 65 | return self._model 66 | 67 | @property 68 | def input_slots(self): 69 | """ 70 | Generator input slots that is auto-detected 71 | """ 72 | return inspect_fn(self._model.forward) 73 | -------------------------------------------------------------------------------- /mybycha/bycha/generators/self_contained_generator.py: -------------------------------------------------------------------------------- 1 | from bycha.generators import AbstractGenerator, register_generator 2 | 3 | 4 | @register_generator 5 | class SelfContainedGenerator(AbstractGenerator): 6 | """ 7 | SelfContainedGenerator use self-implemented generate function within model. 8 | Generator has the same function and interface as model. 9 | It can be directly exported and used for inference or serving. 10 | 11 | Args: 12 | path: path to export or load generator 13 | """ 14 | 15 | def __init__(self, 16 | path=None, **kwargs): 17 | super().__init__(path) 18 | self._kwargs = kwargs 19 | self._model = None 20 | 21 | def build_from_model(self, model, *args, **kwargs): 22 | """ 23 | Build generator from model 24 | 25 | Args: 26 | model (bycha.models.AbstractModel): a neural model 27 | """ 28 | self._model = model 29 | 30 | def _forward(self, *args, **kwargs): 31 | """ 32 | Infer a sample as model in evaluation mode, and predict results from logits predicted by model 33 | """ 34 | kwargs.update(self._kwargs) 35 | output = self._model.generate(*args, **kwargs) 36 | return output 37 | 38 | @property 39 | def model(self): 40 | return self._model 41 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_metric import AbstractMetric 7 | from .pairwise_metric import PairwiseMetric 8 | 9 | register_metric, create_metric, registry = setup_registry('metric', AbstractMetric) 10 | 11 | modules_dir = os.path.dirname(__file__) 12 | for file in os.listdir(modules_dir): 13 | path = os.path.join(modules_dir, file) 14 | if ( 15 | not file.startswith('_') 16 | and not file.startswith('.') 17 | and (file.endswith('.py') or os.path.isdir(path)) 18 | ): 19 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 20 | module = importlib.import_module('bycha.metrics.' + module_name) 21 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/abstract_metric.py: -------------------------------------------------------------------------------- 1 | class AbstractMetric: 2 | """ 3 | Metric evaluates the performance with produced hypotheses and references. 4 | """ 5 | 6 | def __init__(self): 7 | self._score = None 8 | 9 | def build(self, *args, **kwargs): 10 | """ 11 | Build metric 12 | """ 13 | self.reset() 14 | 15 | def reset(self): 16 | """ 17 | Reset metric for a new round of evaluation 18 | """ 19 | pass 20 | 21 | def add_all(self, *args, **kwargs): 22 | raise NotImplementedError 23 | 24 | def add(self, *args, **kwargs): 25 | """ 26 | Add parallel hypotheses and references to metric buffer 27 | """ 28 | raise NotImplementedError 29 | 30 | def eval(self): 31 | """ 32 | Evaluate the performance with buffered hypotheses and references. 33 | """ 34 | raise NotImplementedError 35 | 36 | def __len__(self): 37 | raise NotImplementedError 38 | 39 | def __getitem__(self, idx): 40 | raise NotImplementedError 41 | 42 | def get_item(self, idx, to_str=False): 43 | """ 44 | fetch a item at given index 45 | 46 | Args: 47 | idx: index of a pair of hypothesis and reference 48 | to_str: transform the pair to str format before return it 49 | 50 | Returns: 51 | item: a pair of item in tuple or string format 52 | """ 53 | raise NotImplementedError 54 | 55 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from bycha.metrics import PairwiseMetric, register_metric 2 | 3 | 4 | @register_metric 5 | class Accuracy(PairwiseMetric): 6 | """ 7 | Accuracy evaluates accuracy of produced hypotheses labels by comparing with references. 8 | """ 9 | 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def eval(self): 14 | """ 15 | Calculate the accuracy of produced hypotheses comparing with references 16 | Returns: 17 | score (float): evaluation score 18 | """ 19 | if self._score is not None: 20 | return self._score 21 | else: 22 | correct = 0 23 | for hypo, ref in zip(self.hypos, self.refs): 24 | correct += 1 if hypo == ref else 0 25 | self._score = correct / len(self.hypos) 26 | return self._score 27 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/bleu.py: -------------------------------------------------------------------------------- 1 | import sacrebleu 2 | 3 | from bycha.metrics import PairwiseMetric, register_metric 4 | 5 | 6 | @register_metric 7 | class BLEU(PairwiseMetric): 8 | """ 9 | BLEU evaluates BLEU scores of produced hypotheses by comparing with references. 10 | """ 11 | 12 | def __init__(self, no_tok=False, lang='en'): 13 | super().__init__() 14 | self._no_tok = no_tok 15 | self._lang = lang 16 | 17 | self._sacrebleu_kwargs = {} 18 | if self._no_tok: 19 | self._sacrebleu_kwargs['tokenize'] = 'none' 20 | else: 21 | self._sacrebleu_kwargs['tokenize'] = get_tokenize_by_lang(self._lang) 22 | 23 | def build(self, *args, **kwargs): 24 | """ 25 | Build metric 26 | """ 27 | self.reset() 28 | 29 | def add(self, hypo, ref): 30 | """ 31 | Add parallel hypotheses and references to metric buffer 32 | """ 33 | if isinstance(ref, str): 34 | ref = [ref] 35 | self._hypos.append(hypo) 36 | self._refs.append(ref) 37 | 38 | def eval(self): 39 | """ 40 | Evaluate the performance with buffered hypotheses and references. 41 | """ 42 | if self._score is not None: 43 | return self._score 44 | else: 45 | refs = list(zip(*self._refs)) 46 | bleu = sacrebleu.corpus_bleu(self._hypos, refs, **self._sacrebleu_kwargs) 47 | self._score = bleu.score 48 | return self._score 49 | 50 | 51 | def get_tokenize_by_lang(lang): 52 | if lang in ['zh']: 53 | return 'zh' 54 | elif lang in ['ko']: 55 | return 'char' 56 | else: 57 | return '13a' 58 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/f1.py: -------------------------------------------------------------------------------- 1 | from bycha.metrics import PairwiseMetric, register_metric 2 | 3 | 4 | @register_metric 5 | class F1(PairwiseMetric): 6 | """ 7 | F1 evaluates F1 of produced hypotheses labels by comparing with references. 8 | """ 9 | 10 | def __init__(self, target_label): 11 | super().__init__() 12 | self._target_label = target_label 13 | 14 | self._precision, self._recall = 0, 0 15 | 16 | def eval(self): 17 | """ 18 | Calculate the f1-score of produced hypotheses comparing with references 19 | Returns: 20 | score (float): evaluation score 21 | """ 22 | if self._score is not None: 23 | return self._score 24 | else: 25 | if isinstance(self._target_label, int): 26 | self._precision, self._recall = self._fast_precision_recall() 27 | else: 28 | self._precision, self._recall = self._precision_recall() 29 | self._score = self._precision * self._recall * 2 / (self._precision + self.recall) 30 | return self._score 31 | 32 | def _precision_recall(self): 33 | true_positive, false_positive, true_negative, false_negative = 1e-8, 0, 0, 0 34 | for hypo, ref in zip(self.hypos, self.refs): 35 | if ref == self._target_label: 36 | if hypo == ref: 37 | true_positive += 1 38 | else: 39 | false_negative += 1 40 | else: 41 | if hypo == ref: 42 | true_negative += 1 43 | else: 44 | false_positive += 1 45 | precision = true_positive / (true_positive + false_positive) 46 | recall = true_positive / (true_positive + false_negative) 47 | return precision, recall 48 | 49 | def _fast_precision_recall(self): 50 | import torch 51 | hypos = torch.LongTensor(self.hypos) 52 | refs = torch.LongTensor(self.refs) 53 | 54 | from bycha.utils.runtime import Environment 55 | env = Environment() 56 | if env.device.startswith('cuda'): 57 | hypos, refs = hypos.cuda(), refs.cuda() 58 | with torch.no_grad(): 59 | true_mask = refs.eq(self._target_label) 60 | pos_mask = hypos.eq(self._target_label) 61 | true_positive = true_mask.masked_fill(~pos_mask, False).long().sum().data.item() + 1e-8 62 | false_positive = (~true_mask).masked_fill(~pos_mask, False).long().sum().data.item() 63 | false_negative = true_mask.masked_fill(pos_mask, False).long().sum().data.item() 64 | precision = true_positive / (true_positive + false_positive) 65 | recall = true_positive / (true_positive + false_negative) 66 | return precision, recall 67 | 68 | @property 69 | def precision(self): 70 | return self._precision 71 | 72 | @property 73 | def recall(self): 74 | return self._recall 75 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/matthews_corr.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics import matthews_corrcoef 2 | 3 | from bycha.metrics import PairwiseMetric, register_metric 4 | 5 | 6 | @register_metric 7 | class MatthewsCorr(PairwiseMetric): 8 | """ 9 | MatthewsCorr evaluates matthews correlation of produced hypotheses labels by comparing with references. 10 | """ 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def eval(self): 16 | """ 17 | Calculate the spearman correlation of produced hypotheses comparing with references 18 | Returns: 19 | score (float): evaluation score 20 | """ 21 | if self._score is not None: 22 | return self._score 23 | else: 24 | self._score = matthews_corrcoef(self._refs, self._hypos) 25 | return self._score 26 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/pairwise_metric.py: -------------------------------------------------------------------------------- 1 | from bycha.metrics import AbstractMetric 2 | 3 | 4 | class PairwiseMetric(AbstractMetric): 5 | """ 6 | PairwiseMtric evaluates pairwise comparison between refs and hypos. 7 | """ 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self._hypos, self._refs = [], [] 12 | 13 | def reset(self): 14 | """ 15 | Reset metric for a new round of evaluation 16 | """ 17 | self._hypos.clear() 18 | self._refs.clear() 19 | self._score = None 20 | 21 | def add_all(self, hypos, refs): 22 | """ 23 | Add all hypos and refs 24 | """ 25 | for hypo, ref in zip(hypos, refs): 26 | self.add(hypo, ref) 27 | 28 | def add(self, hypo, ref): 29 | """ 30 | Add parallel hypotheses and references to metric buffer 31 | """ 32 | self._hypos.append(hypo) 33 | self._refs.append(ref) 34 | 35 | def eval(self): 36 | """ 37 | Evaluate the performance with buffered hypotheses and references. 38 | """ 39 | raise NotImplementedError 40 | 41 | def __len__(self): 42 | return len(self._hypos) 43 | 44 | def __getitem__(self, idx): 45 | return self._hypos[idx], self._refs[idx] 46 | 47 | def get_item(self, idx, to_str=False): 48 | """ 49 | fetch a item at given index 50 | 51 | Args: 52 | idx: index of a pair of hypothesis and reference 53 | to_str: transform the pair to str format before return it 54 | 55 | Returns: 56 | item: a pair of item in tuple or string format 57 | """ 58 | ret = self[idx] 59 | if to_str: 60 | ret = '\n\tHypothesis: {}\n\tGround Truth: {}\n'.format(ret[0], ret[1]) 61 | return ret 62 | 63 | @property 64 | def hypos(self): 65 | return self._hypos 66 | 67 | @property 68 | def refs(self): 69 | return self._refs 70 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/pearson_corr.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import pearsonr 2 | import numpy as np 3 | 4 | from bycha.metrics import PairwiseMetric, register_metric 5 | 6 | 7 | @register_metric 8 | class PearsonCorr(PairwiseMetric): 9 | """ 10 | PearsonCorr evaluates pearson's correlation of produced hypotheses labels by comparing with references. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def eval(self): 17 | """ 18 | Calculate the spearman correlation of produced hypotheses comparing with references 19 | Returns: 20 | score (float): evaluation score 21 | """ 22 | if self._score is not None: 23 | return self._score 24 | else: 25 | self._score = pearsonr(np.array(self._hypos), np.array(self._refs))[0] 26 | return self._score 27 | 28 | -------------------------------------------------------------------------------- /mybycha/bycha/metrics/spearman_corr.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import spearmanr 2 | import numpy as np 3 | 4 | from bycha.metrics import PairwiseMetric, register_metric 5 | 6 | 7 | @register_metric 8 | class SpearmanCorr(PairwiseMetric): 9 | """ 10 | SpearmanCorr evaluates spearman's correlation of produced hypotheses labels by comparing with references. 11 | """ 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def eval(self): 17 | """ 18 | Calculate the spearman correlation of produced hypotheses comparing with references 19 | Returns: 20 | score (float): evaluation score 21 | """ 22 | if self._score is not None: 23 | return self._score 24 | else: 25 | self._score = spearmanr(np.array(self._hypos), np.array(self._refs))[0] 26 | return self._score 27 | -------------------------------------------------------------------------------- /mybycha/bycha/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_model import AbstractModel 7 | 8 | register_model, create_model, registry = setup_registry('model', AbstractModel) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.models.' + model_name) 20 | -------------------------------------------------------------------------------- /mybycha/bycha/models/abstract_encoder_decoder_model.py: -------------------------------------------------------------------------------- 1 | from bycha.models import AbstractModel 2 | 3 | 4 | class AbstractEncoderDecoderModel(AbstractModel): 5 | """ 6 | AbstractEncoderDecoderModel defines interface for encoder-decoder model. 7 | It must contains two attributes: encoder and decoder. 8 | """ 9 | 10 | def __init__(self, path, *args, **kwargs): 11 | super().__init__(path) 12 | self._args = args 13 | self._kwargs = kwargs 14 | 15 | self._encoder, self._decoder = None, None 16 | 17 | @property 18 | def encoder(self): 19 | return self._encoder 20 | 21 | @property 22 | def decoder(self): 23 | return self._decoder 24 | -------------------------------------------------------------------------------- /mybycha/bycha/models/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | 5 | models_dir = os.path.dirname(__file__) 6 | for file in os.listdir(models_dir): 7 | path = os.path.join(models_dir, file) 8 | if ( 9 | not file.startswith('_') 10 | and not file.startswith('.') 11 | and (file.endswith('.py') or os.path.isdir(path)) 12 | ): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('bycha.models.huggingface.' + model_name) 15 | -------------------------------------------------------------------------------- /mybycha/bycha/models/huggingface/huggingface_extractive_question_answering_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModelForQuestionAnswering 2 | 3 | from bycha.models import register_model 4 | from bycha.models.abstract_encoder_decoder_model import AbstractEncoderDecoderModel 5 | from bycha.modules.layers.classifier import HuggingfaceClassifier 6 | 7 | 8 | @register_model 9 | class HuggingfaceExtractiveQuestionAnsweringModel(AbstractEncoderDecoderModel): 10 | """ 11 | HuggingfaceExtractiveQuestionAnsweringModel is a extractive question answering model built on 12 | huggingface extractive question answering models. 13 | 14 | Args: 15 | pretrained_model: pretrained_model in huggingface 16 | has_answerable: has answerable problem 17 | path: path to restore model 18 | """ 19 | 20 | def __init__(self, pretrained_model, has_answerable=False, path=None): 21 | super().__init__(path=path) 22 | self._pretrained_model = pretrained_model 23 | self._has_answerable = has_answerable 24 | 25 | self._config = None 26 | self._model = None 27 | self._special_tokens = None 28 | self._encoder, self._decoder = None, None 29 | if self._has_answerable: 30 | self._classification_head = None 31 | 32 | def _build(self, vocab_size, special_tokens): 33 | """ 34 | Build model with vocabulary size and special tokens 35 | 36 | Args: 37 | vocab_size: vocabulary size of input sequence 38 | special_tokens: special tokens of input sequence 39 | """ 40 | self._config = AutoConfig.from_pretrained(self._pretrained_model) 41 | self._model = AutoModelForQuestionAnswering.from_pretrained(self._pretrained_model, config=self._config,) 42 | self._special_tokens = special_tokens 43 | 44 | if self._has_answerable: 45 | self._classification_head = HuggingfaceClassifier(self._model.d_model, 2) 46 | 47 | def forward(self, input, answerable=None, start_positions=None, end_positions=None): 48 | """ 49 | Compute output with neural input 50 | 51 | Args: 52 | input: input sequence 53 | answerable: gold answerable 54 | start_positions: gold start position 55 | end_positions: gold end position 56 | 57 | Returns: 58 | - log probability of start and end position 59 | """ 60 | output = self._model(input, 61 | attention_mask=input.ne(self._special_tokens['pad']), 62 | start_positions=start_positions, 63 | end_positions=end_positions) 64 | return output 65 | 66 | def loss(self, input, answerable=None, start_positions=None, end_positions=None): 67 | """ 68 | Compute loss from network inputs 69 | 70 | Args: 71 | input: input sequence 72 | answerable: gold answerable 73 | start_positions: gold start position 74 | end_positions: gold end position 75 | 76 | Returns: 77 | - loss 78 | """ 79 | output = self(input, answerable, start_positions, end_positions) 80 | return output[0] 81 | 82 | -------------------------------------------------------------------------------- /mybycha/bycha/models/huggingface/huggingface_pretrain_bart_model.py: -------------------------------------------------------------------------------- 1 | from transformers import BartConfig, BartForConditionalGeneration 2 | 3 | from bycha.models import register_model 4 | from bycha.models.abstract_model import AbstractModel 5 | 6 | 7 | @register_model 8 | class HuggingfacePretrainBartModel(AbstractModel): 9 | """ 10 | HuggingfacePretrainBartModel is a pretrained bart model built on 11 | huggingface pretrained bart models. 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self._config = None 18 | self._model = None 19 | self._special_tokens = None 20 | 21 | def _build(self, vocab_size, special_tokens): 22 | """ 23 | Build model with vocabulary size and special tokens 24 | 25 | Args: 26 | vocab_size: vocabulary size of input sequence 27 | special_tokens: special tokens of input sequence 28 | """ 29 | self._config = BartConfig(vocab_size=vocab_size, pad_token_id=special_tokens['pad']) 30 | self._model = BartForConditionalGeneration(self._config) 31 | self._special_tokens = special_tokens 32 | 33 | def forward(self, enc_input, dec_input): 34 | """ 35 | Compute output with neural input 36 | 37 | Args: 38 | enc_input: encoder input sequence 39 | dec_input: decoder input sequence 40 | 41 | Returns: 42 | - log probability of next tokens in sequences 43 | """ 44 | output = self._model(enc_input, 45 | attention_mask=enc_input.ne(self._special_tokens['pad']), 46 | decoder_input_ids=dec_input, 47 | use_cache=self._mode == 'infer') 48 | output = output[0] 49 | return output 50 | -------------------------------------------------------------------------------- /mybycha/bycha/models/huggingface/huggingface_pretrain_mbart_model.py: -------------------------------------------------------------------------------- 1 | from transformers import MBartConfig, MBartForConditionalGeneration 2 | 3 | from bycha.models import register_model 4 | from bycha.models.abstract_model import AbstractModel 5 | 6 | 7 | @register_model 8 | class HuggingfacePretrainMBartModel(AbstractModel): 9 | """ 10 | HuggingfacePretrainBartModel is a pretrained bart model built on 11 | huggingface pretrained bart models. 12 | """ 13 | 14 | def __init__(self, path=None, pretrained_path=None): 15 | super().__init__(path) 16 | 17 | self._config = None 18 | self._model = None 19 | self._special_tokens = None 20 | self._pretrained_path = pretrained_path 21 | 22 | def _build(self, vocab_size, special_tokens): 23 | """ 24 | Build model with vocabulary size and special tokens 25 | 26 | Args: 27 | vocab_size: vocabulary size of input sequence 28 | special_tokens: special tokens of input sequence 29 | """ 30 | self._config = MBartConfig(vocab_size=vocab_size, pad_token_id=special_tokens['pad']) 31 | if self._pretrained_path: 32 | self._model = MBartForConditionalGeneration(self._config).from_pretrained(self._pretrained_path) 33 | else: 34 | self._model = MBartForConditionalGeneration(self._config) 35 | self._special_tokens = special_tokens 36 | 37 | def forward(self, src, tgt): 38 | """ 39 | Compute output with neural input 40 | 41 | Args: 42 | src: encoder input sequence 43 | tgt: decoder input sequence 44 | 45 | Returns: 46 | - log probability of next tokens in sequences 47 | """ 48 | output = self._model(src, 49 | attention_mask=src.ne(self._special_tokens['pad']), 50 | decoder_input_ids=tgt, 51 | use_cache=self._mode == 'infer') 52 | output = output[0] 53 | return output 54 | 55 | def generate(self, src, tgt_langtok_id, max_length, beam): 56 | return self._model.generate(input_ids=src, decoder_start_token_id=tgt_langtok_id, max_length=max_length, num_beams=beam) 57 | -------------------------------------------------------------------------------- /mybycha/bycha/models/huggingface/huggingface_sequence_classification_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModelForSequenceClassification 2 | 3 | from bycha.models import register_model 4 | from bycha.models.abstract_model import AbstractModel 5 | 6 | 7 | @register_model 8 | class HuggingfaceSequenceClassificationModel(AbstractModel): 9 | """ 10 | HuggingfaceSequenceClassificationModel is a sequence classification architecture built on 11 | huggingface sequence classification models. 12 | 13 | Args: 14 | pretrained_model: pretrained_model in huggingface 15 | num_labels: number of labels 16 | """ 17 | 18 | def __init__(self, pretrained_model, num_labels=2): 19 | super().__init__() 20 | self._pretrained_model = pretrained_model 21 | self._num_labels = num_labels 22 | 23 | self._config = None 24 | self._model = None 25 | self._special_tokens = None 26 | 27 | def _build(self, vocab_size, special_tokens): 28 | """ 29 | Build model with vocabulary size and special tokens 30 | 31 | Args: 32 | vocab_size: vocabulary size of input sequence 33 | special_tokens: special tokens of input sequence 34 | """ 35 | self._config = AutoConfig.from_pretrained( 36 | self._pretrained_model, 37 | num_labels=self._num_labels 38 | ) 39 | self._model = AutoModelForSequenceClassification.from_pretrained( 40 | self._pretrained_model, 41 | config=self._config, 42 | ) 43 | self._special_tokens = special_tokens 44 | 45 | def forward(self, input): 46 | """ 47 | Compute output with neural input 48 | 49 | Args: 50 | input: input source sequences 51 | 52 | Returns: 53 | - log probability of labels 54 | """ 55 | output = self._model(input, attention_mask=input.ne(self._special_tokens['pad'])) 56 | output = output.logits if self._num_labels > 1 else output.logits.squeeze(dim=-1) 57 | return output 58 | -------------------------------------------------------------------------------- /mybycha/bycha/models/seq2seq.py: -------------------------------------------------------------------------------- 1 | from bycha.models import register_model 2 | from bycha.models.encoder_decoder_model import EncoderDecoderModel 3 | 4 | 5 | @register_model 6 | class Seq2Seq(EncoderDecoderModel): 7 | """ 8 | EncoderDecoderModel defines overall encoder-decoder architecture. 9 | 10 | Args: 11 | encoder: encoder configurations to build an encoder 12 | decoder: decoder configurations to build an decoder 13 | d_model: feature embedding 14 | share_embedding: how the embedding is share [all, decoder-input-output, None]. 15 | `all` indicates that source embedding, target embedding and target 16 | output projection are the same. 17 | `decoder-input-output` indicates that only target embedding and target 18 | output projection are the same. 19 | `None` indicates that none of them are the same. 20 | path: path to restore model 21 | """ 22 | 23 | def __init__(self, 24 | encoder, 25 | decoder, 26 | d_model, 27 | share_embedding=None, 28 | path=None): 29 | super().__init__(encoder=encoder, 30 | decoder=decoder, 31 | d_model=d_model, 32 | share_embedding=share_embedding, 33 | path=path) 34 | 35 | def forward(self, src, tgt): 36 | """ 37 | Compute output with neural input 38 | 39 | Args: 40 | src: source sequence 41 | tgt: previous tokens at target side, which is a time-shifted target sequence in training 42 | 43 | Returns: 44 | - log probability of next token at target side 45 | """ 46 | memory, memory_padding_mask = self._encoder(src=src) 47 | logits = self._decoder(tgt=tgt, 48 | memory=memory, 49 | memory_padding_mask=memory_padding_mask) 50 | return logits 51 | -------------------------------------------------------------------------------- /mybycha/bycha/models/sequence_classification_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bycha.models import AbstractModel, register_model 4 | from bycha.modules.encoders import create_encoder 5 | from bycha.modules.layers.embedding import Embedding 6 | from bycha.modules.layers.classifier import HuggingfaceClassifier 7 | 8 | 9 | @register_model 10 | class SequenceClassificationModel(AbstractModel): 11 | """ 12 | SequenceClassificationModel is a general sequence classification architecture consisting of 13 | one encoder and one classifier. 14 | 15 | Args: 16 | encoder: encoder configuration 17 | labels: number of labels 18 | dropout: dropout 19 | source_num: the number of input source sequence 20 | path: path to restore model 21 | """ 22 | 23 | def __init__(self, 24 | encoder, 25 | labels, 26 | dropout=0., 27 | source_num=1, 28 | path=None): 29 | super().__init__(path) 30 | self._encoder_config = encoder 31 | self._labels = labels 32 | self._source_num = source_num 33 | self._dropout = dropout 34 | 35 | self._encoder, self._classifier = None, None 36 | self._path = path 37 | 38 | def _build(self, vocab_size, special_tokens): 39 | """ 40 | Build model with vocabulary size and special tokens 41 | 42 | Args: 43 | vocab_size: vocabulary size of input sequence 44 | special_tokens: special tokens of input sequence 45 | """ 46 | self._build_encoder(vocab_size=vocab_size, special_tokens=special_tokens) 47 | self._build_classifier() 48 | 49 | def _build_encoder(self, vocab_size, special_tokens): 50 | """ 51 | Build encoder with vocabulary size and special tokens 52 | 53 | Args: 54 | vocab_size: vocabulary size of input sequence 55 | special_tokens: special tokens of input sequence 56 | """ 57 | self._encoder = create_encoder(self._encoder_config) 58 | embed = Embedding(vocab_size=vocab_size, 59 | d_model=self.encoder.d_model, 60 | padding_idx=special_tokens['pad']) 61 | self._encoder.build(embed=embed, special_tokens=special_tokens) 62 | 63 | def _build_classifier(self): 64 | """ 65 | Build classifer on label space 66 | """ 67 | self._classifier = HuggingfaceClassifier(self.encoder.out_dim * self._source_num, self._labels, dropout=self._dropout) 68 | 69 | @property 70 | def encoder(self): 71 | return self._encoder 72 | 73 | @property 74 | def classifier(self): 75 | return self._classifier 76 | 77 | def forward(self, *inputs): 78 | """ 79 | Compute output with neural input 80 | 81 | Args: 82 | *inputs: input source sequences 83 | 84 | Returns: 85 | - log probability of labels 86 | """ 87 | x = [self.encoder(t)[-1] for t in inputs] 88 | x = torch.cat(x, dim=-1) 89 | logits = self.classifier(x) 90 | return logits 91 | 92 | -------------------------------------------------------------------------------- /mybycha/bycha/models/variational_auto_encoder.py: -------------------------------------------------------------------------------- 1 | from bycha.models import register_model 2 | from bycha.models.seq2seq import Seq2Seq 3 | 4 | 5 | @register_model 6 | class VariationalAutoEncoders(Seq2Seq): 7 | """ 8 | VariationalAutoEncoders is an extension to Seq2Seq model with latent space . 9 | 10 | Args: 11 | encoder: encoder configurations to build an encoder 12 | decoder: decoder configurations to build an decoder 13 | d_model: feature embedding 14 | share_embedding: how the embedding is share [all, decoder-input-output, None]. 15 | `all` indicates that source embedding, target embedding and target 16 | output projection are the same. 17 | `decoder-input-output` indicates that only target embedding and target 18 | output projection are the same. 19 | `None` indicates that none of them are the same. 20 | path: path to restore model 21 | """ 22 | 23 | def __init__(self, 24 | encoder, 25 | decoder, 26 | d_model, 27 | share_embedding=None, 28 | path=None, 29 | ): 30 | super().__init__(encoder=encoder, 31 | decoder=decoder, 32 | d_model=d_model, 33 | share_embedding=share_embedding, 34 | path=path) 35 | 36 | def reg_loss(self): 37 | """ 38 | Auto-Encoding regularization loss 39 | 40 | Returns: 41 | - KL loss between prior and posterior 42 | """ 43 | return self.encoder.reg_loss() 44 | 45 | def nll(self, rec_loss, reg_losses, method="elbo"): 46 | """ 47 | NLL loss 48 | 49 | Args: 50 | rec_loss: reconstruction loss 51 | reg_losses: regularization loss 52 | method: generation method 53 | 54 | Returns: 55 | - NLL loss 56 | """ 57 | return self.encoder.nll(rec_loss, reg_losses, method) 58 | 59 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/modules/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/modules/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_decoder import AbstractDecoder 7 | 8 | register_decoder, create_decoder, registry = setup_registry('decoder', AbstractDecoder) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.modules.decoders.' + module_name) 20 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/decoders/abstract_decoder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | 4 | class AbstractDecoder(Module): 5 | """ 6 | AbstractEncoder is the abstract for encoders, and defines general interface for encoders. 7 | 8 | Args: 9 | name: encoder name 10 | """ 11 | 12 | def __init__(self, name=None): 13 | super().__init__() 14 | self._name = name 15 | self._cache = {} 16 | self._mode = 'train' 17 | 18 | def build(self, *args, **kwargs): 19 | """ 20 | Build decoder with task instance 21 | """ 22 | raise NotImplementedError 23 | 24 | def forward(self, *args, **kwargs): 25 | """ 26 | Process forward of decoder. 27 | """ 28 | raise NotImplementedError 29 | 30 | def reset(self, mode): 31 | """ 32 | Reset encoder and switch running mode 33 | 34 | Args: 35 | mode: running mode in [train, valid, infer] 36 | """ 37 | self._cache.clear() 38 | self._mode = mode 39 | 40 | def get_cache(self): 41 | """ 42 | Retrieve inner cache 43 | 44 | Returns: 45 | - cached states as a Dict 46 | """ 47 | return self._cache 48 | 49 | def set_cache(self, cache): 50 | """ 51 | Set cache from outside 52 | 53 | Args: 54 | cache: cache dict from outside 55 | """ 56 | self._cache = cache 57 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/decoders/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract_decoder_layer import AbstractDecoderLayer 2 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/decoders/layers/abstract_decoder_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class AbstractDecoderLayer(nn.Module): 8 | """ 9 | AbstractDecoderLayer is an abstract class for decoder layers. 10 | """ 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self._cache = dict() 15 | self._mode = 'train' 16 | self._dummy_param = nn.Parameter(torch.empty(0)) 17 | 18 | def reset(self, mode: str): 19 | """ 20 | Reset encoder layer and switch running mode 21 | 22 | Args: 23 | mode: running mode in [train, valid, infer] 24 | """ 25 | self._cache: Dict[str, torch.Tensor] = {"prev": self._dummy_param} 26 | self._mode = mode 27 | 28 | def _update_cache(self, *args, **kwargs): 29 | """ 30 | Update cache with current states 31 | """ 32 | pass 33 | 34 | def get_cache(self): 35 | """ 36 | Retrieve inner cache 37 | 38 | Returns: 39 | - cached states as a Dict 40 | """ 41 | return self._cache 42 | 43 | def set_cache(self, cache: Dict[str, torch.Tensor]): 44 | """ 45 | Set cache from outside 46 | 47 | Args: 48 | cache: cache dict from outside 49 | """ 50 | self._cache = cache 51 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_encoder import AbstractEncoder 7 | 8 | register_encoder, create_encoder, registry = setup_registry('encoder', AbstractEncoder) 9 | 10 | 11 | modules_dir = os.path.dirname(__file__) 12 | for file in os.listdir(modules_dir): 13 | path = os.path.join(modules_dir, file) 14 | if ( 15 | not file.startswith('_') 16 | and not file.startswith('.') 17 | and (file.endswith('.py') or os.path.isdir(path)) 18 | ): 19 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 20 | module = importlib.import_module('bycha.modules.encoders.' + module_name) 21 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/abstract_encoder.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | 4 | class AbstractEncoder(Module): 5 | """ 6 | AbstractEncoder is the abstract for encoders, and defines general interface for encoders. 7 | 8 | Args: 9 | name: encoder name 10 | """ 11 | 12 | def __init__(self, name=None): 13 | super().__init__() 14 | self._name = name 15 | self._cache = {} 16 | self._mode = 'train' 17 | 18 | def build(self, *args, **kwargs): 19 | """ 20 | Build encoder with task instance 21 | """ 22 | raise NotImplementedError 23 | 24 | def forward(self, *args, **kwargs): 25 | """ 26 | Process forward of encoder. Outputs are cached until the encoder is reset. 27 | """ 28 | if self._mode == 'train': 29 | if 'out' not in self._cache: 30 | out = self._forward(*args, **kwargs) 31 | self._cache['out'] = out 32 | return self._cache['out'] 33 | else: 34 | return self._forward(*args, **kwargs) 35 | 36 | def _forward(self, *args, **kwargs): 37 | """ 38 | Forward function to override. Its results can be auto cached in forward. 39 | """ 40 | raise NotImplementedError 41 | 42 | @property 43 | def name(self): 44 | return self._name 45 | 46 | @property 47 | def d_model(self): 48 | raise NotImplementedError 49 | 50 | @property 51 | def out_dim(self): 52 | raise NotImplementedError 53 | 54 | def _cache_states(self, name, state): 55 | """ 56 | Cache a state into encoder cache 57 | 58 | Args: 59 | name: state key 60 | state: state value 61 | """ 62 | self._cache[name] = state 63 | 64 | def reset(self, mode): 65 | """ 66 | Reset encoder and switch running mode 67 | 68 | Args: 69 | mode: running mode in [train, valid, infer] 70 | """ 71 | self._cache.clear() 72 | self._mode = mode 73 | 74 | def set_cache(self, cache): 75 | """ 76 | Set cache from outside 77 | 78 | Args: 79 | cache: cache dict from outside 80 | """ 81 | self._cache = cache 82 | 83 | def get_cache(self): 84 | """ 85 | Retrieve inner cache 86 | 87 | Returns: 88 | - cached states as a Dict 89 | """ 90 | return self._cache 91 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/huggingface_encoder.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from transformers import AutoConfig, AutoModel 3 | 4 | from bycha.modules.encoders import AbstractEncoder, register_encoder 5 | 6 | 7 | @register_encoder 8 | class HuggingFaceEncoder(AbstractEncoder): 9 | """ 10 | HuggingFaceEncoder is a wrapped encoder from huggingface pretrained models 11 | 12 | Args: 13 | pretrained_model: name of pretrained model, see huggingface for supported models 14 | freeze: freeze pretrained model in training 15 | return_seed: return with sequence representation 16 | name: encoder name 17 | """ 18 | 19 | def __init__(self, 20 | pretrained_model, 21 | freeze=False, 22 | return_seed=False, 23 | name=None): 24 | super().__init__(name=name) 25 | self._pretrained_model_name = pretrained_model 26 | self._freeze = freeze 27 | self._return_seed = return_seed 28 | 29 | self._special_tokens = None 30 | self._configs = None 31 | self._huggingface_model = None 32 | 33 | def build(self, special_tokens=None, vocab_size=None): 34 | """ 35 | Build computational graph 36 | 37 | Args: 38 | special_tokens: special_tokens: special tokens defined in vocabulary 39 | vocab_size: vocabulary size of embedding 40 | """ 41 | self._special_tokens = special_tokens 42 | self._configs = AutoConfig.from_pretrained(self._pretrained_model_name) 43 | self._huggingface_model = AutoModel.from_config(self._configs) 44 | if self._freeze: 45 | self.freeze_params() 46 | 47 | assert self._configs.vocab_size == vocab_size 48 | 49 | def freeze_params(self): 50 | """ 51 | Freeze parameters of pretrained model 52 | """ 53 | for param in self._huggingface_model.base_model.parameters(): 54 | param.requires_grad = False 55 | 56 | def _forward(self, text: Tensor): 57 | r""" 58 | Args: 59 | text: tokens in src side. 60 | :math:`(N, S)` where N is the batch size, S is the source sequence length. 61 | 62 | Returns: 63 | - source token hidden representation. 64 | :math:`(S, N, E)` where S is the source sequence length, N is the batch size, 65 | E is the embedding size. 66 | """ 67 | padding_mask = text.eq(self._padding_idx) 68 | model_out = self._huggingface_model(text, ~padding_mask) 69 | try: 70 | x, seed = model_out['last_hidden_state'], model_out['pooler_output'] 71 | except: 72 | x, seed = model_out 73 | finally: 74 | x = x.transpose(0, 1) 75 | if self._return_seed: 76 | return x, padding_mask, seed 77 | else: 78 | return x, padding_mask 79 | 80 | @property 81 | def out_dim(self): 82 | return self._configs.hidden_size 83 | 84 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract_encoder_layer import AbstractEncoderLayer 2 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/layers/abstract_encoder_layer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class AbstractEncoderLayer(nn.Module): 5 | """ 6 | AbstractEncoderLayer is an abstract class for encoder layers. 7 | """ 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self._cache = {} 12 | self._mode = 'train' 13 | 14 | def reset(self, mode): 15 | """ 16 | Reset encoder layer and switch running mode 17 | 18 | Args: 19 | mode: running mode in [train, valid, infer] 20 | """ 21 | self._mode = mode 22 | self._cache.clear() 23 | 24 | def _update_cache(self, *args, **kwargs): 25 | """ 26 | Update internal cache from outside states 27 | """ 28 | pass 29 | 30 | def get_cache(self): 31 | """ 32 | Retrieve inner cache 33 | 34 | Returns: 35 | - cached states as a Dict 36 | """ 37 | return self._cache 38 | 39 | def set_cache(self, cache): 40 | """ 41 | Set cache from outside 42 | 43 | Args: 44 | cache: cache dict from outside 45 | """ 46 | self._cache = cache 47 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/layers/moe_encoder_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from torch import Tensor 4 | from torch import nn 5 | 6 | from bycha.modules.encoders.layers import AbstractEncoderLayer 7 | from bycha.modules.layers.moe import MoE 8 | 9 | 10 | class MoEEncoderLayer(AbstractEncoderLayer): 11 | """ 12 | TransformerEncoderLayer performs one layer of transformer operation, namely self-attention and feed-forward network. 13 | 14 | Args: 15 | d_model: feature dimension 16 | nhead: head numbers of multihead attention 17 | dim_feedforward: dimensionality of inner vector space 18 | dropout: dropout rate 19 | activation: activation function used in feed-forward network 20 | normalize_before: use pre-norm fashion, default as post-norm. 21 | Pre-norm suit deep nets while post-norm achieve better results when nets are shallow. 22 | """ 23 | 24 | def __init__(self, 25 | d_model, 26 | nhead, 27 | dim_feedforward=2048, 28 | dropout=0.1, 29 | attention_dropout=0, 30 | activation="relu", 31 | normalize_before=False, 32 | num_experts=1, 33 | sparse=True,): 34 | super(MoEEncoderLayer, self).__init__() 35 | self.normalize_before = normalize_before 36 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=attention_dropout) 37 | 38 | self.moe = MoE(d_model, dim_feedforward=dim_feedforward, activation=activation, 39 | num_experts=num_experts, sparse=sparse) 40 | 41 | self.self_attn_norm = nn.LayerNorm(d_model) 42 | self.ffn_norm = nn.LayerNorm(d_model) 43 | self.dropout1 = nn.Dropout(dropout) 44 | self.dropout2 = nn.Dropout(dropout) 45 | 46 | def forward(self, 47 | src: Tensor, 48 | src_mask: Optional[Tensor] = None, 49 | src_key_padding_mask: Optional[Tensor] = None): 50 | r"""Pass the input through the encoder layer. 51 | 52 | Args: 53 | src: the sequence to the encoder layer (required). 54 | :math:`(S, B, D)`, where S is sequence length, B is batch size and D is feature dimension 55 | src_mask: the attention mask for the src sequence (optional). 56 | :math:`(S, S)`, where S is sequence length. 57 | src_key_padding_mask: the mask for the src keys per batch (optional). 58 | :math: `(B, S)`, where B is batch size and S is sequence length 59 | """ 60 | residual = src 61 | if self.normalize_before: 62 | src = self.self_attn_norm(src) 63 | src = self.self_attn(src, src, src, attn_mask=src_mask, 64 | key_padding_mask=src_key_padding_mask)[0] 65 | src = self.dropout1(src) 66 | src = residual + src 67 | if not self.normalize_before: 68 | src = self.self_attn_norm(src) 69 | 70 | residual = src 71 | if self.normalize_before: 72 | src = self.ffn_norm(src) 73 | src, loss = self.moe(src) 74 | src = self.dropout2(src) 75 | src = residual + src 76 | if not self.normalize_before: 77 | src = self.ffn_norm(src) 78 | return src, loss 79 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/layers/transformer_encoder_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | from torch import nn 5 | 6 | from bycha.modules.encoders.layers import AbstractEncoderLayer 7 | from bycha.modules.layers.feed_forward import FFN 8 | 9 | 10 | class TransformerEncoderLayer(AbstractEncoderLayer): 11 | """ 12 | TransformerEncoderLayer performs one layer of transformer operation, namely self-attention and feed-forward network. 13 | 14 | Args: 15 | d_model: feature dimension 16 | nhead: head numbers of multihead attention 17 | dim_feedforward: dimensionality of inner vector space 18 | dropout: dropout rate 19 | activation: activation function used in feed-forward network 20 | normalize_before: use pre-norm fashion, default as post-norm. 21 | Pre-norm suit deep nets while post-norm achieve better results when nets are shallow. 22 | """ 23 | 24 | def __init__(self, 25 | d_model, 26 | nhead, 27 | dim_feedforward=2048, 28 | dropout=0.1, 29 | attention_dropout=0, 30 | activation="relu", 31 | normalize_before=False,): 32 | super(TransformerEncoderLayer, self).__init__() 33 | self.normalize_before = normalize_before 34 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=attention_dropout) 35 | # Implementation of Feedforward model 36 | self.ffn = FFN(d_model, dim_feedforward=dim_feedforward, activation=activation) 37 | 38 | self.self_attn_norm = nn.LayerNorm(d_model) 39 | self.ffn_norm = nn.LayerNorm(d_model) 40 | self.dropout1 = nn.Dropout(dropout) 41 | self.dropout2 = nn.Dropout(dropout) 42 | 43 | def forward(self, 44 | src: Tensor, 45 | src_mask: Optional[Tensor] = None, 46 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 47 | r"""Pass the input through the encoder layer. 48 | 49 | Args: 50 | src: the sequence to the encoder layer (required). 51 | :math:`(S, B, D)`, where S is sequence length, B is batch size and D is feature dimension 52 | src_mask: the attention mask for the src sequence (optional). 53 | :math:`(S, S)`, where S is sequence length. 54 | src_key_padding_mask: the mask for the src keys per batch (optional). 55 | :math: `(B, S)`, where B is batch size and S is sequence length 56 | """ 57 | residual = src 58 | if self.normalize_before: 59 | src = self.self_attn_norm(src) 60 | src = self.self_attn(src, src, src, attn_mask=src_mask, 61 | key_padding_mask=src_key_padding_mask)[0] 62 | src = self.dropout1(src) 63 | src = residual + src 64 | if not self.normalize_before: 65 | src = self.self_attn_norm(src) 66 | 67 | residual = src 68 | if self.normalize_before: 69 | src = self.ffn_norm(src) 70 | src = self.ffn(src) 71 | src = self.dropout2(src) 72 | src = residual + src 73 | if not self.normalize_before: 74 | src = self.ffn_norm(src) 75 | return src 76 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/encoders/lstm_encoder.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.nn import LSTM 3 | import torch.nn as nn 4 | 5 | from bycha.modules.encoders import AbstractEncoder, register_encoder 6 | 7 | 8 | @register_encoder 9 | class LSTMEncoder(AbstractEncoder): 10 | """ 11 | LSTMEncoder is a LSTM encoder. 12 | 13 | Args: 14 | num_layers: number of encoder layers 15 | d_model: feature dimension 16 | hidden_size: hidden size within LSTM 17 | dropout: dropout rate 18 | bidirectional: encode sequence with bidirectional LSTM 19 | return_seed: return with sequence representation 20 | name: module name 21 | """ 22 | 23 | def __init__(self, 24 | num_layers, 25 | d_model=512, 26 | hidden_size=1024, 27 | dropout=0.1, 28 | bidirectional=True, 29 | return_seed=None, 30 | name=None): 31 | super().__init__() 32 | self._num_layers = num_layers 33 | self._d_model = d_model 34 | self._hidden_size = hidden_size 35 | self._dropout = dropout 36 | self._bidirectional = bidirectional 37 | self._return_seed = return_seed 38 | self._name = name 39 | 40 | self._special_tokens = None 41 | self._embed, self._embed_dropout = None, None 42 | self._layer = None 43 | self._pool_seed = None 44 | 45 | def build(self, embed, special_tokens): 46 | """ 47 | Build computational modules. 48 | 49 | Args: 50 | embed: token embedding 51 | special_tokens: special tokens defined in vocabulary 52 | """ 53 | self._embed = embed 54 | self._special_tokens = special_tokens 55 | self._embed_dropout = nn.Dropout(self._dropout) 56 | self._layer = LSTM(input_size=self._d_model, 57 | hidden_size=self._hidden_size, 58 | num_layers=self._num_layers, 59 | dropout=self._dropout, 60 | bidirectional=self._bidirectional) 61 | 62 | def _forward(self, 63 | src: Tensor): 64 | r""" 65 | Args: 66 | src: tokens in src side. 67 | :math:`(N, S)` where N is the batch size, S is the source sequence length. 68 | 69 | Returns: 70 | - source token hidden representation. 71 | :math:`(S, N, E)` where S is the source sequence length, N is the batch size, 72 | E is the embedding size. 73 | """ 74 | x = self._embed(src) 75 | x = self._embed_dropout(x) 76 | 77 | src_padding_mask = src.eq(self._special_tokens['pad']) 78 | x = x.transpose(0, 1) 79 | x = self._layer(x)[0] 80 | 81 | if self._pool_seed: 82 | return x, src_padding_mask, x.mean(dim=0) 83 | else: 84 | return x, src_padding_mask 85 | 86 | @property 87 | def d_model(self): 88 | return self._d_model 89 | 90 | @property 91 | def out_dim(self): 92 | return self._hidden_size * 2 if self._bidirectional else self._hidden_size 93 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/modules/layers/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/autopruning_ffn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from bycha.modules.utils import get_activation_fn 6 | from bycha.modules.layers.gumbel import gumbel_softmax_topk 7 | 8 | class AutoPruningFFN(nn.Module): 9 | """ 10 | Feed-forward neural network 11 | 12 | Args: 13 | d_model: input feature dimension 14 | dim_feedforward: dimensionality of inner vector space 15 | dim_out: output feature dimensionality 16 | activation: activation function 17 | bias: requires bias in output linear function 18 | """ 19 | 20 | def __init__(self, 21 | d_model, 22 | dim_feedforward=None, 23 | dim_out=None, 24 | activation="relu", 25 | bias=True): 26 | super().__init__() 27 | self._dim_feedforward = dim_feedforward or d_model 28 | self._dim_out = dim_out or d_model 29 | self._bias = bias 30 | 31 | self._fc1 = nn.Linear(d_model, self._dim_feedforward) 32 | self._fc2 = nn.Linear(self._dim_feedforward, self._dim_out, bias=self._bias) 33 | self._activation = get_activation_fn(activation) 34 | 35 | def forward(self, x, weights=None, sorted_indeces=None, tau=0.): 36 | """ 37 | Args: 38 | x: feature to perform feed-forward net 39 | :math:`(*, D)`, where D is feature dimension 40 | 41 | Returns: 42 | - feed forward output 43 | :math:`(*, D)`, where D is feature dimension 44 | """ 45 | if weights is None or sorted_indeces is None or not self.training: 46 | x = self._fc1(x) 47 | x = self._activation(x) 48 | x = self._fc2(x) 49 | return x 50 | 51 | gumbel_onehot, _ = gumbel_softmax_topk(weights, tau=tau) 52 | one_index = gumbel_onehot.max(-1, keepdim=True)[1].item() 53 | prune_ratio = gumbel_onehot[one_index] * one_index * 0.1 54 | prune_num = torch.floor(prune_ratio * self._dim_feedforward).long() 55 | 56 | x = F.linear(x, self._fc1.weight[sorted_indeces[prune_num:]], 57 | self._fc1.bias[sorted_indeces[prune_num:]]) if self._bias else \ 58 | F.linear(x, self._fc1.weight[sorted_indeces[prune_num:]]) 59 | x *= gumbel_onehot[one_index] 60 | x = self._activation(x) 61 | x = F.linear(x, self._fc2.weight[:, sorted_indeces[prune_num:]], self._fc2.bias) 62 | return x 63 | 64 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/bert_layer_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BertLayerNorm(nn.Module): 6 | """ 7 | BertLayerNorm is layer norm used in BERT. 8 | It is a layernorm module in the TF style (epsilon inside the square root). 9 | 10 | Args: 11 | hidden_size: dimensionality of hidden space 12 | """ 13 | 14 | def __init__(self, hidden_size, eps=1e-12): 15 | 16 | super(BertLayerNorm, self).__init__() 17 | self.gamma = nn.Parameter(torch.ones(hidden_size)) 18 | self.beta = nn.Parameter(torch.zeros(hidden_size)) 19 | self.variance_epsilon = eps 20 | 21 | def forward(self, x): 22 | r""" 23 | 24 | 25 | Args: 26 | x: feature to perform layer norm 27 | :math:`(*, D)`, where D is the feature dimension 28 | 29 | Returns: 30 | - normalized feature 31 | :math:`(*, D)`, where D is the feature dimension 32 | """ 33 | u = x.mean(-1, keepdim=True) 34 | s = (x - u).pow(2).mean(-1, keepdim=True) 35 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 36 | return self.gamma * x + self.beta 37 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/dlcl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | 6 | class DynamicLinearCombinationLayer(nn.Module): 7 | """ 8 | DLCL: make the input to be a linear combination of previous outputs 9 | For pre-norm, x_{l+1} = \sum_{k=0}^{l} W_{k} * LN(y_{k}) 10 | For post-norm, x_{l+1} = LN(\sum_{k=0}^{l} W_{k} * y_{k}) 11 | where x_{l}, y_{l} are the input and output of l-th layer 12 | 13 | For pre-norm, LN should be performed in previous layer 14 | For post-norm, LN is performed in this layer 15 | 16 | Args: 17 | idx: this is the `idx`-th layer 18 | post_ln: post layernorm 19 | """ 20 | def __init__(self, idx, post_ln=None): 21 | super(DynamicLinearCombinationLayer, self).__init__() 22 | assert (idx > 0) 23 | self.linear = nn.Linear(idx, 1, bias=False) 24 | nn.init._no_grad_fill_(self.linear.weight, 1.0 / idx) 25 | self.post_ln = post_ln 26 | 27 | def forward(self, y): 28 | """ 29 | Args: 30 | y: SequenceLength x BatchSize x Dim x idx 31 | """ 32 | x = self.linear(y) 33 | x = x.squeeze(dim=-1) 34 | if self.post_ln is not None: 35 | x = self.post_ln(x) 36 | return x 37 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Embedding(nn.Embedding): 5 | """ 6 | Embedding is a wrapped class of torch.nn.Embedding with normal initialization on weight 7 | and zero initialization on pad. 8 | 9 | Args: 10 | vocab_size: vocabulary size 11 | d_model: feature dimensionality 12 | padding_idx: index of pad, which is a special token to ignore 13 | """ 14 | 15 | def __init__(self, vocab_size, d_model, padding_idx=None): 16 | super().__init__(vocab_size, d_model, padding_idx=padding_idx) 17 | nn.init.normal_(self.weight, mean=0, std=d_model ** -0.5) 18 | if padding_idx: 19 | nn.init.constant_(self.weight[padding_idx], 0) 20 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from bycha.modules.utils import get_activation_fn 4 | 5 | 6 | class FFN(nn.Module): 7 | """ 8 | Feed-forward neural network 9 | 10 | Args: 11 | d_model: input feature dimension 12 | dim_feedforward: dimensionality of inner vector space 13 | dim_out: output feature dimensionality 14 | activation: activation function 15 | bias: requires bias in output linear function 16 | """ 17 | 18 | def __init__(self, 19 | d_model, 20 | dim_feedforward=None, 21 | dim_out=None, 22 | activation="relu", 23 | bias=True): 24 | super().__init__() 25 | dim_feedforward = dim_feedforward or d_model 26 | dim_out = dim_out or d_model 27 | 28 | self._fc1 = nn.Linear(d_model, dim_feedforward) 29 | self._fc2 = nn.Linear(dim_feedforward, dim_out, bias=bias) 30 | self._activation = get_activation_fn(activation) 31 | 32 | def forward(self, x): 33 | """ 34 | Args: 35 | x: feature to perform feed-forward net 36 | :math:`(*, D)`, where D is feature dimension 37 | 38 | Returns: 39 | - feed forward output 40 | :math:`(*, D)`, where D is feature dimension 41 | """ 42 | x = self._fc1(x) 43 | x = self._activation(x) 44 | x = self._fc2(x) 45 | return x 46 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Gaussian(nn.Module): 5 | """ 6 | Gaussian predict gaussian characteristics, namely mean and logvar 7 | 8 | Args: 9 | d_model: feature dimension 10 | latent_size: dimensionality of gaussian distribution 11 | """ 12 | 13 | def __init__(self, d_model, latent_size): 14 | super().__init__() 15 | self._dmodel = d_model 16 | self._latent_size = latent_size 17 | 18 | self.post_mean = nn.Linear(d_model, latent_size) 19 | self.post_logvar = nn.Linear(d_model, latent_size) 20 | 21 | def forward(self, x): 22 | """ 23 | Args: 24 | x: feature to perform gaussian 25 | :math:`(*, D)`, where D is feature dimension 26 | 27 | Returns: 28 | - gaussian mean 29 | :math:`(*, D)`, where D is feature dimension 30 | - gaussian logvar 31 | :math:`(*, D)`, where D is feature dimension 32 | """ 33 | post_mean = self.post_mean(x) 34 | post_logvar = self.post_logvar(x) 35 | return post_mean, post_logvar 36 | 37 | 38 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def gumbel_softmax_topk(logits, k=1, tau=1, hard=True, dim=-1): 4 | while True: 5 | gumbels = -torch.empty_like(logits).exponential_().log() 6 | #print(gumbels, gumbels.device) 7 | gumbels = (logits + gumbels) / tau 8 | y_soft = gumbels.softmax(dim) 9 | if (torch.isinf(gumbels).any()) or (torch.isinf(y_soft).any()) or (torch.isnan(y_soft).any()): 10 | continue 11 | else: 12 | break 13 | 14 | index = None 15 | if hard: 16 | _, index = torch.topk(y_soft, k) 17 | y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0) 18 | ret = y_hard - y_soft.detach() + y_soft 19 | else: 20 | ret = y_soft 21 | return ret, index 22 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/layerdrop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | 6 | class LayerDropModuleList(nn.ModuleList): 7 | """ 8 | Implementation of LayerDrop based on torch.nn.ModuleList 9 | 10 | Usage: 11 | Replace torch.nn.ModuleList with LayerDropModuleList 12 | For example: 13 | layers = nn.ModuleList([TransformerEncoderLayer 14 | for _ in range(num_layers)]) 15 | 16 | -> layers = LayerDropModuleList(p=0.2, gamma=1e-4) 17 | layers.extend([TransformerEncoderLayer 18 | for _ in range(num_layers)]) 19 | 20 | Args: 21 | p: initial drop probability 22 | gamma: attenuation speed of drop probability 23 | mode_depth: probability distribution across layers 24 | modules: an iterable of modules to add 25 | """ 26 | 27 | def __init__(self, p, gamma=0., mode_depth=None, modules=None): 28 | super().__init__(modules) 29 | self.p = p 30 | self._gamma = gamma 31 | self._mode_depth = mode_depth 32 | self._step = -1 33 | 34 | def __iter__(self): 35 | self._step += 1 36 | layer_num = len(self) 37 | dropout_probs = torch.empty(layer_num).uniform_() 38 | if self.training and self._gamma > 0: 39 | p_now = self.p - self.p * math.exp(-self._gamma * self._step) 40 | else: 41 | p_now = self.p 42 | p_now = max(0., p_now) 43 | 44 | p_layers = [p_now] * layer_num 45 | if self._mode_depth == 'transformer': 46 | p_layers = [2*min(i+1, layer_num-i)/layer_num*p_now for i in range(layer_num)] 47 | elif self._mode_depth == 'bert': 48 | p_layers = [p_now*i/layer_num for i in range(1, layer_num+1)] 49 | 50 | for i, m in enumerate(super().__iter__()): 51 | m.layerdrop = p_layers[i] if self.training else 0. 52 | if not self.training or (dropout_probs[i] > m.layerdrop): 53 | yield m 54 | 55 | def config_to_params(config): 56 | layerdrop, gamma, mode_depth = 0., 0., None 57 | if config is None: 58 | return layerdrop, gamma, mode_depth 59 | if 'prob' in config: 60 | layerdrop = config['prob'] 61 | if 'gamma' in config: 62 | gamma = config['gamma'] 63 | if 'mode_depth' in config: 64 | mode_depth = config['mode_depth'] 65 | return layerdrop, gamma, mode_depth 66 | 67 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/learned_positional_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | class LearnedPositionalEmbedding(nn.Embedding): 10 | """ 11 | This module learns positional embeddings up to a fixed maximum size. 12 | 13 | Args: 14 | num_embeddings: number of embeddings 15 | embedding_dim: embedding dimension 16 | """ 17 | 18 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = None, post_mask=False): 19 | super().__init__(num_embeddings, embedding_dim, padding_idx) 20 | nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5) 21 | # if post_mask = True, then padding_idx = id of pad token in token embedding, we first mark padding 22 | # tokens using padding_idx, then generate embedding matrix using positional embedding, finally set 23 | # marked positions with zero 24 | self._post_mask = post_mask 25 | 26 | def forward( 27 | self, 28 | input: Tensor, 29 | positions: Optional[Tensor] = None 30 | ): 31 | """ 32 | Args: 33 | input: an input LongTensor 34 | :math:`(*, L)`, where L is sequence length 35 | positions: pre-defined positions 36 | :math:`(*, L)`, where L is sequence length 37 | 38 | Returns: 39 | - positional embedding indexed from input 40 | :math:`(*, L, D)`, where L is sequence length and D is dimensionality 41 | """ 42 | if self._post_mask: 43 | mask = input.ne(self.padding_idx).long() 44 | if positions is None: 45 | positions = (torch.cumsum(mask, dim=1) - 1).long() 46 | emb = F.embedding( 47 | positions, 48 | self.weight, 49 | None, 50 | self.max_norm, 51 | self.norm_type, 52 | self.scale_grad_by_freq, 53 | self.sparse, 54 | )#[B,L,H] 55 | emb = emb * mask.unsqueeze(-1) 56 | return emb 57 | else: 58 | if positions is None: 59 | mask = torch.ones_like(input) 60 | positions = (torch.cumsum(mask, dim=1) - 1).long() 61 | return F.embedding( 62 | positions, 63 | self.weight, 64 | self.padding_idx, 65 | self.max_norm, 66 | self.norm_type, 67 | self.scale_grad_by_freq, 68 | self.sparse, 69 | ) 70 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/layers/sinusoidal_positional_embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.onnx.operators 5 | from torch import nn 6 | 7 | 8 | class SinusoidalPositionalEmbedding(nn.Module): 9 | """This module produces sinusoidal positional embeddings of any length. 10 | 11 | Padding symbols are ignored. 12 | """ 13 | 14 | class __SinusoidalPositionalEmbedding(nn.Module): 15 | 16 | def __init__(self, embedding_dim, num_embeddings=1024): 17 | super().__init__() 18 | self._embedding_dim = embedding_dim 19 | self._num_embeddings = num_embeddings 20 | 21 | num_timescales = self._embedding_dim // 2 22 | log_timescale_increment = torch.FloatTensor([math.log(10000.) / (num_timescales - 1)]) 23 | inv_timescales = nn.Parameter((torch.arange(num_timescales) * -log_timescale_increment).exp(), requires_grad=False) 24 | self.register_buffer('_inv_timescales', inv_timescales) 25 | 26 | def forward( 27 | self, 28 | input, 29 | ): 30 | """Input is expected to be of size [bsz x seqlen].""" 31 | mask = torch.ones_like(input).type_as(self._inv_timescales) 32 | positions = torch.cumsum(mask, dim=1) - 1 33 | 34 | scaled_time = positions[:, :, None] * self._inv_timescales[None, None, :] 35 | signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=-1) 36 | return signal.detach() 37 | 38 | __embed__ = None 39 | 40 | def __init__(self, embedding_dim, num_embeddings=1024): 41 | super().__init__() 42 | if not SinusoidalPositionalEmbedding.__embed__: 43 | SinusoidalPositionalEmbedding.__embed__ = SinusoidalPositionalEmbedding.__SinusoidalPositionalEmbedding( 44 | embedding_dim=embedding_dim, 45 | num_embeddings=num_embeddings 46 | ) 47 | self.embedding = SinusoidalPositionalEmbedding.__embed__ 48 | 49 | def forward(self, input): 50 | return self.embedding(input) 51 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/search/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_search import AbstractSearch 7 | 8 | register_search, create_search, registry = setup_registry('search', AbstractSearch) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.modules.search.' + module_name) 20 | 21 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/search/abstract_search.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module 2 | 3 | 4 | class AbstractSearch(Module): 5 | """ 6 | AbstractSearch is search algorithm on original neural model to perform special inference. 7 | """ 8 | 9 | def __init__(self): 10 | super().__init__() 11 | self._mode = 'infer' 12 | 13 | def build(self, *args, **kwargs): 14 | """ 15 | Build search algorithm with task instance 16 | """ 17 | raise NotImplementedError 18 | 19 | def forward(self, *args, **kwargs): 20 | """ 21 | Process forward of search algorithm. 22 | """ 23 | raise NotImplementedError 24 | 25 | def reset(self, mode): 26 | """ 27 | Reset encoder and switch running mode 28 | 29 | Args: 30 | mode: running mode in [train, valid, infer] 31 | """ 32 | self._mode = mode 33 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/search/greedy_search.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | 4 | from bycha.modules.search import register_search 5 | from bycha.modules.search.sequence_search import SequenceSearch 6 | from bycha.modules.utils import create_init_scores 7 | 8 | 9 | @register_search 10 | class GreedySearch(SequenceSearch): 11 | """ 12 | GreedySearch is greedy search on sequence generation. 13 | 14 | Args: 15 | maxlen_coef (a, b): maxlen computation coefficient. 16 | The max length is computed as `(S * a + b)`, where S is source sequence length. 17 | """ 18 | 19 | def __init__(self, maxlen_coef=(1.2, 10)): 20 | super().__init__() 21 | 22 | self._maxlen_a, self._maxlen_b = maxlen_coef 23 | 24 | def forward(self, 25 | prev_tokens, 26 | memory, 27 | memory_padding_mask, 28 | target_mask: Optional[torch.Tensor] = None, 29 | prev_scores: Optional[torch.Tensor] = None): 30 | """ 31 | Decoding full-step sequence with greedy search 32 | 33 | Args: 34 | prev_tokens: previous tokens or prefix of sequence 35 | memory: memory for attention. 36 | :math:`(M, N, E)`, where M is the memory sequence length, N is the batch size, 37 | memory_padding_mask: memory sequence padding mask. 38 | :math:`(N, M)` where M is the memory sequence length, N is the batch size. 39 | target_mask: target mask indicating blacklist tokens 40 | :math:`(B, V)` where B is batch size and V is vocab size 41 | prev_scores: scores of previous tokens 42 | :math:`(B)` where B is batch size 43 | 44 | Returns: 45 | - log probability of generated sequence 46 | - generated sequence 47 | """ 48 | batch_size = prev_tokens.size(0) 49 | scores = create_init_scores(prev_tokens, memory) if prev_scores is None else prev_scores 50 | for _ in range(int(memory.size(0) * self._maxlen_a + self._maxlen_b)): 51 | logits = self._decoder(prev_tokens, memory, memory_padding_mask) 52 | logits = logits[:, -1, :] 53 | if target_mask is not None: 54 | logits = logits.masked_fill(target_mask, float('-inf')) 55 | next_word_scores, words = logits.max(dim=-1) 56 | eos_mask = words.eq(self._eos) 57 | if eos_mask.long().sum() == batch_size: 58 | break 59 | scores = scores + next_word_scores.masked_fill_(eos_mask, 0.).view(-1) 60 | prev_tokens = torch.cat([prev_tokens, words.unsqueeze(dim=-1)], dim=-1) 61 | return scores, prev_tokens 62 | -------------------------------------------------------------------------------- /mybycha/bycha/modules/search/sequence_search.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | 5 | from bycha.modules.search import AbstractSearch 6 | 7 | 8 | class SequenceSearch(AbstractSearch): 9 | """ 10 | SequenceSearch algorithms are used to generate a complete sequence with strategies. 11 | It usually built from a one-step neural model and fledges the model to a full-step generation. 12 | """ 13 | 14 | def __init__(self): 15 | super().__init__() 16 | 17 | self._decoder = None 18 | self._bos, self._eos, self._pad = None, None, None 19 | 20 | def build(self, decoder, bos, eos, pad, *args, **kwargs): 21 | """ 22 | Build the search algorithm with task instances. 23 | 24 | Args: 25 | decoder: decoder of neural model. 26 | bos: begin-of-sentence index 27 | eos: end-of-sentence index 28 | pad: pad index 29 | """ 30 | self._decoder = decoder 31 | self._bos, self._eos, self._pad = bos, eos, pad 32 | 33 | def forward(self, 34 | prev_tokens: Tensor, 35 | memory: Tensor, 36 | memory_padding_mask: Tensor, 37 | target_mask: Optional[Tensor] = None, 38 | prev_scores: Optional[Tensor] = None): 39 | """ 40 | Decoding full-step sequence 41 | 42 | Args: 43 | prev_tokens: previous tokens or prefix of sequence 44 | memory: memory for attention. 45 | :math:`(M, N, E)`, where M is the memory sequence length, N is the batch size, 46 | memory_padding_mask: memory sequence padding mask. 47 | :math:`(N, M)` where M is the memory sequence length, N is the batch size. 48 | target_mask: target mask indicating blacklist tokens 49 | :math:`(B, V)` where B is batch size and V is vocab size 50 | prev_scores: scores of previous tokens 51 | :math:`(B)` where B is batch size 52 | 53 | Returns: 54 | - log probability of generated sequence 55 | - generated sequence 56 | """ 57 | raise NotImplementedError 58 | 59 | def reset(self, mode): 60 | """ 61 | Reset encoder and switch running mode 62 | 63 | Args: 64 | mode: running mode in [train, valid, infer] 65 | """ 66 | self._mode = mode 67 | self._decoder.reset(mode) 68 | 69 | -------------------------------------------------------------------------------- /mybycha/bycha/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | from bycha.utils.runtime import Environment 6 | 7 | from .abstract_sampler import AbstractSampler 8 | from .distributed_sampler import DistributedSampler 9 | 10 | register_sampler, _create_sampler, registry = setup_registry('sampler', AbstractSampler) 11 | 12 | 13 | def create_sampler(configs, is_training=False): 14 | """ 15 | Create a sampler. 16 | Note in distributed training, sampler should be further wrapped with a DistributedSampler. 17 | 18 | Args: 19 | configs: sampler configuration 20 | is_training: whether the sampler is used for training. 21 | 22 | Returns: 23 | a data sampler 24 | """ 25 | sampler = _create_sampler(configs) 26 | env = Environment() 27 | if env.distributed_world > 1 and is_training: 28 | sampler = DistributedSampler(sampler) 29 | return sampler 30 | 31 | 32 | modules_dir = os.path.dirname(__file__) 33 | for file in os.listdir(modules_dir): 34 | path = os.path.join(modules_dir, file) 35 | if ( 36 | not file.startswith('_') 37 | and not file.startswith('.') 38 | and (file.endswith('.py') or os.path.isdir(path)) 39 | ): 40 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 41 | module = importlib.import_module('bycha.samplers.' + module_name) 42 | -------------------------------------------------------------------------------- /mybycha/bycha/samplers/batch_shuffle_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from bycha.samplers import register_sampler 4 | from bycha.samplers.sequential_sampler import SequentialSampler 5 | 6 | 7 | @register_sampler 8 | class BatchShuffleSampler(SequentialSampler): 9 | """ 10 | BatchShuffleSampler pre-compute all the batch sequentially, 11 | and shuffle the reading order of batches before an new round of iteration. 12 | """ 13 | 14 | def __init__(self, **kwargs): 15 | super().__init__(**kwargs) 16 | 17 | @property 18 | def batch_sampler(self): 19 | """ 20 | Pre-calculate batches within sampler with strategy 21 | 22 | Returns: 23 | batches: a list of batches of index 24 | """ 25 | batches = super().batch_sampler 26 | random.shuffle(batches) 27 | return batches 28 | -------------------------------------------------------------------------------- /mybycha/bycha/samplers/bucket_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from bycha.samplers import AbstractSampler, register_sampler 4 | from bycha.utils.runtime import Environment 5 | 6 | 7 | @register_sampler 8 | class BucketSampler(AbstractSampler): 9 | """ 10 | BucketSampler put samples of similar size into a bucket to lift computational efficiency and to accelerate training. 11 | 12 | Args: 13 | noise: inject noise when create buckets for each iteration. 14 | """ 15 | 16 | def __init__(self, noise=0., **kwargs): 17 | super().__init__(**kwargs) 18 | self._noise = noise 19 | 20 | def build(self, data_source): 21 | """ 22 | Build sampler over data_source 23 | 24 | Args: 25 | data_source: a list of data 26 | """ 27 | self._data_source = data_source 28 | self._length = len(self._data_source) 29 | self.reset(0) 30 | 31 | def reset(self, epoch, *args, **kwargs): 32 | """ 33 | Resetting sampler states / shuffle reading order for next round of iteration 34 | 35 | Args: 36 | epoch: iteration epoch 37 | """ 38 | env = Environment() 39 | random.seed(env.seed + epoch) 40 | token_nums = [(i, self._inject_noise(sample['token_num'])) for i, sample in enumerate(self._data_source)] 41 | token_nums.sort(key=lambda x: x[1], reverse=True) 42 | self._permutation = [idx for idx, _ in token_nums] 43 | 44 | @property 45 | def batch_sampler(self): 46 | """ 47 | Pre-calculate batches within sampler with strategy 48 | 49 | Returns: 50 | batches: a list of batches of index 51 | """ 52 | batches = super().batch_sampler 53 | random.shuffle(batches) 54 | return batches 55 | 56 | def _inject_noise(self, x): 57 | """ 58 | Disturb size 59 | 60 | Args: 61 | x: size 62 | 63 | Returns: 64 | disturbed size 65 | """ 66 | if self._noise > 0: 67 | variance = int(x * self._noise) 68 | r = random.randint(-variance, variance) 69 | return x + r 70 | else: 71 | return x 72 | -------------------------------------------------------------------------------- /mybycha/bycha/samplers/sequential_sampler.py: -------------------------------------------------------------------------------- 1 | from bycha.samplers import AbstractSampler, register_sampler 2 | 3 | 4 | @register_sampler 5 | class SequentialSampler(AbstractSampler): 6 | """ 7 | SequentialSampler iterates on samples sequentially. 8 | """ 9 | 10 | def __init__(self, **kwargs): 11 | super().__init__(**kwargs) 12 | 13 | def build(self, data_source): 14 | """ 15 | Build sampler over data_source 16 | 17 | Args: 18 | data_source: a list of data 19 | """ 20 | self._data_source = data_source 21 | self._permutation = [_ for _ in range(len(self._data_source))] 22 | self._length = len(self._permutation) 23 | -------------------------------------------------------------------------------- /mybycha/bycha/samplers/shuffled_sampler.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from bycha.samplers import AbstractSampler, register_sampler 4 | from bycha.utils.runtime import Environment 5 | 6 | 7 | @register_sampler 8 | class ShuffleSampler(AbstractSampler): 9 | """ 10 | ShuffleSampler shuffle the order before fetching samples. 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | self._env = Environment() 16 | 17 | def build(self, data_source): 18 | """ 19 | Build sampler over data_source 20 | 21 | Args: 22 | data_source: a list of data 23 | """ 24 | self._data_source = data_source 25 | self._permutation = [_ for _ in range(len(self._data_source))] 26 | self._length = len(self._permutation) 27 | self.reset(0) 28 | 29 | def reset(self, epoch, *args, **kwargs): 30 | """ 31 | Resetting sampler states / shuffle reading order for next round of iteration 32 | """ 33 | random.seed(self._env.seed + epoch) 34 | random.shuffle(self._permutation) 35 | -------------------------------------------------------------------------------- /mybycha/bycha/services/__init__.py: -------------------------------------------------------------------------------- 1 | import euler 2 | 3 | from .idls.bycha_thrift import Request, Response, Service 4 | from .server import Server 5 | -------------------------------------------------------------------------------- /mybycha/bycha/services/idls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/services/idls/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/services/idls/bycha.thrift: -------------------------------------------------------------------------------- 1 | namespace py Thrift 2 | 3 | struct Request{ 4 | // a json-dumped string for a batch of samples 5 | 1: required string samples; 6 | } 7 | 8 | struct Response{ 9 | // a json-dumped string for a batch of results 10 | 1: string results; 11 | 2: string debug_info; 12 | // message for this request, "Success" or others 13 | 4: i32 code; 14 | } 15 | 16 | service Service{ 17 | // infer the result score of title 18 | Response serve(1:Request req) 19 | } 20 | -------------------------------------------------------------------------------- /mybycha/bycha/services/idls/model_infer.thrift: -------------------------------------------------------------------------------- 1 | namespace py background_model_infer 2 | 3 | service ModelInfer { 4 | binary infer(1:binary samples) 5 | } -------------------------------------------------------------------------------- /mybycha/bycha/services/model_server.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import pickle 3 | import logging 4 | logger = logging.getLogger(__name__) 5 | 6 | import torch 7 | 8 | from bycha.utils.ops import auto_map_args 9 | from bycha.utils.runtime import Environment 10 | from bycha.utils.tensor import to_device 11 | 12 | 13 | class ModelServer: 14 | """ 15 | ModelServer is a thrift server running neural model at backend. 16 | 17 | Args: 18 | generator: neural inference model 19 | """ 20 | 21 | def __init__(self, generator): 22 | self._generator = generator 23 | self._env = Environment() 24 | 25 | self._generator.eval() 26 | 27 | def infer(self, net_input): 28 | """ 29 | Inference with neural model. 30 | 31 | Args: 32 | net_input: neural model 33 | 34 | Returns: 35 | - neural output 36 | """ 37 | try: 38 | net_input = pickle.loads(net_input) 39 | if isinstance(net_input, Dict): 40 | net_input = auto_map_args(net_input, self._generator.input_slots) 41 | net_input = to_device(net_input, self._env.device, fp16=self._env.fp16) 42 | with torch.no_grad(): 43 | self._generator.reset('infer') 44 | net_output = self._generator(*net_input) 45 | net_output = to_device(net_output, 'cpu') 46 | net_output = pickle.dumps(net_output) 47 | return net_output 48 | except Exception as e: 49 | logger.warning(str(e)) 50 | return None 51 | -------------------------------------------------------------------------------- /mybycha/bycha/services/server.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | from thriftpy.rpc import make_client 8 | import thriftpy 9 | 10 | from bycha.services import Request, Response 11 | from bycha.tasks import create_task 12 | from bycha.utils.runtime import build_env, Environment 13 | 14 | 15 | class Server: 16 | """ 17 | Server make the task a interactive service, and can be deployed online to serve requests. 18 | 19 | Args: 20 | configs: configurations to build a task 21 | """ 22 | 23 | def __init__(self, configs): 24 | self._configs = configs 25 | 26 | if 'env' in self._configs: 27 | build_env(**self._configs['env']) 28 | task_config = self._configs.pop('task') 29 | task_config.pop('generator') 30 | task_config.pop('model') 31 | 32 | self._task = create_task(task_config) 33 | self._task.build() 34 | self._task.reset(training=False, infering=True) 35 | 36 | self._env = Environment() 37 | 38 | def serve(self, request): 39 | """ 40 | Serve a request 41 | 42 | Args: 43 | request (Request): a request for serving. 44 | It must contain a jsonable attribute named `samples` indicating a batch of unprocessed samples. 45 | 46 | Returns: 47 | response (Response): a response to the given request. 48 | """ 49 | response = Response(results='') 50 | try: 51 | logger.info('receive request {}'.format(request)) 52 | generator = _build_backend_generator_service() 53 | samples = request.samples 54 | samples = json.loads(samples) 55 | samples = self._task.preprocess(samples) 56 | samples = pickle.dumps(samples['net_input']) 57 | results = generator.infer(samples) 58 | results = pickle.loads(results) 59 | debug_info = {'net_output': results.tolist()} 60 | results = self._task.postprocess(results) 61 | response = Response(results=json.dumps(results), 62 | debug_info=json.dumps(debug_info) if self._env.debug else None) 63 | logger.info('return response {}'.format(response)) 64 | except Exception as e: 65 | logger.warning(str(e)) 66 | finally: 67 | return response 68 | 69 | 70 | def _build_backend_generator_service(): 71 | """ 72 | Create a service to connect backend neural model 73 | 74 | Returns: 75 | - a thrift client connecting neural model 76 | """ 77 | backgronud_model_infer_thrift = thriftpy.load("/opt/tiger/ByCha/bycha/services/idls/model_infer.thrift", 78 | module_name="model_infer_thrift") 79 | grpc_port = int(os.environ.get('GRPC_PORT', 6000)) 80 | generator = make_client(backgronud_model_infer_thrift.ModelInfer, 81 | 'localhost', 82 | grpc_port, 83 | timeout=100000000) 84 | return generator 85 | 86 | -------------------------------------------------------------------------------- /mybycha/bycha/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | TRAIN, VALID, EVALUATE, SERVE = 'TRAIN', 'VALID', 'EVALUATE', 'SERVE' 7 | 8 | from .abstract_task import AbstractTask 9 | 10 | register_task, create_task, registry = setup_registry('task', AbstractTask) 11 | 12 | modules_dir = os.path.dirname(__file__) 13 | for file in os.listdir(modules_dir): 14 | path = os.path.join(modules_dir, file) 15 | if ( 16 | not file.startswith('_') 17 | and not file.startswith('.') 18 | and (file.endswith('.py') or os.path.isdir(path)) 19 | ): 20 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 21 | module = importlib.import_module('bycha.tasks.' + module_name) 22 | 23 | -------------------------------------------------------------------------------- /mybycha/bycha/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_tokenizer import AbstractTokenizer 7 | 8 | register_tokenizer, create_tokenizer, registry = setup_registry('tokenizer', AbstractTokenizer) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.tokenizers.' + module_name) 20 | 21 | -------------------------------------------------------------------------------- /mybycha/bycha/tokenizers/abstract_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class AbstractTokenizer: 5 | """ 6 | Tokenizer provides a tokenization pipeline. 7 | """ 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def build(self, *args, **kwargs): 13 | """ 14 | Build tokenizer. 15 | """ 16 | pass 17 | 18 | def __len__(self): 19 | raise NotImplementedError 20 | 21 | def encode(self, *args) -> List[int]: 22 | """ 23 | Encode a textual sentence into a list of index. 24 | """ 25 | raise NotImplementedError 26 | 27 | def decode(self, x: List[int]) -> str: 28 | """ 29 | Decode a list of index back into a textual sentence 30 | 31 | Args: 32 | x: a list of index 33 | """ 34 | raise NotImplementedError 35 | 36 | def tok(self, *args) -> str: 37 | """ 38 | Tokenize a textual sentence without index mapping. 39 | """ 40 | out = [] 41 | for ext in args: 42 | out += ext 43 | return out 44 | 45 | def detok(self, x: str) -> str: 46 | """ 47 | Detokenize a textual sentence without index mapping. 48 | 49 | Args: 50 | x: a textual sentence 51 | """ 52 | return x 53 | 54 | def token2index(self, *args) -> List[int]: 55 | """ 56 | Only map a textual sentence to index 57 | """ 58 | raise NotImplementedError 59 | 60 | def index2token(self, x: List[int]) -> str: 61 | """ 62 | Only map a list of index back into a textual sentence 63 | 64 | Args: 65 | x: a list of index 66 | """ 67 | raise NotImplementedError 68 | 69 | @staticmethod 70 | def learn(*args, **kwargs): 71 | """ 72 | Learn a tokenizer from data set. 73 | """ 74 | raise NotImplementedError 75 | 76 | @property 77 | def special_tokens(self): 78 | return { 79 | 'bos': self.bos, 80 | 'eos': self.eos, 81 | 'pad': self.pad, 82 | 'unk': self.unk 83 | } 84 | 85 | @property 86 | def bos(self): 87 | raise NotImplementedError 88 | 89 | @property 90 | def eos(self): 91 | raise NotImplementedError 92 | 93 | @property 94 | def unk(self): 95 | raise NotImplementedError 96 | 97 | @property 98 | def pad(self): 99 | raise NotImplementedError 100 | 101 | @property 102 | def bos_token(self): 103 | raise NotImplementedError 104 | 105 | @property 106 | def eos_token(self): 107 | raise NotImplementedError 108 | 109 | @property 110 | def unk_token(self): 111 | raise NotImplementedError 112 | 113 | @property 114 | def pad_token(self): 115 | raise NotImplementedError 116 | -------------------------------------------------------------------------------- /mybycha/bycha/tokenizers/huggingface_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from transformers import AutoTokenizer 6 | 7 | from bycha.tokenizers import AbstractTokenizer, register_tokenizer 8 | 9 | 10 | @register_tokenizer 11 | class HuggingfaceTokenizer(AbstractTokenizer): 12 | """ 13 | HuggingfaceTokenizer use `huggingface/transformers` lib to do tokenization 14 | see huggingface/transformers(https://github.com/huggingface/transformers) 15 | 16 | Args: 17 | tokenizer_name: tokenizer names 18 | """ 19 | 20 | def __init__(self, tokenizer_name, *args, **kwargs): 21 | super().__init__() 22 | self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) 23 | 24 | @staticmethod 25 | def learn(*args, **kwargs): 26 | """ 27 | HuggingfaceTokenizer are used for pretrained model, and is usually directly load from huggingface. 28 | """ 29 | logger.info('learn vocab not supported for huggingface tokenizer') 30 | raise NotImplementedError 31 | 32 | def __len__(self): 33 | return len(self._tokenizer) 34 | 35 | def encode(self, input, *args, **kwargs) -> List[int]: 36 | """ 37 | Encode a textual sentence into a list of index. 38 | """ 39 | if len(input) == 1: 40 | input = input[0] 41 | return self._tokenizer.encode(input, *args, **kwargs) 42 | 43 | def decode(self, *args, **kwargs) -> str: 44 | """ 45 | Decode a list of index back into a textual sentence 46 | """ 47 | return self._tokenizer.decode(*args, **kwargs) 48 | 49 | def __call__(self, *args, **kwargs): 50 | return self._tokenizer(*args, **kwargs) 51 | 52 | def token2index(self, *args) -> List[int]: 53 | """ 54 | Only map a textual sentence to index 55 | """ 56 | return self.encode(*args)[1:-1] 57 | 58 | @property 59 | def max_length(self): 60 | return self._tokenizer.model_max_length 61 | 62 | @property 63 | def bos(self): 64 | return self._tokenizer.bos_token_id 65 | 66 | @property 67 | def eos(self): 68 | return self._tokenizer.eos_token_id 69 | 70 | @property 71 | def unk(self): 72 | return self._tokenizer.unk_token_id 73 | 74 | @property 75 | def pad(self): 76 | return self._tokenizer.pad_token_id 77 | 78 | @property 79 | def bos_token(self): 80 | return self._tokenizer.bos_token 81 | 82 | @property 83 | def eos_token(self): 84 | return self._tokenizer.eos_token 85 | 86 | @property 87 | def unk_token(self): 88 | return self._tokenizer.unk_token 89 | 90 | @property 91 | def pad_token(self): 92 | return self._tokenizer.pad_token 93 | -------------------------------------------------------------------------------- /mybycha/bycha/tokenizers/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | SPACE_NORMALIZER = re.compile("\s+") 4 | 5 | SPECIAL_SYMBOLS = ['', '', '', ''] 6 | 7 | -------------------------------------------------------------------------------- /mybycha/bycha/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_trainer import AbstractTrainer 7 | 8 | register_trainer, create_trainer, registry = setup_registry('trainer', AbstractTrainer) 9 | 10 | modules_dir = os.path.dirname(__file__) 11 | for file in os.listdir(modules_dir): 12 | path = os.path.join(modules_dir, file) 13 | if ( 14 | not file.startswith('_') 15 | and not file.startswith('.') 16 | and (file.endswith('.py') or os.path.isdir(path)) 17 | ): 18 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('bycha.trainers.' + module_name) 20 | 21 | -------------------------------------------------------------------------------- /mybycha/bycha/trainers/moe_trainer.py: -------------------------------------------------------------------------------- 1 | from bycha.trainers.trainer import Trainer 2 | from bycha.trainers import register_trainer 3 | 4 | @register_trainer 5 | class MoETrainer(Trainer): 6 | 7 | def __init__(self, 8 | load_balance_alpha=0., 9 | *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | self._load_balance_alpha = load_balance_alpha 12 | 13 | def _forward_loss(self, samples): 14 | """ 15 | Forward neural model and compute the loss of given samples 16 | 17 | Args: 18 | samples: a batch of samples 19 | 20 | Returns: 21 | - derived loss as torch.Tensor 22 | - states for updating log 23 | """ 24 | loss, logging_states = self._criterion(**samples) 25 | loss, logging_states = self._load_balance_loss(loss, logging_states) 26 | return loss, logging_states 27 | 28 | def _load_balance_loss(self, loss, logging_states): 29 | moe_loss = self._model._encoder.moe_loss + self._model._decoder.moe_loss 30 | moe_loss /= 2 31 | loss += moe_loss*self._load_balance_alpha 32 | logging_states['moe_loss'] = moe_loss.data.item() 33 | return loss, logging_states 34 | 35 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/mybycha/bycha/utils/__init__.py -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from bycha.utils.registry import setup_registry 5 | 6 | from .abstract_rate_scheduler import AbstractRateScheduler 7 | 8 | register_rate_scheduler, _create_rate_scheduler, registry = setup_registry('rate_scheduler', AbstractRateScheduler) 9 | 10 | 11 | def create_rate_scheduler(configs): 12 | if isinstance(configs, float): 13 | configs = {'class': 'ConstantRateScheduler', 'rate': configs} 14 | rate_schduler = _create_rate_scheduler(configs) 15 | return rate_schduler 16 | 17 | 18 | modules_dir = os.path.dirname(__file__) 19 | for file in os.listdir(modules_dir): 20 | path = os.path.join(modules_dir, file) 21 | if ( 22 | not file.startswith('_') 23 | and not file.startswith('.') 24 | and (file.endswith('.py') or os.path.isdir(path)) 25 | ): 26 | module_name = file[:file.find('.py')] if file.endswith('.py') else file 27 | module = importlib.import_module('bycha.utils.rate_schedulers.' + module_name) 28 | 29 | 30 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/abstract_rate_scheduler.py: -------------------------------------------------------------------------------- 1 | class AbstractRateScheduler: 2 | """ 3 | AbstractRateScheduler is an auxiliary tools for adjust rate. 4 | 5 | Args: 6 | rate: initial rate 7 | """ 8 | 9 | def __init__(self, rate: float = 0., *args, **kwargs): 10 | self._rate: float = rate 11 | 12 | def build(self, *args, **kwargs): 13 | """ 14 | Build rate scheduler 15 | """ 16 | pass 17 | 18 | def step_update(self, step, *args, **kwargs): 19 | """ 20 | Update inner rate with outside states at each step 21 | 22 | Args: 23 | step: training step 24 | """ 25 | pass 26 | 27 | def step_reset(self, step, *args, **kwargs): 28 | """ 29 | Reset inner rate with outside states at each step 30 | 31 | Args: 32 | step: training step 33 | """ 34 | pass 35 | 36 | def epoch_update(self, epoch, *args, **kwargs): 37 | """ 38 | Update inner rate with outside states at each epoch 39 | 40 | Args: 41 | epoch: training epoch 42 | """ 43 | pass 44 | 45 | def epoch_reset(self, epoch, *args, **kwargs): 46 | """ 47 | Update inner rate with outside states at each epoch 48 | 49 | Args: 50 | epoch: training epoch 51 | """ 52 | pass 53 | 54 | @property 55 | def rate(self): 56 | return self._rate 57 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/constant_rate_scheduler.py: -------------------------------------------------------------------------------- 1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler 2 | 3 | 4 | @register_rate_scheduler 5 | class ConstantRateScheduler(AbstractRateScheduler): 6 | """ 7 | ConstantRateScheduler do no schedule rate. 8 | 9 | Args: 10 | rate: constant rate 11 | """ 12 | 13 | def __init__(self, rate): 14 | super().__init__(rate) 15 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/inverse_square_root_rate_scheduler.py: -------------------------------------------------------------------------------- 1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler 2 | 3 | 4 | @register_rate_scheduler 5 | class InverseSquareRootRateScheduler(AbstractRateScheduler): 6 | """ 7 | InverseSquareRootRateScheduler first linearly warm up rate and decay the rate in square root. 8 | 9 | Args: 10 | rate: maximum rate 11 | warmup_steps: number of updates in warming up 12 | """ 13 | 14 | def __init__(self, rate, warmup_steps=1000): 15 | super().__init__(rate) 16 | self._warmup_steps = warmup_steps 17 | 18 | self._lr_step, self._decay_factor = None, None 19 | 20 | def build(self): 21 | """ 22 | Build rate scheduler 23 | """ 24 | self._lr_step = self._rate / self._warmup_steps 25 | self._decay_factor = self._rate * self._warmup_steps ** 0.5 26 | self._rate = 0. 27 | 28 | def step_update(self, step, *args, **kwargs): 29 | """ 30 | Update inner rate with outside states at each step 31 | 32 | Args: 33 | step: training step 34 | """ 35 | if step < self._warmup_steps: 36 | self._rate = step * self._lr_step 37 | else: 38 | self._rate = self._decay_factor * step ** -0.5 39 | 40 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/logistic_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler 4 | 5 | 6 | @register_rate_scheduler 7 | class LogisticScheduler(AbstractRateScheduler): 8 | """ 9 | LogisticScheduler scheduler the rate with logistic decay. 10 | 11 | Args: 12 | k: decaying weight 13 | x0: bias 14 | """ 15 | 16 | def __init__(self, k=0.0025, x0=4000): 17 | super().__init__(0.) 18 | self._k = k 19 | self._x0 = x0 20 | 21 | def step_update(self, step, *args, **kwargs): 22 | """ 23 | Update inner rate with outside states at each step 24 | 25 | Args: 26 | step: training step 27 | """ 28 | self._rate = 1 / (1 + math.exp(-self._k * (step - self._x0))) 29 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/noam_scheduler.py: -------------------------------------------------------------------------------- 1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler 2 | 3 | 4 | @register_rate_scheduler 5 | class NoamScheduler(AbstractRateScheduler): 6 | """ 7 | NoamScheduler is a scheduling methods proposed by Noam 8 | 9 | Args: 10 | d_model: neural model feature dimension 11 | warmup_steps: training steps in warming up 12 | """ 13 | 14 | def __init__(self, d_model, warmup_steps=4000): 15 | super().__init__(0.) 16 | self._warmup_steps = warmup_steps 17 | self._d_model = d_model 18 | 19 | def step_update(self, step, *args, **kwargs): 20 | """ 21 | Update inner rate with outside states at each step 22 | 23 | Args: 24 | step: training step 25 | """ 26 | self._rate = (self._d_model ** -0.5) * min([step ** -0.5, step * (self._warmup_steps ** -1.5)]) 27 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/rate_schedulers/polynomial_decay_scheduler.py: -------------------------------------------------------------------------------- 1 | from bycha.utils.rate_schedulers import AbstractRateScheduler, register_rate_scheduler 2 | 3 | 4 | @register_rate_scheduler 5 | class PolynomialDecayScheduler(AbstractRateScheduler): 6 | """ 7 | PolynomialDecaySchedulaer first linearly warm up rate, then decay the rate polynomailly and 8 | finally keep at an minimum rate. 9 | 10 | Args: 11 | max_rate: maximum rate 12 | total_steps: total training steps 13 | warmup_steps: number of updates in warming up 14 | end_rate: minimum rate at end 15 | power: polynomial decaying power 16 | """ 17 | 18 | def __init__(self, max_rate, total_steps, warmup_steps=4000, end_rate=0.0, power=1.0,): 19 | super().__init__(0.) 20 | self._max_rate = max_rate 21 | self._total_steps = total_steps 22 | self._warmup_steps = warmup_steps 23 | self._end_rate = end_rate 24 | self._power = power 25 | 26 | def step_update(self, step, *args, **kwargs): 27 | """ 28 | Update inner rate with outside states at each step 29 | 30 | Args: 31 | step: training step 32 | """ 33 | if self._warmup_steps > 0 and step <= self._warmup_steps: 34 | warmup_factor = step / float(self._warmup_steps) 35 | self._rate = warmup_factor * self._max_rate 36 | elif step >= self._total_steps: 37 | self._rate = self._end_rate 38 | else: 39 | rate_range = self._max_rate - self._end_rate 40 | pct_remaining = 1 - (step - self._warmup_steps) / (self._total_steps - self._warmup_steps) 41 | self._rate = rate_range * pct_remaining ** self._power + self._end_rate 42 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/registry.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | logger = logging.getLogger(__name__) 4 | 5 | from bycha.utils.data import possible_eval 6 | from bycha.utils.io import jsonable 7 | from bycha.utils.ops import deepcopy_on_ref 8 | 9 | MODULE_REGISTRY = {} 10 | 11 | 12 | def setup_registry(registry, base_cls, force_extend=True): 13 | """ 14 | Set up registry for a certain class 15 | 16 | Args: 17 | registry: registry name 18 | base_cls: base class of a certain class 19 | force_extend: force a new class extend the base class 20 | 21 | Returns: 22 | - decorator to register a subclass 23 | - function to create a subclass with configurations 24 | - registry dictionary 25 | """ 26 | 27 | if registry not in MODULE_REGISTRY: 28 | MODULE_REGISTRY[registry] = {} 29 | 30 | def register_cls(cls): 31 | """ 32 | Register a class with its name 33 | 34 | Args: 35 | cls: a new class fro registration 36 | """ 37 | name = cls.__name__.lower() 38 | if name in MODULE_REGISTRY[registry]: 39 | raise ValueError('Cannot register duplicate {} class ({})'.format(registry, name)) 40 | if force_extend and not issubclass(cls, base_cls): 41 | raise ValueError('Class {} must extend {}'.format(name, base_cls.__name__)) 42 | if name in MODULE_REGISTRY[registry]: 43 | raise ValueError('Cannot register class with duplicate class name ({})'.format(name)) 44 | MODULE_REGISTRY[registry][name] = cls 45 | return cls 46 | 47 | def create_cls(configs=None): 48 | """ 49 | Create a class with configuration 50 | 51 | Args: 52 | configs: configuration dictionary for building class 53 | 54 | Returns: 55 | - an instance of class 56 | """ 57 | configs = deepcopy_on_ref(configs) 58 | name = configs.pop('class') 59 | json_configs = {k: v for k, v in configs.items() if jsonable(k) and jsonable(v)} 60 | logger.info('Creating {} class with configs \n{}\n'.format(name, json.dumps(json_configs, indent=4, sort_keys=True))) 61 | assert name.lower() in MODULE_REGISTRY[registry], f"{name} is not implemented in ByCha" 62 | cls = MODULE_REGISTRY[registry][name.lower()] 63 | kwargs = {} 64 | for k, v in configs.items(): 65 | kwargs[k] = possible_eval(v) 66 | return cls(**kwargs) 67 | 68 | return register_cls, create_cls, MODULE_REGISTRY[registry] 69 | -------------------------------------------------------------------------------- /mybycha/bycha/utils/txc_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from bycha.utils.io import UniIO 4 | 5 | 6 | def build_txc(config): 7 | """ 8 | Build a txc structure from config 9 | 10 | Args: 11 | config: txc configuration path 12 | 13 | Returns: 14 | - process function 15 | """ 16 | from txc.modules.structures import create_structure 17 | from txc.runtime import build_env 18 | build_env(mode='infer') 19 | if not config: 20 | return None 21 | with UniIO(config) as fin: 22 | config = json.load(fin) 23 | txc = create_structure(config) 24 | 25 | def _process(*args): 26 | for x, unit_in_arg in zip(args, txc.input_names): 27 | txc.buffer_in[unit_in_arg] = x 28 | txc.run() 29 | out = txc.fetch_buffer_out(name=txc.output_names[0]) 30 | return out 31 | 32 | return _process 33 | -------------------------------------------------------------------------------- /mybycha/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | tensorflow==2.4.0 4 | tqdm 5 | sacremoses 6 | sacrebleu 7 | sentencepiece 8 | fastBPE 9 | scipy 10 | scikit-learn 11 | mosestokenizer 12 | nltk 13 | pyyaml 14 | mpi4py 15 | numpy==1.19.2 16 | more-itertools 17 | tabulate 18 | lightseq 19 | -------------------------------------------------------------------------------- /mybycha/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="bycha", 5 | version="0.1.0", 6 | keywords=["Natural Language Processing", "Machine Learning"], 7 | license="MIT Licence", 8 | packages=find_packages(), 9 | include_package_data=True, 10 | platforms="any", 11 | install_requires=open("requirements.txt").readlines(), 12 | zip_safe=False, 13 | 14 | scripts=[], 15 | entry_points={ 16 | 'console_scripts': [ 17 | 'bycha-run = bycha.entries.run:main', 18 | 'bycha-export = bycha.entries.export:main', 19 | 'bycha-preprocess = bycha.entries.preprocess:main', 20 | 'bycha-serve = bycha.entries.serve:main', 21 | 'bycha-serve-model = bycha.entries.serve_model:main', 22 | 'bycha-build-tokenizer = bycha.entries.build_tokenizer:main', 23 | 'bycha-binarize-data = bycha.entries.binarize_data:main' 24 | ] 25 | } 26 | ) 27 | -------------------------------------------------------------------------------- /pics/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/pics/overview.png -------------------------------------------------------------------------------- /pics/sketch_and_generate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/longlongman/DESERT/830562e13a0089e9bb3d77956ab70e606316ae78/pics/sketch_and_generate.png -------------------------------------------------------------------------------- /preparation/fragmenizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .brics_fragmenizer import BRICS_Fragmenizer 2 | from .ring_r_fragmenizer import RING_R_Fragmenizer 3 | from .brics_ring_r_fragmenizer import BRICS_RING_R_Fragmenizer 4 | from .atom_fragmenizer import ATOM_Fragmenizer 5 | -------------------------------------------------------------------------------- /preparation/fragmenizer/atom_fragmenizer.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | 3 | class ATOM_Fragmenizer(): 4 | def __init__(self): 5 | self.type = 'Atom_Fragmenizers' 6 | 7 | def get_bonds(self, mol): 8 | bonds = mol.GetBonds() 9 | return list(bonds) 10 | 11 | def fragmenize(self, mol, dummyStart=1): 12 | bonds = self.get_bonds(mol) 13 | if len(bonds) != 0: 14 | bond_ids = [bond.GetIdx() for bond in bonds] 15 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))] 16 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels) 17 | dummyEnd = dummyStart + len(dummyLabels) - 1 18 | else: 19 | break_mol = mol 20 | dummyEnd = dummyStart - 1 21 | return break_mol, dummyEnd 22 | -------------------------------------------------------------------------------- /preparation/fragmenizer/brics_fragmenizer.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from rdkit.Chem.BRICS import FindBRICSBonds 3 | 4 | class BRICS_Fragmenizer(): 5 | def __inti__(self): 6 | self.type = 'BRICS_Fragmenizers' 7 | 8 | def get_bonds(self, mol): 9 | bonds = [bond[0] for bond in list(FindBRICSBonds(mol))] 10 | return bonds 11 | 12 | def fragmenize(self, mol, dummyStart=1): 13 | # get bonds need to be break 14 | bonds = [bond[0] for bond in list(FindBRICSBonds(mol))] 15 | 16 | # whether the molecule can really be break 17 | if len(bonds) != 0: 18 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds] 19 | 20 | # break the bonds & set the dummy labels for the bonds 21 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))] 22 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels) 23 | dummyEnd = dummyStart + len(dummyLabels) - 1 24 | else: 25 | break_mol = mol 26 | dummyEnd = dummyStart - 1 27 | 28 | return break_mol, dummyEnd 29 | -------------------------------------------------------------------------------- /preparation/fragmenizer/brics_ring_r_fragmenizer.py: -------------------------------------------------------------------------------- 1 | from fragmenizer import BRICS_Fragmenizer, RING_R_Fragmenizer 2 | from rdkit import Chem 3 | 4 | # from brics_fragmenizer import BRICS_Fragmenizer 5 | # from ring_r_fragmenizer import RING_R_Fragmenizer 6 | 7 | class BRICS_RING_R_Fragmenizer(): 8 | def __init__(self): 9 | self.type = 'BRICS_RING_R_Fragmenizer' 10 | self.brics_fragmenizer = BRICS_Fragmenizer() 11 | self.ring_r_fragmenizer = RING_R_Fragmenizer() 12 | 13 | def fragmenize(self, mol, dummyStart=1): 14 | brics_bonds = self.brics_fragmenizer.get_bonds(mol) 15 | ring_r_bonds = self.ring_r_fragmenizer.get_bonds(mol) 16 | bonds = brics_bonds + ring_r_bonds 17 | 18 | if len(bonds) != 0: 19 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds] 20 | bond_ids = list(set(bond_ids)) 21 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))] 22 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels) 23 | dummyEnd = dummyStart + len(dummyLabels) - 1 24 | else: 25 | break_mol = mol 26 | dummyEnd = dummyStart - 1 27 | 28 | return break_mol, dummyEnd 29 | -------------------------------------------------------------------------------- /preparation/fragmenizer/ring_r_fragmenizer.py: -------------------------------------------------------------------------------- 1 | from rdkit import Chem 2 | from utils import get_rings, get_other_atom_idx, find_parts_bonds 3 | from rdkit.Chem.rdchem import BondType 4 | 5 | class RING_R_Fragmenizer(): 6 | def __init__(self): 7 | self.type = 'RING_R_Fragmenizer' 8 | 9 | def bonds_filter(self, mol, bonds): 10 | filted_bonds = [] 11 | for bond in bonds: 12 | bond_type = mol.GetBondBetweenAtoms(bond[0], bond[1]).GetBondType() 13 | if not bond_type is BondType.SINGLE: 14 | continue 15 | f_atom = mol.GetAtomWithIdx(bond[0]) 16 | s_atom = mol.GetAtomWithIdx(bond[1]) 17 | if f_atom.GetSymbol() == '*' or s_atom.GetSymbol() == '*': 18 | continue 19 | if mol.GetBondBetweenAtoms(bond[0], bond[1]).IsInRing(): 20 | continue 21 | filted_bonds.append(bond) 22 | return filted_bonds 23 | 24 | def get_bonds(self, mol): 25 | bonds = [] 26 | rings = get_rings(mol) 27 | if len(rings) > 0: 28 | for ring in rings: 29 | rest_atom_idx = get_other_atom_idx(mol, ring) 30 | bonds += find_parts_bonds(mol, [rest_atom_idx, ring]) 31 | bonds = self.bonds_filter(mol, bonds) 32 | return bonds 33 | 34 | def fragmenize(self, mol, dummyStart=1): 35 | rings = get_rings(mol) 36 | if len(rings) > 0: 37 | bonds = [] 38 | for ring in rings: 39 | rest_atom_idx = get_other_atom_idx(mol, ring) 40 | bonds += find_parts_bonds(mol, [rest_atom_idx, ring]) 41 | bonds = self.bonds_filter(mol, bonds) 42 | if len(bonds) > 0: 43 | bond_ids = [mol.GetBondBetweenAtoms(x, y).GetIdx() for x, y in bonds] 44 | bond_ids = list(set(bond_ids)) 45 | dummyLabels = [(i + dummyStart, i + dummyStart) for i in range(len(bond_ids))] 46 | break_mol = Chem.FragmentOnBonds(mol, bond_ids, dummyLabels=dummyLabels) 47 | dummyEnd = dummyStart + len(dummyLabels) - 1 48 | else: 49 | break_mol = mol 50 | dummyEnd = dummyStart - 1 51 | else: 52 | break_mol = mol 53 | dummyEnd = dummyStart - 1 54 | return break_mol, dummyEnd 55 | -------------------------------------------------------------------------------- /preparation/get_fragment_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rdkit import Chem 3 | import gzip 4 | import pickle 5 | from datetime import datetime 6 | from fragmenizer import BRICS_RING_R_Fragmenizer 7 | from utils import centralize, canonical_frag_smi 8 | 9 | data_path = 'ZINC SDF PATH' 10 | save_pkl_path = 'YOUR PATH' 11 | save_pkl_pattern = 'BRICS_RING_R.{}.pkl' 12 | save_pkl_file = os.path.join(save_pkl_path, save_pkl_pattern) 13 | file_list = os.listdir(data_path)[:] 14 | 15 | vocab = dict() 16 | 17 | save_interval = 1000 * 10000 18 | print_interval = 1000 19 | mo_cnt = 0 20 | 21 | fragmenizer = BRICS_RING_R_Fragmenizer() 22 | 23 | start = datetime.now() 24 | for f_idx, file_name in enumerate(file_list): 25 | file_path = os.path.join(data_path, file_name) 26 | gzip_data = gzip.open(file_path) 27 | with Chem.ForwardSDMolSupplier(gzip_data) as mos: 28 | for mo in mos: 29 | if mo is None: 30 | continue 31 | mo_cnt += 1 32 | 33 | frags, _ = fragmenizer.fragmenize(mo) 34 | frags = Chem.GetMolFrags(frags, asMols=True) 35 | for frag in frags: 36 | frag = centralize(frag) 37 | frag_smi = canonical_frag_smi(Chem.MolToSmiles(frag)) 38 | 39 | if frag_smi not in vocab: 40 | vocab[frag_smi] = frag 41 | 42 | if mo_cnt % save_interval == 0: 43 | with open(save_pkl_file.format(mo_cnt), 'wb') as fw: 44 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL) 45 | 46 | if mo_cnt % print_interval == 0: 47 | now = datetime.now() 48 | time_interval = (now - start).total_seconds() 49 | print('current {} file {} molecule {:.3f} ms/mol'.format(f_idx, mo_cnt, time_interval * 1000 / print_interval)) 50 | start = datetime.now() 51 | 52 | with open(save_pkl_file.format(mo_cnt), 'wb') as fw: 53 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL) 54 | 55 | def mapping_star(mol): 56 | star_mapping = dict() 57 | star_cnt = 0 58 | for atom in mol.GetAtoms(): 59 | if atom.GetSymbol() == '*': 60 | star_cnt += 1 61 | star_mapping[star_cnt] = atom.GetSmarts() 62 | star_mapping[atom.GetSmarts()] = star_cnt 63 | return star_mapping 64 | 65 | vocab_w_mapping = {'PAD': [None, None, 0],'UNK': [None, None, 1], 'BOS': [None, None, 2], 'EOS': [None, None, 3], 'BOB': [None, None, 4], 'EOB': [None, None, 5]} 66 | for key in vocab.keys(): 67 | star_mapping = mapping_star(vocab[key]) 68 | vocab_w_mapping[key] = [vocab[key], star_mapping, len(vocab_w_mapping)] 69 | 70 | with open(save_pkl_file.format('vocab'), 'wb') as fw: 71 | pickle.dump(vocab, fw, protocol=pickle.HIGHEST_PROTOCOL) 72 | -------------------------------------------------------------------------------- /preparation/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import find_parts_bonds 2 | from .utils import get_other_atom_idx 3 | from .utils import get_rings 4 | from .utils import get_bonds 5 | from .utils import centralize 6 | from .utils import canonical_frag_smi 7 | from .utils import get_center 8 | from .utils import get_align_points 9 | from .utils import get_tree 10 | from .utils import tree_linearize 11 | from .utils import get_surrogate_frag 12 | from .utils import get_atom_mapping_between_frag_and_surrogate 13 | from .utils_full import get_dock_fast_with_smiles_with_mol 14 | -------------------------------------------------------------------------------- /preparation/utils/common.py: -------------------------------------------------------------------------------- 1 | # van der Waals radius 2 | ATOM_RADIUS = { 3 | 'C': 1.908, 4 | 'F': 1.75, 5 | 'Cl': 1.948, 6 | 'Br': 2.22, 7 | 'I': 2.35, 8 | 'N': 1.824, 9 | 'O': 1.6612, 10 | 'P': 2.1, 11 | 'S': 2.0, 12 | 'Si': 2.2, # not accurate 13 | 'H': 1.0 14 | } 15 | 16 | # atomic number 17 | ATOMIC_NUMBER = { 18 | 'C': 6, 19 | 'F': 9, 20 | 'Cl': 17, 21 | 'Br': 35, 22 | 'I': 53, 23 | 'N': 7, 24 | 'O': 8, 25 | 'P': 15, 26 | 'S': 16, 27 | 'Si': 14, 28 | 'H': 1 29 | } 30 | 31 | ATOMIC_NUMBER_REVERSE = {v: k for k, v in ATOMIC_NUMBER.items()} 32 | -------------------------------------------------------------------------------- /shape_pretraining/__init__.py: -------------------------------------------------------------------------------- 1 | from .shape_pretraining_dataset import ShapePretrainingDataset 2 | from .shape_pretraining_model import ShapePretrainingModel 3 | from .shape_pretraining_encoder import ShapePretrainingEncoder 4 | from .shape_pretraining_dataset_shard import ShapePretrainingDatasetShard 5 | from .shape_pretraining_dataloader_shard import ShapePretrainingDataLoaderShard 6 | from .shape_pretraining_task_no_regression import ShapePretrainingTaskNoRegression 7 | from .shape_pretraining_task_no_regression_pocket import ShapePretrainingTaskNoRegressionPocket 8 | from .shape_pretraining_criterion_no_regression import ShapePretrainingCriterionNoRegression 9 | from .shape_pretraining_decoder_iterative_no_regression import ShapePretrainingDecoderIterativeNoRegression 10 | from .shape_pretraining_iterator_no_regression import ShapePretrainingIteratorNoRegression 11 | from .shape_pretraining_search_iterative_no_regression import ShapePretrainingSearchIterativeNoRegression 12 | from .shape_pretraining_dataset_pocket import ShapePretrainingDatasetPocket -------------------------------------------------------------------------------- /shape_pretraining/common.py: -------------------------------------------------------------------------------- 1 | # van der Waals radius 2 | ATOM_RADIUS = { 3 | 'C': 1.908, 4 | 'F': 1.75, 5 | 'Cl': 1.948, 6 | 'Br': 2.22, 7 | 'I': 2.35, 8 | 'N': 1.824, 9 | 'O': 1.6612, 10 | 'P': 2.1, 11 | 'S': 2.0, 12 | 'Si': 2.2 # not accurate 13 | } 14 | 15 | # atomic number 16 | ATOMIC_NUMBER = { 17 | 'C': 6, 18 | 'F': 9, 19 | 'Cl': 17, 20 | 'Br': 35, 21 | 'I': 53, 22 | 'N': 7, 23 | 'O': 8, 24 | 'P': 15, 25 | 'S': 16, 26 | 'Si': 14 27 | } 28 | 29 | ATOMIC_NUMBER_REVERSE = {v: k for k, v in ATOMIC_NUMBER.items()} 30 | -------------------------------------------------------------------------------- /shape_pretraining/io.py: -------------------------------------------------------------------------------- 1 | from bycha.utils.io import _InputStream, _OutputStream, _InputBytes, _OutputBytes 2 | import pickle 3 | from bycha.utils.runtime import logger 4 | import random 5 | from bycha.utils.ops import local_seed 6 | 7 | class _MyInputBytes(_InputBytes): 8 | def __init__(self, fake_epoch, *args, shuffle=False, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self._data_iter = None 11 | self._shuffle = shuffle 12 | self._fake_epoch = fake_epoch 13 | 14 | def __next__(self): 15 | try: 16 | if self._idx >= len(self._fins): 17 | raise IndexError 18 | if not self._data_iter: 19 | data = pickle.load(self._fins[self._idx]) 20 | if self._shuffle: 21 | with local_seed((self._fake_epoch * len(self._fins)) + self._idx): 22 | old_state = random.getstate() 23 | random.seed((self._fake_epoch * len(self._fins)) + self._idx) 24 | random.shuffle(data) 25 | random.setstate(old_state) 26 | self._data_iter = iter(data) 27 | sample = next(self._data_iter) 28 | return sample 29 | except StopIteration: 30 | self._idx += 1 31 | self._data_iter = None 32 | sample = self.__next__() 33 | return sample 34 | except IndexError: 35 | raise StopIteration 36 | 37 | def reset(self): 38 | self._idx = 0 39 | for fin in self._fins: 40 | fin.seek(0) 41 | self._data_iter = None 42 | 43 | class MyUniIO(_InputStream, _OutputStream, _MyInputBytes, _OutputBytes): 44 | def __init__(self, path, fake_epoch, mode='r', encoding='utf8', shuffle=False): 45 | pass 46 | 47 | def __new__(cls, path, fake_epoch, mode='r', encoding='utf8', shuffle=False): 48 | if 'r' in mode.lower(): 49 | if 'b' in mode.lower(): 50 | return _MyInputBytes(fake_epoch, path, mode=mode, shuffle=shuffle) 51 | return _InputStream(path, encoding=encoding) 52 | elif 'w' in mode.lower(): 53 | if 'b' in mode.lower(): 54 | return _OutputBytes(path, mode=mode) 55 | return _OutputStream(path, encoding=encoding) 56 | logger.warning(f'Not support file mode: {mode}') 57 | raise ValueError 58 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from bycha.datasets import register_dataset 3 | from bycha.datasets.in_memory_dataset import InMemoryDataset 4 | from bycha.utils.runtime import logger, progress_bar 5 | from .utils import get_mol_centroid, centralize 6 | 7 | @register_dataset 8 | class ShapePretrainingDataset(InMemoryDataset): 9 | def __init__(self, 10 | path): 11 | super().__init__(path) 12 | 13 | vocab_path = self._path['vocab'] 14 | with open(vocab_path, 'rb') as fr: 15 | self._vocab = pickle.load(fr) 16 | 17 | def _load(self): 18 | self._data = [] 19 | 20 | samples_path = self._path['samples'] 21 | 22 | with open(samples_path, 'rb') as fr: 23 | samples = pickle.load(fr) 24 | 25 | accecpted, discarded = 0, 0 26 | for i, sample in enumerate(progress_bar(samples, desc='Loading Samples...')): 27 | try: 28 | self._data.append(self._full_callback(sample)) 29 | accecpted += 1 30 | except Exception: 31 | logger.warning('sample {} is discarded'.format(i)) 32 | discarded += 1 33 | 34 | self._length = len(self._data) 35 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples') 36 | 37 | def _callback(self, sample): 38 | # centralize a molecule and translate its fragments 39 | mol = sample[0] 40 | centroid = get_mol_centroid(mol) 41 | mol = centralize(mol) 42 | 43 | fragment_list = [] 44 | for fragment in sample[1]: 45 | if not fragment[3] is None: 46 | trans_vec = fragment[3] - centroid 47 | else: 48 | trans_vec = fragment[3] 49 | fragment_list.append({ 50 | 'vocab_id': fragment[0], 51 | 'vocab_key': fragment[1], 52 | 'frag_smi': fragment[2], 53 | 'trans_vec': trans_vec, 54 | 'rotate_mat': fragment[4] 55 | }) 56 | 57 | tree_list = sample[2] 58 | 59 | return { 60 | 'mol': mol, 61 | 'frag_list': fragment_list, 62 | 'tree_list': tree_list 63 | } 64 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_dataset_pocket.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from bycha.datasets import register_dataset 3 | from bycha.datasets.in_memory_dataset import InMemoryDataset 4 | from bycha.utils.runtime import logger, progress_bar 5 | from .utils import get_mol_centroid, centralize 6 | 7 | @register_dataset 8 | class ShapePretrainingDatasetPocket(InMemoryDataset): 9 | def __init__(self, 10 | path): 11 | super().__init__(path) 12 | 13 | vocab_path = self._path['vocab'] 14 | with open(vocab_path, 'rb') as fr: 15 | self._vocab = pickle.load(fr) 16 | 17 | def _load(self): 18 | self._data = [] 19 | 20 | samples_path = self._path['samples'] 21 | 22 | with open(samples_path, 'rb') as fr: 23 | samples = pickle.load(fr) 24 | 25 | accecpted, discarded = 0, 0 26 | for i, sample in enumerate(progress_bar(samples, desc='Loading Samples...')): 27 | try: 28 | self._data.append(self._full_callback(sample)) 29 | accecpted += 1 30 | except Exception: 31 | logger.warning('sample {} is discarded'.format(i)) 32 | discarded += 1 33 | 34 | self._length = len(self._data) 35 | logger.info(f'Totally accept {accecpted} samples, discard {discarded} samples') 36 | 37 | def _callback(self, sample): 38 | return { 39 | 'shape': sample 40 | } 41 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_dataset_shard.py: -------------------------------------------------------------------------------- 1 | from bycha.datasets import register_dataset 2 | from bycha.datasets.streaming_dataset import StreamingDataset 3 | from .io import MyUniIO 4 | from bycha.utils.runtime import logger 5 | import pickle 6 | from .utils import get_mol_centroid, centralize, set_atom_prop 7 | 8 | @register_dataset 9 | class ShapePretrainingDatasetShard(StreamingDataset): 10 | def __init__(self, 11 | path, 12 | vocab_path, 13 | sample_each_shard, 14 | shuffle=False): 15 | super().__init__(path) 16 | self._sample_each_shard = sample_each_shard 17 | self._shuffle = shuffle 18 | self._fake_epoch = 0 19 | 20 | with open(vocab_path, 'rb') as fr: 21 | self._vocab = pickle.load(fr) 22 | 23 | def build(self, collate_fn=None, preprocessed=False): 24 | self._collate_fn = collate_fn 25 | if self._path: 26 | self._fin = MyUniIO(self._path, self._fake_epoch, mode='rb', shuffle=self._shuffle) 27 | 28 | def __iter__(self): 29 | for sample in self._fin: 30 | try: 31 | sample = self._full_callback(sample) 32 | yield sample 33 | except StopIteration: 34 | raise StopIteration 35 | except Exception as e: 36 | logger.warning(e) 37 | 38 | def reset(self): 39 | self._pos = 0 40 | self._fin = MyUniIO(self._path, self._fake_epoch, mode='rb', shuffle=self._shuffle) 41 | self._fake_epoch += 1 42 | 43 | def _callback(self, sample): 44 | # centralize a molecule and translate its fragments 45 | mol = sample[0] 46 | centroid = get_mol_centroid(mol) 47 | mol = centralize(mol) 48 | 49 | for atom in mol.GetAtoms(): 50 | set_atom_prop(atom, 'origin_atom_idx', str(atom.GetIdx())) 51 | 52 | fragment_list = [] 53 | for fragment in sample[1]: 54 | if not fragment[3] is None: 55 | trans_vec = fragment[3] - centroid 56 | else: 57 | trans_vec = fragment[3] 58 | fragment_list.append({ 59 | 'vocab_id': fragment[0], 60 | 'vocab_key': fragment[1], 61 | 'frag_smi': fragment[2], 62 | 'trans_vec': trans_vec, 63 | 'rotate_mat': fragment[4] 64 | }) 65 | 66 | tree_list = sample[2] 67 | 68 | return { 69 | 'mol': mol, 70 | 'frag_list': fragment_list, 71 | 'tree_list': tree_list 72 | } 73 | 74 | def finalize(self): 75 | self._fin.close() 76 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_encoder.py: -------------------------------------------------------------------------------- 1 | from bycha.modules.encoders import register_encoder 2 | from bycha.modules.encoders.transformer_encoder import TransformerEncoder 3 | from bycha.modules.layers.feed_forward import FFN 4 | import torch 5 | 6 | @register_encoder 7 | class ShapePretrainingEncoder(TransformerEncoder): 8 | def __init__(self, patch_size, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self._patch_size = patch_size 11 | 12 | def build(self, 13 | embed, 14 | special_tokens): 15 | super().build(embed, special_tokens) 16 | self._patch_ffn = FFN(self._patch_size**3, self._d_model, self._d_model) 17 | 18 | def _forward(self, src): 19 | bz, sl = src.size(0), src.size(1) 20 | 21 | x = self._patch_ffn(src) 22 | if self._embed_scale is not None: 23 | x = x * self._embed_scale 24 | if self._pos_embed is not None: 25 | pos = torch.arange(sl).unsqueeze(0).repeat(bz, 1).to(x.device) 26 | x = x + self._pos_embed(pos) 27 | if self._embed_norm is not None: 28 | x = self._embed_norm(x) 29 | x = self._embed_dropout(x) 30 | 31 | src_padding_mask = torch.zeros((bz, sl), dtype=torch.bool).to(x.device) 32 | x = x.transpose(0, 1) 33 | for layer in self._layers: 34 | x = layer(x, src_key_padding_mask=src_padding_mask) 35 | 36 | if self._norm is not None: 37 | x = self._norm(x) 38 | 39 | if self._return_seed: 40 | encoder_out = x[1:], src_padding_mask[:, 1:], x[0] 41 | else: 42 | encoder_out = x, src_padding_mask 43 | 44 | return encoder_out 45 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_iterator_no_regression.py: -------------------------------------------------------------------------------- 1 | from bycha.modules.encoders import register_encoder 2 | from bycha.modules.encoders.transformer_encoder import TransformerEncoder 3 | from bycha.modules.layers.feed_forward import FFN 4 | import torch 5 | import torch.nn as nn 6 | 7 | @register_encoder 8 | class ShapePretrainingIteratorNoRegression(TransformerEncoder): 9 | def __init__(self, *args, **kwargs): 10 | super().__init__(*args, **kwargs) 11 | 12 | def build(self, 13 | embed, 14 | special_tokens, 15 | trans_size, 16 | rotat_size): 17 | super().build(embed, special_tokens) 18 | from bycha.modules.layers.embedding import Embedding 19 | 20 | self._trans_emb = Embedding(vocab_size=trans_size, 21 | d_model=embed.weight.shape[1]) 22 | self._rotat_emb = Embedding(vocab_size=rotat_size, 23 | d_model=embed.weight.shape[1]) 24 | 25 | self._logits_output_proj = nn.Linear(embed.weight.shape[1], 26 | embed.weight.shape[0], 27 | bias=False) 28 | self._logits_output_proj.weight = embed.weight 29 | self._trans_output_proj = nn.Linear(self._trans_emb.weight.shape[1], 30 | self._trans_emb.weight.shape[0], 31 | bias=False) 32 | self._trans_output_proj.weight = self._trans_emb.weight 33 | self._rotat_output_proj = nn.Linear(self._rotat_emb.weight.shape[1], 34 | self._rotat_emb.weight.shape[0], 35 | bias=False) 36 | self._rotat_output_proj.weight = self._rotat_emb.weight 37 | 38 | def _forward(self, logits, trans, r_mat, padding_mask): 39 | bz, sl = logits.size(0), logits.size(1) 40 | logits_pred = logits.argmax(-1) 41 | trans_pred = trans.argmax(-1) 42 | r_mat_pred = r_mat.argmax(-1) 43 | 44 | x = self._embed(logits_pred) 45 | x = x + self._trans_emb(trans_pred) 46 | x = x + self._rotat_emb(r_mat_pred) 47 | 48 | if self._embed_scale is not None: 49 | x = x * self._embed_scale 50 | if self._pos_embed is not None: 51 | pos = torch.arange(sl).unsqueeze(0).repeat(bz, 1).to(x.device) 52 | x = x + self._pos_embed(pos) 53 | if self._embed_norm is not None: 54 | x = self._embed_norm(x) 55 | x = self._embed_dropout(x) 56 | 57 | src_padding_mask = padding_mask 58 | x = x.transpose(0, 1) 59 | for layer in self._layers: 60 | x = layer(x, src_key_padding_mask=src_padding_mask) 61 | 62 | if self._norm is not None: 63 | x = self._norm(x) 64 | 65 | x = x.transpose(0, 1) 66 | logits = self._logits_output_proj(x) 67 | trans = self._trans_output_proj(x) 68 | r_mat = self._rotat_output_proj(x) 69 | 70 | return logits, trans, r_mat 71 | -------------------------------------------------------------------------------- /shape_pretraining/shape_pretraining_model.py: -------------------------------------------------------------------------------- 1 | from bycha.models import register_model 2 | from bycha.models.encoder_decoder_model import EncoderDecoderModel 3 | import torch 4 | 5 | @register_model 6 | class ShapePretrainingModel(EncoderDecoderModel): 7 | def __init__(self, 8 | encoder, 9 | decoder, 10 | d_model, 11 | share_embedding=None, 12 | path=None, 13 | no_shape=False, 14 | no_trans=False, 15 | no_rotat=False): 16 | super().__init__(encoder=encoder, 17 | decoder=decoder, 18 | d_model=d_model, 19 | share_embedding=share_embedding, 20 | path=path) 21 | self._no_shape = no_shape 22 | self._no_trans = no_trans 23 | self._no_rotat = no_rotat 24 | 25 | def forward(self, 26 | shape, 27 | shape_patches, 28 | input_frag_idx, 29 | input_frag_idx_mask, 30 | input_frag_trans, 31 | input_frag_trans_mask, 32 | input_frag_r_mat, 33 | input_frag_r_mat_mask): 34 | memory, memory_padding_mask = self._encoder(src=shape_patches) 35 | if self._no_shape: 36 | memory = torch.zeros_like(memory) 37 | if self._no_trans: 38 | input_frag_trans = torch.zeros_like(input_frag_trans) 39 | if self._no_rotat: 40 | input_frag_r_mat = torch.zeros_like(input_frag_r_mat) 41 | logits, trans, r_mat = self._decoder(input_frag_idx=input_frag_idx, 42 | input_frag_trans=input_frag_trans, 43 | input_frag_r_mat=input_frag_r_mat, 44 | memory=memory, 45 | memory_padding_mask=memory_padding_mask) 46 | return (logits, trans, r_mat) 47 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DIR=$(dirname $(readlink -f "$0")) 3 | cd $DIR 4 | source bashutil.sh 5 | 6 | declare -A BASH_ARGS=( 7 | # bycha config 8 | [config]=./configs/training.yaml 9 | [lib]=shape_pretraining 10 | [args]= 11 | ) 12 | parse_args "$@" 13 | 14 | MPIRUN bycha-run \ 15 | --config $config \ 16 | --lib $lib \ 17 | --task.trainer.tensorboard_dir ❗❗❗FILL_THIS❗❗❗ \ 18 | --task.trainer.save_model_dir ❗❗❗FILL_THIS❗❗❗ \ 19 | --task.trainer.restore_path ❗❗❗FILL_THIS❗❗❗ \ 20 | ${BASH_ARGS[args]} 21 | 22 | --------------------------------------------------------------------------------