├── .gitattributes ├── .gitignore ├── Image └── DeepHop.jpg ├── LICENSE ├── README.md ├── data ├── cond-test.txt ├── cond-train.txt ├── cond-val.txt ├── src-test.txt ├── src-train.txt ├── src-val.txt ├── tgt-test.txt ├── tgt-train.txt └── tgt-val.txt ├── deephop ├── Graph3dConv.py ├── calc_SC_RDKit.py ├── calc_scaffold_smilarity.py ├── data_loader.py ├── env.yaml ├── graph_embedding.py ├── make_pair.py ├── onmt │ ├── GAT.py │ ├── GATGATE.py │ ├── GCN.py │ ├── MPNNs │ │ ├── MPNN.py │ │ ├── MPNNv2.py │ │ ├── MPNNv3.py │ │ ├── MessageFunction.py │ │ ├── UpdateFunction.py │ │ └── nnet.py │ ├── MYGCN.py │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── cnn_decoder.py │ │ ├── decoder.py │ │ ├── ensemble.py │ │ └── transformer.py │ ├── encoders │ │ ├── __init__.py │ │ ├── audio_encoder.py │ │ ├── cnn_encoder.py │ │ ├── encoder.py │ │ ├── image_encoder.py │ │ ├── mean_encoder.py │ │ ├── rnn_encoder.py │ │ └── transformer.py │ ├── inputters │ │ ├── __init__.py │ │ ├── audio_dataset.py │ │ ├── dataset_base.py │ │ ├── image_dataset.py │ │ ├── inputter.py │ │ └── text_dataset.py │ ├── model_builder.py │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ ├── model_saver.py │ │ ├── sru.py │ │ └── stacked_rnn.py │ ├── modules │ │ ├── __init__.py │ │ ├── average_attn.py │ │ ├── conv_multi_step_attention.py │ │ ├── copy_generator.py │ │ ├── embeddings.py │ │ ├── gate.py │ │ ├── global_attention.py │ │ ├── multi_headed_attn.py │ │ ├── position_ffn.py │ │ ├── sparse_activations.py │ │ ├── sparse_losses.py │ │ ├── structured_attention.py │ │ ├── util_class.py │ │ └── weight_norm.py │ ├── myutils.py │ ├── opts.py │ ├── tests │ │ ├── output_hyp.txt │ │ ├── pull_request_chk.sh │ │ ├── rebuild_test_models.sh │ │ ├── test_attention.py │ │ ├── test_models.py │ │ ├── test_models.sh │ │ ├── test_preprocess.py │ │ └── test_simple.py │ ├── train_single.py │ ├── trainer.py │ ├── translate │ │ ├── __init__.py │ │ ├── beam.py │ │ ├── penalties.py │ │ ├── translation.py │ │ ├── translation_server.py │ │ └── translator.py │ └── utils │ │ ├── __init__.py │ │ ├── cnn_factory.py │ │ ├── distributed.py │ │ ├── logging.py │ │ ├── loss.py │ │ ├── masking.py │ │ ├── misc.py │ │ ├── optimizers.py │ │ ├── report_manager.py │ │ ├── rnn_factory.py │ │ └── statistics.py ├── preprocess.py ├── protein_emb.pkl ├── pvalue_score_predictions.py ├── replace_torchtext │ ├── batch.py │ ├── field.py │ └── iterator.py ├── split_data.py ├── train.py └── translate.py ├── known_pvalue_data ├── CHEMBL1868.csv ├── CHEMBL1906.csv ├── CHEMBL1907601.csv ├── CHEMBL1907602.csv ├── CHEMBL1974.csv ├── CHEMBL2094127.csv ├── CHEMBL2095191.csv ├── CHEMBL2111367.csv ├── CHEMBL2147.csv ├── CHEMBL2148.csv ├── CHEMBL2276.csv ├── CHEMBL2431.csv ├── CHEMBL2534.csv ├── CHEMBL258.csv ├── CHEMBL2637.csv ├── CHEMBL267.csv ├── CHEMBL2695.csv ├── CHEMBL2828.csv ├── CHEMBL2835.csv ├── CHEMBL2850.csv ├── CHEMBL299.csv ├── CHEMBL3130.csv ├── CHEMBL3145.csv ├── CHEMBL3231.csv ├── CHEMBL3234.csv ├── CHEMBL3267.csv ├── CHEMBL3529.csv ├── CHEMBL3582.csv ├── CHEMBL3650.csv ├── CHEMBL3778.csv ├── CHEMBL3788.csv ├── CHEMBL3861.csv ├── CHEMBL3905.csv ├── CHEMBL4005.csv ├── CHEMBL4203.csv ├── CHEMBL4204.csv ├── CHEMBL4247.csv ├── CHEMBL4501.csv ├── CHEMBL4523.csv ├── CHEMBL4578.csv ├── CHEMBL4722.csv ├── CHEMBL4816.csv ├── CHEMBL4898.csv ├── CHEMBL5145.csv ├── CHEMBL5407.csv └── CHEMBL5543.csv └── score ├── data_loader.py ├── env.yaml ├── evaluate.py ├── mtdnn.py ├── summary_one_task.py ├── total_mtr ├── seed0 │ ├── checkpoint │ ├── ckpt-43.data-00000-of-00002 │ ├── ckpt-43.data-00001-of-00002 │ └── ckpt-43.index ├── seed1 │ ├── checkpoint │ ├── ckpt-53.data-00000-of-00002 │ ├── ckpt-53.data-00001-of-00002 │ └── ckpt-53.index └── seed2 │ ├── checkpoint │ ├── ckpt-37.data-00000-of-00002 │ ├── ckpt-37.data-00001-of-00002 │ └── ckpt-37.index ├── train.py └── util.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | -------------------------------------------------------------------------------- /Image/DeepHop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/Image/DeepHop.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 prokia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepHop 2 | 3 | Supporting Information for the paper "[Deep Scaffold Hopping with Multimodal Transformer Neural Networks](https://chemrxiv.org/engage/chemrxiv/article-details/60c75035f96a005e48287d8f)" 4 | 5 | DeepHop is a multi-modal molecular transformation framework. It accepts a hit molecule and an interest target protein sequence as inputs and design isofunctional molecular structures to the source compound. 6 | 7 | ![DeepHop](Image/DeepHop.jpg) 8 | 9 | 10 | 11 | ## Installation 12 | 13 | Create a conda environment for QSAR-scorer: 14 | ```shell script 15 | conda create env -f=score/env.yaml 16 | ``` 17 | 18 | Create a conda environment for Deephop: 19 | ```shell script 20 | conda create env -f=deephop/env.yaml 21 | ``` 22 | Note that you should replace three source files(batch.py, field.py, iterator.py) of the torchtext library in "your deep hop env path/python3.7/site-packages/torchtext/data" with the corrsponding three files contained in "deephop/replace_torchtext" since we have modified the codes. 23 | 24 | 25 | ## Scaffold hopping pairs construction 26 | For the convenience of illustration, We assume that: 27 | you code extract in /data/u1/projects/mget_3d 28 | environment for deephop is named deephop_env 29 | 30 | ```shell script 31 | cd /data/u1/projects/mget_3d 32 | conda activate deephop_env 33 | ``` 34 | you can use make_pair.py to generate hopping pairs. 35 | 36 | ## Dataset split 37 | ```shell script 38 | python split_data.py -out_dir data40_tue_3d/0.60 -protein_group data40 -target_uniq_rate 0.6 -hopping_pairs_dir hopping_pairs_with_scaffold 39 | ``` 40 | 41 | ## Data preprocessing 42 | ```shell script 43 | python preprocess.py -train_src data40_tue_3d/0.60/src-train.txt -train_tgt data40_tue_3d/0.60/tgt-train.txt -train_cond data40_tue_3d/0.60/cond-train.txt -valid_src data40_tue_3d/0.60/src-val.txt -valid_tgt data40_tue_3d/0.60/tgt-val.txt -valid_cond data40_tue_3d/0.60/cond-val.txt -save_data data40_tue_3d/0.60/seqdata -share_vocab -src_seq_length 1000 -tgt_seq_length 1000 -src_vocab_size 1000 -tgt_vocab_size 1000 -with_3d_confomer 44 | ``` 45 | 46 | ## Model training 47 | ```shell script 48 | python train.py -condition_dim 768 -use_graph_embedding -arch after_encoding -data data40_tue_3d/0.60/seqdata -save_model experiments/data40_tue_3d/after/models/model -seed 42 -save_checkpoint_steps 158 -keep_checkpoint 400 -train_steps 95193 -param_init 0 -param_init_glorot -max_generator_batches 32 -batch_size 8192 -batch_type tokens -normalization tokens -max_grad_norm 0 -accum_count 4 -optim adam -adam_beta1 0.9 -adam_beta2 0.998 -decay_method noam -warmup_steps 475 -learning_rate 2 -label_smoothing 0.0 -report_every 10 -layers 4 -rnn_size 256 -word_vec_size 256 -encoder_type transformer -decoder_type transformer -dropout 0.1 -position_encoding -share_embeddings -global_attention general -global_attention_function softmax -self_attn_type scaled-dot -heads 8 -transformer_ff 2048 -log_file experiments/data40_tue_3d/after/train.log -tensorboard -tensorboard_log_dir experiments/data40_tue_3d/after/logs -world_size 4 -gpu_ranks 0 1 2 3 -valid_steps 475 -valid_batch_size 32 49 | ``` 50 | 51 | ## Hops generation 52 | To generate the output SMILES by loading saved model 53 | ```shell script 54 | python translate.py -condition_dim 768 -use_graph_embedding -arch after_encoding -with_3d_confomer -model /data/u1/projects/mget_3d/experiments/data40_tue/3d_gcn/models/model_step_9500.pt -gpu 0 -src data40_tue_3d/src-test.txt -cond data40_tue_3d/cond-test.txt -output /data/u1/projects/mget_3d/summary_tue/data40/after/9500/pred.txt -beam_size 10 -n_best 10 -batch_size 16 -replace_unk -max_length 200 -fast -use_protein40 55 | ``` 56 | 57 | ## Evaluation 58 | 59 | To evaluate our model 60 | ```shell script 61 | python pvalue_score_predictions.py -beam_size 10 -src summary_tue/data40/after/9500/src-test-protein.txt -prediction /data/u1/projects/mget_3d/summary_tue/data40/after/9500/pred.txt -score_file /data/u1/projects/mget_3d/summary_tue/data40/after/9500/score.csv -invalid_smiles -cond summary_tue/data40/after/9500/cond-test-protein.txt -train_data_dir /data/u1/projects/mget_3d/data40_tue_3d -scorer_model_dir /data/u1/projects/score/total_mtr -pvalue_dir /data/u1/projects/mget_3d/score_train_data 62 | ``` 63 | where the final result report is saved at /data/u1/projects/mget_3d/summary_tue/data40/after/9500/score_final.csv 64 | 65 | ## Citation 66 | 67 | Please cite the following paper if you use this code in your work. 68 | ```bibtex 69 | @article{zheng2021deep, 70 | title={Deep scaffold hopping with multimodal transformer neural networks}, 71 | author={Zheng, Shuangjia and Lei, Zengrong and Ai, Haitao and Chen, Hongming and Deng, Daiguo and Yang, Yuedong}, 72 | journal={Journal of cheminformatics}, 73 | volume={13}, 74 | number={1}, 75 | pages={1--15}, 76 | year={2021}, 77 | publisher={Springer} 78 | } 79 | ``` 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /deephop/Graph3dConv.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import MessagePassing 6 | from torch_geometric.utils import add_self_loops 7 | 8 | class GraphConv(MessagePassing): 9 | def __init__(self, coors, out_channels_1, out_features, label_dim=1, dropout=0): 10 | """ 11 | label_dim - dimention of node reprezentaion 12 | coors - dimension of position 13 | out_channels_1 - dimension of convolution on each reprezentation chanal 14 | * autput will have dimention label_dim * out_channels_1 15 | out_features - dimension of node representation after graphConv 16 | """ 17 | super(GraphConv, self).__init__(aggr='add') 18 | self.lin_in = torch.nn.Linear(coors, label_dim * out_channels_1) 19 | self.lin_out = torch.nn.Linear(label_dim * out_channels_1, out_features) 20 | self.dropout = dropout 21 | 22 | def forward(self, x, pos, edge_index): 23 | """ 24 | x - feature matrix of the whole graph [num_nodes, label_dim] 25 | pos - node position matrix [num_nodes, coors] 26 | edge_index - graph connectivity [2, num_edges] 27 | """ 28 | 29 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # num_edges = num_edges + num_nodes 30 | 31 | return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add') # [N, out_channels, label_dim] 32 | 33 | def message(self, pos_i, pos_j, x_j): 34 | """ 35 | pos_i [num_edges, coors] 36 | pos_j [num_edges, coors] 37 | x_j [num_edges, label_dim] 38 | """ 39 | 40 | tmp = pos_j - pos_i 41 | L = self.lin_in(tmp) # [num_edges, out_channels] 42 | num_nodes, label_dim = list(x_j.size()) 43 | label_dim_out_channels_1 = list(L.size())[1] 44 | 45 | X = F.relu(L) 46 | Y = x_j 47 | X = torch.t(X) 48 | X = F.dropout(X, p=self.dropout, training=self.training) 49 | result = torch.t( 50 | (X.view(label_dim, -1, num_nodes) * torch.t(Y).unsqueeze(1)).reshape(label_dim_out_channels_1, num_nodes)) 51 | return result 52 | 53 | def update(self, aggr_out): 54 | """ 55 | aggr_out [num_nodes, label_dim, out_channels] 56 | """ 57 | aggr_out = self.lin_out(aggr_out) # [num_nodes, label_dim, out_features] 58 | aggr_out = F.relu(aggr_out) 59 | aggr_out = F.dropout(aggr_out, p=self.dropout, training=self.training) 60 | 61 | return aggr_out 62 | 63 | class Graph3dConv(GraphConv): 64 | def forward(self, g, features): 65 | """ 66 | x - feature matrix of the whole graph [num_nodes, label_dim] 67 | pos - node position matrix [num_nodes, coors] 68 | edge_index - graph connectivity [2, num_edges] 69 | """ 70 | # x = features 71 | x, pos = torch.split(features, [features.shape[-1] - 3, 3], dim=-1) 72 | edge_index = torch.stack(g.edges()).to(x.device) 73 | edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # num_edges = num_edges + num_nodes 74 | 75 | return self.propagate(edge_index=edge_index, x=x, pos=pos, aggr='add') # [N, out_channels, label_dim] -------------------------------------------------------------------------------- /deephop/calc_SC_RDKit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rdkit import Chem 3 | from rdkit.Chem import AllChem, rdShapeHelpers 4 | from rdkit.Chem.FeatMaps import FeatMaps 5 | from rdkit import RDConfig 6 | 7 | # Set up features to use in FeatureMap 8 | fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') 9 | fdef = AllChem.BuildFeatureFactory(fdefName) 10 | 11 | fmParams = {} 12 | for k in fdef.GetFeatureFamilies(): 13 | fparams = FeatMaps.FeatMapParams() 14 | fmParams[k] = fparams 15 | 16 | keep = ('Donor', 'Acceptor', 'NegIonizable', 'PosIonizable', 17 | 'ZnBinder', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe') 18 | 19 | def get_FeatureMapScore(query_mol, ref_mol): 20 | featLists = [] 21 | for m in [query_mol, ref_mol]: 22 | rawFeats = fdef.GetFeaturesForMol(m) 23 | # filter that list down to only include the ones we're intereted in 24 | featLists.append([f for f in rawFeats if f.GetFamily() in keep]) 25 | fms = [FeatMaps.FeatMap(feats=x, weights=[1] * len(x), params=fmParams) for x in featLists] 26 | fms[0].scoreMode=FeatMaps.FeatMapScoreMode.Best 27 | fm_score = fms[0].ScoreFeats(featLists[1]) / min(fms[0].GetNumFeatures(), len(featLists[1])) 28 | 29 | return fm_score 30 | 31 | def calc_SC_RDKit_score(query_mol, ref_mol): 32 | fm_score = get_FeatureMapScore(query_mol, ref_mol) 33 | 34 | protrude_dist = rdShapeHelpers.ShapeProtrudeDist(query_mol, ref_mol, 35 | allowReordering=False) 36 | SC_RDKit_score = 0.5*fm_score + 0.5*(1 - protrude_dist) 37 | 38 | return SC_RDKit_score 39 | -------------------------------------------------------------------------------- /deephop/calc_scaffold_smilarity.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import multiprocessing 4 | 5 | from rdkit import Chem, DataStructs 6 | from rdkit.Chem import AllChem 7 | from rdkit.Chem.Scaffolds import MurckoScaffold 8 | import os 9 | from data_loader import load_data_frame 10 | 11 | def calc_scaffold_similarity(s1: str, s2: str) -> float: 12 | mol1 = Chem.MolFromSmiles(s1) 13 | mol2 = Chem.MolFromSmiles(s2) 14 | if mol1 is None or mol2 is None: 15 | return -1.0 16 | try: 17 | scafold1 = MurckoScaffold.GetScaffoldForMol(mol1) 18 | scafold2 = MurckoScaffold.GetScaffoldForMol(mol2) 19 | f1 = AllChem.GetMorganFingerprint(scafold1, 3) 20 | f2 = AllChem.GetMorganFingerprint(scafold2, 3) 21 | return DataStructs.TanimotoSimilarity(f1, f2) 22 | except Exception: 23 | return -1.0 24 | 25 | 26 | def process_one_protein(args): 27 | file, save_path = args 28 | save_dir = os.path.dirname(save_path) 29 | os.makedirs(save_dir, exist_ok=True) 30 | df = load_data_frame(file) 31 | df["delta_p"] = df["delta_p"].apply(lambda x: abs(x)) 32 | df["score_scaffold"] = df[["ref_smiles", "target_smiles"]].apply(lambda x: calc_scaffold_similarity(x[0], x[1]), 33 | axis=1) 34 | df.to_csv(save_path, index=False) 35 | 36 | 37 | if __name__ == '__main__': 38 | arg_list = [] 39 | root_dir = "/data/u1/projects/mget/hopping_pairs" 40 | out_dir = "/data/u1/projects/mget/hopping_pairs_with_scaffold" 41 | for f in os.listdir(root_dir): 42 | arg_list.append((f"{root_dir}/{f}", f"{out_dir}/{f}")) 43 | 44 | pool = multiprocessing.Pool() 45 | results = pool.map(process_one_protein, arg_list) 46 | pool.close() 47 | pool.join() 48 | -------------------------------------------------------------------------------- /deephop/data_loader.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import csv 4 | from typing import List 5 | 6 | import pandas as pd 7 | import rdkit.Chem.AllChem as Chem 8 | 9 | 10 | def load_sdf_data(sdf_file): 11 | mols = Chem.SDMolSupplier(sdf_file) 12 | mols = [m for m in mols if m is not None] 13 | if len(mols) == 0: 14 | return None 15 | colums = list(mols[0].GetPropsAsDict().keys()) 16 | colums.sort() 17 | colums.append('smiles') 18 | data_list = [] 19 | for m in mols: 20 | kv = m.GetPropsAsDict() 21 | if 'smiles' not in kv.keys(): 22 | kv['smiles'] = Chem.MolToSmiles(m) 23 | data_list.append([kv[k] for k in colums]) 24 | 25 | return pd.DataFrame(data_list, columns=colums) 26 | 27 | 28 | def detect_delimiter(source_file): 29 | with open(source_file) as r: 30 | first_line = r.readline() 31 | if '\t' in first_line: 32 | return '\t' 33 | if ',' in first_line: 34 | return ',' 35 | return ' ' 36 | 37 | 38 | def has_header(head_line: List[str]): 39 | for s in head_line: 40 | try: 41 | mol = Chem.MolFromSmiles(s) 42 | if mol is not None: 43 | return False 44 | except: 45 | continue 46 | return True 47 | 48 | 49 | def get_csv_header(path: str) -> List[str]: 50 | """ 51 | Returns the header of a data CSV file. 52 | 53 | :param path: Path to a CSV file. 54 | :return: A list of strings containing the strings in the comma-separated header. 55 | """ 56 | with open(path) as f: 57 | header = next(csv.reader(f)) 58 | 59 | return header 60 | 61 | 62 | def get_xls_header(path: str) -> List[str]: 63 | pass 64 | 65 | 66 | def load_data_frame(source_file) -> object: 67 | if source_file.endswith('csv') or source_file.endswith('txt') or source_file.endswith('smi'): 68 | df = pd.read_csv(source_file, delimiter=detect_delimiter(source_file)) 69 | elif source_file.endswith('xls') or source_file.endswith('xlsx'): 70 | df = pd.read_excel(source_file) 71 | elif source_file.endswith('sdf'): 72 | df = load_sdf_data(source_file) 73 | else: 74 | print("can not read %s" % source_file) 75 | df = None 76 | return df 77 | 78 | 79 | if __name__ == '__main__': 80 | pass 81 | -------------------------------------------------------------------------------- /deephop/graph_embedding.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | from functools import partial 4 | from typing import List 5 | 6 | from split_data import TASKS 7 | import pickle 8 | import numpy 9 | 10 | protein_embedding = pickle.load(open('protein_emb.pkl', 'rb')) 11 | 12 | def __get_emb(use_graph_embedding: bool, embedding_dim: int, label_list: List[int]) -> numpy.ndarray: 13 | if embedding_dim == 0: 14 | return numpy.array([0.0]) 15 | if use_graph_embedding: 16 | return numpy.stack([protein_embedding[TASKS[index]] for index in label_list]) 17 | else: 18 | # one_hot encoding 19 | v = [[0.0] * embedding_dim] * len(label_list) 20 | for i, label in enumerate(label_list): 21 | v[i][label] = 1.0 22 | return numpy.array(v) 23 | 24 | __condition_transformer = None 25 | 26 | def init_condition_transformer(use_graph_embedding, embedding_dim): 27 | global __condition_transformer 28 | __condition_transformer = partial(__get_emb, use_graph_embedding, embedding_dim) 29 | 30 | def get_emb(label_list): 31 | return __condition_transformer(label_list) -------------------------------------------------------------------------------- /deephop/make_pair.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import multiprocessing 4 | import os 5 | from typing import List 6 | import pandas as pd 7 | from pandas import DataFrame 8 | from rdkit import DataStructs, Chem 9 | from rdkit.Chem import AllChem, rdMolAlign, Mol 10 | 11 | 12 | from calc_SC_RDKit import calc_SC_RDKit_score 13 | from data_loader import load_data_frame 14 | 15 | 16 | def calc_2D_similarity(s1: str, s2: str) -> float: 17 | mol1 = Chem.MolFromSmiles(s1) 18 | mol2 = Chem.MolFromSmiles(s2) 19 | if mol1 is None or mol2 is None: 20 | return -1.0 21 | try: 22 | f1 = AllChem.GetMorganFingerprint(mol1, 3) 23 | f2 = AllChem.GetMorganFingerprint(mol2, 3) 24 | return DataStructs.TanimotoSimilarity(f1, f2) 25 | except Exception: 26 | return -1.0 27 | 28 | 29 | def calc_3d_similarity(ref: str, gen: str) -> float: 30 | ref_mol = Chem.MolFromSmiles(ref) 31 | gen_mol = Chem.MolFromSmiles(gen) 32 | if ref_mol is None or gen_mol is None: 33 | return -1.0 34 | try: 35 | ref_mol = Chem.AddHs(ref_mol) 36 | Chem.AllChem.EmbedMolecule(ref_mol) 37 | Chem.AllChem.UFFOptimizeMolecule(ref_mol) 38 | 39 | gen_mol = Chem.AddHs(gen_mol) 40 | Chem.AllChem.EmbedMolecule(gen_mol) 41 | Chem.AllChem.UFFOptimizeMolecule(gen_mol) 42 | 43 | pyO3A = rdMolAlign.GetO3A(gen_mol, ref_mol).Align() 44 | return calc_SC_RDKit_score(gen_mol, ref_mol) 45 | except Exception: 46 | return -1.0 47 | 48 | 49 | def prepare_conformer(smiles: str): 50 | try: 51 | ref_mol = Chem.MolFromSmiles(smiles) 52 | ref_mol = Chem.AddHs(ref_mol) 53 | Chem.AllChem.EmbedMolecule(ref_mol) 54 | Chem.AllChem.UFFOptimizeMolecule(ref_mol) 55 | except Exception: 56 | return None 57 | return ref_mol 58 | 59 | 60 | def calc_3d_score(ref_mol: Mol, gen_mol: Mol): 61 | try: 62 | pyO3A = rdMolAlign.GetO3A(gen_mol, ref_mol).Align() 63 | return calc_SC_RDKit_score(gen_mol, ref_mol) 64 | except: 65 | return -1.0 66 | 67 | 68 | def get_use_cpu_count(): 69 | cpu_num = multiprocessing.cpu_count() 70 | use_cpu_num = max(int(cpu_num * 0.9), 1) 71 | return use_cpu_num 72 | 73 | 74 | class MapFunc(object): 75 | def __init__(self, score_func, smiles_list): 76 | self.score_func = score_func 77 | self.smiles_list = smiles_list 78 | 79 | def __call__(self, args): 80 | ref = args[0] 81 | dest = args[1] 82 | return self.score_func(self.smiles_list[ref], self.smiles_list[dest]) 83 | 84 | 85 | def parallel_run(pair_list, func): 86 | pool = multiprocessing.Pool() 87 | results = pool.map(func, pair_list) 88 | 89 | pool.close() 90 | pool.join() 91 | 92 | return results 93 | 94 | 95 | def make_pair(df, smilarity_2d_threshold: float = 0.6, smilarity_3d_threshold: float = 0.6, 96 | delta_pvalue: float = 1.0, 97 | split_ratio: List[float] = [0.8, 0, 1, 0.1]) -> List[DataFrame]: 98 | data = [] 99 | smiles_list = list(df['canonical_smiles']) 100 | total = len(smiles_list) 101 | # loc = df.columns.get_loc("pchembl_value") 102 | p_value_list = list(df['pchembl_value']) 103 | step_1_pairs = [] 104 | # first making pair,then splitting 105 | for j, ref in enumerate(smiles_list): 106 | for k in range(j + 1, total): 107 | # target = smiles_list[k] 108 | delta_p = p_value_list[k] - p_value_list[j] 109 | if abs(delta_p) < 1.0: 110 | # delta_p is too littile, skip this pair 111 | continue 112 | if delta_p > 0: 113 | step_1_pairs.append([j, k, delta_p]) 114 | else: 115 | step_1_pairs.append([k, j, delta_p]) 116 | 117 | func = MapFunc(calc_2D_similarity, smiles_list) 118 | score_2d_list = parallel_run(step_1_pairs, func) 119 | step_2_pairs = [] 120 | for i, score_2d in enumerate(score_2d_list): 121 | if smilarity_2d_threshold > score_2d > 0: 122 | new_list = step_1_pairs[i] 123 | new_list.append(score_2d) 124 | step_2_pairs.append(new_list) 125 | 126 | prepare_conformer_list = parallel_run(smiles_list, prepare_conformer) 127 | func_3d = MapFunc(calc_3d_score, prepare_conformer_list) 128 | score_3d_list = parallel_run(step_2_pairs, func_3d) 129 | for i, score_3d in enumerate(score_3d_list): 130 | if score_3d > smilarity_3d_threshold: 131 | ref, dest, p_value, score_2d = step_2_pairs[i] 132 | data.append([smiles_list[ref], smiles_list[dest], abs(p_value), score_2d, score_3d]) 133 | 134 | return [pd.DataFrame(data=data, columns=['ref_smiles', 'target_smiles', 'delta_p', 'score_2d', 'score_3d', ])] 135 | 136 | 137 | def get_data(): 138 | df = load_data_frame('/home/aht/paper_code/shaungjia/chembl_webresource_client/scaffold_hopping_320target.csv') 139 | core_columns = ["canonical_smiles", "pchembl_value", "target_chembl_id"] 140 | df = df[df['canonical_smiles'].notnull() & df['pchembl_value'].notnull() & df['target_chembl_id'].notnull()] 141 | df = df[(df['canonical_smiles'].str.len() > 0) & (df['target_chembl_id'].str.len() > 0)] 142 | df = df.drop_duplicates(core_columns, keep='first') 143 | return df 144 | 145 | 146 | if __name__ == '__main__': 147 | df = get_data() 148 | df = df.drop_duplicates(["canonical_smiles", "target_chembl_id"], keep='first') 149 | groups = df.groupby(['target_chembl_id']) 150 | data_list = [] 151 | df_dicts = {k: v for k, v in groups} 152 | targets = [target_chembl_id for target_chembl_id, _ in df_dicts.items()] 153 | targets.sort(key=lambda k: len(df_dicts[k]), reverse=False) 154 | for target_chembl_id in targets: 155 | sub_df = df_dicts[target_chembl_id] 156 | tmp_save_file = f'tmp/{target_chembl_id}.csv' 157 | if os.path.isfile(tmp_save_file): 158 | continue 159 | # 300 is threshold of the number of molecules 160 | if len(sub_df) > 300: 161 | print(f"process traget {target_chembl_id}, size: {len(sub_df)}") 162 | # pd_train, pd_val, pd_test = make_pair(sub_df) 163 | splitted_data = make_pair(sub_df) 164 | for data in splitted_data: 165 | data['target_chembl_id'] = target_chembl_id 166 | data_list.append(splitted_data) 167 | assert len(splitted_data) == 1 168 | splitted_data[0].to_csv(f'tmp/{target_chembl_id}.csv', index=False) 169 | 170 | for i in range(len(data_list[0])): 171 | out_csv = f'/home/aht/paper_code/shaungjia/chembl_webresource_client/prepared_data/{i}.csv' 172 | result = pd.concat([v[i] for v in data_list]) 173 | result.to_csv(out_csv, index=False) 174 | -------------------------------------------------------------------------------- /deephop/onmt/GAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GATLayer(nn.Module): 7 | def __init__(self, in_dim, out_dim, edge_dim): 8 | super(GATLayer, self).__init__() 9 | 10 | # equation (1) 11 | self.fc = nn.Linear(in_dim, out_dim, bias=False) 12 | # equation (2) 13 | self.attn_fc = nn.Linear(2 * out_dim + edge_dim, 1, bias=False) 14 | 15 | def edge_attention(self, edges): 16 | # edge UDF for equation (2) 17 | z2 = torch.cat([edges.src['z'], edges.dst['z'], edges.data['w']], dim=1) 18 | a = self.attn_fc(z2) 19 | return {'e': F.leaky_relu(a)} 20 | 21 | def message_func(self, edges): 22 | # message UDF for equation (3) & (4) 23 | return {'z': edges.src['z'], 'e': edges.data['e']} 24 | 25 | def reduce_func(self, nodes): 26 | # reduce UDF for equation (3) & (4) 27 | # equation (3) 28 | alpha = F.softmax(nodes.mailbox['e'], dim=1) 29 | # equation (4) 30 | h = torch.sum(alpha * nodes.mailbox['z'], dim=1) 31 | return {'h': h} 32 | 33 | def forward(self, g, h): 34 | # equation (1) 35 | z = self.fc(h) 36 | g.ndata['z'] = z 37 | # equation (2) 38 | g.apply_edges(self.edge_attention) 39 | # equation (3) & (4) 40 | g.update_all(self.message_func, self.reduce_func) 41 | return g.ndata.pop('h') 42 | 43 | 44 | 45 | class MultiHeadGATLayer(nn.Module): 46 | def __init__(self, in_dim, out_dim, edge_dim, num_heads, merge='mean'): 47 | super(MultiHeadGATLayer, self).__init__() 48 | self.heads = nn.ModuleList() 49 | for i in range(num_heads): 50 | self.heads.append(GATLayer(in_dim, out_dim,edge_dim)) 51 | self.merge = merge 52 | 53 | def forward(self, g, h): 54 | head_outs = [attn_head(g, h) for attn_head in self.heads] 55 | if self.merge == 'cat': 56 | # concat on the output feature dimension (dim=1) 57 | return torch.cat(head_outs, dim=1) 58 | else: 59 | # merge using average 60 | return torch.mean(torch.stack(head_outs), dim = 0) 61 | 62 | 63 | -------------------------------------------------------------------------------- /deephop/onmt/GATGATE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | from dgl.nn.pytorch.conv import GATConv 6 | from onmt.GAT import * 7 | 8 | class GATGATE(nn.Module): 9 | def __init__(self, in_feats, out_feats, edge_dim, num_heads): 10 | super(GATGATE, self).__init__() 11 | self.in_feats = in_feats 12 | self.out_feats = out_feats 13 | self.GATLayers = nn.ModuleList([]) 14 | self.GATLayers.append(MultiHeadGATLayer(in_feats, out_feats, edge_dim, num_heads)) 15 | self.GATLayers.append(MultiHeadGATLayer(out_feats, out_feats, edge_dim, num_heads)) 16 | self.GATLayers.append(MultiHeadGATLayer(out_feats, out_feats, edge_dim, num_heads)) 17 | self.seq_fc1 = nn.Linear(out_feats, out_feats) 18 | self.seq_fc2 = nn.Linear(out_feats, out_feats) 19 | self.bias = nn.Parameter(torch.rand(1, out_feats)) 20 | torch.nn.init.uniform_(self.bias, a=-0.2, b=0.2) 21 | 22 | def forward(self, g, features): 23 | n = features.size(0) 24 | 25 | for i in range(len(self.GATLayers)): 26 | h = self.GATLayers[i](g, features) # N * Heads * D 27 | if i < len(self.GATLayers)-1: 28 | h = F.elu(h) 29 | # h = torch.mean(h, dim = 1) # N * D 30 | # print(i, ' layer', h.size()) 31 | if i==0: 32 | features = h 33 | continue 34 | z = torch.sigmoid(self.seq_fc1(h) + self.seq_fc2(features) + self.bias.expand(n, self.out_feats)) 35 | features = z * h + (1 - z) * features 36 | 37 | return features 38 | 39 | -------------------------------------------------------------------------------- /deephop/onmt/GCN.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import dgl.function as fn 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import time 8 | 9 | def gcn_reduce(nodes): 10 | msgs = torch.cat((nodes.mailbox['h'], nodes.data['h'].unsqueeze(1)), dim = 1) 11 | msgs = torch.mean(msgs, dim = 1) 12 | return {'h': msgs} 13 | 14 | def gcn_msg(edges): 15 | 16 | return {'h': edges.src['h']} 17 | 18 | 19 | 20 | class NodeApplyModule(nn.Module): 21 | def __init__(self, in_feats, out_feats): 22 | super(NodeApplyModule, self).__init__() 23 | self.fc = nn.Linear(in_feats, out_feats, bias = True) 24 | 25 | 26 | def forward(self, node): 27 | h = self.fc(node.data['h']) 28 | h = F.relu(h) 29 | return {'h' : h} 30 | 31 | 32 | class GCN(nn.Module): 33 | def __init__(self, in_feats, out_feats): 34 | super(GCN, self).__init__() 35 | self.apply_mod = NodeApplyModule(in_feats, out_feats) 36 | 37 | def forward(self, g, features): 38 | g.ndata['h'] = features 39 | 40 | g.update_all(gcn_msg, gcn_reduce) 41 | 42 | g.apply_nodes(func = self.apply_mod) 43 | 44 | 45 | return g.ndata.pop('h') 46 | 47 | 48 | # 2 layers GCN 49 | class Net(nn.Module): 50 | def __init__(self): 51 | super(Net, self).__init__() 52 | self.gcn1 = GCN(256, 256) 53 | self.gcn2 = GCN(256, 256) 54 | # self.fc = nn.Linear(70, 15) 55 | 56 | def forward(self, g, features): 57 | x = self.gcn1(g, features) 58 | x = self.gcn2(g, x) 59 | g.ndata['h'] = x 60 | # hg = dgl.mean_nodes(g, 'h') 61 | return x 62 | -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/MPNN.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from .MessageFunction import MessageFunction 5 | from .UpdateFunction import UpdateFunction 6 | from .ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | 13 | class MPNN(nn.Module): 14 | """ 15 | MPNN as proposed by Gilmer et al.. 16 | 17 | This class implements the whole Gilmer et al. model following the functions Message, Update and Readout. 18 | 19 | Parameters 20 | ---------- 21 | in_n : int list 22 | Sizes for the node and edge features. 23 | hidden_state_size : int 24 | Size of the hidden states (the input will be padded with 0's to this size). 25 | message_size : int 26 | Message function output vector size. 27 | n_layers : int 28 | Number of iterations Message+Update (weight tying). 29 | l_target : int 30 | Size of the output. 31 | type : str (Optional) 32 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 33 | """ 34 | 35 | def __init__(self, in_n, hidden_state_size, message_size, n_layers, l_target, type='regression'): 36 | super(MPNN, self).__init__() 37 | 38 | # Define message 39 | self.m = nn.ModuleList( 40 | [MessageFunction('mpnn', args={'edge_feat': in_n[1], 'in': hidden_state_size, 41 | 'out': message_size})]) 42 | 43 | # Define Update 44 | self.u = nn.ModuleList([UpdateFunction('mpnn',args={'in_m': message_size, 45 | 'out': hidden_state_size})]) 46 | 47 | # Define Readout 48 | self.r = ReadoutFunction('mpnn',args={'in': hidden_state_size, 49 | 'target': l_target}) 50 | 51 | self.type = type 52 | 53 | self.args = {} 54 | self.args['out'] = hidden_state_size 55 | 56 | self.n_layers = n_layers 57 | 58 | def forward(self, g, h_in, e): 59 | 60 | # only use the last layer embedding 61 | 62 | # Padding to some larger dimension d 63 | h_t = torch.cat([h_in, Variable( 64 | torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2) 65 | 66 | 67 | 68 | for t in range(0, self.n_layers): 69 | e_aux = e.view(-1, e.size(3)) 70 | 71 | h_aux = h_t.view(-1, h_t.size(2)) 72 | 73 | m = self.m[0].forward(h_t, h_aux, e_aux) 74 | m = m.view(h_t.size(0), h_t.size(1), -1, m.size(1)) 75 | 76 | # Nodes without edge set message to 0 77 | m = torch.unsqueeze(g, 3).expand_as(m) * m 78 | 79 | m = torch.sum(m, 1) 80 | 81 | h_t = self.u[0].forward(h_t, m) 82 | 83 | # Delete virtual nodes 84 | h_t = (torch.sum(h_in, 2)[..., None].expand_as(h_t) > 0).type_as(h_t) * h_t 85 | 86 | 87 | return h_t 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | # h = [] 97 | 98 | # # Padding to some larger dimension d 99 | # h_t = torch.cat([h_in, Variable( 100 | # torch.zeros(h_in.size(0), h_in.size(1), self.args['out'] - h_in.size(2)).type_as(h_in.data))], 2) 101 | 102 | # h.append(h_t.clone()) 103 | 104 | # # Layer 105 | # for t in range(0, self.n_layers): 106 | # e_aux = e.view(-1, e.size(3)) 107 | 108 | # h_aux = h[t].view(-1, h[t].size(2)) 109 | 110 | # m = self.m[0].forward(h[t], h_aux, e_aux) 111 | # m = m.view(h[0].size(0), h[0].size(1), -1, m.size(1)) 112 | 113 | # # Nodes without edge set message to 0 114 | # m = torch.unsqueeze(g, 3).expand_as(m) * m 115 | 116 | # m = torch.squeeze(torch.sum(m, 1),1) 117 | 118 | # h_t = self.u[0].forward(h[t], m) 119 | 120 | # # Delete virtual nodes 121 | # h_t = (torch.sum(h_in, 2)[..., None].expand_as(h_t) > 0).type_as(h_t) * h_t 122 | # h.append(h_t) 123 | 124 | # # Readout 125 | # #res = self.r.forward(h) 126 | 127 | # # if self.type == 'classification': 128 | # # res = nn.LogSoftmax()(res) 129 | # return h -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/MPNNv2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from .MessageFunction import MessageFunction 5 | from .UpdateFunction import UpdateFunction 6 | from .ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | 13 | class MPNNv2(nn.Module): 14 | """ 15 | MPNN as proposed by Battaglia et al.. 16 | 17 | This class implements the whole Battaglia et al. model following the functions proposed by Gilmer et al. as 18 | Message, Update and Readout. 19 | 20 | Parameters 21 | ---------- 22 | in_n : int list 23 | Sizes for the node and edge features. 24 | out_message : int list 25 | Output sizes for the different Message functions. 26 | out_update : int list 27 | Output sizes for the different Update functions. 28 | l_target : int 29 | Size of the output. 30 | type : str (Optional) 31 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 32 | """ 33 | 34 | def __init__(self, in_n, out_message, out_update, l_target, type='regression'): 35 | super(MPNNv2, self).__init__() 36 | 37 | n_layers = len(out_update) 38 | 39 | # Define message 1 & 2 40 | self.m = nn.ModuleList([MessageFunction('intnet', args={'in': 2*in_n[0] + in_n[1], 'out': out_message[i]}) 41 | if i == 0 else 42 | MessageFunction('intnet', args={'in': 2*out_update[i-1] + in_n[1], 'out': out_message[i]}) 43 | for i in range(n_layers)]) 44 | 45 | # Define Update 1 & 2 46 | self.u = nn.ModuleList([UpdateFunction('intnet', args={'in': in_n[0]+out_message[i], 'out': out_update[i]}) 47 | if i == 0 else 48 | UpdateFunction('intnet', args={'in': out_update[i-1]+out_message[i], 'out': out_update[i]}) 49 | for i in range(n_layers)]) 50 | 51 | # Define Readout 52 | self.r = ReadoutFunction('intnet', args={'in': out_update[-1], 'target': l_target}) 53 | 54 | self.type = type 55 | 56 | def forward(self, g, h_in, e): 57 | 58 | h = [] 59 | h.append(h_in) 60 | 61 | # Layer 62 | for t in range(0, len(self.m)): 63 | 64 | u_args = self.u[t].get_args() 65 | h_t = Variable(torch.zeros(h_in.size(0), h_in.size(1), u_args['out']).type_as(h[t].data)) 66 | 67 | # Apply one layer pass (Message + Update) 68 | for v in range(0, h_in.size(1)): 69 | 70 | m = self.m[t].forward(h[t][:, v, :], h[t], e[:, v, :, :]) 71 | 72 | # Nodes without edge set message to 0 73 | m = g[:, v, :,None].expand_as(m) * m 74 | 75 | m = torch.sum(m, 1) 76 | 77 | # Interaction Net 78 | opt = {} 79 | opt['x_v'] = Variable(torch.Tensor([]).type_as(m.data)) 80 | 81 | h_t[:, v, :] = self.u[t].forward(h[t][:, v, :], m, opt) 82 | 83 | h.append(h_t.clone()) 84 | 85 | # Readout 86 | res = self.r.forward(h) 87 | if self.type == 'classification': 88 | res = nn.LogSoftmax()(res) 89 | return res 90 | -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/MPNNv3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | from .MessageFunction import MessageFunction 5 | from .UpdateFunction import UpdateFunction 6 | from .ReadoutFunction import ReadoutFunction 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.autograd import Variable 11 | 12 | 13 | 14 | class MPNNv3(nn.Module): 15 | """ 16 | MPNN as proposed by Duvenaud et al.. 17 | 18 | This class implements the whole Duvenaud et al. model following the functions proposed by Gilmer et al. as 19 | Message, Update and Readout. 20 | 21 | Parameters 22 | ---------- 23 | d : int list. 24 | Possible degrees for the input graph. 25 | in_n : int list 26 | Sizes for the node and edge features. 27 | out_update : int list 28 | Output sizes for the different Update functions. 29 | hidden_state_readout : int 30 | Input size for the neural net used inside the readout function. 31 | l_target : int 32 | Size of the output. 33 | type : str (Optional) 34 | Classification | [Regression (default)]. If classification, LogSoftmax layer is applied to the output vector. 35 | """ 36 | 37 | def __init__(self, d, in_n, out_update, hidden_state_readout, l_target, type='regression'): 38 | super(MPNNv3, self).__init__() 39 | 40 | n_layers = len(out_update) 41 | 42 | # Define message 1 & 2 43 | self.m = nn.ModuleList([MessageFunction('duvenaud') for _ in range(n_layers)]) 44 | 45 | # Define Update 1 & 2 46 | self.u = nn.ModuleList([UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(in_n[0], in_n[1]), 'out': out_update[0]}) if i == 0 else 47 | UpdateFunction('duvenaud', args={'deg': d, 'in': self.m[i].get_out_size(out_update[i-1], in_n[1]), 'out': out_update[i]}) for i in range(n_layers)]) 48 | 49 | # Define Readout 50 | self.r = ReadoutFunction('duvenaud', 51 | args={'layers': len(self.m) + 1, 52 | 'in': [in_n[0] if i == 0 else out_update[i-1] for i in range(n_layers+1)], 53 | 'out': hidden_state_readout, 54 | 'target': l_target}) 55 | 56 | self.type = type 57 | 58 | def forward(self, g, h_in, e, plotter=None): 59 | 60 | h = [] 61 | h.append(h_in) 62 | 63 | # Layer 64 | for t in range(0, len(self.m)): 65 | 66 | u_args = self.u[t].get_args() 67 | 68 | h_t = Variable(torch.zeros(h_in.size(0), h_in.size(1), u_args['out']).type_as(h[t].data)) 69 | 70 | # Apply one layer pass (Message + Update) 71 | for v in range(0, h_in.size(1)): 72 | 73 | m = self.m[t].forward(h[t][:, v, :], h[t], e[:, v, :]) 74 | 75 | # Nodes without edge set message to 0 76 | m = g[:, v, :, None].expand_as(m) * m 77 | 78 | m = torch.sum(m, 1) 79 | 80 | # Duvenaud 81 | deg = torch.sum(g[:, v, :].data, 1) 82 | 83 | # Separate degrees 84 | for i in range(len(u_args['deg'])): 85 | ind = deg == u_args['deg'][i] 86 | ind = Variable(torch.squeeze(torch.nonzero(torch.squeeze(ind))), volatile=True) 87 | 88 | opt = {'deg': i} 89 | 90 | # Update 91 | if len(ind) != 0: 92 | aux = self.u[t].forward(torch.index_select(h[t], 0, ind)[:, v, :], torch.index_select(m, 0, ind), opt) 93 | ind = ind.data.cpu().numpy() 94 | for j in range(len(ind)): 95 | h_t[ind[j], v, :] = aux[j, :] 96 | 97 | if plotter is not None: 98 | num_feat = h_t.size(2) 99 | color = h_t[0,:,:].data.cpu().numpy() 100 | for i in range(num_feat): 101 | plotter(color[:, i], 'layer_' + str(t) + '_element_' + str(i) + '.png') 102 | 103 | h.append(h_t.clone()) 104 | # Readout 105 | res = self.r.forward(h) 106 | if self.type == 'classification': 107 | res = nn.LogSoftmax()(res) 108 | return res 109 | -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/MessageFunction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | MessageFunction.py: Propagates a message depending on two nodes and their common edge. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | 14 | from .nnet import NNet 15 | 16 | import numpy as np 17 | import os 18 | import argparse 19 | import time 20 | import torch 21 | 22 | import torch.nn as nn 23 | from torch.autograd.variable import Variable 24 | 25 | 26 | class MessageFunction(nn.Module): 27 | 28 | # Constructor 29 | def __init__(self, message_def='mpnn', args={}): 30 | super(MessageFunction, self).__init__() 31 | self.m_definition = '' 32 | self.m_function = None 33 | self.args = {} 34 | self.__set_message(message_def, args) 35 | 36 | # Message from h_v to h_w through e_vw 37 | def forward(self, h_v, h_w, e_vw, args=None): 38 | return self.m_function(h_v, h_w, e_vw, args) 39 | 40 | # Set a message function 41 | def __set_message(self, message_def, args={}): 42 | self.m_definition = message_def.lower() 43 | 44 | self.m_function = { 45 | 'duvenaud': self.m_duvenaud, 46 | 'intnet': self.m_intnet, 47 | 'mpnn': self.m_mpnn, 48 | }.get(self.m_definition, None) 49 | 50 | if self.m_function is None: 51 | print('WARNING!: Message Function has not been set correctly\n\tIncorrect definition ' + message_def) 52 | quit() 53 | 54 | init_parameters = { 55 | 'duvenaud': self.init_duvenaud, 56 | 'intnet': self.init_intnet, 57 | 'mpnn': self.init_mpnn 58 | }.get(self.m_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {})) 59 | 60 | self.learn_args, self.learn_modules, self.args = init_parameters(args) 61 | 62 | self.m_size = { 63 | 'duvenaud': self.out_duvenaud, 64 | 'intnet': self.out_intnet, 65 | 'mpnn': self.out_mpnn 66 | }.get(self.m_definition, None) 67 | 68 | # Get the name of the used message function 69 | def get_definition(self): 70 | return self.m_definition 71 | 72 | # Get the message function arguments 73 | def get_args(self): 74 | return self.args 75 | 76 | # Get Output size 77 | def get_out_size(self, size_h, size_e, args=None): 78 | return self.m_size(size_h, size_e, args) 79 | 80 | 81 | # Duvenaud et al. (2015), Convolutional Networks for Learning Molecular Fingerprints 82 | def m_duvenaud(self, h_v, h_w, e_vw, args): 83 | m = torch.cat([h_w, e_vw], 2) 84 | return m 85 | 86 | def out_duvenaud(self, size_h, size_e, args): 87 | return size_h + size_e 88 | 89 | def init_duvenaud(self, params): 90 | learn_args = [] 91 | learn_modules = [] 92 | args = {} 93 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 94 | 95 | # Battaglia et al. (2016), Interaction Networks 96 | def m_intnet(self, h_v, h_w, e_vw, args): 97 | m = torch.cat([h_v[:, None, :].expand_as(h_w), h_w, e_vw], 2) 98 | b_size = m.size() 99 | 100 | m = m.view(-1, b_size[2]) 101 | 102 | m = self.learn_modules[0](m) 103 | m = m.view(b_size[0], b_size[1], -1) 104 | return m 105 | 106 | def out_intnet(self, size_h, size_e, args): 107 | return self.args['out'] 108 | 109 | def init_intnet(self, params): 110 | learn_args = [] 111 | learn_modules = [] 112 | args = {} 113 | args['in'] = params['in'] 114 | args['out'] = params['out'] 115 | learn_modules.append(NNet(n_in=params['in'], n_out=params['out'])) 116 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 117 | 118 | # Gilmer et al. (2017), Neural Message Passing for Quantum Chemistry 119 | def m_mpnn(self, h_v, h_w, e_vw, opt={}): 120 | # Matrices for each edge 121 | edge_output = self.learn_modules[0](e_vw) 122 | edge_output = edge_output.view(-1, self.args['out'], self.args['in']) 123 | 124 | #h_w_rows = h_w[..., None].expand(h_w.size(0), h_v.size(1), h_w.size(1)).contiguous() 125 | h_w_rows = h_w[..., None].expand(h_w.size(0), h_w.size(1), h_v.size(1)).contiguous() 126 | 127 | h_w_rows = h_w_rows.view(-1, self.args['in']) 128 | 129 | h_multiply = torch.bmm(edge_output, torch.unsqueeze(h_w_rows,2)) 130 | 131 | m_new = torch.squeeze(h_multiply) 132 | 133 | return m_new 134 | 135 | def out_mpnn(self, size_h, size_e, args): 136 | return self.args['out'] 137 | 138 | def init_mpnn(self, params): 139 | learn_args = [] 140 | learn_modules = [] 141 | args = {} 142 | 143 | args['in'] = params['in'] 144 | args['out'] = params['out'] 145 | 146 | # Define a parameter matrix A for each edge label. 147 | learn_modules.append(NNet(n_in=params['edge_feat'], n_out=(params['in']*params['out']))) 148 | 149 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 150 | 151 | -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/UpdateFunction.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | UpdateFunction.py: Updates the nodes using the previous state and the message. 6 | 7 | Usage: 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | 13 | # Own modules 14 | 15 | from .MessageFunction import MessageFunction 16 | from .nnet import NNet 17 | 18 | import numpy as np 19 | import time 20 | import os 21 | import argparse 22 | import torch 23 | 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from torch.autograd.variable import Variable 27 | 28 | #dtype = torch.cuda.FloatTensor 29 | dtype = torch.FloatTensor 30 | 31 | 32 | class UpdateFunction(nn.Module): 33 | 34 | # Constructor 35 | def __init__(self, update_def='nn', args={}): 36 | super(UpdateFunction, self).__init__() 37 | self.u_definition = '' 38 | self.u_function = None 39 | self.args = {} 40 | self.__set_update(update_def, args) 41 | 42 | # Update node hv given message mv 43 | def forward(self, h_v, m_v, opt={}): 44 | return self.u_function(h_v, m_v, opt) 45 | 46 | # Set update function 47 | def __set_update(self, update_def, args): 48 | self.u_definition = update_def.lower() 49 | 50 | self.u_function = { 51 | 'duvenaud': self.u_duvenaud, 52 | 'intnet': self.u_intnet, 53 | 'mpnn': self.u_mpnn 54 | }.get(self.u_definition, None) 55 | 56 | if self.u_function is None: 57 | print('WARNING!: Update Function has not been set correctly\n\tIncorrect definition ' + update_def) 58 | 59 | init_parameters = { 60 | 'duvenaud': self.init_duvenaud, 61 | 'intnet': self.init_intnet, 62 | 'mpnn': self.init_mpnn 63 | }.get(self.u_definition, lambda x: (nn.ParameterList([]), nn.ModuleList([]), {})) 64 | 65 | self.learn_args, self.learn_modules, self.args = init_parameters(args) 66 | 67 | # Get the name of the used update function 68 | def get_definition(self): 69 | return self.u_definition 70 | 71 | # Get the update function arguments 72 | def get_args(self): 73 | return self.args 74 | 75 | # Duvenaud 76 | def u_duvenaud(self, h_v, m_v, opt): 77 | 78 | param_sz = self.learn_args[0][opt['deg']].size() 79 | parameter_mat = torch.t(self.learn_args[0][opt['deg']])[None, ...].expand(m_v.size(0), param_sz[1], param_sz[0]) 80 | 81 | #print(parameter_mat.size()) 82 | #print(m_v.size()) 83 | #print(torch.transpose(m_v.unsqueeze(-2), 1, 2).size()) 84 | 85 | #aux = torch.bmm(parameter_mat, torch.transpose(m_v, 1, 2)) 86 | aux = torch.bmm(parameter_mat, torch.transpose(m_v.unsqueeze(-2), 1, 2)) 87 | 88 | return torch.transpose(torch.nn.Sigmoid()(aux), 1, 2) 89 | 90 | def init_duvenaud(self, params): 91 | learn_args = [] 92 | learn_modules = [] 93 | args = {} 94 | 95 | # Filter degree 0 (the message will be 0 and therefore there is no update 96 | args['deg'] = [i for i in params['deg'] if i!=0] 97 | args['in'] = params['in'] 98 | args['out'] = params['out'] 99 | 100 | # Define a parameter matrix H for each degree. 101 | learn_args.append(torch.nn.Parameter(torch.randn(len(args['deg']), args['in'], args['out']))) 102 | 103 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 104 | 105 | # Battaglia et al. (2016), Interaction Networks 106 | def u_intnet(self, h_v, m_v, opt): 107 | if opt['x_v'].ndimension(): 108 | input_tensor = torch.cat([h_v, opt['x_v'], torch.squeeze(m_v)], 1) 109 | else: 110 | input_tensor = torch.cat([h_v, torch.squeeze(m_v)], 1) 111 | 112 | return self.learn_modules[0](input_tensor) 113 | 114 | def init_intnet(self, params): 115 | learn_args = [] 116 | learn_modules = [] 117 | args = {} 118 | 119 | args['in'] = params['in'] 120 | args['out'] = params['out'] 121 | 122 | learn_modules.append(NNet(n_in=params['in'], n_out=params['out'])) 123 | 124 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 125 | 126 | def u_mpnn(self, h_v, m_v, opt={}): 127 | h_in = h_v.view(-1,h_v.size(2)) 128 | 129 | m_in = m_v.view(-1,m_v.size(2)) 130 | 131 | h_new = self.learn_modules[0](m_in[None,...],h_in[None,...])[0] # 0 or 1??? 132 | return torch.squeeze(h_new).view(h_v.size()) 133 | 134 | def init_mpnn(self, params): 135 | learn_args = [] 136 | learn_modules = [] 137 | args = {} 138 | 139 | args['in_m'] = params['in_m'] 140 | args['out'] = params['out'] 141 | 142 | # GRU 143 | learn_modules.append(nn.GRU(params['in_m'], params['out'])) 144 | 145 | return nn.ParameterList(learn_args), nn.ModuleList(learn_modules), args 146 | -------------------------------------------------------------------------------- /deephop/onmt/MPNNs/nnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class NNet(nn.Module): 9 | 10 | def __init__(self, n_in, n_out, hlayers=(128, 256, 128)): 11 | super(NNet, self).__init__() 12 | self.n_hlayers = len(hlayers) 13 | self.fcs = nn.ModuleList([nn.Linear(n_in, hlayers[i]) if i == 0 else 14 | nn.Linear(hlayers[i-1], n_out) if i == self.n_hlayers else 15 | nn.Linear(hlayers[i-1], hlayers[i]) for i in range(self.n_hlayers+1)]) 16 | 17 | def forward(self, x): 18 | x = x.contiguous().view(-1, self.num_flat_features(x)) 19 | for i in range(self.n_hlayers): 20 | x = F.relu(self.fcs[i](x)) 21 | x = self.fcs[-1](x) 22 | return x 23 | 24 | def num_flat_features(self, x): 25 | size = x.size()[1:] # all dimensions except the batch dimension 26 | num_features = 1 27 | for s in size: 28 | num_features *= s 29 | return num_features 30 | 31 | 32 | def main(): 33 | net = NNet(n_in=100, n_out=20) 34 | print(net) 35 | 36 | if __name__=='__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /deephop/onmt/MYGCN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import dgl 5 | from dgl.nn.pytorch.conv import GATConv 6 | 7 | class MGCN(nn.Module): 8 | def __init__(self, in_feats, out_feats, n_layers): 9 | super(MGCN, self).__init__() 10 | self.in_feats = in_feats 11 | self.out_feats = out_feats 12 | self.GATLayers = nn.ModuleList([]) 13 | self.GATLayers.append(GATConv(in_feats, out_feats, num_heads = 2, activation=None)) 14 | self.GATLayers.append(GATConv(out_feats, out_feats, num_heads = 2, activation=None)) 15 | self.GATLayers.append(GATConv(out_feats, out_feats, num_heads = 2, activation=None)) 16 | self.seq_fc1 = nn.Linear(out_feats, out_feats) 17 | self.seq_fc2 = nn.Linear(out_feats, out_feats) 18 | self.bias = nn.Parameter(torch.rand(1, out_feats)) 19 | torch.nn.init.uniform_(self.bias, a=-0.2, b=0.2) 20 | 21 | def forward(self, g, features): 22 | n = features.size(0) 23 | 24 | for i in range(len(self.GATLayers)): 25 | h = self.GATLayers[i](g, features) # N * Heads * D 26 | if i < len(self.GATLayers)-1: 27 | h = F.elu(h) 28 | h = torch.mean(h, dim = 1) # N * D 29 | if i==0: 30 | features = h 31 | continue 32 | z = torch.sigmoid(self.seq_fc1(h) + self.seq_fc2(features) + self.bias.expand(n, self.out_feats)) 33 | features = z * h + (1 - z) * features 34 | 35 | return features 36 | 37 | -------------------------------------------------------------------------------- /deephop/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | # For Flake 17 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 18 | onmt.utils, onmt.modules, "Trainer"] 19 | 20 | __version__ = "0.4.1" 21 | -------------------------------------------------------------------------------- /deephop/onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | -------------------------------------------------------------------------------- /deephop/onmt/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the CNN Decoder part of 3 | "Convolutional Sequence to Sequence Learning" 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | import onmt.modules 9 | from onmt.decoders.decoder import DecoderState 10 | from onmt.utils.misc import aeq 11 | from onmt.utils.cnn_factory import shape_transform, GatedConv 12 | 13 | SCALE_WEIGHT = 0.5 ** 0.5 14 | 15 | 16 | class CNNDecoder(nn.Module): 17 | """ 18 | Decoder built on CNN, based on :cite:`DBLP:journals/corr/GehringAGYD17`. 19 | 20 | 21 | Consists of residual convolutional layers, with ConvMultiStepAttention. 22 | """ 23 | 24 | def __init__(self, num_layers, hidden_size, attn_type, 25 | copy_attn, cnn_kernel_width, dropout, embeddings): 26 | super(CNNDecoder, self).__init__() 27 | 28 | # Basic attributes. 29 | self.decoder_type = 'cnn' 30 | self.num_layers = num_layers 31 | self.hidden_size = hidden_size 32 | self.cnn_kernel_width = cnn_kernel_width 33 | self.embeddings = embeddings 34 | self.dropout = dropout 35 | 36 | # Build the CNN. 37 | input_size = self.embeddings.embedding_size 38 | self.linear = nn.Linear(input_size, self.hidden_size) 39 | self.conv_layers = nn.ModuleList() 40 | for _ in range(self.num_layers): 41 | self.conv_layers.append( 42 | GatedConv(self.hidden_size, self.cnn_kernel_width, 43 | self.dropout, True)) 44 | 45 | self.attn_layers = nn.ModuleList() 46 | for _ in range(self.num_layers): 47 | self.attn_layers.append( 48 | onmt.modules.ConvMultiStepAttention(self.hidden_size)) 49 | 50 | # CNNDecoder has its own attention mechanism. 51 | # Set up a separated copy attention layer, if needed. 52 | self._copy = False 53 | if copy_attn: 54 | self.copy_attn = onmt.modules.GlobalAttention( 55 | hidden_size, attn_type=attn_type) 56 | self._copy = True 57 | 58 | def forward(self, tgt, memory_bank, state, memory_lengths=None, step=None): 59 | """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" 60 | # NOTE: memory_lengths is only here for compatibility reasons 61 | # with onmt.modules.RNNDecoderBase.forward() 62 | # CHECKS 63 | assert isinstance(state, CNNDecoderState) 64 | _, tgt_batch, _ = tgt.size() 65 | _, contxt_batch, _ = memory_bank.size() 66 | aeq(tgt_batch, contxt_batch) 67 | # END CHECKS 68 | 69 | if state.previous_input is not None: 70 | tgt = torch.cat([state.previous_input, tgt], 0) 71 | 72 | # Initialize return variables. 73 | outputs = [] 74 | attns = {"std": []} 75 | assert not self._copy, "Copy mechanism not yet tested in conv2conv" 76 | if self._copy: 77 | attns["copy"] = [] 78 | 79 | emb = self.embeddings(tgt) 80 | assert emb.dim() == 3 # len x batch x embedding_dim 81 | 82 | tgt_emb = emb.transpose(0, 1).contiguous() 83 | # The output of CNNEncoder. 84 | src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() 85 | # The combination of output of CNNEncoder and source embeddings. 86 | src_memory_bank_c = state.init_src.transpose(0, 1).contiguous() 87 | 88 | # Run the forward pass of the CNNDecoder. 89 | emb_reshape = tgt_emb.contiguous().view( 90 | tgt_emb.size(0) * tgt_emb.size(1), -1) 91 | linear_out = self.linear(emb_reshape) 92 | x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) 93 | x = shape_transform(x) 94 | 95 | pad = torch.zeros(x.size(0), x.size(1), 96 | self.cnn_kernel_width - 1, 1) 97 | 98 | pad = pad.type_as(x) 99 | base_target_emb = x 100 | 101 | for conv, attention in zip(self.conv_layers, self.attn_layers): 102 | new_target_input = torch.cat([pad, x], 2) 103 | out = conv(new_target_input) 104 | c, attn = attention(base_target_emb, out, 105 | src_memory_bank_t, src_memory_bank_c) 106 | x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT 107 | output = x.squeeze(3).transpose(1, 2) 108 | 109 | # Process the result and update the attentions. 110 | outputs = output.transpose(0, 1).contiguous() 111 | if state.previous_input is not None: 112 | outputs = outputs[state.previous_input.size(0):] 113 | attn = attn[:, state.previous_input.size(0):].squeeze() 114 | attn = torch.stack([attn]) 115 | attns["std"] = attn 116 | if self._copy: 117 | attns["copy"] = attn 118 | 119 | # Update the state. 120 | state.update_state(tgt) 121 | 122 | return outputs, state, attns 123 | 124 | def init_decoder_state(self, _, memory_bank, enc_hidden, with_cache=False): 125 | """ 126 | Init decoder state. 127 | """ 128 | return CNNDecoderState(memory_bank, enc_hidden) 129 | 130 | 131 | class CNNDecoderState(DecoderState): 132 | """ 133 | Init CNN decoder state. 134 | """ 135 | 136 | def __init__(self, memory_bank, enc_hidden): 137 | self.init_src = (memory_bank + enc_hidden) * SCALE_WEIGHT 138 | self.previous_input = None 139 | 140 | @property 141 | def _all(self): 142 | """ 143 | Contains attributes that need to be updated in self.beam_update(). 144 | """ 145 | return (self.previous_input,) 146 | 147 | def detach(self): 148 | self.previous_input = self.previous_input.detach() 149 | 150 | def update_state(self, new_input): 151 | """ Called for every decoder forward pass. """ 152 | self.previous_input = new_input 153 | 154 | def repeat_beam_size_times(self, beam_size): 155 | """ Repeat beam_size times along batch dimension. """ 156 | self.init_src = self.init_src.data.repeat(1, beam_size, 1) 157 | -------------------------------------------------------------------------------- /deephop/onmt/decoders/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ensemble decoding. 3 | 4 | Decodes using multiple models simultaneously, 5 | combining their prediction distributions by averaging. 6 | All models in the ensemble must share a target vocabulary. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from onmt.decoders.decoder import DecoderState 13 | from onmt.encoders.encoder import EncoderBase 14 | from onmt.models import NMTModel 15 | import onmt.model_builder 16 | 17 | 18 | class EnsembleDecoderState(DecoderState): 19 | """ Dummy DecoderState that wraps a tuple of real DecoderStates """ 20 | def __init__(self, model_decoder_states): 21 | self.model_decoder_states = tuple(model_decoder_states) 22 | 23 | def beam_update(self, idx, positions, beam_size): 24 | for model_state in self.model_decoder_states: 25 | model_state.beam_update(idx, positions, beam_size) 26 | 27 | def repeat_beam_size_times(self, beam_size): 28 | """ Repeat beam_size times along batch dimension. """ 29 | for model_state in self.model_decoder_states: 30 | model_state.repeat_beam_size_times(beam_size) 31 | 32 | def __getitem__(self, index): 33 | return self.model_decoder_states[index] 34 | 35 | def map_batch_fn(self, fn): 36 | for model_state in self.model_decoder_states: 37 | model_state.map_batch_fn(fn) 38 | 39 | 40 | class EnsembleDecoderOutput(object): 41 | """ Wrapper around multiple decoder final hidden states """ 42 | def __init__(self, model_outputs): 43 | self.model_outputs = tuple(model_outputs) 44 | 45 | def squeeze(self, dim=None): 46 | """ 47 | Delegate squeeze to avoid modifying 48 | :obj:`Translator.translate_batch()` 49 | """ 50 | return EnsembleDecoderOutput([ 51 | x.squeeze(dim) for x in self.model_outputs]) 52 | 53 | def __getitem__(self, index): 54 | return self.model_outputs[index] 55 | 56 | 57 | class EnsembleEncoder(EncoderBase): 58 | """ Dummy Encoder that delegates to individual real Encoders """ 59 | def __init__(self, model_encoders): 60 | super(EnsembleEncoder, self).__init__() 61 | self.model_encoders = nn.ModuleList(model_encoders) 62 | 63 | def forward(self, src, lengths=None): 64 | enc_hidden, memory_bank, _ = zip(*[ 65 | model_encoder(src, lengths) 66 | for model_encoder in self.model_encoders]) 67 | return enc_hidden, memory_bank, lengths 68 | 69 | 70 | class EnsembleDecoder(nn.Module): 71 | """ Dummy Decoder that delegates to individual real Decoders """ 72 | def __init__(self, model_decoders): 73 | super(EnsembleDecoder, self).__init__() 74 | self.model_decoders = nn.ModuleList(model_decoders) 75 | 76 | def forward(self, tgt, memory_bank, state, memory_lengths=None, step=None): 77 | """ See :obj:`RNNDecoderBase.forward()` """ 78 | # Memory_lengths is a single tensor shared between all models. 79 | # This assumption will not hold if Translator is modified 80 | # to calculate memory_lengths as something other than the length 81 | # of the input. 82 | outputs, states, attns = zip(*[ 83 | model_decoder( 84 | tgt, memory_bank[i], state[i], memory_lengths, step=step) 85 | for i, model_decoder in enumerate(self.model_decoders)]) 86 | mean_attns = self.combine_attns(attns) 87 | return (EnsembleDecoderOutput(outputs), 88 | EnsembleDecoderState(states), 89 | mean_attns) 90 | 91 | def combine_attns(self, attns): 92 | result = {} 93 | for key in attns[0].keys(): 94 | result[key] = torch.stack([attn[key] for attn in attns]).mean(0) 95 | return result 96 | 97 | def init_decoder_state(self, src, memory_bank, enc_hidden, with_cache=False): 98 | """ See :obj:`RNNDecoderBase.init_decoder_state()` """ 99 | return EnsembleDecoderState( 100 | [model_decoder.init_decoder_state(src, 101 | memory_bank[i], 102 | enc_hidden[i], with_cache) 103 | for i, model_decoder in enumerate(self.model_decoders)]) 104 | 105 | 106 | class EnsembleGenerator(nn.Module): 107 | """ 108 | Dummy Generator that delegates to individual real Generators, 109 | and then averages the resulting target distributions. 110 | """ 111 | def __init__(self, model_generators): 112 | self.model_generators = tuple(model_generators) 113 | super(EnsembleGenerator, self).__init__() 114 | 115 | def forward(self, hidden): 116 | """ 117 | Compute a distribution over the target dictionary 118 | by averaging distributions from models in the ensemble. 119 | All models in the ensemble must share a target vocabulary. 120 | """ 121 | distributions = [model_generator(hidden[i]) 122 | for i, model_generator 123 | in enumerate(self.model_generators)] 124 | return torch.stack(distributions).mean(0) 125 | 126 | 127 | 128 | class EnsembleModel(NMTModel): 129 | """ Dummy NMTModel wrapping individual real NMTModels """ 130 | def __init__(self, models): 131 | encoder = EnsembleEncoder(model.encoder for model in models) 132 | decoder = EnsembleDecoder(model.decoder for model in models) 133 | super(EnsembleModel, self).__init__(encoder, decoder) 134 | self.generator = EnsembleGenerator(model.generator for model in models) 135 | self.models = nn.ModuleList(models) 136 | 137 | 138 | def load_test_model(opt, dummy_opt): 139 | """ Read in multiple models for ensemble """ 140 | shared_fields = None 141 | shared_model_opt = None 142 | models = [] 143 | for model_path in opt.models: 144 | fields, model, model_opt = \ 145 | onmt.model_builder.load_test_model(opt, 146 | dummy_opt, 147 | model_path=model_path) 148 | if shared_fields is None: 149 | shared_fields = fields 150 | else: 151 | for key, field in fields.items(): 152 | if field is not None and 'vocab' in field.__dict__: 153 | assert field.vocab.stoi == shared_fields[key].vocab.stoi, \ 154 | 'Ensemble models must use the same preprocessed data' 155 | models.append(model) 156 | if shared_model_opt is None: 157 | shared_model_opt = model_opt 158 | ensemble_model = EnsembleModel(models) 159 | return shared_fields, ensemble_model, shared_model_opt 160 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.rnn_encoder import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | 8 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 9 | "MeanEncoder"] 10 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """ Audio encoder """ 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | 9 | from onmt.utils.rnn_factory import rnn_factory 10 | 11 | 12 | class AudioEncoder(nn.Module): 13 | """ 14 | A simple encoder convolutional -> recurrent neural network for 15 | audio input. 16 | 17 | Args: 18 | num_layers (int): number of encoder layers. 19 | bidirectional (bool): bidirectional encoder. 20 | rnn_size (int): size of hidden states of the rnn. 21 | dropout (float): dropout probablity. 22 | sample_rate (float): input spec 23 | window_size (int): input spec 24 | 25 | """ 26 | def __init__(self, rnn_type, enc_layers, dec_layers, brnn, 27 | enc_rnn_size, dec_rnn_size, enc_pooling, dropout, 28 | sample_rate, window_size): 29 | super(AudioEncoder, self).__init__() 30 | self.enc_layers = enc_layers 31 | self.rnn_type = rnn_type 32 | self.dec_layers = dec_layers 33 | num_directions = 2 if brnn else 1 34 | self.num_directions = num_directions 35 | assert enc_rnn_size % num_directions == 0 36 | enc_rnn_size_real = enc_rnn_size // num_directions 37 | assert dec_rnn_size % num_directions == 0 38 | self.dec_rnn_size = dec_rnn_size 39 | dec_rnn_size_real = dec_rnn_size // num_directions 40 | self.dec_rnn_size_real = dec_rnn_size_real 41 | self.dec_rnn_size = dec_rnn_size 42 | input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 43 | enc_pooling = enc_pooling.split(',') 44 | assert len(enc_pooling) == enc_layers or len(enc_pooling) == 1 45 | if len(enc_pooling) == 1: 46 | enc_pooling = enc_pooling * enc_layers 47 | enc_pooling = [int(p) for p in enc_pooling] 48 | self.enc_pooling = enc_pooling 49 | 50 | if dropout > 0: 51 | self.dropout = nn.Dropout(dropout) 52 | else: 53 | self.dropout = None 54 | self.W = nn.Linear(enc_rnn_size, dec_rnn_size, bias=False) 55 | self.batchnorm_0 = nn.BatchNorm1d(enc_rnn_size, affine=True) 56 | self.rnn_0, self.no_pack_padded_seq = \ 57 | rnn_factory(rnn_type, 58 | input_size=input_size, 59 | hidden_size=enc_rnn_size_real, 60 | num_layers=1, 61 | dropout=dropout, 62 | bidirectional=brnn) 63 | self.pool_0 = nn.MaxPool1d(enc_pooling[0]) 64 | for l in range(enc_layers - 1): 65 | batchnorm = nn.BatchNorm1d(enc_rnn_size, affine=True) 66 | rnn, _ = \ 67 | rnn_factory(rnn_type, 68 | input_size=enc_rnn_size, 69 | hidden_size=enc_rnn_size_real, 70 | num_layers=1, 71 | dropout=dropout, 72 | bidirectional=brnn) 73 | setattr(self, 'rnn_%d' % (l + 1), rnn) 74 | setattr(self, 'pool_%d' % (l + 1), 75 | nn.MaxPool1d(enc_pooling[l + 1])) 76 | setattr(self, 'batchnorm_%d' % (l + 1), batchnorm) 77 | 78 | def forward(self, src, lengths=None): 79 | "See :obj:`onmt.encoders.encoder.EncoderBase.forward()`" 80 | 81 | batch_size, _, nfft, t = src.size() 82 | src = src.transpose(0, 1).transpose(0, 3).contiguous() \ 83 | .view(t, batch_size, nfft) 84 | orig_lengths = lengths 85 | lengths = lengths.view(-1).tolist() 86 | 87 | for l in range(self.enc_layers): 88 | rnn = getattr(self, 'rnn_%d' % l) 89 | pool = getattr(self, 'pool_%d' % l) 90 | batchnorm = getattr(self, 'batchnorm_%d' % l) 91 | stride = self.enc_pooling[l] 92 | packed_emb = pack(src, lengths) 93 | memory_bank, tmp = rnn(packed_emb) 94 | memory_bank = unpack(memory_bank)[0] 95 | t, _, _ = memory_bank.size() 96 | memory_bank = memory_bank.transpose(0, 2) 97 | memory_bank = pool(memory_bank) 98 | lengths = [int(math.floor((length - stride)/stride + 1)) 99 | for length in lengths] 100 | memory_bank = memory_bank.transpose(0, 2) 101 | src = memory_bank 102 | t, _, num_feat = src.size() 103 | src = batchnorm(src.contiguous().view(-1, num_feat)) 104 | src = src.view(t, -1, num_feat) 105 | if self.dropout and l + 1 != self.enc_layers: 106 | src = self.dropout(src) 107 | 108 | memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) 109 | memory_bank = self.W(memory_bank).view(-1, batch_size, 110 | self.dec_rnn_size) 111 | 112 | state = memory_bank.new_full((self.dec_layers * self.num_directions, 113 | batch_size, self.dec_rnn_size_real), 0) 114 | if self.rnn_type == 'LSTM': 115 | # The encoder hidden is (layers*directions) x batch x dim. 116 | encoder_final = (state, state) 117 | else: 118 | encoder_final = state 119 | return encoder_final, memory_bank, orig_lengths.new_tensor(lengths) 120 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch.nn as nn 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | from onmt.utils.cnn_factory import shape_transform, StackedCNN 8 | 9 | SCALE_WEIGHT = 0.5 ** 0.5 10 | 11 | 12 | class CNNEncoder(EncoderBase): 13 | """ 14 | Encoder built on CNN based on 15 | :cite:`DBLP:journals/corr/GehringAGYD17`. 16 | """ 17 | 18 | def __init__(self, num_layers, hidden_size, 19 | cnn_kernel_width, dropout, embeddings): 20 | super(CNNEncoder, self).__init__() 21 | 22 | self.embeddings = embeddings 23 | input_size = embeddings.embedding_size 24 | self.linear = nn.Linear(input_size, hidden_size) 25 | self.cnn = StackedCNN(num_layers, hidden_size, 26 | cnn_kernel_width, dropout) 27 | 28 | def forward(self, input, lengths=None, hidden=None): 29 | """ See :obj:`onmt.modules.EncoderBase.forward()`""" 30 | self._check_args(input, lengths, hidden) 31 | 32 | emb = self.embeddings(input) 33 | # s_len, batch, emb_dim = emb.size() 34 | 35 | emb = emb.transpose(0, 1).contiguous() 36 | emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) 37 | emb_remap = self.linear(emb_reshape) 38 | emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) 39 | emb_remap = shape_transform(emb_remap) 40 | out = self.cnn(emb_remap) 41 | 42 | return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ 43 | out.squeeze(3).transpose(0, 1).contiguous(), lengths 44 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | from __future__ import division 4 | 5 | import torch.nn as nn 6 | 7 | from onmt.utils.misc import aeq 8 | 9 | 10 | class EncoderBase(nn.Module): 11 | """ 12 | Base encoder class. Specifies the interface used by different encoder types 13 | and required by :obj:`onmt.Models.NMTModel`. 14 | 15 | .. mermaid:: 16 | 17 | graph BT 18 | A[Input] 19 | subgraph RNN 20 | C[Pos 1] 21 | D[Pos 2] 22 | E[Pos N] 23 | end 24 | F[Memory_Bank] 25 | G[Final] 26 | A-->C 27 | A-->D 28 | A-->E 29 | C-->F 30 | D-->F 31 | E-->F 32 | E-->G 33 | """ 34 | 35 | def _check_args(self, src, lengths=None, hidden=None): 36 | _, n_batch, _ = src.size() 37 | if lengths is not None: 38 | n_batch_, = lengths.size() 39 | aeq(n_batch, n_batch_) 40 | 41 | def forward(self, src, lengths=None): 42 | """ 43 | Args: 44 | src (:obj:`LongTensor`): 45 | padded sequences of sparse indices `[src_len x batch x nfeat]` 46 | lengths (:obj:`LongTensor`): length of each sequence `[batch]` 47 | 48 | 49 | Returns: 50 | (tuple of :obj:`FloatTensor`, :obj:`FloatTensor`): 51 | * final encoder state, used to initialize decoder 52 | * memory bank for attention, `[src_len x batch x hidden]` 53 | """ 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | """ Image Encoder """ 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | 7 | class ImageEncoder(nn.Module): 8 | """ 9 | A simple encoder convolutional -> recurrent neural network for 10 | image src. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | """ 18 | 19 | def __init__(self, num_layers, bidirectional, rnn_size, dropout, 20 | image_chanel_size=3): 21 | super(ImageEncoder, self).__init__() 22 | self.num_layers = num_layers 23 | self.num_directions = 2 if bidirectional else 1 24 | self.hidden_size = rnn_size 25 | 26 | self.layer1 = nn.Conv2d(image_chanel_size, 64, kernel_size=(3, 3), 27 | padding=(1, 1), stride=(1, 1)) 28 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 29 | padding=(1, 1), stride=(1, 1)) 30 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 31 | padding=(1, 1), stride=(1, 1)) 32 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 33 | padding=(1, 1), stride=(1, 1)) 34 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 35 | padding=(1, 1), stride=(1, 1)) 36 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 37 | padding=(1, 1), stride=(1, 1)) 38 | 39 | self.batch_norm1 = nn.BatchNorm2d(256) 40 | self.batch_norm2 = nn.BatchNorm2d(512) 41 | self.batch_norm3 = nn.BatchNorm2d(512) 42 | 43 | src_size = 512 44 | self.rnn = nn.LSTM(src_size, int(rnn_size / self.num_directions), 45 | num_layers=num_layers, 46 | dropout=dropout, 47 | bidirectional=bidirectional) 48 | self.pos_lut = nn.Embedding(1000, src_size) 49 | 50 | def load_pretrained_vectors(self, opt): 51 | """ Pass in needed options only when modify function definition.""" 52 | pass 53 | 54 | def forward(self, src, lengths=None): 55 | "See :obj:`onmt.encoders.encoder.EncoderBase.forward()`" 56 | 57 | batch_size = src.size(0) 58 | # (batch_size, 64, imgH, imgW) 59 | # layer 1 60 | src = F.relu(self.layer1(src[:, :, :, :] - 0.5), True) 61 | 62 | # (batch_size, 64, imgH/2, imgW/2) 63 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 64 | 65 | # (batch_size, 128, imgH/2, imgW/2) 66 | # layer 2 67 | src = F.relu(self.layer2(src), True) 68 | 69 | # (batch_size, 128, imgH/2/2, imgW/2/2) 70 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 71 | 72 | # (batch_size, 256, imgH/2/2, imgW/2/2) 73 | # layer 3 74 | # batch norm 1 75 | src = F.relu(self.batch_norm1(self.layer3(src)), True) 76 | 77 | # (batch_size, 256, imgH/2/2, imgW/2/2) 78 | # layer4 79 | src = F.relu(self.layer4(src), True) 80 | 81 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 82 | src = F.max_pool2d(src, kernel_size=(1, 2), stride=(1, 2)) 83 | 84 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 85 | # layer 5 86 | # batch norm 2 87 | src = F.relu(self.batch_norm2(self.layer5(src)), True) 88 | 89 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 90 | src = F.max_pool2d(src, kernel_size=(2, 1), stride=(2, 1)) 91 | 92 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 93 | src = F.relu(self.batch_norm3(self.layer6(src)), True) 94 | 95 | # # (batch_size, 512, H, W) 96 | all_outputs = [] 97 | for row in range(src.size(2)): 98 | inp = src[:, :, row, :].transpose(0, 2) \ 99 | .transpose(1, 2) 100 | row_vec = torch.Tensor(batch_size).type_as(inp.data) \ 101 | .long().fill_(row) 102 | pos_emb = self.pos_lut(row_vec) 103 | with_pos = torch.cat( 104 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 105 | outputs, hidden_t = self.rnn(with_pos) 106 | all_outputs.append(outputs) 107 | out = torch.cat(all_outputs, 0) 108 | 109 | return hidden_t, out, lengths 110 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from __future__ import division 3 | 4 | from onmt.encoders.encoder import EncoderBase 5 | 6 | 7 | class MeanEncoder(EncoderBase): 8 | """A trivial non-recurrent encoder. Simply applies mean pooling. 9 | 10 | Args: 11 | num_layers (int): number of replicated layers 12 | embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use 13 | """ 14 | 15 | def __init__(self, num_layers, embeddings): 16 | super(MeanEncoder, self).__init__() 17 | self.num_layers = num_layers 18 | self.embeddings = embeddings 19 | 20 | def forward(self, src, lengths=None): 21 | "See :obj:`EncoderBase.forward()`" 22 | self._check_args(src, lengths) 23 | 24 | emb = self.embeddings(src) 25 | _, batch, emb_dim = emb.size() 26 | mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) 27 | memory_bank = emb 28 | encoder_final = (mean, mean) 29 | return encoder_final, memory_bank, lengths 30 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | """Define RNN-based encoders.""" 2 | from __future__ import division 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.nn.utils.rnn import pack_padded_sequence as pack 8 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 9 | 10 | from onmt.encoders.encoder import EncoderBase 11 | from onmt.utils.rnn_factory import rnn_factory 12 | 13 | 14 | class RNNEncoder(EncoderBase): 15 | """ A generic recurrent neural network encoder. 16 | 17 | Args: 18 | rnn_type (:obj:`str`): 19 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 20 | bidirectional (bool) : use a bidirectional RNN 21 | num_layers (int) : number of stacked layers 22 | hidden_size (int) : hidden size of each layer 23 | dropout (float) : dropout value for :obj:`nn.Dropout` 24 | embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use 25 | """ 26 | 27 | def __init__(self, rnn_type, bidirectional, num_layers, 28 | hidden_size, dropout=0.0, embeddings=None, 29 | use_bridge=False): 30 | super(RNNEncoder, self).__init__() 31 | assert embeddings is not None 32 | 33 | num_directions = 2 if bidirectional else 1 34 | assert hidden_size % num_directions == 0 35 | hidden_size = hidden_size // num_directions 36 | self.embeddings = embeddings 37 | 38 | self.rnn, self.no_pack_padded_seq = \ 39 | rnn_factory(rnn_type, 40 | input_size=embeddings.embedding_size, 41 | hidden_size=hidden_size, 42 | num_layers=num_layers, 43 | dropout=dropout, 44 | bidirectional=bidirectional) 45 | 46 | # Initialize the bridge layer 47 | self.use_bridge = use_bridge 48 | if self.use_bridge: 49 | self._initialize_bridge(rnn_type, 50 | hidden_size, 51 | num_layers) 52 | 53 | def forward(self, src, lengths=None): 54 | "See :obj:`EncoderBase.forward()`" 55 | self._check_args(src, lengths) 56 | 57 | emb = self.embeddings(src) 58 | # s_len, batch, emb_dim = emb.size() 59 | 60 | packed_emb = emb 61 | if lengths is not None and not self.no_pack_padded_seq: 62 | # Lengths data is wrapped inside a Tensor. 63 | lengths_list = lengths.view(-1).tolist() 64 | packed_emb = pack(emb, lengths_list) 65 | 66 | memory_bank, encoder_final = self.rnn(packed_emb) 67 | 68 | if lengths is not None and not self.no_pack_padded_seq: 69 | memory_bank = unpack(memory_bank)[0] 70 | 71 | if self.use_bridge: 72 | encoder_final = self._bridge(encoder_final) 73 | return encoder_final, memory_bank, lengths 74 | 75 | def _initialize_bridge(self, rnn_type, 76 | hidden_size, 77 | num_layers): 78 | 79 | # LSTM has hidden and cell state, other only one 80 | number_of_states = 2 if rnn_type == "LSTM" else 1 81 | # Total number of states 82 | self.total_hidden_dim = hidden_size * num_layers 83 | 84 | # Build a linear layer for each 85 | self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, 86 | self.total_hidden_dim, 87 | bias=True) 88 | for _ in range(number_of_states)]) 89 | 90 | def _bridge(self, hidden): 91 | """ 92 | Forward hidden state through bridge 93 | """ 94 | def bottle_hidden(linear, states): 95 | """ 96 | Transform from 3D to 2D, apply linear and return initial size 97 | """ 98 | size = states.size() 99 | result = linear(states.view(-1, self.total_hidden_dim)) 100 | return F.relu(result).view(size) 101 | 102 | if isinstance(hidden, tuple): # LSTM 103 | outs = tuple([bottle_hidden(layer, hidden[ix]) 104 | for ix, layer in enumerate(self.bridge)]) 105 | else: 106 | outs = bottle_hidden(self.bridge[0], hidden) 107 | return outs 108 | -------------------------------------------------------------------------------- /deephop/onmt/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | from numpy.core.multiarray import ndarray 7 | from torch import Tensor 8 | 9 | from Graph3dConv import Graph3dConv 10 | from graph_embedding import get_emb 11 | from onmt.GCN import * 12 | from onmt.MYGCN import * 13 | import time 14 | import onmt 15 | from onmt.MPNNs.MPNN import * 16 | import onmt.myutils as myutils 17 | from onmt.encoders.encoder import EncoderBase 18 | # from onmt.utils.misc import aeq 19 | from onmt.modules.position_ffn import PositionwiseFeedForward 20 | # from onmt.encoders import myutils 21 | # from onmt.encoders.MPNN.MPNN import MPNN 22 | from onmt.GATGATE import * 23 | from onmt.modules.util_class import make_condtion 24 | 25 | 26 | class TransformerEncoderLayer(nn.Module): 27 | """ 28 | A single layer of the transformer encoder. 29 | 30 | Args: 31 | d_model (int): the dimension of keys/values/queries in 32 | MultiHeadedAttention, also the input size of 33 | the first-layer of the PositionwiseFeedForward. 34 | heads (int): the number of head for MultiHeadedAttention. 35 | d_ff (int): the second-layer of the PositionwiseFeedForward. 36 | dropout (float): dropout probability(0-1.0). 37 | """ 38 | 39 | def __init__(self, d_model, heads, d_ff, dropout): 40 | super(TransformerEncoderLayer, self).__init__() 41 | self.self_attn = onmt.modules.MultiHeadedAttention( 42 | heads, d_model, dropout=dropout) 43 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 44 | self.layer_norm = onmt.modules.LayerNorm(d_model) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | def forward(self, inputs, mask): 48 | """ 49 | Transformer Encoder Layer definition. 50 | 51 | Args: 52 | inputs (`FloatTensor`): `[batch_size x src_len x model_dim]` 53 | mask (`LongTensor`): `[batch_size x src_len x src_len]` 54 | 55 | Returns: 56 | (`FloatTensor`): 57 | 58 | * outputs `[batch_size x src_len x model_dim]` 59 | """ 60 | input_norm = self.layer_norm(inputs) 61 | context, _ = self.self_attn(input_norm, input_norm, input_norm, 62 | mask=mask) 63 | out = self.dropout(context) + inputs 64 | return self.feed_forward(out) 65 | 66 | 67 | class TransformerEncoder(EncoderBase): 68 | """ 69 | The Transformer encoder from "Attention is All You Need". 70 | 71 | 72 | .. mermaid:: 73 | 74 | graph BT 75 | A[input] 76 | B[multi-head self-attn] 77 | C[feed forward] 78 | O[output] 79 | A --> B 80 | B --> C 81 | C --> O 82 | 83 | Args: 84 | num_layers (int): number of encoder layers 85 | d_model (int): size of the model 86 | heads (int): number of heads 87 | d_ff (int): size of the inner FF layer 88 | dropout (float): dropout parameters 89 | embeddings (:obj:`onmt.modules.Embeddings`): 90 | embeddings to use, should have positional encodings 91 | 92 | Returns: 93 | (`FloatTensor`, `FloatTensor`): 94 | 95 | * embeddings `[src_len x batch_size x model_dim]` 96 | * memory_bank `[src_len x batch_size x model_dim]` 97 | """ 98 | 99 | def __init__(self, num_layers, d_model, heads, d_ff, 100 | dropout, embeddings, condition_dim, arch): 101 | super(TransformerEncoder, self).__init__() 102 | self.num_layers = num_layers 103 | self.embeddings = embeddings 104 | assert condition_dim >= 0 105 | self.condition_dim = condition_dim 106 | self.transformer = nn.ModuleList( 107 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 108 | for _ in range(num_layers)]) 109 | if arch in ['transformer', 'after_encoding']: 110 | d_model += condition_dim 111 | self.layer_norm = onmt.modules.LayerNorm(d_model) 112 | 113 | self.gcn = Graph3dConv(3, 160, 256, label_dim=21) 114 | if arch == 'before_linear': 115 | self.fc = nn.Linear(512+self.condition_dim, 256) 116 | if arch == 'transformer': 117 | self.fc = nn.Linear(256, 256) 118 | else: 119 | self.fc = nn.Linear(512, 256) 120 | self.arch = arch 121 | 122 | def forward(self, src, lengths=None): 123 | """ See :obj:`EncoderBase.forward()`""" 124 | gs = src[1] 125 | condition = src[2] 126 | src = src[0] 127 | 128 | emb = self.embeddings(src) 129 | 130 | out = emb.transpose(0, 1).contiguous() 131 | 132 | cat_list = [out] 133 | # transformer 模式不加入gcn 模块的编码 134 | if self.arch != 'transformer': 135 | emb2 = myutils.gcn_emb(self.gcn, gs, src.device) 136 | emb2 = emb2.view(src.size(1), -1, 256) 137 | cat_list.append(emb2) 138 | 139 | if self.condition_dim > 0 and self.arch == 'before_linear': 140 | condition_of_every_atom = make_condtion(condition, src.size(0), out.device, out.dtype) 141 | cat_list.append(condition_of_every_atom) 142 | 143 | if len(cat_list) > 1: 144 | out = torch.cat(cat_list, dim=2) 145 | else: 146 | out = cat_list[0] 147 | out = self.fc(out) 148 | 149 | words = src[:, :, 0].transpose(0, 1) 150 | w_batch, w_len = words.size() 151 | padding_idx = self.embeddings.word_padding_idx 152 | mask = words.data.eq(padding_idx).unsqueeze(1) \ 153 | .expand(w_batch, w_len, w_len) 154 | # Run the forward pass of every layer of the tranformer. 155 | for i in range(self.num_layers): 156 | out = self.transformer[i](out, mask) 157 | 158 | if self.arch in ['transformer', 'after_encoding']: 159 | condition_of_every_atom = make_condtion(condition, src.size(0), out.device, out.dtype) 160 | out = torch.cat((out, condition_of_every_atom), dim=2) 161 | emb_cond = condition_of_every_atom.transpose(0, 1).contiguous() 162 | emb = torch.cat((emb, emb_cond), dim=2) 163 | 164 | out = self.layer_norm(out) 165 | 166 | return emb, out.transpose(0, 1).contiguous(), lengths 167 | -------------------------------------------------------------------------------- /deephop/onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import collect_feature_vocabs, make_features, \ 7 | collect_features, get_num_features, \ 8 | load_fields_from_vocab, get_fields, \ 9 | save_fields_to_vocab, build_dataset, \ 10 | build_vocab, merge_vocabs, OrderedIterator 11 | from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \ 12 | EOS_WORD, UNK 13 | from onmt.inputters.text_dataset import TextDataset, ShardedTextCorpusIterator 14 | from onmt.inputters.image_dataset import ImageDataset 15 | from onmt.inputters.audio_dataset import AudioDataset, \ 16 | ShardedAudioCorpusIterator 17 | 18 | 19 | __all__ = ['PAD_WORD', 'BOS_WORD', 'EOS_WORD', 'UNK', 'DatasetBase', 20 | 'collect_feature_vocabs', 'make_features', 21 | 'collect_features', 'get_num_features', 22 | 'load_fields_from_vocab', 'get_fields', 23 | 'save_fields_to_vocab', 'build_dataset', 24 | 'build_vocab', 'merge_vocabs', 'OrderedIterator', 25 | 'TextDataset', 'ImageDataset', 'AudioDataset', 26 | 'ShardedTextCorpusIterator', 'ShardedAudioCorpusIterator'] 27 | -------------------------------------------------------------------------------- /deephop/onmt/inputters/dataset_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Base dataset class and constants 4 | """ 5 | from itertools import chain 6 | import torchtext 7 | 8 | import onmt 9 | 10 | PAD_WORD = '' 11 | UNK_WORD = '' 12 | UNK = 0 13 | BOS_WORD = '' 14 | EOS_WORD = '' 15 | 16 | 17 | class DatasetBase(torchtext.data.Dataset): 18 | """ 19 | A dataset basically supports iteration over all the examples 20 | it contains. We currently have 3 datasets inheriting this base 21 | for 3 types of corpus respectively: "text", "img", "audio". 22 | 23 | Internally it initializes an `torchtext.data.Dataset` object with 24 | the following attributes: 25 | 26 | `examples`: a sequence of `torchtext.data.Example` objects. 27 | `fields`: a dictionary associating str keys with `torchtext.data.Field` 28 | objects, and not necessarily having the same keys as the input fields. 29 | """ 30 | 31 | def __getstate__(self): 32 | return self.__dict__ 33 | 34 | def __setstate__(self, _d): 35 | self.__dict__.update(_d) 36 | 37 | def __reduce_ex__(self, proto): 38 | "This is a hack. Something is broken with torch pickle." 39 | return super(DatasetBase, self).__reduce_ex__(proto) 40 | 41 | def load_fields(self, vocab_dict): 42 | """ Load fields from vocab.pt, and set the `fields` attribute. 43 | 44 | Args: 45 | vocab_dict (dict): a dict of loaded vocab from vocab.pt file. 46 | """ 47 | fields = onmt.inputters.inputter.load_fields_from_vocab( 48 | vocab_dict.items(), self.data_type) 49 | self.fields = dict([(k, f) for (k, f) in fields.items() 50 | if k in self.examples[0].__dict__]) 51 | 52 | @staticmethod 53 | def extract_text_features(tokens): 54 | """ 55 | Args: 56 | tokens: A list of tokens, where each token consists of a word, 57 | optionally followed by u"│"-delimited features. 58 | Returns: 59 | A sequence of words, a sequence of features, and num of features. 60 | """ 61 | if not tokens: 62 | return [], [], -1 63 | 64 | specials = [PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD] 65 | words = [] 66 | features = [] 67 | n_feats = None 68 | for token in tokens: 69 | split_token = token.split(u"│") 70 | assert all([special != split_token[0] for special in specials]), \ 71 | "Dataset cannot contain Special Tokens" 72 | 73 | if split_token[0]: 74 | words += [split_token[0]] 75 | features += [split_token[1:]] 76 | 77 | if n_feats is None: 78 | n_feats = len(split_token) 79 | else: 80 | assert len(split_token) == n_feats, \ 81 | "all words must have the same number of features" 82 | features = list(zip(*features)) 83 | return tuple(words), features, n_feats - 1 84 | 85 | # Below are helper functions for intra-class use only. 86 | 87 | def _join_dicts(self, *args): 88 | """ 89 | Args: 90 | dictionaries with disjoint keys. 91 | 92 | Returns: 93 | a single dictionary that has the union of these keys. 94 | """ 95 | return dict(chain(*[d.items() for d in args])) 96 | 97 | def _peek(self, seq): 98 | """ 99 | Args: 100 | seq: an iterator. 101 | 102 | Returns: 103 | the first thing returned by calling next() on the iterator 104 | and an iterator created by re-chaining that value to the beginning 105 | of the iterator. 106 | """ 107 | first = next(seq) 108 | return first, chain([first], seq) 109 | 110 | def _construct_example_fromlist(self, data, fields): 111 | """ 112 | Args: 113 | data: the data to be set as the value of the attributes of 114 | the to-be-created `Example`, associating with respective 115 | `Field` objects with same key. 116 | fields: a dict of `torchtext.data.Field` objects. The keys 117 | are attributes of the to-be-created `Example`. 118 | 119 | Returns: 120 | the created `Example` object. 121 | """ 122 | ex = torchtext.data.Example() 123 | for (name, field), val in zip(fields, data): 124 | if field is not None: 125 | setattr(ex, name, field.preprocess(val)) 126 | else: 127 | setattr(ex, name, val) 128 | return ex 129 | -------------------------------------------------------------------------------- /deephop/onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | 5 | __all__ = ["build_model_saver", "ModelSaver", 6 | "NMTModel", "check_sru_requirement"] 7 | -------------------------------------------------------------------------------- /deephop/onmt/models/model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | from onmt.modules.util_class import make_condtion 6 | 7 | 8 | class NMTModel(nn.Module): 9 | """ 10 | Core trainable object in OpenNMT. Implements a trainable interface 11 | for a simple, generic encoder + decoder model. 12 | 13 | Args: 14 | encoder (:obj:`EncoderBase`): an encoder object 15 | decoder (:obj:`RNNDecoderBase`): a decoder object 16 | multi 0: 43 | self.checkpoint_queue = deque([], maxlen=keep_checkpoint) 44 | 45 | def maybe_save(self, step): 46 | """ 47 | Main entry point for model saver 48 | It wraps the `_save` method with checks and apply `keep_checkpoint` 49 | related logic 50 | """ 51 | if self.keep_checkpoint == 0: 52 | return 53 | 54 | if (step - self.begin_step) % self.save_checkpoint_steps != 0: 55 | return 56 | 57 | if len(self.checkpoint_queue) > 30 and self.init_save_checkpoint_steps == self.save_checkpoint_steps: 58 | self.save_checkpoint_steps = self.save_checkpoint_steps*5 59 | 60 | chkpt, chkpt_name = self._save(step) 61 | 62 | if self.keep_checkpoint > 0: 63 | if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: 64 | todel = self.checkpoint_queue.popleft() 65 | self._rm_checkpoint(todel) 66 | self.checkpoint_queue.append(chkpt_name) 67 | 68 | def _save(self, step): 69 | """ Save a resumable checkpoint. 70 | 71 | Args: 72 | step (int): step number 73 | 74 | Returns: 75 | checkpoint: the saved object 76 | checkpoint_name: name (or path) of the saved checkpoint 77 | """ 78 | raise NotImplementedError() 79 | 80 | def _rm_checkpoint(self, name): 81 | """ 82 | Remove a checkpoint 83 | 84 | Args: 85 | name(str): name that indentifies the checkpoint 86 | (it may be a filepath) 87 | """ 88 | raise NotImplementedError() 89 | 90 | 91 | class ModelSaver(ModelSaverBase): 92 | """ 93 | Simple model saver to filesystem 94 | """ 95 | 96 | def __init__(self, base_path, model, model_opt, fields, optim, 97 | save_checkpoint_steps, keep_checkpoint=0): 98 | super(ModelSaver, self).__init__( 99 | base_path, model, model_opt, fields, optim, 100 | save_checkpoint_steps, keep_checkpoint) 101 | 102 | def _save(self, step): 103 | real_model = (self.model.module 104 | if isinstance(self.model, nn.DataParallel) 105 | else self.model) 106 | real_generator = (real_model.generator.module 107 | if isinstance(real_model.generator, nn.DataParallel) 108 | else real_model.generator) 109 | 110 | model_state_dict = real_model.state_dict() 111 | model_state_dict = {k: v for k, v in model_state_dict.items() 112 | if 'generator' not in k} 113 | generator_state_dict = real_generator.state_dict() 114 | checkpoint = { 115 | 'model': model_state_dict, 116 | 'generator': generator_state_dict, 117 | 'vocab': onmt.inputters.save_fields_to_vocab(self.fields), 118 | 'opt': self.model_opt, 119 | 'optim': self.optim, 120 | } 121 | 122 | logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) 123 | checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) 124 | torch.save(checkpoint, checkpoint_path) 125 | return checkpoint, checkpoint_path 126 | 127 | def _rm_checkpoint(self, name): 128 | os.remove(name) 129 | -------------------------------------------------------------------------------- /deephop/onmt/models/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | """ Implementation of ONMT RNN for Input Feeding Decoding """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class StackedLSTM(nn.Module): 7 | """ 8 | Our own implementation of stacked LSTM. 9 | Needed for the decoder, because we do input feeding. 10 | """ 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTM, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | self.layers = nn.ModuleList() 17 | 18 | for _ in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, input_feed, hidden): 23 | h_0, c_0 = hidden 24 | h_1, c_1 = [], [] 25 | for i, layer in enumerate(self.layers): 26 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 27 | input_feed = h_1_i 28 | if i + 1 != self.num_layers: 29 | input_feed = self.dropout(input_feed) 30 | h_1 += [h_1_i] 31 | c_1 += [c_1_i] 32 | 33 | h_1 = torch.stack(h_1) 34 | c_1 = torch.stack(c_1) 35 | 36 | return input_feed, (h_1, c_1) 37 | 38 | 39 | class StackedGRU(nn.Module): 40 | """ 41 | Our own implementation of stacked GRU. 42 | Needed for the decoder, because we do input feeding. 43 | """ 44 | 45 | def __init__(self, num_layers, input_size, rnn_size, dropout): 46 | super(StackedGRU, self).__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList() 50 | 51 | for _ in range(num_layers): 52 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 53 | input_size = rnn_size 54 | 55 | def forward(self, input_feed, hidden): 56 | h_1 = [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i = layer(input_feed, hidden[0][i]) 59 | input_feed = h_1_i 60 | if i + 1 != self.num_layers: 61 | input_feed = self.dropout(input_feed) 62 | h_1 += [h_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | return input_feed, (h_1,) 66 | -------------------------------------------------------------------------------- /deephop/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import LayerNorm, Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 6 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLossCompute 7 | from onmt.modules.multi_headed_attn import MultiHeadedAttention 8 | from onmt.modules.embeddings import Embeddings, PositionalEncoding 9 | from onmt.modules.weight_norm import WeightNormConv2d 10 | from onmt.modules.average_attn import AverageAttention 11 | 12 | __all__ = ["LayerNorm", "Elementwise", "context_gate_factory", "ContextGate", 13 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 14 | "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", 15 | "PositionalEncoding", "WeightNormConv2d", "AverageAttention"] 16 | -------------------------------------------------------------------------------- /deephop/onmt/modules/average_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Average Attention module """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | 9 | 10 | class AverageAttention(nn.Module): 11 | """ 12 | Average Attention module from 13 | "Accelerating Neural Transformer via an Average Attention Network" 14 | :cite:`https://arxiv.org/abs/1805.00631`. 15 | 16 | Args: 17 | model_dim (int): the dimension of keys/values/queries, 18 | must be divisible by head_count 19 | dropout (float): dropout parameter 20 | """ 21 | 22 | def __init__(self, model_dim, dropout=0.1): 23 | self.model_dim = model_dim 24 | 25 | super(AverageAttention, self).__init__() 26 | 27 | self.average_layer = PositionwiseFeedForward(model_dim, model_dim, 28 | dropout) 29 | self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) 30 | 31 | def cumulative_average_mask(self, batch_size, inputs_len): 32 | """ 33 | Builds the mask to compute the cumulative average as described in 34 | https://arxiv.org/abs/1805.00631 -- Figure 3 35 | 36 | Args: 37 | batch_size (int): batch size 38 | inputs_len (int): length of the inputs 39 | 40 | Returns: 41 | (`FloatTensor`): 42 | 43 | * A Tensor of shape `[batch_size x input_len x input_len]` 44 | """ 45 | 46 | triangle = torch.tril(torch.ones(inputs_len, inputs_len)) 47 | weights = torch.ones(1, inputs_len) / torch.arange( 48 | 1, inputs_len + 1, dtype=torch.float) 49 | mask = triangle * weights.transpose(0, 1) 50 | 51 | return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) 52 | 53 | def cumulative_average(self, inputs, mask_or_step, 54 | layer_cache=None, step=None): 55 | """ 56 | Computes the cumulative average as described in 57 | https://arxiv.org/abs/1805.00631 -- Equations (1) (5) (6) 58 | 59 | Args: 60 | inputs (`FloatTensor`): sequence to average 61 | `[batch_size x input_len x dimension]` 62 | mask_or_step: if cache is set, this is assumed 63 | to be the current step of the 64 | dynamic decoding. Otherwise, it is the mask matrix 65 | used to compute the cumulative average. 66 | cache: a dictionary containing the cumulative average 67 | of the previous step. 68 | """ 69 | if layer_cache is not None: 70 | step = mask_or_step 71 | device = inputs.device 72 | average_attention = (inputs + step * 73 | layer_cache["prev_g"].to(device)) / (step + 1) 74 | layer_cache["prev_g"] = average_attention 75 | return average_attention 76 | else: 77 | mask = mask_or_step 78 | return torch.matmul(mask, inputs) 79 | 80 | def forward(self, inputs, mask=None, layer_cache=None, step=None): 81 | """ 82 | Args: 83 | inputs (`FloatTensor`): `[batch_size x input_len x model_dim]` 84 | 85 | Returns: 86 | (`FloatTensor`, `FloatTensor`): 87 | 88 | * gating_outputs `[batch_size x 1 x model_dim]` 89 | * average_outputs average attention `[batch_size x 1 x model_dim]` 90 | """ 91 | batch_size = inputs.size(0) 92 | inputs_len = inputs.size(1) 93 | 94 | device = inputs.device 95 | average_outputs = self.cumulative_average( 96 | inputs, self.cumulative_average_mask(batch_size, 97 | inputs_len).to(device).float() 98 | if layer_cache is None else step, layer_cache=layer_cache) 99 | average_outputs = self.average_layer(average_outputs) 100 | gating_outputs = self.gating_layer(torch.cat((inputs, 101 | average_outputs), -1)) 102 | input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) 103 | gating_outputs = torch.sigmoid(input_gate) * inputs + \ 104 | torch.sigmoid(forget_gate) * average_outputs 105 | 106 | return gating_outputs, average_outputs 107 | -------------------------------------------------------------------------------- /deephop/onmt/modules/conv_multi_step_attention.py: -------------------------------------------------------------------------------- 1 | """ Multi Step Attention for CNN """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | SCALE_WEIGHT = 0.5 ** 0.5 9 | 10 | 11 | def seq_linear(linear, x): 12 | """ linear transform for 3-d tensor """ 13 | batch, hidden_size, length, _ = x.size() 14 | h = linear(torch.transpose(x, 1, 2).contiguous().view( 15 | batch * length, hidden_size)) 16 | return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) 17 | 18 | 19 | class ConvMultiStepAttention(nn.Module): 20 | """ 21 | 22 | Conv attention takes a key matrix, a value matrix and a query vector. 23 | Attention weight is calculated by key matrix with the query vector 24 | and sum on the value matrix. And the same operation is applied 25 | in each decode conv layer. 26 | 27 | """ 28 | 29 | def __init__(self, input_size): 30 | super(ConvMultiStepAttention, self).__init__() 31 | self.linear_in = nn.Linear(input_size, input_size) 32 | self.mask = None 33 | 34 | def apply_mask(self, mask): 35 | """ Apply mask """ 36 | self.mask = mask 37 | 38 | def forward(self, base_target_emb, input_from_dec, encoder_out_top, 39 | encoder_out_combine): 40 | """ 41 | Args: 42 | base_target_emb: target emb tensor 43 | input: output of decode conv 44 | encoder_out_t: the key matrix for calculation of attetion weight, 45 | which is the top output of encode conv 46 | encoder_out_combine: 47 | the value matrix for the attention-weighted sum, 48 | which is the combination of base emb and top output of encode 49 | 50 | """ 51 | # checks 52 | # batch, channel, height, width = base_target_emb.size() 53 | batch, _, height, _ = base_target_emb.size() 54 | # batch_, channel_, height_, width_ = input_from_dec.size() 55 | batch_, _, height_, _ = input_from_dec.size() 56 | aeq(batch, batch_) 57 | aeq(height, height_) 58 | 59 | # enc_batch, enc_channel, enc_height = encoder_out_top.size() 60 | enc_batch, _, enc_height = encoder_out_top.size() 61 | # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() 62 | enc_batch_, _, enc_height_ = encoder_out_combine.size() 63 | 64 | aeq(enc_batch, enc_batch_) 65 | aeq(enc_height, enc_height_) 66 | 67 | preatt = seq_linear(self.linear_in, input_from_dec) 68 | target = (base_target_emb + preatt) * SCALE_WEIGHT 69 | target = torch.squeeze(target, 3) 70 | target = torch.transpose(target, 1, 2) 71 | pre_attn = torch.bmm(target, encoder_out_top) 72 | 73 | if self.mask is not None: 74 | pre_attn.data.masked_fill_(self.mask, -float('inf')) 75 | 76 | pre_attn = pre_attn.transpose(0, 2) 77 | attn = F.softmax(pre_attn, dim=-1) 78 | attn = attn.transpose(0, 2).contiguous() 79 | context_output = torch.bmm( 80 | attn, torch.transpose(encoder_out_combine, 1, 2)) 81 | context_output = torch.transpose( 82 | torch.unsqueeze(context_output, 3), 1, 2) 83 | return context_output, attn 84 | -------------------------------------------------------------------------------- /deephop/onmt/modules/gate.py: -------------------------------------------------------------------------------- 1 | """ ContextGate module """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def context_gate_factory(gate_type, embeddings_size, decoder_size, 7 | attention_size, output_size): 8 | """Returns the correct ContextGate class""" 9 | 10 | gate_types = {'source': SourceContextGate, 11 | 'target': TargetContextGate, 12 | 'both': BothContextGate} 13 | 14 | assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( 15 | gate_type) 16 | return gate_types[gate_type](embeddings_size, decoder_size, attention_size, 17 | output_size) 18 | 19 | 20 | class ContextGate(nn.Module): 21 | """ 22 | Context gate is a decoder module that takes as input the previous word 23 | embedding, the current decoder state and the attention state, and 24 | produces a gate. 25 | The gate can be used to select the input from the target side context 26 | (decoder state), from the source context (attention state) or both. 27 | """ 28 | 29 | def __init__(self, embeddings_size, decoder_size, 30 | attention_size, output_size): 31 | super(ContextGate, self).__init__() 32 | input_size = embeddings_size + decoder_size + attention_size 33 | self.gate = nn.Linear(input_size, output_size, bias=True) 34 | self.sig = nn.Sigmoid() 35 | self.source_proj = nn.Linear(attention_size, output_size) 36 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 37 | output_size) 38 | 39 | def forward(self, prev_emb, dec_state, attn_state): 40 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 41 | z = self.sig(self.gate(input_tensor)) 42 | proj_source = self.source_proj(attn_state) 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | return z, proj_source, proj_target 46 | 47 | 48 | class SourceContextGate(nn.Module): 49 | """Apply the context gate only to the source context""" 50 | 51 | def __init__(self, embeddings_size, decoder_size, 52 | attention_size, output_size): 53 | super(SourceContextGate, self).__init__() 54 | self.context_gate = ContextGate(embeddings_size, decoder_size, 55 | attention_size, output_size) 56 | self.tanh = nn.Tanh() 57 | 58 | def forward(self, prev_emb, dec_state, attn_state): 59 | z, source, target = self.context_gate( 60 | prev_emb, dec_state, attn_state) 61 | return self.tanh(target + z * source) 62 | 63 | 64 | class TargetContextGate(nn.Module): 65 | """Apply the context gate only to the target context""" 66 | 67 | def __init__(self, embeddings_size, decoder_size, 68 | attention_size, output_size): 69 | super(TargetContextGate, self).__init__() 70 | self.context_gate = ContextGate(embeddings_size, decoder_size, 71 | attention_size, output_size) 72 | self.tanh = nn.Tanh() 73 | 74 | def forward(self, prev_emb, dec_state, attn_state): 75 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 76 | return self.tanh(z * target + source) 77 | 78 | 79 | class BothContextGate(nn.Module): 80 | """Apply the context gate to both contexts""" 81 | 82 | def __init__(self, embeddings_size, decoder_size, 83 | attention_size, output_size): 84 | super(BothContextGate, self).__init__() 85 | self.context_gate = ContextGate(embeddings_size, decoder_size, 86 | attention_size, output_size) 87 | self.tanh = nn.Tanh() 88 | 89 | def forward(self, prev_emb, dec_state, attn_state): 90 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 91 | return self.tanh((1. - z) * target + z * source) 92 | -------------------------------------------------------------------------------- /deephop/onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Position feed-forward network from "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | import onmt 8 | 9 | 10 | class PositionwiseFeedForward(nn.Module): 11 | """ A two-layer Feed-Forward-Network with residual layer norm. 12 | 13 | Args: 14 | d_model (int): the size of input for the first-layer of the FFN. 15 | d_ff (int): the hidden layer size of the second-layer 16 | of the FNN. 17 | dropout (float): dropout probability(0-1.0). 18 | """ 19 | 20 | def __init__(self, d_model, d_ff, dropout=0.1): 21 | super(PositionwiseFeedForward, self).__init__() 22 | self.w_1 = nn.Linear(d_model, d_ff) 23 | self.w_2 = nn.Linear(d_ff, d_model) 24 | self.layer_norm = onmt.modules.LayerNorm(d_model) 25 | self.dropout_1 = nn.Dropout(dropout) 26 | self.relu = nn.ReLU() 27 | self.dropout_2 = nn.Dropout(dropout) 28 | 29 | def forward(self, x): 30 | """ 31 | Layer definition. 32 | 33 | Args: 34 | input: [ batch_size, input_len, model_dim ] 35 | 36 | 37 | Returns: 38 | output: [ batch_size, input_len, model_dim ] 39 | """ 40 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 41 | output = self.dropout_2(self.w_2(inter)) 42 | return output + x 43 | -------------------------------------------------------------------------------- /deephop/onmt/modules/sparse_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of sparsemax (Martins & Astudillo, 2016). See 3 | https://arxiv.org/pdf/1602.02068 for detailed description. 4 | """ 5 | 6 | import torch 7 | from torch.autograd import Function 8 | import torch.nn as nn 9 | 10 | 11 | def threshold_and_support(z, dim=0): 12 | """ 13 | z: any dimension 14 | dim: dimension along which to apply the sparsemax 15 | """ 16 | sorted_z, _ = torch.sort(z, descending=True, dim=dim) 17 | z_sum = sorted_z.cumsum(dim) - 1 # sort of a misnomer 18 | k = torch.arange(1, sorted_z.size(dim) + 1, device=z.device).float().view( 19 | torch.Size([-1] + [1] * (z.dim() - 1)) 20 | ).transpose(0, dim) 21 | support = k * sorted_z > z_sum 22 | 23 | k_z_indices = support.sum(dim=dim).unsqueeze(dim) 24 | k_z = k_z_indices.float() 25 | tau_z = z_sum.gather(dim, k_z_indices - 1) / k_z 26 | return tau_z, k_z 27 | 28 | 29 | class SparsemaxFunction(Function): 30 | 31 | @staticmethod 32 | def forward(ctx, input, dim=0): 33 | """ 34 | input (FloatTensor): any shape 35 | returns (FloatTensor): same shape with sparsemax computed on given dim 36 | """ 37 | ctx.dim = dim 38 | tau_z, k_z = threshold_and_support(input, dim=dim) 39 | output = torch.clamp(input - tau_z, min=0) 40 | ctx.save_for_backward(k_z, output) 41 | return output 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | k_z, output = ctx.saved_tensors 46 | dim = ctx.dim 47 | grad_input = grad_output.clone() 48 | grad_input[output == 0] = 0 49 | 50 | v_hat = (grad_input.sum(dim=dim) / k_z.squeeze()).unsqueeze(dim) 51 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 52 | return grad_input, None 53 | 54 | 55 | sparsemax = SparsemaxFunction.apply 56 | 57 | 58 | class Sparsemax(nn.Module): 59 | 60 | def __init__(self, dim=0): 61 | self.dim = dim 62 | super(Sparsemax, self).__init__() 63 | 64 | def forward(self, input): 65 | return sparsemax(input, self.dim) 66 | 67 | 68 | class LogSparsemax(nn.Module): 69 | 70 | def __init__(self, dim=0): 71 | self.dim = dim 72 | super(LogSparsemax, self).__init__() 73 | 74 | def forward(self, input): 75 | return torch.log(sparsemax(input, self.dim)) 76 | -------------------------------------------------------------------------------- /deephop/onmt/modules/sparse_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from onmt.modules.sparse_activations import threshold_and_support 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class SparsemaxLossFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, target): 12 | """ 13 | input (FloatTensor): n x num_classes 14 | target (LongTensor): n, the indices of the target classes 15 | """ 16 | input_batch, classes = input.size() 17 | target_batch = target.size(0) 18 | aeq(input_batch, target_batch) 19 | 20 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() 21 | tau_z, support_size = threshold_and_support(input, dim=1) 22 | support = input > tau_z 23 | x = torch.where( 24 | support, input**2 - tau_z**2, 25 | torch.tensor(0.0, device=input.device) 26 | ).sum(dim=1) 27 | ctx.save_for_backward(input, target, tau_z) 28 | # clamping necessary because of numerical errors: loss should be lower 29 | # bounded by zero, but negative values near zero are possible without 30 | # the clamp 31 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, target, tau_z = ctx.saved_tensors 36 | sparsemax_out = torch.clamp(input - tau_z, min=0) 37 | delta = torch.zeros_like(sparsemax_out) 38 | delta.scatter_(1, target.unsqueeze(1), 1) 39 | return sparsemax_out - delta, None 40 | 41 | 42 | sparsemax_loss = SparsemaxLossFunction.apply 43 | 44 | 45 | class SparsemaxLoss(nn.Module): 46 | """ 47 | An implementation of sparsemax loss, first proposed in "From Softmax to 48 | Sparsemax: A Sparse Model of Attention and Multi-Label Classification" 49 | (Martins & Astudillo, 2016: https://arxiv.org/pdf/1602.02068). If using 50 | a sparse output layer, it is not possible to use negative log likelihood 51 | because the loss is infinite in the case the target is assigned zero 52 | probability. Inputs to SparsemaxLoss are arbitrary dense real-valued 53 | vectors (like in nn.CrossEntropyLoss), not probability vectors (like in 54 | nn.NLLLoss). 55 | """ 56 | 57 | def __init__(self, weight=None, ignore_index=-100, 58 | reduce=True, size_average=True): 59 | self.weight = weight 60 | self.ignore_index = ignore_index 61 | self.reduce = reduce 62 | self.size_average = size_average 63 | super(SparsemaxLoss, self).__init__() 64 | 65 | def forward(self, input, target): 66 | loss = sparsemax_loss(input, target) 67 | if self.ignore_index >= 0: 68 | ignored_positions = target == self.ignore_index 69 | size = float((target.size(0) - ignored_positions.sum()).item()) 70 | loss.masked_fill_(ignored_positions, 0.0) 71 | else: 72 | size = float(target.size(0)) 73 | if self.reduce: 74 | loss = loss.sum() 75 | if self.size_average: 76 | loss = loss / size 77 | return loss 78 | -------------------------------------------------------------------------------- /deephop/onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | from onmt.utils.logging import init_logger 5 | 6 | 7 | class MatrixTree(nn.Module): 8 | """Implementation of the matrix-tree theorem for computing marginals 9 | of non-projective dependency parsing. This attention layer is used 10 | in the paper "Learning Structured Text Representations." 11 | 12 | 13 | :cite:`DBLP:journals/corr/LiuL17d` 14 | """ 15 | 16 | def __init__(self, eps=1e-5): 17 | self.eps = eps 18 | super(MatrixTree, self).__init__() 19 | 20 | def forward(self, input): 21 | laplacian = input.exp() + self.eps 22 | output = input.clone() 23 | for b in range(input.size(0)): 24 | lap = laplacian[b].masked_fill( 25 | torch.eye(input.size(1)).cuda().ne(0), 0) 26 | lap = -lap + torch.diag(lap.sum(0)) 27 | # store roots on diagonal 28 | lap[0] = input[b].diag().exp() 29 | inv_laplacian = lap.inverse() 30 | 31 | factor = inv_laplacian.diag().unsqueeze(1)\ 32 | .expand_as(input[b]).transpose(0, 1) 33 | term1 = input[b].exp().mul(factor).clone() 34 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 35 | term1[:, 0] = 0 36 | term2[0] = 0 37 | output[b] = term1 - term2 38 | roots_output = input[b].diag().exp().mul( 39 | inv_laplacian.transpose(0, 1)[0]) 40 | output[b] = output[b] + torch.diag(roots_output) 41 | return output 42 | 43 | 44 | if __name__ == "__main__": 45 | logger = init_logger('StructuredAttention.log') 46 | dtree = MatrixTree() 47 | q = torch.rand(1, 5, 5).cuda() 48 | marg = dtree.forward(q) 49 | logger.info(marg.sum(1)) 50 | -------------------------------------------------------------------------------- /deephop/onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | from numpy.core.multiarray import ndarray 5 | 6 | from graph_embedding import get_emb 7 | 8 | 9 | class LayerNorm(nn.Module): 10 | """ 11 | Layer Normalization class 12 | """ 13 | 14 | def __init__(self, features, eps=1e-6): 15 | super(LayerNorm, self).__init__() 16 | self.a_2 = nn.Parameter(torch.ones(features)) 17 | self.b_2 = nn.Parameter(torch.zeros(features)) 18 | self.eps = eps 19 | 20 | def forward(self, x): 21 | 22 | mean = x.mean(-1, keepdim=True) 23 | std = x.std(-1, keepdim=True) 24 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 25 | 26 | 27 | # At the moment this class is only used by embeddings.Embeddings look-up tables 28 | class Elementwise(nn.ModuleList): 29 | """ 30 | A simple network container. 31 | Parameters are a list of modules. 32 | Inputs are a 3d Tensor whose last dimension is the same length 33 | as the list. 34 | Outputs are the result of applying modules to inputs elementwise. 35 | An optional merge parameter allows the outputs to be reduced to a 36 | single Tensor. 37 | """ 38 | 39 | def __init__(self, merge=None, *args): 40 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 41 | self.merge = merge 42 | super(Elementwise, self).__init__(*args) 43 | 44 | def forward(self, inputs): 45 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 46 | assert len(self) == len(inputs_) 47 | outputs = [f(x) for f, x in zip(self, inputs_)] 48 | if self.merge == 'first': 49 | return outputs[0] 50 | elif self.merge == 'concat' or self.merge == 'mlp': 51 | return torch.cat(outputs, 2) 52 | elif self.merge == 'sum': 53 | return sum(outputs) 54 | else: 55 | return outputs 56 | 57 | def make_condtion(condition, num_words, device, dtype): 58 | if isinstance(condition, torch.Tensor): 59 | return condition 60 | if not isinstance(condition, ndarray) and not isinstance(condition[0], list): 61 | condition = get_emb(condition) 62 | if condition.ndim == 1: 63 | return None 64 | condition = torch.tensor(condition, device=device, dtype=dtype) 65 | to_every_atom = condition.repeat(1, num_words) 66 | condition_of_every_atom = to_every_atom.view(condition.size(0), -1, condition.size(1)) 67 | return condition_of_every_atom 68 | -------------------------------------------------------------------------------- /deephop/onmt/tests/rebuild_test_models.sh: -------------------------------------------------------------------------------- 1 | # # Retrain the models used for CI. 2 | # # Should be done rarely, indicates a major breaking change. 3 | my_python=python 4 | 5 | ############### TEST regular RNN choose either -rnn_type LSTM / GRU / SRU and set input_feed 0 for SRU 6 | if true; then 7 | rm data/*.pt 8 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 9 | 10 | $my_python train.py -data data/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 1 -train_steps 10000 -optim adam -learning_rate 0.001 -rnn_type LSTM -input_feed 0 11 | #-truncated_decoder 5 12 | #-label_smoothing 0.1 13 | 14 | mv tmp*e10.pt onmt/tests/test_model.pt 15 | rm tmp*.pt 16 | fi 17 | # 18 | # 19 | ############### TEST CNN 20 | if false; then 21 | rm data/*.pt 22 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 23 | 24 | $my_python train.py -data data/data -save_model /tmp/tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 2 -train_steps 10000 -optim adam -learning_rate 0.001 -encoder_type cnn -decoder_type cnn 25 | 26 | 27 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 28 | 29 | rm /tmp/tmp*.pt 30 | fi 31 | # 32 | ################# MORPH DATA 33 | if true; then 34 | rm data/morph/*.pt 35 | $my_python preprocess.py -train_src data/morph/src.train -train_tgt data/morph/tgt.train -valid_src data/morph/src.valid -valid_tgt data/morph/tgt.valid -save_data data/morph/data 36 | 37 | $my_python train.py -data data/morph/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 400 -word_vec_size 100 -layers 1 -train_steps 8000 -optim adam -learning_rate 0.001 38 | 39 | 40 | mv tmp*e8.pt onmt/tests/test_model2.pt 41 | 42 | rm tmp*.pt 43 | fi 44 | ############### TEST TRANSFORMER 45 | if false; then 46 | rm data/*.pt 47 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 -share_vocab 48 | 49 | 50 | $my_python train.py -data data/data -save_model /tmp/tmp -batch_type tokens -batch_size 1024 -accum_count 4 \ 51 | -layers 4 -rnn_size 256 -word_vec_size 256 -encoder_type transformer -decoder_type transformer -share_embedding \ 52 | -train_steps 10000 -world_size 1 -gpu_ranks 0 -max_generator_batches 4 -dropout 0.1 -normalization tokens \ 53 | -max_grad_norm 0 -optim adam -decay_method noam -learning_rate 2 -label_smoothing 0.1 \ 54 | -position_encoding -param_init 0 -warmup_steps 100 -param_init_glorot -adam_beta2 0.998 55 | # 56 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 57 | rm /tmp/tmp*.pt 58 | fi 59 | # 60 | if false; then 61 | $my_python translate.py -gpu 0 -model onmt/tests/test_model.pt \ 62 | -src data/src-val.txt -output onmt/tests/output_hyp.txt -beam 5 -batch_size 16 63 | 64 | fi 65 | 66 | 67 | -------------------------------------------------------------------------------- /deephop/onmt/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here come the tests for attention types and their compatibility 3 | """ 4 | import unittest 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | import onmt 9 | 10 | 11 | class TestAttention(unittest.TestCase): 12 | 13 | def test_masked_global_attention(self): 14 | 15 | source_lengths = torch.IntTensor([7, 3, 5, 2]) 16 | # illegal_weights_mask = torch.ByteTensor([ 17 | # [0, 0, 0, 0, 0, 0, 0], 18 | # [0, 0, 0, 1, 1, 1, 1], 19 | # [0, 0, 0, 0, 0, 1, 1], 20 | # [0, 0, 1, 1, 1, 1, 1]]) 21 | 22 | batch_size = source_lengths.size(0) 23 | dim = 20 24 | 25 | memory_bank = Variable(torch.randn(batch_size, 26 | source_lengths.max(), dim)) 27 | hidden = Variable(torch.randn(batch_size, dim)) 28 | 29 | attn = onmt.modules.GlobalAttention(dim) 30 | 31 | _, alignments = attn(hidden, memory_bank, 32 | memory_lengths=source_lengths) 33 | # TODO: fix for pytorch 0.3 34 | # illegal_weights = alignments.masked_select(illegal_weights_mask) 35 | 36 | # self.assertEqual(0.0, illegal_weights.data.sum()) 37 | -------------------------------------------------------------------------------- /deephop/onmt/tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import copy 7 | import unittest 8 | import glob 9 | import os 10 | import codecs 11 | from collections import Counter 12 | 13 | import torchtext 14 | 15 | import onmt 16 | import onmt.inputters 17 | import onmt.opts 18 | import preprocess 19 | 20 | 21 | parser = argparse.ArgumentParser(description='preprocess.py') 22 | onmt.opts.preprocess_opts(parser) 23 | 24 | SAVE_DATA_PREFIX = 'data/test_preprocess' 25 | 26 | default_opts = [ 27 | '-data_type', 'text', 28 | '-train_src', 'data/src-train.txt', 29 | '-train_tgt', 'data/tgt-train.txt', 30 | '-valid_src', 'data/src-val.txt', 31 | '-valid_tgt', 'data/tgt-val.txt', 32 | '-save_data', SAVE_DATA_PREFIX 33 | ] 34 | 35 | opt = parser.parse_known_args(default_opts)[0] 36 | 37 | 38 | class TestData(unittest.TestCase): 39 | def __init__(self, *args, **kwargs): 40 | super(TestData, self).__init__(*args, **kwargs) 41 | self.opt = opt 42 | 43 | def dataset_build(self, opt): 44 | fields = onmt.inputters.get_fields("text", 0, 0) 45 | 46 | if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0: 47 | with codecs.open(opt.src_vocab, 'w', 'utf-8') as f: 48 | f.write('a\nb\nc\nd\ne\nf\n') 49 | if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0: 50 | with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f: 51 | f.write('a\nb\nc\nd\ne\nf\n') 52 | 53 | train_data_files = preprocess.build_save_dataset('train', fields, opt) 54 | 55 | preprocess.build_save_vocab(train_data_files, fields, opt) 56 | 57 | preprocess.build_save_dataset('valid', fields, opt) 58 | 59 | # Remove the generated *pt files. 60 | for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): 61 | os.remove(pt) 62 | if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab): 63 | os.remove(opt.src_vocab) 64 | if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab): 65 | os.remove(opt.tgt_vocab) 66 | 67 | def test_merge_vocab(self): 68 | va = torchtext.vocab.Vocab(Counter('abbccc')) 69 | vb = torchtext.vocab.Vocab(Counter('eeabbcccf')) 70 | 71 | merged = onmt.inputters.merge_vocabs([va, vb], 2) 72 | 73 | self.assertEqual(Counter({'c': 6, 'b': 4, 'a': 2, 'e': 2, 'f': 1}), 74 | merged.freqs) 75 | # 4 specicials + 2 words (since we pass 2 to merge_vocabs) 76 | self.assertEqual(6, len(merged.itos)) 77 | self.assertTrue('b' in merged.itos) 78 | 79 | 80 | def _add_test(param_setting, methodname): 81 | """ 82 | Adds a Test to TestData according to settings 83 | 84 | Args: 85 | param_setting: list of tuples of (param, setting) 86 | methodname: name of the method that gets called 87 | """ 88 | 89 | def test_method(self): 90 | if param_setting: 91 | opt = copy.deepcopy(self.opt) 92 | for param, setting in param_setting: 93 | setattr(opt, param, setting) 94 | else: 95 | opt = self.opt 96 | getattr(self, methodname)(opt) 97 | if param_setting: 98 | name = 'test_' + methodname + "_" + "_".join( 99 | str(param_setting).split()) 100 | else: 101 | name = 'test_' + methodname + '_standard' 102 | setattr(TestData, name, test_method) 103 | test_method.__name__ = name 104 | 105 | 106 | test_databuild = [[], 107 | [('src_vocab_size', 1), 108 | ('tgt_vocab_size', 1)], 109 | [('src_vocab_size', 10000), 110 | ('tgt_vocab_size', 10000)], 111 | [('src_seq_length', 1)], 112 | [('src_seq_length', 5000)], 113 | [('src_seq_length_trunc', 1)], 114 | [('src_seq_length_trunc', 5000)], 115 | [('tgt_seq_length', 1)], 116 | [('tgt_seq_length', 5000)], 117 | [('tgt_seq_length_trunc', 1)], 118 | [('tgt_seq_length_trunc', 5000)], 119 | [('shuffle', 0)], 120 | [('lower', True)], 121 | [('dynamic_dict', True)], 122 | [('share_vocab', True)], 123 | [('dynamic_dict', True), 124 | ('share_vocab', True)], 125 | [('dynamic_dict', True), 126 | ('max_shard_size', 500000)], 127 | [('src_vocab', '/tmp/src_vocab.txt'), 128 | ('tgt_vocab', '/tmp/tgt_vocab.txt')], 129 | ] 130 | 131 | for p in test_databuild: 132 | _add_test(p, 'dataset_build') 133 | 134 | # Test image preprocessing 135 | test_databuild = [[], 136 | [('tgt_vocab_size', 1)], 137 | [('tgt_vocab_size', 10000)], 138 | [('tgt_seq_length', 1)], 139 | [('tgt_seq_length', 5000)], 140 | [('tgt_seq_length_trunc', 1)], 141 | [('tgt_seq_length_trunc', 5000)], 142 | [('shuffle', 0)], 143 | [('lower', True)], 144 | [('shard_size', 5)], 145 | [('shard_size', 50)], 146 | [('tgt_vocab', '/tmp/tgt_vocab.txt')], 147 | ] 148 | test_databuild_common = [('data_type', 'img'), 149 | ('src_dir', '/tmp/im2text/images'), 150 | ('train_src', '/tmp/im2text/src-train-head.txt'), 151 | ('train_tgt', '/tmp/im2text/tgt-train-head.txt'), 152 | ('valid_src', '/tmp/im2text/src-val-head.txt'), 153 | ('valid_tgt', '/tmp/im2text/tgt-val-head.txt'), 154 | ] 155 | for p in test_databuild: 156 | _add_test(p + test_databuild_common, 'dataset_build') 157 | 158 | # Test audio preprocessing 159 | test_databuild = [[], 160 | [('tgt_vocab_size', 1)], 161 | [('tgt_vocab_size', 10000)], 162 | [('src_seq_length', 1)], 163 | [('src_seq_length', 5000)], 164 | [('src_seq_length_trunc', 3200)], 165 | [('src_seq_length_trunc', 5000)], 166 | [('tgt_seq_length', 1)], 167 | [('tgt_seq_length', 5000)], 168 | [('tgt_seq_length_trunc', 1)], 169 | [('tgt_seq_length_trunc', 5000)], 170 | [('shuffle', 0)], 171 | [('lower', True)], 172 | [('shard_size', 5)], 173 | [('shard_size', 50)], 174 | [('tgt_vocab', '/tmp/tgt_vocab.txt')], 175 | ] 176 | test_databuild_common = [('data_type', 'audio'), 177 | ('src_dir', '/tmp/speech/an4_dataset'), 178 | ('train_src', '/tmp/speech/src-train-head.txt'), 179 | ('train_tgt', '/tmp/speech/tgt-train-head.txt'), 180 | ('valid_src', '/tmp/speech/src-val-head.txt'), 181 | ('valid_tgt', '/tmp/speech/tgt-val-head.txt'), 182 | ('sample_rate', 16000), 183 | ('window_size', 0.04), 184 | ('window_stride', 0.02), 185 | ('window', 'hamming'), 186 | ] 187 | for p in test_databuild: 188 | _add_test(p + test_databuild_common, 'dataset_build') 189 | -------------------------------------------------------------------------------- /deephop/onmt/tests/test_simple.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | 4 | def test_load(): 5 | onmt 6 | pass 7 | -------------------------------------------------------------------------------- /deephop/onmt/train_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Training on a single process 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import os 9 | import random 10 | import torch 11 | 12 | import onmt.opts as opts 13 | from graph_embedding import init_condition_transformer 14 | 15 | from onmt.inputters.inputter import build_dataset_iter, lazily_load_dataset, \ 16 | _load_fields, _collect_report_features 17 | from onmt.model_builder import build_model 18 | from onmt.utils.optimizers import build_optim 19 | from onmt.trainer import build_trainer 20 | from onmt.models import build_model_saver 21 | from onmt.utils.logging import init_logger, logger 22 | 23 | 24 | def _check_save_model_path(opt): 25 | save_model_path = os.path.abspath(opt.save_model) 26 | model_dirname = os.path.dirname(save_model_path) 27 | if not os.path.exists(model_dirname): 28 | os.makedirs(model_dirname) 29 | 30 | 31 | def _tally_parameters(model): 32 | n_params = sum([p.nelement() for p in model.parameters()]) 33 | enc = 0 34 | dec = 0 35 | for name, param in model.named_parameters(): 36 | if 'encoder' in name: 37 | enc += param.nelement() 38 | elif 'decoder' or 'generator' in name: 39 | dec += param.nelement() 40 | return n_params, enc, dec 41 | 42 | 43 | def training_opt_postprocessing(opt, device_id): 44 | if opt.word_vec_size != -1: 45 | opt.src_word_vec_size = opt.word_vec_size 46 | opt.tgt_word_vec_size = opt.word_vec_size 47 | 48 | if opt.layers != -1: 49 | opt.enc_layers = opt.layers 50 | opt.dec_layers = opt.layers 51 | 52 | if opt.rnn_size != -1: 53 | opt.enc_rnn_size = opt.rnn_size 54 | 55 | if opt.arch in ['transformer', 'after_encoding']: 56 | opt.dec_rnn_size = opt.rnn_size + opt.condition_dim 57 | else: 58 | opt.dec_rnn_size = opt.rnn_size 59 | if opt.model_type == 'text' and opt.enc_rnn_size != opt.dec_rnn_size: 60 | raise AssertionError("""We do not support different encoder and 61 | decoder rnn sizes for translation now.""") 62 | 63 | opt.brnn = (opt.encoder_type == "brnn") 64 | 65 | if opt.rnn_type == "SRU" and not opt.gpu_ranks: 66 | raise AssertionError("Using SRU requires -gpu_ranks set.") 67 | 68 | if torch.cuda.is_available() and not opt.gpu_ranks: 69 | logger.info("WARNING: You have a CUDA device, \ 70 | should run with -gpu_ranks") 71 | 72 | if opt.seed > 0: 73 | torch.manual_seed(opt.seed) 74 | # this one is needed for torchtext random call (shuffled iterator) 75 | # in multi gpu it ensures datasets are read in the same order 76 | random.seed(opt.seed) 77 | # some cudnn methods can be random even after fixing the seed 78 | # unless you tell it to be deterministic 79 | torch.backends.cudnn.deterministic = True 80 | 81 | if device_id >= 0: 82 | torch.cuda.set_device(device_id) 83 | if opt.seed > 0: 84 | # These ensure same initialization in multi gpu mode 85 | torch.cuda.manual_seed(opt.seed) 86 | 87 | return opt 88 | 89 | 90 | def main(opt, device_id): 91 | opt = training_opt_postprocessing(opt, device_id) 92 | init_logger(opt.log_file) 93 | # 初始化口袋特征编码器 94 | init_condition_transformer(opt.use_graph_embedding, opt.condition_dim) 95 | # Load checkpoint if we resume from a previous training. 96 | if opt.train_from: 97 | logger.info('Loading checkpoint from %s' % opt.train_from) 98 | checkpoint = torch.load(opt.train_from, 99 | map_location=lambda storage, loc: storage) 100 | model_opt = checkpoint['opt'] 101 | else: 102 | checkpoint = None 103 | model_opt = opt 104 | 105 | # Peek the first dataset to determine the data_type. 106 | # (All datasets have the same data_type). 107 | first_dataset = next(lazily_load_dataset("train", opt)) 108 | data_type = first_dataset.data_type 109 | 110 | # Load fields generated from preprocess phase. 111 | fields = _load_fields(first_dataset, data_type, opt, checkpoint) 112 | # Report src/tgt features. 113 | 114 | src_features, tgt_features = _collect_report_features(fields) 115 | for j, feat in enumerate(src_features): 116 | logger.info(' * src feature %d size = %d' 117 | % (j, len(fields[feat].vocab))) 118 | for j, feat in enumerate(tgt_features): 119 | logger.info(' * tgt feature %d size = %d' 120 | % (j, len(fields[feat].vocab))) 121 | 122 | # Build model. 123 | model = build_model(model_opt, opt, fields, checkpoint) 124 | n_params, enc, dec = _tally_parameters(model) 125 | logger.info('encoder: %d' % enc) 126 | logger.info('decoder: %d' % dec) 127 | logger.info('* number of parameters: %d' % n_params) 128 | _check_save_model_path(opt) 129 | 130 | # Build optimizer. 131 | optim = build_optim(model, opt, checkpoint) 132 | 133 | # Build model saver 134 | model_saver = build_model_saver(model_opt, opt, model, fields, optim) 135 | 136 | trainer = build_trainer(opt, device_id, model, fields, 137 | optim, data_type, model_saver=model_saver) 138 | 139 | def train_iter_fct(): 140 | return build_dataset_iter( 141 | lazily_load_dataset("train", opt), fields, opt) 142 | 143 | def valid_iter_fct(): 144 | return build_dataset_iter( 145 | lazily_load_dataset("valid", opt), fields, opt, is_train=False) 146 | 147 | # Do training. 148 | trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, 149 | opt.valid_steps) 150 | 151 | if opt.tensorboard: 152 | trainer.report_manager.tensorboard_writer.close() 153 | 154 | 155 | if __name__ == "__main__": 156 | parser = argparse.ArgumentParser( 157 | description='train.py', 158 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 159 | 160 | opts.add_md_help_argument(parser) 161 | opts.model_opts(parser) 162 | opts.train_opts(parser) 163 | 164 | opt = parser.parse_args() 165 | main(opt) 166 | -------------------------------------------------------------------------------- /deephop/onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam import Beam, GNMTGlobalScorer 5 | from onmt.translate.penalties import PenaltyBuilder 6 | from onmt.translate.translation_server import TranslationServer, \ 7 | ServerModelError 8 | 9 | __all__ = ['Translator', 'Translation', 'Beam', 10 | 'GNMTGlobalScorer', 'TranslationBuilder', 11 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError'] 12 | -------------------------------------------------------------------------------- /deephop/onmt/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | 9 | Args: 10 | length_pen (str): option name of length pen 11 | cov_pen (str): option name of cov pen 12 | """ 13 | 14 | def __init__(self, cov_pen, length_pen): 15 | self.length_pen = length_pen 16 | self.cov_pen = cov_pen 17 | 18 | def coverage_penalty(self): 19 | if self.cov_pen == "wu": 20 | return self.coverage_wu 21 | elif self.cov_pen == "summary": 22 | return self.coverage_summary 23 | else: 24 | return self.coverage_none 25 | 26 | def length_penalty(self): 27 | if self.length_pen == "wu": 28 | return self.length_wu 29 | elif self.length_pen == "avg": 30 | return self.length_average 31 | else: 32 | return self.length_none 33 | 34 | """ 35 | Below are all the different penalty terms implemented so far 36 | """ 37 | 38 | def coverage_wu(self, beam, cov, beta=0.): 39 | """ 40 | NMT coverage re-ranking score from 41 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 42 | """ 43 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 44 | return beta * penalty 45 | 46 | def coverage_summary(self, beam, cov, beta=0.): 47 | """ 48 | Our summary penalty. 49 | """ 50 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 51 | penalty -= cov.size(1) 52 | return beta * penalty 53 | 54 | def coverage_none(self, beam, cov, beta=0.): 55 | """ 56 | returns zero as penalty 57 | """ 58 | return beam.scores.clone().fill_(0.0) 59 | 60 | def length_wu(self, beam, logprobs, alpha=0.): 61 | """ 62 | NMT length re-ranking score from 63 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 64 | """ 65 | 66 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 67 | ((5 + 1) ** alpha)) 68 | return (logprobs / modifier) 69 | 70 | def length_average(self, beam, logprobs, alpha=0.): 71 | """ 72 | Returns the average probability of tokens in a sequence. 73 | """ 74 | return logprobs / len(beam.next_ys) 75 | 76 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 77 | """ 78 | Returns unmodified scores. 79 | """ 80 | return logprobs 81 | -------------------------------------------------------------------------------- /deephop/onmt/translate/translation.py: -------------------------------------------------------------------------------- 1 | """ Translation main class """ 2 | from __future__ import division, unicode_literals 3 | from __future__ import print_function 4 | 5 | import torch 6 | import onmt.inputters as inputters 7 | 8 | 9 | class TranslationBuilder(object): 10 | """ 11 | Build a word-based translation from the batch output 12 | of translator and the underlying dictionaries. 13 | 14 | Replacement based on "Addressing the Rare Word 15 | Problem in Neural Machine Translation" :cite:`Luong2015b` 16 | 17 | Args: 18 | data (DataSet): 19 | fields (dict of Fields): data fields 20 | n_best (int): number of translations produced 21 | replace_unk (bool): replace unknown words using attention 22 | has_tgt (bool): will the batch have gold targets 23 | """ 24 | 25 | def __init__(self, data, fields, n_best=1, replace_unk=False, 26 | has_tgt=False): 27 | self.data = data 28 | self.fields = fields 29 | self.n_best = n_best 30 | self.replace_unk = replace_unk 31 | self.has_tgt = has_tgt 32 | 33 | def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): 34 | vocab = self.fields["tgt"].vocab 35 | tokens = [] 36 | for tok in pred: 37 | if tok < len(vocab): 38 | tokens.append(vocab.itos[tok]) 39 | else: 40 | tokens.append(src_vocab.itos[tok - len(vocab)]) 41 | if tokens[-1] == inputters.EOS_WORD: 42 | tokens = tokens[:-1] 43 | break 44 | if self.replace_unk and (attn is not None) and (src is not None): 45 | for i in range(len(tokens)): 46 | if tokens[i] == vocab.itos[inputters.UNK]: 47 | _, max_index = attn[i].max(0) 48 | tokens[i] = src_raw[max_index.item()] 49 | return tokens 50 | 51 | def from_batch(self, translation_batch): 52 | batch = translation_batch["batch"] 53 | assert(len(translation_batch["gold_score"]) == 54 | len(translation_batch["predictions"])) 55 | batch_size = batch.batch_size 56 | 57 | preds, pred_score, attn, gold_score, indices = list(zip( 58 | *sorted(zip(translation_batch["predictions"], 59 | translation_batch["scores"], 60 | translation_batch["attention"], 61 | translation_batch["gold_score"], 62 | batch.indices.data), 63 | key=lambda x: x[-1]))) 64 | 65 | # Sorting 66 | inds, perm = torch.sort(batch.indices.data) 67 | data_type = self.data.data_type 68 | if data_type == 'text': 69 | src = batch.src[0].data.index_select(1, perm) 70 | else: 71 | src = None 72 | 73 | if self.has_tgt: 74 | tgt = batch.tgt.data.index_select(1, perm) 75 | else: 76 | tgt = None 77 | 78 | translations = [] 79 | for b in range(batch_size): 80 | if data_type == 'text': 81 | src_vocab = self.data.src_vocabs[inds[b]] \ 82 | if self.data.src_vocabs else None 83 | src_raw = self.data.examples[inds[b]].src 84 | else: 85 | src_vocab = None 86 | src_raw = None 87 | pred_sents = [self._build_target_tokens( 88 | src[:, b] if src is not None else None, 89 | src_vocab, src_raw, 90 | preds[b][n], attn[b][n]) 91 | for n in range(self.n_best)] 92 | gold_sent = None 93 | if tgt is not None: 94 | gold_sent = self._build_target_tokens( 95 | src[:, b] if src is not None else None, 96 | src_vocab, src_raw, 97 | tgt[1:, b] if tgt is not None else None, None) 98 | 99 | translation = Translation(src[:, b] if src is not None else None, 100 | src_raw, pred_sents, 101 | attn[b], pred_score[b], gold_sent, 102 | gold_score[b]) 103 | translations.append(translation) 104 | 105 | return translations 106 | 107 | 108 | class Translation(object): 109 | """ 110 | Container for a translated sentence. 111 | 112 | Attributes: 113 | src (`LongTensor`): src word ids 114 | src_raw ([str]): raw src words 115 | 116 | pred_sents ([[str]]): words from the n-best translations 117 | pred_scores ([[float]]): log-probs of n-best translations 118 | attns ([`FloatTensor`]) : attention dist for each translation 119 | gold_sent ([str]): words from gold translation 120 | gold_score ([float]): log-prob of gold translation 121 | 122 | """ 123 | 124 | def __init__(self, src, src_raw, pred_sents, 125 | attn, pred_scores, tgt_sent, gold_score): 126 | self.src = src 127 | self.src_raw = src_raw 128 | self.pred_sents = pred_sents 129 | self.attns = attn 130 | self.pred_scores = pred_scores 131 | self.gold_sent = tgt_sent 132 | self.gold_score = gold_score 133 | 134 | def log(self, sent_number): 135 | """ 136 | Log translation. 137 | """ 138 | 139 | output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) 140 | 141 | best_pred = self.pred_sents[0] 142 | best_score = self.pred_scores[0] 143 | pred_sent = ' '.join(best_pred) 144 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 145 | output += "PRED SCORE: {:.4f}\n".format(best_score) 146 | 147 | if self.gold_sent is not None: 148 | tgt_sent = ' '.join(self.gold_sent) 149 | output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) 150 | output += ("GOLD SCORE: {:.4f}\n".format(self.gold_score)) 151 | if len(self.pred_sents) > 1: 152 | output += '\nBEST HYP:\n' 153 | for score, sent in zip(self.pred_scores, self.pred_sents): 154 | output += "[{:.4f}] {}\n".format(score, sent) 155 | 156 | return output 157 | -------------------------------------------------------------------------------- /deephop/onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import aeq, use_gpu 3 | from onmt.utils.report_manager import ReportMgr, build_report_manager 4 | from onmt.utils.statistics import Statistics 5 | from onmt.utils.optimizers import build_optim, MultipleOptimizer, \ 6 | Optimizer 7 | 8 | __all__ = ["aeq", "use_gpu", "ReportMgr", 9 | "build_report_manager", "Statistics", 10 | "build_optim", "MultipleOptimizer", "Optimizer"] 11 | -------------------------------------------------------------------------------- /deephop/onmt/utils/cnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | import onmt.modules 10 | 11 | SCALE_WEIGHT = 0.5 ** 0.5 12 | 13 | 14 | def shape_transform(x): 15 | """ Tranform the size of the tensors to fit for conv input. """ 16 | return torch.unsqueeze(torch.transpose(x, 1, 2), 3) 17 | 18 | 19 | class GatedConv(nn.Module): 20 | """ Gated convolution for CNN class """ 21 | 22 | def __init__(self, input_size, width=3, dropout=0.2, nopad=False): 23 | super(GatedConv, self).__init__() 24 | self.conv = onmt.modules.WeightNormConv2d( 25 | input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), 26 | padding=(width // 2 * (1 - nopad), 0)) 27 | init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5) 28 | self.dropout = nn.Dropout(dropout) 29 | 30 | def forward(self, x_var): 31 | x_var = self.dropout(x_var) 32 | x_var = self.conv(x_var) 33 | out, gate = x_var.split(int(x_var.size(1) / 2), 1) 34 | out = out * F.sigmoid(gate) 35 | return out 36 | 37 | 38 | class StackedCNN(nn.Module): 39 | """ Stacked CNN class """ 40 | 41 | def __init__(self, num_layers, input_size, cnn_kernel_width=3, 42 | dropout=0.2): 43 | super(StackedCNN, self).__init__() 44 | self.dropout = dropout 45 | self.num_layers = num_layers 46 | self.layers = nn.ModuleList() 47 | for _ in range(num_layers): 48 | self.layers.append( 49 | GatedConv(input_size, cnn_kernel_width, dropout)) 50 | 51 | def forward(self, x): 52 | for conv in self.layers: 53 | x = x + conv(x) 54 | x *= SCALE_WEIGHT 55 | return x 56 | -------------------------------------------------------------------------------- /deephop/onmt/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | import torch.distributed 12 | 13 | from onmt.utils.logging import logger 14 | 15 | 16 | def is_master(opt, device_id): 17 | return opt.gpu_ranks[device_id] == 0 18 | 19 | 20 | def multi_init(opt, device_id): 21 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format( 22 | master_ip=opt.master_ip, 23 | master_port=opt.master_port) 24 | dist_world_size = opt.world_size 25 | torch.distributed.init_process_group( 26 | backend=opt.gpu_backend, init_method=dist_init_method, 27 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(opt, device_id): 30 | logger.disabled = True 31 | 32 | return gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /deephop/onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setFormatter(log_format) 21 | logger.addHandler(file_handler) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /deephop/onmt/utils/masking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | class ChemVocabMask(object): 5 | """ 6 | Chemical reaction token mask. Can either be initialized from a saved state or 7 | using a vocab file. Used to hide tokens that do not appear in the source sequence. 8 | """ 9 | def __init__(self, vocab=None, from_file=None): 10 | 11 | if from_file is not None: 12 | checkpoint = torch.load(from_file) 13 | self.always_active = checkpoint['always_active'] 14 | self.atom_vocab_dict = checkpoint['atom_vocab_dict'] 15 | self.vocab_atom_dict = checkpoint['vocab_atom_dict'] 16 | self.vocab_vocab_dict = checkpoint['vocab_vocab_dict'] 17 | self.vocab = checkpoint['vocab'] 18 | 19 | elif vocab is not None: 20 | self.vocab = vocab 21 | self.initialise_dicts() 22 | 23 | def save_dicts(self, file_path): 24 | torch.save({ 25 | 'always_active': self.always_active, 26 | 'atom_vocab_dict': self.atom_vocab_dict, 27 | 'vocab_atom_dict': self.vocab_vocab_dict, 28 | 'vocab_vocab_dict': self.vocab_vocab_dict, 29 | 'vocab': self.vocab 30 | }, file_path) 31 | 32 | def initialize_dicts(self): 33 | """ 34 | Use the vocab file, generated by the preprocess.py to initialize the 35 | atom_vocab, vocab_atom, and vocab_vocab dicts. 36 | :return: 37 | """ 38 | from rdkit import Chem 39 | always_active = [] 40 | atom_vocab_dict = {} 41 | for i, v in enumerate(self.vocab.itos): 42 | mol = Chem.MolFromSmiles(v) 43 | if mol is not None: 44 | atomic_num = mol.GetAtoms()[0].GetAtomicNum() 45 | 46 | if atomic_num in atom_vocab_dict.keys(): 47 | atom_vocab_dict[atomic_num].append(i) 48 | else: 49 | atom_vocab_dict[atomic_num] = [i] 50 | else: 51 | new_v = '' 52 | first_alpha = True 53 | for c in v: 54 | if first_alpha and c.isalpha(): 55 | new_v += c.upper() 56 | first_alpha = False 57 | else: 58 | new_v += c 59 | mol = Chem.MolFromSmiles(new_v) 60 | 61 | if mol is not None: 62 | atomic_num = mol.GetAtoms()[0].GetAtomicNum() 63 | 64 | if atomic_num in atom_vocab_dict.keys(): 65 | atom_vocab_dict[atomic_num].append(i) 66 | else: 67 | atom_vocab_dict[atomic_num] = [i] 68 | else: 69 | always_active.append(i) 70 | self.always_active = always_active 71 | self.atom_vocab_dict = atom_vocab_dict 72 | vocab_atom_dict = {} 73 | for k, v in atom_vocab_dict.items(): 74 | for token in v: 75 | vocab_atom_dict[token] = k 76 | self.vocab_atom_dict = vocab_atom_dict 77 | 78 | vocab_vocab_dict = {} 79 | for k, v in atom_vocab_dict.items(): 80 | for i in v: 81 | vocab_vocab_dict[i] = v 82 | for i in always_active: 83 | vocab_vocab_dict[i] = always_active 84 | self.vocab_vocab_dict = vocab_vocab_dict 85 | 86 | def _get_valid_tokens_per_src_seq_in_batch(self, src): 87 | valid_tokens_per_seq = [ 88 | np.unique([vocab_list for voc in np.unique(s.cpu().numpy()) for vocab_list in self.vocab_vocab_dict[voc]]) for s 89 | in src.t()] 90 | return valid_tokens_per_seq 91 | 92 | def get_log_probs_masking_tensor(self, src, beam_size): 93 | """ 94 | Make a matrix same beam * batch_size x vocab_size, where valid tokens are have entry 1 and other 1e-15. 95 | Therefore, if this multiplies the log_prob matrix, only valid tokens are predicted. 96 | """ 97 | valid_tokens_per_seq = self._get_valid_tokens_per_src_seq_in_batch(src) 98 | mask = torch.stack([(torch.ones(len(self.vocab)).index_fill(0, torch.tensor(valid_tokens), 0) \ 99 | * 1e15).index_fill(0,torch.tensor(valid_tokens), 1) 100 | for valid_tokens in valid_tokens_per_seq for i in range(beam_size)]) 101 | mask = mask.to(src.device) 102 | return mask 103 | 104 | def _get_unique_vocab_counts_from_source(self, src): 105 | unique_counts_dicts = [dict(zip(*np.unique(s.cpu().numpy(), return_counts=True))) for s in src.t()] 106 | return unique_counts_dicts -------------------------------------------------------------------------------- /deephop/onmt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def aeq(*args): 7 | """ 8 | Assert all arguments have the same value 9 | """ 10 | arguments = (arg for arg in args) 11 | first = next(arguments) 12 | assert all(arg == first for arg in arguments), \ 13 | "Not all arguments have the same value: " + str(args) 14 | 15 | 16 | def sequence_mask(lengths, max_len=None): 17 | """ 18 | Creates a boolean mask from sequence lengths. 19 | """ 20 | batch_size = lengths.numel() 21 | max_len = max_len or lengths.max() 22 | return (torch.arange(0, max_len) 23 | .type_as(lengths) 24 | .repeat(batch_size, 1) 25 | .lt(lengths.unsqueeze(1))) 26 | 27 | 28 | def tile(x, count, dim=0): 29 | """ 30 | Tiles x on dimension dim count times. 31 | """ 32 | perm = list(range(len(x.size()))) 33 | if dim != 0: 34 | perm[0], perm[dim] = perm[dim], perm[0] 35 | x = x.permute(perm).contiguous() 36 | out_size = list(x.size()) 37 | out_size[0] *= count 38 | batch = x.size(0) 39 | x = x.view(batch, -1) \ 40 | .transpose(0, 1) \ 41 | .repeat(count, 1) \ 42 | .transpose(0, 1) \ 43 | .contiguous() \ 44 | .view(*out_size) 45 | if dim != 0: 46 | x = x.permute(perm).contiguous() 47 | return x 48 | 49 | 50 | def use_gpu(opt): 51 | """ 52 | Creates a boolean if gpu used 53 | """ 54 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 55 | (hasattr(opt, 'gpu') and opt.gpu > -1) 56 | -------------------------------------------------------------------------------- /deephop/onmt/utils/report_manager.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | import time 4 | from datetime import datetime 5 | 6 | import onmt 7 | 8 | from onmt.utils.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | writer = SummaryWriter(opt.tensorboard_log_dir 15 | + datetime.now().strftime("/%b-%d_%H-%M-%S"), 16 | comment="Unmt") 17 | else: 18 | writer = None 19 | 20 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 21 | tensorboard_writer=writer) 22 | return report_mgr 23 | 24 | 25 | class ReportMgrBase(object): 26 | """ 27 | Report Manager Base class 28 | Inherited classes should override: 29 | * `_report_training` 30 | * `_report_step` 31 | """ 32 | 33 | def __init__(self, report_every, start_time=-1.): 34 | """ 35 | Args: 36 | report_every(int): Report status every this many sentences 37 | start_time(float): manually set report start time. Negative values 38 | means that you will need to set it later or use `start()` 39 | """ 40 | self.report_every = report_every 41 | self.progress_step = 0 42 | self.start_time = start_time 43 | 44 | def start(self): 45 | self.start_time = time.time() 46 | 47 | def log(self, *args, **kwargs): 48 | logger.info(*args, **kwargs) 49 | 50 | def report_training(self, step, num_steps, learning_rate, 51 | report_stats, multigpu=False): 52 | """ 53 | This is the user-defined batch-level traing progress 54 | report function. 55 | 56 | Args: 57 | step(int): current step count. 58 | num_steps(int): total number of batches. 59 | learning_rate(float): current learning rate. 60 | report_stats(Statistics): old Statistics instance. 61 | Returns: 62 | report_stats(Statistics): updated Statistics instance. 63 | """ 64 | if self.start_time < 0: 65 | raise ValueError("""ReportMgr needs to be started 66 | (set 'start_time' or use 'start()'""") 67 | 68 | if multigpu: 69 | report_stats = onmt.utils.Statistics.all_gather_stats(report_stats) 70 | 71 | if step % self.report_every == 0: 72 | self._report_training( 73 | step, num_steps, learning_rate, report_stats) 74 | self.progress_step += 1 75 | return onmt.utils.Statistics() 76 | 77 | def _report_training(self, *args, **kwargs): 78 | """ To be overridden """ 79 | raise NotImplementedError() 80 | 81 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 82 | """ 83 | Report stats of a step 84 | 85 | Args: 86 | train_stats(Statistics): training stats 87 | valid_stats(Statistics): validation stats 88 | lr(float): current learning rate 89 | """ 90 | self._report_step( 91 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 92 | 93 | def _report_step(self, *args, **kwargs): 94 | raise NotImplementedError() 95 | 96 | 97 | class ReportMgr(ReportMgrBase): 98 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 99 | """ 100 | A report manager that writes statistics on standard output as well as 101 | (optionally) TensorBoard 102 | 103 | Args: 104 | report_every(int): Report status every this many sentences 105 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 106 | The TensorBoard Summary writer to use or None 107 | """ 108 | super(ReportMgr, self).__init__(report_every, start_time) 109 | self.tensorboard_writer = tensorboard_writer 110 | 111 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 112 | if self.tensorboard_writer is not None: 113 | stats.log_tensorboard( 114 | prefix, self.tensorboard_writer, learning_rate, step) 115 | 116 | def _report_training(self, step, num_steps, learning_rate, 117 | report_stats): 118 | """ 119 | See base class method `ReportMgrBase.report_training`. 120 | """ 121 | report_stats.output(step, num_steps, 122 | learning_rate, self.start_time) 123 | 124 | # Log the progress using the number of batches on the x-axis. 125 | self.maybe_log_tensorboard(report_stats, 126 | "progress", 127 | learning_rate, 128 | self.progress_step) 129 | report_stats = onmt.utils.Statistics() 130 | 131 | return report_stats 132 | 133 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 134 | """ 135 | See base class method `ReportMgrBase.report_step`. 136 | """ 137 | if train_stats is not None: 138 | self.log('Train perplexity: %g' % train_stats.ppl()) 139 | self.log('Train accuracy: %g' % train_stats.accuracy()) 140 | 141 | self.maybe_log_tensorboard(train_stats, 142 | "train", 143 | lr, 144 | step) 145 | 146 | if valid_stats is not None: 147 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 148 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 149 | 150 | self.maybe_log_tensorboard(valid_stats, 151 | "valid", 152 | lr, 153 | step) 154 | -------------------------------------------------------------------------------- /deephop/onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | from __future__ import division 5 | 6 | import torch.nn as nn 7 | import onmt.models 8 | 9 | 10 | def rnn_factory(rnn_type, **kwargs): 11 | """ rnn factory, Use pytorch version when available. """ 12 | no_pack_padded_seq = False 13 | if rnn_type == "SRU": 14 | # SRU doesn't support PackedSequence. 15 | no_pack_padded_seq = True 16 | rnn = onmt.models.sru.SRU(**kwargs) 17 | else: 18 | rnn = getattr(nn, rnn_type)(**kwargs) 19 | return rnn, no_pack_padded_seq 20 | -------------------------------------------------------------------------------- /deephop/onmt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import math 5 | import sys 6 | 7 | from torch.distributed import get_rank 8 | from onmt.utils.distributed import all_gather_list 9 | from onmt.utils.logging import logger 10 | 11 | 12 | class Statistics(object): 13 | """ 14 | Accumulator for loss statistics. 15 | Currently calculates: 16 | 17 | * accuracy 18 | * perplexity 19 | * elapsed time 20 | """ 21 | 22 | def __init__(self, loss=0, n_words=0, n_correct=0): 23 | self.loss = loss 24 | self.n_words = n_words 25 | self.n_correct = n_correct 26 | self.n_src_words = 0 27 | self.start_time = time.time() 28 | 29 | @staticmethod 30 | def all_gather_stats(stat, max_size=4096): 31 | """ 32 | Gather a `Statistics` object accross multiple process/nodes 33 | 34 | Args: 35 | stat(:obj:Statistics): the statistics object to gather 36 | accross all processes/nodes 37 | max_size(int): max buffer size to use 38 | 39 | Returns: 40 | `Statistics`, the update stats object 41 | """ 42 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 43 | return stats[0] 44 | 45 | @staticmethod 46 | def all_gather_stats_list(stat_list, max_size=4096): 47 | """ 48 | Gather a `Statistics` list accross all processes/nodes 49 | 50 | Args: 51 | stat_list(list([`Statistics`])): list of statistics objects to 52 | gather accross all processes/nodes 53 | max_size(int): max buffer size to use 54 | 55 | Returns: 56 | our_stats(list([`Statistics`])): list of updated stats 57 | """ 58 | # Get a list of world_size lists with len(stat_list) Statistics objects 59 | all_stats = all_gather_list(stat_list, max_size=max_size) 60 | 61 | our_rank = get_rank() 62 | our_stats = all_stats[our_rank] 63 | for other_rank, stats in enumerate(all_stats): 64 | if other_rank == our_rank: 65 | continue 66 | for i, stat in enumerate(stats): 67 | our_stats[i].update(stat, update_n_src_words=True) 68 | return our_stats 69 | 70 | def update(self, stat, update_n_src_words=False): 71 | """ 72 | Update statistics by suming values with another `Statistics` object 73 | 74 | Args: 75 | stat: another statistic object 76 | update_n_src_words(bool): whether to update (sum) `n_src_words` 77 | or not 78 | 79 | """ 80 | self.loss += stat.loss 81 | self.n_words += stat.n_words 82 | self.n_correct += stat.n_correct 83 | 84 | if update_n_src_words: 85 | self.n_src_words += stat.n_src_words 86 | 87 | def accuracy(self): 88 | """ compute accuracy """ 89 | return 100 * (self.n_correct / self.n_words) 90 | 91 | def xent(self): 92 | """ compute cross entropy """ 93 | return self.loss / self.n_words 94 | 95 | def ppl(self): 96 | """ compute perplexity """ 97 | return math.exp(min(self.loss / self.n_words, 100)) 98 | 99 | def elapsed_time(self): 100 | """ compute elapsed time """ 101 | return time.time() - self.start_time 102 | 103 | def output(self, step, num_steps, learning_rate, start): 104 | """Write out statistics to stdout. 105 | 106 | Args: 107 | step (int): current step 108 | n_batch (int): total batches 109 | start (int): start time of step. 110 | """ 111 | t = self.elapsed_time() 112 | logger.info( 113 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + 114 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 115 | % (step, num_steps, 116 | self.accuracy(), 117 | self.ppl(), 118 | self.xent(), 119 | learning_rate, 120 | self.n_src_words / (t + 1e-5), 121 | self.n_words / (t + 1e-5), 122 | time.time() - start)) 123 | sys.stdout.flush() 124 | 125 | def log_tensorboard(self, prefix, writer, learning_rate, step): 126 | """ display statistics to tensorboard """ 127 | t = self.elapsed_time() 128 | writer.add_scalar(prefix + "/xent", self.xent(), step) 129 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 130 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 131 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 132 | writer.add_scalar(prefix + "/lr", learning_rate, step) 133 | -------------------------------------------------------------------------------- /deephop/protein_emb.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/deephop/protein_emb.pkl -------------------------------------------------------------------------------- /deephop/replace_torchtext/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Batch(object): 5 | """Defines a batch of examples along with its Fields. 6 | 7 | Attributes: 8 | batch_size: Number of examples in the batch. 9 | dataset: A reference to the dataset object the examples come from 10 | (which itself contains the dataset's Field objects). 11 | train: Deprecated: this attribute is left for backwards compatibility, 12 | however it is UNUSED as of the merger with pytorch 0.4. 13 | input_fields: The names of the fields that are used as input for the model 14 | target_fields: The names of the fields that are used as targets during 15 | model training 16 | 17 | Also stores the Variable for each column in the batch as an attribute. 18 | """ 19 | 20 | def __init__(self, data=None, dataset=None, device=None): 21 | """Create a Batch from a list of examples.""" 22 | if data is not None: 23 | self.batch_size = len(data) 24 | self.dataset = dataset 25 | self.fields = dataset.fields.keys() # copy field names 26 | self.input_fields = [k for k, v in dataset.fields.items() if 27 | v is not None and not v.is_target] 28 | self.target_fields = [k for k, v in dataset.fields.items() if 29 | v is not None and v.is_target] 30 | 31 | for (name, field) in dataset.fields.items(): 32 | if field is not None: 33 | batch = [getattr(x, name) for x in data if hasattr(x, name)] 34 | if name in ['graph', 'condition_target']: 35 | setattr(self, name, batch) 36 | else: 37 | setattr(self, name, field.process(batch, device=device)) 38 | 39 | 40 | @classmethod 41 | def fromvars(cls, dataset, batch_size, train=None, **kwargs): 42 | """Create a Batch directly from a number of Variables.""" 43 | batch = cls() 44 | batch.batch_size = batch_size 45 | batch.dataset = dataset 46 | batch.fields = dataset.fields.keys() 47 | for k, v in kwargs.items(): 48 | setattr(batch, k, v) 49 | return batch 50 | 51 | def __repr__(self): 52 | return str(self) 53 | 54 | def __str__(self): 55 | if not self.__dict__: 56 | return 'Empty {} instance'.format(torch.typename(self)) 57 | 58 | fields_to_index = filter(lambda field: field is not None, self.fields) 59 | var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name)) 60 | for name in fields_to_index if hasattr(self, name)]) 61 | 62 | data_str = (' from {}'.format(self.dataset.name.upper()) 63 | if hasattr(self.dataset, 'name') and 64 | isinstance(self.dataset.name, str) else '') 65 | 66 | strt = '[{} of size {}{}]\n{}'.format(torch.typename(self), 67 | self.batch_size, data_str, var_strs) 68 | return '\n' + strt 69 | 70 | def __len__(self): 71 | return self.batch_size 72 | 73 | def _get_field_values(self, fields): 74 | if len(fields) == 0: 75 | return None 76 | elif len(fields) == 1: 77 | return getattr(self, fields[0]) 78 | else: 79 | return tuple(getattr(self, f) for f in fields) 80 | 81 | def __iter__(self): 82 | yield self._get_field_values(self.input_fields) 83 | yield self._get_field_values(self.target_fields) 84 | 85 | 86 | def _short_str(tensor): 87 | # unwrap variable to tensor 88 | if not torch.is_tensor(tensor): 89 | # (1) unpack variable 90 | if hasattr(tensor, 'data'): 91 | tensor = getattr(tensor, 'data') 92 | # (2) handle include_lengths 93 | elif isinstance(tensor, tuple): 94 | return str(tuple(_short_str(t) for t in tensor)) 95 | # (3) fallback to default str 96 | else: 97 | return str(tensor) 98 | 99 | # copied from torch _tensor_str 100 | size_str = 'x'.join(str(size) for size in tensor.size()) 101 | device_str = '' if not tensor.is_cuda else \ 102 | ' (GPU {})'.format(tensor.get_device()) 103 | strt = '[{} of size {}{}]'.format(torch.typename(tensor), 104 | size_str, device_str) 105 | return strt 106 | -------------------------------------------------------------------------------- /deephop/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | import argparse 7 | import os 8 | import signal 9 | import torch 10 | 11 | import onmt.opts as opts 12 | import onmt.utils.distributed 13 | 14 | from onmt.utils.logging import logger 15 | from onmt.train_single import main as single_main 16 | 17 | 18 | def main(opt): 19 | if opt.rnn_type == "SRU" and not opt.gpu_ranks: 20 | raise AssertionError("Using SRU requires -gpu_ranks set.") 21 | 22 | if opt.epochs: 23 | raise AssertionError("-epochs is deprecated please use -train_steps.") 24 | 25 | if opt.truncated_decoder > 0 and opt.accum_count > 1: 26 | raise AssertionError("BPTT is not compatible with -accum > 1") 27 | 28 | if len(opt.gpuid) > 1: 29 | raise AssertionError("gpuid is deprecated \ 30 | see world_size and gpu_ranks") 31 | 32 | nb_gpu = len(opt.gpu_ranks) 33 | 34 | if opt.world_size > 1: 35 | mp = torch.multiprocessing.get_context('spawn') 36 | # Create a thread to listen for errors in the child processes. 37 | error_queue = mp.SimpleQueue() 38 | error_handler = ErrorHandler(error_queue) 39 | # Train with multiprocessing. 40 | procs = [] 41 | for device_id in range(nb_gpu): 42 | procs.append(mp.Process(target=run, args=( 43 | opt, device_id, error_queue,), daemon=True)) 44 | procs[device_id].start() 45 | logger.info(" Starting process pid: %d " % procs[device_id].pid) 46 | error_handler.add_child(procs[device_id].pid) 47 | for p in procs: 48 | p.join() 49 | 50 | elif nb_gpu == 1: # case 1 GPU only 51 | single_main(opt, 0) 52 | else: # case only CPU 53 | single_main(opt, -1) 54 | 55 | 56 | def run(opt, device_id, error_queue): 57 | """ run process """ 58 | try: 59 | gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) 60 | if gpu_rank != opt.gpu_ranks[device_id]: 61 | raise AssertionError("An error occurred in \ 62 | Distributed initialization") 63 | single_main(opt, device_id) 64 | except KeyboardInterrupt: 65 | pass # killed by parent, do nothing 66 | except Exception: 67 | # propagate exception to parent process, keeping original traceback 68 | import traceback 69 | error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) 70 | 71 | 72 | class ErrorHandler(object): 73 | """A class that listens for exceptions in children processes and propagates 74 | the tracebacks to the parent process.""" 75 | 76 | def __init__(self, error_queue): 77 | """ init error handler """ 78 | import signal 79 | import threading 80 | self.error_queue = error_queue 81 | self.children_pids = [] 82 | self.error_thread = threading.Thread( 83 | target=self.error_listener, daemon=True) 84 | self.error_thread.start() 85 | signal.signal(signal.SIGUSR1, self.signal_handler) 86 | 87 | def add_child(self, pid): 88 | """ error handler """ 89 | self.children_pids.append(pid) 90 | 91 | def error_listener(self): 92 | """ error listener """ 93 | (rank, original_trace) = self.error_queue.get() 94 | self.error_queue.put((rank, original_trace)) 95 | os.kill(os.getpid(), signal.SIGUSR1) 96 | 97 | def signal_handler(self, signalnum, stackframe): 98 | """ signal handler """ 99 | for pid in self.children_pids: 100 | os.kill(pid, signal.SIGINT) # kill children processes 101 | (rank, original_trace) = self.error_queue.get() 102 | msg = """\n\n-- Tracebacks above this line can probably 103 | be ignored --\n\n""" 104 | msg += original_trace 105 | raise Exception(msg) 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser( 110 | description='train.py', 111 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 112 | 113 | opts.add_md_help_argument(parser) 114 | opts.model_opts(parser) 115 | opts.train_opts(parser) 116 | group = parser.add_argument_group('fulmz_ext') 117 | group.add_argument('-use_graph_embedding', action='store_true', default=False, 118 | help='using the embedding of protein as condition') 119 | group.add_argument('-condition_dim', type=int, default=0, 120 | help='embedding size of protein') 121 | group.add_argument('-arch', type=str, default='before_linear', choices=['before_linear', 'after_encoding', 'after_decoding', 'no_cond', 122 | 'transformer'], 123 | help='model architecture') 124 | 125 | opt = parser.parse_args() 126 | main(opt) 127 | -------------------------------------------------------------------------------- /deephop/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division, unicode_literals 5 | import argparse 6 | 7 | from graph_embedding import init_condition_transformer 8 | from onmt.utils.logging import init_logger 9 | from onmt.translate.translator import build_translator 10 | 11 | import onmt.inputters 12 | import onmt.translate 13 | import onmt 14 | import onmt.model_builder 15 | import onmt.modules 16 | import onmt.opts 17 | from split_data import TASKS, include_list 18 | import os 19 | 20 | def translating_opt_postprocessing(opt): 21 | if opt.rnn_size != -1: 22 | opt.enc_rnn_size = opt.rnn_size 23 | 24 | if opt.arch in ['transformer', 'after_encoding']: 25 | opt.dec_rnn_size = opt.rnn_size + opt.condition_dim 26 | else: 27 | opt.dec_rnn_size = opt.rnn_size 28 | if opt.model_type == 'text' and opt.enc_rnn_size != opt.dec_rnn_size: 29 | raise AssertionError("""We do not support different encoder and 30 | decoder rnn sizes for translation now.""") 31 | 32 | return opt 33 | 34 | def main(opt): 35 | # 初始化口袋特征编码器 36 | init_condition_transformer(opt.use_graph_embedding, opt.condition_dim) 37 | base_dir = os.path.dirname(opt.output) 38 | os.makedirs(base_dir, exist_ok=True) 39 | src_path = f"{base_dir}/src-test-protein.txt" 40 | # tgt_path = f"{base_dir}/tgt-test-protein.txt" 41 | cond_path = f"{base_dir}/cond-test-protein.txt" 42 | if opt.proteins is not None and len(opt.proteins) > 0 and not os.path.isfile(src_path) and not os.path.isfile(cond_path): 43 | with open(opt.cond) as reader: 44 | cond_lines = [[i, int(s.rstrip())] for i, s in enumerate(reader.readlines())] 45 | protein_list = [TASKS.index(s) for s in opt.proteins] 46 | line_index = [i for i, v in cond_lines if v in protein_list] 47 | 48 | with open(opt.src) as reader: 49 | lines = reader.readlines() 50 | src_list = [lines[i] for i in line_index] 51 | with open(opt.cond) as reader: 52 | lines = reader.readlines() 53 | cond_list = [lines[i] for i in line_index] 54 | 55 | uniq_set = set([f"{smi}_{cond}" for smi, cond in zip(src_list, cond_list)]) 56 | with open(src_path, 'w') as src_writer, open(cond_path, 'w') as writer: 57 | for v in uniq_set: 58 | smi, cond = v.split('_') 59 | src_writer.writelines([smi]) 60 | writer.writelines([cond]) 61 | opt.src = src_path 62 | # opt.tgt = tgt_path 63 | opt.cond = cond_path 64 | translator = build_translator(opt, report_score=True) 65 | translator.translate(src_path=opt.src, 66 | tgt_path=opt.tgt, 67 | src_dir=opt.src_dir, 68 | batch_size=opt.batch_size, 69 | attn_debug=opt.attn_debug, 70 | condition_path=opt.cond, 71 | prepare_pt_file=opt.prepare_pt_file, 72 | opt = opt) 73 | 74 | 75 | if __name__ == "__main__": 76 | parser = argparse.ArgumentParser( 77 | description='translate.py', 78 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 79 | onmt.opts.add_md_help_argument(parser) 80 | onmt.opts.translate_opts(parser) 81 | 82 | group = parser.add_argument_group('fulmz_ext') 83 | group.add_argument('-proteins', nargs='+', type=str, 84 | help='translate the molecules of proteins in test data') 85 | group.add_argument('-use_protein40', action='store_true', default=False, 86 | help='translate the molecules of forty proteins(all proteins in training set) in test data') 87 | group.add_argument('-use_graph_embedding', action='store_true', default=False, 88 | help='if false, using the onehot encoding as the embedding of protein, ' 89 | 'else use embedding generated by ProteinBertModel') 90 | group.add_argument('-condition_dim', type=int, default=0, help='embedding size of protein') 91 | group.add_argument('-arch', type=str, default='before_linear', choices=['before_linear', 'after_encoding', 'after_decoding', 'no_cond', 'transformer'], 92 | help='model architecture') 93 | group.add_argument('-prepare_pt_file', type=str, help='the cache directory of preprared data') 94 | group.add_argument('-with_3d_confomer', action='store_true', default=False, help='Searching conformation of molecule to get the 3D positions of atoms per molecule,the position is required by Graph3dConv') 95 | opt = parser.parse_args() 96 | if opt.use_protein40: 97 | opt.proteins = list(include_list) 98 | 99 | logger = init_logger(opt.log_file) 100 | main(opt) 101 | -------------------------------------------------------------------------------- /score/data_loader.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import csv 4 | from typing import List 5 | 6 | import pandas as pd 7 | import rdkit.Chem.AllChem as Chem 8 | 9 | 10 | def load_sdf_data(sdf_file): 11 | mols = Chem.SDMolSupplier(sdf_file) 12 | mols = [m for m in mols if m is not None] 13 | if len(mols) == 0: 14 | return None 15 | colums = list(mols[0].GetPropsAsDict().keys()) 16 | colums.sort() 17 | colums.append('smiles') 18 | data_list = [] 19 | for m in mols: 20 | kv = m.GetPropsAsDict() 21 | if 'smiles' not in kv.keys(): 22 | kv['smiles'] = Chem.MolToSmiles(m) 23 | data_list.append([kv[k] for k in colums]) 24 | 25 | return pd.DataFrame(data_list, columns=colums) 26 | 27 | 28 | def detect_delimiter(source_file): 29 | """ 30 | detect delimeter 31 | :param source_file: 32 | :return: delimeter 33 | """ 34 | with open(source_file) as r: 35 | first_line = r.readline() 36 | if '\t' in first_line: 37 | return '\t' 38 | if ',' in first_line: 39 | return ',' 40 | return ' ' 41 | 42 | 43 | def has_header(head_line: List[str]): 44 | for s in head_line: 45 | try: 46 | mol = Chem.MolFromSmiles(s) 47 | if mol is not None: 48 | return False 49 | except: 50 | continue 51 | return True 52 | 53 | 54 | def get_csv_header(path: str) -> List[str]: 55 | """ 56 | Returns the header of a data CSV file. 57 | 58 | :param path: Path to a CSV file. 59 | :return: A list of strings containing the strings in the comma-separated header. 60 | """ 61 | with open(path) as f: 62 | header = next(csv.reader(f)) 63 | 64 | return header 65 | 66 | 67 | def get_xls_header(path: str) -> List[str]: 68 | pass 69 | 70 | 71 | def load_data_frame(source_file) -> object: 72 | if source_file.endswith('csv') or source_file.endswith('txt') or source_file.endswith('smi'): 73 | df = pd.read_csv(source_file, delimiter=detect_delimiter(source_file)) 74 | elif source_file.endswith('xls') or source_file.endswith('xlsx'): 75 | df = pd.read_excel(source_file) 76 | elif source_file.endswith('sdf'): 77 | df = load_sdf_data(source_file) 78 | else: 79 | print("can not read %s" % source_file) 80 | df = None 81 | return df 82 | 83 | 84 | if __name__ == '__main__': 85 | pass 86 | -------------------------------------------------------------------------------- /score/env.yaml: -------------------------------------------------------------------------------- 1 | name: deepchem 2 | channels: 3 | - deepchem 4 | - omnia 5 | - rdkit 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 8 | - conda-forge 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 10 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 11 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 12 | - defaults 13 | dependencies: 14 | - _py-xgboost-mutex=2.0=cpu_0 15 | - biopython=1.76=py36h516909a_0 16 | - blosc=1.18.1=he1b5a44_0 17 | - brotlipy=0.7.0=py36h8c4c3a4_1000 18 | - bzip2=1.0.8=h516909a_2 19 | - ca-certificates=2020.4.5.1=hecc5488_0 20 | - cairo=1.14.12=he6fea26_5 21 | - certifi=2020.4.5.1=py36h9f0ad1d_0 22 | - cffi=1.11.5=py36_0 23 | - chardet=3.0.4=py36h9f0ad1d_1006 24 | - cryptography=2.9.2=py36h45558ae_0 25 | - cycler=0.10.0=py_2 26 | - cython=0.29.17=py36h831f99a_0 27 | - decorator=4.4.2=py_0 28 | - flaky=3.6.1=py_0 29 | - fontconfig=2.13.1=he4413a7_1000 30 | - freetype=2.10.2=he06d7ca_0 31 | - gettext=0.19.8.1=h5e8e0c9_1 32 | - glib=2.55.0=0 33 | - icu=58.2=hf484d3e_1000 34 | - idna=2.9=py_1 35 | - joblib=0.14.1=py_0 36 | - jpeg=9c=h14c3975_1001 37 | - kiwisolver=1.2.0=py36hdb11119_0 38 | - libblas=3.8.0=14_openblas 39 | - libcblas=3.8.0=14_openblas 40 | - libgfortran-ng=7.3.0=hdf63c60_5 41 | - libiconv=1.15=h516909a_1006 42 | - liblapack=3.8.0=14_openblas 43 | - libopenblas=0.3.7=h5ec1e0e_6 44 | - libpng=1.6.37=hed695b0_1 45 | - libtiff=4.1.0=hc7e4089_6 46 | - libuuid=2.32.1=h14c3975_1000 47 | - libwebp-base=1.1.0=h516909a_3 48 | - libxcb=1.13=h14c3975_1002 49 | - libxgboost=1.0.2=he1b5a44_1 50 | - libxml2=2.9.9=h13577e0_2 51 | - lz4-c=1.9.2=he1b5a44_1 52 | - lzo=2.10=h14c3975_1000 53 | - matplotlib=3.1.1=py36_0 54 | - matplotlib-base=3.1.1=py36hfd891ef_0 55 | - mock=4.0.2=py36h9f0ad1d_0 56 | - networkx=2.4=py_1 57 | - nose=1.3.7=py36h9f0ad1d_1004 58 | - nose-timer=0.7.4=py_0 59 | - numexpr=2.7.1=py36h830a2c2_1 60 | - numpy=1.18.4=py36h7314795_0 61 | - olefile=0.46=py_0 62 | - openssl=1.1.1g=h516909a_0 63 | - pandas=1.0.3=py36h830a2c2_1 64 | - pcre=8.44=he1b5a44_0 65 | - pillow=7.1.2=py36h8328e55_0 66 | - pixman=0.34.0=h14c3975_1003 67 | - pthread-stubs=0.4=h14c3975_1001 68 | - py-xgboost=1.0.2=py36h9f0ad1d_1 69 | - pycparser=2.20=py_0 70 | - pyopenssl=19.1.0=py_1 71 | - pyparsing=2.4.7=pyh9f0ad1d_0 72 | - pyqt=4.11.4=py36_3 73 | - pysocks=1.7.1=py36h9f0ad1d_1 74 | - python-dateutil=2.8.1=py_0 75 | - python_abi=3.6=1_cp36m 76 | - pytz=2020.1=pyh9f0ad1d_0 77 | - requests=2.23.0=pyh8c360ce_2 78 | - scikit-learn=0.23.0=py36h0e1014b_0 79 | - scipy=1.4.1=py36h2d22cac_3 80 | - setuptools=46.3.0=py36h9f0ad1d_0 81 | - sip=4.18=py36_1 82 | - six=1.14.0=py_1 83 | - tensorboardx=2.0=py_0 84 | - termcolor=1.1.0=py_2 85 | - threadpoolctl=2.0.0=pyh5ca1d4c_0 86 | - tk=8.6.10=hed695b0_0 87 | - tornado=6.0.4=py36h8c4c3a4_1 88 | - urllib3=1.25.9=py_0 89 | - xorg-kbproto=1.0.7=h14c3975_1002 90 | - xorg-libice=1.0.10=h516909a_0 91 | - xorg-libsm=1.2.3=h84519dc_1000 92 | - xorg-libx11=1.6.9=h516909a_0 93 | - xorg-libxau=1.0.9=h14c3975_0 94 | - xorg-libxdmcp=1.1.3=h516909a_0 95 | - xorg-libxext=1.3.4=h516909a_0 96 | - xorg-libxrender=0.9.10=h516909a_1002 97 | - xorg-renderproto=0.11.1=h14c3975_1002 98 | - xorg-xextproto=7.3.0=h14c3975_1002 99 | - xorg-xproto=7.0.31=h14c3975_1007 100 | - zlib=1.2.11=h516909a_1006 101 | - zstd=1.4.4=h6597ccf_3 102 | - mdtraj=1.9.1=py36_1 103 | - simdna=0.4.2=py_0 104 | - qt=4.8.7=2 105 | - _libgcc_mutex=0.1=main 106 | - _tflow_select=2.1.0=gpu 107 | - absl-py=0.9.0=py36_0 108 | - astor=0.8.0=py36_0 109 | - blinker=1.4=py36_0 110 | - c-ares=1.15.0=h7b6447c_1001 111 | - cachetools=3.1.1=py_0 112 | - click=7.1.2=py_0 113 | - cloudpickle=1.1.1=py_0 114 | - cudatoolkit=10.1.243=h6bb024c_0 115 | - cudnn=7.6.5=cuda10.1_0 116 | - cupti=10.1.168=0 117 | - gast=0.2.2=py36_0 118 | - google-auth=1.14.1=py_0 119 | - google-auth-oauthlib=0.4.1=py_2 120 | - google-pasta=0.2.0=py_0 121 | - grpcio=1.27.2=py36hf8bcb03_0 122 | - h5py=2.10.0=py36h7918eee_0 123 | - hdf5=1.10.4=hb1b8bf9_0 124 | - keras-applications=1.0.8=py_0 125 | - keras-preprocessing=1.1.0=py_1 126 | - ld_impl_linux-64=2.33.1=h53a641e_7 127 | - libboost=1.67.0=h46d08c1_4 128 | - libedit=3.1.20181209=hc058e9b_0 129 | - libffi=3.3=he6710b0_1 130 | - libgcc-ng=9.1.0=hdf63c60_0 131 | - libprotobuf=3.11.4=hd408876_0 132 | - libstdcxx-ng=9.1.0=hdf63c60_0 133 | - markdown=3.1.1=py36_0 134 | - ncurses=6.2=he6710b0_1 135 | - oauthlib=3.1.0=py_0 136 | - opt_einsum=3.1.0=py_0 137 | - pip=20.0.2=py36_3 138 | - protobuf=3.11.4=py36he6710b0_0 139 | - py-boost=1.67.0=py36h04863e7_4 140 | - pyasn1=0.4.8=py_0 141 | - pyasn1-modules=0.2.7=py_0 142 | - pyjwt=1.7.1=py36_0 143 | - pytables=3.6.1=py36h71ec239_0 144 | - python=3.6.10=h7579374_2 145 | - readline=8.0=h7b6447c_0 146 | - requests-oauthlib=1.3.0=py_0 147 | - rsa=4.0=py_0 148 | - sqlite=3.31.1=h62c20be_1 149 | - tensorboard=2.1.0=py3_0 150 | - tensorflow=2.1.0=gpu_py36h2e5cdaa_0 151 | - tensorflow-base=2.1.0=gpu_py36h6c5654b_0 152 | - tensorflow-estimator=2.1.0=pyhd54b08b_0 153 | - tensorflow-gpu=2.1.0=h0d30ee6_0 154 | - tensorflow-probability=0.8.0=py_0 155 | - werkzeug=1.0.1=py_0 156 | - wheel=0.34.2=py36_0 157 | - wrapt=1.12.1=py36h7b6447c_1 158 | - xz=5.2.5=h7b6447c_0 159 | - fftw3f=3.3.4=2 160 | - openmm=7.4.1=py36_cuda101_rc_1 161 | - pdbfixer=1.6=py36_0 162 | - rdkit=2020.03.2.0=py36hc20afe1_1 163 | - pip: 164 | - deepchem==0.0.0 165 | - pbr==5.4.5 166 | - tables==3.6.1 167 | prefix: /home/ubuntu/anaconda3/envs/deepchem 168 | -------------------------------------------------------------------------------- /score/evaluate.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | from argparse import ArgumentParser 4 | 5 | import deepchem as dc 6 | from numpy import mean 7 | from rdkit import Chem 8 | from tensorflow_core.python.keras import backend 9 | 10 | from util import find_best_model_checkpoint 11 | from train import build_model, TASKS 12 | import pandas as pd 13 | import io 14 | import tensorflow as tf 15 | import os, time 16 | 17 | 18 | def eval_one_mol(smiles, model_save_dir_root, gpu): 19 | # smiles = Chem.MolToSmiles(mol, isomericSmiles=True) 20 | featurizer = dc.feat.CircularFingerprint(size=1024) 21 | loader = dc.data.CSVLoader(tasks=[], smiles_field='smiles', featurizer=featurizer) 22 | 23 | test_file = io.StringIO(f"smiles\n{smiles}\n") 24 | test_dataset = loader.create_dataset(test_file, shard_size=8096) 25 | device = f"/gpu:{gpu}" 26 | pred_list = [] 27 | config = tf.compat.v1.ConfigProto() 28 | config.gpu_options.per_process_gpu_memory_fraction = 0.1 # 程序最多只能占用指定gpu50%的显存 29 | config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配 30 | sess = tf.compat.v1.Session(config=config) 31 | backend.set_session(sess) 32 | 33 | for m in ['seed0', 'seed1', 'seed2']: 34 | model_save_dir = f"{model_save_dir_root}/{m}" 35 | 36 | checkpoint = find_best_model_checkpoint(model_save_dir) 37 | with tf.device(device): 38 | model = build_model(model_save_dir) 39 | model.restore(checkpoint) 40 | model.batch_size = 2 41 | # model.to(device) 42 | result = model.predict(test_dataset) 43 | if result.shape[-1] == 1: 44 | result = result.squeeze(-1) 45 | pred_list.append(result[0][55]) 46 | return mean(pred_list) 47 | 48 | 49 | def eval_test(model_save_dir_root, test_file, pred_path, gpu): 50 | featurizer = dc.feat.CircularFingerprint(size=1024) 51 | loader = dc.data.CSVLoader(tasks=[], smiles_field='smiles', featurizer=featurizer) 52 | test_dataset = loader.create_dataset(test_file, shard_size=8096) 53 | device = f"cuda:{gpu}" 54 | pred_list = [] 55 | for m in ['seed0', 'seed1', 'seed2']: 56 | model_save_dir = f"{model_save_dir_root}/{m}" 57 | model = build_model(model_save_dir) 58 | checkpoint = find_best_model_checkpoint(model_save_dir) 59 | model.restore(checkpoint) 60 | # model.to(device) 61 | result = model.predict(test_dataset) 62 | if result.shape[-1] == 1: 63 | pred = pd.DataFrame(result.squeeze(-1)) 64 | else: 65 | pred = pd.DataFrame(result) 66 | pred['smiles'] = test_dataset.ids.tolist() 67 | # dummmy column for next pipeline 68 | pred['label'] = 0 69 | pred_list.append(pred) 70 | pred = pd.concat(pred_list) 71 | score_colnames = [c for c in pred.columns] 72 | score_colnames.remove('smiles') 73 | score_colnames.remove('label') 74 | # average the values from 3 models as the score 75 | pred.groupby(['smiles'], sort=False)[score_colnames].apply(lambda x: mean(x)) 76 | # pred = pred.reset_index() 77 | cols = ['smiles', 'label'] 78 | cols.extend(score_colnames) 79 | pred.to_csv(pred_path, index=False, columns=cols) 80 | 81 | 82 | def run_pipe_server(model_save_dir_root): 83 | read_path = "/tmp/pipe_scorer.in" 84 | write_path = "/tmp/pipe_scorer.out" 85 | 86 | if os.path.exists(read_path): 87 | os.remove(read_path) 88 | if os.path.exists(write_path): 89 | os.remove(write_path) 90 | 91 | os.mkfifo(write_path) 92 | os.mkfifo(read_path) 93 | 94 | wf = os.open(write_path, os.O_SYNC | os.O_CREAT | os.O_RDWR) 95 | rf = os.open(read_path, os.O_RDONLY) 96 | featurizer = dc.feat.CircularFingerprint(size=1024) 97 | loader = dc.data.CSVLoader(tasks=[], smiles_field='smiles', featurizer=featurizer) 98 | 99 | gpu = 1 100 | device = f"/gpu:{gpu}" 101 | 102 | config = tf.compat.v1.ConfigProto() 103 | config.gpu_options.per_process_gpu_memory_fraction = 0.1 # 程序最多只能占用指定gpu50%的显存 104 | config.gpu_options.allow_growth = True # 不全部占满显存, 按需分配 105 | sess = tf.compat.v1.Session(config=config) 106 | backend.set_session(sess) 107 | model_list = [] 108 | for m in ['seed0', 'seed1', 'seed2']: 109 | model_save_dir = f"{model_save_dir_root}/{m}" 110 | 111 | checkpoint = find_best_model_checkpoint(model_save_dir) 112 | with tf.device(device): 113 | model = build_model(model_save_dir) 114 | model.restore(checkpoint) 115 | model.batch_size = 2 116 | model_list.append(model) 117 | 118 | while True: 119 | s = os.read(rf, 1024) 120 | # # cur_command.write(s) 121 | if len(s) == 0: 122 | time.sleep(1e-3) 123 | continue 124 | smiles = s.decode() 125 | if "exit" in smiles: 126 | os.close(rf) 127 | break 128 | else: 129 | print(f"SMILES: {s}") 130 | # smiles = s.decode() 131 | test_file = io.StringIO(f"smiles\n{smiles}\n") 132 | test_dataset = loader.create_dataset(test_file, shard_size=8096) 133 | pred_list = [] 134 | for model in model_list: 135 | result = model.predict(test_dataset) 136 | if result.shape[-1] == 1: 137 | result = result.squeeze(-1) 138 | pred_list.append(result[0][55]) 139 | score = mean(pred_list) 140 | os.write(wf, f"{score:.6f}".encode()) 141 | print(f"result: {score:.6f}") 142 | os.close(rf) 143 | os.close(wf) 144 | 145 | 146 | if __name__ == '__main__': 147 | # v = eval_one_mol(Chem.MolFromSmiles('COc1cc(C=Nn2c(-c3ccccc3)nc3ccccc3c2=O)cc(OC)c1OCC(=O)Nc1ccccc1F'), 148 | # '/home/aht/paper_code/shaungjia/code/score/model/total_mtr', 1) 149 | # print(v) 150 | parser = ArgumentParser(conflict_handler='resolve', description='Configure') 151 | parser.add_argument('--test_path', type=str, 152 | default='/home/aht/paper_code/shaungjia/code/score/model/do_chemprop/data/total_mtr', 153 | help='the directory of test data') 154 | parser.add_argument('--preds_path', type=str, 155 | default='/home/aht/paper_code/shaungjia/code/score/model/total_mtr', 156 | help='output directory') 157 | parser.add_argument('--model_save_dir', type=str, help='the model saved directory') 158 | parser.add_argument('--gpu', type=int, default=0, help='GPU rank') 159 | parser.add_argument('--scorer_pipe_server', action='store_true', default=False, help='accept the request of clients by named pipe') 160 | 161 | args = parser.parse_args() 162 | if args.scorer_pipe_server: 163 | run_pipe_server(args.model_save_dir) 164 | else: 165 | eval_test(args.model_save_dir, args.test_path, args.preds_path, args.gpu) 166 | -------------------------------------------------------------------------------- /score/summary_one_task.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | from train import eval_test, TASKS 4 | import os 5 | import pandas as pd 6 | import numpy as np 7 | 8 | from util import find_best_model_checkpoint 9 | 10 | def summary_one_dataset(mt_dir, tasks_under_dir): 11 | df_list = [] 12 | for s in range(3): 13 | model_save_dir = f'{mt_dir}/seed{s}' 14 | checkpoint = find_best_model_checkpoint(model_save_dir) 15 | print(f'use checkpoint {checkpoint}') 16 | df = eval_test(model_save_dir, checkpoint, tasks_under_dir) 17 | df_list.append(df) 18 | 19 | def apply_func(row): 20 | # target, r2, rmse = x 21 | target, r2, rmse = row['target'], row['r2'], row['rmse'] 22 | r2_list = [r2] 23 | rmse_list = [rmse] 24 | for df in df_list[1:]: 25 | r2_list.append(df.loc[df.target == target, 'r2'].tolist()[0]) 26 | rmse_list.append(df.loc[df.target == target, 'rmse'].tolist()[0]) 27 | return np.mean(r2_list), np.std(r2_list), np.mean(rmse_list), np.std(rmse_list) 28 | 29 | df = df_list[0].copy(deep=True) 30 | df[['r2', 'r2_std', 'rmse', 'rmse_std']] = df.apply(apply_func, axis=1, result_type='expand') 31 | return df 32 | 33 | 34 | def summary_per_task(): 35 | base_dir = '/model/per_task' 36 | df_list = [] 37 | for f in os.listdir(base_dir): 38 | df = summary_one_dataset(os.path.join(base_dir, f), [f[:-2]]) 39 | df_list.append(df) 40 | total_df = pd.concat(df_list, axis=0) 41 | total_df.to_csv('/home/aht/paper_code/shaungjia/code/score/result_experiment/mtdnn_per_task.csv', index=False) 42 | 43 | -------------------------------------------------------------------------------- /score/total_mtr/seed0/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ckpt-43" 2 | all_model_checkpoint_paths: "ckpt-39" 3 | all_model_checkpoint_paths: "ckpt-40" 4 | all_model_checkpoint_paths: "ckpt-41" 5 | all_model_checkpoint_paths: "ckpt-42" 6 | all_model_checkpoint_paths: "ckpt-43" 7 | all_model_checkpoint_timestamps: 1594024383.2564054 8 | all_model_checkpoint_timestamps: 1594024413.111633 9 | all_model_checkpoint_timestamps: 1594024442.5917788 10 | all_model_checkpoint_timestamps: 1594024472.1298075 11 | all_model_checkpoint_timestamps: 1594024472.9814959 12 | last_preserved_timestamp: 1594023225.5010207 13 | -------------------------------------------------------------------------------- /score/total_mtr/seed0/ckpt-43.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed0/ckpt-43.data-00000-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed0/ckpt-43.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed0/ckpt-43.data-00001-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed0/ckpt-43.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed0/ckpt-43.index -------------------------------------------------------------------------------- /score/total_mtr/seed1/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ckpt-53" 2 | all_model_checkpoint_paths: "ckpt-49" 3 | all_model_checkpoint_paths: "ckpt-50" 4 | all_model_checkpoint_paths: "ckpt-51" 5 | all_model_checkpoint_paths: "ckpt-52" 6 | all_model_checkpoint_paths: "ckpt-53" 7 | all_model_checkpoint_timestamps: 1594026700.8525636 8 | all_model_checkpoint_timestamps: 1594026730.505618 9 | all_model_checkpoint_timestamps: 1594026760.2001607 10 | all_model_checkpoint_timestamps: 1594026789.985513 11 | all_model_checkpoint_timestamps: 1594026790.83649 12 | last_preserved_timestamp: 1594025251.675494 13 | -------------------------------------------------------------------------------- /score/total_mtr/seed1/ckpt-53.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed1/ckpt-53.data-00000-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed1/ckpt-53.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed1/ckpt-53.data-00001-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed1/ckpt-53.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed1/ckpt-53.index -------------------------------------------------------------------------------- /score/total_mtr/seed2/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "ckpt-37" 2 | all_model_checkpoint_paths: "ckpt-33" 3 | all_model_checkpoint_paths: "ckpt-34" 4 | all_model_checkpoint_paths: "ckpt-35" 5 | all_model_checkpoint_paths: "ckpt-36" 6 | all_model_checkpoint_paths: "ckpt-37" 7 | all_model_checkpoint_timestamps: 1594028089.5555096 8 | all_model_checkpoint_timestamps: 1594028118.9732153 9 | all_model_checkpoint_timestamps: 1594028148.5008736 10 | all_model_checkpoint_timestamps: 1594028177.8692698 11 | all_model_checkpoint_timestamps: 1594028178.7763312 12 | last_preserved_timestamp: 1594027111.616544 13 | -------------------------------------------------------------------------------- /score/total_mtr/seed2/ckpt-37.data-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed2/ckpt-37.data-00000-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed2/ckpt-37.data-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed2/ckpt-37.data-00001-of-00002 -------------------------------------------------------------------------------- /score/total_mtr/seed2/ckpt-37.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/prokia/deepHops/529870eba5ad486527113cbd5607535192deb212/score/total_mtr/seed2/ckpt-37.index -------------------------------------------------------------------------------- /score/util.py: -------------------------------------------------------------------------------- 1 | # --*-- coding: utf-8 --*-- 2 | 3 | import os 4 | 5 | def find_best_model_checkpoint(model_save_dir): 6 | all_ckpt = set() 7 | for f in os.listdir(model_save_dir): 8 | if not os.path.isfile(os.path.join(model_save_dir, f)) or not f.startswith('ckpt-'): 9 | continue 10 | all_ckpt.add(f.split('.')[0]) 11 | ckpt_list = list(all_ckpt) 12 | ckpt_list.sort() 13 | return f"{model_save_dir}/{ckpt_list[0]}" --------------------------------------------------------------------------------