├── CrystalVAE.py ├── Discriminator.py ├── Generate.py ├── Generate_Single_Unit_Data.py ├── LICENSE ├── README.md ├── Segmentation.py ├── dataset.py ├── ims ├── Interpolation.png ├── Model.png └── Unit-Cell.png └── main.py /CrystalVAE.py: -------------------------------------------------------------------------------- 1 | # File for repeating lattices. 2 | # Code written by Jordan 3 | from __future__ import print_function 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | import pickle 11 | import numpy as np 12 | from torch.utils.data.dataset import Dataset 13 | from pymatgen.core.structure import Structure 14 | 15 | class Flatten(nn.Module): 16 | ''' 17 | Helper function to flatten a tensor. 18 | ''' 19 | def forward(self, input): 20 | return input.view(input.size(0), -1) 21 | 22 | 23 | class UnFlatten(nn.Module): 24 | # Convert to 3d matrices 25 | def forward(self, input, size=1024): 26 | return input.view(input.size(0), 256, 5, 5, 5) 27 | 28 | class Interpolate(nn.Module): 29 | ''' 30 | Interpolate for upsampling. Use convolution and upsampling 31 | in favor of conv transpose. 32 | ''' 33 | def __init__(self, scale_factor, mode): 34 | super(Interpolate, self).__init__() 35 | self.interp = nn.functional.interpolate 36 | self.scale_factor = scale_factor 37 | self.mode = mode 38 | 39 | def forward(self, x): 40 | x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) 41 | return x 42 | 43 | 44 | class CVAE(nn.Module): 45 | ''' 46 | Crystal VAE. 47 | ''' 48 | def __init__(self, input_channels=1, h_dim=256*(4*4*4),h_dim3=256*(4*4*4),h_dim2=3200, z_dim=300): 49 | super(CVAE, self).__init__() 50 | self.encoder = nn.Sequential( 51 | nn.Conv3d(input_channels, 16, kernel_size=5, stride=2), 52 | nn.BatchNorm3d(16), 53 | nn.LeakyReLU(), 54 | nn.Conv3d(16, 32, kernel_size=3, stride=1), 55 | nn.BatchNorm3d(32), 56 | nn.LeakyReLU(), 57 | nn.Conv3d(32, 64, kernel_size=3, stride=1), 58 | nn.BatchNorm3d(64), 59 | nn.LeakyReLU(), 60 | nn.Conv3d(64, 128, kernel_size=3, stride=2), 61 | nn.BatchNorm3d(128), 62 | nn.LeakyReLU(), 63 | nn.Conv3d(128, 256, kernel_size=3, stride=2), 64 | nn.BatchNorm3d(256), 65 | nn.LeakyReLU(), 66 | Flatten() 67 | ) 68 | self.fc0 = nn.Linear(h_dim, 3000) 69 | 70 | self.fc1 = nn.Linear(3000, z_dim) 71 | self.fc2 = nn.Linear(3000, z_dim) 72 | self.fc3 = nn.Linear(z_dim, h_dim2) 73 | self.fc4 = nn.Linear(h_dim2, h_dim3) 74 | 75 | self.decoder = nn.Sequential( 76 | UnFlatten(), 77 | Interpolate(scale_factor=2,mode='trilinear'), 78 | nn.Conv3d(256, 128, kernel_size=4, stride=1), 79 | nn.BatchNorm3d(128, 1e-3), 80 | nn.LeakyReLU(), 81 | Interpolate(scale_factor=2,mode='trilinear'), 82 | nn.ReplicationPad3d(2), 83 | nn.Conv3d(128, 64, kernel_size=5, stride=1), 84 | nn.BatchNorm3d(64, 1e-3), 85 | nn.LeakyReLU(), 86 | Interpolate(scale_factor=2,mode='trilinear'), 87 | nn.ReplicationPad3d(1), 88 | nn.Conv3d(64, 32, kernel_size=5, stride=1), 89 | nn.BatchNorm3d(32, 1e-3), 90 | nn.LeakyReLU(), 91 | Interpolate(scale_factor=2,mode='trilinear'), 92 | #nn.ReplicationPad3d(2), 93 | nn.Conv3d(32, 16, kernel_size=4, stride=1), 94 | nn.BatchNorm3d(16, 1e-3), 95 | nn.LeakyReLU(), 96 | nn.Conv3d(16, 1, kernel_size=4, stride=1), 97 | nn.BatchNorm3d(1, 1e-3), 98 | nn.ReLU() 99 | ) 100 | 101 | 102 | 103 | def reparameterization(self, mu, logvar): 104 | # Reparamterization trick to backpropograte through the drawing of mu and sigma 105 | device = torch.device("cuda:0") 106 | std = logvar.mul(0.5).exp_() 107 | esp = torch.randn(*mu.size(), device=device) 108 | z = mu + std * esp 109 | return z 110 | 111 | 112 | def bottleneck(self, h): 113 | # Go through the small latent space. 114 | mu, logvar = self.fc1(h), self.fc2(h) 115 | z = self.reparameterization(mu, logvar) 116 | return z, mu, logvar 117 | 118 | def encode(self, x): 119 | # Go through the encoder. 120 | # *Input* is the original grid 121 | # *Output* is a vector to pass through the bottleneck. 122 | h = self.encoder(x) 123 | h = self.fc0(h) 124 | z, mu, logvar = self.bottleneck(h) 125 | return z, mu, logvar 126 | 127 | def decode(self, z, label): 128 | z = self.fc3(z) 129 | z = self.fc4(z) 130 | z = self.decoder(z) 131 | return z 132 | 133 | def forward(self, x): 134 | z, mu, logvar = self.encode(x) 135 | z = self.decode(z) 136 | return z, mu, logvar -------------------------------------------------------------------------------- /Discriminator.py: -------------------------------------------------------------------------------- 1 | # Network used for the discriminator shown in the SI. 2 | # Code written by Jordan 3 | from __future__ import print_function 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | import pickle 11 | import numpy as np 12 | 13 | 14 | class DNet(nn.Module): 15 | def __init__(self): 16 | super(DNet, self).__init__() 17 | self.Network = nn.Sequential( 18 | nn.Conv3d(1, 16, kernel_size=7, stride=1), 19 | nn.BatchNorm3d(16), 20 | nn.LeakyReLU(), 21 | nn.Conv3d(16, 32, kernel_size=5, stride=1), 22 | nn.BatchNorm3d(32), 23 | nn.LeakyReLU(), 24 | nn.Conv3d(32, 64, kernel_size=5, stride=1), 25 | nn.BatchNorm3d(64), 26 | nn.LeakyReLU(), 27 | nn.Conv3d(64, 128, kernel_size=3, stride=1), 28 | nn.BatchNorm3d(128), 29 | nn.LeakyReLU(), 30 | nn.Conv3d(128, 256, kernel_size=3, stride=2), 31 | nn.BatchNorm3d(256), 32 | nn.LeakyReLU(), 33 | nn.Conv3d(256, 512, kernel_size=3, stride=1), 34 | nn.BatchNorm3d(512), 35 | nn.LeakyReLU(), 36 | Flatten(), 37 | nn.Linear(32768, 1000), 38 | nn.LeakyReLU(), 39 | nn.Linear(1000, 100), 40 | nn.LeakyReLU(), 41 | nn.Linear(100, 2), 42 | nn.Softmax() 43 | ) 44 | 45 | def forward(self, x): 46 | x = self.Network(x) 47 | return x -------------------------------------------------------------------------------- /Generate.py: -------------------------------------------------------------------------------- 1 | # File to generate the periodic lattice input data 2 | # Code written by Jordan 3 | from pymatgen.core.structure import Structure 4 | import numpy as np 5 | from pymatgen import Molecule 6 | from pymatgen.analysis.local_env import MinimumVIRENN, VoronoiNN 7 | from pymatgen.core.sites import * 8 | import glob 9 | from mpi4py import MPI 10 | import gzip 11 | import pickle 12 | 13 | 14 | def get_structure(file): 15 | ''' 16 | Input: file name to get the strucutre of 17 | Return: The crystal structure 18 | ''' 19 | crystal = Structure.from_file(file) 20 | return crystal 21 | 22 | def read_in_properties(dataset): 23 | ''' 24 | Input: Properties file 25 | Return: column of interest. 26 | ''' 27 | data = np.genfromtxt(dataset,delimiter=',') 28 | return data[:,1] 29 | 30 | def generate(crystal,a,b,c): 31 | ''' 32 | Input: Crystal and offsets a,b,c for sampling 33 | Outut: Species and density matrix 34 | ''' 35 | Atomic_Numbers = [Element(crystal[i].species_string).Z for i in range(len(crystal))] 36 | lattice = crystal.lattice 37 | mat = np.zeros((30,30,30)) 38 | mat2 = np.zeros((30,30,30)) 39 | sigma = 1.0 40 | bins = np.linspace(0,10,30) 41 | for i in range(30): 42 | for j in range(30): 43 | for k in range(30): 44 | s = PeriodicSite(1,[bins[i]+a,bins[j]+b,bins[k]+c],lattice,coords_are_cartesian=True) 45 | for x in range(len(crystal)): 46 | (dist,im) = crystal[x].distance_and_image(s) 47 | mat[i,j,k] += 1.0/((2.0*np.pi)**1.5)*Atomic_Numbers[x]*(1.0/sigma**3)*np.exp(-dist**2/(2*sigma**2)) 48 | if dist < 0.65: 49 | if mat2[i,j,k] > 0.0: 50 | if np.random.rand() > 0.5: 51 | mat2[i,j,k] = int(Atomic_Numbers[x]) 52 | else: 53 | mat2[i,j,k] = int(Atomic_Numbers[x]) 54 | return mat,mat2 55 | 56 | if __name__=='__main__': 57 | comm = MPI.COMM_WORLD 58 | size = comm.Get_size() 59 | rank = comm.Get_rank() 60 | directory = '/DIRECTORY/' 61 | upper = 46381 62 | delta = int(46381/size) 63 | saveQ = True 64 | 65 | target = read_in_properties(directory+'./id_prop.csv') 66 | print(target[0:7]) 67 | 68 | files = glob.glob(directory+'*') 69 | np.random.seed(1) 70 | file_list = np.random.permutation(np.arange(0,upper+1)) 71 | start = delta*rank 72 | stop = start + delta 73 | if rank == (size-1): 74 | stop = 46381 75 | 76 | to_save = [] 77 | for i in range(start,stop): 78 | if rank == 0: 79 | print("I am on "+str(i)+" and I am going to "+str(stop)) 80 | crystal = get_structure('/DIRECTORY/'+str(file_list[i])+'.cif') 81 | abc = np.array(crystal.lattice.abc) 82 | if min(abc) < 10: 83 | electron_density,electron_density2 = generate(crystal,0,0,0) 84 | to_save.append([file_list[i],electron_density,electron_density2,target[file_list[i]]]) 85 | electron_density,electron_density2 = generate(crystal,10*np.random.rand(),10*np.random.rand(),10*np.random.rand()) 86 | to_save.append([file_list[i],electron_density,electron_density2,target[file_list[i]]]) 87 | electron_density,electron_density2 = generate(crystal,10*np.random.rand(),10*np.random.rand(),10*np.random.rand()) 88 | to_save.append([file_list[i],electron_density,electron_density2,target[file_list[i]]]) 89 | 90 | 91 | if saveQ == True: 92 | with open('/SAVE/DIRECTORY/S_3x_'+str(rank)+'.pickle', 'wb') as f: 93 | print('SAVING') 94 | pickle.dump(to_save, f) 95 | print('SAVED OK') 96 | -------------------------------------------------------------------------------- /Generate_Single_Unit_Data.py: -------------------------------------------------------------------------------- 1 | # File to generate the periodic lattice input data 2 | # Code written by Jordan 3 | from pymatgen.core.structure import Structure 4 | import numpy as np 5 | from pymatgen import Molecule 6 | from pymatgen.analysis.local_env import MinimumVIRENN, VoronoiNN 7 | from pymatgen.core.sites import * 8 | import glob 9 | from mpi4py import MPI 10 | import gzip 11 | import pickle 12 | 13 | 14 | def get_structure(file): 15 | ''' 16 | Input: file name to get the strucutre of 17 | Return: The crystal structure 18 | ''' 19 | crystal = Structure.from_file(file) 20 | return crystal 21 | 22 | def read_in_properties(dataset): 23 | ''' 24 | Input: Properties file 25 | Return: column of interest. 26 | ''' 27 | data = np.genfromtxt(dataset,delimiter=',') 28 | return data[:,1] 29 | 30 | def generate(crystal,a,b,c): 31 | ''' 32 | Input: Crystal and offsets a,b,c for sampling [not used] 33 | Outut: Species and density matrix 34 | ''' 35 | abc = np.array(crystal.lattice.abc) 36 | Zs = np.zeros(len(crystal)) 37 | XYZs = np.zeros((len(crystal),3)) 38 | for i in range(len(crystal)): 39 | Element_ID = Element(crystal[i].species_string).Z 40 | x = crystal[i].x 41 | y = crystal[i].y 42 | z = crystal[i].z 43 | XYZ = np.array([x,y,z]) 44 | Zs[i] = Element_ID 45 | XYZs[i] = XYZ 46 | mean = np.mean(XYZs,axis=0) 47 | shift = np.array([5,5,5]) - mean 48 | XYZs_shifted = np.copy(XYZs) 49 | for i in range(len(XYZs_shifted)): 50 | XYZs_shifted[i] += shift 51 | max_v = np.amax(XYZs_shifted) 52 | min_v = np.amin(XYZs_shifted) 53 | if max_v < 10.0 and min_v > 0.0: 54 | mat = np.zeros((30,30,30)) 55 | mat2 = np.zeros((30,30,30)) 56 | bins = np.linspace(0,10,30) 57 | sigma = 1.0 58 | for i in range(30): 59 | for j in range(30): 60 | for k in range(30): 61 | coordinate = np.array([bins[i],bins[j],bins[k]]) 62 | for x in range(len(XYZs_shifted)): 63 | dist = np.linalg.norm(coordinate - XYZs_shifted[x]) 64 | mat[i,j,k] += 1.0/((2.0*np.pi)**1.5)*Zs[x]*(1.0/sigma**3)*np.exp(-dist**2/(2*sigma**2)) 65 | if dist < 0.667: 66 | if mat2[i,j,k] > 0.0: 67 | print("Overwrite?") 68 | if np.random.rand() > 0.5: 69 | mat2[i,j,k] = int(Zs[x]) 70 | else: 71 | mat2[i,j,k] = int(Zs[x]) 72 | return mat,mat2,Zs,XYZs_shifted 73 | else: 74 | return 0,0,0,0 75 | 76 | 77 | 78 | if __name__=='__main__': 79 | comm = MPI.COMM_WORLD 80 | size = comm.Get_size() 81 | rank = comm.Get_rank() 82 | print('I am size ',rank) 83 | directory = '/scratch3/jordan/Mila/Crystal-Project/dataset/code/data/' 84 | upper = 46381 85 | delta = int(upper/size) 86 | saveQ = True 87 | target = read_in_properties(directory+'./id_prop.csv') 88 | print(target[0:7]) 89 | files = glob.glob(directory+'*') 90 | np.random.seed(1) 91 | file_list = np.random.permutation(np.arange(0,upper+1)) 92 | start = delta*rank 93 | stop = start + delta 94 | if rank == (size-1): 95 | # Ensure the last one is done 96 | stop = 46381 97 | 98 | to_save = [] 99 | for i in range(start,stop): 100 | if rank == 0: 101 | print("I am on "+str(i)+" and I am going to "+str(stop)) 102 | crystal = get_structure('/scratch3/jordan/Mila/Crystal-Project/dataset/code/data/'+str(file_list[i])+'.cif') 103 | 104 | electron_density,species_mat,Zs,XYZs_shifted = generate(crystal,0,0,0) 105 | if len(np.shape(electron_density)) != 0: 106 | # Ignore things that were [0,0,0,0] 107 | to_save.append([file_list[i],electron_density,species_mat,target[file_list[i]],Zs,XYZs_shifted]) 108 | 109 | 110 | if saveQ == True: 111 | with open('/DIRECTORY/Unit_'+str(rank)+'.pickle', 'wb') as f: 112 | print('SAVING ',rank) 113 | pickle.dump(to_save, f) 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jordan Hoffmann 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data-Driven Approach to Encoding and Decoding 3-D Crystal Structures 2 | This is the code accompanying [Data-Driven Approach to Encoding and Decoding 3-D Crystal Structures](https://arxiv.org/abs/1909.00949). 3 | 4 | [Jordan Hoffmann](https://jhoffmann.org/), Louis Maestrati, Yoshihide Sawada, [Jian Tang](https://jian-tang.com), 5 | Jean Michel Sellier, and [Yoshua Bengio](https://mila.quebec/en/yoshua-bengio/) 6 | 7 | Click on the image below to see a video highlighting some of the results. 8 | [![Video](https://img.youtube.com/vi/ZpFN5tSo5Pg/0.jpg)](https://www.youtube.com/watch?v=ZpFN5tSo5Pg) 9 | 10 | # Code 11 | There are two types of code in this repository, each detailed below. 12 | The code is written using `pytorch`. For data generation, we use `mpi4py`. For trianing, we used `Tesla V100` [16 GB]. 13 | Requirements: 14 | ```bash 15 | python 3.7 16 | pytorch 1.1.0 17 | pymatgen 18 | mpi4py // Can easily be omitted. 19 | ``` 20 | 21 | ## Data Availability 22 | For now, please email me at`echo ude.dravrah.g@nnamffohj|rev`. 23 | 24 | 25 | ## Data Generation 26 | Use `generate.py` and `generate_unit.py` for generating repeating lattices and unit cell representations. 27 | We use the library mpi4py to compute these representations in parallel. The input should be a list of files 28 | in `.cif` format. 29 | ```bash 30 | > mpiexec -n 32 python generate.py 31 | ``` 32 | 33 | 34 | ## Crystal-VAE 35 | ```bash 36 | > python main.py --lr 0.0000001 --epochs 100 37 | ``` 38 | 39 | # Citation 40 | If you use the code or the paper, please cite: 41 | ``` 42 | @article{hoffmann2019data, 43 | title={Data-Driven Approach to Encoding and Decoding 3-D Crystal Structures}, 44 | author={Hoffmann, Jordan and Maestrati, Louis and Sawada, Yoshihide and Tang, Jian and Sellier, Jean Michel and Bengio, Yoshua}, 45 | journal={arXiv preprint arXiv:1909.00949}, 46 | year={2019} 47 | } 48 | ``` 49 | 50 | 51 | # Article Description 52 | [![Video0](https://img.youtube.com/vi/3qVVew7-DgQ/0.jpg)](https://www.youtube.com/watch?v=3qVVew7-DgQ) 53 | In this article, we encode and decode 3-D crystal structures. We use a variational autoencode 54 | and a U-Net segmentation model to (1) encode and decode a density field representing the locations 55 | and species of atoms in a crystal and (2) segment the decoded density field into different atomic species. 56 | ![model](./ims/Model.png) 57 | By coupling these two tasks, we are able to very accurately encode and decode our representations of crystals. 58 | ![Unit](./ims/Unit-Cell.png) 59 | We consider two representations of crystals: in the first, we consider a single unit cell. In the second, we consider 60 | repeating unit cells that can have over 200 different atoms to encode and decode. We use the same network for both approaches. 61 | Using the latent space vectors, we can interpolate between different molecules in both frameworks as shown in the following videos: 62 | 63 | [![Video1](https://img.youtube.com/vi/3yPVdgd2mQ0/0.jpg)](https://www.youtube.com/watch?v=3yPVdgd2mQ0) 64 | [![Video2](https://img.youtube.com/vi/q2d8LZq8RW4/0.jpg)](https://www.youtube.com/watch?v=q2d8LZq8RW4) 65 | [![Video3](https://img.youtube.com/vi/pxYb8cnLxio/0.jpg)](https://www.youtube.com/watch?v=pxYb8cnLxio) 66 | [![Video4](https://img.youtube.com/vi/U5-x3jL2zcc/0.jpg)](https://www.youtube.com/watch?v=U5-x3jL2zcc) 67 | 68 | ![Interpolation](./ims/Interpolation.png) 69 | 70 | -------------------------------------------------------------------------------- /Segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | import pickle 9 | import numpy as np 10 | from torch.utils.data.dataset import Dataset 11 | from pymatgen.core.structure import Structure 12 | 13 | # Code adpated by Jordan from: https://github.com/LeeJunHyun/Image_Segmentation/blob/master/network.py 14 | # Changed from 2-D to 3-D 15 | 16 | class Attention3D(nn.Module): 17 | def __init__(self,F_g,F_l,F_int): 18 | super(Attention3D,self).__init__() 19 | self.W_g = nn.Sequential( 20 | nn.Conv3d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True), 21 | nn.BatchNorm3d(F_int) 22 | ) 23 | 24 | self.W_x = nn.Sequential( 25 | nn.Conv3d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True), 26 | nn.BatchNorm3d(F_int) 27 | ) 28 | 29 | self.psi = nn.Sequential( 30 | nn.Conv3d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True), 31 | nn.BatchNorm3d(1), 32 | nn.Sigmoid() 33 | ) 34 | 35 | self.relu = nn.ReLU(inplace=True) 36 | 37 | def forward(self,g,x): 38 | g1 = self.W_g(g) 39 | x1 = self.W_x(x) 40 | psi = self.relu(g1+x1) 41 | psi = self.psi(psi) 42 | return x*psi 43 | 44 | class conv_block3D(nn.Module): 45 | def __init__(self,ch_in,ch_out): 46 | super(conv_block3D,self).__init__() 47 | self.Conv3D_ = nn.Sequential( 48 | nn.Conv3d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 49 | nn.BatchNorm3d(ch_out), 50 | nn.ReLU(inplace=True), 51 | nn.Conv3d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True), 52 | nn.BatchNorm3d(ch_out), 53 | nn.ReLU(inplace=True) 54 | ) 55 | 56 | 57 | def forward(self,x): 58 | x = self.Conv3D_(x) 59 | return x 60 | 61 | class conv_block3D2(nn.Module): 62 | def __init__(self,ch_in,ch_out): 63 | super(conv_block3D2,self).__init__() 64 | self.Conv3D_ = nn.Sequential( 65 | nn.Conv3d(ch_in, ch_out, kernel_size=4,stride=1,padding=1,bias=True), 66 | nn.BatchNorm3d(ch_out), 67 | nn.ReLU(inplace=True), 68 | nn.Conv3d(ch_out, ch_out, kernel_size=4,stride=1,padding=1,bias=True), 69 | nn.BatchNorm3d(ch_out), 70 | nn.ReLU(inplace=True) 71 | ) 72 | 73 | 74 | def forward(self,x): 75 | x = self.Conv3D_(x) 76 | return x 77 | 78 | 79 | class up_conv3D(nn.Module): 80 | def __init__(self,ch_in,ch_out): 81 | super(up_conv3D,self).__init__() 82 | self.up = nn.Sequential( 83 | nn.Upsample(scale_factor=2), 84 | nn.Conv3d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True), 85 | nn.BatchNorm3d(ch_out), 86 | nn.ReLU(inplace=True) 87 | ) 88 | 89 | def forward(self,x): 90 | x = self.up(x) 91 | return x 92 | 93 | class Interpolate(nn.Module): 94 | def __init__(self, scale_factor, mode): 95 | super(Interpolate, self).__init__() 96 | self.interp = nn.functional.interpolate 97 | self.scale_factor = scale_factor 98 | self.mode = mode 99 | 100 | def forward(self, x): 101 | x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) 102 | return x 103 | 104 | class AttU_Net3D(nn.Module): 105 | def __init__(self,input_ch=1,output_ch=95): #Number of classes 106 | super(AttU_Net3D,self).__init__() 107 | 108 | self.Maxpool3D = nn.MaxPool3d(kernel_size=2,stride=2) 109 | 110 | self.pad = nn.ReplicationPad3d(1) 111 | self.sig = nn.Sigmoid() 112 | self.Conv3D_1 = conv_block3D(ch_in=input_ch,ch_out=64) 113 | self.Conv3D_2 = conv_block3D(ch_in=64,ch_out=128) 114 | self.Conv3D_3 = conv_block3D(ch_in=128,ch_out=256) 115 | self.Conv3D_4 = conv_block3D(ch_in=256,ch_out=512) 116 | self.Conv3D_5 = conv_block3D(ch_in=512,ch_out=1024) 117 | 118 | self.Up5 = up_conv3D(ch_in=1024,ch_out=512) 119 | self.Att3D_5 = Attention3D(F_g=512,F_l=512,F_int=256) 120 | self.Up3D_conv5 = conv_block3D(ch_in=1024, ch_out=512) 121 | 122 | self.Up4 = up_conv3D(ch_in=512,ch_out=256) 123 | self.Att3D_4 = Attention3D(F_g=256,F_l=256,F_int=128) 124 | self.Up3D_conv4 = conv_block3D(ch_in=512, ch_out=256) 125 | 126 | self.Up3 = up_conv3D(ch_in=256,ch_out=128) 127 | self.Att3D_3 = Attention3D(F_g=128,F_l=128,F_int=64) 128 | self.Up3D_conv3 = conv_block3D(ch_in=256, ch_out=128) 129 | 130 | self.Up2 = up_conv3D(ch_in=128,ch_out=64) 131 | self.Att3D_2 = Attention3D(F_g=64,F_l=64,F_int=32) 132 | self.Up3D_conv2 = conv_block3D2(ch_in=128, ch_out=64) 133 | 134 | self.Conv3D_1x1 = nn.Conv3d(64,output_ch,kernel_size=1,stride=1,padding=0) 135 | 136 | 137 | def forward(self,x): 138 | x = self.pad(x) #Needed for 2^N input 139 | x1 = self.Conv3D_1(x) 140 | x2 = self.Maxpool3D(x1) 141 | x2 = self.Conv3D_2(x2) 142 | x3 = self.Maxpool3D(x2) 143 | x3 = self.Conv3D_3(x3) 144 | x4 = self.Maxpool3D(x3) 145 | x4 = self.Conv3D_4(x4) 146 | x5 = self.Maxpool3D(x4) 147 | x5 = self.Conv3D_5(x5) 148 | d5 = self.Up5(x5) 149 | x4 = self.Att3D_5(g=d5,x=x4) 150 | d5 = torch.cat((x4,d5),dim=1) 151 | d5 = self.Up3D_conv5(d5) 152 | d4 = self.Up4(d5) 153 | x3 = self.Att3D_4(g=d4,x=x3) 154 | d4 = torch.cat((x3,d4),dim=1) 155 | d4 = self.Up3D_conv4(d4) 156 | d3 = self.Up3(d4) 157 | x2 = self.Att3D_3(g=d3,x=x2) 158 | d3 = torch.cat((x2,d3),dim=1) 159 | d3 = self.Up3D_conv3(d3) 160 | d2 = self.Up2(d3) 161 | x1 = self.Att3D_2(g=d2,x=x1) 162 | d2 = torch.cat((x1,d2),dim=1) 163 | d2 = self.Up3D_conv2(d2) 164 | d1 = self.Conv3D_1x1(d2) 165 | d1 = self.sig(d1) 166 | return d1 -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # File to load in the periodic lattice data 2 | # Code written by Jordan 3 | from __future__ import print_function 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | import pickle 11 | import numpy as np 12 | from torch.utils.data.dataset import Dataset 13 | from pymatgen.core.structure import Structure 14 | 15 | class CrystalDataset(Dataset): 16 | #load input data. 17 | def __init__(self, lower=0,upper=1): 18 | print('Doing data loading now') 19 | options = range(32) 20 | self.input_data = np.vstack([pickle.load(open('/scratch/jordanh/2/RL_2_'+str(options[n])+'.pickle','rb'),encoding='bytes') for n in range(lower,upper)]) 21 | print('Shape of Data is ',np.shape(self.input_data)) 22 | 23 | def __getitem__(self, index): 24 | ID,electron,species,label = self.input_data[index] 25 | species = np.array(species).astype(int) 26 | specie = np.zeros((95,30,30,30)) 27 | np.put_along_axis(specie,species[None,...],1,0) 28 | cif_file = '/scratch/jordanh/data/'+str(ID)+'.cif' 29 | crystal = Structure.from_file(cif_file) 30 | abc = np.array(crystal.lattice.abc) 31 | angles = np.array(crystal.lattice.angles) 32 | # not currently used. 33 | label = (label+1.7021) 34 | # also not currently used 35 | electronPadded = np.pad(electron,5,'symmetric') 36 | electron = electron.reshape((1,30,30,30)) 37 | electronPadded = electronPadded.reshape((1,40,40,40)) 38 | electron = torch.from_numpy(electron) 39 | electronPadded = torch.from_numpy(electronPadded) 40 | mat = torch.from_numpy(specie) 41 | label = torch.as_tensor([label]) 42 | return (electronPadded.float(), electron.float() ,label.float(),mat.float()) 43 | 44 | 45 | def __len__(self): 46 | return len(self.input_data) -------------------------------------------------------------------------------- /ims/Interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoffmannjordan/Encoding-Decoding-3D-Crystals/4d437884b0ffdf0cc6a0df2fb56992f4238a88c9/ims/Interpolation.png -------------------------------------------------------------------------------- /ims/Model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoffmannjordan/Encoding-Decoding-3D-Crystals/4d437884b0ffdf0cc6a0df2fb56992f4238a88c9/ims/Model.png -------------------------------------------------------------------------------- /ims/Unit-Cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoffmannjordan/Encoding-Decoding-3D-Crystals/4d437884b0ffdf0cc6a0df2fb56992f4238a88c9/ims/Unit-Cell.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | import pickle 9 | import numpy as np 10 | from torch.utils.data.dataset import Dataset 11 | from pymatgen.core.structure import Structure 12 | from dataset import * 13 | from CrystalVAE import * 14 | from Discriminator import * 15 | from Segmentation import * 16 | 17 | 18 | 19 | def loss_MSE(recon_x, x, mu, logvar,epoch): 20 | BCE = F.mse_loss(recon_x, x, size_average=True) 21 | weight = 5.0 22 | KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) 23 | return BCE + weight*KLD, BCE, weight*KLD 24 | 25 | def loss_BCE(recon_x, x, mu, logvar,epoch): 26 | weight = 0.0 27 | BCE = F.binary_cross_entropy(recon_x, x, size_average=True) 28 | KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) 29 | return BCE + weight*KLD, BCE, weight*KLD 30 | 31 | 32 | def train(args, model, device, train_loader, optimizer, epoch ,UNet, optimizer2,): 33 | model.train() 34 | for batch_idx, (data, target,label, abcangles, species_mat) in enumerate(train_loader): 35 | 36 | data, target = data.to(device), target.to(device) 37 | species_mat = species_mat.to(device) 38 | 39 | optimizer.zero_grad() 40 | 41 | optimizer2.zero_grad() 42 | 43 | reconstruction, mu, logvar = model(data) 44 | 45 | UNetRecon = UNet(reconstruction) 46 | loss1, bce, kld = loss_MSE(reconstruction, target, mu, logvar,epoch) 47 | lossBCE, bce2, kld2 = loss_BCE(UNetRecon, species_mat, mu, logvar,epoch) 48 | loss = loss1 + 0.1*lossBCE 49 | loss.backward(retain_graph=True) 50 | lossBCE.backward() 51 | optimizer2.step() 52 | optimizer.step() 53 | 54 | 55 | 56 | def test(args, model, device, test_loader, epoch,unet): 57 | 58 | model.eval() 59 | unet.eval() 60 | test_loss = 0 61 | SAVE_UP_TO = 6 62 | with torch.no_grad(): 63 | i = 0 64 | counter = 0 65 | for data, target, label , abcangles, species_mat in test_loader: 66 | # Send the data to the device (GPU) 67 | data, target = data.to(device), target.to(device) 68 | species_mat = species_mat.to(device) 69 | # Compute reconstruction and segmentation 70 | reconstruction, mu, logvar = model(data) 71 | species_pred = unet(reconstruction) 72 | # Compute recontsruction loss 73 | loss , bce , kld = loss_MSE(reconstruction, target, mu, logvar,1) 74 | loss2, bce2, kld2 = loss_BCE(species_pred, species_mat, mu, logvar,epoch) 75 | 76 | species_pred = species_pred.cpu().numpy() 77 | label = label.cpu().numpy() 78 | outputNP3 = reconstruction.cpu().numpy() 79 | targetNP3 = target.cpu().numpy() 80 | species_mat = species_mat.cpu().numpy() 81 | 82 | 83 | if i < SAVE_UP_TO: 84 | for ii in range(0,18): 85 | PREDICTION = outputNP3[ii] 86 | TARGET = targetNP3[ii] 87 | mat1 = PREDICTION.flatten() 88 | np.savetxt('/SAVE/DIRECTORY/ElectronDensity_Pred_'+str(counter)+'_'+str(epoch)+'.csv',mat1) 89 | mat1 = TARGET.flatten() 90 | np.savetxt('/SAVE/DIRECTORY/ElectronDensity_True_'+str(counter)+'_'+str(epoch)+'.csv',mat1) 91 | species_mat_True = species_mat[ii].argmax(axis=0) 92 | species_mat_Pred = species_pred[ii].argmax(axis=0) 93 | np.savetxt('/SAVE/DIRECTORY/Species_True_'+str(counter)+'_'+str(epoch)+'.csv',species_mat_True.flatten().astype(int),fmt='%i') 94 | np.savetxt('/SAVE/DIRECTORY/Species_Pred_'+str(counter)+'_'+str(epoch)+'.csv',species_mat_Pred.flatten().astype(int),fmt='%i') 95 | counter += 1 96 | 97 | i += 1 98 | print("Test loss is: ",loss) 99 | 100 | 101 | 102 | def main(): 103 | 104 | parser = argparse.ArgumentParser(description='Repeating Lattice VAE (ReLa)') 105 | parser.add_argument('--epochs', type=int, default=50, metavar='N', 106 | help='number of epochs to train (default: 50)') 107 | parser.add_argument('--lr', type=float, default=0.0000001, metavar='LR', 108 | help='learning rate (default: 0.0000001)') 109 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 110 | help='SGD momentum (default: 0.5)') 111 | parser.add_argument('--no-cuda', action='store_true', default=False, 112 | help='disables CUDA training') 113 | parser.add_argument('--seed', type=int, default=1, metavar='S', 114 | help='random seed (default: 1)') 115 | 116 | args = parser.parse_args() 117 | use_cuda = not args.no_cuda and torch.cuda.is_available() 118 | 119 | torch.manual_seed(args.seed) 120 | 121 | device = torch.device("cuda" if use_cuda else "cpu") 122 | 123 | kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} 124 | 125 | # We used 32 pickle files to load the data. 126 | CrystalDataset_Train = CrystalDataset(1,32) 127 | train_loader = torch.utils.data.DataLoader(CrystalDataset_Train, batch_size=18, 128 | shuffle=True, num_workers=0) 129 | CrystalDataset_Test = CrystalDataset(0,1) 130 | test_loader = torch.utils.data.DataLoader(CrystalDataset_Test, batch_size=18, 131 | shuffle=False, num_workers=0) 132 | 133 | model = CVAE().to(device) 134 | UNet = AttU_Net3D().to(device) 135 | 136 | 137 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) 138 | optimizer2 = optim.SGD(UNet.parameters(), lr=args.lr, momentum=args.momentum) 139 | 140 | for epoch in range(1, args.epochs + 1): 141 | train(args, model, device, train_loader, optimizer, epoch,UNet,optimizer2) 142 | test(args, model, device, test_loader, epoch,UNet) 143 | if False: 144 | if epoch%2==0: 145 | torch.save(model.state_dict(),"ReLaDS_"+str(epoch)+".pt") 146 | torch.save(UNet.state_dict(), "ReLaDS_U_"+str(epoch)+".pt") 147 | 148 | if __name__ == '__main__': 149 | main() 150 | 151 | 152 | --------------------------------------------------------------------------------