├── .gitignore ├── figure ├── model.png ├── multi.png └── single.png ├── multi_design_esr1 ├── figure.pdf ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── args.py ├── plot.py ├── stats.py ├── dataset.py ├── sascorer.py └── logger.py ├── single_design_esr1 ├── figure.pdf ├── utils.py ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── dataset.py ├── args.py ├── plot.py ├── sascorer.py └── logger.py ├── single_design_acaa1 ├── a.py ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── plot.py ├── sascorer.py └── logger.py ├── multi_design_acaa1 ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── args.py ├── stats.py ├── dataset.py ├── sascorer.py └── logger.py ├── single_design_plogp ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── plot.py └── sascorer.py ├── single_design_qed ├── ZINC │ ├── separate_data.py │ ├── cal_property.py │ └── char.py ├── utils.py ├── dataset.py ├── args.py ├── stats.py └── sascorer.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pt -------------------------------------------------------------------------------- /figure/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/model.png -------------------------------------------------------------------------------- /figure/multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/multi.png -------------------------------------------------------------------------------- /figure/single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/figure/single.png -------------------------------------------------------------------------------- /multi_design_esr1/figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/multi_design_esr1/figure.pdf -------------------------------------------------------------------------------- /single_design_esr1/figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deqiankong/SGDS/HEAD/single_design_esr1/figure.pdf -------------------------------------------------------------------------------- /single_design_acaa1/a.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | it = 22 4 | a = np.load(str(it) + '.npy') 5 | 6 | a = np.unique(a) 7 | kd = np.exp(a * (-1) / (0.00198720425864083 * 298.15)).flatten() 8 | 9 | ind = np.argsort(kd) 10 | ind = ind[:200] 11 | 12 | b = kd[ind] 13 | print(np.mean(b), np.std(b)) 14 | print(b) 15 | -------------------------------------------------------------------------------- /single_design_esr1/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import datetime 5 | import torch 6 | import sys 7 | 8 | 9 | def get_exp_id(file): 10 | return os.path.splitext(os.path.basename(file))[0] 11 | 12 | 13 | def get_output_dir(exp_id, fs_prefix='./'): 14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t) 16 | os.makedirs(output_dir, exist_ok=True) 17 | return output_dir 18 | 19 | 20 | def setup_logging(name, output_dir, console=True): 21 | log_format = logging.Formatter("%(asctime)s : %(message)s") 22 | logger = logging.getLogger(name) 23 | logger.handlers = [] 24 | output_file = os.path.join(output_dir, 'output.log') 25 | file_handler = logging.FileHandler(output_file) 26 | file_handler.setFormatter(log_format) 27 | logger.addHandler(file_handler) 28 | if console: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setFormatter(log_format) 31 | logger.addHandler(console_handler) 32 | logger.setLevel(logging.INFO) 33 | return logger 34 | 35 | 36 | def copy_source(file, output_dir): 37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file))) 38 | 39 | 40 | def copy_all_files(file, output_dir): 41 | dir_src = os.path.dirname(file) 42 | for filename in os.listdir(os.getcwd()): 43 | if filename.endswith('.py'): 44 | shutil.copy(os.path.join(dir_src, filename), output_dir) 45 | 46 | 47 | def set_gpu(gpu): 48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 50 | 51 | if torch.cuda.is_available(): 52 | torch.cuda.set_device(0) 53 | torch.backends.cudnn.benchmark = True 54 | 55 | -------------------------------------------------------------------------------- /multi_design_acaa1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /multi_design_esr1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_acaa1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_esr1/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_plogp/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /single_design_qed/ZINC/separate_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | 4 | USAGE=""" 5 | 6 | """ 7 | import numpy as np 8 | 9 | file_input="data_5.txt" 10 | fp_input=open(file_input) 11 | lines=fp_input.readlines() 12 | fp_input.close() 13 | 14 | logP_list=[] 15 | SAS_list=[] 16 | QED_list=[] 17 | MW_list=[] 18 | TPSA_list=[] 19 | 20 | for i in range(len(lines)): 21 | line=lines[i] 22 | if line[0]=='#': 23 | continue 24 | arr=line.split() 25 | logP=float(arr[1]) 26 | SAS=float(arr[2]) 27 | QED=float(arr[3]) 28 | MW=float(arr[4]) 29 | TPSA=float(arr[5]) 30 | logP_list+=[logP] 31 | SAS_list+=[SAS] 32 | QED_list+=[QED] 33 | MW_list+=[MW] 34 | TPSA_list+=[TPSA] 35 | 36 | logP_array=np.array(logP_list) 37 | SAS_array=np.array(SAS_list) 38 | QED_array=np.array(QED_list) 39 | MW_array=np.array(MW_list) 40 | TPSA_array=np.array(TPSA_list) 41 | 42 | print(logP_array.min(),logP_array.max()) 43 | print(SAS_array.min(),SAS_array.max()) 44 | print(QED_array.min(),QED_array.max()) 45 | print(MW_array.min(),MW_array.max()) 46 | print(TPSA_array.min(),TPSA_array.max()) 47 | 48 | Ndata=len(lines)-1 49 | index=np.arange(0,Ndata) 50 | np.random.shuffle(index) 51 | Ntest=10000 52 | Ntrain=Ndata-Ntest 53 | train_index=index[0:Ntrain] 54 | test_index=index[Ntrain:Ndata] 55 | train_index.sort() 56 | test_index.sort() 57 | 58 | file_output="train_5.txt" 59 | fp_out=open(file_output,"w") 60 | line_out="#smi logP SAS QED MW TPSA\n" 61 | fp_out.write(line_out) 62 | for i in train_index: 63 | line=lines[i+1] 64 | fp_out.write(line) 65 | fp_out.close() 66 | 67 | 68 | file_output="test_5.txt" 69 | fp_out=open(file_output,"w") 70 | line_out="#smi logP SAS QED MW TPSA\n" 71 | fp_out.write(line_out) 72 | for i in test_index: 73 | line=lines[i+1] 74 | fp_out.write(line) 75 | fp_out.close() 76 | 77 | 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sampling with Gradual Distribution Shifting (SGDS) 2 | This is the repository for our paper "Molecule Design by Latent Space Energy-based Modeling and Gradual Distribution Shifting" in UAI 2023. [PDF](https://proceedings.mlr.press/v216/kong23a/kong23a.pdf) 3 | 4 |  5 | 6 | In this paper, we studied the following property optimization tasks: 7 | * single-objective p-logP maximization 8 | * single-objective QED maximization 9 | * single-objective ESR1 binding affinity maximization 10 | * single-objective ACAA1 binding affinity maximization 11 | * multi-objective (ESR1, QED, SA) optmization 12 | * multi-objective (ACAA1, QED, SA) optmization 13 | 14 |
15 |
16 |
17 |
39 | @inproceedings{kong2023molecule,
40 | title={Molecule Design by Latent Space Energy-Based Modeling and Gradual Distribution Shifting},
41 | author={Kong, Deqian and Pang, Bo and Han, Tian and Wu, Ying Nian},
42 | booktitle={Uncertainty in Artificial Intelligence},
43 | pages={1109--1120},
44 | year={2023},
45 | organization={PMLR}
46 | }
47 |
48 |
--------------------------------------------------------------------------------
/multi_design_esr1/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/single_design_qed/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/multi_design_acaa1/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/single_design_acaa1/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/single_design_esr1/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/single_design_plogp/ZINC/cal_property.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import sys
3 |
4 | USAGE="""
5 |
6 | """
7 | import numpy as np
8 | from rdkit import Chem
9 |
10 | from rdkit.Chem.Descriptors import ExactMolWt
11 | from rdkit.Chem.Crippen import MolLogP
12 | from rdkit.Chem.Crippen import MolMR
13 |
14 | from rdkit.Chem.rdMolDescriptors import CalcNumHBD
15 | from rdkit.Chem.rdMolDescriptors import CalcNumHBA
16 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
17 |
18 | from rdkit.Chem.QED import qed
19 | #sys.path.insert(0,'/home/shade/SA_Score')
20 | import sascorer
21 |
22 |
23 | file_input="250k_rndm_zinc_drugs_clean_3.csv"
24 | fp_input=open(file_input)
25 | lines=fp_input.readlines()
26 | fp_input.close()
27 |
28 | file_output="data_5.txt"
29 | fp_out=open(file_output,"w")
30 |
31 | line_out="#smi logP SAS QED MW TPSA\n"
32 | fp_out.write(line_out)
33 | logP_list=[]
34 | SAS_list=[]
35 | QED_list=[]
36 | MW_list=[]
37 | TPSA_list=[]
38 |
39 | for i in range(len(lines)):
40 | line=lines[i]
41 | if line[0]!='"':
42 | continue
43 | if line[1]!=",":
44 | smi=line[1:].strip()
45 | continue
46 | m=Chem.MolFromSmiles(smi)
47 | smi2=Chem.MolToSmiles(m)
48 |
49 | property0=line[2:].split(",")
50 | # logP=float(property0[0])
51 | # SAS=float(property0[2])
52 | # QED=float(property0[1])
53 |
54 | logP=MolLogP(m)
55 | SAS=sascorer.calculateScore(m)
56 | QED=qed(m)
57 |
58 | MW=ExactMolWt(m)
59 | TPSA=CalcTPSA(m)
60 | line_out="%s %6.3f %6.3f %6.3f %6.3f %6.3f\n" %(smi2,logP,SAS,QED,MW,TPSA)
61 | fp_out.write(line_out)
62 | logP_list+=[logP]
63 | SAS_list+=[SAS]
64 | QED_list+=[QED]
65 | MW_list+=[MW]
66 | TPSA_list+=[TPSA]
67 |
68 |
69 | fp_out.close()
70 |
71 | logP_array=np.array(logP_list)
72 | SAS_array=np.array(SAS_list)
73 | QED_array=np.array(QED_list)
74 | MW_array=np.array(MW_list)
75 | TPSA_array=np.array(TPSA_list)
76 |
77 | print(logP_array.min(),logP_array.max())
78 | print(SAS_array.min(),SAS_array.max())
79 | print(QED_array.min(),QED_array.max())
80 | print(MW_array.min(),MW_array.max())
81 | print(TPSA_array.min(),TPSA_array.max())
82 |
83 |
--------------------------------------------------------------------------------
/single_design_qed/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import datetime
5 | import torch
6 | import sys
7 |
8 |
9 | def get_exp_id(file):
10 | return os.path.splitext(os.path.basename(file))[0]
11 |
12 |
13 | def get_output_dir(exp_id, fs_prefix='./'):
14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t)
16 | os.makedirs(output_dir, exist_ok=True)
17 | return output_dir
18 |
19 |
20 | def setup_logging(name, output_dir, console=True):
21 | log_format = logging.Formatter("%(asctime)s : %(message)s")
22 | logger = logging.getLogger(name)
23 | logger.handlers = []
24 | output_file = os.path.join(output_dir, 'output.log')
25 | file_handler = logging.FileHandler(output_file)
26 | file_handler.setFormatter(log_format)
27 | logger.addHandler(file_handler)
28 | if console:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setFormatter(log_format)
31 | logger.addHandler(console_handler)
32 | logger.setLevel(logging.INFO)
33 | return logger
34 |
35 |
36 | def copy_source(file, output_dir):
37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
38 |
39 |
40 | def copy_all_files(file, output_dir):
41 | dir_src = os.path.dirname(file)
42 | for filename in os.listdir(os.getcwd()):
43 | if filename.endswith('.py'):
44 | shutil.copy(os.path.join(dir_src, filename), output_dir)
45 |
46 |
47 | def set_gpu(gpu):
48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.set_device(0)
53 | torch.backends.cudnn.benchmark = True
54 |
55 |
56 | if __name__ == '__main__':
57 | # exp_id = get_exp_id(__file__)
58 | exp_id = 'ebm_plot'
59 | output_dir = get_output_dir(exp_id, fs_prefix='../alienware_')
60 | print(exp_id)
61 | print(os.getcwd())
62 | print(__file__)
63 | print(os.path.basename(__file__))
64 | print(os.path.dirname(__file__))
65 | # copy_source(__file__, output_dir)
66 | copy_all_files(__file__, output_dir)
67 |
--------------------------------------------------------------------------------
/single_design_plogp/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import datetime
5 | import torch
6 | import sys
7 |
8 |
9 | def get_exp_id(file):
10 | return os.path.splitext(os.path.basename(file))[0]
11 |
12 |
13 | def get_output_dir(exp_id, fs_prefix='./'):
14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t)
16 | os.makedirs(output_dir, exist_ok=True)
17 | return output_dir
18 |
19 |
20 | def setup_logging(name, output_dir, console=True):
21 | log_format = logging.Formatter("%(asctime)s : %(message)s")
22 | logger = logging.getLogger(name)
23 | logger.handlers = []
24 | output_file = os.path.join(output_dir, 'output.log')
25 | file_handler = logging.FileHandler(output_file)
26 | file_handler.setFormatter(log_format)
27 | logger.addHandler(file_handler)
28 | if console:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setFormatter(log_format)
31 | logger.addHandler(console_handler)
32 | logger.setLevel(logging.INFO)
33 | return logger
34 |
35 |
36 | def copy_source(file, output_dir):
37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
38 |
39 |
40 | def copy_all_files(file, output_dir):
41 | dir_src = os.path.dirname(file)
42 | for filename in os.listdir(os.getcwd()):
43 | if filename.endswith('.py'):
44 | shutil.copy(os.path.join(dir_src, filename), output_dir)
45 |
46 |
47 | def set_gpu(gpu):
48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.set_device(0)
53 | torch.backends.cudnn.benchmark = True
54 |
55 |
56 | if __name__ == '__main__':
57 | # exp_id = get_exp_id(__file__)
58 | exp_id = 'ebm_plot'
59 | output_dir = get_output_dir(exp_id, fs_prefix='../alienware_')
60 | print(exp_id)
61 | print(os.getcwd())
62 | print(__file__)
63 | print(os.path.basename(__file__))
64 | print(os.path.dirname(__file__))
65 | # copy_source(__file__, output_dir)
66 | copy_all_files(__file__, output_dir)
67 |
--------------------------------------------------------------------------------
/multi_design_acaa1/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import datetime
5 | import torch
6 | import sys
7 |
8 |
9 | def get_exp_id(file):
10 | return os.path.splitext(os.path.basename(file))[0]
11 |
12 |
13 | def get_output_dir(exp_id, fs_prefix='./'):
14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t)
16 | os.makedirs(output_dir, exist_ok=True)
17 | return output_dir
18 |
19 |
20 | def setup_logging(name, output_dir, console=True):
21 | log_format = logging.Formatter("%(asctime)s : %(message)s")
22 | logger = logging.getLogger(name)
23 | logger.handlers = []
24 | output_file = os.path.join(output_dir, 'output.log')
25 | file_handler = logging.FileHandler(output_file)
26 | file_handler.setFormatter(log_format)
27 | logger.addHandler(file_handler)
28 | if console:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setFormatter(log_format)
31 | logger.addHandler(console_handler)
32 | logger.setLevel(logging.INFO)
33 | return logger
34 |
35 |
36 | def copy_source(file, output_dir):
37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
38 |
39 |
40 | def copy_all_files(file, output_dir):
41 | dir_src = os.path.dirname(file)
42 | for filename in os.listdir(os.getcwd()):
43 | if filename.endswith('.py'):
44 | shutil.copy(os.path.join(dir_src, filename), output_dir)
45 |
46 |
47 | def set_gpu(gpu):
48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.set_device(0)
53 | torch.backends.cudnn.benchmark = True
54 |
55 |
56 | if __name__ == '__main__':
57 | # exp_id = get_exp_id(__file__)
58 | # exp_id = 'ebm_plot'
59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_')
60 | # print(exp_id)
61 | # print(os.getcwd())
62 | # print(__file__)
63 | # print(os.path.basename(__file__))
64 | # print(os.path.dirname(__file__))
65 | # # copy_source(__file__, output_dir)
66 | # copy_all_files(__file__, output_dir)
67 | set_gpu(1)
68 | print(torch.cuda.device_count())
69 | a = torch.tensor([1,1,1]).cuda()
70 | print(a)
--------------------------------------------------------------------------------
/multi_design_esr1/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import datetime
5 | import torch
6 | import sys
7 |
8 |
9 | def get_exp_id(file):
10 | return os.path.splitext(os.path.basename(file))[0]
11 |
12 |
13 | def get_output_dir(exp_id, fs_prefix='./'):
14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t)
16 | os.makedirs(output_dir, exist_ok=True)
17 | return output_dir
18 |
19 |
20 | def setup_logging(name, output_dir, console=True):
21 | log_format = logging.Formatter("%(asctime)s : %(message)s")
22 | logger = logging.getLogger(name)
23 | logger.handlers = []
24 | output_file = os.path.join(output_dir, 'output.log')
25 | file_handler = logging.FileHandler(output_file)
26 | file_handler.setFormatter(log_format)
27 | logger.addHandler(file_handler)
28 | if console:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setFormatter(log_format)
31 | logger.addHandler(console_handler)
32 | logger.setLevel(logging.INFO)
33 | return logger
34 |
35 |
36 | def copy_source(file, output_dir):
37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
38 |
39 |
40 | def copy_all_files(file, output_dir):
41 | dir_src = os.path.dirname(file)
42 | for filename in os.listdir(os.getcwd()):
43 | if filename.endswith('.py'):
44 | shutil.copy(os.path.join(dir_src, filename), output_dir)
45 |
46 |
47 | def set_gpu(gpu):
48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.set_device(0)
53 | torch.backends.cudnn.benchmark = True
54 |
55 |
56 | if __name__ == '__main__':
57 | # exp_id = get_exp_id(__file__)
58 | # exp_id = 'ebm_plot'
59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_')
60 | # print(exp_id)
61 | # print(os.getcwd())
62 | # print(__file__)
63 | # print(os.path.basename(__file__))
64 | # print(os.path.dirname(__file__))
65 | # # copy_source(__file__, output_dir)
66 | # copy_all_files(__file__, output_dir)
67 | set_gpu(1)
68 | print(torch.cuda.device_count())
69 | a = torch.tensor([1,1,1]).cuda()
70 | print(a)
--------------------------------------------------------------------------------
/single_design_acaa1/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import shutil
4 | import datetime
5 | import torch
6 | import sys
7 |
8 |
9 | def get_exp_id(file):
10 | return os.path.splitext(os.path.basename(file))[0]
11 |
12 |
13 | def get_output_dir(exp_id, fs_prefix='./'):
14 | t = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
15 | output_dir = os.path.join(fs_prefix + 'output/' + exp_id, t)
16 | os.makedirs(output_dir, exist_ok=True)
17 | return output_dir
18 |
19 |
20 | def setup_logging(name, output_dir, console=True):
21 | log_format = logging.Formatter("%(asctime)s : %(message)s")
22 | logger = logging.getLogger(name)
23 | logger.handlers = []
24 | output_file = os.path.join(output_dir, 'output.log')
25 | file_handler = logging.FileHandler(output_file)
26 | file_handler.setFormatter(log_format)
27 | logger.addHandler(file_handler)
28 | if console:
29 | console_handler = logging.StreamHandler(sys.stdout)
30 | console_handler.setFormatter(log_format)
31 | logger.addHandler(console_handler)
32 | logger.setLevel(logging.INFO)
33 | return logger
34 |
35 |
36 | def copy_source(file, output_dir):
37 | shutil.copyfile(file, os.path.join(output_dir, os.path.basename(file)))
38 |
39 |
40 | def copy_all_files(file, output_dir):
41 | dir_src = os.path.dirname(file)
42 | for filename in os.listdir(os.getcwd()):
43 | if filename.endswith('.py'):
44 | shutil.copy(os.path.join(dir_src, filename), output_dir)
45 |
46 |
47 | def set_gpu(gpu):
48 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
49 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
50 |
51 | if torch.cuda.is_available():
52 | torch.cuda.set_device(0)
53 | torch.backends.cudnn.benchmark = True
54 |
55 |
56 | if __name__ == '__main__':
57 | # exp_id = get_exp_id(__file__)
58 | # exp_id = 'ebm_plot'
59 | # output_dir = get_output_dir(exp_id, fs_prefix='../alienware_')
60 | # print(exp_id)
61 | # print(os.getcwd())
62 | # print(__file__)
63 | # print(os.path.basename(__file__))
64 | # print(os.path.dirname(__file__))
65 | # # copy_source(__file__, output_dir)
66 | # copy_all_files(__file__, output_dir)
67 | set_gpu(1)
68 | print(torch.cuda.device_count())
69 | a = torch.tensor([1,1,1]).cuda()
70 | print(a)
--------------------------------------------------------------------------------
/single_design_qed/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 |
15 |
16 | # max_len = 72
17 | # num_of_embeddings = 109
18 | class MolDataset(Dataset):
19 | def __init__(self, datadir, dname):
20 | Xdata_file = datadir + "/X" + dname + ".npy"
21 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
22 | Ldata_file = datadir + "/L" + dname + ".npy"
23 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
24 | self.len = self.Xdata.shape[0]
25 | LogPdata_file = datadir + "/LogP" + dname + ".npy"
26 | self.LogP = torch.tensor(np.load(LogPdata_file), dtype=torch.float32)
27 | qed_data_file = datadir + "/qed_" + dname + ".npy"
28 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
29 |
30 | def __getitem__(self, index):
31 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
32 | mol = self.Xdata[index]
33 | sos = torch.tensor([108], dtype=torch.long)
34 | mol = torch.cat([sos, mol], dim=0).contiguous()
35 | mask = torch.zeros(mol.shape[0] - 1)
36 | mask[:self.Ldata[index] + 1] = 1.
37 | return (mol, mask, self.qed[index])
38 | # return (mol, mask, self.LogP[index])
39 | # print(self.Ldata[index])
40 | # print(mask)
41 | # print(mol[:self.Ldata[index] + 2])
42 | # return (mol, self.Ldata[index], self.LogP[index])
43 |
44 | def __len__(self):
45 | return self.len
46 |
47 | def save_mol_png(self, label, filepath, size=(600, 600)):
48 | m_smi = self.label2sf2smi(label)
49 | m = Chem.MolFromSmiles(m_smi)
50 | Draw.MolToFile(m, filepath, size=size)
51 |
52 | def label2sf2smi(self, label):
53 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
54 | m_smi = sf.decoder(m_sf)
55 | m_smi = Chem.CanonSmiles(m_smi)
56 | return m_smi
57 |
58 |
59 | if __name__ == '__main__':
60 | datadir = '../data'
61 | ds = MolDataset(datadir, 'train')
62 | ds_loader = DataLoader(dataset=ds, batch_size=100,
63 | shuffle=True, drop_last=True, num_workers=2)
64 |
65 | max_len = 0
66 | for i in range(len(ds)):
67 | m, len, s = ds[i]
68 | m = m.numpy()
69 | mol = ds.label2sf2smi(m[1:])
70 | print(mol)
71 | ds.save_mol_png(label=m[1:], filepath='../a.png')
72 | break
73 |
74 | # print(char_dict)
75 | # print(ds.sf2smi(m), s)
76 |
--------------------------------------------------------------------------------
/single_design_esr1/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 | import math
15 |
16 |
17 | # max_len = 72
18 | # num_of_embeddings = 109
19 | class MolDataset(Dataset):
20 | def __init__(self, datadir, dname):
21 | Xdata_file = datadir + "/X" + dname + ".npy"
22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
23 | Ldata_file = datadir + "/L" + dname + ".npy"
24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
25 | self.len = self.Xdata.shape[0]
26 | sasdata_file = datadir + "/sas_" + dname + ".npy"
27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32)
28 | qed_data_file = datadir + "/qed_" + dname + ".npy"
29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
30 |
31 | ba0data_file = datadir + "/ba0_" + dname + ".npy"
32 | self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32)
33 | # ba1data_file = datadir + "/ba1_" + dname + ".npy"
34 | # self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32)
35 | # self.preprocess()
36 |
37 | def preprocess(self):
38 | ind = torch.nonzero(self.ba0)
39 | self.ba0 = self.ba0[ind].squeeze()
40 | self.sas = self.sas[ind].squeeze()
41 | self.qed = self.qed[ind].squeeze()
42 | self.Xdata = self.Xdata[ind].squeeze()
43 | self.Ldata = self.Ldata[ind].squeeze()
44 | self.len = self.Xdata.shape[0]
45 |
46 | def __getitem__(self, index):
47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
48 | mol = self.Xdata[index]
49 | sos = torch.tensor([108], dtype=torch.long)
50 | mol = torch.cat([sos, mol], dim=0).contiguous()
51 | mask = torch.zeros(mol.shape[0] - 1)
52 | mask[:self.Ldata[index] + 1] = 1.
53 | return (mol, mask, -1 * self.ba0[index])
54 |
55 | def __len__(self):
56 | return self.len
57 |
58 | def save_mol_png(self, label, filepath, size=(600, 600)):
59 | m_smi = self.label2sf2smi(label)
60 | m = Chem.MolFromSmiles(m_smi)
61 | Draw.MolToFile(m, filepath, size=size)
62 |
63 | def label2sf2smi(self, label):
64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
65 | m_smi = sf.decoder(m_sf)
66 | m_smi = Chem.CanonSmiles(m_smi)
67 | return m_smi
68 |
69 | @staticmethod
70 | def delta_to_kd(x):
71 | return math.exp(x / (0.00198720425864083 * 298.15))
72 |
73 |
--------------------------------------------------------------------------------
/single_design_acaa1/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 | import math
15 |
16 |
17 | # max_len = 72
18 | # num_of_embeddings = 109
19 | class MolDataset(Dataset):
20 | def __init__(self, datadir, dname):
21 | Xdata_file = datadir + "/X" + dname + ".npy"
22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
23 | Ldata_file = datadir + "/L" + dname + ".npy"
24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
25 | self.len = self.Xdata.shape[0]
26 | sasdata_file = datadir + "/sas_" + dname + ".npy"
27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32)
28 | qed_data_file = datadir + "/qed_" + dname + ".npy"
29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
30 |
31 | # ba0data_file = datadir + "/ba0_" + dname + ".npy"
32 | # self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32)
33 | ba1data_file = datadir + "/ba1_" + dname + ".npy"
34 | self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32)
35 | # self.preprocess()
36 |
37 | def preprocess(self):
38 | ind = torch.nonzero(self.ba0)
39 | self.ba0 = self.ba0[ind].squeeze()
40 | self.sas = self.sas[ind].squeeze()
41 | self.qed = self.qed[ind].squeeze()
42 | self.Xdata = self.Xdata[ind].squeeze()
43 | self.Ldata = self.Ldata[ind].squeeze()
44 | self.len = self.Xdata.shape[0]
45 |
46 | def __getitem__(self, index):
47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
48 | mol = self.Xdata[index]
49 | sos = torch.tensor([108], dtype=torch.long)
50 | mol = torch.cat([sos, mol], dim=0).contiguous()
51 | mask = torch.zeros(mol.shape[0] - 1)
52 | mask[:self.Ldata[index] + 1] = 1.
53 | return (mol, mask, -1 * self.ba1[index])
54 |
55 | def __len__(self):
56 | return self.len
57 |
58 | def save_mol_png(self, label, filepath, size=(600, 600)):
59 | m_smi = self.label2sf2smi(label)
60 | m = Chem.MolFromSmiles(m_smi)
61 | Draw.MolToFile(m, filepath, size=size)
62 |
63 | def label2sf2smi(self, label):
64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
65 | m_smi = sf.decoder(m_sf)
66 | m_smi = Chem.CanonSmiles(m_smi)
67 | return m_smi
68 |
69 | @staticmethod
70 | def delta_to_kd(x):
71 | return math.exp(x / (0.00198720425864083 * 298.15))
72 |
73 |
74 |
--------------------------------------------------------------------------------
/single_design_plogp/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 |
15 |
16 | # max_len = 72
17 | # num_of_embeddings = 109
18 | class MolDataset(Dataset):
19 | def __init__(self, datadir, dname):
20 | Xdata_file = datadir + "/X" + dname + ".npy"
21 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
22 | Ldata_file = datadir + "/L" + dname + ".npy"
23 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
24 | self.len = self.Xdata.shape[0]
25 | LogPdata_file = datadir + "/LogP" + dname + ".npy"
26 | self.LogP = torch.tensor(np.load(LogPdata_file), dtype=torch.float32)
27 | qed_data_file = datadir + "/qed_" + dname + ".npy"
28 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
29 | PlogPdata_file = datadir + "/PlogP" + dname + ".npy"
30 | self.PlogP = torch.tensor(np.load(PlogPdata_file), dtype=torch.float32)
31 |
32 | def __getitem__(self, index):
33 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
34 | mol = self.Xdata[index]
35 | sos = torch.tensor([108], dtype=torch.long)
36 | mol = torch.cat([sos, mol], dim=0).contiguous()
37 | mask = torch.zeros(mol.shape[0] - 1)
38 | mask[:self.Ldata[index] + 1] = 1.
39 | return (mol, mask, self.PlogP[index])
40 | # return (mol, mask, self.LogP[index])
41 | # print(self.Ldata[index])
42 | # print(mask)
43 | # print(mol[:self.Ldata[index] + 2])
44 | # return (mol, self.Ldata[index], self.LogP[index])
45 |
46 | def __len__(self):
47 | return self.len
48 |
49 | def save_mol_png(self, label, filepath, size=(600, 600)):
50 | m_smi = self.label2sf2smi(label)
51 | m = Chem.MolFromSmiles(m_smi)
52 | Draw.MolToFile(m, filepath, size=size)
53 |
54 | def label2sf2smi(self, label):
55 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
56 | m_smi = sf.decoder(m_sf)
57 | m_smi = Chem.CanonSmiles(m_smi)
58 | return m_smi
59 |
60 |
61 | if __name__ == '__main__':
62 | datadir = '../data'
63 | ds = MolDataset(datadir, 'train')
64 | ds_loader = DataLoader(dataset=ds, batch_size=100,
65 | shuffle=True, drop_last=True, num_workers=2)
66 |
67 | max_len = 0
68 | for i in range(len(ds)):
69 | m, len, s = ds[i]
70 | m = m.numpy()
71 | mol = ds.label2sf2smi(m[1:])
72 | print(mol)
73 | ds.save_mol_png(label=m[1:], filepath='../a.png')
74 | break
75 |
76 | # print(char_dict)
77 | # print(ds.sf2smi(m), s)
78 |
--------------------------------------------------------------------------------
/multi_design_acaa1/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/multi_design_esr1/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/single_design_acaa1/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/single_design_esr1/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/single_design_plogp/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/single_design_qed/ZINC/char.py:
--------------------------------------------------------------------------------
1 | # char_list = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I",
2 | # "n", "c", "o", "s",
3 | # "1", "2", "3", "4", "5", "6", "7", "8",
4 | # "(", ")", "[", "]",
5 | # "-", "=", "#", "/", "\\", "+", "@", "<", ">"]
6 | #
7 | # char_dict = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'P': 5,
8 | # 'S': 6, 'Cl': 7, 'Br': 8, 'I': 9,
9 | # 'n': 10, 'c': 11, 'o': 12, 's': 13,
10 | # '1': 14, '2': 15, '3': 16, '4': 17, '5': 18, '6': 19, '7': 20, '8': 21,
11 | # '(': 22, ')': 23, '[': 24, ']': 25, '-': 26, '=': 27, '#': 28,
12 | # '/': 29, '\\': 30, '+': 31, '@': 32, '<': 33, '>': 34}
13 |
14 | # char_dict=dict()
15 | # i=-1
16 | # for c in char_list:
17 | # i+=1
18 | # char_dict[c]=i
19 | # print(char_dict)
20 | # sys.exit()
21 |
22 | char_list = ['[=Branch2]', '[Ring1]', '[#Branch2]', '[\\NH1]', '[=O]', '[S@@+1]', '[=P@@]', '[-/Ring1]', '[=S]',
23 | '[=Ring1]', '[/C@@]', '[\\NH2+1]', '[\\O]', '[/Cl]', '[/C@]', '[=N+1]', '[=OH1+1]', '[/O+1]', '[#N]',
24 | '[=Branch1]', '[=C]', '[=N]', '[\\NH1+1]', '[P@@]', '[/S]', '[=S+1]', '[F]', '[S+1]', '[S@@]', '[=NH1+1]',
25 | '[/NH1-1]', '[\\S@]', '[\\N+1]', '[#N+1]', '[\\N]', '[I]', '[Branch2]', '[P@]', '[PH1]', '[CH2-1]',
26 | '[/C@H1]', '[Cl]', '[N+1]', '[Ring2]', '[\\O-1]', '[Br]', '[\\C@H1]', '[-/Ring2]', '[\\I]', '[=NH2+1]',
27 | '[C@@]', '[\\S-1]', '[\\C]', '[/N+1]', '[=PH2]', '[/S@]', '[\\S]', '[NH2+1]', '[nop]', '[NH1]', '[P]',
28 | '[Branch1]', '[\\Br]', '[=O+1]', '[-\\Ring1]', '[/N-1]', '[\\Cl]', '[P@@H1]', '[N]', '[=P]', '[NH1+1]',
29 | '[\\N-1]', '[/Br]', '[/NH1+1]', '[S]', '[N-1]', '[/NH2+1]', '[NH1-1]', '[#C]', '[C]', '[\\F]', '[/S-1]',
30 | '[/F]', '[/NH1]', '[=N-1]', '[NH3+1]', '[P+1]', '[=Ring2]', '[CH1-1]', '[S-1]', '[=P@]', '[/C]', '[=S@]',
31 | '[\\C@@H1]', '[O]', '[O-1]', '[/C@@H1]', '[#Branch1]', '[=SH1+1]', '[/O]', '[=S@@]', '[C@H1]', '[S@]',
32 | '[C@@H1]', '[/N]', '[C@]', '[/O-1]', '[PH1+1]', 'sos']
33 |
34 | char_dict = {0: '[=Branch2]', 1: '[Ring1]', 2: '[#Branch2]', 3: '[\\NH1]', 4: '[=O]', 5: '[S@@+1]', 6: '[=P@@]',
35 | 7: '[-/Ring1]', 8: '[=S]', 9: '[=Ring1]', 10: '[/C@@]', 11: '[\\NH2+1]', 12: '[\\O]', 13: '[/Cl]',
36 | 14: '[/C@]', 15: '[=N+1]', 16: '[=OH1+1]', 17: '[/O+1]', 18: '[#N]', 19: '[=Branch1]', 20: '[=C]',
37 | 21: '[=N]', 22: '[\\NH1+1]', 23: '[P@@]', 24: '[/S]', 25: '[=S+1]', 26: '[F]', 27: '[S+1]', 28: '[S@@]',
38 | 29: '[=NH1+1]', 30: '[/NH1-1]', 31: '[\\S@]', 32: '[\\N+1]', 33: '[#N+1]', 34: '[\\N]', 35: '[I]',
39 | 36: '[Branch2]', 37: '[P@]', 38: '[PH1]', 39: '[CH2-1]', 40: '[/C@H1]', 41: '[Cl]', 42: '[N+1]',
40 | 43: '[Ring2]', 44: '[\\O-1]', 45: '[Br]', 46: '[\\C@H1]', 47: '[-/Ring2]', 48: '[\\I]', 49: '[=NH2+1]',
41 | 50: '[C@@]', 51: '[\\S-1]', 52: '[\\C]', 53: '[/N+1]', 54: '[=PH2]', 55: '[/S@]', 56: '[\\S]',
42 | 57: '[NH2+1]', 58: '[nop]', 59: '[NH1]', 60: '[P]', 61: '[Branch1]', 62: '[\\Br]', 63: '[=O+1]',
43 | 64: '[-\\Ring1]', 65: '[/N-1]', 66: '[\\Cl]', 67: '[P@@H1]', 68: '[N]', 69: '[=P]', 70: '[NH1+1]',
44 | 71: '[\\N-1]', 72: '[/Br]', 73: '[/NH1+1]', 74: '[S]', 75: '[N-1]', 76: '[/NH2+1]', 77: '[NH1-1]',
45 | 78: '[#C]', 79: '[C]', 80: '[\\F]', 81: '[/S-1]', 82: '[/F]', 83: '[/NH1]', 84: '[=N-1]', 85: '[NH3+1]',
46 | 86: '[P+1]', 87: '[=Ring2]', 88: '[CH1-1]', 89: '[S-1]', 90: '[=P@]', 91: '[/C]', 92: '[=S@]',
47 | 93: '[\\C@@H1]', 94: '[O]', 95: '[O-1]', 96: '[/C@@H1]', 97: '[#Branch1]', 98: '[=SH1+1]', 99: '[/O]',
48 | 100: '[=S@@]', 101: '[C@H1]', 102: '[S@]', 103: '[C@@H1]', 104: '[/N]', 105: '[C@]', 106: '[/O-1]',
49 | 107: '[PH1+1]'}
50 |
--------------------------------------------------------------------------------
/single_design_qed/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data')
10 | parser.add_argument('--test_file', default='ZINC/test_5.txt')
11 | parser.add_argument('--train_from', default='')
12 | parser.add_argument('--max_len', default=110, type=int)
13 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
14 | parser.add_argument('--eval_batch_size', default=512 * 1, type=int)
15 |
16 | # General options
17 | parser.add_argument('--z_n_iters', type=int, default=20)
18 | parser.add_argument('--z_step_size', type=float, default=0.5)
19 | parser.add_argument('--z_with_noise', type=int, default=1)
20 | parser.add_argument('--num_z_samples', type=int, default=10)
21 | parser.add_argument('--model', type=str, default='mol_ebm')
22 | parser.add_argument('--mask', type=bool, default=False)
23 | parser.add_argument('--single_design', default=True, type=bool)
24 |
25 | # EBM
26 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
27 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
28 | parser.add_argument('--prior_step_size', type=float, default=0.5)
29 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
30 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
31 | parser.add_argument('--ebm_reg', default=0.0, type=float)
32 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
33 | parser.add_argument('--ref_sigma', type=float, default=0.5)
34 | parser.add_argument('--init_factor', type=float, default=1.)
35 | parser.add_argument('--noise_factor', type=float, default=0.5)
36 |
37 | # Decoder and MLP options
38 | parser.add_argument('--mlp_hidden_dim', default=200, type=int)
39 | parser.add_argument('--prop_coefficient', default=10., type=float)
40 | parser.add_argument('--latent_dim', default=100, type=int)
41 | parser.add_argument('--dec_word_dim', default=512, type=int)
42 | parser.add_argument('--dec_h_dim', default=1024, type=int)
43 | parser.add_argument('--dec_num_layers', default=1, type=int)
44 | parser.add_argument('--dec_dropout', default=0.2, type=float)
45 | parser.add_argument('--train_n2n', default=1, type=int)
46 | parser.add_argument('--train_kl', default=1, type=int)
47 |
48 | # Optimization options
49 | parser.add_argument('--log_dir', default='../log/')
50 | parser.add_argument('--checkpoint_dir', default='models')
51 | # parser.add_argument('--slurm', default=0, type=int)
52 | parser.add_argument('--warmup', default=0, type=int)
53 | parser.add_argument('--num_epochs', default=25, type=int)
54 | parser.add_argument('--min_epochs', default=15, type=int)
55 | parser.add_argument('--start_epoch', default=0, type=int)
56 | parser.add_argument('--eps', default=1e-5, type=float)
57 | parser.add_argument('--decay', default=0, type=int)
58 | parser.add_argument('--momentum', default=0.5, type=float)
59 | parser.add_argument('--lr', default=0.001, type=float)
60 | parser.add_argument('--prior_lr', default=0.0001, type=float)
61 | parser.add_argument('--max_grad_norm', default=5, type=float)
62 | parser.add_argument('--gpu', default=0, type=int)
63 | parser.add_argument('--seed', default=3435, type=int)
64 | parser.add_argument('--print_every', type=int, default=100)
65 | parser.add_argument('--sample_every', type=int, default=1000)
66 | parser.add_argument('--kl_every', type=int, default=100)
67 | parser.add_argument('--compute_kl', type=int, default=1)
68 | parser.add_argument('--test', type=int, default=0)
69 | return parser.parse_args()
70 |
71 | if __name__ == '__main__':
72 | args = get_args()
73 | print(args)
74 |
--------------------------------------------------------------------------------
/single_design_plogp/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data')
10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt')
11 | parser.add_argument('--train_from', default='')
12 | parser.add_argument('--max_len', default=110, type=int)
13 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
14 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int)
15 |
16 | # General options
17 | parser.add_argument('--z_n_iters', type=int, default=20)
18 | parser.add_argument('--z_step_size', type=float, default=0.5)
19 | parser.add_argument('--z_with_noise', type=int, default=1)
20 | parser.add_argument('--num_z_samples', type=int, default=10)
21 | parser.add_argument('--model', type=str, default='mol_ebm')
22 | parser.add_argument('--mask', type=bool, default=False)
23 | parser.add_argument('--single_design', default=True, type=bool)
24 |
25 | # EBM
26 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
27 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
28 | parser.add_argument('--prior_step_size', type=float, default=0.5)
29 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
30 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
31 | parser.add_argument('--ebm_reg', default=0.0, type=float)
32 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
33 | parser.add_argument('--ref_sigma', type=float, default=0.5)
34 | parser.add_argument('--init_factor', type=float, default=1.)
35 | parser.add_argument('--noise_factor', type=float, default=0.5)
36 |
37 | # Decoder and MLP options
38 | parser.add_argument('--mlp_hidden_dim', default=200, type=int)
39 | parser.add_argument('--prop_coefficient', default=10., type=float)
40 | parser.add_argument('--latent_dim', default=100, type=int)
41 | parser.add_argument('--dec_word_dim', default=512, type=int)
42 | parser.add_argument('--dec_h_dim', default=1024, type=int)
43 | parser.add_argument('--dec_num_layers', default=1, type=int)
44 | parser.add_argument('--dec_dropout', default=0.2, type=float)
45 | parser.add_argument('--train_n2n', default=1, type=int)
46 | parser.add_argument('--train_kl', default=1, type=int)
47 |
48 | # Optimization options
49 | parser.add_argument('--log_dir', default='../log/')
50 | parser.add_argument('--checkpoint_dir', default='models')
51 | # parser.add_argument('--slurm', default=0, type=int)
52 | parser.add_argument('--warmup', default=0, type=int)
53 | parser.add_argument('--num_epochs', default=25, type=int)
54 | parser.add_argument('--min_epochs', default=15, type=int)
55 | parser.add_argument('--start_epoch', default=0, type=int)
56 | parser.add_argument('--eps', default=1e-5, type=float)
57 | parser.add_argument('--decay', default=0, type=int)
58 | parser.add_argument('--momentum', default=0.5, type=float)
59 | parser.add_argument('--lr', default=0.001, type=float)
60 | parser.add_argument('--prior_lr', default=0.0001, type=float)
61 | parser.add_argument('--max_grad_norm', default=5, type=float)
62 | parser.add_argument('--gpu', default=0, type=int)
63 | parser.add_argument('--seed', default=3435, type=int)
64 | parser.add_argument('--print_every', type=int, default=100)
65 | parser.add_argument('--sample_every', type=int, default=1000)
66 | parser.add_argument('--kl_every', type=int, default=100)
67 | parser.add_argument('--compute_kl', type=int, default=1)
68 | parser.add_argument('--test', type=int, default=0)
69 | return parser.parse_args()
70 |
71 | if __name__ == '__main__':
72 | args = get_args()
73 | print(args)
74 |
--------------------------------------------------------------------------------
/multi_design_esr1/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data')
10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt')
11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi')
12 | parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld')
13 | parser.add_argument('--train_from', default='')
14 | parser.add_argument('--max_len', default=110, type=int)
15 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int)
17 |
18 | # General options
19 | parser.add_argument('--z_n_iters', type=int, default=20)
20 | parser.add_argument('--z_step_size', type=float, default=0.5)
21 | parser.add_argument('--z_with_noise', type=int, default=1)
22 | parser.add_argument('--num_z_samples', type=int, default=10)
23 | parser.add_argument('--model', type=str, default='mol_ebm')
24 | parser.add_argument('--mask', type=bool, default=False)
25 | parser.add_argument('--single_design', default=False, type=bool)
26 | parser.add_argument('--multi_design', default=True, type=bool)
27 |
28 | # EBM
29 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
30 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
31 | parser.add_argument('--prior_step_size', type=float, default=0.5)
32 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
34 | parser.add_argument('--ebm_reg', default=0.0, type=float)
35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
36 | parser.add_argument('--ref_sigma', type=float, default=0.5)
37 | parser.add_argument('--init_factor', type=float, default=1.)
38 | parser.add_argument('--noise_factor', type=float, default=0.5)
39 |
40 | # Decoder and MLP options
41 | parser.add_argument('--mlp_hidden_dim', default=50, type=int)
42 | parser.add_argument('--latent_dim', default=100, type=int)
43 | parser.add_argument('--dec_word_dim', default=512, type=int)
44 | parser.add_argument('--dec_h_dim', default=1024, type=int)
45 | parser.add_argument('--dec_num_layers', default=1, type=int)
46 | parser.add_argument('--dec_dropout', default=0.2, type=float)
47 | parser.add_argument('--train_n2n', default=1, type=int)
48 | parser.add_argument('--train_kl', default=1, type=int)
49 |
50 | # prop coefficients
51 | parser.add_argument('--prop_coefficient', default=10., type=float)
52 | parser.add_argument('--ba', default=10., type=float)
53 | parser.add_argument('--sas', default=10., type=float)
54 | parser.add_argument('--qed', default=10., type=float)
55 |
56 | # Optimization options
57 | parser.add_argument('--log_dir', default='../log/')
58 | parser.add_argument('--checkpoint_dir', default='models')
59 | # parser.add_argument('--slurm', default=0, type=int)
60 | parser.add_argument('--warmup', default=0, type=int)
61 | parser.add_argument('--num_epochs', default=30, type=int)
62 | parser.add_argument('--min_epochs', default=15, type=int)
63 | parser.add_argument('--start_epoch', default=0, type=int)
64 | parser.add_argument('--eps', default=1e-5, type=float)
65 | parser.add_argument('--decay', default=0, type=int)
66 | parser.add_argument('--momentum', default=0.5, type=float)
67 | parser.add_argument('--lr', default=0.001, type=float)
68 | parser.add_argument('--prior_lr', default=0.0001, type=float)
69 | parser.add_argument('--max_grad_norm', default=5, type=float)
70 | parser.add_argument('--gpu', default=1, type=int)
71 | parser.add_argument('--seed', default=3435, type=int)
72 | parser.add_argument('--print_every', type=int, default=100)
73 | parser.add_argument('--sample_every', type=int, default=1000)
74 | parser.add_argument('--kl_every', type=int, default=100)
75 | parser.add_argument('--compute_kl', type=int, default=1)
76 | parser.add_argument('--test', type=int, default=0)
77 | return parser.parse_args()
78 |
79 | if __name__ == '__main__':
80 | args = get_args()
81 | print(args)
82 |
--------------------------------------------------------------------------------
/single_design_acaa1/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data/ba1')
10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt')
11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi')
12 | parser.add_argument('--protein_file', type=str, default='../2iik/2iik.maps.fld')
13 | parser.add_argument('--train_from', default='')
14 | parser.add_argument('--max_len', default=110, type=int)
15 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int)
17 |
18 | # General options
19 | parser.add_argument('--z_n_iters', type=int, default=20)
20 | parser.add_argument('--z_step_size', type=float, default=0.5)
21 | parser.add_argument('--z_with_noise', type=int, default=1)
22 | parser.add_argument('--num_z_samples', type=int, default=10)
23 | parser.add_argument('--model', type=str, default='mol_ebm')
24 | parser.add_argument('--mask', type=bool, default=False)
25 | parser.add_argument('--single_design', default=True, type=bool)
26 | parser.add_argument('--multi_design', default=False, type=bool)
27 |
28 | # EBM
29 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
30 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
31 | parser.add_argument('--prior_step_size', type=float, default=0.5)
32 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
34 | parser.add_argument('--ebm_reg', default=0.0, type=float)
35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
36 | parser.add_argument('--ref_sigma', type=float, default=0.5)
37 | parser.add_argument('--init_factor', type=float, default=1.)
38 | parser.add_argument('--noise_factor', type=float, default=0.5)
39 |
40 | # Decoder and MLP options
41 | parser.add_argument('--mlp_hidden_dim', default=100, type=int)
42 | parser.add_argument('--latent_dim', default=100, type=int)
43 | parser.add_argument('--dec_word_dim', default=512, type=int)
44 | parser.add_argument('--dec_h_dim', default=1024, type=int)
45 | parser.add_argument('--dec_num_layers', default=1, type=int)
46 | parser.add_argument('--dec_dropout', default=0.2, type=float)
47 | parser.add_argument('--train_n2n', default=1, type=int)
48 | parser.add_argument('--train_kl', default=1, type=int)
49 |
50 | # prop coefficients
51 | parser.add_argument('--prop_coefficient', default=10., type=float)
52 | parser.add_argument('--ba', default=10., type=float)
53 | parser.add_argument('--sas', default=10., type=float)
54 | parser.add_argument('--qed', default=10., type=float)
55 |
56 | # Optimization options
57 | parser.add_argument('--log_dir', default='../log/')
58 | parser.add_argument('--checkpoint_dir', default='models')
59 | # parser.add_argument('--slurm', default=0, type=int)
60 | parser.add_argument('--warmup', default=0, type=int)
61 | parser.add_argument('--num_epochs', default=30, type=int)
62 | parser.add_argument('--min_epochs', default=15, type=int)
63 | parser.add_argument('--start_epoch', default=0, type=int)
64 | parser.add_argument('--eps', default=1e-5, type=float)
65 | parser.add_argument('--decay', default=0, type=int)
66 | parser.add_argument('--momentum', default=0.5, type=float)
67 | parser.add_argument('--lr', default=0.001, type=float)
68 | parser.add_argument('--prior_lr', default=0.0001, type=float)
69 | parser.add_argument('--max_grad_norm', default=5, type=float)
70 | parser.add_argument('--gpu', default=1, type=int)
71 | parser.add_argument('--seed', default=3435, type=int)
72 | parser.add_argument('--print_every', type=int, default=100)
73 | parser.add_argument('--sample_every', type=int, default=1000)
74 | parser.add_argument('--kl_every', type=int, default=100)
75 | parser.add_argument('--compute_kl', type=int, default=1)
76 | parser.add_argument('--test', type=int, default=0)
77 | return parser.parse_args()
78 |
79 | if __name__ == '__main__':
80 | args = get_args()
81 | print(args)
82 |
--------------------------------------------------------------------------------
/single_design_esr1/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data')
10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt')
11 | parser.add_argument('--autodock_executable', type=str, default='../../AutoDock-GPU/bin/autodock_gpu_128wi')
12 | parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld')
13 | parser.add_argument('--train_from', default='')
14 | parser.add_argument('--max_len', default=110, type=int)
15 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
16 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int)
17 |
18 | # General options
19 | parser.add_argument('--z_n_iters', type=int, default=20)
20 | parser.add_argument('--z_step_size', type=float, default=0.5)
21 | parser.add_argument('--z_with_noise', type=int, default=1)
22 | parser.add_argument('--num_z_samples', type=int, default=10)
23 | parser.add_argument('--model', type=str, default='mol_ebm')
24 | parser.add_argument('--mask', type=bool, default=False)
25 | parser.add_argument('--single_design', default=True, type=bool)
26 | parser.add_argument('--multi_design', default=False, type=bool)
27 |
28 | # EBM
29 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
30 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
31 | parser.add_argument('--prior_step_size', type=float, default=0.5)
32 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
33 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
34 | parser.add_argument('--ebm_reg', default=0.0, type=float)
35 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
36 | parser.add_argument('--ref_sigma', type=float, default=0.5)
37 | parser.add_argument('--init_factor', type=float, default=1.)
38 | parser.add_argument('--noise_factor', type=float, default=0.5)
39 |
40 | # Decoder and MLP options
41 | parser.add_argument('--mlp_hidden_dim', default=50, type=int)
42 | parser.add_argument('--latent_dim', default=100, type=int)
43 | parser.add_argument('--dec_word_dim', default=512, type=int)
44 | parser.add_argument('--dec_h_dim', default=1024, type=int)
45 | parser.add_argument('--dec_num_layers', default=1, type=int)
46 | parser.add_argument('--dec_dropout', default=0.2, type=float)
47 | parser.add_argument('--train_n2n', default=1, type=int)
48 | parser.add_argument('--train_kl', default=1, type=int)
49 |
50 | # prop coefficients
51 | parser.add_argument('--prop_coefficient', default=10., type=float)
52 | parser.add_argument('--ba', default=10., type=float)
53 | parser.add_argument('--sas', default=10., type=float)
54 | parser.add_argument('--qed', default=10., type=float)
55 |
56 | # Optimization options
57 | parser.add_argument('--log_dir', default='../log/')
58 | parser.add_argument('--checkpoint_dir', default='models')
59 | # parser.add_argument('--slurm', default=0, type=int)
60 | parser.add_argument('--warmup', default=0, type=int)
61 | parser.add_argument('--num_epochs', default=30, type=int)
62 | parser.add_argument('--min_epochs', default=15, type=int)
63 | parser.add_argument('--start_epoch', default=0, type=int)
64 | parser.add_argument('--eps', default=1e-5, type=float)
65 | parser.add_argument('--decay', default=0, type=int)
66 | parser.add_argument('--momentum', default=0.5, type=float)
67 | parser.add_argument('--lr', default=0.001, type=float)
68 | parser.add_argument('--prior_lr', default=0.0001, type=float)
69 | parser.add_argument('--max_grad_norm', default=5, type=float)
70 | parser.add_argument('--gpu', default=1, type=int)
71 | parser.add_argument('--seed', default=3435, type=int)
72 | parser.add_argument('--print_every', type=int, default=100)
73 | parser.add_argument('--sample_every', type=int, default=1000)
74 | parser.add_argument('--kl_every', type=int, default=100)
75 | parser.add_argument('--compute_kl', type=int, default=1)
76 | parser.add_argument('--test', type=int, default=0)
77 | return parser.parse_args()
78 |
79 | if __name__ == '__main__':
80 | args = get_args()
81 | print(args)
82 |
--------------------------------------------------------------------------------
/multi_design_acaa1/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("--debug", default=False, type=bool)
6 | parser.add_argument("--tb", default=False, type=bool)
7 |
8 | # Input data
9 | parser.add_argument('--data_dir', default='../data/ba1')
10 | parser.add_argument('--test_file', default='../ZINC/test_5.txt')
11 | parser.add_argument('--autodock_executable', type=str, default='~/AutoDock-GPU/bin/autodock_gpu_128wi')
12 | # parser.add_argument('--protein_file', type=str, default='../1err/1err.maps.fld')
13 | parser.add_argument('--protein_file', type=str, default='../2iik/2iik.maps.fld')
14 | parser.add_argument('--train_from', default='')
15 | parser.add_argument('--max_len', default=110, type=int)
16 | parser.add_argument('--batch_size', default=1024 * 1, type=int)
17 | parser.add_argument('--eval_batch_size', default=500 * 1, type=int)
18 |
19 | # General options
20 | parser.add_argument('--z_n_iters', type=int, default=20)
21 | parser.add_argument('--z_step_size', type=float, default=0.5)
22 | parser.add_argument('--z_with_noise', type=int, default=1)
23 | parser.add_argument('--num_z_samples', type=int, default=10)
24 | parser.add_argument('--model', type=str, default='mol_ebm')
25 | parser.add_argument('--mask', type=bool, default=False)
26 | parser.add_argument('--single_design', default=False, type=bool)
27 | parser.add_argument('--multi_design', default=True, type=bool)
28 |
29 | # EBM
30 | parser.add_argument('--prior_hidden_dim', type=int, default=200)
31 | parser.add_argument('--z_prior_with_noise', type=int, default=1)
32 | parser.add_argument('--prior_step_size', type=float, default=0.5)
33 | parser.add_argument('--z_n_iters_prior', type=int, default=20)
34 | parser.add_argument('--max_grad_norm_prior', default=1, type=float)
35 | parser.add_argument('--ebm_reg', default=0.0, type=float)
36 | parser.add_argument('--ref_dist', default='gaussian', type=str, choices=['gaussian', 'uniform'])
37 | parser.add_argument('--ref_sigma', type=float, default=0.5)
38 | parser.add_argument('--init_factor', type=float, default=1.)
39 | parser.add_argument('--noise_factor', type=float, default=0.5)
40 |
41 | # Decoder and MLP options
42 | parser.add_argument('--mlp_hidden_dim', default=50, type=int)
43 | parser.add_argument('--latent_dim', default=100, type=int)
44 | parser.add_argument('--dec_word_dim', default=512, type=int)
45 | parser.add_argument('--dec_h_dim', default=1024, type=int)
46 | parser.add_argument('--dec_num_layers', default=1, type=int)
47 | parser.add_argument('--dec_dropout', default=0.2, type=float)
48 | parser.add_argument('--train_n2n', default=1, type=int)
49 | parser.add_argument('--train_kl', default=1, type=int)
50 |
51 | # prop coefficients
52 | parser.add_argument('--prop_coefficient', default=10., type=float)
53 | parser.add_argument('--ba', default=10., type=float)
54 | parser.add_argument('--sas', default=10., type=float)
55 | parser.add_argument('--qed', default=10., type=float)
56 |
57 | # Optimization options
58 | parser.add_argument('--log_dir', default='../log/')
59 | parser.add_argument('--checkpoint_dir', default='models')
60 | # parser.add_argument('--slurm', default=0, type=int)
61 | parser.add_argument('--warmup', default=0, type=int)
62 | parser.add_argument('--num_epochs', default=30, type=int)
63 | parser.add_argument('--min_epochs', default=15, type=int)
64 | parser.add_argument('--start_epoch', default=0, type=int)
65 | parser.add_argument('--eps', default=1e-5, type=float)
66 | parser.add_argument('--decay', default=0, type=int)
67 | parser.add_argument('--momentum', default=0.5, type=float)
68 | parser.add_argument('--lr', default=0.001, type=float)
69 | parser.add_argument('--prior_lr', default=0.0001, type=float)
70 | parser.add_argument('--max_grad_norm', default=5, type=float)
71 | parser.add_argument('--gpu', default=1, type=int)
72 | parser.add_argument('--seed', default=3435, type=int)
73 | parser.add_argument('--print_every', type=int, default=100)
74 | parser.add_argument('--sample_every', type=int, default=1000)
75 | parser.add_argument('--kl_every', type=int, default=100)
76 | parser.add_argument('--compute_kl', type=int, default=1)
77 | parser.add_argument('--test', type=int, default=0)
78 | return parser.parse_args()
79 |
80 | if __name__ == '__main__':
81 | args = get_args()
82 | print(args)
83 |
--------------------------------------------------------------------------------
/single_design_acaa1/plot.py:
--------------------------------------------------------------------------------
1 | from matplotlib import rc
2 | # getting necessary libraries
3 | import numpy as np
4 | import pandas as pd
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
9 |
10 | n = 50
11 | m = 2000
12 | data = []
13 | mean = np.zeros(n)
14 | for i in range(50):
15 | data.append(np.load(str(i) + ".npy"))
16 | data = np.vstack(data)[0:n + 1, 0:m]
17 |
18 | for j in range(n):
19 | mean[j] = np.mean(data[j, :])
20 |
21 | print(data.shape)
22 | # print(data)
23 |
24 | epochInd = [i for i in range(50)]
25 | toDel = [i for i in range(0, 50, 2)]
26 | print(len(toDel))
27 | dataList = []
28 | for i in range(n):
29 | if i not in toDel:
30 | tmp = np.zeros((int(m), 3))
31 | tmp[:, 1] = np.ones(int(m)) * epochInd[i]
32 | tmp[:, 2] = np.ones(int(m)) * mean[i]
33 | for j in range(int(m)):
34 | tmp[j, 0] = data[i, j]
35 | dataList.append(tmp)
36 |
37 | data = np.vstack(dataList)
38 | vals = data[:, 0].tolist()
39 | labels = data[:, 1].astype('int').astype('str').tolist()
40 | means = data[:, 2].tolist()
41 | # print(means)
42 | # print(labels)
43 |
44 |
45 | # Col = ["Vals", "Epoch"]
46 | # df = pd.DataFrame(data, columns = Col)
47 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means})
48 | print(df)
49 | #
50 | #
51 | # sns.kdeplot(data=df, x="Vals", hue='Epoch')
52 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
53 |
54 |
55 | # #
56 | # # # getting the data
57 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository
58 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column
59 | #
60 | # # we define a dictionnary with months that we'll use later
61 | # month_dict = {1: 'january',
62 | # 2: 'february',
63 | # 3: 'march',
64 | # 4: 'april',
65 | # 5: 'may',
66 | # 6: 'june',
67 | # 7: 'july',
68 | # 8: 'august',
69 | # 9: 'september',
70 | # 10: 'october',
71 | # 11: 'november',
72 | # 12: 'december'}
73 | #
74 | # # we create a 'month' column
75 | # temp['month'] = temp['month'].map(month_dict)
76 | # # #
77 | # print(temp)
78 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe
79 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean()
80 | # temp['mean_month'] = temp['month'].map(month_mean_serie)
81 | #
82 | #
83 | # we generate a color palette with Seaborn.color_palette()
84 | pal = sns.color_palette(palette='coolwarm', n_colors=20)
85 |
86 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette'
87 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal)
88 | #
89 | # # then we add the densities kdeplots for each month
90 | g.map(sns.kdeplot, 'Vals',
91 | bw_adjust=1, clip_on=False, warn_singular=False,
92 | fill=True, alpha=1, linewidth=1.5)
93 |
94 | # # # # here we add a white line that represents the contour of each kdeplot
95 | # g.map(sns.kdeplot, 'Vals',
96 | # bw_adjust=1, clip_on=False, warn_singular=False,
97 | # color="w", lw=2)
98 | # #
99 | # # here we add a horizontal line for each plot
100 | g.map(plt.axhline, y=0,
101 | lw=2, clip_on=False)
102 | # #
103 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color
104 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes
105 | for i, ax in enumerate(g.axes.flat):
106 | ax.text(33, 0.03, i,
107 | fontweight='bold', fontsize=15,
108 | color=ax.lines[-1].get_color())
109 |
110 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap
111 | g.fig.subplots_adjust(hspace=-0.88)
112 |
113 | # eventually we remove axes titles, yticks and spines
114 | g.set_titles("")
115 | g.set(yticks=[])
116 | g.despine(bottom=True, left=True)
117 |
118 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold')
119 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15)
120 | g.fig.suptitle('tmp',
121 | ha='right',
122 | fontsize=20,
123 | fontweight=20)
124 | # plt.xlim(xmin=0)
125 | # plt.show()
126 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
127 |
--------------------------------------------------------------------------------
/single_design_esr1/plot.py:
--------------------------------------------------------------------------------
1 | from matplotlib import rc
2 | # getting necessary libraries
3 | import numpy as np
4 | import pandas as pd
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
9 |
10 | n = 50
11 | m = 2000
12 | data = []
13 | mean = np.zeros(n)
14 | for i in range(50):
15 | data.append(np.load(str(i) + ".npy"))
16 | data = np.vstack(data)[0:n + 1, 0:m]
17 |
18 | for j in range(n):
19 | mean[j] = np.mean(data[j, :])
20 |
21 | print(data.shape)
22 | # print(data)
23 |
24 | epochInd = [i for i in range(50)]
25 | toDel = [i for i in range(0, 50, 2)]
26 | print(len(toDel))
27 | dataList = []
28 | for i in range(n):
29 | if i not in toDel:
30 | tmp = np.zeros((int(m), 3))
31 | tmp[:, 1] = np.ones(int(m)) * epochInd[i]
32 | tmp[:, 2] = np.ones(int(m)) * mean[i]
33 | for j in range(int(m)):
34 | tmp[j, 0] = data[i, j]
35 | dataList.append(tmp)
36 |
37 | data = np.vstack(dataList)
38 | vals = data[:, 0].tolist()
39 | labels = data[:, 1].astype('int').astype('str').tolist()
40 | means = data[:, 2].tolist()
41 | # print(means)
42 | # print(labels)
43 |
44 |
45 | # Col = ["Vals", "Epoch"]
46 | # df = pd.DataFrame(data, columns = Col)
47 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means})
48 | print(df)
49 | #
50 | #
51 | # sns.kdeplot(data=df, x="Vals", hue='Epoch')
52 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
53 |
54 |
55 | # #
56 | # # # getting the data
57 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository
58 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column
59 | #
60 | # # we define a dictionnary with months that we'll use later
61 | # month_dict = {1: 'january',
62 | # 2: 'february',
63 | # 3: 'march',
64 | # 4: 'april',
65 | # 5: 'may',
66 | # 6: 'june',
67 | # 7: 'july',
68 | # 8: 'august',
69 | # 9: 'september',
70 | # 10: 'october',
71 | # 11: 'november',
72 | # 12: 'december'}
73 | #
74 | # # we create a 'month' column
75 | # temp['month'] = temp['month'].map(month_dict)
76 | # # #
77 | # print(temp)
78 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe
79 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean()
80 | # temp['mean_month'] = temp['month'].map(month_mean_serie)
81 | #
82 | #
83 | # we generate a color palette with Seaborn.color_palette()
84 | pal = sns.color_palette(palette='coolwarm', n_colors=20)
85 |
86 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette'
87 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal)
88 | #
89 | # # then we add the densities kdeplots for each month
90 | g.map(sns.kdeplot, 'Vals',
91 | bw_adjust=1, clip_on=False, warn_singular=False,
92 | fill=True, alpha=1, linewidth=1.5)
93 |
94 | # # # # here we add a white line that represents the contour of each kdeplot
95 | # g.map(sns.kdeplot, 'Vals',
96 | # bw_adjust=1, clip_on=False, warn_singular=False,
97 | # color="w", lw=2)
98 | # #
99 | # # here we add a horizontal line for each plot
100 | g.map(plt.axhline, y=0,
101 | lw=2, clip_on=False)
102 | # #
103 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color
104 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes
105 | for i, ax in enumerate(g.axes.flat):
106 | ax.text(33, 0.03, i,
107 | fontweight='bold', fontsize=15,
108 | color=ax.lines[-1].get_color())
109 |
110 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap
111 | g.fig.subplots_adjust(hspace=-0.88)
112 |
113 | # eventually we remove axes titles, yticks and spines
114 | g.set_titles("")
115 | g.set(yticks=[])
116 | g.despine(bottom=True, left=True)
117 |
118 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold')
119 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15)
120 | g.fig.suptitle('tmp',
121 | ha='right',
122 | fontsize=20,
123 | fontweight=20)
124 | # plt.xlim(xmin=0)
125 | # plt.show()
126 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
127 |
--------------------------------------------------------------------------------
/single_design_plogp/plot.py:
--------------------------------------------------------------------------------
1 | from matplotlib import rc
2 | # getting necessary libraries
3 | import numpy as np
4 | import pandas as pd
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
9 |
10 | n = 50
11 | m = 10000
12 | data = []
13 | mean = np.zeros(n)
14 | for i in range(50):
15 | data.append(np.load('two_step/' + str(i) + ".npy"))
16 | data = np.vstack(data)[0:n + 1, 0:m]
17 |
18 | for j in range(n):
19 | mean[j] = np.mean(data[j, :])
20 |
21 | print(data.shape)
22 | # print(data)
23 |
24 | epochInd = [i for i in range(50)]
25 | toDel = [i for i in range(15, 50)]
26 | dataList = []
27 | for i in range(n):
28 | if i not in toDel:
29 | tmp = np.zeros((int(m), 3))
30 | tmp[:, 1] = np.ones(int(m)) * epochInd[i]
31 | tmp[:, 2] = np.ones(int(m)) * mean[i]
32 | for j in range(int(m)):
33 | tmp[j, 0] = data[i, j]
34 | dataList.append(tmp)
35 |
36 | data = np.vstack(dataList)
37 | vals = data[:, 0].tolist()
38 | labels = data[:, 1].astype('int').astype('str').tolist()
39 | means = data[:, 2].tolist()
40 | # print(means)
41 | # print(labels)
42 |
43 |
44 | # Col = ["Vals", "Epoch"]
45 | # df = pd.DataFrame(data, columns = Col)
46 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means})
47 | print(df)
48 | #
49 | #
50 | # sns.kdeplot(data=df, x="Vals", hue='Epoch')
51 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
52 |
53 |
54 | # #
55 | # # # getting the data
56 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository
57 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column
58 | #
59 | # # we define a dictionnary with months that we'll use later
60 | # month_dict = {1: 'january',
61 | # 2: 'february',
62 | # 3: 'march',
63 | # 4: 'april',
64 | # 5: 'may',
65 | # 6: 'june',
66 | # 7: 'july',
67 | # 8: 'august',
68 | # 9: 'september',
69 | # 10: 'october',
70 | # 11: 'november',
71 | # 12: 'december'}
72 | #
73 | # # we create a 'month' column
74 | # temp['month'] = temp['month'].map(month_dict)
75 | # # #
76 | # print(temp)
77 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe
78 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean()
79 | # temp['mean_month'] = temp['month'].map(month_mean_serie)
80 | #
81 | #
82 | # we generate a color palette with Seaborn.color_palette()
83 | pal = sns.color_palette(palette='coolwarm', n_colors=20)
84 |
85 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette'
86 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal)
87 | #
88 | # # then we add the densities kdeplots for each month
89 | g.map(sns.kdeplot, 'Vals',
90 | bw_adjust=1, clip_on=False, warn_singular=False,
91 | fill=True, alpha=1, linewidth=1.5)
92 |
93 | # # # # here we add a white line that represents the contour of each kdeplot
94 | # g.map(sns.kdeplot, 'Vals',
95 | # bw_adjust=1, clip_on=False, warn_singular=False,
96 | # color="w", lw=2)
97 | # #
98 | # # here we add a horizontal line for each plot
99 | g.map(plt.axhline, y=0,
100 | lw=2, clip_on=False)
101 | # #
102 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color
103 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes
104 | for i, ax in enumerate(g.axes.flat):
105 | ax.text(33, 0.03, i,
106 | fontweight='bold', fontsize=15,
107 | color=ax.lines[-1].get_color())
108 |
109 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap
110 | g.fig.subplots_adjust(hspace=-0.88)
111 |
112 | # eventually we remove axes titles, yticks and spines
113 | g.set_titles("")
114 | g.set(yticks=[])
115 | g.despine(bottom=True, left=True)
116 |
117 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold')
118 | plt.xlabel("Penalized " + "logP", fontweight='bold', fontsize=15)
119 | g.fig.suptitle('t',
120 | ha='right',
121 | fontsize=20,
122 | fontweight=20)
123 | # plt.xlim(xmin=0)
124 | # plt.show()
125 | plt.savefig("plogp.pdf", dpi=300, bbox_inches='tight')
126 |
--------------------------------------------------------------------------------
/multi_design_esr1/plot.py:
--------------------------------------------------------------------------------
1 | from matplotlib import rc
2 | # getting necessary libraries
3 | import numpy as np
4 | import pandas as pd
5 | import seaborn as sns
6 | import matplotlib.pyplot as plt
7 |
8 | sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
9 |
10 | n = 25
11 | m = 2000
12 | data = []
13 | mean = np.zeros(n)
14 | data = np.zeros([25, 10000])
15 |
16 | for i in range(0, 25):
17 | x = np.load(str(i) + "_ba.npy")
18 | data[i, :len(x)] = x
19 | # data = np.vstack(data)[0:n + 1, 0:m]
20 |
21 |
22 | for j in range(n):
23 | mean[j] = np.mean(data[j, :])
24 |
25 | print(data.shape)
26 | # print(data)
27 |
28 | epochInd = [i for i in range(25)]
29 | toDel = [i for i in range(0, 0)]
30 | print(len(toDel))
31 | dataList = []
32 | for i in range(n):
33 | if i not in toDel:
34 | tmp = np.zeros((int(m), 3))
35 | tmp[:, 1] = np.ones(int(m)) * epochInd[i]
36 | tmp[:, 2] = np.ones(int(m)) * mean[i]
37 | for j in range(int(m)):
38 | tmp[j, 0] = data[i, j]
39 | dataList.append(tmp)
40 |
41 | data = np.vstack(dataList)
42 | vals = data[:, 0].tolist()
43 | labels = data[:, 1].astype('int').astype('str').tolist()
44 | means = data[:, 2].tolist()
45 | # print(means)
46 | # print(labels)
47 |
48 |
49 | # Col = ["Vals", "Epoch"]
50 | # df = pd.DataFrame(data, columns = Col)
51 | df = pd.DataFrame(data={'Epoch': labels, 'Vals': vals, 'Epoch_Mean': means})
52 | print(df)
53 | #
54 | #
55 | # sns.kdeplot(data=df, x="Vals", hue='Epoch')
56 | # plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
57 |
58 |
59 | # #
60 | # # # getting the data
61 | # temp = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/2016-weather-data-seattle.csv') # we retrieve the data from plotly's GitHub repository
62 | # temp['month'] = pd.to_datetime(temp['Date']).dt.month # we store the month in a separate column
63 | #
64 | # # we define a dictionnary with months that we'll use later
65 | # month_dict = {1: 'january',
66 | # 2: 'february',
67 | # 3: 'march',
68 | # 4: 'april',
69 | # 5: 'may',
70 | # 6: 'june',
71 | # 7: 'july',
72 | # 8: 'august',
73 | # 9: 'september',
74 | # 10: 'october',
75 | # 11: 'november',
76 | # 12: 'december'}
77 | #
78 | # # we create a 'month' column
79 | # temp['month'] = temp['month'].map(month_dict)
80 | # # #
81 | # print(temp)
82 | # # we generate a pd.Serie with the mean temperature for each month (used later for colors in the FacetGrid plot), and we create a new column in temp dataframe
83 | # month_mean_serie = temp.groupby('month')['Mean_TemperatureC'].mean()
84 | # temp['mean_month'] = temp['month'].map(month_mean_serie)
85 | #
86 | #
87 | # we generate a color palette with Seaborn.color_palette()
88 | pal = sns.color_palette(palette='coolwarm', n_colors=20)
89 |
90 | # # in the sns.FacetGrid class, the 'hue' argument is the one that is the one that will be represented by colors with 'palette'
91 | g = sns.FacetGrid(df, row='Epoch', hue='Epoch_Mean', aspect=15, height=0.80, palette=pal)
92 | #
93 | # # then we add the densities kdeplots for each month
94 | g.map(sns.kdeplot, 'Vals',
95 | bw_adjust=1, clip_on=False, warn_singular=False,
96 | fill=True, alpha=1, linewidth=1.5)
97 |
98 | # # # # here we add a white line that represents the contour of each kdeplot
99 | # g.map(sns.kdeplot, 'Vals',
100 | # bw_adjust=1, clip_on=False, warn_singular=False,
101 | # color="w", lw=2)
102 | # #
103 | # # here we add a horizontal line for each plot
104 | g.map(plt.axhline, y=0,
105 | lw=2, clip_on=False)
106 | # #
107 | # we loop over the FacetGrid figure axes (g.axes.flat) and add the month as text with the right color
108 | # notice how ax.lines[-1].get_color() enables you to access the last line's color in each matplotlib.Axes
109 | for i, ax in enumerate(g.axes.flat):
110 | ax.text(33, 0.03, i,
111 | fontweight='bold', fontsize=15,
112 | color=ax.lines[-1].get_color())
113 |
114 | # we use matplotlib.Figure.subplots_adjust() function to get the subplots to overlap
115 | g.fig.subplots_adjust(hspace=-0.88)
116 |
117 | # eventually we remove axes titles, yticks and spines
118 | g.set_titles("")
119 | g.set(yticks=[])
120 | g.despine(bottom=True, left=True)
121 |
122 | plt.setp(ax.get_xticklabels(), fontsize=15, fontweight='bold')
123 | plt.xlabel("ACAA1", fontweight='bold', fontsize=15)
124 | g.fig.suptitle('tmp',
125 | ha='right',
126 | fontsize=20,
127 | fontweight=20)
128 | # plt.xlim(xmin=0)
129 | # plt.show()
130 | plt.savefig("figure.pdf", dpi=300, bbox_inches='tight')
131 |
--------------------------------------------------------------------------------
/multi_design_acaa1/stats.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | from rdkit.Chem.QED import qed
5 | import sascorer
6 |
7 |
8 | def property_plot(logps, sass, qeds, data_properties):
9 | from matplotlib import pyplot as plt
10 | import seaborn as sns
11 | import pandas as pd
12 | data_logps, data_sass, data_qeds = data_properties
13 | assert len(data_logps) == len(data_sass) == len(data_qeds)
14 | assert len(logps) == len(sass) == len(qeds)
15 | assert len(data_logps) == len(logps)
16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps)
17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps})
18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"])
19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass})
20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"])
21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds})
22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"])
23 |
24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8])
25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12)
26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0])
27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1])
28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2])
29 |
30 | plt.show()
31 | plt.close(fig)
32 |
33 |
34 | def get_data_properties(path, seq_length=110):
35 | with open(path) as f:
36 | data_lines = f.readlines()
37 |
38 | logps = []
39 | sass = []
40 | qeds = []
41 | mols = []
42 | title = ""
43 | for line in data_lines:
44 | if line[0] == "#":
45 | title = line[1:-1]
46 | title_list = title.split()
47 | print(title_list)
48 | continue
49 | arr = line.split()
50 | # print(arr)
51 | if len(arr) < 2:
52 | continue
53 | smiles = arr[0]
54 | if len(smiles) > seq_length:
55 | continue
56 | assert len(arr) == 6
57 | mols.append(arr[0])
58 | logps.append(arr[1])
59 | sass.append(arr[2])
60 | qeds.append(arr[3])
61 | # smiles0 = smiles.ljust(seq_length, '>')
62 | # smiles_list += [smiles]
63 | # Narr = len(arr)
64 | # # cdd = []
65 | # for i in range(1, Narr):
66 | # if title_list[i] == "logP":
67 | # cdd += [float(arr[i])/10.0]
68 | # elif title_list[i] == "SAS":
69 | # cdd += [float(arr[i])/10.0]
70 | # elif title_list[i] == "QED":
71 | # cdd += [float(arr[i])/1.0]
72 | # elif title_list[i] == "MW":
73 | # cdd += [float(arr[i])/500.0]
74 | # elif title_list[i] == "TPSA":
75 | # cdd += [float(arr[i])/150.0]
76 | return mols, (logps, sass, qeds)
77 |
78 |
79 | def _convert_keys_to_int(dict):
80 | new_dict = {}
81 | for k, v in dict.items():
82 | try:
83 | new_key = int(k)
84 | except ValueError:
85 | new_key = k
86 | new_dict[new_key] = v
87 | return new_dict
88 |
89 |
90 | if __name__ == '__main__':
91 | from args import get_args
92 | import numpy as np
93 | from rdkit import Chem
94 | from rdkit.Chem.Crippen import MolLogP
95 | import json
96 | import selfies as sf
97 | from tqdm import tqdm
98 | import random
99 |
100 | args = get_args()
101 | mols, data_properties = get_data_properties(args.test_file)
102 | # print(mols, len(mols))
103 |
104 | # smiles
105 | # logps, sass, qeds = [], [], []
106 | # for i, smi in enumerate(mols):
107 | # m = Chem.MolFromSmiles(smi)
108 | # p = MolLogP(m)
109 | # logps.append(p)
110 | # try:
111 | # SAS = sascorer.calculateScore(m)
112 | # except ZeroDivisionError:
113 | # SAS = 2.8
114 | # sass.append(SAS)
115 | # QED = qed(m)
116 | # qeds.append(QED)
117 | #
118 | # property_plot(logps, sass, qeds, data_properties)
119 |
120 | # selfies
121 | logps, sass, qeds = [], [], []
122 | dir_test = '../data/Xtrain.npy'
123 | sf_test = np.load(dir_test)
124 | json_file = json.load(open('../data/info.json'))
125 | vocab_itos = json_file['vocab_itos'][0]
126 | # change the keywords of vocab_itos from string to int
127 | vocab_itos = _convert_keys_to_int(vocab_itos)
128 |
129 | randomlist = random.sample(range(0, len(sf_test)), 10000)
130 | random_data = sf_test[randomlist]
131 | for i, data in enumerate(tqdm(random_data)):
132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label')
133 | m_smi = sf.decoder(m_sf)
134 | m = Chem.MolFromSmiles(m_smi)
135 | p = MolLogP(m)
136 | logps.append(p)
137 | try:
138 | SAS = sascorer.calculateScore(m)
139 | except ZeroDivisionError:
140 | SAS = 2.8
141 | sass.append(SAS)
142 | QED = qed(m)
143 | qeds.append(QED)
144 | property_plot(logps, sass, qeds, data_properties)
145 |
--------------------------------------------------------------------------------
/multi_design_esr1/stats.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | from rdkit.Chem.QED import qed
5 | import sascorer
6 |
7 |
8 | def property_plot(logps, sass, qeds, data_properties):
9 | from matplotlib import pyplot as plt
10 | import seaborn as sns
11 | import pandas as pd
12 | data_logps, data_sass, data_qeds = data_properties
13 | assert len(data_logps) == len(data_sass) == len(data_qeds)
14 | assert len(logps) == len(sass) == len(qeds)
15 | assert len(data_logps) == len(logps)
16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps)
17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps})
18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"])
19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass})
20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"])
21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds})
22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"])
23 |
24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8])
25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12)
26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0])
27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1])
28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2])
29 |
30 | plt.show()
31 | plt.close(fig)
32 |
33 |
34 | def get_data_properties(path, seq_length=110):
35 | with open(path) as f:
36 | data_lines = f.readlines()
37 |
38 | logps = []
39 | sass = []
40 | qeds = []
41 | mols = []
42 | title = ""
43 | for line in data_lines:
44 | if line[0] == "#":
45 | title = line[1:-1]
46 | title_list = title.split()
47 | print(title_list)
48 | continue
49 | arr = line.split()
50 | # print(arr)
51 | if len(arr) < 2:
52 | continue
53 | smiles = arr[0]
54 | if len(smiles) > seq_length:
55 | continue
56 | assert len(arr) == 6
57 | mols.append(arr[0])
58 | logps.append(arr[1])
59 | sass.append(arr[2])
60 | qeds.append(arr[3])
61 | # smiles0 = smiles.ljust(seq_length, '>')
62 | # smiles_list += [smiles]
63 | # Narr = len(arr)
64 | # # cdd = []
65 | # for i in range(1, Narr):
66 | # if title_list[i] == "logP":
67 | # cdd += [float(arr[i])/10.0]
68 | # elif title_list[i] == "SAS":
69 | # cdd += [float(arr[i])/10.0]
70 | # elif title_list[i] == "QED":
71 | # cdd += [float(arr[i])/1.0]
72 | # elif title_list[i] == "MW":
73 | # cdd += [float(arr[i])/500.0]
74 | # elif title_list[i] == "TPSA":
75 | # cdd += [float(arr[i])/150.0]
76 | return mols, (logps, sass, qeds)
77 |
78 |
79 | def _convert_keys_to_int(dict):
80 | new_dict = {}
81 | for k, v in dict.items():
82 | try:
83 | new_key = int(k)
84 | except ValueError:
85 | new_key = k
86 | new_dict[new_key] = v
87 | return new_dict
88 |
89 |
90 | if __name__ == '__main__':
91 | from args import get_args
92 | import numpy as np
93 | from rdkit import Chem
94 | from rdkit.Chem.Crippen import MolLogP
95 | import json
96 | import selfies as sf
97 | from tqdm import tqdm
98 | import random
99 |
100 | args = get_args()
101 | mols, data_properties = get_data_properties(args.test_file)
102 | # print(mols, len(mols))
103 |
104 | # smiles
105 | # logps, sass, qeds = [], [], []
106 | # for i, smi in enumerate(mols):
107 | # m = Chem.MolFromSmiles(smi)
108 | # p = MolLogP(m)
109 | # logps.append(p)
110 | # try:
111 | # SAS = sascorer.calculateScore(m)
112 | # except ZeroDivisionError:
113 | # SAS = 2.8
114 | # sass.append(SAS)
115 | # QED = qed(m)
116 | # qeds.append(QED)
117 | #
118 | # property_plot(logps, sass, qeds, data_properties)
119 |
120 | # selfies
121 | logps, sass, qeds = [], [], []
122 | dir_test = '../data/Xtrain.npy'
123 | sf_test = np.load(dir_test)
124 | json_file = json.load(open('../data/info.json'))
125 | vocab_itos = json_file['vocab_itos'][0]
126 | # change the keywords of vocab_itos from string to int
127 | vocab_itos = _convert_keys_to_int(vocab_itos)
128 |
129 | randomlist = random.sample(range(0, len(sf_test)), 10000)
130 | random_data = sf_test[randomlist]
131 | for i, data in enumerate(tqdm(random_data)):
132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label')
133 | m_smi = sf.decoder(m_sf)
134 | m = Chem.MolFromSmiles(m_smi)
135 | p = MolLogP(m)
136 | logps.append(p)
137 | try:
138 | SAS = sascorer.calculateScore(m)
139 | except ZeroDivisionError:
140 | SAS = 2.8
141 | sass.append(SAS)
142 | QED = qed(m)
143 | qeds.append(QED)
144 | property_plot(logps, sass, qeds, data_properties)
145 |
--------------------------------------------------------------------------------
/single_design_qed/stats.py:
--------------------------------------------------------------------------------
1 | from matplotlib import pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | from rdkit.Chem.QED import qed
5 | import sascorer
6 |
7 |
8 | def property_plot(logps, sass, qeds, data_properties):
9 | from matplotlib import pyplot as plt
10 | import seaborn as sns
11 | import pandas as pd
12 | data_logps, data_sass, data_qeds = data_properties
13 | assert len(data_logps) == len(data_sass) == len(data_qeds)
14 | assert len(logps) == len(sass) == len(qeds)
15 | assert len(data_logps) == len(logps)
16 | labels = ['Test'] * len(data_logps) + ['Train'] * len(logps)
17 | logp_df = pd.DataFrame(data={'model': labels, 'logP': data_logps + logps})
18 | logp_df["logP"] = pd.to_numeric(logp_df["logP"])
19 | sas_df = pd.DataFrame(data={'model': labels, 'SAS': data_sass + sass})
20 | sas_df["SAS"] = pd.to_numeric(sas_df["SAS"])
21 | qed_df = pd.DataFrame(data={'model': labels, 'QED': data_qeds + qeds})
22 | qed_df["QED"] = pd.to_numeric(qed_df["QED"])
23 |
24 | fig, ax = plt.subplots(nrows=1, ncols=3, figsize=[6.4 * 3, 4.8])
25 | # fig.suptitle("Epoch: " + str(epoch), fontsize=12)
26 | logp_plot = sns.kdeplot(data=logp_df, x='logP', hue='model', ax=ax[0])
27 | sas_plot = sns.kdeplot(data=sas_df, x='SAS', hue='model', ax=ax[1])
28 | qed_plot = sns.kdeplot(data=qed_df, x='QED', hue='model', ax=ax[2])
29 |
30 | plt.show()
31 | plt.close(fig)
32 |
33 |
34 | def get_data_properties(path, seq_length=110):
35 | with open(path) as f:
36 | data_lines = f.readlines()
37 |
38 | logps = []
39 | sass = []
40 | qeds = []
41 | mols = []
42 | title = ""
43 | for line in data_lines:
44 | if line[0] == "#":
45 | title = line[1:-1]
46 | title_list = title.split()
47 | print(title_list)
48 | continue
49 | arr = line.split()
50 | # print(arr)
51 | if len(arr) < 2:
52 | continue
53 | smiles = arr[0]
54 | if len(smiles) > seq_length:
55 | continue
56 | assert len(arr) == 6
57 | mols.append(arr[0])
58 | logps.append(arr[1])
59 | sass.append(arr[2])
60 | qeds.append(arr[3])
61 | # smiles0 = smiles.ljust(seq_length, '>')
62 | # smiles_list += [smiles]
63 | # Narr = len(arr)
64 | # # cdd = []
65 | # for i in range(1, Narr):
66 | # if title_list[i] == "logP":
67 | # cdd += [float(arr[i])/10.0]
68 | # elif title_list[i] == "SAS":
69 | # cdd += [float(arr[i])/10.0]
70 | # elif title_list[i] == "QED":
71 | # cdd += [float(arr[i])/1.0]
72 | # elif title_list[i] == "MW":
73 | # cdd += [float(arr[i])/500.0]
74 | # elif title_list[i] == "TPSA":
75 | # cdd += [float(arr[i])/150.0]
76 | return mols, (logps, sass, qeds)
77 |
78 |
79 | def _convert_keys_to_int(dict):
80 | new_dict = {}
81 | for k, v in dict.items():
82 | try:
83 | new_key = int(k)
84 | except ValueError:
85 | new_key = k
86 | new_dict[new_key] = v
87 | return new_dict
88 |
89 |
90 | if __name__ == '__main__':
91 | from args import get_args
92 | import numpy as np
93 | from rdkit import Chem
94 | from rdkit.Chem.Crippen import MolLogP
95 | import json
96 | import selfies as sf
97 | from tqdm import tqdm
98 | import random
99 |
100 | args = get_args()
101 | mols, data_properties = get_data_properties(args.test_file)
102 | # print(mols, len(mols))
103 |
104 | # smiles
105 | # logps, sass, qeds = [], [], []
106 | # for i, smi in enumerate(mols):
107 | # m = Chem.MolFromSmiles(smi)
108 | # p = MolLogP(m)
109 | # logps.append(p)
110 | # try:
111 | # SAS = sascorer.calculateScore(m)
112 | # except ZeroDivisionError:
113 | # SAS = 2.8
114 | # sass.append(SAS)
115 | # QED = qed(m)
116 | # qeds.append(QED)
117 | #
118 | # property_plot(logps, sass, qeds, data_properties)
119 |
120 | # selfies
121 | logps, sass, qeds = [], [], []
122 | dir_test = '../data/Xtrain.npy'
123 | sf_test = np.load(dir_test)
124 | json_file = json.load(open('../data/info.json'))
125 | vocab_itos = json_file['vocab_itos'][0]
126 | # change the keywords of vocab_itos from string to int
127 | vocab_itos = _convert_keys_to_int(vocab_itos)
128 |
129 | randomlist = random.sample(range(0, len(sf_test)), 10000)
130 | random_data = sf_test[randomlist]
131 | for i, data in enumerate(tqdm(random_data)):
132 | m_sf = sf.encoding_to_selfies(sf_test[i], vocab_itos, enc_type='label')
133 | m_smi = sf.decoder(m_sf)
134 | m = Chem.MolFromSmiles(m_smi)
135 | p = MolLogP(m)
136 | logps.append(p)
137 | try:
138 | SAS = sascorer.calculateScore(m)
139 | except ZeroDivisionError:
140 | SAS = 2.8
141 | sass.append(SAS)
142 | QED = qed(m)
143 | qeds.append(QED)
144 | property_plot(logps, sass, qeds, data_properties)
145 |
--------------------------------------------------------------------------------
/multi_design_acaa1/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 | import math
15 |
16 |
17 | # max_len = 72
18 | # num_of_embeddings = 109
19 | class MolDataset(Dataset):
20 | def __init__(self, datadir, dname):
21 | Xdata_file = datadir + "/X" + dname + ".npy"
22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
23 | Ldata_file = datadir + "/L" + dname + ".npy"
24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
25 | self.len = self.Xdata.shape[0]
26 | sasdata_file = datadir + "/sas_" + dname + ".npy"
27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32)
28 | qed_data_file = datadir + "/qed_" + dname + ".npy"
29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
30 |
31 | # ba0data_file = datadir + "/ba0_" + dname + ".npy"
32 | # self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32)
33 | ba1data_file = datadir + "/ba1_" + dname + ".npy"
34 | self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32)
35 | # self.preprocess()
36 |
37 | def preprocess(self):
38 | ind = torch.nonzero(self.ba0)
39 | self.ba0 = self.ba0[ind].squeeze()
40 | self.sas = self.sas[ind].squeeze()
41 | self.qed = self.qed[ind].squeeze()
42 | self.Xdata = self.Xdata[ind].squeeze()
43 | self.Ldata = self.Ldata[ind].squeeze()
44 | self.len = self.Xdata.shape[0]
45 |
46 | def __getitem__(self, index):
47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
48 | mol = self.Xdata[index]
49 | sos = torch.tensor([108], dtype=torch.long)
50 | mol = torch.cat([sos, mol], dim=0).contiguous()
51 | mask = torch.zeros(mol.shape[0] - 1)
52 | mask[:self.Ldata[index] + 1] = 1.
53 | return (mol, mask, -1 * self.ba1[index])
54 |
55 | def __len__(self):
56 | return self.len
57 |
58 | def save_mol_png(self, label, filepath, size=(600, 600)):
59 | m_smi = self.label2sf2smi(label)
60 | m = Chem.MolFromSmiles(m_smi)
61 | Draw.MolToFile(m, filepath, size=size)
62 |
63 | def label2sf2smi(self, label):
64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
65 | m_smi = sf.decoder(m_sf)
66 | m_smi = Chem.CanonSmiles(m_smi)
67 | return m_smi
68 |
69 | @staticmethod
70 | def delta_to_kd(x):
71 | return math.exp(x / (0.00198720425864083 * 298.15))
72 |
73 |
74 | if __name__ == '__main__':
75 | datadir = '../data'
76 | # Xdata_file = datadir + "/Xtrain.npy"
77 | # x_train = np.load(Xdata_file)
78 | # x_test = np.load('../data/Xtest.npy')
79 | # x_all = np.concatenate((x_train, x_test), axis=0)
80 | # print(x_all.shape)
81 | #
82 | # ba_train = np.load('../data/ba0_train.npy')
83 | # ba_test = np.load('../data/ba0_test.npy')
84 | # ba_all = np.concatenate((ba_train, ba_test), axis=0)
85 | # print(ba_all.shape)
86 | # ind = np.argsort(ba_all)
87 | # ind = ind[:10000]
88 | #
89 | # sas_train = np.load('../data/sas_train.npy')
90 | # sas_test = np.load('../data/sas_test.npy')
91 | # sas_all = np.concatenate((sas_train, sas_test), axis=0)
92 | #
93 | # l_train = np.load('../data/Ltrain.npy')
94 | # l_test = np.load('../data/Ltest.npy')
95 | # l_all = np.concatenate((l_train, l_test), axis=0)
96 | #
97 | # qed_train = np.load('../data/qed_train.npy')
98 | # qed_test = np.load('../data/qed_test.npy')
99 | # qed_all = np.concatenate((sas_train, sas_test), axis=0)
100 | #
101 | # sas_design = sas_all[ind]
102 | # ba_design = ba_all[ind]
103 | # x_design = x_all[ind]
104 | # qed_design = qed_all[ind]
105 | # l_design = l_all[ind]
106 | #
107 | # np.save('../data/Xdesign.npy', x_design)
108 | # np.save('../data/ba0_design.npy', ba_design)
109 | # np.save('../data/qed_design.npy', qed_design)
110 | # np.save('../data/sas_design.npy', sas_design)
111 | # np.save('../data/Ldesign.npy', l_design)
112 |
113 | ds = MolDataset(datadir, 'train')
114 | ds_loader = DataLoader(dataset=ds, batch_size=100,
115 | shuffle=True, drop_last=True, num_workers=2)
116 |
117 | s = ds.ba1
118 | print(len(s))
119 | print(s[:10])
120 | # print(ds.Xdata[:10])
121 | # print(ds.len)
122 | # non_zero = torch.nonzero(s)
123 | # print(non_zero)
124 | # s = s[non_zero]
125 | # print(len(s))
126 | # print(s[1000:2000])
127 | # print(torch.max(s), torch.min(s))
128 | # print(ds.delta_to_kd(torch.max(s)), ds.delta_to_kd(torch.min(s)))
129 | # max_len = 0
130 | # for i in range(len(ds)):
131 | # print(len(ds))
132 | # m, len, ba, sas, qed = ds[i]
133 | # m = m.numpy()
134 | # mol = ds.label2sf2smi(m[1:])
135 | # print(mol)
136 | # print(ba, sas, qed)
137 | # break
138 |
139 | # print(char_dict)
140 | # print(ds.sf2smi(m), s)
141 |
--------------------------------------------------------------------------------
/multi_design_esr1/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import numpy as np
4 | from ZINC.char import char_list, char_dict
5 | import selfies as sf
6 | import json
7 | from rdkit.Chem import Draw
8 | from rdkit import Chem
9 | from rdkit.Chem.Crippen import MolLogP
10 | from rdkit.Chem.QED import qed
11 | from rdkit.Chem.Descriptors import ExactMolWt
12 | from rdkit.Chem.rdMolDescriptors import CalcTPSA
13 | import sascorer
14 | import math
15 |
16 |
17 | # max_len = 72
18 | # num_of_embeddings = 109
19 | class MolDataset(Dataset):
20 | def __init__(self, datadir, dname):
21 | Xdata_file = datadir + "/X" + dname + ".npy"
22 | self.Xdata = torch.tensor(np.load(Xdata_file), dtype=torch.long) # number-coded molecule
23 | Ldata_file = datadir + "/L" + dname + ".npy"
24 | self.Ldata = torch.tensor(np.load(Ldata_file), dtype=torch.long) # length of each molecule
25 | self.len = self.Xdata.shape[0]
26 | sasdata_file = datadir + "/sas_" + dname + ".npy"
27 | self.sas = torch.tensor(np.load(sasdata_file), dtype=torch.float32)
28 | qed_data_file = datadir + "/qed_" + dname + ".npy"
29 | self.qed = torch.tensor(np.load(qed_data_file), dtype=torch.float32)
30 |
31 | ba0data_file = datadir + "/ba0_" + dname + ".npy"
32 | self.ba0 = torch.tensor(np.load(ba0data_file), dtype=torch.float32)
33 | # ba1data_file = datadir + "/ba1_" + dname + ".npy"
34 | # self.ba1 = torch.tensor(np.load(ba1data_file), dtype=torch.float32)
35 | # self.preprocess()
36 |
37 | def preprocess(self):
38 | ind = torch.nonzero(self.ba0)
39 | self.ba0 = self.ba0[ind].squeeze()
40 | self.sas = self.sas[ind].squeeze()
41 | self.qed = self.qed[ind].squeeze()
42 | self.Xdata = self.Xdata[ind].squeeze()
43 | self.Ldata = self.Ldata[ind].squeeze()
44 | self.len = self.Xdata.shape[0]
45 |
46 | def __getitem__(self, index):
47 | # Add sos=108 for each sequence and this sos is not shown in char_list and char_dict as in selfies.
48 | mol = self.Xdata[index]
49 | sos = torch.tensor([108], dtype=torch.long)
50 | mol = torch.cat([sos, mol], dim=0).contiguous()
51 | mask = torch.zeros(mol.shape[0] - 1)
52 | mask[:self.Ldata[index] + 1] = 1.
53 | return (mol, mask, -1 * self.ba0[index])
54 |
55 | def __len__(self):
56 | return self.len
57 |
58 | def save_mol_png(self, label, filepath, size=(600, 600)):
59 | m_smi = self.label2sf2smi(label)
60 | m = Chem.MolFromSmiles(m_smi)
61 | Draw.MolToFile(m, filepath, size=size)
62 |
63 | def label2sf2smi(self, label):
64 | m_sf = sf.encoding_to_selfies(label, char_dict, enc_type='label')
65 | m_smi = sf.decoder(m_sf)
66 | m_smi = Chem.CanonSmiles(m_smi)
67 | return m_smi
68 |
69 | @staticmethod
70 | def delta_to_kd(x):
71 | return math.exp(x / (0.00198720425864083 * 298.15))
72 |
73 |
74 | if __name__ == '__main__':
75 | datadir = '../data'
76 | # Xdata_file = datadir + "/Xtrain.npy"
77 | # x_train = np.load(Xdata_file)
78 | # x_test = np.load('../data/Xtest.npy')
79 | # x_all = np.concatenate((x_train, x_test), axis=0)
80 | # print(x_all.shape)
81 | #
82 | # ba_train = np.load('../data/ba0_train.npy')
83 | # ba_test = np.load('../data/ba0_test.npy')
84 | # ba_all = np.concatenate((ba_train, ba_test), axis=0)
85 | # print(ba_all.shape)
86 | # ind = np.argsort(ba_all)
87 | # ind = ind[:10000]
88 | #
89 | # sas_train = np.load('../data/sas_train.npy')
90 | # sas_test = np.load('../data/sas_test.npy')
91 | # sas_all = np.concatenate((sas_train, sas_test), axis=0)
92 | #
93 | # l_train = np.load('../data/Ltrain.npy')
94 | # l_test = np.load('../data/Ltest.npy')
95 | # l_all = np.concatenate((l_train, l_test), axis=0)
96 | #
97 | # qed_train = np.load('../data/qed_train.npy')
98 | # qed_test = np.load('../data/qed_test.npy')
99 | # qed_all = np.concatenate((sas_train, sas_test), axis=0)
100 | #
101 | # sas_design = sas_all[ind]
102 | # ba_design = ba_all[ind]
103 | # x_design = x_all[ind]
104 | # qed_design = qed_all[ind]
105 | # l_design = l_all[ind]
106 | #
107 | # np.save('../data/Xdesign.npy', x_design)
108 | # np.save('../data/ba0_design.npy', ba_design)
109 | # np.save('../data/qed_design.npy', qed_design)
110 | # np.save('../data/sas_design.npy', sas_design)
111 | # np.save('../data/Ldesign.npy', l_design)
112 |
113 | ds = MolDataset(datadir, 'train')
114 | ds_loader = DataLoader(dataset=ds, batch_size=100,
115 | shuffle=True, drop_last=True, num_workers=2)
116 |
117 | s = ds.ba1
118 | print(len(s))
119 | print(s[:10])
120 | # print(ds.Xdata[:10])
121 | # print(ds.len)
122 | # non_zero = torch.nonzero(s)
123 | # print(non_zero)
124 | # s = s[non_zero]
125 | # print(len(s))
126 | # print(s[1000:2000])
127 | # print(torch.max(s), torch.min(s))
128 | # print(ds.delta_to_kd(torch.max(s)), ds.delta_to_kd(torch.min(s)))
129 | # max_len = 0
130 | # for i in range(len(ds)):
131 | # print(len(ds))
132 | # m, len, ba, sas, qed = ds[i]
133 | # m = m.numpy()
134 | # mol = ds.label2sf2smi(m[1:])
135 | # print(mol)
136 | # print(ba, sas, qed)
137 | # break
138 |
139 | # print(char_dict)
140 | # print(ds.sf2smi(m), s)
141 |
--------------------------------------------------------------------------------
/multi_design_acaa1/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/multi_design_esr1/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/single_design_acaa1/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/single_design_esr1/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/single_design_plogp/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/single_design_qed/sascorer.py:
--------------------------------------------------------------------------------
1 | #
2 | # calculation of synthetic accessibility score as described in:
3 | #
4 | # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5 | # Peter Ertl and Ansgar Schuffenhauer
6 | # Journal of Cheminformatics 1:8 (2009)
7 | # http://www.jcheminf.com/content/1/1/8
8 | #
9 | # several small modifications to the original paper are included
10 | # particularly slightly different formula for marocyclic penalty
11 | # and taking into account also molecule symmetry (fingerprint density)
12 | #
13 | # for a set of 10k diverse molecules the agreement between the original method
14 | # as implemented in PipelinePilot and this implementation is r2 = 0.97
15 | #
16 | # peter ertl & greg landrum, september 2013
17 | #
18 |
19 |
20 | from rdkit import Chem
21 | from rdkit.Chem import rdMolDescriptors
22 | import pickle
23 |
24 | import math
25 | from collections import defaultdict
26 |
27 | import os.path as op
28 |
29 | _fscores = None
30 |
31 |
32 | def readFragmentScores(name='fpscores'):
33 | import gzip
34 | global _fscores
35 | # generate the full path filename:
36 | if name == "fpscores":
37 | name = op.join(op.dirname(__file__), name)
38 | data = pickle.load(gzip.open('%s.pkl.gz' % name))
39 | outDict = {}
40 | for i in data:
41 | for j in range(1, len(i)):
42 | outDict[i[j]] = float(i[0])
43 | _fscores = outDict
44 |
45 |
46 | def numBridgeheadsAndSpiro(mol, ri=None):
47 | nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
48 | nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
49 | return nBridgehead, nSpiro
50 |
51 |
52 | def calculateScore(m):
53 | if _fscores is None:
54 | readFragmentScores()
55 |
56 | # fragment score
57 | fp = rdMolDescriptors.GetMorganFingerprint(m,
58 | 2) # <- 2 is the *radius* of the circular fingerprint
59 | fps = fp.GetNonzeroElements()
60 | score1 = 0.
61 | nf = 0
62 | for bitId, v in fps.items():
63 | nf += v
64 | sfp = bitId
65 | score1 += _fscores.get(sfp, -4) * v
66 | score1 /= nf
67 |
68 | # features score
69 | nAtoms = m.GetNumAtoms()
70 | nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
71 | ri = m.GetRingInfo()
72 | nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
73 | nMacrocycles = 0
74 | for x in ri.AtomRings():
75 | if len(x) > 8:
76 | nMacrocycles += 1
77 |
78 | sizePenalty = nAtoms**1.005 - nAtoms
79 | stereoPenalty = math.log10(nChiralCenters + 1)
80 | spiroPenalty = math.log10(nSpiro + 1)
81 | bridgePenalty = math.log10(nBridgeheads + 1)
82 | macrocyclePenalty = 0.
83 | # ---------------------------------------
84 | # This differs from the paper, which defines:
85 | # macrocyclePenalty = math.log10(nMacrocycles+1)
86 | # This form generates better results when 2 or more macrocycles are present
87 | if nMacrocycles > 0:
88 | macrocyclePenalty = math.log10(2)
89 |
90 | score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
91 |
92 | # correction for the fingerprint density
93 | # not in the original publication, added in version 1.1
94 | # to make highly symmetrical molecules easier to synthetise
95 | score3 = 0.
96 | if nAtoms > len(fps):
97 | score3 = math.log(float(nAtoms) / len(fps)) * .5
98 |
99 | sascore = score1 + score2 + score3
100 |
101 | # need to transform "raw" value into scale between 1 and 10
102 | min = -4.0
103 | max = 2.5
104 | sascore = 11. - (sascore - min + 1) / (max - min) * 9.
105 | # smooth the 10-end
106 | if sascore > 8.:
107 | sascore = 8. + math.log(sascore + 1. - 9.)
108 | if sascore > 10.:
109 | sascore = 10.0
110 | elif sascore < 1.:
111 | sascore = 1.0
112 |
113 | return sascore
114 |
115 |
116 | def processMols(mols):
117 | print('smiles\tName\tsa_score')
118 | for i, m in enumerate(mols):
119 | if m is None:
120 | continue
121 |
122 | s = calculateScore(m)
123 |
124 | smiles = Chem.MolToSmiles(m)
125 | print(smiles + "\t" + m.GetProp('_Name') + "\t%3f" % s)
126 |
127 |
128 | if __name__ == '__main__':
129 | import sys
130 | import time
131 |
132 | t1 = time.time()
133 | readFragmentScores("fpscores")
134 | t2 = time.time()
135 |
136 | suppl = Chem.SmilesMolSupplier(sys.argv[1])
137 | t3 = time.time()
138 | processMols(suppl)
139 | t4 = time.time()
140 |
141 | print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
142 | file=sys.stderr)
143 |
144 | #
145 | # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
146 | # All rights reserved.
147 | #
148 | # Redistribution and use in source and binary forms, with or without
149 | # modification, are permitted provided that the following conditions are
150 | # met:
151 | #
152 | # * Redistributions of source code must retain the above copyright
153 | # notice, this list of conditions and the following disclaimer.
154 | # * Redistributions in binary form must reproduce the above
155 | # copyright notice, this list of conditions and the following
156 | # disclaimer in the documentation and/or other materials provided
157 | # with the distribution.
158 | # * Neither the name of Novartis Institutes for BioMedical Research Inc.
159 | # nor the names of its contributors may be used to endorse or promote
160 | # products derived from this software without specific prior written permission.
161 | #
162 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
163 | # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
164 | # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
165 | # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
166 | # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
167 | # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
168 | # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
169 | # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
170 | # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
171 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
172 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
173 | #
--------------------------------------------------------------------------------
/multi_design_acaa1/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import shutil
6 | import os.path as osp
7 | import json
8 |
9 | import dateutil.tz
10 |
11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv']
12 |
13 | DEBUG = 10
14 | INFO = 20
15 | WARN = 30
16 | ERROR = 40
17 |
18 | DISABLED = 50
19 |
20 |
21 | class OutputFormat(object):
22 | def writekvs(self, kvs):
23 | """
24 | Write key-value pairs
25 | """
26 | raise NotImplementedError
27 |
28 | def writeseq(self, args):
29 | """
30 | Write a sequence of other data (e.g. a logging message)
31 | """
32 | pass
33 |
34 | def close(self):
35 | return
36 |
37 |
38 | class HumanOutputFormat(OutputFormat):
39 | def __init__(self, file):
40 | self.file = file
41 |
42 | def writekvs(self, kvs):
43 | # Create strings for printing
44 | key2str = OrderedDict()
45 | for (key, val) in kvs.items():
46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val
47 | key2str[self._truncate(key)] = self._truncate(valstr)
48 |
49 | # Find max widths
50 | keywidth = max(map(len, key2str.keys()))
51 | valwidth = max(map(len, key2str.values()))
52 |
53 | # Write out the data
54 | dashes = '-' * (keywidth + valwidth + 7)
55 | lines = [dashes]
56 | for (key, val) in key2str.items():
57 | lines.append('| %s%s | %s%s |' % (
58 | key,
59 | ' ' * (keywidth - len(key)),
60 | val,
61 | ' ' * (valwidth - len(val)),
62 | ))
63 | lines.append(dashes)
64 | self.file.write('\n'.join(lines) + '\n')
65 |
66 | # Flush the output to the file
67 | self.file.flush()
68 |
69 | def _truncate(self, s):
70 | return s[:20] + '...' if len(s) > 23 else s
71 |
72 | def writeseq(self, args):
73 | for arg in args:
74 | self.file.write(arg)
75 | self.file.write('\n')
76 | self.file.flush()
77 |
78 |
79 | class JSONOutputFormat(OutputFormat):
80 | def __init__(self, file):
81 | self.file = file
82 |
83 | def writekvs(self, kvs):
84 | for k, v in kvs.items():
85 | if hasattr(v, 'dtype'):
86 | v = v.tolist()
87 | kvs[k] = v
88 | self.file.write(json.dumps(kvs) + '\n')
89 | self.file.flush()
90 |
91 | def close(self):
92 | self.file.close()
93 |
94 |
95 | class CSVOutputFormat(OutputFormat):
96 | def __init__(self, file):
97 | self.file = file
98 | self.keys = []
99 | self.sep = ','
100 |
101 | def writekvs(self, kvs):
102 | # Add our current row to the history
103 | extra_keys = kvs.keys() - self.keys
104 | if extra_keys:
105 | self.keys.extend(extra_keys)
106 | self.file.seek(0)
107 | lines = self.file.readlines()
108 | self.file.seek(0)
109 | for (i, k) in enumerate(self.keys):
110 | if i > 0:
111 | self.file.write(',')
112 | self.file.write(k)
113 | self.file.write('\n')
114 | for line in lines[1:]:
115 | self.file.write(line[:-1])
116 | self.file.write(self.sep * len(extra_keys))
117 | self.file.write('\n')
118 | for (i, k) in enumerate(self.keys):
119 | if i > 0:
120 | self.file.write(',')
121 | v = kvs.get(k)
122 | if v is not None:
123 | self.file.write(str(v))
124 | self.file.write('\n')
125 | self.file.flush()
126 |
127 | def close(self):
128 | self.file.close()
129 |
130 |
131 | def make_output_format(format, ev_dir):
132 | os.makedirs(ev_dir, exist_ok=True)
133 | if format == 'stdout':
134 | return HumanOutputFormat(sys.stdout)
135 | elif format == 'log':
136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
137 | return HumanOutputFormat(log_file)
138 | elif format == 'json':
139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
140 | return JSONOutputFormat(json_file)
141 | elif format == 'csv':
142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t')
143 | return CSVOutputFormat(csv_file)
144 | else:
145 | raise ValueError('Unknown format specified: %s' % (format,))
146 |
147 |
148 | # ================================================================
149 | # API
150 | # ================================================================
151 | def log_params(params):
152 | assert isinstance(params, dict)
153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt')
154 | output_format = JSONOutputFormat(json_file)
155 | output_format.writekvs(params)
156 | output_format.close()
157 |
158 | def logkv(key, val):
159 | """
160 | Log a value of some diagnostic
161 | Call this once for each diagnostic quantity, each iteration
162 | """
163 | Logger.CURRENT.logkv(key, val)
164 |
165 |
166 | def dumpkvs():
167 | """
168 | Write all of the diagnostics from the current iteration
169 |
170 | level: int. (see old_logger.py docs) If the global logger level is higher than
171 | the level argument here, don't print to stdout.
172 | """
173 | Logger.CURRENT.dumpkvs()
174 |
175 |
176 | # for backwards compatibility
177 | record_tabular = logkv
178 | dump_tabular = dumpkvs
179 |
180 |
181 | def log(*args, level=INFO):
182 | """
183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
184 | """
185 | Logger.CURRENT.log(*args, level=level)
186 |
187 |
188 | def debug(*args):
189 | log(*args, level=DEBUG)
190 |
191 |
192 | def info(*args):
193 | log(*args, level=INFO)
194 |
195 |
196 | def warn(*args):
197 | log(*args, level=WARN)
198 |
199 |
200 | def error(*args):
201 | log(*args, level=ERROR)
202 |
203 |
204 | def set_level(level):
205 | """
206 | Set logging threshold on current logger.
207 | """
208 | Logger.CURRENT.set_level(level)
209 |
210 |
211 | def get_level():
212 | """
213 | Set logging threshold on current logger.
214 | """
215 | return Logger.CURRENT.level
216 |
217 |
218 | def get_dir():
219 | """
220 | Get directory that log files are being written to.
221 | will be None if there is no output directory (i.e., if you didn't call start)
222 | """
223 | return Logger.CURRENT.get_dir()
224 |
225 |
226 | def get_expt_dir():
227 | sys.stderr.write(
228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
229 | return get_dir()
230 |
231 |
232 | # ================================================================
233 | # Backend
234 | # ================================================================
235 |
236 |
237 | class Logger(object):
238 | # A logger with no output files. (See right below class definition)
239 | DEFAULT = None
240 | # So that you can still log to the terminal without setting up any output files
241 | CURRENT = None # Current logger being used by the free functions above
242 |
243 | def __init__(self, dir, output_formats):
244 | self.name2val = OrderedDict() # values this iteration
245 | self.level = INFO
246 | self.dir = dir
247 | self.output_formats = output_formats
248 |
249 | # Logging API, forwarded
250 | # ----------------------------------------
251 | def logkv(self, key, val):
252 | self.name2val[key] = val
253 |
254 | def dumpkvs(self):
255 | for fmt in self.output_formats:
256 | fmt.writekvs(self.name2val)
257 | self.name2val.clear()
258 |
259 | def log(self, *args, level=INFO):
260 | now = datetime.datetime.now(dateutil.tz.tzlocal())
261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ')
262 | if self.level <= level:
263 | self._do_log((timestamp,) + args)
264 |
265 | # Configuration
266 | # ----------------------------------------
267 | def set_level(self, level):
268 | self.level = level
269 |
270 | def get_dir(self):
271 | return self.dir
272 |
273 | def close(self):
274 | for fmt in self.output_formats:
275 | fmt.close()
276 |
277 | # Misc
278 | # ----------------------------------------
279 | def _do_log(self, args):
280 | for fmt in self.output_formats:
281 | fmt.writeseq(args)
282 |
283 |
284 | # ================================================================
285 |
286 | Logger.DEFAULT = Logger(
287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
288 | Logger.CURRENT = Logger.DEFAULT
289 |
290 |
291 | class session(object):
292 | """
293 | Context manager that sets up the loggers for an experiment.
294 | """
295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS):
296 | self.dir = dir
297 | self.format_strs = format_strs
298 |
299 | def __enter__(self):
300 | os.makedirs(self.dir, exist_ok=True)
301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs]
302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
303 |
304 | def __exit__(self, *args):
305 | Logger.CURRENT.close()
306 | Logger.CURRENT = Logger.DEFAULT
307 |
308 |
309 | # ================================================================
310 |
311 |
312 | def main():
313 | info("hi")
314 | debug("shouldn't appear")
315 | set_level(DEBUG)
316 | debug("should appear")
317 | dir = "/tmp/testlogging"
318 | if os.path.exists(dir):
319 | shutil.rmtree(dir)
320 | with session(dir=dir):
321 | record_tabular("a", 3)
322 | record_tabular("b", 2.5)
323 | dump_tabular()
324 | record_tabular("b", -2.5)
325 | record_tabular("a", 5.5)
326 | dump_tabular()
327 | info("^^^ should see a = 5.5")
328 |
329 | record_tabular("b", -2.5)
330 | dump_tabular()
331 |
332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
333 | dump_tabular()
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
--------------------------------------------------------------------------------
/multi_design_esr1/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import shutil
6 | import os.path as osp
7 | import json
8 |
9 | import dateutil.tz
10 |
11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv']
12 |
13 | DEBUG = 10
14 | INFO = 20
15 | WARN = 30
16 | ERROR = 40
17 |
18 | DISABLED = 50
19 |
20 |
21 | class OutputFormat(object):
22 | def writekvs(self, kvs):
23 | """
24 | Write key-value pairs
25 | """
26 | raise NotImplementedError
27 |
28 | def writeseq(self, args):
29 | """
30 | Write a sequence of other data (e.g. a logging message)
31 | """
32 | pass
33 |
34 | def close(self):
35 | return
36 |
37 |
38 | class HumanOutputFormat(OutputFormat):
39 | def __init__(self, file):
40 | self.file = file
41 |
42 | def writekvs(self, kvs):
43 | # Create strings for printing
44 | key2str = OrderedDict()
45 | for (key, val) in kvs.items():
46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val
47 | key2str[self._truncate(key)] = self._truncate(valstr)
48 |
49 | # Find max widths
50 | keywidth = max(map(len, key2str.keys()))
51 | valwidth = max(map(len, key2str.values()))
52 |
53 | # Write out the data
54 | dashes = '-' * (keywidth + valwidth + 7)
55 | lines = [dashes]
56 | for (key, val) in key2str.items():
57 | lines.append('| %s%s | %s%s |' % (
58 | key,
59 | ' ' * (keywidth - len(key)),
60 | val,
61 | ' ' * (valwidth - len(val)),
62 | ))
63 | lines.append(dashes)
64 | self.file.write('\n'.join(lines) + '\n')
65 |
66 | # Flush the output to the file
67 | self.file.flush()
68 |
69 | def _truncate(self, s):
70 | return s[:20] + '...' if len(s) > 23 else s
71 |
72 | def writeseq(self, args):
73 | for arg in args:
74 | self.file.write(arg)
75 | self.file.write('\n')
76 | self.file.flush()
77 |
78 |
79 | class JSONOutputFormat(OutputFormat):
80 | def __init__(self, file):
81 | self.file = file
82 |
83 | def writekvs(self, kvs):
84 | for k, v in kvs.items():
85 | if hasattr(v, 'dtype'):
86 | v = v.tolist()
87 | kvs[k] = v
88 | self.file.write(json.dumps(kvs) + '\n')
89 | self.file.flush()
90 |
91 | def close(self):
92 | self.file.close()
93 |
94 |
95 | class CSVOutputFormat(OutputFormat):
96 | def __init__(self, file):
97 | self.file = file
98 | self.keys = []
99 | self.sep = ','
100 |
101 | def writekvs(self, kvs):
102 | # Add our current row to the history
103 | extra_keys = kvs.keys() - self.keys
104 | if extra_keys:
105 | self.keys.extend(extra_keys)
106 | self.file.seek(0)
107 | lines = self.file.readlines()
108 | self.file.seek(0)
109 | for (i, k) in enumerate(self.keys):
110 | if i > 0:
111 | self.file.write(',')
112 | self.file.write(k)
113 | self.file.write('\n')
114 | for line in lines[1:]:
115 | self.file.write(line[:-1])
116 | self.file.write(self.sep * len(extra_keys))
117 | self.file.write('\n')
118 | for (i, k) in enumerate(self.keys):
119 | if i > 0:
120 | self.file.write(',')
121 | v = kvs.get(k)
122 | if v is not None:
123 | self.file.write(str(v))
124 | self.file.write('\n')
125 | self.file.flush()
126 |
127 | def close(self):
128 | self.file.close()
129 |
130 |
131 | def make_output_format(format, ev_dir):
132 | os.makedirs(ev_dir, exist_ok=True)
133 | if format == 'stdout':
134 | return HumanOutputFormat(sys.stdout)
135 | elif format == 'log':
136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
137 | return HumanOutputFormat(log_file)
138 | elif format == 'json':
139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
140 | return JSONOutputFormat(json_file)
141 | elif format == 'csv':
142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t')
143 | return CSVOutputFormat(csv_file)
144 | else:
145 | raise ValueError('Unknown format specified: %s' % (format,))
146 |
147 |
148 | # ================================================================
149 | # API
150 | # ================================================================
151 | def log_params(params):
152 | assert isinstance(params, dict)
153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt')
154 | output_format = JSONOutputFormat(json_file)
155 | output_format.writekvs(params)
156 | output_format.close()
157 |
158 | def logkv(key, val):
159 | """
160 | Log a value of some diagnostic
161 | Call this once for each diagnostic quantity, each iteration
162 | """
163 | Logger.CURRENT.logkv(key, val)
164 |
165 |
166 | def dumpkvs():
167 | """
168 | Write all of the diagnostics from the current iteration
169 |
170 | level: int. (see old_logger.py docs) If the global logger level is higher than
171 | the level argument here, don't print to stdout.
172 | """
173 | Logger.CURRENT.dumpkvs()
174 |
175 |
176 | # for backwards compatibility
177 | record_tabular = logkv
178 | dump_tabular = dumpkvs
179 |
180 |
181 | def log(*args, level=INFO):
182 | """
183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
184 | """
185 | Logger.CURRENT.log(*args, level=level)
186 |
187 |
188 | def debug(*args):
189 | log(*args, level=DEBUG)
190 |
191 |
192 | def info(*args):
193 | log(*args, level=INFO)
194 |
195 |
196 | def warn(*args):
197 | log(*args, level=WARN)
198 |
199 |
200 | def error(*args):
201 | log(*args, level=ERROR)
202 |
203 |
204 | def set_level(level):
205 | """
206 | Set logging threshold on current logger.
207 | """
208 | Logger.CURRENT.set_level(level)
209 |
210 |
211 | def get_level():
212 | """
213 | Set logging threshold on current logger.
214 | """
215 | return Logger.CURRENT.level
216 |
217 |
218 | def get_dir():
219 | """
220 | Get directory that log files are being written to.
221 | will be None if there is no output directory (i.e., if you didn't call start)
222 | """
223 | return Logger.CURRENT.get_dir()
224 |
225 |
226 | def get_expt_dir():
227 | sys.stderr.write(
228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
229 | return get_dir()
230 |
231 |
232 | # ================================================================
233 | # Backend
234 | # ================================================================
235 |
236 |
237 | class Logger(object):
238 | # A logger with no output files. (See right below class definition)
239 | DEFAULT = None
240 | # So that you can still log to the terminal without setting up any output files
241 | CURRENT = None # Current logger being used by the free functions above
242 |
243 | def __init__(self, dir, output_formats):
244 | self.name2val = OrderedDict() # values this iteration
245 | self.level = INFO
246 | self.dir = dir
247 | self.output_formats = output_formats
248 |
249 | # Logging API, forwarded
250 | # ----------------------------------------
251 | def logkv(self, key, val):
252 | self.name2val[key] = val
253 |
254 | def dumpkvs(self):
255 | for fmt in self.output_formats:
256 | fmt.writekvs(self.name2val)
257 | self.name2val.clear()
258 |
259 | def log(self, *args, level=INFO):
260 | now = datetime.datetime.now(dateutil.tz.tzlocal())
261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ')
262 | if self.level <= level:
263 | self._do_log((timestamp,) + args)
264 |
265 | # Configuration
266 | # ----------------------------------------
267 | def set_level(self, level):
268 | self.level = level
269 |
270 | def get_dir(self):
271 | return self.dir
272 |
273 | def close(self):
274 | for fmt in self.output_formats:
275 | fmt.close()
276 |
277 | # Misc
278 | # ----------------------------------------
279 | def _do_log(self, args):
280 | for fmt in self.output_formats:
281 | fmt.writeseq(args)
282 |
283 |
284 | # ================================================================
285 |
286 | Logger.DEFAULT = Logger(
287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
288 | Logger.CURRENT = Logger.DEFAULT
289 |
290 |
291 | class session(object):
292 | """
293 | Context manager that sets up the loggers for an experiment.
294 | """
295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS):
296 | self.dir = dir
297 | self.format_strs = format_strs
298 |
299 | def __enter__(self):
300 | os.makedirs(self.dir, exist_ok=True)
301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs]
302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
303 |
304 | def __exit__(self, *args):
305 | Logger.CURRENT.close()
306 | Logger.CURRENT = Logger.DEFAULT
307 |
308 |
309 | # ================================================================
310 |
311 |
312 | def main():
313 | info("hi")
314 | debug("shouldn't appear")
315 | set_level(DEBUG)
316 | debug("should appear")
317 | dir = "/tmp/testlogging"
318 | if os.path.exists(dir):
319 | shutil.rmtree(dir)
320 | with session(dir=dir):
321 | record_tabular("a", 3)
322 | record_tabular("b", 2.5)
323 | dump_tabular()
324 | record_tabular("b", -2.5)
325 | record_tabular("a", 5.5)
326 | dump_tabular()
327 | info("^^^ should see a = 5.5")
328 |
329 | record_tabular("b", -2.5)
330 | dump_tabular()
331 |
332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
333 | dump_tabular()
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
--------------------------------------------------------------------------------
/single_design_acaa1/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import shutil
6 | import os.path as osp
7 | import json
8 |
9 | import dateutil.tz
10 |
11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv']
12 |
13 | DEBUG = 10
14 | INFO = 20
15 | WARN = 30
16 | ERROR = 40
17 |
18 | DISABLED = 50
19 |
20 |
21 | class OutputFormat(object):
22 | def writekvs(self, kvs):
23 | """
24 | Write key-value pairs
25 | """
26 | raise NotImplementedError
27 |
28 | def writeseq(self, args):
29 | """
30 | Write a sequence of other data (e.g. a logging message)
31 | """
32 | pass
33 |
34 | def close(self):
35 | return
36 |
37 |
38 | class HumanOutputFormat(OutputFormat):
39 | def __init__(self, file):
40 | self.file = file
41 |
42 | def writekvs(self, kvs):
43 | # Create strings for printing
44 | key2str = OrderedDict()
45 | for (key, val) in kvs.items():
46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val
47 | key2str[self._truncate(key)] = self._truncate(valstr)
48 |
49 | # Find max widths
50 | keywidth = max(map(len, key2str.keys()))
51 | valwidth = max(map(len, key2str.values()))
52 |
53 | # Write out the data
54 | dashes = '-' * (keywidth + valwidth + 7)
55 | lines = [dashes]
56 | for (key, val) in key2str.items():
57 | lines.append('| %s%s | %s%s |' % (
58 | key,
59 | ' ' * (keywidth - len(key)),
60 | val,
61 | ' ' * (valwidth - len(val)),
62 | ))
63 | lines.append(dashes)
64 | self.file.write('\n'.join(lines) + '\n')
65 |
66 | # Flush the output to the file
67 | self.file.flush()
68 |
69 | def _truncate(self, s):
70 | return s[:20] + '...' if len(s) > 23 else s
71 |
72 | def writeseq(self, args):
73 | for arg in args:
74 | self.file.write(arg)
75 | self.file.write('\n')
76 | self.file.flush()
77 |
78 |
79 | class JSONOutputFormat(OutputFormat):
80 | def __init__(self, file):
81 | self.file = file
82 |
83 | def writekvs(self, kvs):
84 | for k, v in kvs.items():
85 | if hasattr(v, 'dtype'):
86 | v = v.tolist()
87 | kvs[k] = v
88 | self.file.write(json.dumps(kvs) + '\n')
89 | self.file.flush()
90 |
91 | def close(self):
92 | self.file.close()
93 |
94 |
95 | class CSVOutputFormat(OutputFormat):
96 | def __init__(self, file):
97 | self.file = file
98 | self.keys = []
99 | self.sep = ','
100 |
101 | def writekvs(self, kvs):
102 | # Add our current row to the history
103 | extra_keys = kvs.keys() - self.keys
104 | if extra_keys:
105 | self.keys.extend(extra_keys)
106 | self.file.seek(0)
107 | lines = self.file.readlines()
108 | self.file.seek(0)
109 | for (i, k) in enumerate(self.keys):
110 | if i > 0:
111 | self.file.write(',')
112 | self.file.write(k)
113 | self.file.write('\n')
114 | for line in lines[1:]:
115 | self.file.write(line[:-1])
116 | self.file.write(self.sep * len(extra_keys))
117 | self.file.write('\n')
118 | for (i, k) in enumerate(self.keys):
119 | if i > 0:
120 | self.file.write(',')
121 | v = kvs.get(k)
122 | if v is not None:
123 | self.file.write(str(v))
124 | self.file.write('\n')
125 | self.file.flush()
126 |
127 | def close(self):
128 | self.file.close()
129 |
130 |
131 | def make_output_format(format, ev_dir):
132 | os.makedirs(ev_dir, exist_ok=True)
133 | if format == 'stdout':
134 | return HumanOutputFormat(sys.stdout)
135 | elif format == 'log':
136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
137 | return HumanOutputFormat(log_file)
138 | elif format == 'json':
139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
140 | return JSONOutputFormat(json_file)
141 | elif format == 'csv':
142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t')
143 | return CSVOutputFormat(csv_file)
144 | else:
145 | raise ValueError('Unknown format specified: %s' % (format,))
146 |
147 |
148 | # ================================================================
149 | # API
150 | # ================================================================
151 | def log_params(params):
152 | assert isinstance(params, dict)
153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt')
154 | output_format = JSONOutputFormat(json_file)
155 | output_format.writekvs(params)
156 | output_format.close()
157 |
158 | def logkv(key, val):
159 | """
160 | Log a value of some diagnostic
161 | Call this once for each diagnostic quantity, each iteration
162 | """
163 | Logger.CURRENT.logkv(key, val)
164 |
165 |
166 | def dumpkvs():
167 | """
168 | Write all of the diagnostics from the current iteration
169 |
170 | level: int. (see old_logger.py docs) If the global logger level is higher than
171 | the level argument here, don't print to stdout.
172 | """
173 | Logger.CURRENT.dumpkvs()
174 |
175 |
176 | # for backwards compatibility
177 | record_tabular = logkv
178 | dump_tabular = dumpkvs
179 |
180 |
181 | def log(*args, level=INFO):
182 | """
183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
184 | """
185 | Logger.CURRENT.log(*args, level=level)
186 |
187 |
188 | def debug(*args):
189 | log(*args, level=DEBUG)
190 |
191 |
192 | def info(*args):
193 | log(*args, level=INFO)
194 |
195 |
196 | def warn(*args):
197 | log(*args, level=WARN)
198 |
199 |
200 | def error(*args):
201 | log(*args, level=ERROR)
202 |
203 |
204 | def set_level(level):
205 | """
206 | Set logging threshold on current logger.
207 | """
208 | Logger.CURRENT.set_level(level)
209 |
210 |
211 | def get_level():
212 | """
213 | Set logging threshold on current logger.
214 | """
215 | return Logger.CURRENT.level
216 |
217 |
218 | def get_dir():
219 | """
220 | Get directory that log files are being written to.
221 | will be None if there is no output directory (i.e., if you didn't call start)
222 | """
223 | return Logger.CURRENT.get_dir()
224 |
225 |
226 | def get_expt_dir():
227 | sys.stderr.write(
228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
229 | return get_dir()
230 |
231 |
232 | # ================================================================
233 | # Backend
234 | # ================================================================
235 |
236 |
237 | class Logger(object):
238 | # A logger with no output files. (See right below class definition)
239 | DEFAULT = None
240 | # So that you can still log to the terminal without setting up any output files
241 | CURRENT = None # Current logger being used by the free functions above
242 |
243 | def __init__(self, dir, output_formats):
244 | self.name2val = OrderedDict() # values this iteration
245 | self.level = INFO
246 | self.dir = dir
247 | self.output_formats = output_formats
248 |
249 | # Logging API, forwarded
250 | # ----------------------------------------
251 | def logkv(self, key, val):
252 | self.name2val[key] = val
253 |
254 | def dumpkvs(self):
255 | for fmt in self.output_formats:
256 | fmt.writekvs(self.name2val)
257 | self.name2val.clear()
258 |
259 | def log(self, *args, level=INFO):
260 | now = datetime.datetime.now(dateutil.tz.tzlocal())
261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ')
262 | if self.level <= level:
263 | self._do_log((timestamp,) + args)
264 |
265 | # Configuration
266 | # ----------------------------------------
267 | def set_level(self, level):
268 | self.level = level
269 |
270 | def get_dir(self):
271 | return self.dir
272 |
273 | def close(self):
274 | for fmt in self.output_formats:
275 | fmt.close()
276 |
277 | # Misc
278 | # ----------------------------------------
279 | def _do_log(self, args):
280 | for fmt in self.output_formats:
281 | fmt.writeseq(args)
282 |
283 |
284 | # ================================================================
285 |
286 | Logger.DEFAULT = Logger(
287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
288 | Logger.CURRENT = Logger.DEFAULT
289 |
290 |
291 | class session(object):
292 | """
293 | Context manager that sets up the loggers for an experiment.
294 | """
295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS):
296 | self.dir = dir
297 | self.format_strs = format_strs
298 |
299 | def __enter__(self):
300 | os.makedirs(self.dir, exist_ok=True)
301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs]
302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
303 |
304 | def __exit__(self, *args):
305 | Logger.CURRENT.close()
306 | Logger.CURRENT = Logger.DEFAULT
307 |
308 |
309 | # ================================================================
310 |
311 |
312 | def main():
313 | info("hi")
314 | debug("shouldn't appear")
315 | set_level(DEBUG)
316 | debug("should appear")
317 | dir = "/tmp/testlogging"
318 | if os.path.exists(dir):
319 | shutil.rmtree(dir)
320 | with session(dir=dir):
321 | record_tabular("a", 3)
322 | record_tabular("b", 2.5)
323 | dump_tabular()
324 | record_tabular("b", -2.5)
325 | record_tabular("a", 5.5)
326 | dump_tabular()
327 | info("^^^ should see a = 5.5")
328 |
329 | record_tabular("b", -2.5)
330 | dump_tabular()
331 |
332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
333 | dump_tabular()
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
--------------------------------------------------------------------------------
/single_design_esr1/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from collections import OrderedDict
3 | import os
4 | import sys
5 | import shutil
6 | import os.path as osp
7 | import json
8 |
9 | import dateutil.tz
10 |
11 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json', 'csv']
12 |
13 | DEBUG = 10
14 | INFO = 20
15 | WARN = 30
16 | ERROR = 40
17 |
18 | DISABLED = 50
19 |
20 |
21 | class OutputFormat(object):
22 | def writekvs(self, kvs):
23 | """
24 | Write key-value pairs
25 | """
26 | raise NotImplementedError
27 |
28 | def writeseq(self, args):
29 | """
30 | Write a sequence of other data (e.g. a logging message)
31 | """
32 | pass
33 |
34 | def close(self):
35 | return
36 |
37 |
38 | class HumanOutputFormat(OutputFormat):
39 | def __init__(self, file):
40 | self.file = file
41 |
42 | def writekvs(self, kvs):
43 | # Create strings for printing
44 | key2str = OrderedDict()
45 | for (key, val) in kvs.items():
46 | valstr = '%-8.5g' % (val,) if hasattr(val, '__float__') else val
47 | key2str[self._truncate(key)] = self._truncate(valstr)
48 |
49 | # Find max widths
50 | keywidth = max(map(len, key2str.keys()))
51 | valwidth = max(map(len, key2str.values()))
52 |
53 | # Write out the data
54 | dashes = '-' * (keywidth + valwidth + 7)
55 | lines = [dashes]
56 | for (key, val) in key2str.items():
57 | lines.append('| %s%s | %s%s |' % (
58 | key,
59 | ' ' * (keywidth - len(key)),
60 | val,
61 | ' ' * (valwidth - len(val)),
62 | ))
63 | lines.append(dashes)
64 | self.file.write('\n'.join(lines) + '\n')
65 |
66 | # Flush the output to the file
67 | self.file.flush()
68 |
69 | def _truncate(self, s):
70 | return s[:20] + '...' if len(s) > 23 else s
71 |
72 | def writeseq(self, args):
73 | for arg in args:
74 | self.file.write(arg)
75 | self.file.write('\n')
76 | self.file.flush()
77 |
78 |
79 | class JSONOutputFormat(OutputFormat):
80 | def __init__(self, file):
81 | self.file = file
82 |
83 | def writekvs(self, kvs):
84 | for k, v in kvs.items():
85 | if hasattr(v, 'dtype'):
86 | v = v.tolist()
87 | kvs[k] = v
88 | self.file.write(json.dumps(kvs) + '\n')
89 | self.file.flush()
90 |
91 | def close(self):
92 | self.file.close()
93 |
94 |
95 | class CSVOutputFormat(OutputFormat):
96 | def __init__(self, file):
97 | self.file = file
98 | self.keys = []
99 | self.sep = ','
100 |
101 | def writekvs(self, kvs):
102 | # Add our current row to the history
103 | extra_keys = kvs.keys() - self.keys
104 | if extra_keys:
105 | self.keys.extend(extra_keys)
106 | self.file.seek(0)
107 | lines = self.file.readlines()
108 | self.file.seek(0)
109 | for (i, k) in enumerate(self.keys):
110 | if i > 0:
111 | self.file.write(',')
112 | self.file.write(k)
113 | self.file.write('\n')
114 | for line in lines[1:]:
115 | self.file.write(line[:-1])
116 | self.file.write(self.sep * len(extra_keys))
117 | self.file.write('\n')
118 | for (i, k) in enumerate(self.keys):
119 | if i > 0:
120 | self.file.write(',')
121 | v = kvs.get(k)
122 | if v is not None:
123 | self.file.write(str(v))
124 | self.file.write('\n')
125 | self.file.flush()
126 |
127 | def close(self):
128 | self.file.close()
129 |
130 |
131 | def make_output_format(format, ev_dir):
132 | os.makedirs(ev_dir, exist_ok=True)
133 | if format == 'stdout':
134 | return HumanOutputFormat(sys.stdout)
135 | elif format == 'log':
136 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt')
137 | return HumanOutputFormat(log_file)
138 | elif format == 'json':
139 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt')
140 | return JSONOutputFormat(json_file)
141 | elif format == 'csv':
142 | csv_file = open(osp.join(ev_dir, 'progress.csv'), 'w+t')
143 | return CSVOutputFormat(csv_file)
144 | else:
145 | raise ValueError('Unknown format specified: %s' % (format,))
146 |
147 |
148 | # ================================================================
149 | # API
150 | # ================================================================
151 | def log_params(params):
152 | assert isinstance(params, dict)
153 | json_file = open(osp.join(Logger.CURRENT.get_dir(), 'params.json'), 'wt')
154 | output_format = JSONOutputFormat(json_file)
155 | output_format.writekvs(params)
156 | output_format.close()
157 |
158 | def logkv(key, val):
159 | """
160 | Log a value of some diagnostic
161 | Call this once for each diagnostic quantity, each iteration
162 | """
163 | Logger.CURRENT.logkv(key, val)
164 |
165 |
166 | def dumpkvs():
167 | """
168 | Write all of the diagnostics from the current iteration
169 |
170 | level: int. (see old_logger.py docs) If the global logger level is higher than
171 | the level argument here, don't print to stdout.
172 | """
173 | Logger.CURRENT.dumpkvs()
174 |
175 |
176 | # for backwards compatibility
177 | record_tabular = logkv
178 | dump_tabular = dumpkvs
179 |
180 |
181 | def log(*args, level=INFO):
182 | """
183 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
184 | """
185 | Logger.CURRENT.log(*args, level=level)
186 |
187 |
188 | def debug(*args):
189 | log(*args, level=DEBUG)
190 |
191 |
192 | def info(*args):
193 | log(*args, level=INFO)
194 |
195 |
196 | def warn(*args):
197 | log(*args, level=WARN)
198 |
199 |
200 | def error(*args):
201 | log(*args, level=ERROR)
202 |
203 |
204 | def set_level(level):
205 | """
206 | Set logging threshold on current logger.
207 | """
208 | Logger.CURRENT.set_level(level)
209 |
210 |
211 | def get_level():
212 | """
213 | Set logging threshold on current logger.
214 | """
215 | return Logger.CURRENT.level
216 |
217 |
218 | def get_dir():
219 | """
220 | Get directory that log files are being written to.
221 | will be None if there is no output directory (i.e., if you didn't call start)
222 | """
223 | return Logger.CURRENT.get_dir()
224 |
225 |
226 | def get_expt_dir():
227 | sys.stderr.write(
228 | "get_expt_dir() is Deprecated. Switch to get_dir() [%s]\n" % (get_dir(),))
229 | return get_dir()
230 |
231 |
232 | # ================================================================
233 | # Backend
234 | # ================================================================
235 |
236 |
237 | class Logger(object):
238 | # A logger with no output files. (See right below class definition)
239 | DEFAULT = None
240 | # So that you can still log to the terminal without setting up any output files
241 | CURRENT = None # Current logger being used by the free functions above
242 |
243 | def __init__(self, dir, output_formats):
244 | self.name2val = OrderedDict() # values this iteration
245 | self.level = INFO
246 | self.dir = dir
247 | self.output_formats = output_formats
248 |
249 | # Logging API, forwarded
250 | # ----------------------------------------
251 | def logkv(self, key, val):
252 | self.name2val[key] = val
253 |
254 | def dumpkvs(self):
255 | for fmt in self.output_formats:
256 | fmt.writekvs(self.name2val)
257 | self.name2val.clear()
258 |
259 | def log(self, *args, level=INFO):
260 | now = datetime.datetime.now(dateutil.tz.tzlocal())
261 | timestamp = now.strftime('[%Y-%m-%d %H:%M:%S.%f %Z] ')
262 | if self.level <= level:
263 | self._do_log((timestamp,) + args)
264 |
265 | # Configuration
266 | # ----------------------------------------
267 | def set_level(self, level):
268 | self.level = level
269 |
270 | def get_dir(self):
271 | return self.dir
272 |
273 | def close(self):
274 | for fmt in self.output_formats:
275 | fmt.close()
276 |
277 | # Misc
278 | # ----------------------------------------
279 | def _do_log(self, args):
280 | for fmt in self.output_formats:
281 | fmt.writeseq(args)
282 |
283 |
284 | # ================================================================
285 |
286 | Logger.DEFAULT = Logger(
287 | output_formats=[HumanOutputFormat(sys.stdout)], dir=None)
288 | Logger.CURRENT = Logger.DEFAULT
289 |
290 |
291 | class session(object):
292 | """
293 | Context manager that sets up the loggers for an experiment.
294 | """
295 | def __init__(self, dir, format_strs=LOG_OUTPUT_FORMATS):
296 | self.dir = dir
297 | self.format_strs = format_strs
298 |
299 | def __enter__(self):
300 | os.makedirs(self.dir, exist_ok=True)
301 | output_formats = [make_output_format(f, self.dir) for f in self.format_strs]
302 | Logger.CURRENT = Logger(dir=self.dir, output_formats=output_formats)
303 |
304 | def __exit__(self, *args):
305 | Logger.CURRENT.close()
306 | Logger.CURRENT = Logger.DEFAULT
307 |
308 |
309 | # ================================================================
310 |
311 |
312 | def main():
313 | info("hi")
314 | debug("shouldn't appear")
315 | set_level(DEBUG)
316 | debug("should appear")
317 | dir = "/tmp/testlogging"
318 | if os.path.exists(dir):
319 | shutil.rmtree(dir)
320 | with session(dir=dir):
321 | record_tabular("a", 3)
322 | record_tabular("b", 2.5)
323 | dump_tabular()
324 | record_tabular("b", -2.5)
325 | record_tabular("a", 5.5)
326 | dump_tabular()
327 | info("^^^ should see a = 5.5")
328 |
329 | record_tabular("b", -2.5)
330 | dump_tabular()
331 |
332 | record_tabular("a", "longasslongasslongasslongasslongasslongassvalue")
333 | dump_tabular()
334 |
335 |
336 | if __name__ == "__main__":
337 | main()
--------------------------------------------------------------------------------