├── .gitignore ├── figure ├── model.png ├── multi.png └── single.png ├── multi_design_esr1 ├── figure.pdf ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── args.py ├── plot.py ├── stats.py ├── dataset.py ├── sascorer.py └── logger.py ├── single_design_esr1 ├── figure.pdf ├── utils.py ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── dataset.py ├── args.py ├── plot.py ├── sascorer.py └── logger.py ├── single_design_acaa1 ├── a.py ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── plot.py ├── sascorer.py └── logger.py ├── multi_design_acaa1 ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── args.py ├── stats.py ├── dataset.py ├── sascorer.py └── logger.py ├── single_design_plogp ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── plot.py └── sascorer.py ├── single_design_qed ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── stats.py └── sascorer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pt -------------------------------------------------------------------------------- /figure/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/model.png -------------------------------------------------------------------------------- /figure/multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/multi.png -------------------------------------------------------------------------------- /figure/single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/single.png -------------------------------------------------------------------------------- /multi_design_esr1/figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/multi_design_esr1/figure.pdf -------------------------------------------------------------------------------- /single_design_esr1/figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/single_design_esr1/figure.pdf -------------------------------------------------------------------------------- /single_design_acaa1/a.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | it = 22 4 | a = np.load(str(it) + '.npy') 5 | 6 | a = np.unique(a) 7 | kd = np.exp(a * (-1) / (0.00198720425864083 * 298.15)).flatten() 8 | 9 | ind = np.argsort(kd) 10 | ind = ind[:200] 11 | 12 | b = kd[ind] 13 | print(np.mean(b), np.std(b)) 14 | print(b) 15 | -------------------------------------------------------------------------------- /single_design_esr1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | -------------------------------------------------------------------------------- /multi_design_acaa1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /multi_design_esr1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_acaa1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_esr1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_plogp/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_qed/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sampling with Gradual Distribution Shifting (SGDS) 2 | This is the repository for our paper "Molecule Design by Latent Space Energy-based Modeling and Gradual Distribution Shifting" in UAI 2023. [PDF](https://proceedings.mlr.press/v216/kong23a/kong23a.pdf) 3 | 4 | ![alt text](https://github.com/deqiankong/SGDS/blob/main/figure/model.png) 5 | 6 | In this paper, we studied the following property optimization tasks: 7 | * single-objective p-logP maximization 8 | * single-objective QED maximization 9 | * single-objective ESR1 binding affinity maximization 10 | * single-objective ACAA1 binding affinity maximization 11 | * multi-objective (ESR1, QED, SA) optmization 12 | * multi-objective (ACAA1, QED, SA) optmization 13 | 14 |

15 | 16 | 17 |

18 | 19 | ## Environment 20 | We follow the previous work [LIMO](https://github.com/Rose-STL-Lab/LIMO) for setting up RDKit, Open Babel and AutoDock-GPU. We extend our gratitude to the authors for their significant contributions. 21 | 22 | ## Data 23 | We use selfies representations of ZINC250k with corresponding property values. All the property values can be computed either by RDKit or AutoDock-GPU. 24 | 25 | ## Usage 26 | For model training given certain property (i.e. ESR1), 27 | ``` 28 | cd single_design_esr1 29 | python main.py 30 | ``` 31 | 32 | For property optimizaton task, 33 | ``` 34 | python single_design.py or multi_design.py 35 | ``` 36 | 37 | ## Cite 38 |
39 | @inproceedings{kong2023molecule,
40 |   title={Molecule Design by Latent Space Energy-Based Modeling and Gradual Distribution Shifting},
41 |   author={Kong, Deqian and Pang, Bo and Han, Tian and Wu, Ying Nian},
42 |   booktitle={Uncertainty in Artificial Intelligence},
43 |   pages={1109--1120},
44 |   year={2023},
45 |   organization={PMLR}
46 | }
47 | 
48 | -------------------------------------------------------------------------------- /multi_design_esr1/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /single_design_qed/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /multi_design_acaa1/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /single_design_acaa1/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /single_design_esr1/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /single_design_plogp/ZINC/cal_property.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | from rdkit import Chem 9 | 10 | from rdkit.Chem.Descriptors import ExactMolWt 11 | from rdkit.Chem.Crippen import MolLogP 12 | from rdkit.Chem.Crippen import MolMR 13 | 14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD 15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA 16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 17 | 18 | from rdkit.Chem.QED import qed 19 | #sys.path.insert(0,'/home/shade/SA_Score') 20 | import sascorer 21 | 22 | 23 | file_input="250k_rndm_zinc_drugs_clean_3.csv" 24 | fp_input=open(file_input) 25 | lines=fp_input.readlines() 26 | fp_input.close() 27 | 28 | file_output="data_5.txt" 29 | fp_out=open(file_output,"w") 30 | 31 | line_out="#smi logP SAS QED MW TPSA\n" 32 | fp_out.write(line_out) 33 | logP_list=[] 34 | SAS_list=[] 35 | QED_list=[] 36 | MW_list=[] 37 | TPSA_list=[] 38 | 39 | for i in range(len(lines)): 40 | line=lines[i] 41 | if line[0]!='"': 42 | continue 43 | if line[1]!=",": 44 | smi=line[1:].strip() 45 | continue 46 | m=Chem.MolFromSmiles(smi) 47 | smi2=Chem.MolToSmiles(m) 48 | 49 | property0=line[2:].split(",") 50 | # logP=float(property0[0]) 51 | # SAS=float(property0[2]) 52 | # QED=float(property0[1]) 53 | 54 | logP=MolLogP(m) 55 | SAS=sascorer.calculateScore(m) 56 | QED=qed(m) 57 | 58 | MW=ExactMolWt(m) 59 | TPSA=CalcTPSA(m) 60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA) 61 | fp_out.write(line_out) 62 | logP_list+=[logP] 63 | SAS_list+=[SAS] 64 | QED_list+=[QED] 65 | MW_list+=[MW] 66 | TPSA_list+=[TPSA] 67 | 68 | 69 | fp_out.close() 70 | 71 | logP_array=np.array(logP_list) 72 | SAS_array=np.array(SAS_list) 73 | QED_array=np.array(QED_list) 74 | MW_array=np.array(MW_list) 75 | TPSA_array=np.array(TPSA_list) 76 | 77 | print(logP_array.min(),logP_array.max()) 78 | print(SAS_array.min(),SAS_array.max()) 79 | print(QED_array.min(),QED_array.max()) 80 | print(MW_array.min(),MW_array.max()) 81 | print(TPSA_array.min(),TPSA_array.max()) 82 | 83 | -------------------------------------------------------------------------------- /single_design_qed/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | 56 | if __name__ == '__main__': 57 | # exp_id = get_exp_id(__file__) 58 | exp_id = 'ebm_plot' 59 | output_dir = get_output_dir(exp_id, fs_prefix='../alienware_') 60 | print(exp_id) 61 | print(os.getcwd()) 62 | print(__file__) 63 | print(os.path.basename(__file__)) 64 | print(os.path.dirname(__file__)) 65 | # copy_source(__file__, output_dir) 66 | copy_all_files(__file__, output_dir) 67 | -------------------------------------------------------------------------------- /single_design_plogp/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | 56 | if __name__ == '__main__': 57 | # exp_id = get_exp_id(__file__) 58 | exp_id = 'ebm_plot' 59 | output_dir = get_output_dir(exp_id, fs_prefix='../alienware_') 60 | print(exp_id) 61 | print(os.getcwd()) 62 | print(__file__) 63 | print(os.path.basename(__file__)) 64 | print(os.path.dirname(__file__)) 65 | # copy_source(__file__, output_dir) 66 | copy_all_files(__file__, output_dir) 67 | -------------------------------------------------------------------------------- /multi_design_acaa1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | 56 | if __name__ == '__main__': 57 | # exp_id = get_exp_id(__file__) 58 | # exp_id = 'ebm_plot' 59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_') 60 | # print(exp_id) 61 | # print(os.getcwd()) 62 | # print(__file__) 63 | # print(os.path.basename(__file__)) 64 | # print(os.path.dirname(__file__)) 65 | # # copy_source(__file__, output_dir) 66 | # copy_all_files(__file__, output_dir) 67 | set_gpu(1) 68 | print(torch.cuda.device_count()) 69 | a = torch.tensor([1,1,1]).cuda() 70 | print(a) -------------------------------------------------------------------------------- /multi_design_esr1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | 56 | if __name__ == '__main__': 57 | # exp_id = get_exp_id(__file__) 58 | # exp_id = 'ebm_plot' 59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_') 60 | # print(exp_id) 61 | # print(os.getcwd()) 62 | # print(__file__) 63 | # print(os.path.basename(__file__)) 64 | # print(os.path.dirname(__file__)) 65 | # # copy_source(__file__, output_dir) 66 | # copy_all_files(__file__, output_dir) 67 | set_gpu(1) 68 | print(torch.cuda.device_count()) 69 | a = torch.tensor([1,1,1]).cuda() 70 | print(a) -------------------------------------------------------------------------------- /single_design_acaa1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | 56 | if __name__ == '__main__': 57 | # exp_id = get_exp_id(__file__) 58 | # exp_id = 'ebm_plot' 59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_') 60 | # print(exp_id) 61 | # print(os.getcwd()) 62 | # print(__file__) 63 | # print(os.path.basename(__file__)) 64 | # print(os.path.dirname(__file__)) 65 | # # copy_source(__file__, output_dir) 66 | # copy_all_files(__file__, output_dir) 67 | set_gpu(1) 68 | print(torch.cuda.device_count()) 69 | a = torch.tensor([1,1,1]).cuda() 70 | print(a) -------------------------------------------------------------------------------- /single_design_qed/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | 15 | 16 | # max_len = 72 17 | # num_of_embeddings = 109 18 | class MolDataset(Dataset): 19 | def __init__(self, datadir, dname): 20 | Xdata_file = datadir + "/X" + dname + ".npy" 21 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 22 | Ldata_file = datadir + "/L" + dname + ".npy" 23 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 24 | self.len = self.Xdata.shape[0] 25 | LogPdata_file = datadir + "/LogP" + dname + ".npy" 26 | self.LogP = torch.tensor(np.load(LogPdata_file), dtype=torch.float32) 27 | qed_data_file = datadir + "/qed_" + dname + ".npy" 28 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 29 | 30 | def __getitem__(self, index): 31 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 32 | mol = self.Xdata[index] 33 | sos = torch.tensor([108], dtype=torch.long) 34 | mol = torch.cat([sos, mol], dim=0).contiguous() 35 | mask = torch.zeros(mol.shape[0] - 1) 36 | mask[:self.Ldata[index] + 1] = 1. 37 | return (mol, mask, self.qed[index]) 38 | # return (mol, mask, self.LogP[index]) 39 | # print(self.Ldata[index]) 40 | # print(mask) 41 | # print(mol[:self.Ldata[index] + 2]) 42 | # return (mol, self.Ldata[index], self.LogP[index]) 43 | 44 | def __len__(self): 45 | return self.len 46 | 47 | def save_mol_png(self, label, filepath, size=(600, 600)): 48 | m_smi = self.label2sf2smi(label) 49 | m = Chem.MolFromSmiles(m_smi) 50 | Draw.MolToFile(m, filepath, size=size) 51 | 52 | def label2sf2smi(self, label): 53 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 54 | m_smi = sf.decoder(m_sf) 55 | m_smi = Chem.CanonSmiles(m_smi) 56 | return m_smi 57 | 58 | 59 | if __name__ == '__main__': 60 | datadir = '../data' 61 | ds = MolDataset(datadir, 'train') 62 | ds_loader = DataLoader(dataset=ds, batch_size=100, 63 | shuffle=True, drop_last=True, num_workers=2) 64 | 65 | max_len = 0 66 | for i in range(len(ds)): 67 | m, len, s = ds[i] 68 | m = m.numpy() 69 | mol = ds.label2sf2smi(m[1:]) 70 | print(mol) 71 | ds.save_mol_png(label=m[1:], filepath='../a.png') 72 | break 73 | 74 | # print(char_dict) 75 | # print(ds.sf2smi(m), s) 76 | -------------------------------------------------------------------------------- /single_design_esr1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | import math 15 | 16 | 17 | # max_len = 72 18 | # num_of_embeddings = 109 19 | class MolDataset(Dataset): 20 | def __init__(self, datadir, dname): 21 | Xdata_file = datadir + "/X" + dname + ".npy" 22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 23 | Ldata_file = datadir + "/L" + dname + ".npy" 24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 25 | self.len = self.Xdata.shape[0] 26 | sasdata_file = datadir + "/sas_" + dname + ".npy" 27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32) 28 | qed_data_file = datadir + "/qed_" + dname + ".npy" 29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 30 | 31 | ba0data_file = datadir + "/ba0_" + dname + ".npy" 32 | self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32) 33 | # ba1data_file = datadir + "/ba1_" + dname + ".npy" 34 | # self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32) 35 | # self.preprocess() 36 | 37 | def preprocess(self): 38 | ind = torch.nonzero(self.ba0) 39 | self.ba0 = self.ba0[ind].squeeze() 40 | self.sas = self.sas[ind].squeeze() 41 | self.qed = self.qed[ind].squeeze() 42 | self.Xdata = self.Xdata[ind].squeeze() 43 | self.Ldata = self.Ldata[ind].squeeze() 44 | self.len = self.Xdata.shape[0] 45 | 46 | def __getitem__(self, index): 47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 48 | mol = self.Xdata[index] 49 | sos = torch.tensor([108], dtype=torch.long) 50 | mol = torch.cat([sos, mol], dim=0).contiguous() 51 | mask = torch.zeros(mol.shape[0] - 1) 52 | mask[:self.Ldata[index] + 1] = 1. 53 | return (mol, mask, -1 * self.ba0[index]) 54 | 55 | def __len__(self): 56 | return self.len 57 | 58 | def save_mol_png(self, label, filepath, size=(600, 600)): 59 | m_smi = self.label2sf2smi(label) 60 | m = Chem.MolFromSmiles(m_smi) 61 | Draw.MolToFile(m, filepath, size=size) 62 | 63 | def label2sf2smi(self, label): 64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 65 | m_smi = sf.decoder(m_sf) 66 | m_smi = Chem.CanonSmiles(m_smi) 67 | return m_smi 68 | 69 | @staticmethod 70 | def delta_to_kd(x): 71 | return math.exp(x / (0.00198720425864083 * 298.15)) 72 | 73 | -------------------------------------------------------------------------------- /single_design_acaa1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | import math 15 | 16 | 17 | # max_len = 72 18 | # num_of_embeddings = 109 19 | class MolDataset(Dataset): 20 | def __init__(self, datadir, dname): 21 | Xdata_file = datadir + "/X" + dname + ".npy" 22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 23 | Ldata_file = datadir + "/L" + dname + ".npy" 24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 25 | self.len = self.Xdata.shape[0] 26 | sasdata_file = datadir + "/sas_" + dname + ".npy" 27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32) 28 | qed_data_file = datadir + "/qed_" + dname + ".npy" 29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 30 | 31 | # ba0data_file = datadir + "/ba0_" + dname + ".npy" 32 | # self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32) 33 | ba1data_file = datadir + "/ba1_" + dname + ".npy" 34 | self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32) 35 | # self.preprocess() 36 | 37 | def preprocess(self): 38 | ind = torch.nonzero(self.ba0) 39 | self.ba0 = self.ba0[ind].squeeze() 40 | self.sas = self.sas[ind].squeeze() 41 | self.qed = self.qed[ind].squeeze() 42 | self.Xdata = self.Xdata[ind].squeeze() 43 | self.Ldata = self.Ldata[ind].squeeze() 44 | self.len = self.Xdata.shape[0] 45 | 46 | def __getitem__(self, index): 47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 48 | mol = self.Xdata[index] 49 | sos = torch.tensor([108], dtype=torch.long) 50 | mol = torch.cat([sos, mol], dim=0).contiguous() 51 | mask = torch.zeros(mol.shape[0] - 1) 52 | mask[:self.Ldata[index] + 1] = 1. 53 | return (mol, mask, -1 * self.ba1[index]) 54 | 55 | def __len__(self): 56 | return self.len 57 | 58 | def save_mol_png(self, label, filepath, size=(600, 600)): 59 | m_smi = self.label2sf2smi(label) 60 | m = Chem.MolFromSmiles(m_smi) 61 | Draw.MolToFile(m, filepath, size=size) 62 | 63 | def label2sf2smi(self, label): 64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 65 | m_smi = sf.decoder(m_sf) 66 | m_smi = Chem.CanonSmiles(m_smi) 67 | return m_smi 68 | 69 | @staticmethod 70 | def delta_to_kd(x): 71 | return math.exp(x / (0.00198720425864083 * 298.15)) 72 | 73 | 74 | -------------------------------------------------------------------------------- /single_design_plogp/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | 15 | 16 | # max_len = 72 17 | # num_of_embeddings = 109 18 | class MolDataset(Dataset): 19 | def __init__(self, datadir, dname): 20 | Xdata_file = datadir + "/X" + dname + ".npy" 21 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 22 | Ldata_file = datadir + "/L" + dname + ".npy" 23 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 24 | self.len = self.Xdata.shape[0] 25 | LogPdata_file = datadir + "/LogP" + dname + ".npy" 26 | self.LogP = torch.tensor(np.load(LogPdata_file), dtype=torch.float32) 27 | qed_data_file = datadir + "/qed_" + dname + ".npy" 28 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 29 | PlogPdata_file = datadir + "/PlogP" + dname + ".npy" 30 | self.PlogP = torch.tensor(np.load(PlogPdata_file), dtype=torch.float32) 31 | 32 | def __getitem__(self, index): 33 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 34 | mol = self.Xdata[index] 35 | sos = torch.tensor([108], dtype=torch.long) 36 | mol = torch.cat([sos, mol], dim=0).contiguous() 37 | mask = torch.zeros(mol.shape[0] - 1) 38 | mask[:self.Ldata[index] + 1] = 1. 39 | return (mol, mask, self.PlogP[index]) 40 | # return (mol, mask, self.LogP[index]) 41 | # print(self.Ldata[index]) 42 | # print(mask) 43 | # print(mol[:self.Ldata[index] + 2]) 44 | # return (mol, self.Ldata[index], self.LogP[index]) 45 | 46 | def __len__(self): 47 | return self.len 48 | 49 | def save_mol_png(self, label, filepath, size=(600, 600)): 50 | m_smi = self.label2sf2smi(label) 51 | m = Chem.MolFromSmiles(m_smi) 52 | Draw.MolToFile(m, filepath, size=size) 53 | 54 | def label2sf2smi(self, label): 55 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 56 | m_smi = sf.decoder(m_sf) 57 | m_smi = Chem.CanonSmiles(m_smi) 58 | return m_smi 59 | 60 | 61 | if __name__ == '__main__': 62 | datadir = '../data' 63 | ds = MolDataset(datadir, 'train') 64 | ds_loader = DataLoader(dataset=ds, batch_size=100, 65 | shuffle=True, drop_last=True, num_workers=2) 66 | 67 | max_len = 0 68 | for i in range(len(ds)): 69 | m, len, s = ds[i] 70 | m = m.numpy() 71 | mol = ds.label2sf2smi(m[1:]) 72 | print(mol) 73 | ds.save_mol_png(label=m[1:], filepath='../a.png') 74 | break 75 | 76 | # print(char_dict) 77 | # print(ds.sf2smi(m), s) 78 | -------------------------------------------------------------------------------- /multi_design_acaa1/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /multi_design_esr1/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /single_design_acaa1/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /single_design_esr1/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /single_design_plogp/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /single_design_qed/ZINC/char.py: -------------------------------------------------------------------------------- 1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I", 2 | # "n", "c", "o", "s", 3 | # "1", "2", "3", "4", "5", "6", "7", "8", 4 | # "(", ")", "[", "]", 5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"] 6 | # 7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5, 8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9, 9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13, 10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21, 11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28, 12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34} 13 | 14 | # char_dict=dict() 15 | # i=-1 16 | # for c in char_list: 17 | # i+=1 18 | # char_dict[c]=i 19 | # print(char_dict) 20 | # sys.exit() 21 | 22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]', 23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]', 24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]', 25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]', 26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]', 27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]', 28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]', 29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]', 30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]', 31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]', 32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos'] 33 | 34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]', 35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]', 36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]', 37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]', 38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]', 39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]', 40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]', 41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]', 42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]', 43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]', 44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]', 45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]', 46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]', 47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]', 48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]', 49 | 107: '[PH1+1]'} 50 | -------------------------------------------------------------------------------- /single_design_qed/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data') 10 | parser.add_argument('--test_file', default='ZINC/test_5.txt') 11 | parser.add_argument('--train_from', default='') 12 | parser.add_argument('--max_len', default=110, type=int) 13 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 14 | parser.add_argument('--eval_batch_size', default=512 * 1, type=int) 15 | 16 | # General options 17 | parser.add_argument('--z_n_iters', type=int, default=20) 18 | parser.add_argument('--z_step_size', type=float, default=0.5) 19 | parser.add_argument('--z_with_noise', type=int, default=1) 20 | parser.add_argument('--num_z_samples', type=int, default=10) 21 | parser.add_argument('--model', type=str, default='mol_ebm') 22 | parser.add_argument('--mask', type=bool, default=False) 23 | parser.add_argument('--single_design', default=True, type=bool) 24 | 25 | # EBM 26 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 27 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 28 | parser.add_argument('--prior_step_size', type=float, default=0.5) 29 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 30 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 31 | parser.add_argument('--ebm_reg', default=0.0, type=float) 32 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 33 | parser.add_argument('--ref_sigma', type=float, default=0.5) 34 | parser.add_argument('--init_factor', type=float, default=1.) 35 | parser.add_argument('--noise_factor', type=float, default=0.5) 36 | 37 | # Decoder and MLP options 38 | parser.add_argument('--mlp_hidden_dim', default=200, type=int) 39 | parser.add_argument('--prop_coefficient', default=10., type=float) 40 | parser.add_argument('--latent_dim', default=100, type=int) 41 | parser.add_argument('--dec_word_dim', default=512, type=int) 42 | parser.add_argument('--dec_h_dim', default=1024, type=int) 43 | parser.add_argument('--dec_num_layers', default=1, type=int) 44 | parser.add_argument('--dec_dropout', default=0.2, type=float) 45 | parser.add_argument('--train_n2n', default=1, type=int) 46 | parser.add_argument('--train_kl', default=1, type=int) 47 | 48 | # Optimization options 49 | parser.add_argument('--log_dir', default='../log/') 50 | parser.add_argument('--checkpoint_dir', default='models') 51 | # parser.add_argument('--slurm', default=0, type=int) 52 | parser.add_argument('--warmup', default=0, type=int) 53 | parser.add_argument('--num_epochs', default=25, type=int) 54 | parser.add_argument('--min_epochs', default=15, type=int) 55 | parser.add_argument('--start_epoch', default=0, type=int) 56 | parser.add_argument('--eps', default=1e-5, type=float) 57 | parser.add_argument('--decay', default=0, type=int) 58 | parser.add_argument('--momentum', default=0.5, type=float) 59 | parser.add_argument('--lr', default=0.001, type=float) 60 | parser.add_argument('--prior_lr', default=0.0001, type=float) 61 | parser.add_argument('--max_grad_norm', default=5, type=float) 62 | parser.add_argument('--gpu', default=0, type=int) 63 | parser.add_argument('--seed', default=3435, type=int) 64 | parser.add_argument('--print_every', type=int, default=100) 65 | parser.add_argument('--sample_every', type=int, default=1000) 66 | parser.add_argument('--kl_every', type=int, default=100) 67 | parser.add_argument('--compute_kl', type=int, default=1) 68 | parser.add_argument('--test', type=int, default=0) 69 | return parser.parse_args() 70 | 71 | if __name__ == '__main__': 72 | args = get_args() 73 | print(args) 74 | -------------------------------------------------------------------------------- /single_design_plogp/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data') 10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt') 11 | parser.add_argument('--train_from', default='') 12 | parser.add_argument('--max_len', default=110, type=int) 13 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 14 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int) 15 | 16 | # General options 17 | parser.add_argument('--z_n_iters', type=int, default=20) 18 | parser.add_argument('--z_step_size', type=float, default=0.5) 19 | parser.add_argument('--z_with_noise', type=int, default=1) 20 | parser.add_argument('--num_z_samples', type=int, default=10) 21 | parser.add_argument('--model', type=str, default='mol_ebm') 22 | parser.add_argument('--mask', type=bool, default=False) 23 | parser.add_argument('--single_design', default=True, type=bool) 24 | 25 | # EBM 26 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 27 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 28 | parser.add_argument('--prior_step_size', type=float, default=0.5) 29 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 30 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 31 | parser.add_argument('--ebm_reg', default=0.0, type=float) 32 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 33 | parser.add_argument('--ref_sigma', type=float, default=0.5) 34 | parser.add_argument('--init_factor', type=float, default=1.) 35 | parser.add_argument('--noise_factor', type=float, default=0.5) 36 | 37 | # Decoder and MLP options 38 | parser.add_argument('--mlp_hidden_dim', default=200, type=int) 39 | parser.add_argument('--prop_coefficient', default=10., type=float) 40 | parser.add_argument('--latent_dim', default=100, type=int) 41 | parser.add_argument('--dec_word_dim', default=512, type=int) 42 | parser.add_argument('--dec_h_dim', default=1024, type=int) 43 | parser.add_argument('--dec_num_layers', default=1, type=int) 44 | parser.add_argument('--dec_dropout', default=0.2, type=float) 45 | parser.add_argument('--train_n2n', default=1, type=int) 46 | parser.add_argument('--train_kl', default=1, type=int) 47 | 48 | # Optimization options 49 | parser.add_argument('--log_dir', default='../log/') 50 | parser.add_argument('--checkpoint_dir', default='models') 51 | # parser.add_argument('--slurm', default=0, type=int) 52 | parser.add_argument('--warmup', default=0, type=int) 53 | parser.add_argument('--num_epochs', default=25, type=int) 54 | parser.add_argument('--min_epochs', default=15, type=int) 55 | parser.add_argument('--start_epoch', default=0, type=int) 56 | parser.add_argument('--eps', default=1e-5, type=float) 57 | parser.add_argument('--decay', default=0, type=int) 58 | parser.add_argument('--momentum', default=0.5, type=float) 59 | parser.add_argument('--lr', default=0.001, type=float) 60 | parser.add_argument('--prior_lr', default=0.0001, type=float) 61 | parser.add_argument('--max_grad_norm', default=5, type=float) 62 | parser.add_argument('--gpu', default=0, type=int) 63 | parser.add_argument('--seed', default=3435, type=int) 64 | parser.add_argument('--print_every', type=int, default=100) 65 | parser.add_argument('--sample_every', type=int, default=1000) 66 | parser.add_argument('--kl_every', type=int, default=100) 67 | parser.add_argument('--compute_kl', type=int, default=1) 68 | parser.add_argument('--test', type=int, default=0) 69 | return parser.parse_args() 70 | 71 | if __name__ == '__main__': 72 | args = get_args() 73 | print(args) 74 | -------------------------------------------------------------------------------- /multi_design_esr1/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data') 10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt') 11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi') 12 | parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld') 13 | parser.add_argument('--train_from', default='') 14 | parser.add_argument('--max_len', default=110, type=int) 15 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int) 17 | 18 | # General options 19 | parser.add_argument('--z_n_iters', type=int, default=20) 20 | parser.add_argument('--z_step_size', type=float, default=0.5) 21 | parser.add_argument('--z_with_noise', type=int, default=1) 22 | parser.add_argument('--num_z_samples', type=int, default=10) 23 | parser.add_argument('--model', type=str, default='mol_ebm') 24 | parser.add_argument('--mask', type=bool, default=False) 25 | parser.add_argument('--single_design', default=False, type=bool) 26 | parser.add_argument('--multi_design', default=True, type=bool) 27 | 28 | # EBM 29 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 30 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 31 | parser.add_argument('--prior_step_size', type=float, default=0.5) 32 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 34 | parser.add_argument('--ebm_reg', default=0.0, type=float) 35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 36 | parser.add_argument('--ref_sigma', type=float, default=0.5) 37 | parser.add_argument('--init_factor', type=float, default=1.) 38 | parser.add_argument('--noise_factor', type=float, default=0.5) 39 | 40 | # Decoder and MLP options 41 | parser.add_argument('--mlp_hidden_dim', default=50, type=int) 42 | parser.add_argument('--latent_dim', default=100, type=int) 43 | parser.add_argument('--dec_word_dim', default=512, type=int) 44 | parser.add_argument('--dec_h_dim', default=1024, type=int) 45 | parser.add_argument('--dec_num_layers', default=1, type=int) 46 | parser.add_argument('--dec_dropout', default=0.2, type=float) 47 | parser.add_argument('--train_n2n', default=1, type=int) 48 | parser.add_argument('--train_kl', default=1, type=int) 49 | 50 | # prop coefficients 51 | parser.add_argument('--prop_coefficient', default=10., type=float) 52 | parser.add_argument('--ba', default=10., type=float) 53 | parser.add_argument('--sas', default=10., type=float) 54 | parser.add_argument('--qed', default=10., type=float) 55 | 56 | # Optimization options 57 | parser.add_argument('--log_dir', default='../log/') 58 | parser.add_argument('--checkpoint_dir', default='models') 59 | # parser.add_argument('--slurm', default=0, type=int) 60 | parser.add_argument('--warmup', default=0, type=int) 61 | parser.add_argument('--num_epochs', default=30, type=int) 62 | parser.add_argument('--min_epochs', default=15, type=int) 63 | parser.add_argument('--start_epoch', default=0, type=int) 64 | parser.add_argument('--eps', default=1e-5, type=float) 65 | parser.add_argument('--decay', default=0, type=int) 66 | parser.add_argument('--momentum', default=0.5, type=float) 67 | parser.add_argument('--lr', default=0.001, type=float) 68 | parser.add_argument('--prior_lr', default=0.0001, type=float) 69 | parser.add_argument('--max_grad_norm', default=5, type=float) 70 | parser.add_argument('--gpu', default=1, type=int) 71 | parser.add_argument('--seed', default=3435, type=int) 72 | parser.add_argument('--print_every', type=int, default=100) 73 | parser.add_argument('--sample_every', type=int, default=1000) 74 | parser.add_argument('--kl_every', type=int, default=100) 75 | parser.add_argument('--compute_kl', type=int, default=1) 76 | parser.add_argument('--test', type=int, default=0) 77 | return parser.parse_args() 78 | 79 | if __name__ == '__main__': 80 | args = get_args() 81 | print(args) 82 | -------------------------------------------------------------------------------- /single_design_acaa1/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data/ba1') 10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt') 11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi') 12 | parser.add_argument('--protein_file', type=str, default='../2iik/2iik.maps.fld') 13 | parser.add_argument('--train_from', default='') 14 | parser.add_argument('--max_len', default=110, type=int) 15 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int) 17 | 18 | # General options 19 | parser.add_argument('--z_n_iters', type=int, default=20) 20 | parser.add_argument('--z_step_size', type=float, default=0.5) 21 | parser.add_argument('--z_with_noise', type=int, default=1) 22 | parser.add_argument('--num_z_samples', type=int, default=10) 23 | parser.add_argument('--model', type=str, default='mol_ebm') 24 | parser.add_argument('--mask', type=bool, default=False) 25 | parser.add_argument('--single_design', default=True, type=bool) 26 | parser.add_argument('--multi_design', default=False, type=bool) 27 | 28 | # EBM 29 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 30 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 31 | parser.add_argument('--prior_step_size', type=float, default=0.5) 32 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 34 | parser.add_argument('--ebm_reg', default=0.0, type=float) 35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 36 | parser.add_argument('--ref_sigma', type=float, default=0.5) 37 | parser.add_argument('--init_factor', type=float, default=1.) 38 | parser.add_argument('--noise_factor', type=float, default=0.5) 39 | 40 | # Decoder and MLP options 41 | parser.add_argument('--mlp_hidden_dim', default=100, type=int) 42 | parser.add_argument('--latent_dim', default=100, type=int) 43 | parser.add_argument('--dec_word_dim', default=512, type=int) 44 | parser.add_argument('--dec_h_dim', default=1024, type=int) 45 | parser.add_argument('--dec_num_layers', default=1, type=int) 46 | parser.add_argument('--dec_dropout', default=0.2, type=float) 47 | parser.add_argument('--train_n2n', default=1, type=int) 48 | parser.add_argument('--train_kl', default=1, type=int) 49 | 50 | # prop coefficients 51 | parser.add_argument('--prop_coefficient', default=10., type=float) 52 | parser.add_argument('--ba', default=10., type=float) 53 | parser.add_argument('--sas', default=10., type=float) 54 | parser.add_argument('--qed', default=10., type=float) 55 | 56 | # Optimization options 57 | parser.add_argument('--log_dir', default='../log/') 58 | parser.add_argument('--checkpoint_dir', default='models') 59 | # parser.add_argument('--slurm', default=0, type=int) 60 | parser.add_argument('--warmup', default=0, type=int) 61 | parser.add_argument('--num_epochs', default=30, type=int) 62 | parser.add_argument('--min_epochs', default=15, type=int) 63 | parser.add_argument('--start_epoch', default=0, type=int) 64 | parser.add_argument('--eps', default=1e-5, type=float) 65 | parser.add_argument('--decay', default=0, type=int) 66 | parser.add_argument('--momentum', default=0.5, type=float) 67 | parser.add_argument('--lr', default=0.001, type=float) 68 | parser.add_argument('--prior_lr', default=0.0001, type=float) 69 | parser.add_argument('--max_grad_norm', default=5, type=float) 70 | parser.add_argument('--gpu', default=1, type=int) 71 | parser.add_argument('--seed', default=3435, type=int) 72 | parser.add_argument('--print_every', type=int, default=100) 73 | parser.add_argument('--sample_every', type=int, default=1000) 74 | parser.add_argument('--kl_every', type=int, default=100) 75 | parser.add_argument('--compute_kl', type=int, default=1) 76 | parser.add_argument('--test', type=int, default=0) 77 | return parser.parse_args() 78 | 79 | if __name__ == '__main__': 80 | args = get_args() 81 | print(args) 82 | -------------------------------------------------------------------------------- /single_design_esr1/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data') 10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt') 11 | parser.add_argument('--autodock_executable', type=str, default='../../AutoDock-GPU/bin/autodock_gpu_128wi') 12 | parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld') 13 | parser.add_argument('--train_from', default='') 14 | parser.add_argument('--max_len', default=110, type=int) 15 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int) 17 | 18 | # General options 19 | parser.add_argument('--z_n_iters', type=int, default=20) 20 | parser.add_argument('--z_step_size', type=float, default=0.5) 21 | parser.add_argument('--z_with_noise', type=int, default=1) 22 | parser.add_argument('--num_z_samples', type=int, default=10) 23 | parser.add_argument('--model', type=str, default='mol_ebm') 24 | parser.add_argument('--mask', type=bool, default=False) 25 | parser.add_argument('--single_design', default=True, type=bool) 26 | parser.add_argument('--multi_design', default=False, type=bool) 27 | 28 | # EBM 29 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 30 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 31 | parser.add_argument('--prior_step_size', type=float, default=0.5) 32 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 34 | parser.add_argument('--ebm_reg', default=0.0, type=float) 35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 36 | parser.add_argument('--ref_sigma', type=float, default=0.5) 37 | parser.add_argument('--init_factor', type=float, default=1.) 38 | parser.add_argument('--noise_factor', type=float, default=0.5) 39 | 40 | # Decoder and MLP options 41 | parser.add_argument('--mlp_hidden_dim', default=50, type=int) 42 | parser.add_argument('--latent_dim', default=100, type=int) 43 | parser.add_argument('--dec_word_dim', default=512, type=int) 44 | parser.add_argument('--dec_h_dim', default=1024, type=int) 45 | parser.add_argument('--dec_num_layers', default=1, type=int) 46 | parser.add_argument('--dec_dropout', default=0.2, type=float) 47 | parser.add_argument('--train_n2n', default=1, type=int) 48 | parser.add_argument('--train_kl', default=1, type=int) 49 | 50 | # prop coefficients 51 | parser.add_argument('--prop_coefficient', default=10., type=float) 52 | parser.add_argument('--ba', default=10., type=float) 53 | parser.add_argument('--sas', default=10., type=float) 54 | parser.add_argument('--qed', default=10., type=float) 55 | 56 | # Optimization options 57 | parser.add_argument('--log_dir', default='../log/') 58 | parser.add_argument('--checkpoint_dir', default='models') 59 | # parser.add_argument('--slurm', default=0, type=int) 60 | parser.add_argument('--warmup', default=0, type=int) 61 | parser.add_argument('--num_epochs', default=30, type=int) 62 | parser.add_argument('--min_epochs', default=15, type=int) 63 | parser.add_argument('--start_epoch', default=0, type=int) 64 | parser.add_argument('--eps', default=1e-5, type=float) 65 | parser.add_argument('--decay', default=0, type=int) 66 | parser.add_argument('--momentum', default=0.5, type=float) 67 | parser.add_argument('--lr', default=0.001, type=float) 68 | parser.add_argument('--prior_lr', default=0.0001, type=float) 69 | parser.add_argument('--max_grad_norm', default=5, type=float) 70 | parser.add_argument('--gpu', default=1, type=int) 71 | parser.add_argument('--seed', default=3435, type=int) 72 | parser.add_argument('--print_every', type=int, default=100) 73 | parser.add_argument('--sample_every', type=int, default=1000) 74 | parser.add_argument('--kl_every', type=int, default=100) 75 | parser.add_argument('--compute_kl', type=int, default=1) 76 | parser.add_argument('--test', type=int, default=0) 77 | return parser.parse_args() 78 | 79 | if __name__ == '__main__': 80 | args = get_args() 81 | print(args) 82 | -------------------------------------------------------------------------------- /multi_design_acaa1/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--debug", default=False, type=bool) 6 | parser.add_argument("--tb", default=False, type=bool) 7 | 8 | # Input data 9 | parser.add_argument('--data_dir', default='../data/ba1') 10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt') 11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi') 12 | # parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld') 13 | parser.add_argument('--protein_file', type=str, default='../2iik/2iik.maps.fld') 14 | parser.add_argument('--train_from', default='') 15 | parser.add_argument('--max_len', default=110, type=int) 16 | parser.add_argument('--batch_size', default=1024 * 1, type=int) 17 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int) 18 | 19 | # General options 20 | parser.add_argument('--z_n_iters', type=int, default=20) 21 | parser.add_argument('--z_step_size', type=float, default=0.5) 22 | parser.add_argument('--z_with_noise', type=int, default=1) 23 | parser.add_argument('--num_z_samples', type=int, default=10) 24 | parser.add_argument('--model', type=str, default='mol_ebm') 25 | parser.add_argument('--mask', type=bool, default=False) 26 | parser.add_argument('--single_design', default=False, type=bool) 27 | parser.add_argument('--multi_design', default=True, type=bool) 28 | 29 | # EBM 30 | parser.add_argument('--prior_hidden_dim', type=int, default=200) 31 | parser.add_argument('--z_prior_with_noise', type=int, default=1) 32 | parser.add_argument('--prior_step_size', type=float, default=0.5) 33 | parser.add_argument('--z_n_iters_prior', type=int, default=20) 34 | parser.add_argument('--max_grad_norm_prior', default=1, type=float) 35 | parser.add_argument('--ebm_reg', default=0.0, type=float) 36 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform']) 37 | parser.add_argument('--ref_sigma', type=float, default=0.5) 38 | parser.add_argument('--init_factor', type=float, default=1.) 39 | parser.add_argument('--noise_factor', type=float, default=0.5) 40 | 41 | # Decoder and MLP options 42 | parser.add_argument('--mlp_hidden_dim', default=50, type=int) 43 | parser.add_argument('--latent_dim', default=100, type=int) 44 | parser.add_argument('--dec_word_dim', default=512, type=int) 45 | parser.add_argument('--dec_h_dim', default=1024, type=int) 46 | parser.add_argument('--dec_num_layers', default=1, type=int) 47 | parser.add_argument('--dec_dropout', default=0.2, type=float) 48 | parser.add_argument('--train_n2n', default=1, type=int) 49 | parser.add_argument('--train_kl', default=1, type=int) 50 | 51 | # prop coefficients 52 | parser.add_argument('--prop_coefficient', default=10., type=float) 53 | parser.add_argument('--ba', default=10., type=float) 54 | parser.add_argument('--sas', default=10., type=float) 55 | parser.add_argument('--qed', default=10., type=float) 56 | 57 | # Optimization options 58 | parser.add_argument('--log_dir', default='../log/') 59 | parser.add_argument('--checkpoint_dir', default='models') 60 | # parser.add_argument('--slurm', default=0, type=int) 61 | parser.add_argument('--warmup', default=0, type=int) 62 | parser.add_argument('--num_epochs', default=30, type=int) 63 | parser.add_argument('--min_epochs', default=15, type=int) 64 | parser.add_argument('--start_epoch', default=0, type=int) 65 | parser.add_argument('--eps', default=1e-5, type=float) 66 | parser.add_argument('--decay', default=0, type=int) 67 | parser.add_argument('--momentum', default=0.5, type=float) 68 | parser.add_argument('--lr', default=0.001, type=float) 69 | parser.add_argument('--prior_lr', default=0.0001, type=float) 70 | parser.add_argument('--max_grad_norm', default=5, type=float) 71 | parser.add_argument('--gpu', default=1, type=int) 72 | parser.add_argument('--seed', default=3435, type=int) 73 | parser.add_argument('--print_every', type=int, default=100) 74 | parser.add_argument('--sample_every', type=int, default=1000) 75 | parser.add_argument('--kl_every', type=int, default=100) 76 | parser.add_argument('--compute_kl', type=int, default=1) 77 | parser.add_argument('--test', type=int, default=0) 78 | return parser.parse_args() 79 | 80 | if __name__ == '__main__': 81 | args = get_args() 82 | print(args) 83 | -------------------------------------------------------------------------------- /single_design_acaa1/plot.py: -------------------------------------------------------------------------------- 1 | from matplotlib import rc 2 | # getting necessary libraries 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) 9 | 10 | n = 50 11 | m = 2000 12 | data = [] 13 | mean = np.zeros(n) 14 | for i in range(50): 15 | data.append(np.load(str(i) + ".npy")) 16 | data = np.vstack(data)[0:n + 1, 0:m] 17 | 18 | for j in range(n): 19 | mean[j] = np.mean(data[j, :]) 20 | 21 | print(data.shape) 22 | # print(data) 23 | 24 | epochInd = [i for i in range(50)] 25 | toDel = [i for i in range(0, 50, 2)] 26 | print(len(toDel)) 27 | dataList = [] 28 | for i in range(n): 29 | if i not in toDel: 30 | tmp = np.zeros((int(m), 3)) 31 | tmp[:, 1] = np.ones(int(m)) * epochInd[i] 32 | tmp[:, 2] = np.ones(int(m)) * mean[i] 33 | for j in range(int(m)): 34 | tmp[j, 0] = data[i, j] 35 | dataList.append(tmp) 36 | 37 | data = np.vstack(dataList) 38 | vals = data[:, 0].tolist() 39 | labels = data[:, 1].astype('int').astype('str').tolist() 40 | means = data[:, 2].tolist() 41 | # print(means) 42 | # print(labels) 43 | 44 | 45 | # Col = ["Vals", "Epoch"] 46 | # df = pd.DataFrame(data, columns = Col) 47 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means}) 48 | print(df) 49 | # 50 | # 51 | # sns.kdeplot(data=df, x="Vals", hue='Epoch') 52 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 53 | 54 | 55 | # # 56 | # # # getting the data 57 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository 58 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column 59 | # 60 | # # we define a dictionnary with months that we'll use later 61 | # month_dict = {1: 'january', 62 | # 2: 'february', 63 | # 3: 'march', 64 | # 4: 'april', 65 | # 5: 'may', 66 | # 6: 'june', 67 | # 7: 'july', 68 | # 8: 'august', 69 | # 9: 'september', 70 | # 10: 'october', 71 | # 11: 'november', 72 | # 12: 'december'} 73 | # 74 | # # we create a 'month' column 75 | # temp['month'] = temp['month'].map(month_dict) 76 | # # # 77 | # print(temp) 78 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe 79 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean() 80 | # temp['mean_month'] = temp['month'].map(month_mean_serie) 81 | # 82 | # 83 | # we generate a color palette with Seaborn.color_palette() 84 | pal = sns.color_palette(palette='coolwarm', n_colors=20) 85 | 86 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette' 87 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal) 88 | # 89 | # # then we add the densities kdeplots for each month 90 | g.map(sns.kdeplot, 'Vals', 91 | bw_adjust=1, clip_on=False, warn_singular=False, 92 | fill=True, alpha=1, linewidth=1.5) 93 | 94 | # # # # here we add a white line that represents the contour of each kdeplot 95 | # g.map(sns.kdeplot, 'Vals', 96 | # bw_adjust=1, clip_on=False, warn_singular=False, 97 | # color="w", lw=2) 98 | # # 99 | # # here we add a horizontal line for each plot 100 | g.map(plt.axhline, y=0, 101 | lw=2, clip_on=False) 102 | # # 103 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color 104 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes 105 | for i, ax in enumerate(g.axes.flat): 106 | ax.text(33, 0.03, i, 107 | fontweight='bold', fontsize=15, 108 | color=ax.lines[-1].get_color()) 109 | 110 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap 111 | g.fig.subplots_adjust(hspace=-0.88) 112 | 113 | # eventually we remove axes titles, yticks and spines 114 | g.set_titles("") 115 | g.set(yticks=[]) 116 | g.despine(bottom=True, left=True) 117 | 118 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold') 119 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15) 120 | g.fig.suptitle('tmp', 121 | ha='right', 122 | fontsize=20, 123 | fontweight=20) 124 | # plt.xlim(xmin=0) 125 | # plt.show() 126 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 127 | -------------------------------------------------------------------------------- /single_design_esr1/plot.py: -------------------------------------------------------------------------------- 1 | from matplotlib import rc 2 | # getting necessary libraries 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) 9 | 10 | n = 50 11 | m = 2000 12 | data = [] 13 | mean = np.zeros(n) 14 | for i in range(50): 15 | data.append(np.load(str(i) + ".npy")) 16 | data = np.vstack(data)[0:n + 1, 0:m] 17 | 18 | for j in range(n): 19 | mean[j] = np.mean(data[j, :]) 20 | 21 | print(data.shape) 22 | # print(data) 23 | 24 | epochInd = [i for i in range(50)] 25 | toDel = [i for i in range(0, 50, 2)] 26 | print(len(toDel)) 27 | dataList = [] 28 | for i in range(n): 29 | if i not in toDel: 30 | tmp = np.zeros((int(m), 3)) 31 | tmp[:, 1] = np.ones(int(m)) * epochInd[i] 32 | tmp[:, 2] = np.ones(int(m)) * mean[i] 33 | for j in range(int(m)): 34 | tmp[j, 0] = data[i, j] 35 | dataList.append(tmp) 36 | 37 | data = np.vstack(dataList) 38 | vals = data[:, 0].tolist() 39 | labels = data[:, 1].astype('int').astype('str').tolist() 40 | means = data[:, 2].tolist() 41 | # print(means) 42 | # print(labels) 43 | 44 | 45 | # Col = ["Vals", "Epoch"] 46 | # df = pd.DataFrame(data, columns = Col) 47 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means}) 48 | print(df) 49 | # 50 | # 51 | # sns.kdeplot(data=df, x="Vals", hue='Epoch') 52 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 53 | 54 | 55 | # # 56 | # # # getting the data 57 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository 58 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column 59 | # 60 | # # we define a dictionnary with months that we'll use later 61 | # month_dict = {1: 'january', 62 | # 2: 'february', 63 | # 3: 'march', 64 | # 4: 'april', 65 | # 5: 'may', 66 | # 6: 'june', 67 | # 7: 'july', 68 | # 8: 'august', 69 | # 9: 'september', 70 | # 10: 'october', 71 | # 11: 'november', 72 | # 12: 'december'} 73 | # 74 | # # we create a 'month' column 75 | # temp['month'] = temp['month'].map(month_dict) 76 | # # # 77 | # print(temp) 78 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe 79 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean() 80 | # temp['mean_month'] = temp['month'].map(month_mean_serie) 81 | # 82 | # 83 | # we generate a color palette with Seaborn.color_palette() 84 | pal = sns.color_palette(palette='coolwarm', n_colors=20) 85 | 86 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette' 87 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal) 88 | # 89 | # # then we add the densities kdeplots for each month 90 | g.map(sns.kdeplot, 'Vals', 91 | bw_adjust=1, clip_on=False, warn_singular=False, 92 | fill=True, alpha=1, linewidth=1.5) 93 | 94 | # # # # here we add a white line that represents the contour of each kdeplot 95 | # g.map(sns.kdeplot, 'Vals', 96 | # bw_adjust=1, clip_on=False, warn_singular=False, 97 | # color="w", lw=2) 98 | # # 99 | # # here we add a horizontal line for each plot 100 | g.map(plt.axhline, y=0, 101 | lw=2, clip_on=False) 102 | # # 103 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color 104 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes 105 | for i, ax in enumerate(g.axes.flat): 106 | ax.text(33, 0.03, i, 107 | fontweight='bold', fontsize=15, 108 | color=ax.lines[-1].get_color()) 109 | 110 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap 111 | g.fig.subplots_adjust(hspace=-0.88) 112 | 113 | # eventually we remove axes titles, yticks and spines 114 | g.set_titles("") 115 | g.set(yticks=[]) 116 | g.despine(bottom=True, left=True) 117 | 118 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold') 119 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15) 120 | g.fig.suptitle('tmp', 121 | ha='right', 122 | fontsize=20, 123 | fontweight=20) 124 | # plt.xlim(xmin=0) 125 | # plt.show() 126 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 127 | -------------------------------------------------------------------------------- /single_design_plogp/plot.py: -------------------------------------------------------------------------------- 1 | from matplotlib import rc 2 | # getting necessary libraries 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) 9 | 10 | n = 50 11 | m = 10000 12 | data = [] 13 | mean = np.zeros(n) 14 | for i in range(50): 15 | data.append(np.load('two_step/' + str(i) + ".npy")) 16 | data = np.vstack(data)[0:n + 1, 0:m] 17 | 18 | for j in range(n): 19 | mean[j] = np.mean(data[j, :]) 20 | 21 | print(data.shape) 22 | # print(data) 23 | 24 | epochInd = [i for i in range(50)] 25 | toDel = [i for i in range(15, 50)] 26 | dataList = [] 27 | for i in range(n): 28 | if i not in toDel: 29 | tmp = np.zeros((int(m), 3)) 30 | tmp[:, 1] = np.ones(int(m)) * epochInd[i] 31 | tmp[:, 2] = np.ones(int(m)) * mean[i] 32 | for j in range(int(m)): 33 | tmp[j, 0] = data[i, j] 34 | dataList.append(tmp) 35 | 36 | data = np.vstack(dataList) 37 | vals = data[:, 0].tolist() 38 | labels = data[:, 1].astype('int').astype('str').tolist() 39 | means = data[:, 2].tolist() 40 | # print(means) 41 | # print(labels) 42 | 43 | 44 | # Col = ["Vals", "Epoch"] 45 | # df = pd.DataFrame(data, columns = Col) 46 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means}) 47 | print(df) 48 | # 49 | # 50 | # sns.kdeplot(data=df, x="Vals", hue='Epoch') 51 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 52 | 53 | 54 | # # 55 | # # # getting the data 56 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository 57 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column 58 | # 59 | # # we define a dictionnary with months that we'll use later 60 | # month_dict = {1: 'january', 61 | # 2: 'february', 62 | # 3: 'march', 63 | # 4: 'april', 64 | # 5: 'may', 65 | # 6: 'june', 66 | # 7: 'july', 67 | # 8: 'august', 68 | # 9: 'september', 69 | # 10: 'october', 70 | # 11: 'november', 71 | # 12: 'december'} 72 | # 73 | # # we create a 'month' column 74 | # temp['month'] = temp['month'].map(month_dict) 75 | # # # 76 | # print(temp) 77 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe 78 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean() 79 | # temp['mean_month'] = temp['month'].map(month_mean_serie) 80 | # 81 | # 82 | # we generate a color palette with Seaborn.color_palette() 83 | pal = sns.color_palette(palette='coolwarm', n_colors=20) 84 | 85 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette' 86 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal) 87 | # 88 | # # then we add the densities kdeplots for each month 89 | g.map(sns.kdeplot, 'Vals', 90 | bw_adjust=1, clip_on=False, warn_singular=False, 91 | fill=True, alpha=1, linewidth=1.5) 92 | 93 | # # # # here we add a white line that represents the contour of each kdeplot 94 | # g.map(sns.kdeplot, 'Vals', 95 | # bw_adjust=1, clip_on=False, warn_singular=False, 96 | # color="w", lw=2) 97 | # # 98 | # # here we add a horizontal line for each plot 99 | g.map(plt.axhline, y=0, 100 | lw=2, clip_on=False) 101 | # # 102 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color 103 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes 104 | for i, ax in enumerate(g.axes.flat): 105 | ax.text(33, 0.03, i, 106 | fontweight='bold', fontsize=15, 107 | color=ax.lines[-1].get_color()) 108 | 109 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap 110 | g.fig.subplots_adjust(hspace=-0.88) 111 | 112 | # eventually we remove axes titles, yticks and spines 113 | g.set_titles("") 114 | g.set(yticks=[]) 115 | g.despine(bottom=True, left=True) 116 | 117 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold') 118 | plt.xlabel("Penalized " + "logP", fontweight='bold', fontsize=15) 119 | g.fig.suptitle('t', 120 | ha='right', 121 | fontsize=20, 122 | fontweight=20) 123 | # plt.xlim(xmin=0) 124 | # plt.show() 125 | plt.savefig("plogp.pdf", dpi=300, bbox_inches='tight') 126 | -------------------------------------------------------------------------------- /multi_design_esr1/plot.py: -------------------------------------------------------------------------------- 1 | from matplotlib import rc 2 | # getting necessary libraries 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | import matplotlib.pyplot as plt 7 | 8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) 9 | 10 | n = 25 11 | m = 2000 12 | data = [] 13 | mean = np.zeros(n) 14 | data = np.zeros([25, 10000]) 15 | 16 | for i in range(0, 25): 17 | x = np.load(str(i) + "_ba.npy") 18 | data[i, :len(x)] = x 19 | # data = np.vstack(data)[0:n + 1, 0:m] 20 | 21 | 22 | for j in range(n): 23 | mean[j] = np.mean(data[j, :]) 24 | 25 | print(data.shape) 26 | # print(data) 27 | 28 | epochInd = [i for i in range(25)] 29 | toDel = [i for i in range(0, 0)] 30 | print(len(toDel)) 31 | dataList = [] 32 | for i in range(n): 33 | if i not in toDel: 34 | tmp = np.zeros((int(m), 3)) 35 | tmp[:, 1] = np.ones(int(m)) * epochInd[i] 36 | tmp[:, 2] = np.ones(int(m)) * mean[i] 37 | for j in range(int(m)): 38 | tmp[j, 0] = data[i, j] 39 | dataList.append(tmp) 40 | 41 | data = np.vstack(dataList) 42 | vals = data[:, 0].tolist() 43 | labels = data[:, 1].astype('int').astype('str').tolist() 44 | means = data[:, 2].tolist() 45 | # print(means) 46 | # print(labels) 47 | 48 | 49 | # Col = ["Vals", "Epoch"] 50 | # df = pd.DataFrame(data, columns = Col) 51 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means}) 52 | print(df) 53 | # 54 | # 55 | # sns.kdeplot(data=df, x="Vals", hue='Epoch') 56 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 57 | 58 | 59 | # # 60 | # # # getting the data 61 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository 62 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column 63 | # 64 | # # we define a dictionnary with months that we'll use later 65 | # month_dict = {1: 'january', 66 | # 2: 'february', 67 | # 3: 'march', 68 | # 4: 'april', 69 | # 5: 'may', 70 | # 6: 'june', 71 | # 7: 'july', 72 | # 8: 'august', 73 | # 9: 'september', 74 | # 10: 'october', 75 | # 11: 'november', 76 | # 12: 'december'} 77 | # 78 | # # we create a 'month' column 79 | # temp['month'] = temp['month'].map(month_dict) 80 | # # # 81 | # print(temp) 82 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe 83 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean() 84 | # temp['mean_month'] = temp['month'].map(month_mean_serie) 85 | # 86 | # 87 | # we generate a color palette with Seaborn.color_palette() 88 | pal = sns.color_palette(palette='coolwarm', n_colors=20) 89 | 90 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette' 91 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal) 92 | # 93 | # # then we add the densities kdeplots for each month 94 | g.map(sns.kdeplot, 'Vals', 95 | bw_adjust=1, clip_on=False, warn_singular=False, 96 | fill=True, alpha=1, linewidth=1.5) 97 | 98 | # # # # here we add a white line that represents the contour of each kdeplot 99 | # g.map(sns.kdeplot, 'Vals', 100 | # bw_adjust=1, clip_on=False, warn_singular=False, 101 | # color="w", lw=2) 102 | # # 103 | # # here we add a horizontal line for each plot 104 | g.map(plt.axhline, y=0, 105 | lw=2, clip_on=False) 106 | # # 107 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color 108 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes 109 | for i, ax in enumerate(g.axes.flat): 110 | ax.text(33, 0.03, i, 111 | fontweight='bold', fontsize=15, 112 | color=ax.lines[-1].get_color()) 113 | 114 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap 115 | g.fig.subplots_adjust(hspace=-0.88) 116 | 117 | # eventually we remove axes titles, yticks and spines 118 | g.set_titles("") 119 | g.set(yticks=[]) 120 | g.despine(bottom=True, left=True) 121 | 122 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold') 123 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15) 124 | g.fig.suptitle('tmp', 125 | ha='right', 126 | fontsize=20, 127 | fontweight=20) 128 | # plt.xlim(xmin=0) 129 | # plt.show() 130 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight') 131 | -------------------------------------------------------------------------------- /multi_design_acaa1/stats.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | from rdkit.Chem.QED import qed 5 | import sascorer 6 | 7 | 8 | def property_plot(logps, sass, qeds, data_properties): 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | import pandas as pd 12 | data_logps, data_sass, data_qeds = data_properties 13 | assert len(data_logps) == len(data_sass) == len(data_qeds) 14 | assert len(logps) == len(sass) == len(qeds) 15 | assert len(data_logps) == len(logps) 16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps) 17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps}) 18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"]) 19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass}) 20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"]) 21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds}) 22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"]) 23 | 24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8]) 25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12) 26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0]) 27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1]) 28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2]) 29 | 30 | plt.show() 31 | plt.close(fig) 32 | 33 | 34 | def get_data_properties(path, seq_length=110): 35 | with open(path) as f: 36 | data_lines = f.readlines() 37 | 38 | logps = [] 39 | sass = [] 40 | qeds = [] 41 | mols = [] 42 | title = "" 43 | for line in data_lines: 44 | if line[0] == "#": 45 | title = line[1:-1] 46 | title_list = title.split() 47 | print(title_list) 48 | continue 49 | arr = line.split() 50 | # print(arr) 51 | if len(arr) < 2: 52 | continue 53 | smiles = arr[0] 54 | if len(smiles) > seq_length: 55 | continue 56 | assert len(arr) == 6 57 | mols.append(arr[0]) 58 | logps.append(arr[1]) 59 | sass.append(arr[2]) 60 | qeds.append(arr[3]) 61 | # smiles0 = smiles.ljust(seq_length, '>') 62 | # smiles_list += [smiles] 63 | # Narr = len(arr) 64 | # # cdd = [] 65 | # for i in range(1, Narr): 66 | # if title_list[i] == "logP": 67 | # cdd += [float(arr[i])/10.0] 68 | # elif title_list[i] == "SAS": 69 | # cdd += [float(arr[i])/10.0] 70 | # elif title_list[i] == "QED": 71 | # cdd += [float(arr[i])/1.0] 72 | # elif title_list[i] == "MW": 73 | # cdd += [float(arr[i])/500.0] 74 | # elif title_list[i] == "TPSA": 75 | # cdd += [float(arr[i])/150.0] 76 | return mols, (logps, sass, qeds) 77 | 78 | 79 | def _convert_keys_to_int(dict): 80 | new_dict = {} 81 | for k, v in dict.items(): 82 | try: 83 | new_key = int(k) 84 | except ValueError: 85 | new_key = k 86 | new_dict[new_key] = v 87 | return new_dict 88 | 89 | 90 | if __name__ == '__main__': 91 | from args import get_args 92 | import numpy as np 93 | from rdkit import Chem 94 | from rdkit.Chem.Crippen import MolLogP 95 | import json 96 | import selfies as sf 97 | from tqdm import tqdm 98 | import random 99 | 100 | args = get_args() 101 | mols, data_properties = get_data_properties(args.test_file) 102 | # print(mols, len(mols)) 103 | 104 | # smiles 105 | # logps, sass, qeds = [], [], [] 106 | # for i, smi in enumerate(mols): 107 | # m = Chem.MolFromSmiles(smi) 108 | # p = MolLogP(m) 109 | # logps.append(p) 110 | # try: 111 | # SAS = sascorer.calculateScore(m) 112 | # except ZeroDivisionError: 113 | # SAS = 2.8 114 | # sass.append(SAS) 115 | # QED = qed(m) 116 | # qeds.append(QED) 117 | # 118 | # property_plot(logps, sass, qeds, data_properties) 119 | 120 | # selfies 121 | logps, sass, qeds = [], [], [] 122 | dir_test = '../data/Xtrain.npy' 123 | sf_test = np.load(dir_test) 124 | json_file = json.load(open('../data/info.json')) 125 | vocab_itos = json_file['vocab_itos'][0] 126 | # change the keywords of vocab_itos from string to int 127 | vocab_itos = _convert_keys_to_int(vocab_itos) 128 | 129 | randomlist = random.sample(range(0, len(sf_test)), 10000) 130 | random_data = sf_test[randomlist] 131 | for i, data in enumerate(tqdm(random_data)): 132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label') 133 | m_smi = sf.decoder(m_sf) 134 | m = Chem.MolFromSmiles(m_smi) 135 | p = MolLogP(m) 136 | logps.append(p) 137 | try: 138 | SAS = sascorer.calculateScore(m) 139 | except ZeroDivisionError: 140 | SAS = 2.8 141 | sass.append(SAS) 142 | QED = qed(m) 143 | qeds.append(QED) 144 | property_plot(logps, sass, qeds, data_properties) 145 | -------------------------------------------------------------------------------- /multi_design_esr1/stats.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | from rdkit.Chem.QED import qed 5 | import sascorer 6 | 7 | 8 | def property_plot(logps, sass, qeds, data_properties): 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | import pandas as pd 12 | data_logps, data_sass, data_qeds = data_properties 13 | assert len(data_logps) == len(data_sass) == len(data_qeds) 14 | assert len(logps) == len(sass) == len(qeds) 15 | assert len(data_logps) == len(logps) 16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps) 17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps}) 18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"]) 19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass}) 20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"]) 21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds}) 22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"]) 23 | 24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8]) 25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12) 26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0]) 27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1]) 28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2]) 29 | 30 | plt.show() 31 | plt.close(fig) 32 | 33 | 34 | def get_data_properties(path, seq_length=110): 35 | with open(path) as f: 36 | data_lines = f.readlines() 37 | 38 | logps = [] 39 | sass = [] 40 | qeds = [] 41 | mols = [] 42 | title = "" 43 | for line in data_lines: 44 | if line[0] == "#": 45 | title = line[1:-1] 46 | title_list = title.split() 47 | print(title_list) 48 | continue 49 | arr = line.split() 50 | # print(arr) 51 | if len(arr) < 2: 52 | continue 53 | smiles = arr[0] 54 | if len(smiles) > seq_length: 55 | continue 56 | assert len(arr) == 6 57 | mols.append(arr[0]) 58 | logps.append(arr[1]) 59 | sass.append(arr[2]) 60 | qeds.append(arr[3]) 61 | # smiles0 = smiles.ljust(seq_length, '>') 62 | # smiles_list += [smiles] 63 | # Narr = len(arr) 64 | # # cdd = [] 65 | # for i in range(1, Narr): 66 | # if title_list[i] == "logP": 67 | # cdd += [float(arr[i])/10.0] 68 | # elif title_list[i] == "SAS": 69 | # cdd += [float(arr[i])/10.0] 70 | # elif title_list[i] == "QED": 71 | # cdd += [float(arr[i])/1.0] 72 | # elif title_list[i] == "MW": 73 | # cdd += [float(arr[i])/500.0] 74 | # elif title_list[i] == "TPSA": 75 | # cdd += [float(arr[i])/150.0] 76 | return mols, (logps, sass, qeds) 77 | 78 | 79 | def _convert_keys_to_int(dict): 80 | new_dict = {} 81 | for k, v in dict.items(): 82 | try: 83 | new_key = int(k) 84 | except ValueError: 85 | new_key = k 86 | new_dict[new_key] = v 87 | return new_dict 88 | 89 | 90 | if __name__ == '__main__': 91 | from args import get_args 92 | import numpy as np 93 | from rdkit import Chem 94 | from rdkit.Chem.Crippen import MolLogP 95 | import json 96 | import selfies as sf 97 | from tqdm import tqdm 98 | import random 99 | 100 | args = get_args() 101 | mols, data_properties = get_data_properties(args.test_file) 102 | # print(mols, len(mols)) 103 | 104 | # smiles 105 | # logps, sass, qeds = [], [], [] 106 | # for i, smi in enumerate(mols): 107 | # m = Chem.MolFromSmiles(smi) 108 | # p = MolLogP(m) 109 | # logps.append(p) 110 | # try: 111 | # SAS = sascorer.calculateScore(m) 112 | # except ZeroDivisionError: 113 | # SAS = 2.8 114 | # sass.append(SAS) 115 | # QED = qed(m) 116 | # qeds.append(QED) 117 | # 118 | # property_plot(logps, sass, qeds, data_properties) 119 | 120 | # selfies 121 | logps, sass, qeds = [], [], [] 122 | dir_test = '../data/Xtrain.npy' 123 | sf_test = np.load(dir_test) 124 | json_file = json.load(open('../data/info.json')) 125 | vocab_itos = json_file['vocab_itos'][0] 126 | # change the keywords of vocab_itos from string to int 127 | vocab_itos = _convert_keys_to_int(vocab_itos) 128 | 129 | randomlist = random.sample(range(0, len(sf_test)), 10000) 130 | random_data = sf_test[randomlist] 131 | for i, data in enumerate(tqdm(random_data)): 132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label') 133 | m_smi = sf.decoder(m_sf) 134 | m = Chem.MolFromSmiles(m_smi) 135 | p = MolLogP(m) 136 | logps.append(p) 137 | try: 138 | SAS = sascorer.calculateScore(m) 139 | except ZeroDivisionError: 140 | SAS = 2.8 141 | sass.append(SAS) 142 | QED = qed(m) 143 | qeds.append(QED) 144 | property_plot(logps, sass, qeds, data_properties) 145 | -------------------------------------------------------------------------------- /single_design_qed/stats.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import seaborn as sns 3 | import pandas as pd 4 | from rdkit.Chem.QED import qed 5 | import sascorer 6 | 7 | 8 | def property_plot(logps, sass, qeds, data_properties): 9 | from matplotlib import pyplot as plt 10 | import seaborn as sns 11 | import pandas as pd 12 | data_logps, data_sass, data_qeds = data_properties 13 | assert len(data_logps) == len(data_sass) == len(data_qeds) 14 | assert len(logps) == len(sass) == len(qeds) 15 | assert len(data_logps) == len(logps) 16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps) 17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps}) 18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"]) 19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass}) 20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"]) 21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds}) 22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"]) 23 | 24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8]) 25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12) 26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0]) 27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1]) 28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2]) 29 | 30 | plt.show() 31 | plt.close(fig) 32 | 33 | 34 | def get_data_properties(path, seq_length=110): 35 | with open(path) as f: 36 | data_lines = f.readlines() 37 | 38 | logps = [] 39 | sass = [] 40 | qeds = [] 41 | mols = [] 42 | title = "" 43 | for line in data_lines: 44 | if line[0] == "#": 45 | title = line[1:-1] 46 | title_list = title.split() 47 | print(title_list) 48 | continue 49 | arr = line.split() 50 | # print(arr) 51 | if len(arr) < 2: 52 | continue 53 | smiles = arr[0] 54 | if len(smiles) > seq_length: 55 | continue 56 | assert len(arr) == 6 57 | mols.append(arr[0]) 58 | logps.append(arr[1]) 59 | sass.append(arr[2]) 60 | qeds.append(arr[3]) 61 | # smiles0 = smiles.ljust(seq_length, '>') 62 | # smiles_list += [smiles] 63 | # Narr = len(arr) 64 | # # cdd = [] 65 | # for i in range(1, Narr): 66 | # if title_list[i] == "logP": 67 | # cdd += [float(arr[i])/10.0] 68 | # elif title_list[i] == "SAS": 69 | # cdd += [float(arr[i])/10.0] 70 | # elif title_list[i] == "QED": 71 | # cdd += [float(arr[i])/1.0] 72 | # elif title_list[i] == "MW": 73 | # cdd += [float(arr[i])/500.0] 74 | # elif title_list[i] == "TPSA": 75 | # cdd += [float(arr[i])/150.0] 76 | return mols, (logps, sass, qeds) 77 | 78 | 79 | def _convert_keys_to_int(dict): 80 | new_dict = {} 81 | for k, v in dict.items(): 82 | try: 83 | new_key = int(k) 84 | except ValueError: 85 | new_key = k 86 | new_dict[new_key] = v 87 | return new_dict 88 | 89 | 90 | if __name__ == '__main__': 91 | from args import get_args 92 | import numpy as np 93 | from rdkit import Chem 94 | from rdkit.Chem.Crippen import MolLogP 95 | import json 96 | import selfies as sf 97 | from tqdm import tqdm 98 | import random 99 | 100 | args = get_args() 101 | mols, data_properties = get_data_properties(args.test_file) 102 | # print(mols, len(mols)) 103 | 104 | # smiles 105 | # logps, sass, qeds = [], [], [] 106 | # for i, smi in enumerate(mols): 107 | # m = Chem.MolFromSmiles(smi) 108 | # p = MolLogP(m) 109 | # logps.append(p) 110 | # try: 111 | # SAS = sascorer.calculateScore(m) 112 | # except ZeroDivisionError: 113 | # SAS = 2.8 114 | # sass.append(SAS) 115 | # QED = qed(m) 116 | # qeds.append(QED) 117 | # 118 | # property_plot(logps, sass, qeds, data_properties) 119 | 120 | # selfies 121 | logps, sass, qeds = [], [], [] 122 | dir_test = '../data/Xtrain.npy' 123 | sf_test = np.load(dir_test) 124 | json_file = json.load(open('../data/info.json')) 125 | vocab_itos = json_file['vocab_itos'][0] 126 | # change the keywords of vocab_itos from string to int 127 | vocab_itos = _convert_keys_to_int(vocab_itos) 128 | 129 | randomlist = random.sample(range(0, len(sf_test)), 10000) 130 | random_data = sf_test[randomlist] 131 | for i, data in enumerate(tqdm(random_data)): 132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label') 133 | m_smi = sf.decoder(m_sf) 134 | m = Chem.MolFromSmiles(m_smi) 135 | p = MolLogP(m) 136 | logps.append(p) 137 | try: 138 | SAS = sascorer.calculateScore(m) 139 | except ZeroDivisionError: 140 | SAS = 2.8 141 | sass.append(SAS) 142 | QED = qed(m) 143 | qeds.append(QED) 144 | property_plot(logps, sass, qeds, data_properties) 145 | -------------------------------------------------------------------------------- /multi_design_acaa1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | import math 15 | 16 | 17 | # max_len = 72 18 | # num_of_embeddings = 109 19 | class MolDataset(Dataset): 20 | def __init__(self, datadir, dname): 21 | Xdata_file = datadir + "/X" + dname + ".npy" 22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 23 | Ldata_file = datadir + "/L" + dname + ".npy" 24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 25 | self.len = self.Xdata.shape[0] 26 | sasdata_file = datadir + "/sas_" + dname + ".npy" 27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32) 28 | qed_data_file = datadir + "/qed_" + dname + ".npy" 29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 30 | 31 | # ba0data_file = datadir + "/ba0_" + dname + ".npy" 32 | # self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32) 33 | ba1data_file = datadir + "/ba1_" + dname + ".npy" 34 | self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32) 35 | # self.preprocess() 36 | 37 | def preprocess(self): 38 | ind = torch.nonzero(self.ba0) 39 | self.ba0 = self.ba0[ind].squeeze() 40 | self.sas = self.sas[ind].squeeze() 41 | self.qed = self.qed[ind].squeeze() 42 | self.Xdata = self.Xdata[ind].squeeze() 43 | self.Ldata = self.Ldata[ind].squeeze() 44 | self.len = self.Xdata.shape[0] 45 | 46 | def __getitem__(self, index): 47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 48 | mol = self.Xdata[index] 49 | sos = torch.tensor([108], dtype=torch.long) 50 | mol = torch.cat([sos, mol], dim=0).contiguous() 51 | mask = torch.zeros(mol.shape[0] - 1) 52 | mask[:self.Ldata[index] + 1] = 1. 53 | return (mol, mask, -1 * self.ba1[index]) 54 | 55 | def __len__(self): 56 | return self.len 57 | 58 | def save_mol_png(self, label, filepath, size=(600, 600)): 59 | m_smi = self.label2sf2smi(label) 60 | m = Chem.MolFromSmiles(m_smi) 61 | Draw.MolToFile(m, filepath, size=size) 62 | 63 | def label2sf2smi(self, label): 64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 65 | m_smi = sf.decoder(m_sf) 66 | m_smi = Chem.CanonSmiles(m_smi) 67 | return m_smi 68 | 69 | @staticmethod 70 | def delta_to_kd(x): 71 | return math.exp(x / (0.00198720425864083 * 298.15)) 72 | 73 | 74 | if __name__ == '__main__': 75 | datadir = '../data' 76 | # Xdata_file = datadir + "/Xtrain.npy" 77 | # x_train = np.load(Xdata_file) 78 | # x_test = np.load('../data/Xtest.npy') 79 | # x_all = np.concatenate((x_train, x_test), axis=0) 80 | # print(x_all.shape) 81 | # 82 | # ba_train = np.load('../data/ba0_train.npy') 83 | # ba_test = np.load('../data/ba0_test.npy') 84 | # ba_all = np.concatenate((ba_train, ba_test), axis=0) 85 | # print(ba_all.shape) 86 | # ind = np.argsort(ba_all) 87 | # ind = ind[:10000] 88 | # 89 | # sas_train = np.load('../data/sas_train.npy') 90 | # sas_test = np.load('../data/sas_test.npy') 91 | # sas_all = np.concatenate((sas_train, sas_test), axis=0) 92 | # 93 | # l_train = np.load('../data/Ltrain.npy') 94 | # l_test = np.load('../data/Ltest.npy') 95 | # l_all = np.concatenate((l_train, l_test), axis=0) 96 | # 97 | # qed_train = np.load('../data/qed_train.npy') 98 | # qed_test = np.load('../data/qed_test.npy') 99 | # qed_all = np.concatenate((sas_train, sas_test), axis=0) 100 | # 101 | # sas_design = sas_all[ind] 102 | # ba_design = ba_all[ind] 103 | # x_design = x_all[ind] 104 | # qed_design = qed_all[ind] 105 | # l_design = l_all[ind] 106 | # 107 | # np.save('../data/Xdesign.npy', x_design) 108 | # np.save('../data/ba0_design.npy', ba_design) 109 | # np.save('../data/qed_design.npy', qed_design) 110 | # np.save('../data/sas_design.npy', sas_design) 111 | # np.save('../data/Ldesign.npy', l_design) 112 | 113 | ds = MolDataset(datadir, 'train') 114 | ds_loader = DataLoader(dataset=ds, batch_size=100, 115 | shuffle=True, drop_last=True, num_workers=2) 116 | 117 | s = ds.ba1 118 | print(len(s)) 119 | print(s[:10]) 120 | # print(ds.Xdata[:10]) 121 | # print(ds.len) 122 | # non_zero = torch.nonzero(s) 123 | # print(non_zero) 124 | # s = s[non_zero] 125 | # print(len(s)) 126 | # print(s[1000:2000]) 127 | # print(torch.max(s), torch.min(s)) 128 | # print(ds.delta_to_kd(torch.max(s)), ds.delta_to_kd(torch.min(s))) 129 | # max_len = 0 130 | # for i in range(len(ds)): 131 | # print(len(ds)) 132 | # m, len, ba, sas, qed = ds[i] 133 | # m = m.numpy() 134 | # mol = ds.label2sf2smi(m[1:]) 135 | # print(mol) 136 | # print(ba, sas, qed) 137 | # break 138 | 139 | # print(char_dict) 140 | # print(ds.sf2smi(m), s) 141 | -------------------------------------------------------------------------------- /multi_design_esr1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import numpy as np 4 | from ZINC.char import char_list, char_dict 5 | import selfies as sf 6 | import json 7 | from rdkit.Chem import Draw 8 | from rdkit import Chem 9 | from rdkit.Chem.Crippen import MolLogP 10 | from rdkit.Chem.QED import qed 11 | from rdkit.Chem.Descriptors import ExactMolWt 12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA 13 | import sascorer 14 | import math 15 | 16 | 17 | # max_len = 72 18 | # num_of_embeddings = 109 19 | class MolDataset(Dataset): 20 | def __init__(self, datadir, dname): 21 | Xdata_file = datadir + "/X" + dname + ".npy" 22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule 23 | Ldata_file = datadir + "/L" + dname + ".npy" 24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule 25 | self.len = self.Xdata.shape[0] 26 | sasdata_file = datadir + "/sas_" + dname + ".npy" 27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32) 28 | qed_data_file = datadir + "/qed_" + dname + ".npy" 29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32) 30 | 31 | ba0data_file = datadir + "/ba0_" + dname + ".npy" 32 | self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32) 33 | # ba1data_file = datadir + "/ba1_" + dname + ".npy" 34 | # self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32) 35 | # self.preprocess() 36 | 37 | def preprocess(self): 38 | ind = torch.nonzero(self.ba0) 39 | self.ba0 = self.ba0[ind].squeeze() 40 | self.sas = self.sas[ind].squeeze() 41 | self.qed = self.qed[ind].squeeze() 42 | self.Xdata = self.Xdata[ind].squeeze() 43 | self.Ldata = self.Ldata[ind].squeeze() 44 | self.len = self.Xdata.shape[0] 45 | 46 | def __getitem__(self, index): 47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies. 48 | mol = self.Xdata[index] 49 | sos = torch.tensor([108], dtype=torch.long) 50 | mol = torch.cat([sos, mol], dim=0).contiguous() 51 | mask = torch.zeros(mol.shape[0] - 1) 52 | mask[:self.Ldata[index] + 1] = 1. 53 | return (mol, mask, -1 * self.ba0[index]) 54 | 55 | def __len__(self): 56 | return self.len 57 | 58 | def save_mol_png(self, label, filepath, size=(600, 600)): 59 | m_smi = self.label2sf2smi(label) 60 | m = Chem.MolFromSmiles(m_smi) 61 | Draw.MolToFile(m, filepath, size=size) 62 | 63 | def label2sf2smi(self, label): 64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label') 65 | m_smi = sf.decoder(m_sf) 66 | m_smi = Chem.CanonSmiles(m_smi) 67 | return m_smi 68 | 69 | @staticmethod 70 | def delta_to_kd(x): 71 | return math.exp(x / (0.00198720425864083 * 298.15)) 72 | 73 | 74 | if __name__ == '__main__': 75 | datadir = '../data' 76 | # Xdata_file = datadir + "/Xtrain.npy" 77 | # x_train = np.load(Xdata_file) 78 | # x_test = np.load('../data/Xtest.npy') 79 | # x_all = np.concatenate((x_train, x_test), axis=0) 80 | # print(x_all.shape) 81 | # 82 | # ba_train = np.load('../data/ba0_train.npy') 83 | # ba_test = np.load('../data/ba0_test.npy') 84 | # ba_all = np.concatenate((ba_train, ba_test), axis=0) 85 | # print(ba_all.shape) 86 | # ind = np.argsort(ba_all) 87 | # ind = ind[:10000] 88 | # 89 | # sas_train = np.load('../data/sas_train.npy') 90 | # sas_test = np.load('../data/sas_test.npy') 91 | # sas_all = np.concatenate((sas_train, sas_test), axis=0) 92 | # 93 | # l_train = np.load('../data/Ltrain.npy') 94 | # l_test = np.load('../data/Ltest.npy') 95 | # l_all = np.concatenate((l_train, l_test), axis=0) 96 | # 97 | # qed_train = np.load('../data/qed_train.npy') 98 | # qed_test = np.load('../data/qed_test.npy') 99 | # qed_all = np.concatenate((sas_train, sas_test), axis=0) 100 | # 101 | # sas_design = sas_all[ind] 102 | # ba_design = ba_all[ind] 103 | # x_design = x_all[ind] 104 | # qed_design = qed_all[ind] 105 | # l_design = l_all[ind] 106 | # 107 | # np.save('../data/Xdesign.npy', x_design) 108 | # np.save('../data/ba0_design.npy', ba_design) 109 | # np.save('../data/qed_design.npy', qed_design) 110 | # np.save('../data/sas_design.npy', sas_design) 111 | # np.save('../data/Ldesign.npy', l_design) 112 | 113 | ds = MolDataset(datadir, 'train') 114 | ds_loader = DataLoader(dataset=ds, batch_size=100, 115 | shuffle=True, drop_last=True, num_workers=2) 116 | 117 | s = ds.ba1 118 | print(len(s)) 119 | print(s[:10]) 120 | # print(ds.Xdata[:10]) 121 | # print(ds.len) 122 | # non_zero = torch.nonzero(s) 123 | # print(non_zero) 124 | # s = s[non_zero] 125 | # print(len(s)) 126 | # print(s[1000:2000]) 127 | # print(torch.max(s), torch.min(s)) 128 | # print(ds.delta_to_kd(torch.max(s)), ds.delta_to_kd(torch.min(s))) 129 | # max_len = 0 130 | # for i in range(len(ds)): 131 | # print(len(ds)) 132 | # m, len, ba, sas, qed = ds[i] 133 | # m = m.numpy() 134 | # mol = ds.label2sf2smi(m[1:]) 135 | # print(mol) 136 | # print(ba, sas, qed) 137 | # break 138 | 139 | # print(char_dict) 140 | # print(ds.sf2smi(m), s) 141 | -------------------------------------------------------------------------------- /multi_design_acaa1/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /multi_design_esr1/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /single_design_acaa1/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /single_design_esr1/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /single_design_plogp/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /single_design_qed/sascorer.py: -------------------------------------------------------------------------------- 1 | # 2 | # calculation of synthetic accessibility score as described in: 3 | # 4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions 5 | # Peter Ertl and Ansgar Schuffenhauer 6 | # Journal of Cheminformatics 1:8 (2009) 7 | # http://www.jcheminf.com/content/1/1/8 8 | # 9 | # several small modifications to the original paper are included 10 | # particularly slightly different formula for marocyclic penalty 11 | # and taking into account also molecule symmetry (fingerprint density) 12 | # 13 | # for a set of 10k diverse molecules the agreement between the original method 14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97 15 | # 16 | # peter ertl & greg landrum, september 2013 17 | # 18 | 19 | 20 | from rdkit import Chem 21 | from rdkit.Chem import rdMolDescriptors 22 | import pickle 23 | 24 | import math 25 | from collections import defaultdict 26 | 27 | import os.path as op 28 | 29 | _fscores = None 30 | 31 | 32 | def readFragmentScores(name='fpscores'): 33 | import gzip 34 | global _fscores 35 | # generate the full path filename: 36 | if name == "fpscores": 37 | name = op.join(op.dirname(__file__), name) 38 | data = pickle.load(gzip.open('%s.pkl.gz' % name)) 39 | outDict = {} 40 | for i in data: 41 | for j in range(1, len(i)): 42 | outDict[i[j]] = float(i[0]) 43 | _fscores = outDict 44 | 45 | 46 | def numBridgeheadsAndSpiro(mol, ri=None): 47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol) 48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol) 49 | return nBridgehead, nSpiro 50 | 51 | 52 | def calculateScore(m): 53 | if _fscores is None: 54 | readFragmentScores() 55 | 56 | # fragment score 57 | fp = rdMolDescriptors.GetMorganFingerprint(m, 58 | 2) # <- 2 is the *radius* of the circular fingerprint 59 | fps = fp.GetNonzeroElements() 60 | score1 = 0. 61 | nf = 0 62 | for bitId, v in fps.items(): 63 | nf += v 64 | sfp = bitId 65 | score1 += _fscores.get(sfp, -4) * v 66 | score1 /= nf 67 | 68 | # features score 69 | nAtoms = m.GetNumAtoms() 70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True)) 71 | ri = m.GetRingInfo() 72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri) 73 | nMacrocycles = 0 74 | for x in ri.AtomRings(): 75 | if len(x) > 8: 76 | nMacrocycles += 1 77 | 78 | sizePenalty = nAtoms**1.005 - nAtoms 79 | stereoPenalty = math.log10(nChiralCenters + 1) 80 | spiroPenalty = math.log10(nSpiro + 1) 81 | bridgePenalty = math.log10(nBridgeheads + 1) 82 | macrocyclePenalty = 0. 83 | # --------------------------------------- 84 | # This differs from the paper, which defines: 85 | # macrocyclePenalty = math.log10(nMacrocycles+1) 86 | # This form generates better results when 2 or more macrocycles are present 87 | if nMacrocycles > 0: 88 | macrocyclePenalty = math.log10(2) 89 | 90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty 91 | 92 | # correction for the fingerprint density 93 | # not in the original publication, added in version 1.1 94 | # to make highly symmetrical molecules easier to synthetise 95 | score3 = 0. 96 | if nAtoms > len(fps): 97 | score3 = math.log(float(nAtoms) / len(fps)) * .5 98 | 99 | sascore = score1 + score2 + score3 100 | 101 | # need to transform "raw" value into scale between 1 and 10 102 | min = -4.0 103 | max = 2.5 104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9. 105 | # smooth the 10-end 106 | if sascore > 8.: 107 | sascore = 8. + math.log(sascore + 1. - 9.) 108 | if sascore > 10.: 109 | sascore = 10.0 110 | elif sascore < 1.: 111 | sascore = 1.0 112 | 113 | return sascore 114 | 115 | 116 | def processMols(mols): 117 | print('smiles\tName\tsa_score') 118 | for i, m in enumerate(mols): 119 | if m is None: 120 | continue 121 | 122 | s = calculateScore(m) 123 | 124 | smiles = Chem.MolToSmiles(m) 125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s) 126 | 127 | 128 | if __name__ == '__main__': 129 | import sys 130 | import time 131 | 132 | t1 = time.time() 133 | readFragmentScores("fpscores") 134 | t2 = time.time() 135 | 136 | suppl = Chem.SmilesMolSupplier(sys.argv[1]) 137 | t3 = time.time() 138 | processMols(suppl) 139 | t4 = time.time() 140 | 141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)), 142 | file=sys.stderr) 143 | 144 | # 145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc. 146 | # All rights reserved. 147 | # 148 | # Redistribution and use in source and binary forms, with or without 149 | # modification, are permitted provided that the following conditions are 150 | # met: 151 | # 152 | # * Redistributions of source code must retain the above copyright 153 | # notice, this list of conditions and the following disclaimer. 154 | # * Redistributions in binary form must reproduce the above 155 | # copyright notice, this list of conditions and the following 156 | # disclaimer in the documentation and/or other materials provided 157 | # with the distribution. 158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc. 159 | # nor the names of its contributors may be used to endorse or promote 160 | # products derived from this software without specific prior written permission. 161 | # 162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 173 | # -------------------------------------------------------------------------------- /multi_design_acaa1/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import shutil 6 | import os.path as osp 7 | import json 8 | 9 | import dateutil.tz 10 | 11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv'] 12 | 13 | DEBUG = 10 14 | INFO = 20 15 | WARN = 30 16 | ERROR = 40 17 | 18 | DISABLED = 50 19 | 20 | 21 | class OutputFormat(object): 22 | def writekvs(self, kvs): 23 | """ 24 | Write key-value pairs 25 | """ 26 | raise NotImplementedError 27 | 28 | def writeseq(self, args): 29 | """ 30 | Write a sequence of other data (e.g. a logging message) 31 | """ 32 | pass 33 | 34 | def close(self): 35 | return 36 | 37 | 38 | class HumanOutputFormat(OutputFormat): 39 | def __init__(self, file): 40 | self.file = file 41 | 42 | def writekvs(self, kvs): 43 | # Create strings for printing 44 | key2str = OrderedDict() 45 | for (key, val) in kvs.items(): 46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val 47 | key2str[self._truncate(key)] = self._truncate(valstr) 48 | 49 | # Find max widths 50 | keywidth = max(map(len, key2str.keys())) 51 | valwidth = max(map(len, key2str.values())) 52 | 53 | # Write out the data 54 | dashes = '-' * (keywidth + valwidth + 7) 55 | lines = [dashes] 56 | for (key, val) in key2str.items(): 57 | lines.append('| %s%s | %s%s |' % ( 58 | key, 59 | ' ' * (keywidth - len(key)), 60 | val, 61 | ' ' * (valwidth - len(val)), 62 | )) 63 | lines.append(dashes) 64 | self.file.write('\n'.join(lines) + '\n') 65 | 66 | # Flush the output to the file 67 | self.file.flush() 68 | 69 | def _truncate(self, s): 70 | return s[:20] + '...' if len(s) > 23 else s 71 | 72 | def writeseq(self, args): 73 | for arg in args: 74 | self.file.write(arg) 75 | self.file.write('\n') 76 | self.file.flush() 77 | 78 | 79 | class JSONOutputFormat(OutputFormat): 80 | def __init__(self, file): 81 | self.file = file 82 | 83 | def writekvs(self, kvs): 84 | for k, v in kvs.items(): 85 | if hasattr(v, 'dtype'): 86 | v = v.tolist() 87 | kvs[k] = v 88 | self.file.write(json.dumps(kvs) + '\n') 89 | self.file.flush() 90 | 91 | def close(self): 92 | self.file.close() 93 | 94 | 95 | class CSVOutputFormat(OutputFormat): 96 | def __init__(self, file): 97 | self.file = file 98 | self.keys = [] 99 | self.sep = ',' 100 | 101 | def writekvs(self, kvs): 102 | # Add our current row to the history 103 | extra_keys = kvs.keys() - self.keys 104 | if extra_keys: 105 | self.keys.extend(extra_keys) 106 | self.file.seek(0) 107 | lines = self.file.readlines() 108 | self.file.seek(0) 109 | for (i, k) in enumerate(self.keys): 110 | if i > 0: 111 | self.file.write(',') 112 | self.file.write(k) 113 | self.file.write('\n') 114 | for line in lines[1:]: 115 | self.file.write(line[:-1]) 116 | self.file.write(self.sep * len(extra_keys)) 117 | self.file.write('\n') 118 | for (i, k) in enumerate(self.keys): 119 | if i > 0: 120 | self.file.write(',') 121 | v = kvs.get(k) 122 | if v is not None: 123 | self.file.write(str(v)) 124 | self.file.write('\n') 125 | self.file.flush() 126 | 127 | def close(self): 128 | self.file.close() 129 | 130 | 131 | def make_output_format(format, ev_dir): 132 | os.makedirs(ev_dir, exist_ok=True) 133 | if format == 'stdout': 134 | return HumanOutputFormat(sys.stdout) 135 | elif format == 'log': 136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') 137 | return HumanOutputFormat(log_file) 138 | elif format == 'json': 139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') 140 | return JSONOutputFormat(json_file) 141 | elif format == 'csv': 142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t') 143 | return CSVOutputFormat(csv_file) 144 | else: 145 | raise ValueError('Unknown format specified: %s' % (format,)) 146 | 147 | 148 | # ================================================================ 149 | # API 150 | # ================================================================ 151 | def log_params(params): 152 | assert isinstance(params, dict) 153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt') 154 | output_format = JSONOutputFormat(json_file) 155 | output_format.writekvs(params) 156 | output_format.close() 157 | 158 | def logkv(key, val): 159 | """ 160 | Log a value of some diagnostic 161 | Call this once for each diagnostic quantity, each iteration 162 | """ 163 | Logger.CURRENT.logkv(key, val) 164 | 165 | 166 | def dumpkvs(): 167 | """ 168 | Write all of the diagnostics from the current iteration 169 | 170 | level: int. (see old_logger.py docs) If the global logger level is higher than 171 | the level argument here, don't print to stdout. 172 | """ 173 | Logger.CURRENT.dumpkvs() 174 | 175 | 176 | # for backwards compatibility 177 | record_tabular = logkv 178 | dump_tabular = dumpkvs 179 | 180 | 181 | def log(*args, level=INFO): 182 | """ 183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 184 | """ 185 | Logger.CURRENT.log(*args, level=level) 186 | 187 | 188 | def debug(*args): 189 | log(*args, level=DEBUG) 190 | 191 | 192 | def info(*args): 193 | log(*args, level=INFO) 194 | 195 | 196 | def warn(*args): 197 | log(*args, level=WARN) 198 | 199 | 200 | def error(*args): 201 | log(*args, level=ERROR) 202 | 203 | 204 | def set_level(level): 205 | """ 206 | Set logging threshold on current logger. 207 | """ 208 | Logger.CURRENT.set_level(level) 209 | 210 | 211 | def get_level(): 212 | """ 213 | Set logging threshold on current logger. 214 | """ 215 | return Logger.CURRENT.level 216 | 217 | 218 | def get_dir(): 219 | """ 220 | Get directory that log files are being written to. 221 | will be None if there is no output directory (i.e., if you didn't call start) 222 | """ 223 | return Logger.CURRENT.get_dir() 224 | 225 | 226 | def get_expt_dir(): 227 | sys.stderr.write( 228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),)) 229 | return get_dir() 230 | 231 | 232 | # ================================================================ 233 | # Backend 234 | # ================================================================ 235 | 236 | 237 | class Logger(object): 238 | # A logger with no output files. (See right below class definition) 239 | DEFAULT = None 240 | # So that you can still log to the terminal without setting up any output files 241 | CURRENT = None # Current logger being used by the free functions above 242 | 243 | def __init__(self, dir, output_formats): 244 | self.name2val = OrderedDict() # values this iteration 245 | self.level = INFO 246 | self.dir = dir 247 | self.output_formats = output_formats 248 | 249 | # Logging API, forwarded 250 | # ---------------------------------------- 251 | def logkv(self, key, val): 252 | self.name2val[key] = val 253 | 254 | def dumpkvs(self): 255 | for fmt in self.output_formats: 256 | fmt.writekvs(self.name2val) 257 | self.name2val.clear() 258 | 259 | def log(self, *args, level=INFO): 260 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ') 262 | if self.level <= level: 263 | self._do_log((timestamp,) + args) 264 | 265 | # Configuration 266 | # ---------------------------------------- 267 | def set_level(self, level): 268 | self.level = level 269 | 270 | def get_dir(self): 271 | return self.dir 272 | 273 | def close(self): 274 | for fmt in self.output_formats: 275 | fmt.close() 276 | 277 | # Misc 278 | # ---------------------------------------- 279 | def _do_log(self, args): 280 | for fmt in self.output_formats: 281 | fmt.writeseq(args) 282 | 283 | 284 | # ================================================================ 285 | 286 | Logger.DEFAULT = Logger( 287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None) 288 | Logger.CURRENT = Logger.DEFAULT 289 | 290 | 291 | class session(object): 292 | """ 293 | Context manager that sets up the loggers for an experiment. 294 | """ 295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS): 296 | self.dir = dir 297 | self.format_strs = format_strs 298 | 299 | def __enter__(self): 300 | os.makedirs(self.dir, exist_ok=True) 301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs] 302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) 303 | 304 | def __exit__(self, *args): 305 | Logger.CURRENT.close() 306 | Logger.CURRENT = Logger.DEFAULT 307 | 308 | 309 | # ================================================================ 310 | 311 | 312 | def main(): 313 | info("hi") 314 | debug("shouldn't appear") 315 | set_level(DEBUG) 316 | debug("should appear") 317 | dir = "/tmp/testlogging" 318 | if os.path.exists(dir): 319 | shutil.rmtree(dir) 320 | with session(dir=dir): 321 | record_tabular("a", 3) 322 | record_tabular("b", 2.5) 323 | dump_tabular() 324 | record_tabular("b", -2.5) 325 | record_tabular("a", 5.5) 326 | dump_tabular() 327 | info("^^^ should see a = 5.5") 328 | 329 | record_tabular("b", -2.5) 330 | dump_tabular() 331 | 332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") 333 | dump_tabular() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() -------------------------------------------------------------------------------- /multi_design_esr1/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import shutil 6 | import os.path as osp 7 | import json 8 | 9 | import dateutil.tz 10 | 11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv'] 12 | 13 | DEBUG = 10 14 | INFO = 20 15 | WARN = 30 16 | ERROR = 40 17 | 18 | DISABLED = 50 19 | 20 | 21 | class OutputFormat(object): 22 | def writekvs(self, kvs): 23 | """ 24 | Write key-value pairs 25 | """ 26 | raise NotImplementedError 27 | 28 | def writeseq(self, args): 29 | """ 30 | Write a sequence of other data (e.g. a logging message) 31 | """ 32 | pass 33 | 34 | def close(self): 35 | return 36 | 37 | 38 | class HumanOutputFormat(OutputFormat): 39 | def __init__(self, file): 40 | self.file = file 41 | 42 | def writekvs(self, kvs): 43 | # Create strings for printing 44 | key2str = OrderedDict() 45 | for (key, val) in kvs.items(): 46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val 47 | key2str[self._truncate(key)] = self._truncate(valstr) 48 | 49 | # Find max widths 50 | keywidth = max(map(len, key2str.keys())) 51 | valwidth = max(map(len, key2str.values())) 52 | 53 | # Write out the data 54 | dashes = '-' * (keywidth + valwidth + 7) 55 | lines = [dashes] 56 | for (key, val) in key2str.items(): 57 | lines.append('| %s%s | %s%s |' % ( 58 | key, 59 | ' ' * (keywidth - len(key)), 60 | val, 61 | ' ' * (valwidth - len(val)), 62 | )) 63 | lines.append(dashes) 64 | self.file.write('\n'.join(lines) + '\n') 65 | 66 | # Flush the output to the file 67 | self.file.flush() 68 | 69 | def _truncate(self, s): 70 | return s[:20] + '...' if len(s) > 23 else s 71 | 72 | def writeseq(self, args): 73 | for arg in args: 74 | self.file.write(arg) 75 | self.file.write('\n') 76 | self.file.flush() 77 | 78 | 79 | class JSONOutputFormat(OutputFormat): 80 | def __init__(self, file): 81 | self.file = file 82 | 83 | def writekvs(self, kvs): 84 | for k, v in kvs.items(): 85 | if hasattr(v, 'dtype'): 86 | v = v.tolist() 87 | kvs[k] = v 88 | self.file.write(json.dumps(kvs) + '\n') 89 | self.file.flush() 90 | 91 | def close(self): 92 | self.file.close() 93 | 94 | 95 | class CSVOutputFormat(OutputFormat): 96 | def __init__(self, file): 97 | self.file = file 98 | self.keys = [] 99 | self.sep = ',' 100 | 101 | def writekvs(self, kvs): 102 | # Add our current row to the history 103 | extra_keys = kvs.keys() - self.keys 104 | if extra_keys: 105 | self.keys.extend(extra_keys) 106 | self.file.seek(0) 107 | lines = self.file.readlines() 108 | self.file.seek(0) 109 | for (i, k) in enumerate(self.keys): 110 | if i > 0: 111 | self.file.write(',') 112 | self.file.write(k) 113 | self.file.write('\n') 114 | for line in lines[1:]: 115 | self.file.write(line[:-1]) 116 | self.file.write(self.sep * len(extra_keys)) 117 | self.file.write('\n') 118 | for (i, k) in enumerate(self.keys): 119 | if i > 0: 120 | self.file.write(',') 121 | v = kvs.get(k) 122 | if v is not None: 123 | self.file.write(str(v)) 124 | self.file.write('\n') 125 | self.file.flush() 126 | 127 | def close(self): 128 | self.file.close() 129 | 130 | 131 | def make_output_format(format, ev_dir): 132 | os.makedirs(ev_dir, exist_ok=True) 133 | if format == 'stdout': 134 | return HumanOutputFormat(sys.stdout) 135 | elif format == 'log': 136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') 137 | return HumanOutputFormat(log_file) 138 | elif format == 'json': 139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') 140 | return JSONOutputFormat(json_file) 141 | elif format == 'csv': 142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t') 143 | return CSVOutputFormat(csv_file) 144 | else: 145 | raise ValueError('Unknown format specified: %s' % (format,)) 146 | 147 | 148 | # ================================================================ 149 | # API 150 | # ================================================================ 151 | def log_params(params): 152 | assert isinstance(params, dict) 153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt') 154 | output_format = JSONOutputFormat(json_file) 155 | output_format.writekvs(params) 156 | output_format.close() 157 | 158 | def logkv(key, val): 159 | """ 160 | Log a value of some diagnostic 161 | Call this once for each diagnostic quantity, each iteration 162 | """ 163 | Logger.CURRENT.logkv(key, val) 164 | 165 | 166 | def dumpkvs(): 167 | """ 168 | Write all of the diagnostics from the current iteration 169 | 170 | level: int. (see old_logger.py docs) If the global logger level is higher than 171 | the level argument here, don't print to stdout. 172 | """ 173 | Logger.CURRENT.dumpkvs() 174 | 175 | 176 | # for backwards compatibility 177 | record_tabular = logkv 178 | dump_tabular = dumpkvs 179 | 180 | 181 | def log(*args, level=INFO): 182 | """ 183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 184 | """ 185 | Logger.CURRENT.log(*args, level=level) 186 | 187 | 188 | def debug(*args): 189 | log(*args, level=DEBUG) 190 | 191 | 192 | def info(*args): 193 | log(*args, level=INFO) 194 | 195 | 196 | def warn(*args): 197 | log(*args, level=WARN) 198 | 199 | 200 | def error(*args): 201 | log(*args, level=ERROR) 202 | 203 | 204 | def set_level(level): 205 | """ 206 | Set logging threshold on current logger. 207 | """ 208 | Logger.CURRENT.set_level(level) 209 | 210 | 211 | def get_level(): 212 | """ 213 | Set logging threshold on current logger. 214 | """ 215 | return Logger.CURRENT.level 216 | 217 | 218 | def get_dir(): 219 | """ 220 | Get directory that log files are being written to. 221 | will be None if there is no output directory (i.e., if you didn't call start) 222 | """ 223 | return Logger.CURRENT.get_dir() 224 | 225 | 226 | def get_expt_dir(): 227 | sys.stderr.write( 228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),)) 229 | return get_dir() 230 | 231 | 232 | # ================================================================ 233 | # Backend 234 | # ================================================================ 235 | 236 | 237 | class Logger(object): 238 | # A logger with no output files. (See right below class definition) 239 | DEFAULT = None 240 | # So that you can still log to the terminal without setting up any output files 241 | CURRENT = None # Current logger being used by the free functions above 242 | 243 | def __init__(self, dir, output_formats): 244 | self.name2val = OrderedDict() # values this iteration 245 | self.level = INFO 246 | self.dir = dir 247 | self.output_formats = output_formats 248 | 249 | # Logging API, forwarded 250 | # ---------------------------------------- 251 | def logkv(self, key, val): 252 | self.name2val[key] = val 253 | 254 | def dumpkvs(self): 255 | for fmt in self.output_formats: 256 | fmt.writekvs(self.name2val) 257 | self.name2val.clear() 258 | 259 | def log(self, *args, level=INFO): 260 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ') 262 | if self.level <= level: 263 | self._do_log((timestamp,) + args) 264 | 265 | # Configuration 266 | # ---------------------------------------- 267 | def set_level(self, level): 268 | self.level = level 269 | 270 | def get_dir(self): 271 | return self.dir 272 | 273 | def close(self): 274 | for fmt in self.output_formats: 275 | fmt.close() 276 | 277 | # Misc 278 | # ---------------------------------------- 279 | def _do_log(self, args): 280 | for fmt in self.output_formats: 281 | fmt.writeseq(args) 282 | 283 | 284 | # ================================================================ 285 | 286 | Logger.DEFAULT = Logger( 287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None) 288 | Logger.CURRENT = Logger.DEFAULT 289 | 290 | 291 | class session(object): 292 | """ 293 | Context manager that sets up the loggers for an experiment. 294 | """ 295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS): 296 | self.dir = dir 297 | self.format_strs = format_strs 298 | 299 | def __enter__(self): 300 | os.makedirs(self.dir, exist_ok=True) 301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs] 302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) 303 | 304 | def __exit__(self, *args): 305 | Logger.CURRENT.close() 306 | Logger.CURRENT = Logger.DEFAULT 307 | 308 | 309 | # ================================================================ 310 | 311 | 312 | def main(): 313 | info("hi") 314 | debug("shouldn't appear") 315 | set_level(DEBUG) 316 | debug("should appear") 317 | dir = "/tmp/testlogging" 318 | if os.path.exists(dir): 319 | shutil.rmtree(dir) 320 | with session(dir=dir): 321 | record_tabular("a", 3) 322 | record_tabular("b", 2.5) 323 | dump_tabular() 324 | record_tabular("b", -2.5) 325 | record_tabular("a", 5.5) 326 | dump_tabular() 327 | info("^^^ should see a = 5.5") 328 | 329 | record_tabular("b", -2.5) 330 | dump_tabular() 331 | 332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") 333 | dump_tabular() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() -------------------------------------------------------------------------------- /single_design_acaa1/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import shutil 6 | import os.path as osp 7 | import json 8 | 9 | import dateutil.tz 10 | 11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv'] 12 | 13 | DEBUG = 10 14 | INFO = 20 15 | WARN = 30 16 | ERROR = 40 17 | 18 | DISABLED = 50 19 | 20 | 21 | class OutputFormat(object): 22 | def writekvs(self, kvs): 23 | """ 24 | Write key-value pairs 25 | """ 26 | raise NotImplementedError 27 | 28 | def writeseq(self, args): 29 | """ 30 | Write a sequence of other data (e.g. a logging message) 31 | """ 32 | pass 33 | 34 | def close(self): 35 | return 36 | 37 | 38 | class HumanOutputFormat(OutputFormat): 39 | def __init__(self, file): 40 | self.file = file 41 | 42 | def writekvs(self, kvs): 43 | # Create strings for printing 44 | key2str = OrderedDict() 45 | for (key, val) in kvs.items(): 46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val 47 | key2str[self._truncate(key)] = self._truncate(valstr) 48 | 49 | # Find max widths 50 | keywidth = max(map(len, key2str.keys())) 51 | valwidth = max(map(len, key2str.values())) 52 | 53 | # Write out the data 54 | dashes = '-' * (keywidth + valwidth + 7) 55 | lines = [dashes] 56 | for (key, val) in key2str.items(): 57 | lines.append('| %s%s | %s%s |' % ( 58 | key, 59 | ' ' * (keywidth - len(key)), 60 | val, 61 | ' ' * (valwidth - len(val)), 62 | )) 63 | lines.append(dashes) 64 | self.file.write('\n'.join(lines) + '\n') 65 | 66 | # Flush the output to the file 67 | self.file.flush() 68 | 69 | def _truncate(self, s): 70 | return s[:20] + '...' if len(s) > 23 else s 71 | 72 | def writeseq(self, args): 73 | for arg in args: 74 | self.file.write(arg) 75 | self.file.write('\n') 76 | self.file.flush() 77 | 78 | 79 | class JSONOutputFormat(OutputFormat): 80 | def __init__(self, file): 81 | self.file = file 82 | 83 | def writekvs(self, kvs): 84 | for k, v in kvs.items(): 85 | if hasattr(v, 'dtype'): 86 | v = v.tolist() 87 | kvs[k] = v 88 | self.file.write(json.dumps(kvs) + '\n') 89 | self.file.flush() 90 | 91 | def close(self): 92 | self.file.close() 93 | 94 | 95 | class CSVOutputFormat(OutputFormat): 96 | def __init__(self, file): 97 | self.file = file 98 | self.keys = [] 99 | self.sep = ',' 100 | 101 | def writekvs(self, kvs): 102 | # Add our current row to the history 103 | extra_keys = kvs.keys() - self.keys 104 | if extra_keys: 105 | self.keys.extend(extra_keys) 106 | self.file.seek(0) 107 | lines = self.file.readlines() 108 | self.file.seek(0) 109 | for (i, k) in enumerate(self.keys): 110 | if i > 0: 111 | self.file.write(',') 112 | self.file.write(k) 113 | self.file.write('\n') 114 | for line in lines[1:]: 115 | self.file.write(line[:-1]) 116 | self.file.write(self.sep * len(extra_keys)) 117 | self.file.write('\n') 118 | for (i, k) in enumerate(self.keys): 119 | if i > 0: 120 | self.file.write(',') 121 | v = kvs.get(k) 122 | if v is not None: 123 | self.file.write(str(v)) 124 | self.file.write('\n') 125 | self.file.flush() 126 | 127 | def close(self): 128 | self.file.close() 129 | 130 | 131 | def make_output_format(format, ev_dir): 132 | os.makedirs(ev_dir, exist_ok=True) 133 | if format == 'stdout': 134 | return HumanOutputFormat(sys.stdout) 135 | elif format == 'log': 136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') 137 | return HumanOutputFormat(log_file) 138 | elif format == 'json': 139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') 140 | return JSONOutputFormat(json_file) 141 | elif format == 'csv': 142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t') 143 | return CSVOutputFormat(csv_file) 144 | else: 145 | raise ValueError('Unknown format specified: %s' % (format,)) 146 | 147 | 148 | # ================================================================ 149 | # API 150 | # ================================================================ 151 | def log_params(params): 152 | assert isinstance(params, dict) 153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt') 154 | output_format = JSONOutputFormat(json_file) 155 | output_format.writekvs(params) 156 | output_format.close() 157 | 158 | def logkv(key, val): 159 | """ 160 | Log a value of some diagnostic 161 | Call this once for each diagnostic quantity, each iteration 162 | """ 163 | Logger.CURRENT.logkv(key, val) 164 | 165 | 166 | def dumpkvs(): 167 | """ 168 | Write all of the diagnostics from the current iteration 169 | 170 | level: int. (see old_logger.py docs) If the global logger level is higher than 171 | the level argument here, don't print to stdout. 172 | """ 173 | Logger.CURRENT.dumpkvs() 174 | 175 | 176 | # for backwards compatibility 177 | record_tabular = logkv 178 | dump_tabular = dumpkvs 179 | 180 | 181 | def log(*args, level=INFO): 182 | """ 183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 184 | """ 185 | Logger.CURRENT.log(*args, level=level) 186 | 187 | 188 | def debug(*args): 189 | log(*args, level=DEBUG) 190 | 191 | 192 | def info(*args): 193 | log(*args, level=INFO) 194 | 195 | 196 | def warn(*args): 197 | log(*args, level=WARN) 198 | 199 | 200 | def error(*args): 201 | log(*args, level=ERROR) 202 | 203 | 204 | def set_level(level): 205 | """ 206 | Set logging threshold on current logger. 207 | """ 208 | Logger.CURRENT.set_level(level) 209 | 210 | 211 | def get_level(): 212 | """ 213 | Set logging threshold on current logger. 214 | """ 215 | return Logger.CURRENT.level 216 | 217 | 218 | def get_dir(): 219 | """ 220 | Get directory that log files are being written to. 221 | will be None if there is no output directory (i.e., if you didn't call start) 222 | """ 223 | return Logger.CURRENT.get_dir() 224 | 225 | 226 | def get_expt_dir(): 227 | sys.stderr.write( 228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),)) 229 | return get_dir() 230 | 231 | 232 | # ================================================================ 233 | # Backend 234 | # ================================================================ 235 | 236 | 237 | class Logger(object): 238 | # A logger with no output files. (See right below class definition) 239 | DEFAULT = None 240 | # So that you can still log to the terminal without setting up any output files 241 | CURRENT = None # Current logger being used by the free functions above 242 | 243 | def __init__(self, dir, output_formats): 244 | self.name2val = OrderedDict() # values this iteration 245 | self.level = INFO 246 | self.dir = dir 247 | self.output_formats = output_formats 248 | 249 | # Logging API, forwarded 250 | # ---------------------------------------- 251 | def logkv(self, key, val): 252 | self.name2val[key] = val 253 | 254 | def dumpkvs(self): 255 | for fmt in self.output_formats: 256 | fmt.writekvs(self.name2val) 257 | self.name2val.clear() 258 | 259 | def log(self, *args, level=INFO): 260 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ') 262 | if self.level <= level: 263 | self._do_log((timestamp,) + args) 264 | 265 | # Configuration 266 | # ---------------------------------------- 267 | def set_level(self, level): 268 | self.level = level 269 | 270 | def get_dir(self): 271 | return self.dir 272 | 273 | def close(self): 274 | for fmt in self.output_formats: 275 | fmt.close() 276 | 277 | # Misc 278 | # ---------------------------------------- 279 | def _do_log(self, args): 280 | for fmt in self.output_formats: 281 | fmt.writeseq(args) 282 | 283 | 284 | # ================================================================ 285 | 286 | Logger.DEFAULT = Logger( 287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None) 288 | Logger.CURRENT = Logger.DEFAULT 289 | 290 | 291 | class session(object): 292 | """ 293 | Context manager that sets up the loggers for an experiment. 294 | """ 295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS): 296 | self.dir = dir 297 | self.format_strs = format_strs 298 | 299 | def __enter__(self): 300 | os.makedirs(self.dir, exist_ok=True) 301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs] 302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) 303 | 304 | def __exit__(self, *args): 305 | Logger.CURRENT.close() 306 | Logger.CURRENT = Logger.DEFAULT 307 | 308 | 309 | # ================================================================ 310 | 311 | 312 | def main(): 313 | info("hi") 314 | debug("shouldn't appear") 315 | set_level(DEBUG) 316 | debug("should appear") 317 | dir = "/tmp/testlogging" 318 | if os.path.exists(dir): 319 | shutil.rmtree(dir) 320 | with session(dir=dir): 321 | record_tabular("a", 3) 322 | record_tabular("b", 2.5) 323 | dump_tabular() 324 | record_tabular("b", -2.5) 325 | record_tabular("a", 5.5) 326 | dump_tabular() 327 | info("^^^ should see a = 5.5") 328 | 329 | record_tabular("b", -2.5) 330 | dump_tabular() 331 | 332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") 333 | dump_tabular() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() -------------------------------------------------------------------------------- /single_design_esr1/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from collections import OrderedDict 3 | import os 4 | import sys 5 | import shutil 6 | import os.path as osp 7 | import json 8 | 9 | import dateutil.tz 10 | 11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv'] 12 | 13 | DEBUG = 10 14 | INFO = 20 15 | WARN = 30 16 | ERROR = 40 17 | 18 | DISABLED = 50 19 | 20 | 21 | class OutputFormat(object): 22 | def writekvs(self, kvs): 23 | """ 24 | Write key-value pairs 25 | """ 26 | raise NotImplementedError 27 | 28 | def writeseq(self, args): 29 | """ 30 | Write a sequence of other data (e.g. a logging message) 31 | """ 32 | pass 33 | 34 | def close(self): 35 | return 36 | 37 | 38 | class HumanOutputFormat(OutputFormat): 39 | def __init__(self, file): 40 | self.file = file 41 | 42 | def writekvs(self, kvs): 43 | # Create strings for printing 44 | key2str = OrderedDict() 45 | for (key, val) in kvs.items(): 46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val 47 | key2str[self._truncate(key)] = self._truncate(valstr) 48 | 49 | # Find max widths 50 | keywidth = max(map(len, key2str.keys())) 51 | valwidth = max(map(len, key2str.values())) 52 | 53 | # Write out the data 54 | dashes = '-' * (keywidth + valwidth + 7) 55 | lines = [dashes] 56 | for (key, val) in key2str.items(): 57 | lines.append('| %s%s | %s%s |' % ( 58 | key, 59 | ' ' * (keywidth - len(key)), 60 | val, 61 | ' ' * (valwidth - len(val)), 62 | )) 63 | lines.append(dashes) 64 | self.file.write('\n'.join(lines) + '\n') 65 | 66 | # Flush the output to the file 67 | self.file.flush() 68 | 69 | def _truncate(self, s): 70 | return s[:20] + '...' if len(s) > 23 else s 71 | 72 | def writeseq(self, args): 73 | for arg in args: 74 | self.file.write(arg) 75 | self.file.write('\n') 76 | self.file.flush() 77 | 78 | 79 | class JSONOutputFormat(OutputFormat): 80 | def __init__(self, file): 81 | self.file = file 82 | 83 | def writekvs(self, kvs): 84 | for k, v in kvs.items(): 85 | if hasattr(v, 'dtype'): 86 | v = v.tolist() 87 | kvs[k] = v 88 | self.file.write(json.dumps(kvs) + '\n') 89 | self.file.flush() 90 | 91 | def close(self): 92 | self.file.close() 93 | 94 | 95 | class CSVOutputFormat(OutputFormat): 96 | def __init__(self, file): 97 | self.file = file 98 | self.keys = [] 99 | self.sep = ',' 100 | 101 | def writekvs(self, kvs): 102 | # Add our current row to the history 103 | extra_keys = kvs.keys() - self.keys 104 | if extra_keys: 105 | self.keys.extend(extra_keys) 106 | self.file.seek(0) 107 | lines = self.file.readlines() 108 | self.file.seek(0) 109 | for (i, k) in enumerate(self.keys): 110 | if i > 0: 111 | self.file.write(',') 112 | self.file.write(k) 113 | self.file.write('\n') 114 | for line in lines[1:]: 115 | self.file.write(line[:-1]) 116 | self.file.write(self.sep * len(extra_keys)) 117 | self.file.write('\n') 118 | for (i, k) in enumerate(self.keys): 119 | if i > 0: 120 | self.file.write(',') 121 | v = kvs.get(k) 122 | if v is not None: 123 | self.file.write(str(v)) 124 | self.file.write('\n') 125 | self.file.flush() 126 | 127 | def close(self): 128 | self.file.close() 129 | 130 | 131 | def make_output_format(format, ev_dir): 132 | os.makedirs(ev_dir, exist_ok=True) 133 | if format == 'stdout': 134 | return HumanOutputFormat(sys.stdout) 135 | elif format == 'log': 136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') 137 | return HumanOutputFormat(log_file) 138 | elif format == 'json': 139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') 140 | return JSONOutputFormat(json_file) 141 | elif format == 'csv': 142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t') 143 | return CSVOutputFormat(csv_file) 144 | else: 145 | raise ValueError('Unknown format specified: %s' % (format,)) 146 | 147 | 148 | # ================================================================ 149 | # API 150 | # ================================================================ 151 | def log_params(params): 152 | assert isinstance(params, dict) 153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt') 154 | output_format = JSONOutputFormat(json_file) 155 | output_format.writekvs(params) 156 | output_format.close() 157 | 158 | def logkv(key, val): 159 | """ 160 | Log a value of some diagnostic 161 | Call this once for each diagnostic quantity, each iteration 162 | """ 163 | Logger.CURRENT.logkv(key, val) 164 | 165 | 166 | def dumpkvs(): 167 | """ 168 | Write all of the diagnostics from the current iteration 169 | 170 | level: int. (see old_logger.py docs) If the global logger level is higher than 171 | the level argument here, don't print to stdout. 172 | """ 173 | Logger.CURRENT.dumpkvs() 174 | 175 | 176 | # for backwards compatibility 177 | record_tabular = logkv 178 | dump_tabular = dumpkvs 179 | 180 | 181 | def log(*args, level=INFO): 182 | """ 183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 184 | """ 185 | Logger.CURRENT.log(*args, level=level) 186 | 187 | 188 | def debug(*args): 189 | log(*args, level=DEBUG) 190 | 191 | 192 | def info(*args): 193 | log(*args, level=INFO) 194 | 195 | 196 | def warn(*args): 197 | log(*args, level=WARN) 198 | 199 | 200 | def error(*args): 201 | log(*args, level=ERROR) 202 | 203 | 204 | def set_level(level): 205 | """ 206 | Set logging threshold on current logger. 207 | """ 208 | Logger.CURRENT.set_level(level) 209 | 210 | 211 | def get_level(): 212 | """ 213 | Set logging threshold on current logger. 214 | """ 215 | return Logger.CURRENT.level 216 | 217 | 218 | def get_dir(): 219 | """ 220 | Get directory that log files are being written to. 221 | will be None if there is no output directory (i.e., if you didn't call start) 222 | """ 223 | return Logger.CURRENT.get_dir() 224 | 225 | 226 | def get_expt_dir(): 227 | sys.stderr.write( 228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),)) 229 | return get_dir() 230 | 231 | 232 | # ================================================================ 233 | # Backend 234 | # ================================================================ 235 | 236 | 237 | class Logger(object): 238 | # A logger with no output files. (See right below class definition) 239 | DEFAULT = None 240 | # So that you can still log to the terminal without setting up any output files 241 | CURRENT = None # Current logger being used by the free functions above 242 | 243 | def __init__(self, dir, output_formats): 244 | self.name2val = OrderedDict() # values this iteration 245 | self.level = INFO 246 | self.dir = dir 247 | self.output_formats = output_formats 248 | 249 | # Logging API, forwarded 250 | # ---------------------------------------- 251 | def logkv(self, key, val): 252 | self.name2val[key] = val 253 | 254 | def dumpkvs(self): 255 | for fmt in self.output_formats: 256 | fmt.writekvs(self.name2val) 257 | self.name2val.clear() 258 | 259 | def log(self, *args, level=INFO): 260 | now = datetime.datetime.now(dateutil.tz.tzlocal()) 261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ') 262 | if self.level <= level: 263 | self._do_log((timestamp,) + args) 264 | 265 | # Configuration 266 | # ---------------------------------------- 267 | def set_level(self, level): 268 | self.level = level 269 | 270 | def get_dir(self): 271 | return self.dir 272 | 273 | def close(self): 274 | for fmt in self.output_formats: 275 | fmt.close() 276 | 277 | # Misc 278 | # ---------------------------------------- 279 | def _do_log(self, args): 280 | for fmt in self.output_formats: 281 | fmt.writeseq(args) 282 | 283 | 284 | # ================================================================ 285 | 286 | Logger.DEFAULT = Logger( 287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None) 288 | Logger.CURRENT = Logger.DEFAULT 289 | 290 | 291 | class session(object): 292 | """ 293 | Context manager that sets up the loggers for an experiment. 294 | """ 295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS): 296 | self.dir = dir 297 | self.format_strs = format_strs 298 | 299 | def __enter__(self): 300 | os.makedirs(self.dir, exist_ok=True) 301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs] 302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats) 303 | 304 | def __exit__(self, *args): 305 | Logger.CURRENT.close() 306 | Logger.CURRENT = Logger.DEFAULT 307 | 308 | 309 | # ================================================================ 310 | 311 | 312 | def main(): 313 | info("hi") 314 | debug("shouldn't appear") 315 | set_level(DEBUG) 316 | debug("should appear") 317 | dir = "/tmp/testlogging" 318 | if os.path.exists(dir): 319 | shutil.rmtree(dir) 320 | with session(dir=dir): 321 | record_tabular("a", 3) 322 | record_tabular("b", 2.5) 323 | dump_tabular() 324 | record_tabular("b", -2.5) 325 | record_tabular("a", 5.5) 326 | dump_tabular() 327 | info("^^^ should see a = 5.5") 328 | 329 | record_tabular("b", -2.5) 330 | dump_tabular() 331 | 332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue") 333 | dump_tabular() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() --------------------------------------------------------------------------------