├── Generative_Model ├── main.py ├── MOFVAE │ ├── main.py │ ├── Auto_Encoder.py │ └── VAE_model.py ├── TensorToLoc.py ├── util │ ├── henrys_constant.py │ ├── periodic_padding.py │ ├── scaling.py │ └── rotations.py ├── gen_data_sample.py ├── VAE.py ├── util.py ├── mof_dataset_v2.py ├── mof_dataset.py ├── GAN.py ├── AE.py ├── CIFtoTensor.py ├── CIFtoVoxel.py ├── MOFGan.py └── mof_wgan_gp_multi_channel.py ├── exploratory ├── exploratory_stats.py ├── LJ_POTENTIAL.png ├── properties_stats.py ├── distance_matrix.py ├── minimum_radius_connectivity.py └── parse_cif.py ├── .gitignore ├── Moledule_Generation └── MolGAN │ ├── MOF_GENERATOR_32.p │ ├── model.py │ ├── main.py │ └── tester.py ├── data ├── readme.md ├── data_download.py └── generate_energy_grids.py ├── requirements.txt ├── data_util.py ├── data_download.py ├── MOF_Force_Field ├── MOLGCN.py ├── MOLGCN_Tester.py ├── model.py ├── tester.py ├── dataloader.py └── MOF_Approximator.ipynb ├── README.md └── binary_dataset └── binary_dataset.py /Generative_Model/main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Generative_Model/MOFVAE/main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /exploratory/exploratory_stats.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | data/ 3 | gcn_model/data/ 4 | __pyc__/ 5 | *.pyc 6 | *.ics 7 | *.cif 8 | *.idea 9 | -------------------------------------------------------------------------------- /exploratory/LJ_POTENTIAL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szaman19/Materials-Search/HEAD/exploratory/LJ_POTENTIAL.png -------------------------------------------------------------------------------- /Moledule_Generation/MolGAN/MOF_GENERATOR_32.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/szaman19/Materials-Search/HEAD/Moledule_Generation/MolGAN/MOF_GENERATOR_32.p -------------------------------------------------------------------------------- /Generative_Model/TensorToLoc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from CIFtoTensor import CIFtoTensor 4 | 5 | 6 | 7 | cif_file = CIFtoTensor.get_cif_file() 8 | struc = CIFtoTensor.get_pymat_struct(cif_file) 9 | mol_tensor = CIFtoTensor.to3DTensor(struc) 10 | 11 | print(mol_tensor.shape) -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | ## Data Generation 2 | --- 3 | 4 | This directory holds scripts to generate the energy grids for the MOF CoRE dataset. 5 | 6 | ### Data Download 7 | 8 | Run the download script with: 9 | 10 | ```python 11 | python data_download.py 12 | ``` 13 | 14 | ### Energy Grid Generation 15 | 16 | The script requires Musen Zhao's implementation of `cif2input` and `grid_gen`. 17 | 18 | ```python 19 | python generate_energy_grids.py 20 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2019.9.11 2 | chardet==3.0.4 3 | cycler==0.10.0 4 | decorator==4.4.0 5 | idna==2.8 6 | kiwisolver==1.1.0 7 | matplotlib==3.1.1 8 | monty==2.0.6 9 | mpmath==1.1.0 10 | networkx==2.3 11 | numpy==1.17.2 12 | palettable==3.3.0 13 | pandas==0.25.1 14 | PyDispatcher==2.0.5 15 | pymatgen==2019.9.16 16 | pyparsing==2.4.2 17 | python-dateutil==2.8.0 18 | pytz==2019.2 19 | requests==2.22.0 20 | ruamel.yaml==0.16.5 21 | ruamel.yaml.clib==0.2.0 22 | scipy==1.3.1 23 | six==1.12.0 24 | spglib==1.14.1.post0 25 | sympy==1.4 26 | tabulate==0.8.5 27 | urllib3==1.26.5 28 | -------------------------------------------------------------------------------- /exploratory/properties_stats.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import matplotlib.pyplot as plt 4 | 5 | def main(): 6 | os.chdir("../data") 7 | file_name = "/2019-07-01-ASR-public_12020.csv" 8 | file_name =os.getcwd() + file_name 9 | dataframe = pd.read_csv(file_name) 10 | 11 | # for col in dataframe.columns: 12 | # print(col) 13 | print(dataframe[dataframe.columns[:5]].describe()) 14 | # pd.DataFrame.hist(dataframe, column=["LCD","PLD"], grid=False, sharey=True) 15 | # plt.xlabel("LCD") 16 | # plt.ylabel("# Number of Structures") 17 | # plt.show() 18 | 19 | main() 20 | -------------------------------------------------------------------------------- /Generative_Model/util/henrys_constant.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | 4 | GAS_CONSTANT = 8.31446 5 | 6 | 7 | class PropertyCalculations: 8 | 9 | @staticmethod 10 | def get_henrys_constant(grid: List[List[List]], temperature=77): 11 | grid_size = 32 12 | 13 | temp = 0 14 | for i in range(grid_size): 15 | for j in range(grid_size): 16 | for k in range(grid_size): 17 | energy_value = grid[i][j][k] 18 | temp += math.exp(-energy_value / temperature) 19 | return temp / (GAS_CONSTANT * temperature * (grid_size ** 3)) 20 | -------------------------------------------------------------------------------- /data_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | import random 5 | 6 | def main(): 7 | 8 | file_name = os.getcwd() + "/data/files.log" 9 | f = open(file_name,"r") 10 | file_names = [i for i in f.readlines()] 11 | 12 | source_dir = os.getcwd()+ "/data/structure_11660/" 13 | train_dir = os.getcwd()+"/gcn_model/data/training/" 14 | test_dir = os.getcwd()+"/gcn_model/data/test/" 15 | 16 | for cif_file in file_names: 17 | cif_file = cif_file.rstrip() 18 | num = random.random() 19 | if(num < .89): 20 | os.rename(source_dir+cif_file, train_dir+cif_file) 21 | else: 22 | os.rename(source_dir+cif_file, test_dir+cif_file) 23 | main() 24 | -------------------------------------------------------------------------------- /Generative_Model/gen_data_sample.py: -------------------------------------------------------------------------------- 1 | from CIFtoTensor import CIFtoTensor 2 | import glob 3 | import pickle 4 | 5 | def main(): 6 | files = glob.glob("data/training/*.cif") 7 | 8 | counter = 0 9 | file_num = 0 10 | 11 | tensors = [] 12 | for f in files: 13 | cif_file = CIFtoTensor.get_cif_file(f) 14 | struc = CIFtoTensor.get_pymat_struct(cif_file) 15 | 16 | mol_tensor = CIFtoTensor.to3DTensor(struc, normalize=134) 17 | 18 | print(file_num * 320 + counter, "/", len(files)) 19 | tensors.append(mol_tensor) 20 | counter +=1 21 | if (counter == 320): 22 | pickle.dump(tensors, "training_mol_tensors_"+str(file_num)+".p") 23 | file_num +=1 24 | counter = 0 25 | tensors = [] 26 | elif (file_num*320 + counter == len(files)): 27 | pickle.dump(tensors, "training_mol_tensors_"+str(file_num)+".p") 28 | 29 | main() 30 | -------------------------------------------------------------------------------- /Generative_Model/util/periodic_padding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, functional 4 | from torch.nn.modules.utils import _ntuple 5 | 6 | 7 | class PeriodicPadNd(Module): 8 | def forward(self, x: Tensor) -> Tensor: 9 | return functional.pad(x, self.padding, 'circular') 10 | 11 | def extra_repr(self) -> str: 12 | return '{}'.format(self.padding) 13 | 14 | 15 | class PeriodicPad2d(PeriodicPadNd): 16 | def __init__(self, padding) -> None: 17 | super(PeriodicPad2d, self).__init__() 18 | self.padding = _ntuple(4)(padding) 19 | 20 | 21 | class PeriodicPad3d(PeriodicPadNd): 22 | def __init__(self, padding) -> None: 23 | super(PeriodicPad3d, self).__init__() 24 | self.padding = _ntuple(6)(padding) 25 | 26 | 27 | def main(): 28 | p = PeriodicPad2d(2) 29 | x = torch.arange(9).float().reshape(1, 1, 3, 3) 30 | print(x) 31 | y = p(x) 32 | print(y) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /data_download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def downloader(link, file_name): 5 | url = link 6 | r = requests.get(url) 7 | 8 | with open(file_name, 'wb') as f: 9 | f.write(r.content) 10 | f.close() 11 | if (r.status_code == "200"): 12 | print("Completed download") 13 | 14 | def main(): 15 | print('Beginning cif download') 16 | url_cif = 'https://zenodo.org/record/3677685/files/2019-11-01-ASR-public_12020.tar.gz?download=1' 17 | cif_tar_name = '2019-11-01-ASR-public_12020.tar.gz' 18 | downloader(url_cif, cif_tar_name) 19 | 20 | print('Beginning csv download') 21 | url_csv = "https://zenodo.org/record/3677685/files/2019-11-01-ASR-public_12020.csv?download=1" 22 | csv_name = "2019-11-01-ASR-public_12020.csv" 23 | downloader(url_csv, csv_name) 24 | 25 | print('Beginning csv download') 26 | url_csv = "https://zenodo.org/record/3677685/files/2019-11-01-ASR-internal_14142.csv?download=1" 27 | csv_name = "2019-11-01-ASR-internal_14142.csv" 28 | downloader(url_csv, csv_name) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /exploratory/distance_matrix.py: -------------------------------------------------------------------------------- 1 | import pymatgen 2 | import sys 3 | from pymatgen.io.cif import CifParser 4 | from pymatgen.core.lattice import Lattice 5 | import os 6 | import math 7 | 8 | def cif_structure(file_name): 9 | parser = CifParser(file_name) 10 | structure = parser.get_structures()[0] 11 | # print(structure[0].distance_matrix) 12 | # print(type(structure[0].distance_matrix)) 13 | 14 | temp = structure.distance_matrix 15 | 16 | # temp = (temp < 2.5) * tempz 17 | 18 | # print(temp) z 19 | # counter = 0 20 | # for each in structure.sites: 21 | # counter +=1 22 | # print(each.species,distance(each.coords[0], each.coords[1], each.coords[2])) 23 | 24 | # for i in range(counter): 25 | # print(structure.distance_matrix[0][i]) 26 | 27 | 28 | return temp 29 | 30 | 31 | def distance(x,y,z): 32 | return math.sqrt(x**2 + y**2 + z**2) 33 | def main(): 34 | os.chdir("../data/") 35 | file = "AHOKOX_clean.cif" 36 | 37 | distance_matrix = cif_structure(file) 38 | 39 | # print(distance_matrix) 40 | # cif_lattice(file) 41 | 42 | 43 | 44 | 45 | main() -------------------------------------------------------------------------------- /Generative_Model/VAE.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.autograd as autograd 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import util 7 | 8 | from MOFVAE import VAE_model 9 | 10 | cuda = True if torch.cuda.is_available() else False 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | num_epochs = 100 14 | batch_size = 128 15 | learning_rate = 1e-3 16 | 17 | 18 | 19 | if __name__ == '__main__': 20 | enc = VAE_model.Encoder(Num_features=12, Z_dim = 64).to(device) 21 | dec = VAE_model.Decoder(num_features = 12, z_dimension = 64, voxel_side_length=32).to(device) 22 | 23 | 24 | 25 | data = Variable(torch.rand(2,12,32,32,32)).to(device) 26 | 27 | optimizer = optim.Adam() 28 | z, mu, logvar = enc(data) 29 | 30 | print(z.shape, mu.shape, logvar.shape) 31 | 32 | X = z.view(z.size(0), z.size(1), 1,1,1) 33 | print(X.shape) 34 | 35 | decoded_z = dec(X) 36 | 37 | print(decoded_z[0].shape) 38 | 39 | chans = range(0,1) 40 | util.Visualize_4DTensor(decoded_z[0].cpu().detach().numpy(), chans) 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /MOF_Force_Field/MOLGCN.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.typing import OptTensor, PairTensor, PairOptTensor, Adj 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch_geometric.nn.conv import MessagePassing 6 | import torch.nn as neural_net 7 | import torch.nn.functional as F 8 | 9 | class MOLGCN(MessagePassing): 10 | """docstring for MOLGCN""" 11 | def __init__(self, 12 | nn, 13 | aggr = 'add', 14 | learn_input = False, 15 | feature_size = 4, 16 | **kwargs): 17 | super(MOLGCN, self).__init__() 18 | self.nn = nn 19 | self.aggr = aggr 20 | self.learn_input = learn_input 21 | 22 | self.bond_representation_learner = None 23 | 24 | 25 | 26 | def reset_parameters(self): 27 | self.nn.reset_parameters() 28 | 29 | def forward(self, x, edge_index, edge_attr, size = None): 30 | 31 | if isinstance(x, Tensor): 32 | x: PairTensor = (x, x) 33 | 34 | out = self.propagate(edge_index, x = x, edge_attr = edge_attr, size = size) 35 | 36 | return out 37 | 38 | 39 | def message(self, x_i, x_j, edge_attr): 40 | 41 | bonds = x_i + x_j 42 | z = torch.cat([bonds, edge_attr], dim = -1) 43 | return self.nn(z) 44 | 45 | 46 | -------------------------------------------------------------------------------- /data/data_download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import tarfile 3 | 4 | 5 | 6 | def downloader(link, file_name): 7 | url = link 8 | r = requests.get(url) 9 | 10 | with open(file_name, 'wb') as f: 11 | f.write(r.content) 12 | f.close() 13 | if (r.status_code == "200"): 14 | print("Completed download") 15 | 16 | def main(): 17 | print('Beginning cif download') 18 | url_cif = 'https://zenodo.org/record/3677685/files/2019-11-01-ASR-public_12020.tar.gz?download=1' 19 | cif_tar_name = '2019-11-01-ASR-public_12020.tar.gz' 20 | downloader(url_cif, cif_tar_name) 21 | 22 | tar = tarfile.open(cif_tar_name , "r:gz") 23 | tar.extractall() 24 | tar.close() 25 | 26 | print('Beginning csv download') 27 | url_csv = "https://zenodo.org/record/3677685/files/2019-11-01-ASR-public_12020.csv?download=1" 28 | csv_name = "2019-11-01-ASR-public_12020.csv" 29 | downloader(url_csv, csv_name) 30 | 31 | print('Beginning csv download') 32 | url_csv = "https://zenodo.org/record/3677685/files/2019-11-01-ASR-internal_14142.csv?download=1" 33 | csv_name = "2019-11-01-ASR-internal_14142.csv" 34 | downloader(url_csv, csv_name) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Materials Search with ML 2 | 3 | This repo consists of the various models used in the Deep Learning based Metal Organic Framework (MOF) generation project. The sub-directories include: 4 | 5 | 1. data - scripts for data gather from CoRE MOF Databse 6 | 2. exploratory - exploratory statistical analysis of MOFs 7 | 3. 3D_Grid_Data - code to convert cif to 3D molecular tensors 8 | 4. cnn_model - 3D CNN model for property classification and regression 9 | 5. gcn_model - Graph Convolutional model for property classification and regression 10 | 6. 3DGen_Model - 3D generative models (GAN, VAE) for materials generation 11 | 12 | 13 | 14 | 15 | 16 | ### Downloading CoRE MOF Database 17 | 18 | *** 19 | 20 | #### Requirements: 21 | 1. Python3 22 | 2. Requests 23 | 24 | 25 | #### Download: 26 | 1. Create Virtual environment 27 | 2. Activate your virtual environment. (Call it venv so the git automatically ignores it) 28 | 3. Install dependencies 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 4. Run the downloader with: 33 | ``` 34 | python data_download.py 35 | ``` 36 | 5. De-compress the .tar file to retrieve the .ics files 37 | 38 | This material is based upon work supported by the National Science Foundation under Grant No. DMR-1940243. 39 | 40 | Any opinions, findings, and conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the National Science Foundation. 41 | -------------------------------------------------------------------------------- /Generative_Model/util.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import Axes3D 5 | 6 | 7 | def Visualize_4DTensor(tensor, channels, threshold=1E-6, savefile="Visualize_4DTensor.png"): 8 | Channel_Titles = ["Energy Grid","H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] 9 | if (len(tensor.shape) != 4): 10 | print("Tensor must be 4-dimensional. Tensor shape was: ", tensor.shape) 11 | else: 12 | fig = plt.figure() 13 | counter = 1 14 | for channel in channels: 15 | ax = fig.add_subplot(len(channels), 1, counter, projection='3d') 16 | grid = tensor[channel] 17 | grid[grid < threshold] = 0 18 | ax.voxels(grid) 19 | ax.set_title(Channel_Titles[channel]) 20 | counter +=1 21 | 22 | plt.legend() 23 | plt.savefig(savefile) 24 | 25 | 26 | def Visualize_MOF(tensor, channels, threshold=1E-1, savefile="MOF.png"): 27 | fig = plt.figure() 28 | ax = fig.gca(projection='3d') 29 | Channel_Titles = ["Energy Grid","H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] 30 | 31 | for i,channel in enumerate(channels): 32 | grid = np.copy(tensor[channel]) 33 | grid[grid < threshold] = 0 34 | ax.voxels(grid) 35 | # plt.legend() 36 | plt.savefig(savefile) 37 | 38 | def Visualize_MOF_Split(tensor, channels, threshold=1E-1, savefile="MOF.png"): 39 | fig = plt.figure() 40 | ax = fig.gca(projection='3d') 41 | Channel_Titles = ["Energy Grid","H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] 42 | 43 | for i,channel in enumerate(channels): 44 | grid = np.copy(tensor[channel]) 45 | grid[grid < threshold] = 0 46 | ax.voxels(grid) 47 | plt.savefig(Channel_Titles[i]+"_"+savefile) 48 | # plt.close() -------------------------------------------------------------------------------- /Generative_Model/mof_dataset_v2.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.utils.data import Dataset 7 | 8 | import numpy as np 9 | 10 | from mof_dataset import MOFDataset 11 | from util.rotations import Rotations 12 | 13 | 14 | def normalize(k): 15 | return np.sign(k) * np.log(abs(k) + 1) 16 | 17 | 18 | class MOFDatasetV2(Dataset): 19 | def __init__(self, path): 20 | self.path = path 21 | with open(path, "rb") as f: 22 | self.data: Tensor = pickle.load(f) 23 | 24 | def __len__(self): 25 | return len(self.data) 26 | 27 | def __getitem__(self, index): 28 | return self.data[index] 29 | 30 | @staticmethod 31 | def get_data_loader(path: str, batch_size: int, shuffle=True): 32 | return torch.utils.data.DataLoader( 33 | MOFDatasetV2(path), 34 | batch_size=batch_size, 35 | shuffle=shuffle, 36 | ) 37 | 38 | 39 | def main(): 40 | file_type = "Training" if False else "Test" 41 | 42 | data_loader = MOFDataset.get_data_loader(f"_data/{file_type}_MOFS.p", batch_size=1) 43 | 44 | result_list = [] 45 | 46 | batch: int 47 | mofs: torch.Tensor 48 | for batch, mofs in enumerate(data_loader): 49 | result_list.append(mofs) 50 | for rotation in Rotations.rotate_3d(mofs[0][0]): 51 | result_list.append(rotation.unsqueeze(0).unsqueeze(0)) 52 | 53 | result = torch.cat(result_list) 54 | 55 | output_path = Path(f"_data/{file_type}_MOFS_v2.p") 56 | with output_path.open("wb+") as f: 57 | pickle.dump(result, f, protocol=4) 58 | 59 | print(result.shape) 60 | print("DONE!") 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /Generative_Model/util/scaling.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional 6 | import torch.nn.functional 7 | 8 | 9 | class ScaleUtil: 10 | 11 | @staticmethod 12 | def resize_2d(t: Tensor, size: Union[int, Tuple[int, int, int]]) -> Tensor: 13 | t = t.unsqueeze(0).unsqueeze(0) # Input format is Batch x Channels x Dims 14 | return functional.interpolate(t, size=size, mode='bilinear', align_corners=False)[0][0] 15 | 16 | @staticmethod 17 | def resize_3d(t: Tensor, size: Union[int, Tuple[int, int, int]]) -> Tensor: 18 | t = t.unsqueeze(0).unsqueeze(0) # Input format is Batch x Channels x Dims 19 | return functional.interpolate(t, size=size, mode='trilinear', align_corners=False)[0][0] 20 | 21 | 22 | def main(): 23 | t2 = torch.tensor([[1, 2, 3, 4], 24 | [5, 0, 7, 8], 25 | [9, 10, 11, 12], 26 | [13, 14, 15, 16]]).float() 27 | 28 | t3 = torch.tensor([[[1, 2, 3, 4], 29 | [5, 0, 7, 8], 30 | [9, 10, 11, 12], 31 | [13, 14, 15, 16]], 32 | [[1, 2, 3, 4], 33 | [5, 6, 7, 8], 34 | [9, 10, 11, 12], 35 | [13, 14, 15, 16]], 36 | [[1, 2, 3, 4], 37 | [5, 6, 7, 8], 38 | [9, 10, 11, 12], 39 | [13, 14, 15, 16]]]).float() 40 | 41 | t = t2 if True else t3 42 | print("BEFORE:", t.shape) 43 | result = ScaleUtil.resize_2d(t, 5) 44 | print(result) 45 | print("AFTER:", result.shape) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /Generative_Model/mof_dataset.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from pathlib import Path 3 | from typing import List 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | import Voxel_MOF 9 | 10 | 11 | class MOFDataset(Dataset): 12 | def __init__(self, path, no_grid=False, no_loc=False,transform=None): 13 | self.path = path 14 | self.no_grid = no_grid 15 | self.no_loc = no_loc 16 | path = Path(path) 17 | with path.open("rb") as f: 18 | self.data: List[Voxel_MOF] = pickle.load(f) 19 | self.transform = transform 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | def __getitem__(self, idx): 25 | if self.no_grid: 26 | return self.data[idx].loc_tensor 27 | elif self.no_loc: 28 | return self.data[idx].grid_tensor 29 | else: 30 | return self.data[idx].data 31 | 32 | @staticmethod 33 | def get_data_loader(path: str, batch_size: int, no_grid=False, no_loc=False): 34 | return torch.utils.data.DataLoader( 35 | MOFDataset(path, no_grid=no_grid, no_loc=no_loc), 36 | batch_size=batch_size, 37 | shuffle=True, 38 | ) 39 | 40 | 41 | def main(): 42 | data_loader = MOFDataset.get_data_loader("../3D_Grid_Data/Test_MOFS.p", 25) 43 | 44 | batch: int 45 | mofs: torch.Tensor 46 | for batch, mofs in enumerate(data_loader): 47 | print(batch, mofs.shape) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | model = None 53 | saved_model = Path("AE_MODEL_FULL.p") 54 | if (saved_model.is_file()): 55 | print("Loading Saved Model") 56 | model = torch.load("AE_MODEL_FULL.p") 57 | else: 58 | model = AE.ConvolutionalAE(2048, 11) 59 | 60 | if cuda: 61 | model.cuda() 62 | -------------------------------------------------------------------------------- /data/generate_energy_grids.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os.path as osp 3 | from tqdm import tqdm 4 | import subprocess as sp 5 | import os 6 | 7 | 8 | def main(): 9 | 10 | cur_dir = osp.dirname(osp.realpath(__file__)) 11 | struct_dir = osp.join(cur_dir, "structure_10143") 12 | 13 | inp_dir = osp.join(cur_dir, "inp_grids") 14 | energy_dir = osp.join(cur_dir, "energy_grids") 15 | 16 | if (not osp.isdir(inp_dir)): 17 | os.mkdir(inp_dir) 18 | 19 | if (not osp.isdir(energy_dir)): 20 | os.mkdir(energy_dir) 21 | 22 | UFF_loc = osp.join(cur_dir, "MOFGAN/data_ff_UFF") 23 | 24 | if (not osp.exists(UFF_loc)): 25 | raise ValueError("Couldn't find valid force-field file. ") 26 | 27 | cif_files = glob(struct_dir+"/*.cif") 28 | cif_names = [x.split("/")[-1][:-4] for x in cif_files] 29 | 30 | num_files = len(cif_files) 31 | 32 | num_concurrent_processes = 38 33 | for _blocks in tqdm(range(0, num_files, num_concurrent_processes)): 34 | counter = 0 35 | procs = [] 36 | while (counter < num_concurrent_processes and counter + _blocks < num_files): 37 | 38 | # generate inp file 39 | cif_name = cif_names[_blocks + counter] 40 | cif_file = cif_files[_blocks + counter] 41 | inp_file = osp.join(inp_dir, cif_name+".inp") 42 | grid_file = osp.join(energy_dir, cif_name+".grid") 43 | # print(inp_file, cif_file, UFF_loc) 44 | if (not osp.exists(inp_file)): 45 | p = sp.Popen(["./cif2input", cif_file, UFF_loc, inp_file]) 46 | p.wait() 47 | 48 | if (not osp.exists(grid_file)): 49 | procs.append(sp.Popen(["./grid_gen", inp_file, grid_file])) 50 | 51 | counter += 1 52 | 53 | exit_codes = [p.wait() for p in procs] 54 | print(exit_codes) 55 | 56 | print(num_files) 57 | 58 | 59 | if __name__ == "__main__": 60 | main() -------------------------------------------------------------------------------- /Generative_Model/util/rotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class Rotations: 6 | 7 | @staticmethod 8 | def rotate_2d(tensor: Tensor): 9 | rotate_90 = tensor.transpose(0, 1).flip(0) 10 | rotate_180 = rotate_90.transpose(0, 1).flip(0) # rotate_180 = tensor.flip(0).flip(1) #(0=HR, 1=VR) 11 | rotate_270 = rotate_180.transpose(0, 1).flip(0) 12 | 13 | return rotate_90, rotate_180, rotate_270 14 | 15 | @staticmethod 16 | def rotate_3d(tensor: Tensor): 17 | # Rotations are counter clockwise looking toward the origin from a positive position along the axis 18 | 19 | # Z Axis Rotations 20 | z90 = tensor.transpose(1, 2).flip(1) 21 | z180 = z90.transpose(1, 2).flip(1) 22 | z270 = z180.transpose(1, 2).flip(1) 23 | 24 | # Y Axis Rotations 25 | y90 = tensor.transpose(0, 2).flip(2) 26 | y180 = y90.transpose(0, 2).flip(2) 27 | y270 = y180.transpose(0, 2).flip(2) 28 | 29 | # X Axis Rotations 30 | x90 = tensor.transpose(0, 1).flip(1) 31 | x180 = x90.transpose(0, 1).flip(1) 32 | x270 = x180.transpose(0, 1).flip(1) 33 | 34 | # return x90, x180, x270, tensor, y90, y180, y270, tensor, z90, z180, z270, tensor 35 | return x90, x180, x270, y90, y180, y270, z90, z180, z270 36 | 37 | 38 | def main(): 39 | t = torch.tensor([[[1, 2, 3, 4], 40 | [5, 0, 7, 8], 41 | [9, 10, 11, 12], 42 | [13, 14, 15, 16]], 43 | [[1, 2, 3, 4], 44 | [5, 6, 7, 8], 45 | [9, 10, 11, 12], 46 | [13, 14, 15, 16]], 47 | [[1, 2, 3, 4], 48 | [5, 6, 7, 8], 49 | [9, 10, 11, 12], 50 | [13, 14, 15, 16]]]) 51 | print(t) 52 | for r in Rotations.rotate_3d(t): 53 | print(r) 54 | 55 | 56 | if __name__ == '__main__': 57 | main() 58 | -------------------------------------------------------------------------------- /Generative_Model/GAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class Generator(nn.Module): 7 | """docstring for Generator""" 8 | def __init__(self, z_len,cube_side, num_atoms): 9 | super(Generator, self).__init__() 10 | self.inp_vector = z_len 11 | self.length = cube_side 12 | self.num_channels = num_atoms 13 | 14 | #inp.vector x 1 x 1 x 1 15 | #Conv3d(in_channels, out_channels, kerbel_size, stride, padding) 16 | self.main = nn.Sequential( 17 | 18 | nn.ConvTranspose3d(self.inp_vector, self.length*8, 4,2,0), # self.length*8 x 4 x 4 x 4 19 | nn.BatchNorm3d(512), 20 | nn.Tanh(), 21 | 22 | nn.ConvTranspose3d(self.length*8, self.length*4, 4,2,1), # self.length*4 x 8 x 8 x 8 23 | nn.BatchNorm3d(256), 24 | nn.Tanh(), 25 | 26 | nn.ConvTranspose3d(self.length*4, self.length*2, 4,2,1), # self.length*2 x 16 x 16 x 16 27 | nn.BatchNorm3d(128), 28 | nn.Tanh(), 29 | 30 | nn.ConvTranspose3d(self.length*2, self.length, 4,2,1), # self.length x 32 x 32 x 32 31 | nn.BatchNorm3d(64), 32 | nn.Tanh(), 33 | 34 | nn.ConvTranspose3d(self.length, self.num_channels, 4,2,1), #self.num_channels x 64 x 64 x 64 35 | nn.Sigmoid() 36 | ) 37 | def forward(self, x): 38 | x = x.view(x.size(0),x.size(1),1,1,1) 39 | return self.main(x) 40 | 41 | class Discriminator(nn.Module): 42 | def __init__(self, num_atoms, cube_side): 43 | super(Discriminator,self).__init__() 44 | self.num_channels = num_atoms 45 | self.length = cube_side 46 | 47 | self.main = nn.Sequential( 48 | 49 | nn.Conv3d(self.num_channels, self.length, 4,2,1), #64 50 | nn.BatchNorm3d(64), 51 | nn.LeakyReLU(0.2), 52 | 53 | nn.Conv3d(self.length, self.length * 2, 4,2,1), #32 54 | nn.BatchNorm3d(128), 55 | nn.LeakyReLU(0.2), 56 | 57 | nn.Conv3d(self.length * 2, self.length * 4, 4,2,1), #16 58 | nn.BatchNorm3d(256), 59 | nn.LeakyReLU(0.2), 60 | 61 | nn.Conv3d(self.length * 4, self.length * 8, 4,2,1), #4 62 | nn.BatchNorm3d(512), 63 | nn.LeakyReLU(0.2), 64 | 65 | nn.Conv3d(self.length * 8, 1, 4,2,0), #1 66 | nn.Sigmoid() 67 | 68 | ) 69 | 70 | def forward(self, x): 71 | x = self.main(x) 72 | return x.view(-1,x.size(1)) 73 | 74 | if __name__ == '__main__': 75 | G = Generator(200,64, 11) 76 | D = Discriminator(11,64) 77 | z = Variable(torch.rand(16,200,1,1,1)) 78 | X = G(z) 79 | D_X = D(X) 80 | print(X.shape, D_X.shape) 81 | 82 | 83 | -------------------------------------------------------------------------------- /MOF_Force_Field/MOLGCN_Tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, DataLoader 3 | from model import MOF_Net, run 4 | from MOLGCN import MOLGCN 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.data as data_utils 11 | 12 | def gaussian_dist(val, mean, variance): 13 | temp = - ((val-mean)**2) / (2 * (variance)) 14 | return (1 / (np.sqrt(2 * np.pi* variance)))* np.exp(temp) 15 | 16 | def energy(bond_type, distance): 17 | if bond_type == 1: 18 | return gaussian_dist(distance, 0.6, .1) 19 | elif bond_type == 2: 20 | return gaussian_dist(distance, 0.05, 0.01) 21 | else: 22 | return gaussian_dist(distance, 0.3, 0.02) 23 | 24 | def generate_graph_data(N): 25 | data_list = [] 26 | for data_sample in range(N): 27 | node_features = torch.eye(3) 28 | edge_list = torch.zeros((2,3)).long() 29 | 30 | edge_list[0][0] = 0 31 | edge_list[0][1] = 1 32 | edge_list[0][2] = 2 33 | 34 | edge_list[1][0] = 1 35 | edge_list[1][1] = 2 36 | edge_list[1][2] = 0 37 | 38 | distances = torch.rand((3,1)) 39 | np_dists = distances.data.cpu().numpy() 40 | 41 | edge_features = distances 42 | y = energy(1,np_dists[0][0]) + energy(2,np_dists[1][0])+ energy(3,np_dists[2][0]) 43 | 44 | node_features = node_features.float() 45 | edge_list = edge_list.long() 46 | y = torch.tensor(y).float() 47 | edge_features = edge_features.float() 48 | geom_data = Data(x=node_features, edge_index = edge_list, edge_attr = edge_features ,y = y) 49 | data_list.append(geom_data) 50 | return data_list 51 | 52 | data_list = generate_graph_data(10000) 53 | loader = DataLoader(data_list, batch_size=1) 54 | 55 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 56 | mlp = nn.Sequential(nn.Linear(7,16), 57 | nn.ReLU(), 58 | nn.Linear(16,1) 59 | ) 60 | model = MOF_Net(7, mlp).to(device) 61 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 62 | loss_func = nn.MSELoss() 63 | 64 | for epoch in range(401): 65 | total = 0 66 | for batch, data in enumerate(loader): 67 | # print(data) 68 | x = data.to(device) 69 | out = model(x) 70 | y = data.y.to(device) 71 | loss = loss_func(out, y) 72 | optimizer.zero_grad() 73 | loss.backward() 74 | total += loss.item() 75 | optimizer.step() 76 | if (epoch % 10 == 0): 77 | print(total / len(loader)) -------------------------------------------------------------------------------- /exploratory/minimum_radius_connectivity.py: -------------------------------------------------------------------------------- 1 | import pymatgen 2 | import sys 3 | from pymatgen.io.cif import CifParser 4 | from multiprocessing import Pool 5 | import os 6 | import math 7 | import glob 8 | import networkx as nx 9 | import numpy as np 10 | class Periodic_Struture(object): 11 | """docstring for Periodic Struture""" 12 | def __init__(self, file_name): 13 | super(Periodic_Struture, self).__init__() 14 | self.distance_matrix = self.cif_distance_matrix(file_name) 15 | self.file_name = file_name 16 | self.r = np.amax(self.distance_matrix) 17 | 18 | def cif_distance_matrix(self,file_name): 19 | parser = CifParser(file_name) 20 | structure = parser.get_structures()[0] 21 | temp = structure.distance_matrix 22 | 23 | return temp 24 | def get_min_radius(self): 25 | 26 | self.min_radius_helper(self.r) 27 | 28 | def min_radius_helper(self,radius): 29 | 30 | 31 | # print("Radius: ", radius, " Connected: ", is_connected) 32 | 33 | new_radius = radius / 2 34 | prev_radius = radius 35 | 36 | is_connected = False 37 | while(not(abs(new_radius - self.r) < .1)): 38 | temp = (self.distance_matrix < new_radius) * self.distance_matrix 39 | graph = nx.from_numpy_array(temp) 40 | is_connected = nx.is_connected(graph) 41 | 42 | # print("Radius: ", new_radius, "Is Connected: ",is_connected) 43 | # print(new_radius, ", ", self.r) 44 | if(is_connected): 45 | self.r= new_radius 46 | new_radius = new_radius / 2 47 | else: 48 | new_radius = new_radius + (self.r - new_radius) / 2 49 | # prev_radius = new_radius 50 | 51 | # self.r = prev_radius 52 | 53 | def func(files): 54 | 55 | return_val = [] 56 | for f in files: 57 | struct = Periodic_Struture(f) 58 | struct.get_min_radius() 59 | return_val.append(str((f, str(struct.r)))) 60 | return return_val 61 | 62 | 63 | def main(): 64 | os.chdir("../data/structure_11660/") 65 | files = glob.glob("*.cif") 66 | 67 | # struct = Periodic_Struture(file) 68 | # print(struct.r) 69 | # struct.get_min_radius() 70 | 71 | Num_Processes = 20 72 | 73 | num_files = len(files) 74 | 75 | file_chunks = [ files[int((num_files/Num_Processes) * i): int((num_files / Num_Processes * (i+1)))] for i in range(Num_Processes)] 76 | 77 | for each in file_chunks: 78 | print(each) 79 | pool = Pool(processes=Num_Processes) 80 | results = [pool.apply_async(func, args=(file_chunks[i],)) for i in range(Num_Processes)] 81 | output = [p.get() for p in results] 82 | 83 | log = open("min_radius_2.log","w") 84 | 85 | for returned_list in output: 86 | for each in returned_list: 87 | log.write(each) 88 | log.write("\n") 89 | log.close() 90 | 91 | # print(struct.r) 92 | 93 | 94 | 95 | 96 | if __name__ == '__main__': 97 | main() 98 | 99 | -------------------------------------------------------------------------------- /Generative_Model/MOFVAE/Auto_Encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.utils.data import DataLoader 6 | 7 | class MLPAutoEncoder(nn.Module): 8 | """docstring for AutoEncoder""" 9 | def __init__(self): 10 | super(AutoEncoder, self).__init__() 11 | 12 | self.encoder = nn.Sequential( 13 | nn.Linear(32*32*32*12, 512), 14 | nn.ReLU(True), 15 | nn.BatchNorm1d(512), 16 | nn.Linear(512, 512), 17 | nn.ReLU(True), 18 | nn.BatchNorm1d(512), 19 | nn.Linear(512, 256), 20 | nn.ReLU(True), 21 | nn.BatchNorm1d(256), 22 | nn.Linear(256, 128)) 23 | 24 | self.decoder = nn.Sequential( 25 | nn.Linear(128, 256), 26 | nn.ReLU(True), 27 | nn.BatchNorm1d(256), 28 | nn.Linear(256, 512), 29 | nn.ReLU(True), 30 | nn.BatchNorm1d(512), 31 | nn.Linear(512, 512), 32 | nn.ReLU(True), 33 | nn.BatchNorm1d(512), 34 | nn.Linear(512, 32*32*32*12), 35 | nn.Tanh()) 36 | 37 | def forward(self, x): 38 | u = self.encoder(x) 39 | x_prime = self.decoder(u) 40 | return x_prime 41 | 42 | class ConvolutionalAE(nn.Module): 43 | """docstring for ConvolutionalAE""" 44 | def __init__(self, z_dimension, num_features): 45 | super(ConvolutionalAE, self).__init__() 46 | self.z_dim = z_dimension 47 | self.num_features = num_features 48 | 49 | self.encoder = nn.Sequential( 50 | nn.Conv3d(self.num_features, self.z_dim // 8, 4,2,1), #32 51 | nn.BatchNorm3d(self.z_dim // 8), 52 | nn.LeakyReLU(0.2), 53 | 54 | nn.Conv3d(self.z_dim // 8, self.z_dim // 4, 4,2,1), #16 55 | nn.BatchNorm3d(self.z_dim // 4), 56 | nn.LeakyReLU(0.2), 57 | 58 | nn.Conv3d(self.z_dim // 4, self.z_dim // 2, 4,2,1), #4 59 | nn.BatchNorm3d(self.z_dim // 2), 60 | nn.LeakyReLU(0.2), 61 | 62 | nn.Conv3d(self.z_dim // 2, self.z_dim, 4,2,1), #2 63 | nn.BatchNorm3d(self.z_dim), 64 | nn.LeakyReLU(0.2), 65 | 66 | nn.Conv3d(self.z_dim, self.z_dim, 2,2,0), #1 67 | ) 68 | 69 | self.decoder = nn.Sequential( 70 | 71 | nn.ConvTranspose3d(self.z_dim, self.num_features*16, 4,2,0), # self.num_features*8 x 4 x 4 x 4 72 | nn.BatchNorm3d(self.num_features*16), 73 | nn.Tanh(), 74 | 75 | nn.ConvTranspose3d(self.num_features*16, self.num_features*4, 4,2,1), # self.self.num_features*4 x 8 x 8 x 8 76 | nn.BatchNorm3d(self.num_features*4), 77 | nn.Tanh(), 78 | 79 | nn.ConvTranspose3d(self.num_features*4, self.num_features*2, 4,2,1), # self.self.num_features*2 x 16 x 16 x 16 80 | nn.BatchNorm3d(self.num_features*2), 81 | nn.Tanh(), 82 | 83 | nn.ConvTranspose3d(self.num_features*2, self.num_features, 4,2,1), # self.self.num_features x 32 x 32 x 32 84 | nn.Sigmoid(), 85 | ) 86 | def forward(self, x): 87 | u = self.encoder(x) 88 | x_prime = self.decoder(u) 89 | return x_prime 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | # def main(): 101 | 102 | 103 | 104 | # if __name__ == '__main__': 105 | # main() 106 | # -------------------------------------------------------------------------------- /Moledule_Generation/MolGAN/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from torch_geometric.nn import TopKPooling 6 | from torch_geometric.nn import GraphConv 7 | from torch_geometric.nn import global_mean_pool as gap 8 | from torch_geometric.nn import global_max_pool as gmp 9 | 10 | 11 | class Generator(nn.Module): 12 | """docstring for Generator""" 13 | def __init__(self, input_vector_dim = 32, num_nodes=32, num_features=11, num_edge_features = 1): 14 | super(Generator, self).__init__() 15 | 16 | #Z vector to generate continuous graph from. Default is 32 17 | self.z_dim = input_vector_dim 18 | self.N = num_nodes # Number of atoms 19 | self.T = num_features # Types of atoms 20 | self.Y = num_edge_features # Edge features 21 | 22 | #Map the Z vector to a higher dimensional vector 23 | self.layer_1 = nn.Linear(self.z_dim, 128) 24 | self.layer_1_act = nn.Tanh() 25 | self.layer_2 = nn.Linear(128,256) 26 | self.layer_2_act = nn.Tanh() 27 | self.layer_3 = nn.Linear(256,512) 28 | self.layer_3_act = nn.Tanh() 29 | 30 | # This should match the dimensions for the A tensor (Adjacency Tensor) 31 | self.edges_layer = nn.Linear(512, self.N * self.N * self.Y) 32 | 33 | # This should match the dimensions for the X matrix (Feature Matrix) 34 | self.node_layer = nn.Linear(512, self.N * self.T) 35 | 36 | 37 | def forward(self, x): 38 | output = self.layer_1(x) 39 | output = self.layer_1_act(output) 40 | output = self.layer_2(output) 41 | output = self.layer_2_act(output) 42 | output = self.layer_3(output) 43 | output = self.layer_3_act(output) 44 | 45 | adj_mat = self.edges_layer(output).view(-1, self.Y, self.N, self.N) 46 | 47 | #Make tensor symmetric in H W dimensions 48 | adj_mat = (adj_mat + adj_mat.permute(0,1,3,2))/2 49 | #Move the edge attribute to the end 50 | adj_mat = adj_mat.permute(0,2,3,1) 51 | 52 | feat_mat = self.node_layer(output).view(-1, self.N, self.T) 53 | 54 | return adj_mat, feat_mat 55 | 56 | 57 | class Discriminator(nn.Module): 58 | """docstring for Discriminator""" 59 | def __init__(self, num_input_features = 11): 60 | super(Discriminator, self).__init__() 61 | self.atom_types = num_input_features 62 | 63 | self.conv1 = GraphConv(self.atom_types, 128) 64 | self.pool1 = TopKPooling(128, ratio=0.8) 65 | 66 | self.conv2 = GraphConv(128, 128) 67 | self.pool2 = TopKPooling(128, ratio = 0.8) 68 | 69 | 70 | self.lin1 = nn.Linear(512, 64) 71 | self.lin2 = nn.Linear(64,16) 72 | self.lin3 = nn.Linear(16,1) 73 | 74 | def forward(self, data): 75 | x, edge_index, batch = data.x, data.edge_index, data.batch 76 | 77 | x = F.relu (self.conv1(x,edge_index)) 78 | 79 | x, edge_index,_, batch, _,_ = self.pool1(x, edge_index, None, batch) 80 | x1 = torch.cat([gmp(x,batch), gap(x,batch)], dim=1) 81 | 82 | x = F.relu (self.conv2(x,edge_index)) 83 | x, edge_index, _, batch, _,_ = self.pool2(x, edge_index, None,batch) 84 | x2 = torch.cat([gmp(x,batch), gap(x,batch)], dim=1) 85 | 86 | x = torch.cat([x1,x2], dim=1) 87 | 88 | x = F.relu(self.lin1(x)) 89 | x = F.relu(self.lin2(x)) 90 | x = F.relu(self.lin3(x)) 91 | 92 | return x 93 | 94 | 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /MOF_Force_Field/model.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.data as data_utils 8 | from torch_geometric.data import Data, DataLoader 9 | from torch import nn 10 | from torch_geometric.nn import global_mean_pool as gap 11 | from torch_geometric.nn import global_max_pool as gmp 12 | from torch_geometric.nn import global_add_pool as gaddp 13 | 14 | from MOLGCN import MOLGCN 15 | 16 | 17 | 18 | 19 | class MOF_Net(torch.nn.Module): 20 | def __init__(self, 21 | input_features = None, 22 | mlp = None): 23 | super(MOF_Net, self).__init__() 24 | if (mlp): 25 | self.nn = mlp 26 | else: 27 | raise Exception("Must set one of either input_features or mlp ") 28 | 29 | self.conv = MOLGCN(self.nn) 30 | def forward(self, data): 31 | x, edge_index, batch, edge_attr = data.x, data.edge_index, data.batch, data.edge_attr 32 | # print(edge_attr.shape) 33 | x = self.conv(x, edge_index, edge_attr) 34 | # print(x.shape) 35 | x = gaddp(x, batch) 36 | # print(x.shape) 37 | x = x.squeeze() / 2 38 | # print(x.shape) 39 | return x 40 | 41 | 42 | 43 | 44 | 45 | def run(loader, 46 | model, 47 | optimizer, 48 | loss_func, 49 | device, 50 | train = True): 51 | 52 | average_batch_loss = 0 53 | 54 | def run_(): 55 | total = 0 56 | desc = ['validation', 'training'] 57 | for data in loader: 58 | data = data.to(device) 59 | y_out = model(data) 60 | y = data.y.to(device) 61 | loss = loss_func(y, y_out) 62 | 63 | if (train): 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | 68 | total += loss.item() 69 | return total / len(loader) 70 | 71 | 72 | if (train): 73 | average_batch_loss = run_() 74 | else: 75 | with torch.no_grad(): #This reduces memory usage 76 | average_batch_loss = run_() 77 | return average_batch_loss 78 | 79 | 80 | 81 | if __name__ == '__main__': 82 | 83 | dataset = MOFDataset('FIGXAU_V2.csv','.') 84 | dataset = dataset.shuffle() 85 | 86 | batch_size = 16 87 | 88 | one_tenth_length = int(len(dataset) * 0.1) 89 | train_dataset = dataset[:one_tenth_length * 8] 90 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 91 | 92 | val_dataset = dataset[one_tenth_length * 8 :] 93 | val_loader = DataLoader(val_dataset, batch_size = batch_size) 94 | 95 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 96 | model = MOF_Net(9).to(device) 97 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 98 | loss_func = nn.MSELoss() 99 | 100 | 101 | for epoch in range(10): 102 | print("*" * 100) 103 | training_loss = run(train_loader,model,optimizer,loss_func,device,True) 104 | val_loss = run(val_loader, 105 | model, 106 | optimizer, 107 | loss_func, 108 | device, 109 | False) 110 | 111 | print('\n') 112 | print("Epoch {} : Training Loss: {:.4f} \t Validation Loss: {:.4f} ".format(epoch + 1, training_loss, val_loss)) 113 | print('\n') 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /Moledule_Generation/MolGAN/main.py: -------------------------------------------------------------------------------- 1 | from model import Generator 2 | from model import Discriminator 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | from torch.autograd import Variable 11 | import torch_geometric.utils as graph_utils 12 | from torch_geometric.data import Data 13 | from torch_geometric.data import DataLoader 14 | import pickle 15 | # from data import SmallWorld 16 | def post_process(inp): 17 | def listify(x): 18 | return x if type(x) == list or type(x) == tuple else [x] 19 | def delistifiy(x): 20 | return x if len(x) > 1 else x[0] 21 | 22 | softmax = [F.softmax(e_logits, -1) for e_logits in listify(inp)] 23 | 24 | return [delistifiy(e) for e in (softmax)] 25 | 26 | def gen_samples(generator, device, num_samples=32): 27 | fake_data = [] 28 | fake_label = 0 29 | for i in range(num_samples): 30 | z = np.random.normal(0,1,size=(1,32)) 31 | z = torch.from_numpy(z).to(device).float() 32 | adj, node = generator(z) 33 | 34 | (adj_hat, node_hat) = post_process((adj, node)) 35 | 36 | adj_hat = adj_hat.squeeze() 37 | 38 | index, value = graph_utils.dense_to_sparse(adj_hat) 39 | 40 | value = value.unsqueeze(1) 41 | 42 | data = Data(x=node_hat, edge_index=index, edge_attr=value, y=fake_label) 43 | 44 | # print(type(data)) 45 | fake_data.append(data) 46 | 47 | return fake_data 48 | 49 | def main(): 50 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 51 | Gen = Generator().to(device) 52 | D = Discriminator().to(device) 53 | # optimizerD = optim 54 | 55 | 56 | criterion = nn.BCELoss() 57 | 58 | real_label = 1 59 | fake_label = 0 60 | 61 | optimizerD = optim.Adam(D.parameters(), lr=0.0001, betas=(0.9, 0.999)) 62 | optimizerG = optim.Adam(Gen.parameters(), lr=0.01, betas=(0.9, 0.999)) 63 | 64 | NUM_EPOCHS = 200 65 | 66 | for epoch in range(NUM_EPOCHS): 67 | ## 68 | ## First train the discriminator 69 | ## 70 | 71 | training_data_list = pickle.load(open('MOF_GENERATOR_32.p','rb')) 72 | fake_data = gen_samples(Gen, device, num_samples=16) 73 | true_loader = DataLoader(training_data_list, batch_size=16) 74 | fake_loader = DataLoader(fake_data, batch_size=16) 75 | 76 | real_loss = 0 77 | fake_loss = 0 78 | 79 | for data in true_loader: 80 | data = data.to(device) 81 | out = D(data) 82 | real_loss += -torch.mean(out) 83 | 84 | 85 | for data in fake_loader: 86 | data = data.to(device) 87 | out = D(data) 88 | fake_loss += torch.mean(out) 89 | disc_loss = fake_loss + real_loss 90 | 91 | optimizerD.zero_grad() 92 | disc_loss.backward() 93 | optimizerD.step() 94 | print("Discriminator Loss: ",disc_loss.item(), 95 | "Real Loss:", -real_loss.item(), " Fake Loss", fake_loss.item()) 96 | ## 97 | ## Second train the generator 98 | ## 99 | fake_data = gen_samples(Gen, device) 100 | fake_loader = DataLoader(fake_data, batch_size=16) 101 | 102 | g_fake_loss = 0 103 | for data in fake_loader: 104 | data = data.to(device) 105 | out = D(data) 106 | g_fake_loss += -torch.mean(out) 107 | 108 | optimizerG.zero_grad() 109 | g_fake_loss.backward() 110 | optimizerG.step() 111 | print("Generator Loss: ", -g_fake_loss.item()) 112 | print("*" * 40) 113 | ## 114 | ## Not really sure here 115 | ## 116 | 117 | if __name__ == '__main__': 118 | main() 119 | -------------------------------------------------------------------------------- /Moledule_Generation/MolGAN/tester.py: -------------------------------------------------------------------------------- 1 | from model import Generator 2 | from model import Discriminator 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | 10 | from torch.autograd import Variable 11 | import torch_geometric.utils as graph_utils 12 | from torch_geometric.data import Data 13 | from torch_geometric.data import DataLoader 14 | 15 | import pickle 16 | def post_process(inp): 17 | def listify(x): 18 | return x if type(x) == list or type(x) == tuple else [x] 19 | def delistifiy(x): 20 | return x if len(x) > 1 else x[0] 21 | 22 | softmax = [F.softmax(e_logits, -1) for e_logits in listify(inp)] 23 | 24 | return [delistifiy(e) for e in (softmax)] 25 | 26 | 27 | def main(): 28 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 29 | Discrim = Discriminator() 30 | Gen = Generator() 31 | 32 | gen_optim = optim.Adam(Gen.parameters(), lr=0.0001) 33 | discrim_optim = optim.Adam(Discrim.parameters(), lr= 0.0001) 34 | 35 | 36 | real_label = 1 37 | fake_label = 0 38 | 39 | criterion = nn.BCELoss() 40 | 41 | # z = np.random.normal(0, 1, size=(1, 32)) 42 | # z = torch.from_numpy(z).to(device).float() 43 | # adj, node = Gen(z) 44 | # print(adj) 45 | # (adj_hat, node_hat) = post_process((adj, node)) 46 | # print(adj_hat) 47 | # print(adj_hat.shape) 48 | # adj_hat = adj_hat.squeeze() 49 | # print(adj_hat) 50 | # print(adj_hat.shape) 51 | # index, value = graph_utils.dense_to_sparse(adj_hat) 52 | # value = value.unsqueeze(1) 53 | # print(node_hat.shape) 54 | # print(index.shape) 55 | # print(value.shape) 56 | # data = Data(x=x, edge_index=index, edge_attr=value, y=fake_label) 57 | 58 | training_data_list = pickle.load(open('MOF_GENERATOR_32.p','rb')) 59 | 60 | for epochs in range(10): 61 | fake_data = [] 62 | for i in range(32): 63 | z = np.random.normal(0,1,size=(1,32)) 64 | z = torch.from_numpy(z).to(device).float() 65 | adj, node = Gen(z) 66 | 67 | (adj_hat, node_hat) = post_process((adj, node)) 68 | 69 | adj_hat = adj_hat.squeeze() 70 | 71 | index, value = graph_utils.dense_to_sparse(adj_hat) 72 | 73 | value = value.unsqueeze(1) 74 | 75 | data = Data(x=node_hat, edge_index=index, edge_attr=value, y=fake_label) 76 | 77 | # print(type(data)) 78 | fake_data.append(data) 79 | 80 | true_loader = DataLoader(training_data_list, batch_size=16) 81 | fake_loader = DataLoader(fake_data, batch_size=16) 82 | 83 | real_loss = 0 84 | fake_loss = 0 85 | 86 | for data in true_loader: 87 | data = data.to(device) 88 | # print(dir(data)) 89 | # print((data.edge_attr).shape) 90 | # print((data.edge_index).shape) 91 | # print((data.weight).shape) 92 | out = Discrim(data) 93 | real_loss += -torch.mean(out) 94 | 95 | print("*" * 40) 96 | 97 | for data in fake_loader: 98 | data = data.to(device) 99 | # print(dir(data)) 100 | # print((data.x).shape) 101 | # print((data.edge_attr).shape) 102 | # print((data.edge_index).shape) 103 | out = Discrim(data) 104 | fake_loss += torch.mean(out) 105 | disc_loss = fake_loss + real_loss 106 | 107 | discrim_optim.zero_grad() 108 | disc_loss.backward() 109 | discrim_optim.step() 110 | 111 | 112 | ## 113 | ## Train the generator 114 | ## 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /MOF_Force_Field/tester.py: -------------------------------------------------------------------------------- 1 | from dataloader import MOFDataset 2 | import os.path as osp 3 | import os 4 | 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.utils.data as data_utils 12 | 13 | from torch_geometric.data import Data, DataLoader 14 | from model import MOF_Net, run 15 | from MOLGCN import MOLGCN 16 | import model 17 | 18 | from pathlib import Path 19 | 20 | 21 | 22 | 23 | 24 | cur_dir = os.curdir 25 | dataset = MOFDataset('FIGXAU_V2.csv','.') 26 | 27 | 28 | 29 | batch_size = 48 30 | 31 | one_tenth_length = int(len(dataset) * 0.1) 32 | 33 | train_dataset = dataset[:one_tenth_length * 8] 34 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 35 | 36 | val_dataset = dataset[one_tenth_length * 8 :] 37 | val_loader = DataLoader(val_dataset, batch_size = 1024) 38 | 39 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 40 | 41 | mlp = nn.Sequential(nn.Linear(5,1024), 42 | nn.ReLU(), 43 | nn.Dropout(0.5), 44 | nn.Linear(1024,256), 45 | nn.ReLU(), 46 | nn.Dropout(0.5), 47 | nn.Linear(256,32), 48 | nn.ReLU(), 49 | nn.Dropout(0.5), 50 | nn.Linear(32,1) 51 | ) 52 | 53 | 54 | model = MOF_Net(5,mlp).to(device) 55 | 56 | PATH = 'warmup_model.pt' 57 | 58 | saved_model = Path(PATH) 59 | 60 | if (saved_model.is_file()): 61 | print("Loading Saved Model") 62 | model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want 63 | 64 | model.to(device) 65 | 66 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) 67 | loss_func = nn.SmoothL1Loss() 68 | 69 | train_loss_list = [] 70 | val_loss_list = [] 71 | # warmup_batch = next(iter(train_loader)) 72 | 73 | # not_converge = True 74 | 75 | # counter = 0 76 | # while(not_converge): 77 | # total = 0 78 | # for epoch in range(1000): 79 | # data = warmup_batch.to(device) 80 | # y_out = model(data) 81 | # y = data.y.to(device) 82 | # loss = loss_func(y,y_out) 83 | # optimizer.zero_grad() 84 | # loss.backward() 85 | # optimizer.step() 86 | # # print("Epoch {} : Training Loss: {:.9f} \t".\ 87 | # # format(epoch + 1, loss.item())) 88 | # total += loss.item() 89 | # total = total / 1000 90 | # counter +=1 91 | # if(total > 2): 92 | # print("Not Converged. Continuing Training. Average Error: {}".format(total)) 93 | # not_converge = True 94 | # else: 95 | # print("Converged! Ending Training. Average Error: {}".format(total)) 96 | # not_converge = False 97 | # if (counter > 20): 98 | # not_converge = False 99 | 100 | 101 | 102 | for epoch in range(20): 103 | training_loss = run(train_loader,model,optimizer,loss_func,device,True) 104 | val_loss = run(val_loader, 105 | model, 106 | optimizer, 107 | loss_func, 108 | device, 109 | False) 110 | train_loss_list.append(training_loss) 111 | val_loss_list.append(val_loss) 112 | print("Epoch {} : Training Loss: {:.4f} \t Validation Loss: {:.4f} ".\ 113 | format(epoch + 1, training_loss, val_loss)) 114 | 115 | with torch.no_grad(): 116 | data = warmup_batch.to(device) 117 | y = model(data) 118 | print("Predicted: \n \t", y) 119 | print("Actual: \n \t", data.y) 120 | torch.save(model.state_dict(), PATH) -------------------------------------------------------------------------------- /MOF_Force_Field/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.data import DataLoader, Data 2 | import torch 3 | from torch_geometric.data import InMemoryDataset 4 | from tqdm import tqdm 5 | import math 6 | import numpy as np 7 | from scipy.spatial import distance 8 | 9 | from sklearn.preprocessing import LabelEncoder 10 | from sklearn.preprocessing import OneHotEncoder 11 | 12 | import pandas as pd 13 | import torch_geometric.utils as data_utils 14 | 15 | import ray 16 | ray.init() 17 | 18 | @ray.remote 19 | def get_torch_data(df, threshold = 3): 20 | atoms = df['atom'].values 21 | 22 | energy = np.array([-1*df['Energy(Ry)'].values[0]]) 23 | atoms = np.expand_dims(atoms, axis=1) 24 | 25 | one_hot_encoding = OneHotEncoder(sparse=False).fit_transform(atoms) 26 | coords = df[['x(angstrom)','y(angstrom)','z(angstrom)']].values 27 | 28 | edge_index = None 29 | edge_attr = None 30 | 31 | 32 | while True: 33 | dist = distance.cdist(coords, coords) 34 | dist[dist>threshold] = 0 35 | dist = torch.from_numpy(dist) 36 | edge_index, edge_attr = data_utils.dense_to_sparse(dist) 37 | edge_attr = edge_attr.unsqueeze(dim=1).type(torch.FloatTensor) 38 | edge_index = torch.LongTensor(edge_index) 39 | if (data_utils.contains_isolated_nodes(edge_index, num_nodes = 13)): 40 | threshold +=0.5 41 | else: 42 | break 43 | 44 | x = torch.from_numpy(one_hot_encoding).type(torch.FloatTensor) 45 | y = torch.from_numpy(energy).type(torch.FloatTensor) 46 | data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) 47 | 48 | return data 49 | 50 | class MOFDataset(InMemoryDataset): 51 | def __init__(self, 52 | file_name, 53 | root, 54 | transform=None, 55 | pre_transform=None, 56 | pre_filter = None): 57 | self.df = pd.read_csv(file_name) 58 | 59 | super(MOFDataset, self).__init__(root, transform, pre_transform, pre_filter) 60 | self.pre_filter = pre_filter 61 | 62 | self.data, self.slices = torch.load(self.processed_paths[0]) 63 | 64 | 65 | @property 66 | def raw_file_names(self): 67 | return [] 68 | 69 | @property 70 | def processed_file_names(self): 71 | return ['bonds.dataset'] 72 | 73 | def download(self): 74 | pass 75 | 76 | def process(self): 77 | data_list = [] 78 | 79 | # process by run 80 | grouped = self.df.groupby('run') 81 | for run, group in tqdm(grouped): 82 | group = group.reset_index(drop=True) 83 | data_list.append(get_torch_data.remote(group[1:])) 84 | 85 | data_list = ray.get(data_list) 86 | 87 | if(self.pre_filter): 88 | data_list = [x for x in data_list if self.pre_filter(x)] 89 | data, slices = self.collate(data_list) 90 | torch.save((data, slices), self.processed_paths[0]) 91 | 92 | # 49988 number of unique runs 93 | 94 | if __name__ == '__main__': 95 | dataset = MOFDataset('FIGXAU_V2.csv','.') 96 | 97 | dataset = dataset.shuffle() 98 | one_tenth_length = int(len(dataset) * 0.1) 99 | train_dataset = dataset[:one_tenth_length * 8] 100 | val_dataset = dataset[one_tenth_length*8:one_tenth_length * 9] 101 | test_dataset = dataset[one_tenth_length*9:] 102 | 103 | batch_size = 512 104 | train_loader = DataLoader(train_dataset, batch_size=batch_size) 105 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 106 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 107 | -------------------------------------------------------------------------------- /Generative_Model/AE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from MOFVAE import Auto_Encoder as AE 5 | import util 6 | from mof_dataset import MOFDataset 7 | import numpy as np 8 | from pathlib import Path 9 | 10 | 11 | cuda = True if torch.cuda.is_available() else False 12 | device = torch.device('cuda' if cuda else 'cpu') 13 | num_epochs = 1000 14 | batch_size = 128 15 | learning_rate = 1e-3 16 | 17 | dataloader = MOFDataset.get_data_loader("../3D_Grid_Data/Training_MOFS.p", batch_size, no_grid=True) 18 | 19 | model = None 20 | saved_model = Path("AE_MODEL.p") 21 | if (saved_model.is_file()): 22 | print("Loading Saved Model") 23 | model= torch.load("AE_MODEL.p") 24 | else: 25 | model = AE.ConvolutionalAE(2048,11) 26 | 27 | if cuda: 28 | model.cuda() 29 | 30 | 31 | criterion = nn.MSELoss(reduction = 'sum') 32 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) 33 | fake_data = Variable(torch.ones(1,11,32,32,32)).cuda() 34 | 35 | 36 | def to_img(x): 37 | x = 0.5 * (x + 1) 38 | x = x.clamp(0, 1) 39 | x = x.view(x.size(0), 11, 32, 32,32) 40 | return x 41 | 42 | def train(): 43 | loss = 0 44 | total_loss = 0 45 | for data in enumerate(dataloader): 46 | batch, img= data 47 | # print(img.shape) 48 | # img = img.view(img.size(0), -1) 49 | # print(img.shape) 50 | img = Variable(img.float()).cuda() 51 | # # ===================forward===================== 52 | output = model(img) 53 | 54 | loss = criterion(output.flatten(), img.flatten()) 55 | 56 | print("Batch [{}/{}], loss:{:.6f}".format(batch, len(dataloader), loss.item())) 57 | total_loss += loss.item() 58 | # # ===================backward==================== 59 | optimizer.zero_grad() 60 | loss.backward() 61 | optimizer.step() 62 | # ===================log======================== 63 | total_loss = total_loss / (len(dataloader)) 64 | return total_loss 65 | 66 | def small_batch_train(batch_of_data): 67 | loss = 0 68 | img = Variable(batch_of_data.float()).cuda() 69 | output = model(img) 70 | loss = criterion(output, img) 71 | optimizer.zero_grad() 72 | loss.backward() 73 | optimizer.step() 74 | 75 | return loss.item() 76 | 77 | def visualize(tensor, savefile): 78 | channels = list(range(11)) 79 | util.Visualize_MOF(tensor, channels, savefile=savefile) 80 | 81 | def trial(data, file_name): 82 | with torch.no_grad(): 83 | output = model(data.float().cuda()) 84 | # loss = criterion(fake_data, output) 85 | 86 | print("Outputing file: ", file_name) 87 | visualize(output[0].cpu().numpy(), file_name) 88 | # optimizer.zero_grad() 89 | # loss.backward() 90 | # optimizer.step() 91 | 92 | # return loss 93 | 94 | def main(): 95 | 96 | min_loss =0 97 | test_image = None 98 | img = None 99 | for data in enumerate(dataloader): 100 | batch, img = data 101 | print(img[0].shape) 102 | img[img < 10E-2] = 0 103 | test_image = img.clone().detach() 104 | visualize(test_image[0], "MOF.png") 105 | break 106 | 107 | # trial() 108 | 109 | for epoch in range(num_epochs): 110 | loss = train() 111 | # print(epoch, loss.item()) 112 | 113 | 114 | if (epoch %1 ==0): 115 | print('epoch [{}/{}], loss:{:.6f}'.format(epoch + 1, num_epochs, loss)) 116 | if(epoch % 10 == 0): 117 | trial(test_image, "Decoded_Image_Prog_"+str(epoch)+".png") 118 | 119 | if (epoch % 50 ==0): 120 | torch.save(model, "AE_MODEL_FULL.p") 121 | # with torch.no_grad(): 122 | # ones = Variable(torch.ones(1,12,32,32,32)).cuda() 123 | # output = model(ones) 124 | 125 | # loss1 = criterion(output, ones) 126 | 127 | # loss2 = criterion(output.flatten(),ones.flatten() ) 128 | 129 | # print(loss1.item(), loss2.item()) 130 | # print(output[0].shape) 131 | # print(output[0][0]) 132 | 133 | 134 | 135 | main() -------------------------------------------------------------------------------- /Generative_Model/MOFVAE/VAE_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.autograd as autograd 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class VAE(nn.Module): 9 | """docstring for VAE""" 10 | def __init__(self, z_dimension, num_features, voxel_side_length): 11 | super(VAE, self).__init__() 12 | self.z_dim = z_dimension 13 | self.voxel_side_length = voxel_side_length 14 | self.num_features = num_features 15 | 16 | self.encoder = nn.Sequential( 17 | # (in_channels, out_channels, kernel_size, stride, padding) 18 | nn.Conv3d(self.num_features, self.z_dim // 8, 4,2,1), #32 19 | nn.BatchNorm3d(self.z_dim // 8), 20 | nn.LeakyReLU(0.2), 21 | 22 | nn.Conv3d(self.z_dim // 8, self.z_dim // 4, 4,2,1), #16 23 | nn.BatchNorm3d(self.z_dim // 4), 24 | nn.LeakyReLU(0.2), 25 | 26 | nn.Conv3d(self.z_dim // 4, self.z_dim // 2, 4,2,1), #4 27 | nn.BatchNorm3d(self.z_dim // 2), 28 | nn.LeakyReLU(0.2), 29 | 30 | nn.Conv3d(self.z_dim // 2, self.z_dim, 4,2,1), #2 31 | nn.BatchNorm3d(self.z_dim), 32 | nn.LeakyReLU(0.2), 33 | 34 | nn.Conv3d(self.z_dim, self.z_dim, 2,2,0), #1 35 | ) 36 | self.mu_fc = nn.Linear(self.z_dim, self.z_dim) 37 | self.logvar_fc = nn.Linear(self.z_dim, self.z_dim) 38 | 39 | if(self.voxel_side_length == 32): 40 | self.decoder = nn.Sequential( 41 | 42 | nn.ConvTranspose3d(self.z_dim, self.num_features*8, 4,2,0), # self.num_features*8 x 4 x 4 x 4 43 | nn.BatchNorm3d(self.num_features*8), 44 | nn.Tanh(), 45 | 46 | nn.ConvTranspose3d(self.num_features*8, self.num_features*4, 4,2,1), # self.self.num_features*4 x 8 x 8 x 8 47 | nn.BatchNorm3d(self.num_features*4), 48 | nn.Tanh(), 49 | 50 | nn.ConvTranspose3d(self.num_features*4, self.num_features*2, 4,2,1), # self.self.num_features*2 x 16 x 16 x 16 51 | nn.BatchNorm3d(self.num_features*2), 52 | nn.Tanh(), 53 | 54 | nn.ConvTranspose3d(self.num_features*2, self.num_features, 4,2,1), # self.self.num_features x 32 x 32 x 32 55 | nn.Sigmoid(), 56 | ) 57 | else: 58 | print("Not Implemented variable sized grid initializer. set voxel_side_length = 32. ") 59 | 60 | def reparametrize(self, mu, logvar): 61 | std = torch.exp(0.5*logvar) 62 | eps = torch.randn_like(std) 63 | return mu + eps * std 64 | def encode(self, x): 65 | #Input: X 66 | #Shape: (Num_batches, Num_features, D, W, H) 67 | #Output: Mu, Logvar 68 | #Shape: (Num_Batches, Z_dim), (Num_Batches, Z_dim) 69 | x = self.encoder(x) 70 | x = x.view(-1, x.size(1)) 71 | mu = self.mu_fc(x) 72 | logvar = self.logvar_fc(x) 73 | return self.reparametrize(mu,logvar), mu, logvar 74 | 75 | def decode(self, x): 76 | return self.decoder(x) 77 | 78 | def forward(self, data): 79 | z, mu, logvar = self.encode(data) 80 | return self.decode(z), mu, logvar 81 | 82 | 83 | 84 | class MLP_VAE(nn.Module): 85 | """docstring for MLP_VAE""" 86 | def __init__(self, z_dim): 87 | super(MLP_VAE, self).__init__() 88 | self.z_dim = z_dim 89 | 90 | self.encoder = nn.Sequential( 91 | nn.Linear(self.z_dim, 400), 92 | nn.ReLU(True)) 93 | 94 | self.decoder = nn.Sequential( 95 | nn.Linear(20, 400), 96 | nn.ReLU(True), 97 | nn.Linear(400, 28 * 28), 98 | nn.Sigmoid()) 99 | 100 | self.mu = nn.Linear(400,20) 101 | self.logvar = nn.Linear(400,20) 102 | 103 | def encode(self, x): 104 | h1 = self.encoder(x) 105 | return self.mu(h1), self.logvar(h1) 106 | 107 | def reparameterize(self, mu, logvar): 108 | std = torch.exp(0.5*logvar) 109 | eps = torch.randn_like(std) 110 | return mu + eps*std 111 | def decode(self, z): 112 | h3 = self.decoder(z) 113 | return h3 114 | def forward(self,x): 115 | mu, logvar = self.encode(x) 116 | z = self.reparameterize(mu, logvar) 117 | return self.decode(z), mu, logvar 118 | 119 | if __name__ == '__main__': 120 | enc = Encoder(Num_features=12, Z_dim = 64) 121 | dec = Decoder(num_features = 12, z_dimension = 64, voxel_side_length=32) 122 | 123 | data = Variable(torch.rand(2,12,32,32,32)) 124 | 125 | z, mu, logvar = enc(data) 126 | 127 | print(z.shape, mu.shape, logvar.shape) 128 | 129 | X = z.view(z.size(0), z.size(1), 1,1,1) 130 | print(X.shape) 131 | 132 | decoded_z = dec(X) 133 | 134 | print(decoded_z[0].shape) 135 | 136 | #chans = range(0,4) 137 | #util.Visualize_4DTensor(decoded_z[0].cpu().detach().numpy(), chans) 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /binary_dataset/binary_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import tarfile 4 | from os import path 5 | from pathlib import Path 6 | 7 | import networkx as nx 8 | import numpy as np 9 | import pandas as pd 10 | import requests 11 | import torch 12 | import torch_geometric 13 | from pymatgen.io.cif import CifParser 14 | from torch_geometric.data import InMemoryDataset 15 | 16 | 17 | class BinaryDataSet(InMemoryDataset): 18 | 19 | def __init__(self, train=True): 20 | root = os.getcwd() 21 | self.train = train 22 | self.properties_file = 'properties.csv' 23 | self.data_file = 'data.tar' 24 | self.processed_test_data = 'test_data.pt' 25 | self.processed_training_data = 'training_data.pt' 26 | self.data_file_path = path.join(root, 'raw', self.data_file) 27 | self.properties_file_path = path.join(root, 'raw', self.properties_file) 28 | super(BinaryDataSet, self).__init__(root, transform=None, pre_transform=None) 29 | self.data, self.slices = torch.load(self.processed_training_data if train else self.processed_test_data) 30 | 31 | @property 32 | def raw_file_names(self): 33 | return [self.properties_file, self.data_file] 34 | 35 | @property 36 | def processed_file_names(self): 37 | return [self.processed_test_data, self.processed_training_data] 38 | 39 | @staticmethod 40 | def validate_caller(): 41 | if __name__ != '__main__': 42 | print("Run this script separately to generate necessary files") 43 | exit(1) 44 | 45 | def download(self): 46 | self.validate_caller() 47 | print('Downloading properties CSV...') 48 | self.download_file(url='https://zenodo.org/record/3370144/files/2019-07-01-ASR-public_12020.csv?download=1', 49 | file_name=self.properties_file_path) 50 | print('Downloading CIF data...') 51 | self.download_file(url='https://zenodo.org/record/3370144/files/2019-07-01-ASR-public_12020.tar?download=1', 52 | file_name=self.data_file_path) 53 | 54 | print("Extracting tar...") 55 | with tarfile.open(self.data_file_path) as tar: 56 | tar.extractall(self.raw_dir) 57 | 58 | print("Generating training set and test set...") 59 | 60 | source_dir = path.join(self.raw_dir, "structure_11660") 61 | train_dir = path.join(self.raw_dir, "training") 62 | test_dir = path.join(self.raw_dir, "test") 63 | 64 | os.mkdir(train_dir) 65 | os.mkdir(test_dir) 66 | 67 | for cif_file_path in Path(source_dir).iterdir(): 68 | cif_file = cif_file_path.name 69 | num = random.random() 70 | source_file = path.join(source_dir, cif_file) 71 | if num < .79: 72 | os.rename(source_file, path.join(train_dir, cif_file)) 73 | else: 74 | os.rename(source_file, path.join(test_dir, cif_file)) 75 | 76 | print("Done!") 77 | 78 | def process(self): 79 | self.validate_caller() 80 | data_list = [] 81 | print("Creating binary PyTorch dataset...") 82 | 83 | labels = pd.read_csv(self.properties_file_path) 84 | for data_type in ['training', 'test']: 85 | output_file = self.processed_training_data if data_type == 'training' else self.processed_test_data 86 | directory = Path(path.join(self.raw_dir, data_type)) 87 | 88 | counter = 0 89 | total_files = 0 90 | for _ in directory.iterdir(): 91 | total_files += 1 92 | 93 | for file in directory.iterdir(): 94 | structure = self.cif_structure(str(file)) 95 | distance_matrix = structure.distance_matrix 96 | 97 | graph = nx.from_numpy_matrix(distance_matrix.astype(np.double)) 98 | num_nodes = distance_matrix.shape[0] 99 | 100 | data = torch_geometric.utils.from_networkx(graph) 101 | data.x = torch.ones(num_nodes, 1) 102 | data.y = labels['LCD'][counter] 103 | data_list.append(data) 104 | 105 | print("Elements loaded: ", counter, "/", total_files) 106 | counter += 1 107 | 108 | data, slices = self.collate(data_list) 109 | torch.save((data, slices), output_file) 110 | 111 | @staticmethod 112 | def download_file(url, file_name): 113 | r = requests.get(url) 114 | 115 | with open(file_name, 'wb') as f: 116 | f.write(r.content) 117 | if r.status_code == "200": 118 | print("Completed download") 119 | 120 | @staticmethod 121 | def cif_structure(file_name): 122 | parser = CifParser(file_name) 123 | structure = parser.get_structures()[0] 124 | return structure 125 | 126 | 127 | if __name__ == '__main__': 128 | BinaryDataSet() 129 | -------------------------------------------------------------------------------- /MOF_Force_Field/MOF_Approximator.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from dataloader import MOFDataset\n", 10 | "import os.path as osp \n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "cur_dir = os.curdir\n", 21 | "dataset = MOFDataset('FIGXAU_V2.csv','.')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [ 29 | { 30 | "data": { 31 | "text/plain": [ 32 | "MOFDataset(49988)" 33 | ] 34 | }, 35 | "execution_count": 3, 36 | "metadata": {}, 37 | "output_type": "execute_result" 38 | } 39 | ], 40 | "source": [ 41 | "dataset" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "import numpy as np \n", 51 | "import matplotlib.pyplot as plt \n", 52 | "import torch \n", 53 | "import torch.nn as nn \n", 54 | "import torch.nn.functional as F\n", 55 | "import torch.utils.data as data_utils\n", 56 | "\n", 57 | "from torch_geometric.data import Data, DataLoader\n", 58 | "from model import MOF_Net, run\n", 59 | "from MOLGCN import MOLGCN\n", 60 | "import model" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Epoch 1 : Training Loss: 1975.3556 \t Validation Loss: 514.7431 \n", 73 | "Epoch 2 : Training Loss: 564.7342 \t Validation Loss: 470.6502 \n", 74 | "Epoch 3 : Training Loss: 502.4572 \t Validation Loss: 423.8463 \n", 75 | "Epoch 4 : Training Loss: 455.0027 \t Validation Loss: 422.1005 \n", 76 | "Epoch 5 : Training Loss: 402.6585 \t Validation Loss: 415.1606 \n", 77 | "Epoch 6 : Training Loss: 380.8641 \t Validation Loss: 415.1636 \n", 78 | "Epoch 7 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 79 | "Epoch 8 : Training Loss: 380.8648 \t Validation Loss: 415.1636 \n", 80 | "Epoch 9 : Training Loss: 380.8714 \t Validation Loss: 415.1636 \n", 81 | "Epoch 10 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 82 | "Epoch 11 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 83 | "Epoch 12 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 84 | "Epoch 13 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 85 | "Epoch 14 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n", 86 | "Epoch 15 : Training Loss: 380.8642 \t Validation Loss: 415.1636 \n" 87 | ] 88 | } 89 | ], 90 | "source": [ 91 | "dataset = dataset.shuffle()\n", 92 | "batch_size = 2\n", 93 | "one_tenth_length = int(len(dataset) * 0.1)\n", 94 | "\n", 95 | "train_dataset = dataset[:one_tenth_length * 8]\n", 96 | "train_loader = DataLoader(train_dataset, batch_size=batch_size)\n", 97 | "\n", 98 | "val_dataset = dataset[one_tenth_length * 8 :]\n", 99 | "val_loader = DataLoader(val_dataset, batch_size = batch_size)\n", 100 | "\n", 101 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 102 | "model = MOF_Net(5).to(device)\n", 103 | "optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n", 104 | "loss_func = nn.MSELoss()\n", 105 | "\n", 106 | "train_loss_list = []\n", 107 | "val_loss_list = [] \n", 108 | "\n", 109 | "for epoch in range(100):\n", 110 | " training_loss = run(train_loader,model,optimizer,loss_func,device,True)\n", 111 | " val_loss = run(val_loader,\n", 112 | " model,\n", 113 | " optimizer,\n", 114 | " loss_func,\n", 115 | " device,\n", 116 | " False)\n", 117 | " train_loss_list.append(training_loss)\n", 118 | " val_loss_list.append(val_loss)\n", 119 | " print(\"Epoch {} : Training Loss: {:.4f} \\t Validation Loss: {:.4f} \".format(epoch + 1, training_loss, val_loss)) " 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [] 128 | } 129 | ], 130 | "metadata": { 131 | "kernelspec": { 132 | "display_name": "Python 3", 133 | "language": "python", 134 | "name": "python3" 135 | }, 136 | "language_info": { 137 | "codemirror_mode": { 138 | "name": "ipython", 139 | "version": 3 140 | }, 141 | "file_extension": ".py", 142 | "mimetype": "text/x-python", 143 | "name": "python", 144 | "nbconvert_exporter": "python", 145 | "pygments_lexer": "ipython3", 146 | "version": "3.6.9" 147 | } 148 | }, 149 | "nbformat": 4, 150 | "nbformat_minor": 4 151 | } 152 | -------------------------------------------------------------------------------- /Generative_Model/CIFtoTensor.py: -------------------------------------------------------------------------------- 1 | from pymatgen.io.cif import CifParser 2 | import warnings 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.mplot3d import Axes3D 6 | from matplotlib import cm 7 | from scipy.ndimage import gaussian_filter 8 | 9 | class CIFtoTensor(object): 10 | """docstring for CIFtoTensor""" 11 | def __init__(self): 12 | super(CIFtoTensor, self).__init__() 13 | self.default_atoms = ["H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] 14 | 15 | 16 | 17 | def get_cif_file(file_name = "AHOKOX_clean.cif"): 18 | cif_file = CifParser(file_name) 19 | return cif_file 20 | def get_pymat_struct(cif_file): 21 | struct = cif_file.get_structures() 22 | if(len(struct) > 1): 23 | warnings.warn("Multiple Structures generated") 24 | struct = struct[0] 25 | return struct 26 | 27 | def get_image_distance(cif_file): 28 | struct = cif_file.get_structures() 29 | 30 | if(len(struct) > 1): 31 | warnings.warn("Multiple Structures generated") 32 | struct = struct[0] 33 | return struct.distance_matrix 34 | 35 | def to3DTensor(pymat_struc , 36 | atom_species=[], 37 | dimensions = (32,32,32), 38 | gaussian_blurring=True, 39 | half_precision = False, normalize = 134, spread = 0.5): 40 | 41 | 42 | if(gaussian_blurring): 43 | warnings.warn("Haven't implemented blurring. Returning tensor with pointwise values") 44 | 45 | if(half_precision): 46 | warnings.warn("Using half precision tensors. May not be compatible with training on network without some rewrites") 47 | ''' 48 | TO DO: 49 | 50 | Possibly a useful idea to just integer values for the representation. Look into it later. 51 | ''' 52 | if(len(atom_species) ==0): 53 | # print('Atom Species not specified. Using default_atoms: ["H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] ') 54 | atom_species = ["H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] 55 | 56 | dimensions = (len(atom_species),dimensions[0],dimensions[1], dimensions[2]) 57 | mol_tensor = np.zeros(dimensions) 58 | 59 | MAX = dimensions[1]-1 60 | NORMALIZE = normalize 61 | shifted = False 62 | 63 | site_0 = None 64 | for site in pymat_struc.sites: 65 | if(site.x ==0 and site.y == 0 and site.z == 0): 66 | site_0 = site 67 | else: 68 | specie = atom_species.index(str(site.specie)) 69 | x = 0 70 | y = 0 71 | z = 0 72 | if(site.x < 0): 73 | shifted = True 74 | x = MAX - (abs(site.x) / NORMALIZE) 75 | if(site.x < 0): 76 | shifted = True 77 | y = MAX - (abs(site.y) / NORMALIZE) 78 | if(site.x < 0): 79 | shifted = True 80 | z = MAX - (abs(site.z) / NORMALIZE) 81 | assert x >= 0 and x <= MAX 82 | assert y >= 0 and y <= MAX 83 | assert z >= 0 and z <= MAX 84 | 85 | mol_tensor[specie] =add_mol_gaussian(mol_tensor,specie,x,y,z,variance=spread) 86 | if (site_0): 87 | site_0_specie = atom_species.index(str(site_0.specie)) 88 | if(shifted): 89 | mol_tensor[site_0_specie] = add_mol_gaussian(mol_tensor, site_0_specie, MAX,MAX,MAX,variance=spread) 90 | else: 91 | mol_tensor[site_0_specie] = add_mol_gaussian(mol_tensor, site_0_specie, 0,0,0,variance=spread) 92 | 93 | #for i in range(len(atom_species)): 94 | # mol_tensor[i] = gaussian_filter(mol_tensor[i], sigma = 0.5) 95 | 96 | return mol_tensor 97 | def add_mol_disrete(tensor, specie, x, y, z): 98 | tensor[specie][x][y][z] += 1 99 | return tensor[specie] 100 | 101 | def add_mol_gaussian(tensor, specie, x,y,z, variance=0.5): 102 | shape = tensor[specie].shape 103 | distances = np.zeros(shape) 104 | 105 | for x_i in range(shape[0]): 106 | for y_i in range(shape[1]): 107 | for z_i in range(shape[2]): 108 | distances[x_i][y_i][z_i] = (-0.5)*((x_i - x)**2 + (y_i - y)**2 +(z_i-z)**2)/(variance**2) 109 | distances = np.exp(distances) 110 | distances = np.power(2/np.pi,3/2) * distances 111 | assert distances.shape == shape 112 | tensor[specie] += distances 113 | 114 | return tensor[specie] 115 | def Plot3D(tensor): 116 | fig = plt.figure(figsize = plt.figaspect(0.25)) 117 | dims = (32,32,32) 118 | ax1 = fig.add_subplot(1,4,1,projection='3d') 119 | ax2 = fig.add_subplot(1,4,2,projection='3d') 120 | ax3 = fig.add_subplot(1,4,3,projection='3d') 121 | ax4 = fig.add_subplot(1,4,4,projection='3d') 122 | 123 | 124 | 125 | # for i in range(11): 126 | # ax.voxels(tensor[i], edgecolor="k") 127 | ax1.voxels(tensor[1], edgecolor="k", facecolor="blue") 128 | ax2.voxels(tensor[5], edgecolor="k", facecolor="orange") 129 | ax3.voxels(tensor[3], edgecolor="k", facecolor="yellow") 130 | ax4.voxels(tensor[4], edgecolor="k", facecolor="red") 131 | 132 | ax1.set_title("Oxygen") 133 | ax2.set_title("Copper") 134 | ax3.set_title("Carbon") 135 | ax4.set_title("Phosphorus") 136 | 137 | plt.legend() 138 | plt.show() 139 | 140 | 141 | def main(): 142 | cif_file = CIFtoTensor.get_cif_file() 143 | struc = CIFtoTensor.get_pymat_struct(cif_file) 144 | mol_tensor = CIFtoTensor.to3DTensor(struc) 145 | Plot3D(mol_tensor) 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /Generative_Model/CIFtoVoxel.py: -------------------------------------------------------------------------------- 1 | from pymatgen.io.cif import CifParser 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.mplot3d import Axes3D 5 | from matplotlib import cm 6 | import warnings 7 | 8 | class CIFtoVoxel(object): 9 | """docstring for CIFtoVoxel""" 10 | def __init__(self, filename): 11 | super(CIFtoVoxel, self).__init__() 12 | self.filename = filename 13 | cif_file = CifParser(self.filename) 14 | 15 | self.struct = cif_file.get_structures()[0] 16 | 17 | self.lattice = self.struct.lattice 18 | 19 | self.voxel = self.__generate_voxel() 20 | 21 | 22 | def __generate_voxel(self): 23 | return self.to3DTensor() 24 | 25 | 26 | def to3DTensor(self , 27 | atom_species=[], 28 | dimensions = (32,32,32), 29 | gaussian_blurring=True, 30 | spread = 0.1): 31 | 32 | 33 | # if(half_precision): 34 | # warnings.warn("Using half precision tensors. \ 35 | # May not be compatible with training on \ 36 | # network without some rewrites") 37 | ''' 38 | TO DO: 39 | 40 | Possibly a useful idea to just integer values for the representation. 41 | Look into it later. 42 | ''' 43 | if(len(atom_species) ==0): 44 | # print('Atom Species not specified. Using default_atoms: 45 | #["H","O", "N", "C", "P", "Cu","Co","Ag","Zn","Cd", "Fe"] ') 46 | atom_species = ["H","O", "N", "C", "P", 47 | "Cu","Co","Ag","Zn","Cd", "Fe"] 48 | 49 | dimensions = (len(atom_species),dimensions[0], 50 | dimensions[1], dimensions[2]) 51 | mol_tensor = np.zeros(dimensions) 52 | 53 | MAX = dimensions[1]-1 54 | NORMALIZE_A = self.lattice.a 55 | NORMALIZE_B = self.lattice.b 56 | NORMALIZE_C = self.lattice.c 57 | 58 | site_0 = None 59 | for site in self.struct.sites: 60 | x = site.a 61 | y = site.b 62 | z = site.c 63 | specie = atom_species.index(str(site.specie)) 64 | # print(x,y,z) 65 | mol_tensor[specie] = self.add_mol_gaussian(mol_tensor, 66 | specie, 67 | x, 68 | y, 69 | z, 70 | variance = spread) 71 | 72 | return mol_tensor 73 | 74 | def add_mol_gaussian(self, tensor, specie, x,y,z, variance=0.5): 75 | shape = tensor[specie].shape 76 | distances = np.zeros(shape) 77 | 78 | # print(shape) 79 | for x_i in range(shape[0]): 80 | for y_i in range(shape[1]): 81 | for z_i in range(shape[2]): 82 | gp_x = (x_i / (shape[0]-1) ) 83 | gp_y = (y_i / (shape[1]-1) ) 84 | gp_z = (z_i / (shape[2]-1) ) 85 | 86 | dist = ((np.abs(gp_x-x))**2+ 87 | (np.abs(gp_y-y))**2+ 88 | (np.abs(gp_z-z))**2) 89 | # print("Species: {}, x {} y {} z {} :".format(specie, gp_x, gp_y, gp_z), dist) 90 | distances[x_i][y_i][z_i] = dist 91 | 92 | distances = distances / (variance**2) 93 | # print(distances[1]) 94 | 95 | distances = np.exp(- 0.5 * distances) 96 | # print(distances[1]) 97 | distances = np.power(1/(2*np.pi),3/2) * distances 98 | # print(distances[1]) 99 | tensor[specie] += distances 100 | 101 | return tensor[specie] 102 | 103 | 104 | 105 | def get_voxel(self): 106 | return self.voxel 107 | 108 | def Plot3D(tensor): 109 | fig = plt.figure(figsize = plt.figaspect(0.25)) 110 | ax1 = fig.add_subplot(1,4,1,projection='3d') 111 | ax2 = fig.add_subplot(1,4,2,projection='3d') 112 | ax3 = fig.add_subplot(1,4,3,projection='3d') 113 | ax4 = fig.add_subplot(1,4,4,projection='3d') 114 | 115 | bl = tensor[1] 116 | # print(bl) 117 | bl[bl < bl.max()* 0.95] = 0 118 | 119 | # print(bl) 120 | og = tensor[5] 121 | og[og < og.max() * 0.95] = 0 122 | 123 | ca = tensor[3] 124 | ca[ca < ca.max()* 0.95] = 0 125 | 126 | pa = tensor[4] 127 | pa[pa < pa.max() *0.95] = 0 128 | 129 | 130 | ax1.voxels(bl, edgecolor="k", facecolor="blue") 131 | ax2.voxels(og, edgecolor="k", facecolor="orange") 132 | ax3.voxels(ca, edgecolor="k", facecolor="yellow") 133 | ax4.voxels(pa, edgecolor="k", facecolor="red") 134 | 135 | ax1.set_title("Oxygen") 136 | ax2.set_title("Copper") 137 | ax3.set_title("Carbon") 138 | ax4.set_title("Phosphorus") 139 | 140 | plt.legend() 141 | plt.show() 142 | 143 | def Voxel_coords(tensor, threshold = 0.95, precision = 3 ): 144 | atom_species = ["H","O", "N", "C", "P", 145 | "Cu","Co","Ag","Zn","Cd", "Fe"] 146 | string = "Atom, x , y, z \n" 147 | for i, specie in enumerate(atom_species): 148 | temp = tensor[i] 149 | 150 | temp[temp < temp.max() * threshold] = 0 151 | temp = np.transpose(temp.nonzero()) 152 | temp = temp // precision 153 | temp = temp * precision 154 | 155 | locations = np.unique(temp, axis = 0) 156 | 157 | for loc in locations: 158 | string+=(specie+","+np.array2string(loc / (31), 159 | separator=",", 160 | precision = 4)[1:-1]) 161 | string += '\n' 162 | return string 163 | 164 | if __name__ == '__main__': 165 | cif_voxel = CIFtoVoxel('AHOKOX_clean.cif').get_voxel() 166 | # Plot3D(cif_voxel) 167 | vox_cords = Voxel_coords(cif_voxel) 168 | 169 | save_file = open("AHOKOX_reconstructed.csv", 'w') 170 | save_file.write(vox_cords) 171 | save_file.close() 172 | 173 | print(vox_cords) 174 | -------------------------------------------------------------------------------- /exploratory/parse_cif.py: -------------------------------------------------------------------------------- 1 | import pymatgen 2 | import sys 3 | from pymatgen.io.cif import CifParser 4 | import os 5 | import glob 6 | import numpy as np 7 | import networkx as nx 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.mplot3d import Axes3D 10 | from multiprocessing import Pool 11 | 12 | ''' 13 | Given a structure, find the number of unique elements present 14 | and return a list of elements 15 | ''' 16 | def num_species(structure): 17 | 18 | sites = structure.as_dict()['sites'] 19 | species = set() 20 | for site in sites: 21 | elements = site['species'] 22 | for each in elements: 23 | species.add(each['element']) 24 | return species 25 | 26 | def num_elements(structure): 27 | sites = structure.as_dict()['sites'] 28 | return len(sites) 29 | ''' 30 | Given a valid cif filename, returns a pymatgen structure object 31 | 32 | ''' 33 | def cif_structure(file_name): 34 | parser = CifParser(file_name) 35 | structure = parser.get_structures()[0] 36 | return structure 37 | 38 | def gen_3d_Plot(structure): 39 | fig = plt.figure() 40 | ax = fig.gca(projection='3d') 41 | dims = (25,10,5) 42 | numpy_pos = np.zeros(dims, dtype=bool) 43 | numpy_color = np.zeros(dims, dtype=str) 44 | 45 | dic_color = {"Cu":"yellow","O":"blue","C":"green","P":"red"} 46 | for each in structure.sites: 47 | # print(type(each.specie),type(each.coords)) 48 | new_coords = [0]*3 49 | for i in range(len(each.coords)): 50 | new_coords[i] = int(-1* each.coords[i]) 51 | x = new_coords[0] 52 | y = new_coords[1] 53 | z = new_coords[2] 54 | numpy_pos[x][y][z] = True 55 | numpy_color[x][y][z] = dic_color[str(each.specie)] 56 | 57 | ax.voxels(numpy_pos,facecolors=numpy_color, edgecolor="k") 58 | plt.show() 59 | def smaller(x1, x2): 60 | return int((x1 <= x2))*x1 + int((x1 > x2)) *x2 61 | def bigger(x1, x2): 62 | return int((x1 >= x2))*x1 + int((x1 < x2)) *x2 63 | 64 | def min_max_coords(structure): 65 | site = structure.sites[0] 66 | coord = site.coords 67 | 68 | x_min = coord[0] 69 | x_max = coord[0] 70 | y_min = coord[1] 71 | y_max = coord[1] 72 | z_min = coord[2] 73 | z_max = coord[2] 74 | 75 | for each in structure.sites: 76 | coord = each.coords 77 | 78 | x_min = smaller(x_min,coord[0]) 79 | y_min = smaller(y_min,coord[1]) 80 | z_min = smaller(z_min,coord[2]) 81 | 82 | x_max = bigger(x_max,coord[0]) 83 | y_max = bigger(y_max,coord[1]) 84 | z_max = bigger(z_max,coord[2]) 85 | return x_min,x_max,y_min,y_max,z_min, z_max 86 | 87 | 88 | def min_max_dataset(dataset): 89 | structure = cif_structure(dataset[0]) 90 | x_min,x_max,y_min,y_max,z_min, z_max = min_max_coords(structure) 91 | 92 | for file in dataset: 93 | structure = cif_structure(file) 94 | x_min_s,x_max_s,y_min_s,y_max_s,z_min_s, z_max_s = min_max_coords(structure) 95 | x_min = smaller(x_min,x_min_s) 96 | y_min = smaller(y_min,y_min_s) 97 | z_min = smaller(z_min,z_min_s) 98 | 99 | x_max = bigger(x_max,x_max_s) 100 | y_max = bigger(y_max,y_max_s) 101 | z_max = bigger(z_max,z_max_s) 102 | return x_min,x_max,y_min,y_max,z_min, z_max 103 | 104 | def func(files): 105 | reference_set = set() 106 | reference_set.add("H") 107 | reference_set.add("N") 108 | reference_set.add("C") 109 | reference_set.add("O") 110 | reference_set.add("Co") 111 | reference_set.add("P") 112 | reference_set.add("Zn") 113 | reference_set.add("Ag") 114 | reference_set.add("Cd") 115 | reference_set.add("Cu") 116 | reference_set.add("Fe") 117 | 118 | return_val = [] 119 | for f in files: 120 | structure = cif_structure(f) 121 | u_elements = num_species(structure) 122 | if(u_elements.issubset(reference_set)): 123 | # print(f) 124 | return_val.append(f) 125 | return return_val 126 | 127 | 128 | def main(): 129 | os.chdir("../data/structure_11660/") 130 | files = glob.glob("*.cif") 131 | 132 | print(len(files)) 133 | print(type(files)) 134 | 135 | 136 | # structure = cif_structure(files[0]) 137 | # gen_3d_Plot(structure) 138 | # x_min= x_max=y_min=y_max=z_min= z_max = 0 139 | # x_min,x_max,y_min,y_max,z_min, z_max = min_max_dataset(files) 140 | 141 | # print(x_min,x_max,y_min,y_max,z_min, z_max) 142 | # structure = cif_structure(files[10]) 143 | # distance_matrix = structure.distance_matrix 144 | # print(files[0]) 145 | # print(distance_matrix.shape) 146 | # graph = nx.from_numpy_matrix(distance_matrix) 147 | 148 | # nx.draw(graph) 149 | # plt.show() 150 | # total_unique_species = set() 151 | 152 | # file_write = open("sizes.dat","w") 153 | elements = {} 154 | 155 | reference_set = set() 156 | reference_set.add("H") 157 | reference_set.add("N") 158 | reference_set.add("C") 159 | reference_set.add("O") 160 | reference_set.add("Co") 161 | reference_set.add("P") 162 | reference_set.add("Zn") 163 | reference_set.add("Ag") 164 | reference_set.add("Cd") 165 | reference_set.add("Cu") 166 | reference_set.add("Fe") 167 | 168 | 169 | Num_Processes = 20 170 | 171 | num_files = len(files) 172 | 173 | file_chunks = [ files[int((num_files/Num_Processes) * i): int((num_files / Num_Processes * (i+1)))] for i in range(Num_Processes)] 174 | 175 | for each in file_chunks: 176 | print(each) 177 | pool = Pool(processes=Num_Processes) 178 | results = [pool.apply_async(func, args=(file_chunks[i],)) for i in range(Num_Processes)] 179 | output = [p.get() for p in results] 180 | 181 | log = open("files.log","w") 182 | 183 | for returned_list in output: 184 | for each in returned_list: 185 | log.write(each) 186 | log.write("\n") 187 | log.close() 188 | # for file in files: 189 | # structure = cif_structure(file) 190 | 191 | # u_elements = num_species(structure) 192 | # # print(u_elements) 193 | # for each in u_elements: 194 | # if(each in elements): 195 | # elements[each] += 1 196 | # else: 197 | # elements[each] = 1 198 | 199 | # for each in elements.keys(): 200 | # print(each, elements[each]) 201 | 202 | # # total_unique_species.union(u_elements) 203 | # file_write.write(file) 204 | # file_write.write(" ") 205 | # file_write.write(str(u_elements)) 206 | # file_write.write("\n") 207 | # print(u_elements) 208 | # file_write.close() 209 | # print("*"*80) 210 | 211 | # print("Total number of unique elements in the dataset: ", len(total_unique_species)) 212 | # print(total_unique_species) 213 | 214 | 215 | main() 216 | -------------------------------------------------------------------------------- /Generative_Model/MOFGan.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.autograd as autograd 4 | import torch.nn as nn 5 | 6 | from GAN import Generator, Discriminator 7 | from mof_dataset import MOFDataset 8 | 9 | cuda = True if torch.cuda.is_available() else False 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | class Generator(nn.Module): 14 | def __init__(self, num_channels, side_length): 15 | super(Generator, self).__init__() 16 | self.num_channels = num_channels 17 | self.side_length = side_length 18 | 19 | self.main = nn.Sequential( 20 | # in_channels, out_channels, kernel_size, stride, padding 21 | nn.ConvTranspose3d(1024, side_length * 8, 4, 2, 0), 22 | nn.BatchNorm3d(side_length * 8), 23 | nn.ReLU(), 24 | 25 | nn.ConvTranspose3d(side_length * 8, side_length * 4, 4, 2, 1), 26 | nn.BatchNorm3d(side_length * 4), 27 | nn.ReLU(), 28 | 29 | nn.ConvTranspose3d(side_length * 4, side_length * 2, 4, 2, 1), 30 | nn.BatchNorm3d(side_length * 2), 31 | nn.ReLU(), 32 | 33 | nn.ConvTranspose3d(side_length * 2, num_channels, 4, 2, 1), # 34 | nn.Sigmoid() 35 | ) 36 | 37 | def forward(self, x): 38 | x = x.view(x.size(0), x.size(1), 1, 1, 1) 39 | return self.main(x) 40 | 41 | 42 | class Discriminator(nn.Module): 43 | def __init__(self, num_channels, grid_size): 44 | super(Discriminator, self).__init__() 45 | self.num_channels = num_channels 46 | self.grid_size = grid_size 47 | 48 | # in_channels, out_channels, kernel_size, stride, padding 49 | self.main = nn.Sequential( 50 | nn.Conv3d(num_channels, grid_size, 4, 2, 1), 51 | nn.BatchNorm3d(grid_size), 52 | nn.LeakyReLU(0.2), 53 | 54 | nn.Conv3d(grid_size, grid_size * 2, 4, 2, 1), 55 | nn.BatchNorm3d(grid_size * 2), 56 | nn.LeakyReLU(0.2), 57 | 58 | nn.Conv3d(grid_size * 2, grid_size * 4, 4, 2, 1), 59 | nn.BatchNorm3d(grid_size * 4), 60 | nn.LeakyReLU(0.2), 61 | 62 | nn.Conv3d(grid_size * 4, grid_size * 8, 4, 2, 1), 63 | nn.BatchNorm3d(grid_size * 8), 64 | nn.LeakyReLU(0.2), 65 | 66 | nn.Conv3d(grid_size * 8, 1, 4, 2, 1), # kernel size 67 | nn.Sigmoid() 68 | ) 69 | 70 | def forward(self, x): 71 | x = self.main(x) 72 | return x.view(-1, x.size(1)) 73 | 74 | 75 | def compute_gradient_penalty(critic: Discriminator, 76 | real_samples: torch.Tensor, fake_samples: torch.Tensor) -> torch.Tensor: 77 | """Calculates the gradient penalty loss for WGAN GP""" 78 | alpha = torch.from_numpy(np.random.random((real_samples.size(0), 1, 1, 1, 1))).float().to(device) # TODO: Check 79 | interpolates = (alpha * real_samples + ((1 + (-alpha)) * fake_samples)).requires_grad_(True) 80 | d_interpolates = critic(interpolates) 81 | fake = torch.ones(real_samples.shape[0], 1).requires_grad_().to(device) 82 | gradients = autograd.grad( 83 | outputs=d_interpolates, 84 | inputs=interpolates, 85 | grad_outputs=fake, 86 | create_graph=True, 87 | retain_graph=True, 88 | only_inputs=True, 89 | )[0] 90 | gradients = gradients.view(gradients.size(0), -1) 91 | return ((gradients.norm(2, dim=1) - 1) ** 2).mean() 92 | 93 | 94 | def main(): 95 | # HYPERPARAMETERS 96 | epochs = 500 97 | batch_size = 32 98 | alpha_gp = 10 99 | learning_rate = 0.0001 100 | beta1 = 0.5 101 | beta2 = 0.9 102 | critic_updates_per_generator_update = 1 103 | 104 | # MOFS 105 | num_atoms = 12 106 | grid_size = 32 107 | 108 | train_loader = MOFDataset.get_data_loader("../3D_Grid_Data/Test_MOFS.p", batch_size) 109 | 110 | # Initialize generator and discriminator 111 | generator: Generator = Generator(num_atoms, grid_size) 112 | discriminator: Discriminator = Discriminator(num_atoms, grid_size) 113 | 114 | if cuda: 115 | generator.cuda() 116 | discriminator.cuda() 117 | 118 | generator_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, beta2)) 119 | discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, beta2)) 120 | 121 | batches_done = 0 122 | for epoch in range(epochs): 123 | for batch, mof in enumerate(train_loader): 124 | 125 | real_images = mof.to(device) 126 | discriminator_optimizer.zero_grad() 127 | numpy_array = np.random.normal(0, 1, (real_images.shape[0], 1024)) 128 | 129 | z = torch.from_numpy(numpy_array).float().requires_grad_().to(device) 130 | 131 | fake_images = generator(z) 132 | real_validity = discriminator(real_images) 133 | fake_validity = discriminator(fake_images) 134 | gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, fake_images.data) 135 | d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + alpha_gp * gradient_penalty 136 | 137 | d_loss.backward() 138 | discriminator_optimizer.step() 139 | generator_optimizer.zero_grad() 140 | 141 | if batch % critic_updates_per_generator_update == 0: 142 | fake_images = generator(z) 143 | fake_validity = discriminator(fake_images) 144 | g_loss = -torch.mean(fake_validity) 145 | 146 | g_loss.backward() 147 | generator_optimizer.step() 148 | 149 | if batch % 16 == 0: 150 | print( 151 | "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" 152 | % (epoch, epochs, batch, len(train_loader), d_loss.item(), g_loss.item()) 153 | ) 154 | 155 | # if batch == 0: 156 | # print(f"[Epoch {epoch}/{epochs}] [D loss: {d_loss.item()}] [G loss: {g_loss.item()}]") 157 | 158 | if batches_done % 20 == 0: 159 | pass 160 | # print("Generated Structure: ") 161 | # torch.set_printoptions(profile="full") 162 | # print(fake_images[0].shape) 163 | 164 | batches_done += 1 165 | 166 | 167 | if __name__ == '__main__': 168 | # G = Generator(12, 64) 169 | # z = torch.rand(16, 1024, 1, 1, 1, requires_grad=True) 170 | # X = G(z) 171 | # print(X.shape) 172 | main() 173 | # with open('gan-processed.p', 'rb') as f: 174 | # data = pickle.load(f) 175 | # print(data[0].shape) 176 | -------------------------------------------------------------------------------- /Generative_Model/mof_wgan_gp_multi_channel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | from enum import Enum 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.autograd as autograd 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader, TensorDataset 13 | 14 | from gan_logger import GANLogger 15 | from mof_dataset_v2 import MOFDatasetV2 16 | from sphere_dataset import SphereDataset 17 | 18 | folder = Path("mof_wgan_gp_multi_channel") 19 | images_folder = folder / "images" 20 | os.makedirs(images_folder, exist_ok=True) 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--batch_size", type=int, default=16, help="size of the batches") 24 | parser.add_argument("--latent_dim", type=int, default=1024, help="dimensionality of the latent space") 25 | parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter") 26 | parser.add_argument("--sample_interval", type=int, default=100, help="interval between image samples") 27 | 28 | adam_b1 = 0.5 29 | adam_b2 = 0.9 # or 0.999 30 | clip_value = 0.01 31 | 32 | opt = parser.parse_args() 33 | GANLogger.log(opt) 34 | 35 | grid_size = 32 36 | channels = 1 37 | img_shape = (channels, grid_size, grid_size, grid_size) 38 | 39 | cuda = True if torch.cuda.is_available() else False 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | 42 | 43 | # class Generator(nn.Module): 44 | # def __init__(self): 45 | # super(Generator, self).__init__() 46 | # 47 | # def block(in_feat, out_feat, normalize=True): 48 | # layers = [nn.Linear(in_feat, out_feat)] 49 | # if normalize: 50 | # layers.append(nn.BatchNorm1d(out_feat, 0.8)) 51 | # layers.append(nn.LeakyReLU(0.2, inplace=True)) 52 | # return layers 53 | # 54 | # self.model = nn.Sequential( 55 | # *block(opt.latent_dim, 128, normalize=False), 56 | # *block(128, 256), 57 | # *block(256, 512), 58 | # *block(512, 1024), 59 | # nn.Linear(1024, int(np.prod(img_shape))), 60 | # nn.Tanh() 61 | # ) 62 | # 63 | # def forward(self, z): 64 | # img = self.model(z) 65 | # img = img.view(img.shape[0], *img_shape) 66 | # return img 67 | 68 | 69 | class Generator(nn.Module): 70 | def __init__(self): 71 | super(Generator, self).__init__() 72 | 73 | kernel_size = 4 74 | stride = 2 75 | padding = 1 76 | 77 | self.latent_to_features = nn.Sequential( 78 | nn.Linear(opt.latent_dim, 8 * grid_size * channels * 2 * 2 * 2), 79 | nn.ReLU() 80 | ) 81 | 82 | self.features_to_image = nn.Sequential( 83 | nn.ConvTranspose3d(grid_size * 8, grid_size * 4, kernel_size, stride, padding), 84 | nn.BatchNorm3d(grid_size * 4), 85 | nn.ReLU(), 86 | 87 | nn.ConvTranspose3d(grid_size * 4, grid_size * 2, kernel_size, stride, padding), 88 | nn.BatchNorm3d(grid_size * 2), 89 | nn.ReLU(), 90 | 91 | nn.ConvTranspose3d(grid_size * 2, grid_size, kernel_size, stride, padding), 92 | nn.BatchNorm3d(grid_size), 93 | nn.ReLU(), 94 | 95 | nn.ConvTranspose3d(grid_size, channels, kernel_size, stride, padding), 96 | nn.Tanh(), 97 | # nn.Sigmoid(), 98 | ) 99 | 100 | def forward(self, z): 101 | # print("ZIN:", z.shape) 102 | z = self.latent_to_features(z) 103 | # print("ZIN2:", z.shape) 104 | z = z.view(z.shape[0], channels * grid_size * 8, 2, 2, 2) 105 | # print("ZIN3:", z.shape) 106 | return self.features_to_image(z) 107 | 108 | 109 | # class Discriminator(nn.Module): 110 | # def __init__(self): 111 | # super(Discriminator, self).__init__() 112 | # 113 | # # padding_amount = grid_size // 10 114 | # # padding_size = (padding_amount * 2) ** 2 115 | # # self.periodic_pad = PeriodicPad3d(padding_amount) 116 | # 117 | # # DON'T USE BATCH NORM WITH GP 118 | # self.model = nn.Sequential( 119 | # nn.Linear(int(np.prod(img_shape)), 512), 120 | # nn.LeakyReLU(0.2, inplace=True), 121 | # 122 | # nn.Linear(512, 256), 123 | # nn.LeakyReLU(0.2, inplace=True), 124 | # 125 | # nn.Linear(256, 1), 126 | # ) 127 | # 128 | # def forward(self, x: torch.Tensor): 129 | # # print("BEFORE:", x.shape) 130 | # # x = self.periodic_pad(x) 131 | # # print("AFTER:", x.shape) 132 | # return self.model(x.view(x.shape[0], -1)) # Shape: [batch_size x 1] 133 | 134 | 135 | class Discriminator(nn.Module): 136 | def __init__(self): 137 | super(Discriminator, self).__init__() 138 | 139 | in_channels = int(np.prod(img_shape)) 140 | kernel = 5 141 | stride = 2 142 | padding = 3 143 | print("In Channels:", in_channels) 144 | 145 | self.model = nn.Sequential( # DON'T USE BATCH NORM WITH GP 146 | nn.Conv3d(channels, grid_size, kernel, stride, padding, padding_mode='circular'), 147 | nn.LayerNorm([grid_size, 16, 16, 16]), 148 | nn.LeakyReLU(0.2), 149 | # nn.Dropout3d(0.5), # TODO: Where should this go 150 | 151 | nn.Conv3d(grid_size, grid_size * 2, kernel, stride, padding, padding_mode='circular'), 152 | nn.LayerNorm([grid_size * 2, 8, 8, 8]), 153 | nn.LeakyReLU(0.2), 154 | 155 | nn.Conv3d(grid_size * 2, grid_size * 4, kernel, stride, padding, padding_mode='circular'), 156 | nn.LayerNorm([grid_size * 4, 4, 4, 4]), 157 | nn.LeakyReLU(0.2), 158 | 159 | nn.Conv3d(grid_size * 4, grid_size * 8, kernel, stride, padding, padding_mode='circular'), 160 | nn.LayerNorm([grid_size * 8, 2, 2, 2]), 161 | nn.LeakyReLU(0.2), 162 | # nn.Sigmoid(), 163 | 164 | # Flatten then linear 165 | ) 166 | 167 | self.final = nn.Sequential( 168 | nn.Linear(grid_size * 8 * 8, 1), 169 | # nn.Sigmoid(), 170 | ) 171 | 172 | def forward(self, x): # Input: [batch_size x channels x grid_size x grid_size x grid_size] 173 | # print("DISC INPUT:", x.shape) 174 | x = self.model(x) 175 | # print("DISC OUTPUT:", x.shape) 176 | x = self.final(x.view(x.shape[0], -1)) 177 | return x 178 | 179 | 180 | # class Discriminator(nn.Module): # HYBRID DISCRIMINATOR 181 | # def __init__(self): 182 | # super(Discriminator, self).__init__() 183 | # 184 | # kernel = 5 185 | # stride = 2 186 | # padding = 3 187 | # 188 | # self.model = nn.Sequential( # DON'T USE BATCH NORM WITH GP 189 | # nn.Conv3d(channels, grid_size, kernel, stride, padding, padding_mode='circular'), 190 | # nn.LayerNorm([grid_size, 16, 16, 16]), 191 | # nn.LeakyReLU(0.2), 192 | # ) 193 | # 194 | # self.final = nn.Sequential( 195 | # nn.Linear(grid_size * (16 ** 3), 512), 196 | # nn.LeakyReLU(0.2, inplace=True), 197 | # 198 | # nn.Linear(512, 256), 199 | # nn.LeakyReLU(0.2, inplace=True), 200 | # 201 | # nn.Linear(256, 1), 202 | # ) 203 | # 204 | # def forward(self, x): # Input: [batch_size x channels x grid_size x grid_size x grid_size] 205 | # # print("DISC INPUT:", x.shape) 206 | # x = self.model(x) 207 | # # print("DISC OUTPUT:", x.shape) 208 | # x = self.final(x.view(x.shape[0], -1)) 209 | # return x 210 | 211 | 212 | def init_weights(m): 213 | classname = m.__class__.__name__ 214 | if type(m) == nn.Linear: 215 | torch.nn.init.xavier_uniform_(m.weight) 216 | m.bias.data.fill_(0.01) 217 | elif classname.find('Conv') != -1: 218 | nn.init.normal_(m.weight.data, 0.0, 0.02) 219 | elif classname.find('BatchNorm') != -1: 220 | nn.init.normal_(m.weight.data, 1.0, 0.02) 221 | nn.init.constant_(m.bias.data, 0) 222 | 223 | 224 | # Loss weight for gradient penalty 225 | lambda_gp = 10 226 | 227 | start = time.time() 228 | 229 | 230 | # dataset = MOFDataset("_data/Test_MOFS.p") 231 | # tmp: Voxel_MOF = dataset.data[0] 232 | # print(tmp.grid_metadata) 233 | # print(type(tmp.grid_tensor)) 234 | # with open("output", 'w+') as f: 235 | # for i in range(32): 236 | # for j in range(32): 237 | # for k in range(32): 238 | # f.write(f"{i} {j} {k} {tmp.grid_tensor[i][j][k]}\n") 239 | # print(PropertyCalculations.get_henrys_constant(tmp.grid_tensor)) 240 | # 241 | # print("DIE!") 242 | # exit(0) 243 | 244 | class DatasetMode(Enum): 245 | TRAIN = 1 246 | TEST = 2 247 | SPHERE = 3 248 | 249 | 250 | dataset_mode = DatasetMode.TRAIN 251 | # dataset_mode = DatasetMode.TEST 252 | 253 | if dataset_mode == DatasetMode.TRAIN: 254 | # data_loader = MOFDataset.get_data_loader("_data/Training_MOFS.p", batch_size=opt.batch_size) 255 | data_loader = MOFDatasetV2.get_data_loader("_data/Training_MOFS_v2.p", batch_size=opt.batch_size) 256 | elif dataset_mode == DatasetMode.TEST: 257 | # data_loader = MOFDataset.get_data_loader("_data/Test_MOFS.p", batch_size=opt.batch_size) 258 | data_loader = MOFDatasetV2.get_data_loader("_data/Test_MOFS_v2.p", batch_size=opt.batch_size) 259 | elif dataset_mode == DatasetMode.SPHERE: 260 | data_loader = DataLoader( 261 | # TensorDataset(SphereDataset.generate(200)), 262 | TensorDataset(SphereDataset.generate_complex(200, 3)), 263 | batch_size=opt.batch_size, 264 | shuffle=True, 265 | ) 266 | 267 | # exit(0) 268 | 269 | # print(data_loader[0]) 270 | print("LOAD TIME:", (time.time() - start)) 271 | # exit(0) 272 | # for i, item in enumerate(data_loader): 273 | # print(type(item), item) 274 | # if i >= 0: 275 | # exit(0) 276 | 277 | # Initialize generator and discriminator 278 | generator = Generator() 279 | discriminator = Discriminator() 280 | generator.apply(init_weights) 281 | discriminator.apply(init_weights) 282 | 283 | if cuda: 284 | generator.cuda() 285 | discriminator.cuda() 286 | 287 | GANLogger.log(generator, discriminator) 288 | 289 | # Optimizers 290 | # glr = 0.00001 291 | # dlr = 0.0000004 292 | # glr = 0.002 293 | # dlr = 0.002 294 | glr = 0.0001 295 | dlr = 0.0001 296 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=glr, betas=(adam_b1, adam_b2)) 297 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=dlr, betas=(adam_b1, adam_b2)) 298 | 299 | 300 | def compute_gradient_penalty(disc, real_samples, fake_samples): 301 | """Calculates the gradient penalty loss for WGAN GP""" 302 | # Random weight term for interpolation between real and fake samples 303 | # alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) 304 | alpha = torch.from_numpy(np.random.random((real_samples.size(0), 1, 1, 1, 1))).float().to(device) 305 | # Get random interpolation between real and fake samples 306 | # interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 307 | interpolates = (alpha * real_samples + ((-alpha + 1) * fake_samples)).requires_grad_(True) 308 | d_interpolates = disc(interpolates) 309 | # fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False) 310 | fake = torch.ones(real_samples.shape[0], 1).to(device) 311 | # Get gradient w.r.t. interpolates 312 | gradients = autograd.grad( 313 | outputs=d_interpolates, 314 | inputs=interpolates, 315 | grad_outputs=fake, 316 | create_graph=True, 317 | retain_graph=True, 318 | only_inputs=True, 319 | )[0] 320 | gradients = gradients.view(gradients.size(0), -1) 321 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 322 | return gradient_penalty 323 | 324 | 325 | # ---------- 326 | # Training 327 | # ---------- 328 | 329 | 330 | def main(): 331 | fake_store = None 332 | previous_real = None 333 | 334 | batches_done = 0 335 | epochs = 1000 336 | for epoch in range(epochs): 337 | for i, images in enumerate(data_loader): 338 | if dataset_mode == DatasetMode.SPHERE: 339 | images: torch.Tensor = images[0] # Sphere dataset ONLY 340 | real_images = images.to(device).requires_grad_(True) 341 | # print(images.shape) 342 | 343 | # --------------------- 344 | # Train Discriminator 345 | # --------------------- 346 | 347 | optimizer_D.zero_grad() 348 | 349 | # Sample noise as generator input 350 | # noise = Variable(Tensor(np.random.normal(0, 1, (images.shape[0], opt.latent_dim)))) 351 | noise = torch.from_numpy(np.random.normal(0, 1, (images.shape[0], opt.latent_dim))) \ 352 | .float().requires_grad_(True).to(device) 353 | 354 | # noise = torch.from_numpy(np.random.normal(0, 1, (images.shape[0], opt.latent_dim))) \ 355 | # .float().requires_grad_(True).to(device) 356 | 357 | # Generate a batch of images 358 | fake_images = generator(noise) 359 | # print("FAKE IMAGE SHAPE:", fake_images.shape) 360 | # When we feed the same thing over and over again 361 | # if fake_store is None: 362 | # fake_images = fake_store = generator(noise) 363 | # else: 364 | # fake_images = fake_store.clone().detach()[:images.shape[0]] 365 | 366 | real_pred = discriminator(real_images) # Real images 367 | fake_pred = discriminator(fake_images) # Fake images 368 | 369 | garbage = torch.from_numpy(np.random.normal(0, 1, (images.shape[0], np.prod(img_shape)))) \ 370 | .float().requires_grad_(True).to(device).view(images.shape[0], *img_shape) 371 | garbage_pred = discriminator(garbage) 372 | 373 | if previous_real is not None: 374 | wasserstein_distance_real_only = abs(real_pred.mean() - previous_real.mean()).item() 375 | print(f"REAL ONLY EMD: {wasserstein_distance_real_only}") 376 | previous_real = real_pred 377 | 378 | # TODO: Regularize the outputs -1 1 range? 379 | 380 | # print(f"CURR: 381 | # {str(fake_validity.mean().item()).ljust(21)} - {str(real_validity.mean().item()).ljust(21)} " 382 | # f"= {fake_validity.mean() - real_validity.mean()} ") 383 | # Gradient penalty 384 | gradient_penalty = compute_gradient_penalty(discriminator, real_images.data, fake_images.data) 385 | # Adversarial loss 386 | # We want the critic to maximize the separation between fake and real 387 | wasserstein_distance = abs(fake_pred.mean() - real_pred.mean()).item() 388 | d_loss = fake_pred.mean() - real_pred.mean() + lambda_gp * gradient_penalty 389 | 390 | d_loss.backward() 391 | optimizer_D.step() 392 | 393 | # # Clip weights of discriminator 394 | # for p in discriminator.parameters(): 395 | # p.data.clamp_(-clip_value, clip_value) 396 | 397 | optimizer_G.zero_grad() 398 | 399 | # Train the generator every n_critic steps 400 | if i % opt.n_critic == 0: 401 | # ----------------- 402 | # Train Generator 403 | # ----------------- 404 | 405 | # Generate a batch of images 406 | fake_images = generator(noise) 407 | # Loss measures generator's ability to fool the discriminator 408 | # Train on fake images 409 | fake_pred = discriminator(fake_images) 410 | g_loss = -torch.mean(fake_pred) 411 | 412 | g_loss.backward() 413 | optimizer_G.step() 414 | 415 | print(f"[Epoch {epoch}/{epochs}]".ljust(16) 416 | + f"[Batch {i}/{len(data_loader)}] ".ljust(14) 417 | + f"[-C Loss: {'{:.4f}'.format(-d_loss.item()).rjust(11)}] " 418 | + f"[G Loss: {'{:.4f}'.format(g_loss.item()).rjust(11)}] " 419 | + f"[Wasserstein Distance: {round(wasserstein_distance, 3)}]") 420 | 421 | print("GARBAGE/REAL EMD:", abs(garbage_pred.mean() - real_pred.mean()).item()) 422 | # NOTE: Garbage EMD should theoretically be very high relative to generated/real" 423 | # but we're not training to maximize that, only between generated so I guess it makes sense? 424 | GANLogger.update(-d_loss.item(), g_loss.item()) 425 | 426 | if batches_done % opt.sample_interval == 0: 427 | # print("GENERATOR WEIGHTS:") 428 | # for name, param in generator.named_parameters(): 429 | # print(name, param) 430 | # print("\nDISCRIMINATOR WEIGHTS:") 431 | # for name, param in discriminator.named_parameters(): 432 | # print(name, param) 433 | 434 | with open(f"{images_folder}/{str(batches_done).zfill(5)}.p", "wb+") as f: 435 | pickle.dump(fake_images.cpu(), f) 436 | 437 | batches_done += opt.n_critic 438 | 439 | 440 | def test(): 441 | print("TEST") 442 | tensor: torch.Tensor = data_loader.dataset[0] 443 | print(tensor.max()) 444 | 445 | 446 | if __name__ == '__main__': 447 | title = f"MOF WGAN GP - GLR: {glr}, DLR: {dlr}, S={img_shape}, BS={opt.batch_size}" 448 | GANLogger.init(title, folder) 449 | main() 450 | # test() 451 | 452 | """ 453 | For convolutions: 454 | Log energy values 455 | Scale 0-1 456 | Log then scale? 457 | 458 | Make sure no zeros 459 | """ 460 | --------------------------------------------------------------------------------