├── models ├── global_config.py ├── kernels │ ├── irregular_mean_fwd_old.cu │ ├── irregular_transpose.cu │ ├── irregular_transpose_old.cu │ ├── fused_dpa_bwd_q_v2.cu │ ├── fused_dpa_bwd_q_v3.cu │ ├── fused_dpa_bwd_q.cu │ ├── minimum_distance.cu │ ├── reci_enc_fwd.cu │ ├── reci_enc_fwd_v2.cu │ ├── reci_enc_bwd.cu │ ├── reci_enc_bwd_v2.cu │ ├── irregular_mean_fwd.cu │ ├── fused_dpa_fwd_v2.cu │ ├── pairwise_sum.cuh │ ├── fused_dpa_fwd.cu │ ├── fused_dpa_bwd.cu │ ├── fused_dpa_fwd_v3.cu │ ├── fused_dpa_bwd_v3.cu │ ├── fused_dpa_bwd_v2.cu │ ├── real_enc_fwd.cu │ ├── real_enc_fwd_v2.cu │ ├── real_enc_bwd.cu │ ├── real_enc_bwd_v2.cu │ ├── real_enc_proj_fwd.cu │ ├── real_enc_proj_fwd_v2.cu │ ├── reduce_kernel_utils.cuh │ ├── real_enc_proj_bwd.cu │ └── real_enc_proj_bwd_v2.cu ├── cuda_funcs │ ├── __init__.py │ ├── minimum_distance.py │ ├── irregular_mean.py │ ├── reci_space_enc.py │ ├── kernel_manager.py │ ├── real_space_enc_proj.py │ ├── fused_dpa_v2.py │ ├── fused_dpa.py │ └── real_space_enc.py ├── pooling.py └── latticeformer_params.py ├── losses └── regression_loss.py ├── init_datasets.py ├── demo.sh ├── LICENSE ├── train.sh ├── docker └── pytorch21_cuda121 │ └── Dockerfile ├── params └── latticeformer │ └── default.json ├── data ├── download_megnet_elastic.py └── download_jarvis.py ├── dataloaders └── dataset_latticeformer.py ├── demo.py ├── README.md ├── utils.py └── train.py /models/global_config.py: -------------------------------------------------------------------------------- 1 | REPRODUCIBLITY_STATE:int = 0 2 | # == 0: can reproduce paper's results but slow 3 | # >= 1: replace pooling by CUDA code 4 | # >= 2: replace self-attention by CUDA code (pairwise sum for softmax) 5 | # >= 3: replace self-attention by CUDA code (+ divided sum for running sum) 6 | # >= 4: replace periodic encoding with better numerical accuracy code 7 | -------------------------------------------------------------------------------- /losses/regression_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | 5 | def regression_loss(pred, data, targets, scale, bias, loss_fn:str): 6 | loss_fn = loss_fn.lower() 7 | assert(len(targets)==pred.shape[1]) 8 | assert loss_fn in ("l1", "mse", "smooth_l1") 9 | 10 | if loss_fn == "l1": 11 | loss_fn = F.l1_loss 12 | elif loss_fn == "mse": 13 | loss_fn = F.mse_loss 14 | elif loss_fn == "smooth_l1": 15 | loss_fn = F.smooth_l1_loss 16 | 17 | loss = 0 18 | for i, t in enumerate(targets): 19 | labels = (data[t]-bias[i])/scale[i] 20 | loss += loss_fn(pred[:, i], labels, reduction='none') 21 | return loss 22 | -------------------------------------------------------------------------------- /models/kernels/irregular_mean_fwd_old.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void irregular_mean_fwd( 4 | const float* src_n, 5 | const long long int* start_n, 6 | const long long int B, 7 | const long long int D, 8 | float* dst_n 9 | ){ 10 | const unsigned long long int tid = (unsigned long long int)blockDim.x * blockIdx.x + threadIdx.x; 11 | if (tid >= B*D) return; 12 | const unsigned int n = tid/D; 13 | const unsigned int k = tid%D; 14 | const unsigned long long int start = start_n[n]; 15 | const unsigned long long int end = start_n[n+1]; 16 | 17 | float sum = 0; 18 | const float* end_ptr = src_n + end*D + k; 19 | src_n += start*D + k; 20 | while(src_n != end_ptr){ 21 | sum += *src_n; 22 | src_n += D; 23 | } 24 | dst_n[n*D+k] = sum / (end - start); 25 | } -------------------------------------------------------------------------------- /init_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | from dataloaders.dataset_latticeformer import RegressionDatasetMP_Latticeformer as Dataset 3 | from dataloaders.common import CellFormat 4 | splits = ["train", "val", "test", "all"] 5 | datasets = [ 6 | "jarvis__megnet", 7 | "jarvis__megnet-shear", 8 | "jarvis__megnet-bulk", 9 | "jarvis__dft_3d_2021", 10 | "jarvis__dft_3d_2021-ehull", 11 | "jarvis__dft_3d_2021-mbj_bandgap", 12 | ] 13 | 14 | import torch 15 | for dataset in datasets: 16 | for split in splits: 17 | for format in [CellFormat.RAW, CellFormat.PRIMITIVE]: 18 | if ("shear" in dataset or "bulk" in dataset) and split == "all": 19 | continue 20 | print("Processing ------------------", dataset, split, format) 21 | data = Dataset(split, dataset, format) 22 | sizes = data.data.sizes.float() 23 | print(torch.mean(sizes).item(), torch.max(sizes).item(), torch.median(sizes).item(), torch.std(sizes).item()) 24 | -------------------------------------------------------------------------------- /models/cuda_funcs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | CUPY_AVAILABLE=True 3 | 4 | try: 5 | import cupy as cp 6 | import pytorch_pfn_extras as ppe 7 | from torch.utils.dlpack import to_dlpack, from_dlpack 8 | ppe.cuda.use_torch_mempool_in_cupy() 9 | except: 10 | CUPY_AVAILABLE = False 11 | 12 | from .kernel_manager import Kernel, KernelManager, compile_kernels 13 | from .real_space_enc import RealPeriodicEncodingFuncCUDA 14 | from .real_space_enc_proj import RealPeriodicEncodingWithProjFuncCUDA 15 | from .reci_space_enc import ReciPeriodicEncodingFuncCUDA 16 | from .fused_dpa import FusedDotProductAttentionCUDA 17 | from .irregular_mean import IrregularMeanCUDA 18 | 19 | __all__ = [ 20 | 'KernelManager', 21 | 'Kernel', 22 | 'compile_kernels', 23 | 'FusedDotProductAttentionCUDA', 24 | 'RealPeriodicEncodingFuncCUDA', 25 | 'RealPeriodicEncodingWithProjFuncCUDA', 26 | 'ReciPeriodicEncodingFuncCUDA', 27 | 'IrregularMeanCUDA', 28 | 'CUPY_AVAILABLE', 29 | ] 30 | -------------------------------------------------------------------------------- /demo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | save_path="result/latticeformer/jarvis" 4 | 5 | # | target_set | targets | 6 | # jarvis__megnet | e_form | bandgap | 7 | # jarvis__megnet-bulk | bulk_modulus 8 | # jarvis__megnet-shear | shear_modulus 9 | # jarvis__dft_3d_2021 | formation_energy | total_energy | opt_bandgap | 10 | # jarvis__dft_3d_2021-ehull | ehull | 11 | # jarvis__dft_3d_2021-mbj_bandgap | mbj_bandgap | 12 | 13 | target_set=jarvis__megnet 14 | targets=e_form 15 | gpu=0 16 | exp_name=demo 17 | reproduciblity_state=3 18 | layer=4 # {4, 7} 19 | 20 | CUDA_VISIBLE_DEVICES=${gpu} python demo.py -p latticeformer/default.json \ 21 | --seed 123 \ 22 | --save_path ${save_path} \ 23 | --domain real \ 24 | --num_layers ${layer} \ 25 | --experiment_name ${exp_name}/${targets} \ 26 | --target_set ${target_set} \ 27 | --targets ${targets} \ 28 | --batch_size 256 \ 29 | --reproduciblity_state ${reproduciblity_state} \ 30 | --pretrained_model weights/megnet-${targets}-layer${layer}.ckpt 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 OMRON SINIC X 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/kernels/irregular_transpose.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | void irregular_transpose( 3 | const float* src_matrices, 4 | const long long int* upper_e_t, 5 | const long long int* mat_sec_t, 6 | const long long int* dims_t, 7 | const long long int hE, 8 | const long long int D_, 9 | float* dst_matrices 10 | ){ 11 | const unsigned long long int tid = (unsigned long long int)blockDim.x * blockIdx.x + threadIdx.x; 12 | if (tid >= hE*D_) return; 13 | const unsigned long long int t = tid/D_; 14 | const unsigned int k = tid%D_; 15 | const unsigned int D = D_; 16 | const unsigned long long int e = upper_e_t[t]; 17 | const unsigned long long int e0 = mat_sec_t[t]; 18 | const unsigned int dim = dims_t[t]; 19 | const unsigned int i = (e - e0) / dim; 20 | const unsigned int j = (e - e0) % dim; 21 | 22 | const unsigned long long int mat_index = e0*D_ + k; 23 | src_matrices += mat_index; 24 | dst_matrices += mat_index; 25 | const unsigned int ij = (i*dim + j)*D; 26 | const unsigned int ji = (j*dim + i)*D; 27 | float tmp = src_matrices[ji]; 28 | dst_matrices[ji] = src_matrices[ij]; 29 | dst_matrices[ij] = tmp; 30 | } 31 | 32 | -------------------------------------------------------------------------------- /models/pooling.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import Tensor 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torch.nn import Parameter, Module, Linear, Dropout, LayerNorm, ModuleList 7 | from typing import List, Optional, Tuple, Union, Callable 8 | from . import cuda_funcs 9 | 10 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 11 | import math 12 | from . import global_config as config 13 | 14 | def max_pool(x, batch, sizes): 15 | x = torch.split_with_sizes(x, sizes.tolist(), 0) 16 | x = torch.stack([torch.max(x,dim=0)[0] for x in x]) 17 | return x 18 | 19 | def avr_pool(x, batch, sizes): 20 | if config.REPRODUCIBLITY_STATE>=1 and cuda_funcs.CUPY_AVAILABLE: 21 | x = cuda_funcs.IrregularMeanCUDA.apply(x, batch, sizes) 22 | else: 23 | x = torch.split_with_sizes(x, sizes.tolist(), 0) 24 | x = torch.stack([torch.mean(x,dim=0) for x in x]) 25 | return x 26 | 27 | 28 | def _get_activation_fn(activation): 29 | if activation == "relu": 30 | return F.relu 31 | elif activation == "gelu": 32 | return F.gelu 33 | 34 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 35 | 36 | -------------------------------------------------------------------------------- /models/kernels/irregular_transpose_old.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void irregular_transpose_old( 4 | const float* src_matrices, 5 | const long long int* start_i, 6 | const short* dims_i, 7 | const long long int B, 8 | const long long int D_, 9 | float* dst_matrices 10 | ){ 11 | const unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; 12 | if (tid >= B*D_) return; 13 | const unsigned int D = D_; 14 | const unsigned int i = tid/D; 15 | const unsigned int k = tid%D; 16 | const unsigned int dim = dims_i[i]; 17 | 18 | src_matrices += start_i[i]*D+k; 19 | dst_matrices += start_i[i]*D+k; 20 | 21 | unsigned int ij_step = D; 22 | unsigned int ji_step = D*dim; 23 | for (unsigned int row = 0; row < dim; row++){ 24 | unsigned int ij = (dim+1)*row*D; 25 | unsigned int ji = ij; 26 | for (unsigned int col = row; col < dim; col++){ 27 | // unsigned int ij = (row*dim+col)*D; 28 | // unsigned int ji = (col*dim+row)*D; 29 | 30 | float tmp = src_matrices[ji]; 31 | dst_matrices[ji] = src_matrices[ij]; 32 | dst_matrices[ij] = tmp; 33 | ij += ij_step; 34 | ji += ji_step; 35 | } 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /models/cuda_funcs/minimum_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | def compute_minimum_distance(rpos_ij_e, tvecs_n, batch_i, edge_ij_e, rvlen_n, cutoff_radius): 18 | assert cutoff_radius >= 0 19 | E = rpos_ij_e.shape[0] 20 | dev = rpos_ij_e.device 21 | 22 | rpos_ij_e = rpos_ij_e.contiguous() 23 | tvecs_n = tvecs_n.contiguous() 24 | batch_i = batch_i.contiguous() 25 | edge_ij_e = edge_ij_e.contiguous() 26 | rvlen_n = rvlen_n.contiguous() 27 | 28 | dist2_min_e = torch.empty((E, ), device=dev, dtype=rpos_ij_e.dtype) 29 | 30 | bsz = 32 31 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 32 | KernelManager.minimum_distance( ((E+bsz-1)//bsz, ), (bsz, ), ( 33 | _to_copy(rpos_ij_e), 34 | _to_copy(tvecs_n), 35 | _to_copy(batch_i), 36 | _to_copy(edge_ij_e), 37 | E, 38 | _to_copy(rvlen_n), 39 | cutoff_radius, 40 | _to_copy(dist2_min_e), 41 | )) 42 | return dist2_min_e -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | save_path="result/latticeformer/jarvis" 4 | 5 | # | target_set | targets | 6 | # jarvis__megnet | e_form | bandgap | 7 | # jarvis__megnet-bulk | bulk_modulus 8 | # jarvis__megnet-shear | shear_modulus 9 | # jarvis__dft_3d_2021 | formation_energy | total_energy | opt_bandgap | 10 | # jarvis__dft_3d_2021-ehull | ehull | 11 | # jarvis__dft_3d_2021-mbj_bandgap | mbj_bandgap | 12 | 13 | target_set=jarvis__dft_3d_2021 14 | targets=formation_energy 15 | gpu=0 16 | exp_name=speed_test 17 | reproduciblity_state=3 18 | layer=4 19 | 20 | CUDA_VISIBLE_DEVICES=${gpu} python train.py -p latticeformer/default.json \ 21 | --seed 123 \ 22 | --save_path ${save_path} \ 23 | --domain real \ 24 | --num_layers ${layer} \ 25 | --experiment_name ${exp_name}/${targets} \ 26 | --target_set ${target_set} \ 27 | --targets ${targets} \ 28 | --batch_size 256 \ 29 | --value_pe_dist_real 0 \ 30 | --reproduciblity_state ${reproduciblity_state} \ 31 | ; \ 32 | CUDA_VISIBLE_DEVICES=${gpu} python train.py -p latticeformer/default.json \ 33 | --seed 123 \ 34 | --save_path ${save_path} \ 35 | --domain real \ 36 | --num_layers ${layer} \ 37 | --experiment_name ${exp_name}/${targets} \ 38 | --target_set ${target_set} \ 39 | --targets ${targets} \ 40 | --batch_size 256 \ 41 | --reproduciblity_state ${reproduciblity_state} \ 42 | 43 | -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd_q_v2.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/reduce_kernel_utils.cuh" 2 | 3 | __global__ __device__ void fused_dpa_bwd_q_thread( 4 | const float* key_ihk, 5 | const float* gaij_eh, 6 | const long long int* edge_j_e, 7 | const long long int e_start, 8 | const long long int e_end, 9 | const long long int H, 10 | float* gque_k 11 | ){ 12 | long long int e = e_start + threadIdx.x; 13 | float g_softmax = 0.0f; 14 | if (e < e_end){ 15 | long long int j = edge_j_e[e]; 16 | g_softmax = gaij_eh[e*H]; 17 | key_ihk += (j*H)*K_HEAD_DIM; 18 | } 19 | 20 | #pragma unroll 21 | for (int k = 0; k < K_HEAD_DIM; k++){ 22 | float gq = g_softmax*key_ihk[k]; 23 | __syncthreads(); 24 | gq = blockReduceSum(gq); 25 | if (threadIdx.x == 0) 26 | gque_k[k] = gq; 27 | } 28 | } 29 | 30 | extern "C" __global__ 31 | void fused_dpa_bwd_q_v2( 32 | const float* key_ihk, 33 | const float* gaij_eh, 34 | const long long int* edge_ij_e, 35 | const long long int* e_start_i, 36 | const long long int N, 37 | const long long int H, 38 | const long long int E, 39 | float* gque_ihk 40 | ){ 41 | const long long int tid = (long long int)blockIdx.x*blockDim.x + threadIdx.x; 42 | if (tid >= N*H) return; 43 | 44 | const long long int i = tid / H; 45 | const long long int h = tid % H; 46 | const long long int e_start = e_start_i[i]; 47 | const long long int e_end = e_start_i[i+1]; 48 | 49 | fused_dpa_bwd_q_thread<<< 1, ((e_end-e_start+31)/32)*32 >>>( 50 | key_ihk + h*K_HEAD_DIM, 51 | gaij_eh + h, 52 | edge_ij_e + E, 53 | e_start, 54 | e_end, 55 | H, 56 | gque_ihk + tid*K_HEAD_DIM 57 | ); 58 | } -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd_q_v3.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void fused_dpa_bwd_q_v3( 4 | const float* key_ihk, 5 | const float* gaij_eh, 6 | const long long int* edge_ij_e, 7 | const long long int* e_start_i, 8 | const long long int N, 9 | const long long int H, 10 | const long long int E, 11 | float* gque_ihk 12 | ){ 13 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 14 | if (tid >= N*H) return; 15 | 16 | const long long int K = VPE_DIM; 17 | const long long int i = tid/H; 18 | const long long int h = tid%H; 19 | const long long int e_start = e_start_i[i]; 20 | const long long int e_end = e_start_i[i+1]; 21 | 22 | __shared__ float _gq[THREAD_NUM][K_HEAD_DIM+1]; 23 | __shared__ float _rgq[THREAD_NUM][K_HEAD_DIM+1]; 24 | float *gq = _gq[threadIdx.x]; 25 | float *rgq = _rgq[threadIdx.x]; 26 | 27 | #pragma unroll 28 | for (int k = 0; k < K_HEAD_DIM; k++){ 29 | gq[k] = 0; 30 | rgq[k] = 0; 31 | } 32 | for (long long int e = e_start; e < e_end; e++) 33 | { 34 | long long int j = edge_ij_e[E+e]; 35 | float g_softmax = gaij_eh[e*H+h]; 36 | const float *key = key_ihk + (j*H+h)*K_HEAD_DIM; 37 | #pragma unroll 38 | for (int k = 0; k < K_HEAD_DIM; k++){ 39 | rgq[k] += g_softmax*key[k]; 40 | } 41 | 42 | if (((e-e_start) % RUNNING_SUM_LEN) == 0 || e == e_end-1){ 43 | #pragma unroll 44 | for (int k = 0; k < K_HEAD_DIM; k++){ 45 | gq[k] += rgq[k]; 46 | rgq[k] = 0; 47 | } 48 | } 49 | } 50 | 51 | gque_ihk += (i*H+h)*K_HEAD_DIM; 52 | #pragma unroll 53 | for (int k = 0; k < K_HEAD_DIM; k++) 54 | gque_ihk[k] = gq[k]; 55 | } -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd_q.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void fused_dpa_bwd_q( 4 | const float* key_ihk, 5 | const float* gaij_eh, 6 | const long long int* edge_ij_e, 7 | const long long int* e_start_i, 8 | const long long int N, 9 | const long long int H, 10 | const long long int E, 11 | float* gque_ihk 12 | ){ 13 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 14 | if (tid >= N*H) return; 15 | 16 | const long long int K = VPE_DIM; 17 | const long long int i = tid/H; 18 | const long long int h = tid%H; 19 | const long long int e_end = e_start_i[i+1]; 20 | 21 | __shared__ float _gq[THREAD_NUM][K_HEAD_DIM+1]; 22 | float *gq = _gq[threadIdx.x]; 23 | 24 | #pragma unroll 25 | for (int k = 0; k < K_HEAD_DIM; k++) 26 | gq[k] = 0; 27 | 28 | for (long long int e = e_start_i[i]; e < e_end; e++) 29 | { 30 | long long int j = edge_ij_e[E+e]; 31 | float g_softmax = gaij_eh[e*H+h]; 32 | const float *key = key_ihk + (j*H+h)*K_HEAD_DIM; 33 | #pragma unroll 34 | for (int k = 0; k < K_HEAD_DIM; k++){ 35 | gq[k] += g_softmax*key[k]; 36 | } 37 | // gb = go.reshape(s,1,H,K) * p.reshape(s,s,H,1) 38 | // gv = gb.sum(dim=0) 39 | // gval.append(gv) 40 | // gbij.append(gb.reshape(s*s,H,K)) 41 | // gsm = (v.reshape(1,s,H,K) + b.reshape(s,s,H,K) - o.reshape(s,1,H,K))*gb 42 | // ga = gsm.sum(dim=3) 43 | // gq = (ga.reshape(s,s,H,1)*k.reshape(1,s,H,K)).sum(dim=1) 44 | // gk = (ga.reshape(s,s,H,1)*q.reshape(s,1,H,K)).sum(dim=0) 45 | // gaij.append(ga.reshape(s*s,H)) 46 | // gque.append(gq) 47 | // gkey.append(gk) 48 | } 49 | 50 | gque_ihk += (i*H+h)*K_HEAD_DIM; 51 | #pragma unroll 52 | for (int k = 0; k < K_HEAD_DIM; k++) 53 | gque_ihk[k] = gq[k]; 54 | } -------------------------------------------------------------------------------- /docker/pytorch21_cuda121/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:12.1.1-cudnn8-devel-ubuntu20.04 2 | 3 | # Setup proxies if needed 4 | # ENV http_proxy 'hogehoge' 5 | # ENV http_proxy $http_proxy 6 | # ENV HTTP_PROXY $http_proxy 7 | # ENV https_proxy $http_proxy 8 | # ENV HTTPS_PROXY $http_proxy 9 | # ENV ftp_proxy $http_proxy 10 | # ENV FTP_PROXY $http_proxy 11 | 12 | # Install basics 13 | ENV DEBIAN_FRONTEND=noninteractive 14 | RUN apt-get update -y \ 15 | && apt-get install -y software-properties-common apt-utils git wget curl ca-certificates bzip2 cmake tree htop bmon iotop g++ \ 16 | && apt-get clean 17 | 18 | # Fix the issue of missing GLIBCXX_3.4.29. 19 | # The 'software-properties-common' is installed above to run add-apt-repository. 20 | RUN add-apt-repository ppa:ubuntu-toolchain-r/test \ 21 | && apt update -y \ 22 | && apt upgrade -y libstdc++6 \ 23 | && apt-get clean 24 | 25 | # Install Miniconda 26 | ARG PYTHON_VERSION=3.11 27 | RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 28 | && chmod +x ~/miniconda.sh \ 29 | && ~/miniconda.sh -b -p /opt/conda \ 30 | && rm ~/miniconda.sh 31 | ENV PATH /opt/conda/bin:$PATH 32 | 33 | RUN conda install -y numpy scipy matplotlib \ 34 | && conda clean -i -t -y 35 | RUN pip install --no-cache-dir pymatgen 36 | 37 | ARG TORCH=2.1.1 38 | ARG TORCH_PYG=2.1.0 39 | ARG CUDA=cu121 40 | 41 | RUN pip install --no-cache-dir \ 42 | torch==${TORCH}+${CUDA} \ 43 | --extra-index-url https://download.pytorch.org/whl/${CUDA} 44 | 45 | RUN pip install --no-cache-dir \ 46 | pyg_lib \ 47 | torch_scatter \ 48 | torch_sparse \ 49 | torch_cluster \ 50 | torch_spline_conv \ 51 | torch_geometric -f https://data.pyg.org/whl/torch-${TORCH_PYG}+${CUDA}.html 52 | 53 | RUN pip install --no-cache-dir \ 54 | pytorch-lightning==2.1.3 \ 55 | cupy-cuda12x pytorch-pfn-extras \ 56 | jarvis-tools \ 57 | tensorboard 58 | 59 | ENV CUDA_DEVICE_ORDER PCI_BUS_ID 60 | -------------------------------------------------------------------------------- /params/latticeformer/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 123, 3 | "experiment_name": "code_v5", 4 | "save_path": "result/latticeformer", 5 | "optimizer": "adamw", 6 | "lr": 0.0005, 7 | "lr_sch": "inverse_sqrt_nowarmup", 8 | "sch_params": [4000.0], 9 | "adam_betas": [0.9, 0.98], 10 | "weight_decay": 1e-5, 11 | "clip_norm": 1.0, 12 | "clip_grad": 0.0, 13 | "dropout": 0.0, 14 | "batch_size": 128, 15 | "n_epochs": 500, 16 | "targets": "bandgap", 17 | "loss_func":"L1", 18 | "encoder_name": "latticeformer", 19 | "num_layers": 4, 20 | "lattice_range": 2, 21 | "adaptive_cutoff_sigma": -3.5, 22 | "minimum_range": true, 23 | "model_dim": 128, 24 | "k_dim": 0, 25 | "v_dim": 0, 26 | "head_num": 8, 27 | "ff_dim": 512, 28 | "t_fixup_init": true, 29 | "embedding_dim": [128], 30 | "t_activation": "relu", 31 | "domain": "real", 32 | "gauss_lb_real": 0.5, 33 | "gauss_lb_reci": 0.5, 34 | "gauss_state": "q", 35 | "normalize_gauss": true, 36 | "positive_func": "elu=0.1", 37 | "use_cgcnn_feat": false, 38 | "scale_real": [1.4], 39 | "scale_reci": [2.2], 40 | "value_pe_dist_real": 64, 41 | "value_pe_dist_reci": 0, 42 | "value_pe_wave_real": 0, 43 | "value_pe_wave_reci": 0, 44 | "value_pe_dist_max": -10.0, 45 | "value_pe_width_scale": 1.0, 46 | "value_pe_headed": true, 47 | "value_pe_condproj": "no", 48 | "exclude_self": false, 49 | "norm_func_mode": 0, 50 | "pooling": "avr", 51 | "pre_pooling_op": "no", 52 | "norm_type": "no", 53 | "freeze_bn_epochs": 0, 54 | "train_filter_max": 0, 55 | "train_filter_min": 0, 56 | "train_percent_check": 1.0, 57 | "val_percent_check": 1.0, 58 | "model_checkpoint_save_top_k": 1, 59 | "pretrained_model": null, 60 | "num_workers": 0, 61 | "target_set": null, 62 | "normalize_targets": "scale_bias", 63 | "swa_epochs": 50, 64 | "scale_grad": 0.0, 65 | "use_cuda_code": true, 66 | "use_low_memory": true, 67 | "ddp": false, 68 | "reproduciblity_state": 0 69 | } 70 | -------------------------------------------------------------------------------- /models/kernels/minimum_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void minimum_distance( 5 | const float* rpos_ij_e, 6 | const float* tvecs_n, 7 | const long long int* batch_i, 8 | const long long int* edge_ij_e, 9 | const long long int E, 10 | const float* rveclens_n, 11 | const double cutoff_radius, 12 | float* dist2_min_e){ 13 | 14 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 15 | if (tid >= E) return; 16 | 17 | const long long int e = tid; 18 | const long long int i = edge_ij_e[e]; 19 | const long long int n = batch_i[i]; 20 | rpos_ij_e += e*3; 21 | const float r_ijx = rpos_ij_e[0]; 22 | const float r_ijy = rpos_ij_e[1]; 23 | const float r_ijz = rpos_ij_e[2]; 24 | tvecs_n += n*9; 25 | const float t1_x = tvecs_n[0]; 26 | const float t1_y = tvecs_n[1]; 27 | const float t1_z = tvecs_n[2]; 28 | const float t2_x = tvecs_n[3]; 29 | const float t2_y = tvecs_n[4]; 30 | const float t2_z = tvecs_n[5]; 31 | const float t3_x = tvecs_n[6]; 32 | const float t3_y = tvecs_n[7]; 33 | const float t3_z = tvecs_n[8]; 34 | 35 | rveclens_n += n*3; 36 | const float rvl1 = rveclens_n[0]; 37 | const float rvl2 = rveclens_n[1]; 38 | const float rvl3 = rveclens_n[2]; 39 | 40 | float cutoff = (float)cutoff_radius; 41 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 42 | if (cutoff > 0.0f) 43 | { 44 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 45 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 46 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 47 | 48 | #if MINIMUM_RANGE > 0 49 | R1 = max(R1, MINIMUM_RANGE); 50 | R2 = max(R2, MINIMUM_RANGE); 51 | R3 = max(R3, MINIMUM_RANGE); 52 | #endif 53 | } 54 | 55 | float d2min = 1e10; 56 | for (float n1 = -R1; n1 <= R1; n1++) 57 | for (float n2 = -R2; n2 <= R2; n2++) 58 | for (float n3 = -R3; n3 <= R3; n3++) 59 | { 60 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 61 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 62 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 63 | float d2 = dx*dx + dy*dy + dz*dz; 64 | d2min = fminf(d2min, d2); 65 | } 66 | dist2_min_e[e] = d2min; 67 | } -------------------------------------------------------------------------------- /models/cuda_funcs/irregular_mean.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | # This function can be implemented with scatter_add, 18 | # which however does not ensure the reproducibility. 19 | class IrregularMeanCUDA(torch.autograd.Function): 20 | @staticmethod 21 | def forward(ctx, x_ik:Tensor, batch_i:Tensor, sizes:Tensor): 22 | # x : (points, *) 23 | # batch_i : (points) 24 | N = x_ik.shape[0] 25 | D = x_ik.numel() // N 26 | dev = x_ik.device 27 | kw = {'device': x_ik.device, 'dtype': x_ik.dtype} 28 | 29 | if sizes is None: 30 | B = batch_i.max().item()+1 31 | sizes = torch.zeros(B, dtype=torch.long, device=dev) 32 | sizes.scatter_add_(0, batch_i, torch.ones(batch_i.shape, dtype=torch.long, device=dev)) 33 | else: 34 | B = sizes.shape[0] 35 | 36 | start_n = torch.constant_pad_nd(torch.cumsum(sizes, 0), (1,0)) 37 | 38 | x_ik = x_ik.contiguous().detach() 39 | start_n = start_n.contiguous() 40 | 41 | o_nk = torch.empty((B, ) + x_ik.shape[1:], **kw) 42 | 43 | bsz = min(32, D) 44 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 45 | assert (sizes <= KernelManager.MAX_SYSTEM_SIZE_POW2).all(), "Increase MAX_SYSTEM_SIZE in KernelManager" 46 | 47 | KernelManager.irregular_mean_fwd(((B*D+bsz-1)//bsz, ), (bsz, ), 48 | ( 49 | _to_copy(x_ik), 50 | _to_copy(start_n), 51 | N, D, 52 | _to_copy(o_nk), 53 | ) 54 | ) 55 | 56 | ctx.save_for_backward(batch_i, sizes) 57 | return o_nk 58 | 59 | @staticmethod 60 | def backward(ctx, go_nk): 61 | batch_i, sizes = ctx.saved_tensors 62 | shape = [go_nk.shape[0]] + [1 for _ in go_nk.shape[1:]] 63 | 64 | # This code matches the implmentation of torch.mean(). 65 | # gx_ik = (go_nk * sizes.reshape(shape).float().reciprocal())[batch_i] 66 | 67 | gx_ik = (go_nk / sizes.reshape(shape).float())[batch_i] 68 | 69 | return gx_ik, None, None 70 | -------------------------------------------------------------------------------- /models/kernels/reci_enc_fwd.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | 3 | void reci_enc_fwd( 4 | const float* a_ik, 5 | const float* kr_base_e, 6 | const float* rvecs_n, 7 | const float* vcell_n, 8 | const long long int* batch_i, 9 | const long long int* edge_ij_e, 10 | const long long int N, 11 | const long long int H, 12 | const long long int E, 13 | float* z_ek, 14 | float* sumexp_ek){ 15 | 16 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 17 | if (tid >= E*H) return; 18 | 19 | const long long int k = tid%H; 20 | const long long int e = tid/H; 21 | const long long int i = edge_ij_e[e]; 22 | const long long int j = edge_ij_e[E+e]; 23 | const long long int n = batch_i[i]; 24 | kr_base_e += e*3; 25 | const float kr_base_1 = kr_base_e[0]; 26 | const float kr_base_2 = kr_base_e[1]; 27 | const float kr_base_3 = kr_base_e[2]; 28 | rvecs_n += n*9; 29 | const float r1_x = rvecs_n[0]; 30 | const float r1_y = rvecs_n[1]; 31 | const float r1_z = rvecs_n[2]; 32 | const float r2_x = rvecs_n[3]; 33 | const float r2_y = rvecs_n[4]; 34 | const float r2_z = rvecs_n[5]; 35 | const float r3_x = rvecs_n[6]; 36 | const float r3_y = rvecs_n[7]; 37 | const float r3_z = rvecs_n[8]; 38 | const float a = a_ik[i*H + k]; 39 | const float vcell = vcell_n[n]; 40 | const int R = LATTICE_RANGE; 41 | 42 | // Unlike real space, normalization using max_logit is not needed since always max_logit = 0. 43 | float sum = 0.0; 44 | float sum_exp = 0.0; 45 | 46 | // Because of symmetry, n1 range can be [0, R] instead of [-R, R] 47 | // by scaling by a factor of 2 for values at n1!=0. 48 | // for (float n1 = -R; n1 <= R; n1++){ const float scale=1; 49 | for (float n1 = 0; n1 <= R; n1++) { float scale = n1==0.0f ? 1.0f : 2.0f; 50 | for (float n2 = -R; n2 <= R; n2++) 51 | for (float n3 = -R; n3 <= R; n3++) 52 | { 53 | float k1 = r1_x*n1 + r2_x*n2 + r3_x*n3; 54 | float k2 = r1_y*n1 + r2_y*n2 + r3_y*n3; 55 | float k3 = r1_z*n1 + r2_z*n2 + r3_z*n3; 56 | float exp_ak = expf(a*(k1*k1 + k2*k2 + k3*k3)); 57 | sum += exp_ak * cosf(kr_base_1*n1 + kr_base_2*n2 + kr_base_3*n3) * scale; 58 | } 59 | } 60 | // if (n == 0 && sum < 0) { 61 | // printf("%f, %f, (%f %f %f)\n", sum, a, kr_base_1, kr_base_2, kr_base_3); 62 | // } 63 | 64 | const float C_4PI = 12.5663706144; 65 | float log_ci = 1.5f*(logf(-C_4PI*a)) - logf(vcell); 66 | //log_ci = 0; 67 | z_ek[tid] = logf(fmaxf(sum,1e-6)) + log_ci; 68 | sumexp_ek[tid] = sum; 69 | } -------------------------------------------------------------------------------- /models/kernels/reci_enc_fwd_v2.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | 3 | void reci_enc_fwd_v2( 4 | const float* a_ik, 5 | const float* kr_base_e, 6 | const float* rvecs_n, 7 | const float* vcell_n, 8 | const long long int* batch_i, 9 | const long long int* edge_ij_e, 10 | const long long int N, 11 | const long long int H, 12 | const long long int E, 13 | float* z_ek, 14 | float* sumexp_ek){ 15 | 16 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 17 | if (tid >= E*H) return; 18 | 19 | const long long int k = tid%H; 20 | const long long int e = tid/H; 21 | const long long int i = edge_ij_e[e]; 22 | const long long int j = edge_ij_e[E+e]; 23 | const long long int n = batch_i[i]; 24 | kr_base_e += e*3; 25 | const float kr_base_1 = kr_base_e[0]; 26 | const float kr_base_2 = kr_base_e[1]; 27 | const float kr_base_3 = kr_base_e[2]; 28 | rvecs_n += n*9; 29 | const float r1_x = rvecs_n[0]; 30 | const float r1_y = rvecs_n[1]; 31 | const float r1_z = rvecs_n[2]; 32 | const float r2_x = rvecs_n[3]; 33 | const float r2_y = rvecs_n[4]; 34 | const float r2_z = rvecs_n[5]; 35 | const float r3_x = rvecs_n[6]; 36 | const float r3_y = rvecs_n[7]; 37 | const float r3_z = rvecs_n[8]; 38 | const float a = a_ik[i*H + k]; 39 | const float vcell = vcell_n[n]; 40 | const int R = LATTICE_RANGE; 41 | 42 | // Unlike real space, normalization using max_logit is not needed since always max_logit = 0. 43 | float sum = 0.0; 44 | float sum_exp = 0.0; 45 | 46 | // Because of symmetry, n1 range can be [0, R] instead of [-R, R] 47 | // by scaling by a factor of 2 for values at n1!=0. 48 | // for (float n1 = -R; n1 <= R; n1++){ const float scale=1; 49 | for (float n1 = 0, s1=0; n1 <= R; n1++, sum+=s1, s1=0) { float scale = n1==0.0f ? 1.0f : 2.0f; 50 | for (float n2 = -R, s2=0; n2 <= R; n2++, s1 +=s2, s2=0) 51 | for (float n3 = -R; n3 <= R; n3++) 52 | { 53 | float k1 = r1_x*n1 + r2_x*n2 + r3_x*n3; 54 | float k2 = r1_y*n1 + r2_y*n2 + r3_y*n3; 55 | float k3 = r1_z*n1 + r2_z*n2 + r3_z*n3; 56 | float exp_ak = expf(a*(k1*k1 + k2*k2 + k3*k3)); 57 | s2 += exp_ak * cosf(kr_base_1*n1 + kr_base_2*n2 + kr_base_3*n3) * scale; 58 | } 59 | } 60 | // if (n == 0 && sum < 0) { 61 | // printf("%f, %f, (%f %f %f)\n", sum, a, kr_base_1, kr_base_2, kr_base_3); 62 | // } 63 | 64 | const float C_4PI = 12.5663706144; 65 | float log_ci = 1.5f*(logf(-C_4PI*a)) - logf(vcell); 66 | //log_ci = 0; 67 | z_ek[tid] = logf(fmaxf(sum,1e-6)) + log_ci; 68 | sumexp_ek[tid] = sum; 69 | } -------------------------------------------------------------------------------- /models/kernels/reci_enc_bwd.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | 3 | void reci_enc_bwd( 4 | const float* a_ik, 5 | const float* kr_base_e, 6 | const float* rvecs_n, 7 | const float* vcell_n, 8 | const long long int* batch_i, 9 | const long long int* edge_ij_e, 10 | const long long int* e_start_i, 11 | const float* z_ek, 12 | const float* gz_ek, 13 | const float* sumexp_ek, 14 | const long long int N, 15 | const long long int H, 16 | const long long int E, 17 | float* ga_ik){ 18 | 19 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 20 | if (tid >= N*H) return; 21 | 22 | const long long int k = tid%H; 23 | const long long int i = tid/H; 24 | const long long int n = batch_i[i]; 25 | rvecs_n += n*9; 26 | const float r1_x = rvecs_n[0]; 27 | const float r1_y = rvecs_n[1]; 28 | const float r1_z = rvecs_n[2]; 29 | const float r2_x = rvecs_n[3]; 30 | const float r2_y = rvecs_n[4]; 31 | const float r2_z = rvecs_n[5]; 32 | const float r3_x = rvecs_n[6]; 33 | const float r3_y = rvecs_n[7]; 34 | const float r3_z = rvecs_n[8]; 35 | const float a = a_ik[i*H + k]; 36 | const int R = LATTICE_RANGE; 37 | const long long int e_start = e_start_i[i]; 38 | const long long int e_end = e_start_i[i+1]; 39 | 40 | float sum = 0; 41 | for (long long int e = e_start; e < e_end; e++) 42 | { 43 | const long long int j = edge_ij_e[E+e]; 44 | const float* kr_base = &kr_base_e[e*3]; 45 | const float kr_base_1 = kr_base[0]; 46 | const float kr_base_2 = kr_base[1]; 47 | const float kr_base_3 = kr_base[2]; 48 | const long long int ek = e*H+k; 49 | const float z = z_ek[ek]; 50 | const float sum_exp = sumexp_ek[ek]; 51 | const float gz = gz_ek[ek]; 52 | 53 | float s = 0; 54 | // Because of symmetry, n1 range can be [0, R] instead of [-R, R] 55 | // by scaling values at n1 != 0 by a factor of 2. 56 | //for (float n1 = -R; n1 <= R; n1++) { const float scale = 1.0f; 57 | for (float n1 = 0; n1 <= R; n1++) { float scale = n1==0.0f ? 1.0f : 2.0f; 58 | for (float n2 = -R; n2 <= R; n2++) 59 | for (float n3 = -R; n3 <= R; n3++) 60 | { 61 | float k1 = r1_x*n1 + r2_x*n2 + r3_x*n3; 62 | float k2 = r1_y*n1 + r2_y*n2 + r3_y*n3; 63 | float k3 = r1_z*n1 + r2_z*n2 + r3_z*n3; 64 | float kk = (k1*k1 + k2*k2 + k3*k3); 65 | s += kk * expf(a*kk) * cosf(kr_base_1*n1 + kr_base_2*n2 + kr_base_3*n3) * scale; 66 | } 67 | } 68 | //sum += s*gz/fmaxf(sum_exp, 1e-6); 69 | sum += s*gz*(sum_exp < 1e-6? 0.0f : 1.0f/sum_exp); 70 | } 71 | ga_ik[tid] = sum; 72 | } -------------------------------------------------------------------------------- /data/download_megnet_elastic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import pickle 4 | from tqdm import tqdm 5 | from jarvis.core.atoms import Atoms 6 | 7 | # Prepare bulk and shear megnet dataset from the Matformer paper (NeurIPS 22): 8 | # https://github.com/YKQ98/Matformer/tree/569a7e9331b2acacc184fab38f5f6085e46a9881 9 | urls = [ 10 | 'https://figshare.com/ndownloader/files/40258705', 11 | 'https://figshare.com/ndownloader/files/40258675', 12 | 'https://figshare.com/ndownloader/files/40258666', 13 | 'https://figshare.com/ndownloader/files/40258681', 14 | 'https://figshare.com/ndownloader/files/40258684', 15 | 'https://figshare.com/ndownloader/files/40258678' 16 | ] 17 | 18 | elems = { 19 | 'bulk': { 20 | 'id': ('material_id', str), 21 | 'bulk modulus': ('bulk_modulus', float), 22 | 'structure': ('structure', object) 23 | }, 24 | 'shear': { 25 | 'id': ('material_id', str), 26 | 'shear modulus': ('shear_modulus', float), 27 | 'structure': ('structure', object) 28 | }, 29 | } 30 | 31 | count = 0 32 | for target in ["bulk", "shear"]: 33 | outdir = f'jarvis__megnet-{target}' 34 | os.makedirs(outdir, exist_ok=True) 35 | 36 | all = [] 37 | for split in ["train", "val", "test"]: 38 | url = urls[count] 39 | filename = f'{outdir}/{target}_megnet_{split}.pkl' 40 | 41 | # under proxy, use verify=False to avoid an SSL error. 42 | urlData = requests.get(url, verify=False).content 43 | with open(filename ,mode='wb') as f: 44 | f.write(urlData) 45 | 46 | with open(filename, mode="rb") as fp: 47 | data = pickle.load(fp) 48 | 49 | new_data = [] 50 | for x in tqdm(data): 51 | atoms = Atoms( 52 | lattice_mat=x['atoms']['lattice_mat'], 53 | coords=x['atoms']['coords'], 54 | elements=x['atoms']['elements'], 55 | cartesian=x['atoms']['cartesian'], 56 | ) 57 | x['structure'] = atoms.pymatgen_converter() 58 | 59 | new_x = {} 60 | for key, (newkey, vtype) in elems[target].items(): 61 | val = x[key] 62 | 63 | if vtype == float and type(val) != float: 64 | val = float(val) 65 | elif vtype == int and type(val) != int: 66 | val = int(val) 67 | elif vtype == str and type(val) != str: 68 | val = str(val) 69 | 70 | new_x[newkey] = val 71 | new_data.append(new_x) 72 | 73 | os.makedirs(f'{outdir}/{split}/raw', exist_ok=True) 74 | with open(f'{outdir}/{split}/raw/raw_data.pkl', mode="wb") as fp: 75 | pickle.dump(new_data, fp) 76 | print(new_data[0]) 77 | 78 | count += 1 -------------------------------------------------------------------------------- /models/kernels/reci_enc_bwd_v2.cu: -------------------------------------------------------------------------------- 1 | extern "C" __global__ 2 | 3 | void reci_enc_bwd_v2( 4 | const float* a_ik, 5 | const float* kr_base_e, 6 | const float* rvecs_n, 7 | const float* vcell_n, 8 | const long long int* batch_i, 9 | const long long int* edge_ij_e, 10 | const long long int* e_start_i, 11 | const float* z_ek, 12 | const float* gz_ek, 13 | const float* sumexp_ek, 14 | const long long int N, 15 | const long long int H, 16 | const long long int E, 17 | float* ga_ik){ 18 | 19 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 20 | if (tid >= N*H) return; 21 | 22 | const long long int k = tid%H; 23 | const long long int i = tid/H; 24 | const long long int n = batch_i[i]; 25 | rvecs_n += n*9; 26 | const float r1_x = rvecs_n[0]; 27 | const float r1_y = rvecs_n[1]; 28 | const float r1_z = rvecs_n[2]; 29 | const float r2_x = rvecs_n[3]; 30 | const float r2_y = rvecs_n[4]; 31 | const float r2_z = rvecs_n[5]; 32 | const float r3_x = rvecs_n[6]; 33 | const float r3_y = rvecs_n[7]; 34 | const float r3_z = rvecs_n[8]; 35 | const float a = a_ik[i*H + k]; 36 | const int R = LATTICE_RANGE; 37 | const long long int e_start = e_start_i[i]; 38 | const long long int e_end = e_start_i[i+1]; 39 | 40 | float sum = 0; 41 | for (long long int e = e_start; e < e_end; e++) 42 | { 43 | const long long int j = edge_ij_e[E+e]; 44 | const float* kr_base = &kr_base_e[e*3]; 45 | const float kr_base_1 = kr_base[0]; 46 | const float kr_base_2 = kr_base[1]; 47 | const float kr_base_3 = kr_base[2]; 48 | const long long int ek = e*H+k; 49 | const float z = z_ek[ek]; 50 | const float sum_exp = sumexp_ek[ek]; 51 | const float gz = gz_ek[ek]; 52 | 53 | float s = 0; 54 | // Because of symmetry, n1 range can be [0, R] instead of [-R, R] 55 | // by scaling values at n1 != 0 by a factor of 2. 56 | //for (float n1 = -R; n1 <= R; n1++) { const float scale = 1.0f; 57 | for (float n1 = 0, s1=0; n1 <= R; n1++, s +=s1, s1=0) { float scale = n1==0.0f ? 1.0f : 2.0f; 58 | for (float n2 = -R, s2=0; n2 <= R; n2++, s1+=s2, s2=0) 59 | for (float n3 = -R; n3 <= R; n3++) 60 | { 61 | float k1 = r1_x*n1 + r2_x*n2 + r3_x*n3; 62 | float k2 = r1_y*n1 + r2_y*n2 + r3_y*n3; 63 | float k3 = r1_z*n1 + r2_z*n2 + r3_z*n3; 64 | float kk = (k1*k1 + k2*k2 + k3*k3); 65 | s2 += kk * expf(a*kk) * cosf(kr_base_1*n1 + kr_base_2*n2 + kr_base_3*n3) * scale; 66 | } 67 | } 68 | //sum += s*gz/fmaxf(sum_exp, 1e-6); 69 | sum += s*gz*(sum_exp < 1e-6? 0.0f : 1.0f/sum_exp); 70 | } 71 | ga_ik[tid] = sum; 72 | } -------------------------------------------------------------------------------- /models/kernels/irregular_mean_fwd.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/pairwise_sum.cuh" 2 | 3 | extern "C" __global__ 4 | void irregular_mean_fwd( 5 | const float* src_n, 6 | const long long int* start_n, 7 | const long long int B, 8 | const long long int D, 9 | float* dst_n 10 | ){ 11 | const unsigned long long int tid = (unsigned long long int)blockDim.x * blockIdx.x + threadIdx.x; 12 | if (tid >= B*D) return; 13 | const unsigned int n = tid/D; 14 | const unsigned int k = tid%D; 15 | const unsigned long long int start = start_n[n]; 16 | const unsigned long long int end = start_n[n+1]; 17 | int len = end - start; 18 | 19 | src_n += start*D + k; 20 | float sum = 0; 21 | 22 | // This code matches torch.sum() when batch_num = 1. 23 | // constexpr int K = 4; 24 | // int batch_size = 8192/max(1<(src_n, D, len, data); break; 64 | 65 | case 3:// 8 66 | sum = pairwise_sum<3, INTER_NUM>(src_n, D, len, data); break; 67 | 68 | case 4:// 16 69 | sum = pairwise_sum<4, INTER_NUM>(src_n, D, len, data); break; 70 | 71 | case 5:// 32 72 | sum = pairwise_sum<5, INTER_NUM>(src_n, D, len, data); break; 73 | 74 | case 6:// 64 75 | sum = pairwise_sum<6, INTER_NUM>(src_n, D, len, data); break; 76 | 77 | case 7:// 128 78 | sum = pairwise_sum<7, INTER_NUM>(src_n, D, len, data); break; 79 | 80 | case 8:// 256 81 | sum = pairwise_sum<8, INTER_NUM>(src_n, D, len, data); break; 82 | 83 | case 9:// 512 84 | sum = pairwise_sum<9, INTER_NUM>(src_n, D, len, data); break; 85 | 86 | case 10:// 1024 87 | sum = pairwise_sum<10, INTER_NUM>(src_n, D, len, data); break; 88 | } 89 | 90 | dst_n[n*D+k] = sum / (float)len; 91 | } -------------------------------------------------------------------------------- /models/kernels/fused_dpa_fwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/reduce_kernel_utils.cuh" 2 | 3 | __global__ __device__ void fused_dpa_fwd_thread( 4 | const float* que_k, 5 | const float* key_ihk, 6 | const float* val_ihk, 7 | const float* aij_eh, 8 | const float* bij_ehk, 9 | const long long int* edge_j_e, 10 | const long long int e_start, 11 | const long long int e_end, 12 | const long long int H, 13 | float* prob_eh, 14 | float* out_k 15 | ) { 16 | long long int e = e_start + threadIdx.x; 17 | bool isValid = e < e_end; 18 | 19 | float attn; 20 | long long int j = 0; 21 | if (isValid){ 22 | j = edge_j_e[e]; 23 | key_ihk += (j*H)*K_HEAD_DIM; 24 | 25 | attn = 0; 26 | #pragma unroll 27 | for (int k = 0; k < K_HEAD_DIM; k++){ 28 | attn += que_k[k]*key_ihk[k]; 29 | } 30 | if (aij_eh != NULL) attn += aij_eh[e*H]; 31 | } else { 32 | attn = -1e20; 33 | } 34 | __syncthreads(); 35 | 36 | float max_attn = blockReduceMax(attn); 37 | 38 | attn = exp(attn - max_attn); 39 | float sum = blockReduceSum(attn); 40 | 41 | attn /= sum; 42 | if (isValid) 43 | prob_eh[e*H] = attn; 44 | 45 | if (bij_ehk != NULL) { 46 | val_ihk += (j*H)*V_HEAD_DIM; 47 | bij_ehk += (e*H)*V_HEAD_DIM; 48 | 49 | #pragma unroll 50 | for (int k = 0; k < V_HEAD_DIM; k++){ 51 | float output = isValid ? (val_ihk[k]+bij_ehk[k])*attn : 0.0f; 52 | __syncthreads(); 53 | output = blockReduceSum(output); 54 | if (threadIdx.x == 0){ 55 | out_k[k] = output; 56 | } 57 | } 58 | } else { 59 | val_ihk += (j*H)*V_HEAD_DIM; 60 | 61 | #pragma unroll 62 | for (int k = 0; k < V_HEAD_DIM; k++){ 63 | float output = isValid ? (val_ihk[k])*attn : 0.0f; 64 | __syncthreads(); 65 | output = blockReduceSum(output); 66 | if (threadIdx.x == 0) 67 | out_k[k] = output; 68 | } 69 | } 70 | } 71 | 72 | extern "C" __global__ 73 | void fused_dpa_fwd_v2( 74 | const float* que_ihk, 75 | const float* key_ihk, 76 | const float* val_ihk, 77 | const float* aij_eh, 78 | const float* bij_ehk, 79 | const long long int* edge_ij_e, 80 | const long long int* e_start_i, 81 | const long long int N, 82 | const long long int H, 83 | const long long int E, 84 | float* prob_eh, 85 | float* out_ihk 86 | ){ 87 | const long long int tid = (long long int)blockIdx.x*blockDim.x + threadIdx.x; 88 | if (tid >= N*H) return; 89 | 90 | const long long int i = tid / H; 91 | const long long int h = tid % H; 92 | const long long int e_start = e_start_i[i]; 93 | const long long int e_end = e_start_i[i+1]; 94 | 95 | fused_dpa_fwd_thread<<< 1, ((e_end-e_start+31)/32)*32 >>>( 96 | que_ihk + tid*K_HEAD_DIM, 97 | key_ihk + h*K_HEAD_DIM, 98 | val_ihk + h*V_HEAD_DIM, 99 | aij_eh ? aij_eh + h : aij_eh, 100 | bij_ehk ? bij_ehk+ h*V_HEAD_DIM : bij_ehk, 101 | edge_ij_e + E, 102 | e_start, 103 | e_end, 104 | H, 105 | prob_eh + h, 106 | out_ihk + tid*V_HEAD_DIM 107 | ); 108 | 109 | } -------------------------------------------------------------------------------- /models/latticeformer_params.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Type, List 2 | import copy 3 | Entity = TypeVar('Entity', bound='LatticeformerParams') 4 | 5 | class LatticeformerParams: 6 | def __init__(self, 7 | domain:str="real", 8 | lattice_range:int=4, 9 | minimum_range:bool=True, 10 | adaptive_cutoff_sigma:float=-3.5, 11 | gauss_lb_real:float=0.5, 12 | gauss_lb_reci:float=0.5, 13 | scale_real:List[float]=[1.4], 14 | scale_reci:List[float]=[2.2], 15 | normalize_gauss:bool=True, 16 | value_pe_dist_real:int=64, 17 | value_pe_wave_real:int=0, 18 | value_pe_dist_reci:int=0, 19 | value_pe_wave_reci:int=0, 20 | value_pe_headed:bool=True, 21 | value_pe_condproj:str="no", 22 | positive_func:str='elu=0.1', 23 | exclude_self:bool=False, 24 | layer_index:int=-1, 25 | norm_func_mode:int=0, 26 | value_pe_dist_max:float=-10.0, 27 | value_pe_width_scale:float=1.0, 28 | gauss_state:str="q", 29 | use_low_memory:bool=False, 30 | ) -> None: 31 | 32 | self.layer_index = layer_index 33 | self.domain = domain 34 | self.lattice_range = lattice_range 35 | self.minimum_range = minimum_range 36 | self.adaptive_cutoff_sigma = adaptive_cutoff_sigma 37 | self.gauss_lb_real = gauss_lb_real 38 | self.gauss_lb_reci = gauss_lb_reci 39 | self.scale_real = scale_real 40 | self.scale_reci = scale_reci 41 | self.normalize_gauss = normalize_gauss 42 | self.value_pe_dist_real = value_pe_dist_real 43 | self.value_pe_wave_real = value_pe_wave_real 44 | self.value_pe_dist_reci = value_pe_dist_reci 45 | self.value_pe_wave_reci = value_pe_wave_reci 46 | self.value_pe_headed = value_pe_headed 47 | self.value_pe_condproj = value_pe_condproj 48 | self.positive_func = positive_func 49 | self.exclude_self = exclude_self 50 | self.norm_func_mode = norm_func_mode 51 | self.value_pe_dist_max = value_pe_dist_max 52 | self.value_pe_width_scale = value_pe_width_scale 53 | self.gauss_state = gauss_state 54 | self.use_low_memory = use_low_memory 55 | 56 | def parseFromArgs(self, args): 57 | for key in self.__dict__: 58 | self.__dict__[key] = getattr(args, key, self.__dict__[key]) 59 | print("Parsed LatticeformerParams:") 60 | print(self.__dict__) 61 | 62 | def getLayerParameters(self, layer_index) -> Entity: 63 | if self.domain in ("real", "reci", "multihead"): 64 | domain = self.domain 65 | else: 66 | domains = self.domain.split('-') 67 | domain = domains[layer_index % len(domains)] 68 | 69 | scale_real = self.scale_real 70 | scale_reci = self.scale_reci 71 | if isinstance(scale_real, (list,tuple)): 72 | scale_real = scale_real[layer_index % len(scale_real)] 73 | if isinstance(scale_reci, (list,tuple)): 74 | scale_reci = scale_reci[layer_index % len(scale_reci)] 75 | 76 | params = copy.deepcopy(self) 77 | params.domain = domain 78 | params.scale_real = scale_real 79 | params.scale_reci = scale_reci 80 | params.layer_index = layer_index 81 | return params 82 | -------------------------------------------------------------------------------- /dataloaders/dataset_latticeformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch_geometric.data import Data 4 | from pymatgen.symmetry.analyzer import SpacegroupAnalyzer 5 | 6 | from dataloaders.common import MultimodalDatasetMP, RegressionDatasetMP 7 | from dataloaders.common import generate_site_species_vector 8 | from .common import CellFormat 9 | 10 | 11 | def make_data(material, ATOM_NUM_UPPER, cell_format:CellFormat): 12 | if "final_structure" in material: 13 | structure = material['final_structure'] 14 | elif "structure" in material: 15 | structure = material['structure'] 16 | else: 17 | raise AttributeError("Material has no structure!") 18 | if cell_format == CellFormat.CONVENTIONAL: 19 | structure = SpacegroupAnalyzer(structure).get_conventional_standard_structure() 20 | elif cell_format == CellFormat.PRIMITIVE: 21 | structure = SpacegroupAnalyzer(structure).get_primitive_standard_structure() 22 | # # assert len(structure.cart_coords) == len(primitive.cart_coords), f"{len(structure.cart_coords)}, {len(primitive.cart_coords)}" 23 | 24 | 25 | if "material_id" in material: 26 | id = material['material_id'] 27 | elif "file_id" in material: 28 | id = material['file_id'] 29 | else: 30 | id = material['id'] 31 | 32 | atom_pos = torch.tensor(structure.cart_coords, dtype=torch.float) 33 | atom_fea = generate_site_species_vector(structure, ATOM_NUM_UPPER) 34 | data = Data(x=atom_fea, y=None, pos=atom_pos) 35 | data.trans_vec = torch.tensor(structure.lattice.matrix, dtype=torch.float)[None] 36 | data.material_id = id 37 | data.sizes = torch.tensor([atom_pos.shape[0]], dtype=torch.long) 38 | return data 39 | 40 | class MultimodalDatasetMP_Latticeformer(MultimodalDatasetMP): 41 | def __init__(self, params, target_split, target_set=None, post_filter=None): 42 | self.use_primitive = params.use_primitive if hasattr(params, 'use_primitive') else True 43 | 44 | super(MultimodalDatasetMP_Latticeformer, self).__init__(target_split, target_set, post_filter) 45 | 46 | @property 47 | def processed_file_names(self): 48 | if self.use_primitive: 49 | return 'processed_data_latticeformer.pt' 50 | else: 51 | return 'processed_data_convcell_latticeformer.pt' 52 | 53 | def process_input(self, material): 54 | return make_data(material, self.ATOM_NUM_UPPER, self.use_primitive) 55 | 56 | # In torch_geometric.data.dataset.Dataset, these functions are checked 57 | # if exist in self.__class__.__dict__.keys(). But __dict__ does not capture 58 | # the inherited functions. So, here explicitly claim the process and download functions 59 | def process(self): 60 | super().process() 61 | def download(self): 62 | super().download() 63 | 64 | 65 | class RegressionDatasetMP_Latticeformer(RegressionDatasetMP): 66 | def __init__(self, target_split, target_set=None, cell_format:CellFormat=CellFormat.PRIMITIVE, post_filter=None): 67 | self.model_name = "latticeformer" 68 | super(RegressionDatasetMP_Latticeformer, self).__init__(target_split, target_set, cell_format, post_filter) 69 | 70 | def process_input(self, material): 71 | return make_data(material, self.ATOM_NUM_UPPER, self.cell_format) 72 | 73 | # In torch_geometric.data.dataset.Dataset, these functions are checked 74 | # if exist in self.__class__.__dict__.keys(). But __dict__ does not capture 75 | # the inherited functions. So, here explicitly claim the process and download functions 76 | def process(self): 77 | super().process() 78 | def download(self): 79 | super().download() 80 | -------------------------------------------------------------------------------- /models/cuda_funcs/reci_space_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | class ReciPeriodicEncodingFuncCUDA(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, a_ik, kr_base_e, rvecs_n, vcell_n, batch_i, edge_ij_e): 20 | # a_ik : (points, heads) 21 | # kr_base_e : (edges, 3) 22 | # rvecs_n : (batch, 3, 3) 23 | # vcell_n : (batch) 24 | # batch_i : (points) 25 | # edge_ij_e : (2, edges) 26 | # z_ijk = log( sum_n exp( a_ik*|pj + t1*n1+t2*n2+t3*n3 - pi|^2 ) ) 27 | # : (edges, heads) 28 | N, H = a_ik.shape 29 | E = edge_ij_e.shape[1] 30 | kw = {'device': a_ik.device, 'dtype': a_ik.dtype} 31 | 32 | a_ik = a_ik.contiguous().detach() 33 | kr_base_e = kr_base_e.contiguous() 34 | rvecs_n = rvecs_n.contiguous() 35 | vcell_n = vcell_n.contiguous() 36 | batch_i = batch_i.contiguous() 37 | edge_ij_e = edge_ij_e.contiguous() 38 | 39 | z_ek = torch.empty((E, H), **kw) 40 | sumexp_ek = torch.empty((E, H), **kw) 41 | bsz = H 42 | dev = a_ik.device 43 | 44 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 45 | from .. import global_config as config 46 | kernel = KernelManager.reci_enc_fwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 47 | else KernelManager.reci_enc_fwd 48 | kernel(((E*H+bsz-1)//bsz, ), (bsz, ), 49 | ( 50 | _to_copy(a_ik), 51 | _to_copy(kr_base_e), 52 | _to_copy(rvecs_n), 53 | _to_copy(vcell_n), 54 | _to_copy(batch_i), 55 | _to_copy(edge_ij_e), 56 | N, H, E, 57 | _to_copy(z_ek), 58 | _to_copy(sumexp_ek), 59 | ) 60 | ) 61 | 62 | ctx.save_for_backward(a_ik, kr_base_e, rvecs_n, vcell_n, batch_i, edge_ij_e, z_ek, sumexp_ek) 63 | return z_ek 64 | 65 | @staticmethod 66 | def backward(ctx, gz_ek): 67 | a_ik, kr_base_e, rvecs_n, vcell_n, batch_i, edge_ij_e, z_ek, sumexp_ek = ctx.saved_tensors 68 | N, H = a_ik.shape 69 | E = edge_ij_e.shape[1] 70 | 71 | e_start_i = torch.zeros(N+1, dtype=batch_i.dtype, device=batch_i.device) 72 | e_start_i.scatter_add_(0, edge_ij_e[0]+1, torch.ones_like(edge_ij_e[0])) 73 | e_start_i = e_start_i.cumsum(0) 74 | 75 | ga_ik = torch.empty_like(a_ik) 76 | bsz = H 77 | dev = a_ik.device 78 | 79 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 80 | from .. import global_config as config 81 | kernel = KernelManager.reci_enc_bwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 82 | else KernelManager.reci_enc_bwd 83 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), 84 | ( 85 | _to_copy(a_ik), 86 | _to_copy(kr_base_e), 87 | _to_copy(rvecs_n), 88 | _to_copy(vcell_n), 89 | _to_copy(batch_i), 90 | _to_copy(edge_ij_e), 91 | _to_copy(e_start_i), 92 | _to_copy(z_ek.detach()), 93 | _to_copy(gz_ek.detach().contiguous()), 94 | _to_copy(sumexp_ek.detach()), 95 | N, H, E, 96 | _to_copy(ga_ik), 97 | ) 98 | ) 99 | return ga_ik, None, None, None, None, None 100 | -------------------------------------------------------------------------------- /models/kernels/pairwise_sum.cuh: -------------------------------------------------------------------------------- 1 | 2 | __forceinline__ __device__ int log2_ceil(int value) { 3 | int log2_value = 0; 4 | while ((1 << log2_value) < value) ++log2_value; 5 | return log2_value; 6 | } 7 | 8 | 9 | #if 0 10 | template 11 | __forceinline__ __device__ float fixed_length_pairwise_sum(float *input) 12 | { 13 | #pragma unroll 14 | for (int i = num/2; i > 0; i /= 2 ){ 15 | #pragma unroll 16 | for (int j = 0; j < i; j++){ 17 | input[j] += input[j+i]; 18 | } 19 | } 20 | return input[0]; 21 | } 22 | template <> __forceinline__ float fixed_length_pairwise_sum<1>(float *input) { 23 | return input[0]; 24 | } 25 | template <> __forceinline__ float fixed_length_pairwise_sum<2>(float *input) { 26 | return input[0] + input[1]; 27 | } 28 | template <> __forceinline__ float fixed_length_pairwise_sum<4>(float *input) { 29 | return (input[0] + input[2]) + (input[1] + input[3]); 30 | } 31 | template <> __forceinline__ float fixed_length_pairwise_sum<8>(float *input) { 32 | return ((input[0] + input[4]) + (input[2] + input[6])) \ 33 | + ((input[1] + input[5]) + (input[3] + input[7])); 34 | } 35 | 36 | #else 37 | template 38 | __forceinline__ __device__ float fixed_length_pairwise_sum(float *input) 39 | { 40 | // #pragma unroll 41 | // for (int i = 1; i <= num/2; i *= 2 ){ 42 | // #pragma unroll 43 | // for (int j = 0; j < num; j += 2*i){ 44 | // input[j] += input[j+i]; 45 | // } 46 | // } 47 | // return input[0]; 48 | return fixed_length_pairwise_sum(input) + fixed_length_pairwise_sum(input+num/2); 49 | } 50 | 51 | template <> __forceinline__ float fixed_length_pairwise_sum<1>(float *input) { 52 | return input[0]; 53 | } 54 | template <> __forceinline__ float fixed_length_pairwise_sum<2>(float *input) { 55 | return input[0] + input[1]; 56 | } 57 | template <> __forceinline__ float fixed_length_pairwise_sum<3>(float *input) { 58 | return (input[0] + input[1]) + (input[2]); 59 | } 60 | template <> __forceinline__ float fixed_length_pairwise_sum<4>(float *input) { 61 | return (input[0] + input[1]) + (input[2] + input[3]); 62 | } 63 | template <> __forceinline__ float fixed_length_pairwise_sum<5>(float *input) { 64 | return fixed_length_pairwise_sum<4>(input) + input[4]; 65 | } 66 | template <> __forceinline__ float fixed_length_pairwise_sum<7>(float *input) { 67 | return fixed_length_pairwise_sum<4>(input) + ((input[4] + input[5]) + input[6]); 68 | } 69 | template <> __forceinline__ float fixed_length_pairwise_sum<8>(float *input) { 70 | return ((input[0] + input[1]) + (input[2] + input[3])) \ 71 | + ((input[4] + input[5]) + (input[6] + input[7])); 72 | } 73 | template <> __forceinline__ float fixed_length_pairwise_sum<9>(float *input) { 74 | return fixed_length_pairwise_sum<8>(input) + input[8]; 75 | } 76 | #endif 77 | 78 | template 79 | float pairwise_sum(const float *src, int stride, int num, float *buff) 80 | { 81 | constexpr int next_power_of_two = 1 << log2_elements; 82 | constexpr int batch_num = (next_power_of_two>iter_num) ? next_power_of_two/iter_num : 1; 83 | float mini_batch[iter_num]; 84 | 85 | #pragma unroll 86 | for (int i = 0; i < batch_num; i++){ 87 | #if 1 88 | #pragma unroll 89 | for (int j = 0; j < iter_num; j++){ 90 | int index = i*iter_num + j; 91 | if (index < num){ 92 | mini_batch[j] = src[index*stride]; 93 | } else { 94 | mini_batch[j] = 0; 95 | } 96 | } 97 | buff[i] = fixed_length_pairwise_sum(mini_batch); 98 | 99 | #else 100 | #pragma unroll 101 | buff[i] = 0; 102 | for (int j = 0; j < iter_num; j++){ 103 | int index = i*iter_num + j; 104 | if (index < num){ 105 | buff[i] += src[index*stride]; 106 | } 107 | } 108 | #endif 109 | } 110 | return fixed_length_pairwise_sum(buff); 111 | } 112 | 113 | -------------------------------------------------------------------------------- /models/kernels/fused_dpa_fwd.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/pairwise_sum.cuh" 2 | 3 | extern "C" __global__ 4 | void fused_dpa_fwd( 5 | const float* que_ihk, 6 | const float* key_ihk, 7 | const float* val_ihk, 8 | const float* aij_eh, 9 | const float* bij_ehk, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const long long int N, 13 | const long long int H, 14 | const long long int E, 15 | float* prob_eh, 16 | float* out_ihk 17 | ){ 18 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 19 | if (tid >= N*H) return; 20 | const long long int i = tid/H; 21 | const long long int h = tid%H; 22 | const long long int e_start = e_start_i[i]; 23 | const long long int e_end = e_start_i[i+1]; 24 | 25 | #if 0 26 | // Read q from global mem. 27 | const float* que_k = &que_ihk[tid*K_HEAD_DIM]; 28 | #else 29 | // Load q onto shared mem. 30 | __shared__ float _que_k[THREAD_NUM][K_HEAD_DIM+1]; 31 | que_ihk += tid*K_HEAD_DIM; 32 | float* que_k = _que_k[threadIdx.x]; 33 | #pragma unroll 34 | for (int k = 0; k < K_HEAD_DIM; k++){ 35 | que_k[k] = que_ihk[k]; 36 | } 37 | #endif 38 | 39 | __shared__ float _attns[THREAD_NUM][MAX_SYSTEM_SIZE+1]; 40 | __shared__ float _output[THREAD_NUM][V_HEAD_DIM+1]; 41 | float *attns = _attns[threadIdx.x]; 42 | float *output = _output[threadIdx.x]; 43 | 44 | float max_attn = -1e20; 45 | int e_count = 0; 46 | for (long long int e = e_start; e < e_end; e++) 47 | { 48 | long long int j = edge_ij_e[E+e]; 49 | const float* key_k = &key_ihk[(j*H+h)*K_HEAD_DIM]; 50 | 51 | float attn = 0; 52 | #pragma unroll 53 | for (int k = 0; k < K_HEAD_DIM; k++){ 54 | attn += que_k[k]*key_k[k]; 55 | } 56 | if (aij_eh != NULL) 57 | attn += aij_eh[e*H+h]; 58 | max_attn = max(max_attn, attn); 59 | attns[e_count] = attn; 60 | e_count++; 61 | } 62 | 63 | float sum = 0; 64 | for (int j = 0; j < e_count; j++) 65 | { 66 | float v = exp(attns[j] - max_attn); 67 | attns[j] = v; 68 | sum += v; 69 | } 70 | 71 | // Compute pairwise sum for better numerical accuracy. 72 | constexpr int BS = 64; 73 | for (int j = e_count; j < (e_count+BS-1)/BS*BS; j++) 74 | attns[j] = 0.0f; 75 | float s[MAX_SYSTEM_SIZE_POW2/BS] = {0.0f}; 76 | for (int j = 0; j < (e_count+BS-1)/BS; j++){ 77 | s[j] = fixed_length_pairwise_sum(&attns[j*BS]); 78 | } 79 | sum = fixed_length_pairwise_sum(s); 80 | 81 | 82 | for (int j = 0; j < e_count; j++){ 83 | attns[j] /= sum; 84 | prob_eh[(e_start+j)*H+h] = attns[j]; 85 | } 86 | 87 | #pragma unroll 88 | for (int k = 0; k < V_HEAD_DIM; k++){ 89 | output[k] = 0; 90 | } 91 | 92 | e_count = 0; 93 | if (bij_ehk != NULL) { 94 | for (long long int e = e_start; e < e_end; e++) 95 | { 96 | long long int j = edge_ij_e[E+e]; 97 | const float* val_k = &val_ihk[(j*H+h)*V_HEAD_DIM]; 98 | const float* bij_k = &bij_ehk[(e*H+h)*V_HEAD_DIM]; 99 | 100 | float attn = attns[e_count]; 101 | #pragma unroll 102 | for (int k = 0; k < V_HEAD_DIM; k++){ 103 | output[k] += (val_k[k]+bij_k[k])*attn; 104 | } 105 | e_count++; 106 | } 107 | } else { 108 | for (long long int e = e_start; e < e_end; e++) 109 | { 110 | long long int j = edge_ij_e[E+e]; 111 | const float* val_k = &val_ihk[(j*H+h)*V_HEAD_DIM]; 112 | 113 | float attn = attns[e_count]; 114 | #pragma unroll 115 | for (int k = 0; k < V_HEAD_DIM; k++){ 116 | output[k] += val_k[k]*attn; 117 | } 118 | e_count++; 119 | } 120 | } 121 | 122 | out_ihk += tid*V_HEAD_DIM; 123 | #pragma unroll 124 | for (int k = 0; k < V_HEAD_DIM; k++) 125 | out_ihk[k] = output[k]; 126 | } -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void fused_dpa_bwd( 4 | const float* que_ihk, 5 | const float* key_ihk, 6 | const float* val_ihk, 7 | const float* taij_eh, 8 | const float* tbij_ehk, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const long long int N, 13 | const long long int H, 14 | const long long int E, 15 | const float* tprob_eh, 16 | const float* out_ihk, 17 | const float* gout_ihk, 18 | float* gque_ihk, 19 | float* gkey_ihk, 20 | float* gval_ihk, 21 | float* tgaij_eh, 22 | float* tgbij_ehk 23 | ){ 24 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 25 | if (tid >= N*H) return; 26 | 27 | const long long int K = VPE_DIM; 28 | const long long int j = tid/H; 29 | const long long int h = tid%H; 30 | const long long int n = batch_i[j]; 31 | const long long int e_start = e_start_i[j]; 32 | const long long int e_end = e_start_i[j+1]; 33 | 34 | const float* que_k = &que_ihk[tid*K_HEAD_DIM]; 35 | __shared__ float _v[THREAD_NUM][V_HEAD_DIM+1]; 36 | __shared__ float _gv[THREAD_NUM][V_HEAD_DIM+1]; 37 | __shared__ float _gk[THREAD_NUM][K_HEAD_DIM+1]; 38 | float *v = _v[threadIdx.x]; 39 | float *gv = _gv[threadIdx.x]; 40 | float *gk = _gk[threadIdx.x]; 41 | 42 | const float *v_src = val_ihk + (j*H+h)*V_HEAD_DIM; 43 | #pragma unroll 44 | for (int k = 0; k < V_HEAD_DIM; k++){ 45 | gv[k] = 0; 46 | v[k] = v_src[k]; 47 | } 48 | 49 | #pragma unroll 50 | for (int k = 0; k < K_HEAD_DIM; k++) 51 | gk[k] = 0; 52 | 53 | if (tgbij_ehk != NULL && tbij_ehk != NULL ) { 54 | for (long long int e = e_start; e < e_end; e++) 55 | { 56 | long long int i = edge_ij_e[E+e]; 57 | 58 | float pij = tprob_eh[e*H+h]; 59 | const float *go = gout_ihk + (i*H+h)*V_HEAD_DIM; 60 | const float *o = out_ihk + (i*H+h)*V_HEAD_DIM; 61 | float *gb = tgbij_ehk + (e*H+h)*V_HEAD_DIM; 62 | const float *b = tbij_ehk + (e*H+h)*V_HEAD_DIM; 63 | float g_softmax = 0; 64 | #pragma unroll 65 | for (int k = 0; k < V_HEAD_DIM; k++){ 66 | float t = go[k]*pij; 67 | gv[k] += t; 68 | gb[k] = t; 69 | g_softmax += (v[k] + b[k] - o[k]) * t; 70 | } 71 | 72 | tgaij_eh[e*H+h] = g_softmax; 73 | 74 | const float *q = que_ihk + (i*H+h)*K_HEAD_DIM; 75 | #pragma unroll 76 | for (int k = 0; k < K_HEAD_DIM; k++){ 77 | gk[k] += g_softmax*q[k]; 78 | } 79 | // gb = go.reshape(s,1,H,K) * p.reshape(s,s,H,1) 80 | // gv = gb.sum(dim=0) 81 | // gval.append(gv) 82 | // gbij.append(gb.reshape(s*s,H,K)) 83 | // gsm = (v.reshape(1,s,H,K) + b.reshape(s,s,H,K) - o.reshape(s,1,H,K))*gb 84 | // ga = gsm.sum(dim=3) 85 | // gq = (ga.reshape(s,s,H,1)*k.reshape(1,s,H,K)).sum(dim=1) 86 | // gk = (ga.reshape(s,s,H,1)*q.reshape(s,1,H,K)).sum(dim=0) 87 | // gaij.append(ga.reshape(s*s,H)) 88 | // gque.append(gq) 89 | // gkey.append(gk) 90 | } 91 | } else { 92 | for (long long int e = e_start; e < e_end; e++) 93 | { 94 | long long int i = edge_ij_e[E+e]; 95 | 96 | float pij = tprob_eh[e*H+h]; 97 | const float *go = gout_ihk + (i*H+h)*V_HEAD_DIM; 98 | const float *o = out_ihk + (i*H+h)*V_HEAD_DIM; 99 | float g_softmax = 0; 100 | #pragma unroll 101 | for (int k = 0; k < V_HEAD_DIM; k++){ 102 | float t = go[k]*pij; 103 | gv[k] += t; 104 | g_softmax += (v[k] - o[k]) * t; 105 | } 106 | 107 | tgaij_eh[e*H+h] = g_softmax; 108 | 109 | const float *q = que_ihk + (i*H+h)*K_HEAD_DIM; 110 | #pragma unroll 111 | for (int k = 0; k < K_HEAD_DIM; k++){ 112 | gk[k] += g_softmax*q[k]; 113 | } 114 | } 115 | } 116 | 117 | gval_ihk += (j*H+h)*V_HEAD_DIM; 118 | #pragma unroll 119 | for (int k = 0; k < V_HEAD_DIM; k++) 120 | gval_ihk[k] = gv[k]; 121 | 122 | gkey_ihk += (j*H+h)*K_HEAD_DIM; 123 | #pragma unroll 124 | for (int k = 0; k < K_HEAD_DIM; k++) 125 | gkey_ihk[k] = gk[k]; 126 | } -------------------------------------------------------------------------------- /models/kernels/fused_dpa_fwd_v3.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/pairwise_sum.cuh" 2 | 3 | extern "C" __global__ 4 | void fused_dpa_fwd_v3( 5 | const float* que_ihk, 6 | const float* key_ihk, 7 | const float* val_ihk, 8 | const float* aij_eh, 9 | const float* bij_ehk, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const long long int N, 13 | const long long int H, 14 | const long long int E, 15 | float* prob_eh, 16 | float* out_ihk 17 | ){ 18 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 19 | if (tid >= N*H) return; 20 | const long long int i = tid/H; 21 | const long long int h = tid%H; 22 | const long long int e_start = e_start_i[i]; 23 | const long long int e_end = e_start_i[i+1]; 24 | 25 | // Load q onto shared mem. 26 | constexpr int DIM_MAX = (V_HEAD_DIM>K_HEAD_DIM) ? V_HEAD_DIM : K_HEAD_DIM; 27 | __shared__ float _vec_k[THREAD_NUM][DIM_MAX+1]; 28 | que_ihk += tid*K_HEAD_DIM; 29 | float* que_k = _vec_k[threadIdx.x]; 30 | #pragma unroll 31 | for (int k = 0; k < K_HEAD_DIM; k++){ 32 | que_k[k] = que_ihk[k]; 33 | } 34 | 35 | __shared__ float _attns[THREAD_NUM][MAX_SYSTEM_SIZE+1]; 36 | __shared__ float _output[THREAD_NUM][V_HEAD_DIM+1]; 37 | float *attns = _attns[threadIdx.x]; 38 | float *output = _output[threadIdx.x]; 39 | 40 | float max_attn = -1e20; 41 | int e_count = 0; 42 | for (long long int e = e_start; e < e_end; e++) 43 | { 44 | long long int j = edge_ij_e[E+e]; 45 | const float* key_k = &key_ihk[(j*H+h)*K_HEAD_DIM]; 46 | 47 | float attn = 0; 48 | #pragma unroll 49 | for (int k = 0; k < K_HEAD_DIM; k++){ 50 | attn += que_k[k]*key_k[k]; 51 | } 52 | if (aij_eh != NULL) 53 | attn += aij_eh[e*H+h]; 54 | max_attn = max(max_attn, attn); 55 | attns[e_count] = attn; 56 | e_count++; 57 | } 58 | 59 | for (int j = 0; j < e_count; j++) 60 | { 61 | float v = exp(attns[j] - max_attn); 62 | attns[j] = v; 63 | } 64 | 65 | // Compute pairwise sum for better numerical accuracy. 66 | constexpr int BS = 64; 67 | for (int j = e_count; j < (e_count+BS-1)/BS*BS; j++) 68 | attns[j] = 0.0f; 69 | float s[MAX_SYSTEM_SIZE_POW2/BS] = {0.0f}; 70 | for (int j = 0; j < (e_count+BS-1)/BS; j++){ 71 | s[j] = fixed_length_pairwise_sum(&attns[j*BS]); 72 | } 73 | float sum = fixed_length_pairwise_sum(s); 74 | 75 | 76 | for (int j = 0; j < e_count; j++){ 77 | attns[j] /= sum; 78 | prob_eh[(e_start+j)*H+h] = attns[j]; 79 | } 80 | 81 | float *short_run_sum = _vec_k[threadIdx.x]; 82 | #pragma unroll 83 | for (int k = 0; k < V_HEAD_DIM; k++){ 84 | output[k] = 0; 85 | short_run_sum[k] = 0; 86 | } 87 | 88 | e_count = 0; 89 | if (bij_ehk != NULL) { 90 | for (long long int e = e_start; e < e_end; e++) 91 | { 92 | long long int j = edge_ij_e[E+e]; 93 | const float* val_k = &val_ihk[(j*H+h)*V_HEAD_DIM]; 94 | const float* bij_k = &bij_ehk[(e*H+h)*V_HEAD_DIM]; 95 | 96 | float attn = attns[e_count]; 97 | #pragma unroll 98 | for (int k = 0; k < V_HEAD_DIM; k++){ 99 | short_run_sum[k] += (val_k[k]+bij_k[k])*attn; 100 | } 101 | e_count++; 102 | 103 | if ((e_count % RUNNING_SUM_LEN) == 0 || e == e_end-1){ 104 | #pragma unroll 105 | for (int k = 0; k < V_HEAD_DIM; k++){ 106 | output[k] += short_run_sum[k]; 107 | short_run_sum[k] = 0; 108 | } 109 | } 110 | } 111 | } else { 112 | for (long long int e = e_start; e < e_end; e++) 113 | { 114 | long long int j = edge_ij_e[E+e]; 115 | const float* val_k = &val_ihk[(j*H+h)*V_HEAD_DIM]; 116 | 117 | float attn = attns[e_count]; 118 | #pragma unroll 119 | for (int k = 0; k < V_HEAD_DIM; k++){ 120 | short_run_sum[k] += val_k[k]*attn; 121 | } 122 | e_count++; 123 | 124 | if ((e_count % RUNNING_SUM_LEN) == 0 || e == e_end-1){ 125 | #pragma unroll 126 | for (int k = 0; k < V_HEAD_DIM; k++){ 127 | output[k] += short_run_sum[k]; 128 | short_run_sum[k] = 0; 129 | } 130 | } 131 | } 132 | } 133 | 134 | out_ihk += tid*V_HEAD_DIM; 135 | #pragma unroll 136 | for (int k = 0; k < V_HEAD_DIM; k++) 137 | out_ihk[k] = output[k]; 138 | } -------------------------------------------------------------------------------- /models/cuda_funcs/kernel_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import List, Optional, Tuple 3 | import os 4 | import inspect 5 | import math 6 | 7 | 8 | try: 9 | import cupy as cp 10 | import pytorch_pfn_extras as ppe 11 | except: 12 | pass 13 | 14 | 15 | class Kernel: 16 | def __init__(self, name:str): 17 | self.name:str = name 18 | self.code:str = None 19 | self.raw_kernel:cp.RawKernel = None 20 | 21 | def __call__(self, grid, block, args, **kwargs): 22 | self.raw_kernel(grid, block, args, **kwargs) 23 | 24 | class KernelManager: 25 | #position_enc_forward:Kernel = None 26 | #position_enc_backward:Kernel = None 27 | real_enc_fwd:Kernel = None 28 | real_enc_bwd:Kernel = None 29 | real_enc_fwd_v2:Kernel = None 30 | real_enc_bwd_v2:Kernel = None 31 | #position_enc_proj_forward:Kernel = None 32 | #position_enc_proj_backward:Kernel = None 33 | real_enc_proj_fwd:Kernel = None 34 | real_enc_proj_bwd:Kernel = None 35 | real_enc_proj_fwd_v2:Kernel = None 36 | real_enc_proj_bwd_v2:Kernel = None 37 | reci_enc_fwd:Kernel = None 38 | reci_enc_bwd:Kernel = None 39 | reci_enc_fwd_v2:Kernel = None 40 | reci_enc_bwd_v2:Kernel = None 41 | fused_dpa_fwd:Kernel = None 42 | fused_dpa_fwd_v2:Kernel = None 43 | fused_dpa_fwd_v3:Kernel = None 44 | fused_dpa_bwd:Kernel = None 45 | fused_dpa_bwd_v2:Kernel = None 46 | fused_dpa_bwd_v3:Kernel = None 47 | fused_dpa_bwd_q:Kernel = None 48 | fused_dpa_bwd_q_v2:Kernel = None 49 | fused_dpa_bwd_q_v3:Kernel = None 50 | irregular_transpose:Kernel = None 51 | irregular_transpose_old:Kernel = None 52 | irregular_mean_fwd:Kernel = None 53 | minimum_distance:Kernel = None 54 | 55 | MAX_SYSTEM_SIZE:int = 320 56 | MAX_SYSTEM_SIZE_POW2:int = int(2**math.ceil(math.log2(MAX_SYSTEM_SIZE))) 57 | RUNNING_SUM_LEN:int = 8 58 | 59 | @staticmethod 60 | def get_kernel_names() -> List[str]: 61 | return [name for name, attr in inspect.getmembers(KernelManager) \ 62 | if not name.startswith("_") and \ 63 | not inspect.isfunction(attr) and \ 64 | KernelManager.__annotations__.get(name, None) == Kernel 65 | ] 66 | 67 | @staticmethod 68 | def get_kernel(name:str) -> Kernel: 69 | return KernelManager.__dict__[name] 70 | 71 | @staticmethod 72 | def set_kernel(name:str, kernel:Kernel): 73 | setattr(KernelManager, name, kernel) 74 | #KernelManager.__dict__[name] = kernel 75 | 76 | src_dir = os.path.dirname(os.path.abspath(__file__)) 77 | for name in KernelManager.get_kernel_names(): 78 | kernel = Kernel(name) 79 | with open(os.path.join(src_dir, f'../kernels/{kernel.name}.cu'), 'r') as f: 80 | kernel.code = f.read() 81 | KernelManager.set_kernel(name, kernel) 82 | 83 | # kernels = [ 84 | # 'position_enc_forward', 85 | # 'position_enc_backward', 86 | # 'adaptive_real_forward', 87 | # 'adaptive_real_backward', 88 | # 'position_enc_proj_forward', 89 | # 'position_enc_proj_backward', 90 | # 'adaptive_real_proj_forward', 91 | # 'adaptive_real_proj_backward', 92 | # 'reciprocal_forward', 93 | # 'reciprocal_backward', 94 | # 'fused_dpa_fwd', 95 | # 'fused_dpa_bwd', 96 | # 'fused_dpa_bwd_q', 97 | # 'irregular_transpose', 98 | # 'irregular_transpose_old', 99 | # ] 100 | 101 | # kernels = { name: Kernel(name) for name in kernels } 102 | 103 | # src_dir = os.path.dirname(os.path.abspath(__file__)) 104 | # for name in kernels: 105 | # kernel = kernels[name] 106 | # with open(os.path.join(src_dir, f'../kernels/{kernel.name}.cu'), 'r') as f: 107 | # kernel.code = f.read() 108 | 109 | def compile_kernels(lattice_range:int, head_num:int, key_head_dim:int, value_pe_dim:int, value_head_dim:int, set_minimum_range:bool): 110 | constants_dict = { 111 | 'LATTICE_RANGE': str(lattice_range), 112 | 'THREAD_NUM': str(head_num), 113 | 'HEAD_NUM': str(head_num), 114 | 'VPE_DIM': str(value_pe_dim), 115 | 'V_HEAD_DIM': str(value_head_dim), 116 | 'K_HEAD_DIM': str(key_head_dim), 117 | 'SKIP_OUTOF_RADIUS': '0', 118 | 'MINIMUM_RANGE': str(lattice_range) if set_minimum_range else '0', 119 | 'MAX_SYSTEM_SIZE_POW2': KernelManager.MAX_SYSTEM_SIZE_POW2, 120 | 'MAX_SYSTEM_SIZE': KernelManager.MAX_SYSTEM_SIZE, 121 | 'RUNNING_SUM_LEN': KernelManager.RUNNING_SUM_LEN, 122 | } 123 | def replace_constants(code:str): 124 | for key,val in constants_dict.items(): 125 | code = code.replace(key, val if isinstance(val, str) else str(val)) 126 | return code 127 | 128 | options = ('-dc', '--std=c++11') 129 | if torch.cuda.device_count() > 0: 130 | with cp.cuda.Device(0): 131 | for name in KernelManager.get_kernel_names(): 132 | kernel = KernelManager.get_kernel(name) 133 | code = replace_constants(kernel.code) 134 | kernel.raw_kernel = cp.RawKernel(code, kernel.name, options, jitify=True) 135 | -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd_v3.cu: -------------------------------------------------------------------------------- 1 | 2 | extern "C" __global__ 3 | void fused_dpa_bwd_v3( 4 | const float* que_ihk, 5 | const float* key_ihk, 6 | const float* val_ihk, 7 | const float* taij_eh, 8 | const float* tbij_ehk, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const long long int N, 13 | const long long int H, 14 | const long long int E, 15 | const float* tprob_eh, 16 | const float* out_ihk, 17 | const float* gout_ihk, 18 | float* gque_ihk, 19 | float* gkey_ihk, 20 | float* gval_ihk, 21 | float* tgaij_eh, 22 | float* tgbij_ehk 23 | ){ 24 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 25 | if (tid >= N*H) return; 26 | 27 | const long long int K = VPE_DIM; 28 | const long long int j = tid/H; 29 | const long long int h = tid%H; 30 | const long long int n = batch_i[j]; 31 | const long long int e_start = e_start_i[j]; 32 | const long long int e_end = e_start_i[j+1]; 33 | 34 | const float* que_k = &que_ihk[tid*K_HEAD_DIM]; 35 | __shared__ float _v[THREAD_NUM][V_HEAD_DIM+1]; 36 | __shared__ float _gv[THREAD_NUM][V_HEAD_DIM+1]; 37 | __shared__ float _gk[THREAD_NUM][K_HEAD_DIM+1]; 38 | __shared__ float _run_gv[THREAD_NUM][V_HEAD_DIM+1]; 39 | __shared__ float _run_gk[THREAD_NUM][K_HEAD_DIM+1]; 40 | float *v = _v[threadIdx.x]; 41 | float *gv = _gv[threadIdx.x]; 42 | float *gk = _gk[threadIdx.x]; 43 | float *rgv = _run_gv[threadIdx.x]; 44 | float *rgk = _run_gk[threadIdx.x]; 45 | 46 | const float *v_src = val_ihk + (j*H+h)*V_HEAD_DIM; 47 | #pragma unroll 48 | for (int k = 0; k < V_HEAD_DIM; k++){ 49 | gv[k] = 0; 50 | rgv[k] = 0; 51 | v[k] = v_src[k]; 52 | } 53 | 54 | #pragma unroll 55 | for (int k = 0; k < K_HEAD_DIM; k++){ 56 | gk[k] = 0; 57 | rgk[k] = 0; 58 | } 59 | 60 | if (tgbij_ehk != NULL && tbij_ehk != NULL ) { 61 | for (long long int e = e_start; e < e_end; e++) 62 | { 63 | long long int i = edge_ij_e[E+e]; 64 | 65 | float pij = tprob_eh[e*H+h]; 66 | const float *go = gout_ihk + (i*H+h)*V_HEAD_DIM; 67 | const float *o = out_ihk + (i*H+h)*V_HEAD_DIM; 68 | float *gb = tgbij_ehk + (e*H+h)*V_HEAD_DIM; 69 | const float *b = tbij_ehk + (e*H+h)*V_HEAD_DIM; 70 | float g_softmax = 0; 71 | #pragma unroll 72 | for (int k = 0; k < V_HEAD_DIM; k++){ 73 | float t = go[k]*pij; 74 | rgv[k] += t; 75 | gb[k] = t; 76 | g_softmax += (v[k] + b[k] - o[k]) * t; 77 | } 78 | 79 | tgaij_eh[e*H+h] = g_softmax; 80 | 81 | const float *q = que_ihk + (i*H+h)*K_HEAD_DIM; 82 | #pragma unroll 83 | for (int k = 0; k < K_HEAD_DIM; k++){ 84 | rgk[k] += g_softmax*q[k]; 85 | } 86 | 87 | if (((e-e_start) % RUNNING_SUM_LEN) == 0 || e == e_end-1){ 88 | #pragma unroll 89 | for (int k = 0; k < V_HEAD_DIM; k++){ 90 | gv[k] += rgv[k]; 91 | rgv[k] = 0; 92 | } 93 | #pragma unroll 94 | for (int k = 0; k < K_HEAD_DIM; k++){ 95 | gk[k] += rgk[k]; 96 | rgk[k] = 0; 97 | } 98 | } 99 | } 100 | } else { 101 | for (long long int e = e_start; e < e_end; e++) 102 | { 103 | long long int i = edge_ij_e[E+e]; 104 | 105 | float pij = tprob_eh[e*H+h]; 106 | const float *go = gout_ihk + (i*H+h)*V_HEAD_DIM; 107 | const float *o = out_ihk + (i*H+h)*V_HEAD_DIM; 108 | float g_softmax = 0; 109 | #pragma unroll 110 | for (int k = 0; k < V_HEAD_DIM; k++){ 111 | float t = go[k]*pij; 112 | rgv[k] += t; 113 | g_softmax += (v[k] - o[k]) * t; 114 | } 115 | 116 | tgaij_eh[e*H+h] = g_softmax; 117 | 118 | const float *q = que_ihk + (i*H+h)*K_HEAD_DIM; 119 | #pragma unroll 120 | for (int k = 0; k < K_HEAD_DIM; k++){ 121 | rgk[k] += g_softmax*q[k]; 122 | } 123 | 124 | if (((e-e_start) % RUNNING_SUM_LEN) == 0 || e == e_end-1){ 125 | #pragma unroll 126 | for (int k = 0; k < V_HEAD_DIM; k++){ 127 | gv[k] += rgv[k]; 128 | rgv[k] = 0; 129 | } 130 | #pragma unroll 131 | for (int k = 0; k < K_HEAD_DIM; k++){ 132 | gk[k] += rgk[k]; 133 | rgk[k] = 0; 134 | } 135 | } 136 | } 137 | } 138 | 139 | gval_ihk += (j*H+h)*V_HEAD_DIM; 140 | #pragma unroll 141 | for (int k = 0; k < V_HEAD_DIM; k++) 142 | gval_ihk[k] = gv[k]; 143 | 144 | gkey_ihk += (j*H+h)*K_HEAD_DIM; 145 | #pragma unroll 146 | for (int k = 0; k < K_HEAD_DIM; k++) 147 | gkey_ihk[k] = gk[k]; 148 | } -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import sys 4 | import torch 5 | from torch_geometric.loader import DataLoader 6 | 7 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 8 | 9 | from tqdm import tqdm 10 | from utils import Params 11 | from distutils.util import strtobool 12 | 13 | def get_option(): 14 | argparser = ArgumentParser(description='Training the network') 15 | argparser.add_argument('-p', '--param_file', type=str, default='default.json', help='filename of the parameter JSON') 16 | args, unknown = argparser.parse_known_args() 17 | return args 18 | 19 | def demo(): 20 | args = get_option() 21 | print('parsed args :') 22 | print(args) 23 | try: 24 | params = Params(f'{args.param_file}') 25 | except: 26 | params = Params(f'./params/{args.param_file}') 27 | 28 | parser = ArgumentParser(description='Training the network') 29 | parser.add_argument('-p', '--param_file', type=str, default=args.param_file, help='Config json file for default params') 30 | # load the json config and use it as default values. 31 | boolder = lambda x:bool(strtobool(x)) 32 | typefinder = lambda v: str if v is None else boolder if type(v)==bool else type(v) 33 | for key in params.dict: 34 | v = params.dict[key] 35 | if isinstance(v, (list, tuple)): 36 | parser.add_argument(f"--{key}", type=typefinder(v[0]), default=v, nargs='+') 37 | else: 38 | parser.add_argument(f"--{key}", type=typefinder(v), default=v) 39 | params.__dict__ = parser.parse_args().__dict__ 40 | print(params.dict) 41 | 42 | import models.global_config as config 43 | config.REPRODUCIBLITY_STATE = getattr(params, 'reproduciblity_state', 0) 44 | print(f"reproduciblity_state = {config.REPRODUCIBLITY_STATE}") 45 | 46 | # Reproducibility 47 | seed = getattr(params, 'seed', 123) 48 | deterministic = params.encoder_name in ["latticeformer"] 49 | torch.manual_seed(seed) 50 | torch.cuda.manual_seed(seed) 51 | torch.cuda.manual_seed_all(seed) 52 | torch.backends.cudnn.benchmark = False 53 | torch.backends.cudnn.deterministic = deterministic 54 | torch.backends.cuda.matmul.allow_tf32 = False 55 | torch.backends.cudnn.allow_tf32 = False 56 | # torch.backends.cuda.preferred_linalg_library("cusolver") # since torch 1.11, needed to avoid an error by torch.det(), but now det_3x3 is implemented manually. 57 | 58 | from dataloaders.dataset_latticeformer import RegressionDatasetMP_Latticeformer as Dataset 59 | from models.latticeformer import Latticeformer 60 | 61 | model = Latticeformer(params) 62 | param_num = sum([p.nelement() for p in model.parameters()]) 63 | print(f"Whole: {param_num}, {param_num*4/1024**2} MB") 64 | param_num = sum([p.nelement() for p in model.encoder.layers[0].parameters()]) 65 | print(f"Block: {param_num}, {param_num*4/1024**1} KB") 66 | 67 | if params.pretrained_model is not None: 68 | with open(params.pretrained_model, "rb") as f: 69 | ckeckpoint = torch.load(f) 70 | state_dict = ckeckpoint['state_dict'] 71 | target_std = ckeckpoint['state_dict']['target_std'] 72 | target_mean = ckeckpoint['state_dict']['target_mean'] 73 | 74 | model_name = "swa_model.module." 75 | # if params.pretrained_model.endswith("best.ckpt"): 76 | # model_name = "model." 77 | # else: 78 | # model_name = "swa_model.module." 79 | 80 | print("model name:", model_name) 81 | model_dict = { key.replace(model_name, ""):state_dict[key] for key in state_dict if key.startswith(model_name) } 82 | model.load_state_dict(model_dict) 83 | # correct the last linear layer weights 84 | target_std = target_std.to(model.mlp[-1].weight.device) 85 | target_mean = target_mean.to(model.mlp[-1].weight.device) 86 | model.mlp[-1].load_state_dict({ 87 | 'weight': model.mlp[-1].weight * target_std[:,None], 88 | 'bias': model.mlp[-1].bias * target_std + target_mean, 89 | }) 90 | 91 | else: 92 | print("Specify --pretrained_model for demonstration.") 93 | exit() 94 | 95 | # Setup datasets 96 | target_set = getattr(params, "target_set", None) 97 | test_dataset = Dataset(target_split='test', target_set=target_set) 98 | test_loader = DataLoader(test_dataset, batch_size=params.batch_size, shuffle=False, num_workers=0, drop_last=False) 99 | 100 | model = model.cuda() 101 | model.eval() 102 | targets = params.targets if isinstance(params.targets, list) else [params.targets] 103 | 104 | with torch.no_grad(): 105 | mae_err = {t: [] for t in targets} 106 | for batch in tqdm(test_loader): 107 | batch = batch.cuda() 108 | output = model(batch) 109 | for i, t in enumerate(targets): 110 | labels = batch[t] 111 | mae_err[t].append(abs(output[:, i] - labels).detach().cpu()) 112 | 113 | for t in targets: 114 | print(f"{t}: {torch.cat(mae_err[t]).mean().item()}") 115 | 116 | 117 | if __name__ == '__main__': 118 | demo() 119 | -------------------------------------------------------------------------------- /models/kernels/fused_dpa_bwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include "models/kernels/reduce_kernel_utils.cuh" 2 | 3 | __global__ __device__ void fused_dpa_bwd_thread( 4 | const float* que_ihk, 5 | const float* val_k, 6 | const float* tbij_ehk, 7 | const long long int* edge_i_e, 8 | const long long int e_start, 9 | const long long int e_end, 10 | const long long int H, 11 | const float* tprob_eh, 12 | const float* out_ihk, 13 | const float* gout_ihk, 14 | float* gkey_k, 15 | float* gval_k, 16 | float* tgaij_eh, 17 | float* tgbij_ehk 18 | ){ 19 | long long int e = e_start + threadIdx.x; 20 | bool isValid = e < e_end; 21 | 22 | e = min(e, e_end-1); 23 | long long int i = edge_i_e[e]; 24 | 25 | float pij = isValid ? tprob_eh[e*H] : 0.0f; 26 | que_ihk += (i*H)*K_HEAD_DIM; 27 | gout_ihk += (i*H)*V_HEAD_DIM; 28 | out_ihk += (i*H)*V_HEAD_DIM; 29 | tgbij_ehk += (e*H)*V_HEAD_DIM; 30 | tbij_ehk += (e*H)*V_HEAD_DIM; 31 | 32 | float g_softmax = 0; 33 | #pragma unroll 34 | for (int k = 0; k < V_HEAD_DIM; k++){ 35 | float t = gout_ihk[k]*pij; 36 | g_softmax += (val_k[k] + tbij_ehk[k] - out_ihk[k]) * t; 37 | 38 | if (isValid) 39 | tgbij_ehk[k] = t; 40 | 41 | __syncthreads(); 42 | t = blockReduceSum(t); 43 | if (threadIdx.x == 0) 44 | gval_k[k] = t; 45 | } 46 | 47 | if (isValid) 48 | tgaij_eh[e*H] = g_softmax; 49 | 50 | #pragma unroll 51 | for (int k = 0; k < K_HEAD_DIM; k++){ 52 | float gk = g_softmax*que_ihk[k]; 53 | __syncthreads(); 54 | gk = blockReduceSum(gk); 55 | if (threadIdx.x == 0) 56 | gkey_k[k] = gk; 57 | } 58 | } 59 | 60 | 61 | __global__ __device__ void fused_dpa_bwd_thread_no_bij( 62 | const float* que_ihk, 63 | const float* val_k, 64 | const long long int* edge_i_e, 65 | const long long int e_start, 66 | const long long int e_end, 67 | const long long int H, 68 | const float* tprob_eh, 69 | const float* out_ihk, 70 | const float* gout_ihk, 71 | float* gkey_k, 72 | float* gval_k, 73 | float* tgaij_eh 74 | ){ 75 | long long int e = e_start + threadIdx.x; 76 | bool isValid = e < e_end; 77 | 78 | e = min(e, e_end-1); 79 | long long int i = edge_i_e[e]; 80 | 81 | float pij = isValid ? tprob_eh[e*H] : 0.0f; 82 | que_ihk += (i*H)*K_HEAD_DIM; 83 | gout_ihk += (i*H)*V_HEAD_DIM; 84 | out_ihk += (i*H)*V_HEAD_DIM; 85 | 86 | float g_softmax = 0; 87 | #pragma unroll 88 | for (int k = 0; k < V_HEAD_DIM; k++){ 89 | float t = gout_ihk[k]*pij; 90 | g_softmax += (val_k[k] - out_ihk[k]) * t; 91 | 92 | __syncthreads(); 93 | t = blockReduceSum(t); 94 | if (threadIdx.x == 0) 95 | gval_k[k] = t; 96 | } 97 | 98 | if (isValid) 99 | tgaij_eh[e*H] = g_softmax; 100 | 101 | #pragma unroll 102 | for (int k = 0; k < K_HEAD_DIM; k++){ 103 | float gk = g_softmax*que_ihk[k]; 104 | __syncthreads(); 105 | gk = blockReduceSum(gk); 106 | if (threadIdx.x == 0) 107 | gkey_k[k] = gk; 108 | } 109 | } 110 | 111 | 112 | extern "C" __global__ 113 | void fused_dpa_bwd_v2( 114 | const float* que_ihk, 115 | const float* val_ihk, 116 | const float* tbij_ehk, 117 | const long long int* edge_ij_e, 118 | const long long int* e_start_i, 119 | const long long int N, 120 | const long long int H, 121 | const long long int E, 122 | const float* tprob_eh, 123 | const float* out_ihk, 124 | const float* gout_ihk, 125 | float* gkey_ihk, 126 | float* gval_ihk, 127 | float* tgaij_eh, 128 | float* tgbij_ehk 129 | ){ 130 | const long long int tid = (long long int)blockIdx.x*blockDim.x + threadIdx.x; 131 | if (tid >= N*H) return; 132 | 133 | const long long int j = tid / H; 134 | const long long int h = tid % H; 135 | const long long int e_start = e_start_i[j]; 136 | const long long int e_end = e_start_i[j+1]; 137 | 138 | if (tgbij_ehk != NULL && tbij_ehk != NULL ) { 139 | fused_dpa_bwd_thread<<< 1, ((e_end-e_start+31)/32)*32 >>>( 140 | que_ihk + h*K_HEAD_DIM, 141 | val_ihk + tid*V_HEAD_DIM, 142 | tbij_ehk + h*V_HEAD_DIM, 143 | edge_ij_e + E, 144 | e_start, 145 | e_end, 146 | H, 147 | tprob_eh + h, 148 | out_ihk + h*V_HEAD_DIM, 149 | gout_ihk + h*V_HEAD_DIM, 150 | gkey_ihk + tid*K_HEAD_DIM, 151 | gval_ihk + tid*V_HEAD_DIM, 152 | tgaij_eh + h, 153 | tgbij_ehk + h*V_HEAD_DIM 154 | ); 155 | } else { 156 | fused_dpa_bwd_thread_no_bij<<< 1, ((e_end-e_start+31)/32)*32 >>>( 157 | que_ihk + h*K_HEAD_DIM, 158 | val_ihk + tid*V_HEAD_DIM, 159 | edge_ij_e + E, 160 | e_start, 161 | e_end, 162 | H, 163 | tprob_eh + h, 164 | out_ihk + h*V_HEAD_DIM, 165 | gout_ihk + h*V_HEAD_DIM, 166 | gkey_ihk + tid*K_HEAD_DIM, 167 | gval_ihk + tid*V_HEAD_DIM, 168 | tgaij_eh + h 169 | ); 170 | } 171 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_fwd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_fwd( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int N, 12 | const long long int H, 13 | const long long int E, 14 | const long long int K_, 15 | const double dist_max, 16 | const double wscale, 17 | const float* rveclens_n, 18 | const double cutoff_radius, 19 | float* z_ek, 20 | float* v_ekd){ 21 | 22 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 23 | if (tid >= E*H) return; 24 | 25 | const long long int k = tid%H; 26 | const long long int e = tid/H; 27 | const long long int i = edge_ij_e[e]; 28 | const long long int j = edge_ij_e[E+e]; 29 | const long long int n = batch_i[i]; 30 | rpos_ij_e += e*3; 31 | const float r_ijx = rpos_ij_e[0]; 32 | const float r_ijy = rpos_ij_e[1]; 33 | const float r_ijz = rpos_ij_e[2]; 34 | tvecs_n += n*9; 35 | const float t1_x = tvecs_n[0]; 36 | const float t1_y = tvecs_n[1]; 37 | const float t1_z = tvecs_n[2]; 38 | const float t2_x = tvecs_n[3]; 39 | const float t2_y = tvecs_n[4]; 40 | const float t2_z = tvecs_n[5]; 41 | const float t3_x = tvecs_n[6]; 42 | const float t3_y = tvecs_n[7]; 43 | const float t3_z = tvecs_n[8]; 44 | const float a = a_ik[i*H + k]; 45 | const int R = LATTICE_RANGE; 46 | const float Rf = (float)LATTICE_RANGE; 47 | 48 | #if VPE_DIM > 0 49 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 50 | float *sv = shared_v[threadIdx.x]; 51 | 52 | for (int dim = 0; dim < VPE_DIM; dim++) 53 | sv[dim] = 0; 54 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 55 | const float mu0 = (float)dist_max/VPE_DIM; 56 | #endif 57 | 58 | rveclens_n += n*3; 59 | const float rvl1 = rveclens_n[0]; 60 | const float rvl2 = rveclens_n[1]; 61 | const float rvl3 = rveclens_n[2]; 62 | 63 | float cutoff = (float)cutoff_radius; 64 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 65 | if (cutoff != 0.0f) 66 | { 67 | if (cutoff < 0) { 68 | // Better sync the threads in each block? 69 | // -> disabled due to thread stucking 70 | // float a_max = a; 71 | // for (int t = 0; t < THREAD_NUM; t++) 72 | // a_max = max(a_max, a_ik[i*H + t]); 73 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 74 | cutoff = sqrt(-0.5f/a)*(-cutoff); 75 | } 76 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 77 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 78 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 79 | 80 | #if MINIMUM_RANGE > 0 81 | R1 = max(R1, MINIMUM_RANGE); 82 | R2 = max(R2, MINIMUM_RANGE); 83 | R3 = max(R3, MINIMUM_RANGE); 84 | #endif 85 | } 86 | 87 | float d2min = 1e10; 88 | if (1 || dist2_min_e == NULL) 89 | { 90 | for (float n1 = -R1; n1 <= R1; n1++) 91 | for (float n2 = -R2; n2 <= R2; n2++) 92 | for (float n3 = -R3; n3 <= R3; n3++) 93 | { 94 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 95 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 96 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 97 | float d2 = dx*dx + dy*dy + dz*dz; 98 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 99 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 100 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 101 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 102 | d2min = fminf(d2min, d2); 103 | } 104 | } else { 105 | d2min = dist2_min_e[e]; 106 | } 107 | 108 | float sum = 0; 109 | for (float n1 = -R1; n1 <= R1; n1++) 110 | for (float n2 = -R2; n2 <= R2; n2++) 111 | for (float n3 = -R3; n3 <= R3; n3++) 112 | { 113 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 114 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 115 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 116 | float d2 = dx*dx + dy*dy + dz*dz; 117 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 118 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 119 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 120 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 121 | float w = expf(a*(d2 - d2min)); 122 | sum += w; 123 | 124 | #if VPE_DIM > 0 125 | // b_dim = exp( -((dim*(m/K)-dist)/(sqrt(2)*wscale*dist_max/K))**2 ) 126 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 127 | #pragma unroll 128 | for (int dim = 0; dim < VPE_DIM; dim++) 129 | { 130 | b += reci_ws_sqrt2; 131 | sv[dim] += exp(-b*b)*w; 132 | } 133 | #endif 134 | } 135 | 136 | #if VPE_DIM > 0 137 | float *v = &v_ekd[tid*VPE_DIM]; 138 | #pragma unroll 139 | for (int dim = 0; dim < VPE_DIM; dim++) 140 | v[dim] = sv[dim]/sum; 141 | #endif 142 | 143 | z_ek[tid] = logf(sum) + d2min*a; 144 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_fwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_fwd_v2( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int N, 12 | const long long int H, 13 | const long long int E, 14 | const long long int K_, 15 | const double dist_max, 16 | const double wscale, 17 | const float* rveclens_n, 18 | const double cutoff_radius, 19 | float* z_ek, 20 | float* v_ekd){ 21 | 22 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 23 | if (tid >= E*H) return; 24 | 25 | const long long int k = tid%H; 26 | const long long int e = tid/H; 27 | const long long int i = edge_ij_e[e]; 28 | const long long int j = edge_ij_e[E+e]; 29 | const long long int n = batch_i[i]; 30 | rpos_ij_e += e*3; 31 | const float r_ijx = rpos_ij_e[0]; 32 | const float r_ijy = rpos_ij_e[1]; 33 | const float r_ijz = rpos_ij_e[2]; 34 | tvecs_n += n*9; 35 | const float t1_x = tvecs_n[0]; 36 | const float t1_y = tvecs_n[1]; 37 | const float t1_z = tvecs_n[2]; 38 | const float t2_x = tvecs_n[3]; 39 | const float t2_y = tvecs_n[4]; 40 | const float t2_z = tvecs_n[5]; 41 | const float t3_x = tvecs_n[6]; 42 | const float t3_y = tvecs_n[7]; 43 | const float t3_z = tvecs_n[8]; 44 | const float a = a_ik[i*H + k]; 45 | const int R = LATTICE_RANGE; 46 | const float Rf = (float)LATTICE_RANGE; 47 | 48 | #if VPE_DIM > 0 49 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 50 | float *sv = shared_v[threadIdx.x]; 51 | 52 | for (int dim = 0; dim < VPE_DIM; dim++) 53 | sv[dim] = 0; 54 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 55 | const float mu0 = (float)dist_max/VPE_DIM; 56 | #endif 57 | 58 | rveclens_n += n*3; 59 | const float rvl1 = rveclens_n[0]; 60 | const float rvl2 = rveclens_n[1]; 61 | const float rvl3 = rveclens_n[2]; 62 | 63 | float cutoff = (float)cutoff_radius; 64 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 65 | if (cutoff != 0.0f) 66 | { 67 | if (cutoff < 0) { 68 | // Better sync the threads in each block? 69 | // -> disabled due to thread stucking 70 | // float a_max = a; 71 | // for (int t = 0; t < THREAD_NUM; t++) 72 | // a_max = max(a_max, a_ik[i*H + t]); 73 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 74 | cutoff = sqrt(-0.5f/a)*(-cutoff); 75 | } 76 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 77 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 78 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 79 | 80 | #if MINIMUM_RANGE > 0 81 | R1 = max(R1, MINIMUM_RANGE); 82 | R2 = max(R2, MINIMUM_RANGE); 83 | R3 = max(R3, MINIMUM_RANGE); 84 | #endif 85 | } 86 | 87 | float d2min = 1e10; 88 | if (1 || dist2_min_e == NULL) 89 | { 90 | for (float n1 = -R1; n1 <= R1; n1++) 91 | for (float n2 = -R2; n2 <= R2; n2++) 92 | for (float n3 = -R3; n3 <= R3; n3++) 93 | { 94 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 95 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 96 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 97 | float d2 = dx*dx + dy*dy + dz*dz; 98 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 99 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 100 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 101 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 102 | d2min = fminf(d2min, d2); 103 | } 104 | } else { 105 | d2min = dist2_min_e[e]; 106 | } 107 | 108 | float sum = 0; 109 | for (float n1 = -R1, s1=0; n1 <= R1; n1++, sum+=s1, s1=0) 110 | for (float n2 = -R2, s2=0; n2 <= R2; n2++, s1 +=s2, s2=0) 111 | for (float n3 = -R3; n3 <= R3; n3++) 112 | { 113 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 114 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 115 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 116 | float d2 = dx*dx + dy*dy + dz*dz; 117 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 118 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 119 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 120 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 121 | float w = expf(a*(d2 - d2min)); 122 | s2 += w; 123 | 124 | #if VPE_DIM > 0 125 | // b_dim = exp( -((dim*(m/K)-dist)/(sqrt(2)*wscale*dist_max/K))**2 ) 126 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 127 | #pragma unroll 128 | for (int dim = 0; dim < VPE_DIM; dim++) 129 | { 130 | b += reci_ws_sqrt2; 131 | sv[dim] += exp(-b*b)*w; 132 | } 133 | #endif 134 | } 135 | 136 | #if VPE_DIM > 0 137 | float *v = &v_ekd[tid*VPE_DIM]; 138 | #pragma unroll 139 | for (int dim = 0; dim < VPE_DIM; dim++) 140 | v[dim] = sv[dim]/sum; 141 | #endif 142 | 143 | z_ek[tid] = logf(sum) + d2min*a; 144 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_bwd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_bwd( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const float* z_ek, 13 | const float* gz_ek, 14 | const float* gv_ekd, 15 | const long long int N, 16 | const long long int H, 17 | const long long int E, 18 | const long long int K_, 19 | const double dist_max, 20 | const double wscale, 21 | const float* rveclens_n, 22 | const double cutoff_radius, 23 | float* ga_ik){ 24 | 25 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 26 | if (tid >= N*H) return; 27 | 28 | const long long int K = VPE_DIM; 29 | const long long int k = tid%H; 30 | const long long int i = tid/H; 31 | const long long int n = batch_i[i]; 32 | tvecs_n += n*9; 33 | const float t1_x = tvecs_n[0]; 34 | const float t1_y = tvecs_n[1]; 35 | const float t1_z = tvecs_n[2]; 36 | const float t2_x = tvecs_n[3]; 37 | const float t2_y = tvecs_n[4]; 38 | const float t2_z = tvecs_n[5]; 39 | const float t3_x = tvecs_n[6]; 40 | const float t3_y = tvecs_n[7]; 41 | const float t3_z = tvecs_n[8]; 42 | const float a = a_ik[i*H + k]; 43 | const int R = LATTICE_RANGE; 44 | const long long int e_start = e_start_i[i]; 45 | const long long int e_end = e_start_i[i+1]; 46 | #if VPE_DIM > 0 47 | __shared__ float shared_gv[THREAD_NUM][VPE_DIM+1]; 48 | #endif 49 | 50 | rveclens_n += n*3; 51 | const float rvl1 = rveclens_n[0]; 52 | const float rvl2 = rveclens_n[1]; 53 | const float rvl3 = rveclens_n[2]; 54 | 55 | float cutoff = (float)cutoff_radius; 56 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 57 | if (cutoff != 0.0f) 58 | { 59 | if (cutoff < 0) { 60 | // Better sync the threads in each block? 61 | // -> disabled due to thread stucking 62 | // float a_max = a; 63 | // for (int t = 0; t < THREAD_NUM; t++) 64 | // a_max = max(a_max, a_ik[i*H + t]); 65 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 66 | cutoff = sqrt(-0.5f/a)*(-cutoff); 67 | } 68 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 69 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 70 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 71 | float cutoff2 = cutoff*cutoff; 72 | 73 | #if MINIMUM_RANGE > 0 74 | R1 = max(R1, MINIMUM_RANGE); 75 | R2 = max(R2, MINIMUM_RANGE); 76 | R3 = max(R3, MINIMUM_RANGE); 77 | #endif 78 | } 79 | 80 | float sum = 0; 81 | float sum_v = 0; 82 | for (long long int e = e_start; e < e_end; e++) 83 | { 84 | const long long int j = edge_ij_e[E+e]; 85 | const float r_ijx = rpos_ij_e[e*3+0]; 86 | const float r_ijy = rpos_ij_e[e*3+1]; 87 | const float r_ijz = rpos_ij_e[e*3+2]; 88 | const long long int ek = e*H+k; 89 | const float z = z_ek[ek]; 90 | const float gz = gz_ek[ek]; 91 | 92 | #if VPE_DIM > 0 93 | float *sgv = shared_gv[threadIdx.x]; 94 | const float *gv = &gv_ekd[ek*K]; 95 | #pragma unroll 96 | for (int dim = 0; dim < VPE_DIM; dim++) { 97 | sgv[dim] = gv[dim]; 98 | } 99 | #endif 100 | 101 | float px_avr = 0; 102 | float pbg_avr = 0; 103 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 104 | const float mu0 = (float)dist_max/VPE_DIM; 105 | for (float n1 = -R1; n1 <= R1; n1++) 106 | for (float n2 = -R2; n2 <= R2; n2++) 107 | for (float n3 = -R3; n3 <= R3; n3++) 108 | { 109 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 110 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 111 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 112 | float d2 = dx*dx + dy*dy + dz*dz; 113 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 114 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 115 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 116 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 117 | float p = expf(a*d2 - z); 118 | float px = d2*p; 119 | px_avr += px; 120 | 121 | #if VPE_DIM > 0 122 | float bg = 0; 123 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 124 | #pragma unroll 125 | for (int dim = 0; dim < VPE_DIM; dim++) 126 | { 127 | b += reci_ws_sqrt2; 128 | bg += expf(-b*b)*sgv[dim]; 129 | } 130 | sum_v += px*bg; 131 | pbg_avr += p*bg; 132 | #endif 133 | } 134 | /* 135 | b: (E, 1, R, K) 136 | x: (E, 1, R, 1) 137 | y: (N, H, 1, 1) 138 | z: (E, H, 1, K) 139 | g: (E, H, 1, K) 140 | p: (E, H, R, 1) 141 | 142 | (E,H,R,K) (E,H,R,1) (E,H,R,K) (E,H,1,K): (E,H,R,1)*(E,1,R,K)*(E,H,1,K) 143 | dz/dye = p*x * ( b*g - (p*b*g).sum(axis=R)) 144 | 145 | (E,H,1,1) 146 | dz/dyi = (dz/dye).sum(axis=R,K).sum_for_j() 147 | 148 | (E,H,R,1)*(E,H,R,1) (E,H,1,1) *(E,H,1,1) 149 | dz/dye = (p*x) *(b*g).sum(axis=K) - (p*x).sum(axis=R)*(p*b*g).sum(axis=R,K)) 150 | */ 151 | 152 | sum += px_avr*gz; 153 | sum_v -= px_avr*pbg_avr; 154 | 155 | } 156 | ga_ik[tid] = sum + sum_v; 157 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_bwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_bwd_v2( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const float* z_ek, 13 | const float* gz_ek, 14 | const float* gv_ekd, 15 | const long long int N, 16 | const long long int H, 17 | const long long int E, 18 | const long long int K_, 19 | const double dist_max, 20 | const double wscale, 21 | const float* rveclens_n, 22 | const double cutoff_radius, 23 | float* ga_ik){ 24 | 25 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 26 | if (tid >= N*H) return; 27 | 28 | const long long int K = VPE_DIM; 29 | const long long int k = tid%H; 30 | const long long int i = tid/H; 31 | const long long int n = batch_i[i]; 32 | tvecs_n += n*9; 33 | const float t1_x = tvecs_n[0]; 34 | const float t1_y = tvecs_n[1]; 35 | const float t1_z = tvecs_n[2]; 36 | const float t2_x = tvecs_n[3]; 37 | const float t2_y = tvecs_n[4]; 38 | const float t2_z = tvecs_n[5]; 39 | const float t3_x = tvecs_n[6]; 40 | const float t3_y = tvecs_n[7]; 41 | const float t3_z = tvecs_n[8]; 42 | const float a = a_ik[i*H + k]; 43 | const int R = LATTICE_RANGE; 44 | const long long int e_start = e_start_i[i]; 45 | const long long int e_end = e_start_i[i+1]; 46 | #if VPE_DIM > 0 47 | __shared__ float shared_gv[THREAD_NUM][VPE_DIM+1]; 48 | #endif 49 | 50 | rveclens_n += n*3; 51 | const float rvl1 = rveclens_n[0]; 52 | const float rvl2 = rveclens_n[1]; 53 | const float rvl3 = rveclens_n[2]; 54 | 55 | float cutoff = (float)cutoff_radius; 56 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 57 | if (cutoff != 0.0f) 58 | { 59 | if (cutoff < 0) { 60 | // Better sync the threads in each block? 61 | // -> disabled due to thread stucking 62 | // float a_max = a; 63 | // for (int t = 0; t < THREAD_NUM; t++) 64 | // a_max = max(a_max, a_ik[i*H + t]); 65 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 66 | cutoff = sqrt(-0.5f/a)*(-cutoff); 67 | } 68 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 69 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 70 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 71 | float cutoff2 = cutoff*cutoff; 72 | 73 | #if MINIMUM_RANGE > 0 74 | R1 = max(R1, MINIMUM_RANGE); 75 | R2 = max(R2, MINIMUM_RANGE); 76 | R3 = max(R3, MINIMUM_RANGE); 77 | #endif 78 | } 79 | 80 | float sum = 0; 81 | float sum_v = 0; 82 | for (long long int e = e_start; e < e_end; e++) 83 | { 84 | const long long int j = edge_ij_e[E+e]; 85 | const float r_ijx = rpos_ij_e[e*3+0]; 86 | const float r_ijy = rpos_ij_e[e*3+1]; 87 | const float r_ijz = rpos_ij_e[e*3+2]; 88 | const long long int ek = e*H+k; 89 | const float z = z_ek[ek]; 90 | const float gz = gz_ek[ek]; 91 | 92 | #if VPE_DIM > 0 93 | float *sgv = shared_gv[threadIdx.x]; 94 | const float *gv = &gv_ekd[ek*K]; 95 | #pragma unroll 96 | for (int dim = 0; dim < VPE_DIM; dim++) { 97 | sgv[dim] = gv[dim]; 98 | } 99 | #endif 100 | 101 | float px_avr = 0; 102 | float pbg_avr = 0; 103 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 104 | const float mu0 = (float)dist_max/VPE_DIM; 105 | for (float n1 = -R1, px_avr1=0, pbg_avr1=0, sum_v1=0; n1 <= R1; n1++, px_avr +=px_avr1, pbg_avr +=pbg_avr1, sum_v +=sum_v1, px_avr1=pbg_avr1=sum_v1=0) 106 | for (float n2 = -R2, px_avr2=0, pbg_avr2=0, sum_v2=0; n2 <= R2; n2++, px_avr1+=px_avr2, pbg_avr1+=pbg_avr2, sum_v1+=sum_v2, px_avr2=pbg_avr2=sum_v2=0) 107 | for (float n3 = -R3; n3 <= R3; n3++) 108 | { 109 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 110 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 111 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 112 | float d2 = dx*dx + dy*dy + dz*dz; 113 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 114 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 115 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 116 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 117 | float p = expf(a*d2 - z); 118 | float px = d2*p; 119 | px_avr2 += px; 120 | 121 | #if VPE_DIM > 0 122 | float bg = 0; 123 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 124 | #pragma unroll 125 | for (int dim = 0; dim < VPE_DIM; dim++) 126 | { 127 | b += reci_ws_sqrt2; 128 | bg += expf(-b*b)*sgv[dim]; 129 | } 130 | sum_v2 += px*bg; 131 | pbg_avr2 += p*bg; 132 | #endif 133 | } 134 | /* 135 | b: (E, 1, R, K) 136 | x: (E, 1, R, 1) 137 | y: (N, H, 1, 1) 138 | z: (E, H, 1, K) 139 | g: (E, H, 1, K) 140 | p: (E, H, R, 1) 141 | 142 | (E,H,R,K) (E,H,R,1) (E,H,R,K) (E,H,1,K): (E,H,R,1)*(E,1,R,K)*(E,H,1,K) 143 | dz/dye = p*x * ( b*g - (p*b*g).sum(axis=R)) 144 | 145 | (E,H,1,1) 146 | dz/dyi = (dz/dye).sum(axis=R,K).sum_for_j() 147 | 148 | (E,H,R,1)*(E,H,R,1) (E,H,1,1) *(E,H,1,1) 149 | dz/dye = (p*x) *(b*g).sum(axis=K) - (p*x).sum(axis=R)*(p*b*g).sum(axis=R,K)) 150 | */ 151 | 152 | sum += px_avr*gz; 153 | sum_v -= px_avr*pbg_avr; 154 | 155 | } 156 | ga_ik[tid] = sum + sum_v; 157 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_proj_fwd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_proj_fwd( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int N, 12 | const long long int H, 13 | const long long int E, 14 | const long long int K_, 15 | const double dist_max, 16 | const double wscale, 17 | const float* W_k, 18 | const long long int W_num, 19 | const float* rveclens_n, 20 | const double cutoff_radius, 21 | float* z_ek, 22 | float* v_ekd){ 23 | 24 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 25 | if (tid >= E*H) return; 26 | 27 | const long long int k = tid%H; 28 | const long long int e = tid/H; 29 | const long long int i = edge_ij_e[e]; 30 | const long long int j = edge_ij_e[E+e]; 31 | const long long int n = batch_i[i]; 32 | rpos_ij_e += e*3; 33 | const float r_ijx = rpos_ij_e[0]; 34 | const float r_ijy = rpos_ij_e[1]; 35 | const float r_ijz = rpos_ij_e[2]; 36 | tvecs_n += n*9; 37 | const float t1_x = tvecs_n[0]; 38 | const float t1_y = tvecs_n[1]; 39 | const float t1_z = tvecs_n[2]; 40 | const float t2_x = tvecs_n[3]; 41 | const float t2_y = tvecs_n[4]; 42 | const float t2_z = tvecs_n[5]; 43 | const float t3_x = tvecs_n[6]; 44 | const float t3_y = tvecs_n[7]; 45 | const float t3_z = tvecs_n[8]; 46 | const float a = a_ik[i*H + k]; 47 | const int R = LATTICE_RANGE; 48 | const float Rf = (float)LATTICE_RANGE; 49 | 50 | #if VPE_DIM > 0 51 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 52 | float *sv = shared_v[threadIdx.x]; 53 | 54 | for (int dim = 0; dim < VPE_DIM; dim++) 55 | sv[dim] = 0; 56 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 57 | const float mu0 = (float)dist_max/VPE_DIM; 58 | #endif 59 | 60 | rveclens_n += n*3; 61 | const float rvl1 = rveclens_n[0]; 62 | const float rvl2 = rveclens_n[1]; 63 | const float rvl3 = rveclens_n[2]; 64 | 65 | float cutoff = (float)cutoff_radius; 66 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 67 | if (cutoff != 0.0f) 68 | { 69 | if (cutoff < 0) { 70 | // Better sync the threads in each block? 71 | // -> disabled due to thread stucking 72 | // float a_max = a; 73 | // for (int t = 0; t < THREAD_NUM; t++) 74 | // a_max = max(a_max, a_ik[i*H + t]); 75 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 76 | cutoff = sqrt(-0.5f/a)*(-cutoff); 77 | } 78 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 79 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 80 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 81 | 82 | #if MINIMUM_RANGE > 0 83 | R1 = max(R1, MINIMUM_RANGE); 84 | R2 = max(R2, MINIMUM_RANGE); 85 | R3 = max(R3, MINIMUM_RANGE); 86 | #endif 87 | } 88 | 89 | float d2min = 1e10; 90 | if (1 || dist2_min_e == NULL) 91 | { 92 | for (float n1 = -R1; n1 <= R1; n1++) 93 | for (float n2 = -R2; n2 <= R2; n2++) 94 | for (float n3 = -R3; n3 <= R3; n3++) 95 | { 96 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 97 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 98 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 99 | float d2 = dx*dx + dy*dy + dz*dz; 100 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 101 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 102 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 103 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 104 | d2min = fminf(d2min, d2); 105 | } 106 | } else { 107 | d2min = dist2_min_e[e]; 108 | } 109 | 110 | float sum = 0; 111 | for (float n1 = -R1; n1 <= R1; n1++) 112 | for (float n2 = -R2; n2 <= R2; n2++) 113 | for (float n3 = -R3; n3 <= R3; n3++) 114 | { 115 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 116 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 117 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 118 | float d2 = dx*dx + dy*dy + dz*dz; 119 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 120 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 121 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 122 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 123 | float w = expf(a*(d2 - d2min)); 124 | sum += w; 125 | 126 | #if VPE_DIM > 0 127 | // b_dim = exp( -((dim*(m/K)-dist)/(sqrt(2)*wscale*dist_max/K))**2 ) 128 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 129 | #pragma unroll 130 | for (int dim = 0; dim < VPE_DIM; dim++) 131 | { 132 | b += reci_ws_sqrt2; 133 | sv[dim] += exp(-b*b)*w; 134 | } 135 | #endif 136 | } 137 | 138 | #if VPE_DIM > 0 139 | if (W_k == NULL){ 140 | float *v = &v_ekd[tid*VPE_DIM]; 141 | #pragma unroll 142 | for (int dim = 0; dim < VPE_DIM; dim++) 143 | v[dim] = sv[dim]/sum; 144 | } else { 145 | // Do the matrix-vector multiplication: v' = Wv. 146 | float *v = &v_ekd[tid*V_HEAD_DIM]; 147 | long long int w_ind = 0; 148 | if (W_num == 1){ 149 | w_ind = 0; 150 | } else if (W_num == E) { 151 | w_ind = e; 152 | } else if (W_num == N) { 153 | w_ind = i; 154 | } 155 | const float *W = &W_k[(w_ind*H+k)*V_HEAD_DIM*VPE_DIM]; 156 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 157 | float sum_v = 0; 158 | #pragma unroll 159 | for (int dim = 0; dim < VPE_DIM; dim++){ 160 | // For numerical accuracy, it is important to do "sv[dim]/sum" 161 | // instead of "sum_v/sum" after the loop. 162 | sum_v += W[wdim*VPE_DIM+dim]*(sv[dim]); 163 | } 164 | v[wdim] = sum_v/sum; 165 | } 166 | } 167 | #endif 168 | 169 | z_ek[tid] = logf(sum) + d2min*a; 170 | } -------------------------------------------------------------------------------- /models/kernels/real_enc_proj_fwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_proj_fwd_v2( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int N, 12 | const long long int H, 13 | const long long int E, 14 | const long long int K_, 15 | const double dist_max, 16 | const double wscale, 17 | const float* W_k, 18 | const long long int W_num, 19 | const float* rveclens_n, 20 | const double cutoff_radius, 21 | float* z_ek, 22 | float* v_ekd){ 23 | 24 | const long long int tid = (long long int)blockDim.x * blockIdx.x + threadIdx.x; 25 | if (tid >= E*H) return; 26 | 27 | const long long int k = tid%H; 28 | const long long int e = tid/H; 29 | const long long int i = edge_ij_e[e]; 30 | const long long int j = edge_ij_e[E+e]; 31 | const long long int n = batch_i[i]; 32 | rpos_ij_e += e*3; 33 | const float r_ijx = rpos_ij_e[0]; 34 | const float r_ijy = rpos_ij_e[1]; 35 | const float r_ijz = rpos_ij_e[2]; 36 | tvecs_n += n*9; 37 | const float t1_x = tvecs_n[0]; 38 | const float t1_y = tvecs_n[1]; 39 | const float t1_z = tvecs_n[2]; 40 | const float t2_x = tvecs_n[3]; 41 | const float t2_y = tvecs_n[4]; 42 | const float t2_z = tvecs_n[5]; 43 | const float t3_x = tvecs_n[6]; 44 | const float t3_y = tvecs_n[7]; 45 | const float t3_z = tvecs_n[8]; 46 | const float a = a_ik[i*H + k]; 47 | const int R = LATTICE_RANGE; 48 | const float Rf = (float)LATTICE_RANGE; 49 | 50 | #if VPE_DIM > 0 51 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 52 | float *sv = shared_v[threadIdx.x]; 53 | 54 | for (int dim = 0; dim < VPE_DIM; dim++) 55 | sv[dim] = 0; 56 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 57 | const float mu0 = (float)dist_max/VPE_DIM; 58 | #endif 59 | 60 | rveclens_n += n*3; 61 | const float rvl1 = rveclens_n[0]; 62 | const float rvl2 = rveclens_n[1]; 63 | const float rvl3 = rveclens_n[2]; 64 | 65 | float cutoff = (float)cutoff_radius; 66 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 67 | if (cutoff != 0.0f) 68 | { 69 | if (cutoff < 0) { 70 | // Better sync the threads in each block? 71 | // -> disabled due to thread stucking 72 | // float a_max = a; 73 | // for (int t = 0; t < THREAD_NUM; t++) 74 | // a_max = max(a_max, a_ik[i*H + t]); 75 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 76 | cutoff = sqrt(-0.5f/a)*(-cutoff); 77 | } 78 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 79 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 80 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 81 | 82 | #if MINIMUM_RANGE > 0 83 | R1 = max(R1, MINIMUM_RANGE); 84 | R2 = max(R2, MINIMUM_RANGE); 85 | R3 = max(R3, MINIMUM_RANGE); 86 | #endif 87 | } 88 | 89 | float d2min = 1e10; 90 | if (1 || dist2_min_e == NULL) 91 | { 92 | for (float n1 = -R1; n1 <= R1; n1++) 93 | for (float n2 = -R2; n2 <= R2; n2++) 94 | for (float n3 = -R3; n3 <= R3; n3++) 95 | { 96 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 97 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 98 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 99 | float d2 = dx*dx + dy*dy + dz*dz; 100 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 101 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 102 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 103 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 104 | d2min = fminf(d2min, d2); 105 | } 106 | } else { 107 | d2min = dist2_min_e[e]; 108 | } 109 | 110 | float sum = 0; 111 | for (float n1 = -R1, s1=0; n1 <= R1; n1++, sum+=s1, s1=0) 112 | for (float n2 = -R2, s2=0; n2 <= R2; n2++, s1 +=s2, s2=0) 113 | for (float n3 = -R3; n3 <= R3; n3++) 114 | { 115 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 116 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 117 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 118 | float d2 = dx*dx + dy*dy + dz*dz; 119 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 120 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 121 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 122 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 123 | float w = expf(a*(d2 - d2min)); 124 | s2 += w; 125 | 126 | #if VPE_DIM > 0 127 | // b_dim = exp( -((dim*(m/K)-dist)/(sqrt(2)*wscale*dist_max/K))**2 ) 128 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 129 | #pragma unroll 130 | for (int dim = 0; dim < VPE_DIM; dim++) 131 | { 132 | b += reci_ws_sqrt2; 133 | sv[dim] += exp(-b*b)*w; 134 | } 135 | #endif 136 | } 137 | 138 | #if VPE_DIM > 0 139 | if (W_k == NULL){ 140 | float *v = &v_ekd[tid*VPE_DIM]; 141 | #pragma unroll 142 | for (int dim = 0; dim < VPE_DIM; dim++) 143 | v[dim] = sv[dim]/sum; 144 | } else { 145 | // Do the matrix-vector multiplication: v' = Wv. 146 | float *v = &v_ekd[tid*V_HEAD_DIM]; 147 | long long int w_ind = 0; 148 | if (W_num == 1){ 149 | w_ind = 0; 150 | } else if (W_num == E) { 151 | w_ind = e; 152 | } else if (W_num == N) { 153 | w_ind = i; 154 | } 155 | #pragma unroll 156 | for (int dim = 0; dim < VPE_DIM; dim++) 157 | sv[dim] = sv[dim]/sum; 158 | const float *W = &W_k[(w_ind*H+k)*V_HEAD_DIM*VPE_DIM]; 159 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 160 | float sum_v = 0; 161 | #pragma unroll 162 | for (int dim = 0; dim < VPE_DIM; dim++){ 163 | // For numerical accuracy, it is important to do "sv[dim]/sum" 164 | // instead of "sum_v/sum" after the loop. 165 | sum_v += W[wdim*VPE_DIM+dim]*sv[dim]; 166 | } 167 | v[wdim] = sum_v; 168 | } 169 | } 170 | #endif 171 | 172 | z_ek[tid] = logf(sum) + d2min*a; 173 | } -------------------------------------------------------------------------------- /data/download_jarvis.py: -------------------------------------------------------------------------------- 1 | import jarvis 2 | import os 3 | import pathlib 4 | from jarvis.db.figshare import data as jdata 5 | from jarvis.core.atoms import Atoms 6 | 7 | from tqdm import tqdm 8 | import random 9 | import numpy 10 | import pickle 11 | import math 12 | 13 | def get_id_train_val_test( 14 | total_size=1000, 15 | split_seed=123, 16 | train_ratio=None, 17 | val_ratio=0.1, 18 | test_ratio=0.1, 19 | n_train=None, 20 | n_test=None, 21 | n_val=None, 22 | keep_data_order=False, 23 | ): 24 | """Get train, val, test IDs.""" 25 | if ( 26 | train_ratio is None 27 | and val_ratio is not None 28 | and test_ratio is not None 29 | ): 30 | if train_ratio is None: 31 | assert val_ratio + test_ratio < 1 32 | train_ratio = 1 - val_ratio - test_ratio 33 | print("Using rest of the dataset except the test and val sets.") 34 | else: 35 | assert train_ratio + val_ratio + test_ratio <= 1 36 | # indices = list(range(total_size)) 37 | if n_train is None: 38 | n_train = int(train_ratio * total_size) 39 | if n_test is None: 40 | n_test = int(test_ratio * total_size) 41 | if n_val is None: 42 | n_val = int(val_ratio * total_size) 43 | ids = list(numpy.arange(total_size)) 44 | if not keep_data_order: 45 | random.seed(split_seed) 46 | random.shuffle(ids) 47 | if n_train + n_val + n_test > total_size: 48 | raise ValueError( 49 | "Check total number of samples.", 50 | n_train + n_val + n_test, 51 | ">", 52 | total_size, 53 | ) 54 | 55 | id_train = ids[:n_train] 56 | id_val = ids[-(n_val + n_test) : -n_test] # noqa:E203 57 | id_test = ids[-n_test:] 58 | return id_train, id_val, id_test 59 | 60 | 61 | if __name__ == '__main__': 62 | # If SSL certification error occurs Under proxy, the following workaround may work. 63 | # os.environ['REQUESTS_CA_BUNDLE'] = '' 64 | # os.environ['CURL_CA_BUNDLE'] = '' 65 | 66 | # print(jarvis.__path__) 67 | # cached_files = pathlib.Path(jarvis.__path__[0]).glob("db/*.zip") 68 | # for file in cached_files: 69 | # print(f"Removing {file.absolute()}") 70 | # os.remove(file.absolute()) 71 | 72 | datasets = [ 73 | "megnet", 74 | "dft_3d_2021", 75 | "dft_3d_2021", 76 | "dft_3d_2021", 77 | ] 78 | save_names = [ 79 | "jarvis__megnet", 80 | "jarvis__dft_3d_2021", 81 | "jarvis__dft_3d_2021-mbj_bandgap", 82 | "jarvis__dft_3d_2021-ehull", 83 | ] 84 | used_vals = [ 85 | { 86 | 'id': ('material_id', str), 87 | 'gap pbe': ('bandgap', float), 88 | 'e_form': ('e_form', float), 89 | 'structure': ('structure', object) 90 | }, 91 | { 92 | 'structure': ('structure', object), 93 | 'jid': ('material_id', str), 94 | 'formation_energy_peratom': ('formation_energy', float), 95 | 'optb88vdw_total_energy': ('total_energy', float), 96 | 'optb88vdw_bandgap': ('opt_bandgap', float), 97 | }, 98 | { 99 | 'structure': ('structure', object), 100 | 'jid': ('material_id', str), 101 | 'mbj_bandgap': ('mbj_bandgap', float), 102 | }, 103 | { 104 | 'structure': ('structure', object), 105 | 'jid': ('material_id', str), 106 | 'ehull': ('ehull', float), 107 | } 108 | ] 109 | 110 | for i, t in enumerate(datasets): 111 | try: 112 | print(f"Processing dataset: {t}") 113 | data = jdata(t) 114 | new_data = [] 115 | print(data[0]) 116 | for x in tqdm(data): 117 | atoms = Atoms( 118 | lattice_mat=x['atoms']['lattice_mat'], 119 | coords=x['atoms']['coords'], 120 | elements=x['atoms']['elements'], 121 | cartesian=x['atoms']['cartesian'], 122 | ) 123 | x['structure'] = atoms.pymatgen_converter() 124 | 125 | new_x = {} 126 | ok = True 127 | for key in used_vals[i]: 128 | newkey, vtype = used_vals[i][key] 129 | val = x[key] 130 | new_x[newkey] = val 131 | if vtype == int and type(val) != int: 132 | ok = False 133 | break 134 | 135 | elif vtype == float: 136 | if type(val) == int: 137 | x[newkey] = float(val) 138 | elif type(val) == float and not math.isnan(val) and not math.isinf(val): 139 | pass 140 | else: 141 | ok = False 142 | break 143 | 144 | elif vtype == str and val is None: 145 | x[newkey] = "" 146 | elif vtype == str and type(val) != str: 147 | x[newkey] = str(val) 148 | 149 | if ok: 150 | new_data.append(new_x) 151 | 152 | print(f"filtered: {len(data)} -> {len(new_data)} ({len(new_data) - len(data)})") 153 | data = new_data 154 | 155 | print("Printing the first item...") 156 | for k in data[0]: 157 | print(f"{k}\t: {data[0][k]}") 158 | 159 | if t == "megnet": 160 | id_train, id_val, id_test = get_id_train_val_test( 161 | len(data), 162 | n_train=60000, 163 | n_val=5000, 164 | n_test=4239 165 | ) 166 | else: 167 | id_train, id_val, id_test = get_id_train_val_test( 168 | len(data), 169 | ) 170 | 171 | splits = {} 172 | splits['train'] = [data[i] for i in id_train] 173 | splits['val'] = [data[i] for i in id_val] 174 | splits['test'] = [data[i] for i in id_test] 175 | splits['all'] = data 176 | 177 | print("Saving split data...") 178 | for key in splits: 179 | save_dir = f"{save_names[i]}/{key}/raw" 180 | os.makedirs(save_dir, exist_ok=True) 181 | 182 | print(f"{key}\t:{len(splits[key])}") 183 | with open(f"{save_dir}/raw_data.pkl", "wb") as fp: 184 | pickle.dump(splits[key], fp) 185 | 186 | except Exception as e: 187 | raise e 188 | -------------------------------------------------------------------------------- /models/kernels/reduce_kernel_utils.cuh: -------------------------------------------------------------------------------- 1 | /* 2 | * This code is modified from 3 | * https://github.com/NVIDIA/FasterTransformer/blob/df4a7534860137e060e18d2ebf019906120ea204/src/fastertransformer/kernels/reduce_kernel_utils.cuh 4 | * to only contain warpReduceSum, warpReduceMax, warpReduceSumV2, warpReduceMaxV2. 5 | */ 6 | 7 | /* 8 | * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. 9 | * 10 | * Licensed under the Apache License, Version 2.0 (the "License"); 11 | * you may not use this file except in compliance with the License. 12 | * You may obtain a copy of the License at 13 | * 14 | * http://www.apache.org/licenses/LICENSE-2.0 15 | * 16 | * Unless required by applicable law or agreed to in writing, software 17 | * distributed under the License is distributed on an "AS IS" BASIS, 18 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | * See the License for the specific language governing permissions and 20 | * limitations under the License. 21 | */ 22 | 23 | #pragma once 24 | // #include 25 | // #include 26 | #include 27 | // #include 28 | // #include 29 | // #include 30 | //__reduce_max_sync 31 | 32 | #define FULL_MASK 0xffffffff 33 | 34 | template 35 | __inline__ __device__ T warpReduceSum(T val) 36 | { 37 | //unsigned b = __ballot_sync(FULL_MASK, threadIdx.x < blockDim.x); 38 | #pragma unroll 39 | for (int mask = 16; mask > 0; mask >>= 1) 40 | val += __shfl_xor_sync(FULL_MASK, val, mask, 32); //__shfl_sync bf16 return float when sm < 80 41 | return val; 42 | } 43 | 44 | /* Calculate the sum of all elements in a block */ 45 | template 46 | __inline__ __device__ T blockReduceSum(T val) 47 | { 48 | static __shared__ T shared[32]; 49 | int lane = threadIdx.x & 0x1f; 50 | int wid = threadIdx.x >> 5; 51 | 52 | val = warpReduceSum(val); 53 | 54 | if (lane == 0) 55 | shared[wid] = val; 56 | 57 | __syncthreads(); 58 | 59 | /* 60 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 61 | // blockDim.x is not divided by 32 62 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 63 | */ 64 | 65 | // Modify from threadIdx.x to lane to share the result in the block. 66 | val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); 67 | val = warpReduceSum(val); 68 | 69 | return val; 70 | } 71 | 72 | template 73 | __inline__ __device__ T warpReduceMax(T val) 74 | { 75 | #pragma unroll 76 | for (int mask = 16; mask > 0; mask >>= 1) 77 | val = max(val, __shfl_xor_sync(FULL_MASK, val, mask, 32)); 78 | return val; 79 | } 80 | 81 | /* Calculate the maximum of all elements in a block */ 82 | template 83 | __inline__ __device__ T blockReduceMax(T val) 84 | { 85 | static __shared__ T shared[32]; 86 | int lane = threadIdx.x & 0x1f; // in-warp idx 87 | int wid = threadIdx.x >> 5; // warp idx 88 | 89 | val = warpReduceMax(val); // get maxx in each warp 90 | 91 | if (lane == 0) // record in-warp maxx by warp Idx 92 | shared[wid] = val; 93 | 94 | __syncthreads(); 95 | 96 | /* 97 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 98 | // blockDim.x is not divided by 32 99 | val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; 100 | */ 101 | 102 | // Modify from threadIdx.x to lane to share the result in the block. 103 | val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; 104 | val = warpReduceMax(val); 105 | 106 | return val; 107 | } 108 | 109 | 110 | /* Calculate the maximum of all elements in a block */ 111 | template 112 | __inline__ __device__ T blockAllReduceMax(T val) 113 | { 114 | static __shared__ T shared[32]; 115 | int lane = threadIdx.x & 0x1f; // in-warp idx 116 | int wid = threadIdx.x >> 5; // warp idx 117 | 118 | val = warpReduceMax(val); // get maxx in each warp 119 | 120 | if (lane == 0) // record in-warp maxx by warp Idx 121 | shared[wid] = val; 122 | 123 | __syncthreads(); 124 | 125 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 126 | // blockDim.x is not divided by 32 127 | val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; 128 | val = warpReduceMax(val); 129 | 130 | return val; 131 | } 132 | 133 | template 134 | __inline__ __device__ T warpReduceSumV2(T* val) 135 | { 136 | #pragma unroll 137 | for (int i = 0; i < NUM; i++) { 138 | #pragma unroll 139 | for (int mask = 16; mask > 0; mask >>= 1) 140 | val[i] += __shfl_xor_sync(FULL_MASK, val[i], mask, 32); 141 | } 142 | return (T)(0.0f); 143 | } 144 | 145 | template 146 | __inline__ __device__ T blockReduceSumV2(T* val) 147 | { 148 | static __shared__ T shared[NUM][33]; 149 | int lane = threadIdx.x & 0x1f; 150 | int wid = threadIdx.x >> 5; 151 | 152 | warpReduceSumV2(val); 153 | 154 | if (lane == 0) { 155 | #pragma unroll 156 | for (int i = 0; i < NUM; i++) { 157 | shared[i][wid] = val[i]; 158 | } 159 | } 160 | 161 | __syncthreads(); 162 | 163 | bool is_mask = threadIdx.x < (blockDim.x / 32.f); 164 | #pragma unroll 165 | for (int i = 0; i < NUM; i++) { 166 | val[i] = is_mask ? shared[i][lane] : (T)(0.0f); 167 | } 168 | warpReduceSumV2(val); 169 | return (T)0.0f; 170 | } 171 | 172 | template 173 | __inline__ __device__ T warpReduceMaxV2(T* val) 174 | { 175 | #pragma unroll 176 | for (int i = 0; i < NUM; i++) { 177 | #pragma unroll 178 | for (int mask = 16; mask > 0; mask >>= 1) 179 | val[i] = max(val[i], __shfl_xor_sync(FULL_MASK, val[i], mask, 32)); 180 | } 181 | return (T)(0.0f); 182 | } 183 | 184 | template 185 | __inline__ __device__ T blockReduceMaxV2(T* val) 186 | { 187 | static __shared__ T shared[32][NUM]; 188 | int lane = threadIdx.x & 0x1f; // in-warp idx 189 | int wid = threadIdx.x >> 5; // warp idx 190 | 191 | warpReduceMaxV2(val); // get maxx in each warp 192 | 193 | if (lane == 0) // record in-warp maxx by warp Idx 194 | { 195 | #pragma unroll 196 | for (int i = 0; i < NUM; i++) { 197 | shared[wid][i] = val[i]; 198 | } 199 | } 200 | 201 | __syncthreads(); 202 | 203 | // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent 204 | // blockDim.x is not divided by 32 205 | bool is_mask = threadIdx.x < (blockDim.x / 32.f); 206 | #pragma unroll 207 | for (int i = 0; i < NUM; i++) { 208 | val[i] = is_mask ? shared[lane][i] : (T)-1e20f; 209 | } 210 | warpReduceMaxV2(val); 211 | 212 | return (T)0.0f; 213 | } 214 | -------------------------------------------------------------------------------- /models/cuda_funcs/real_space_enc_proj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | 18 | class RealPeriodicEncodingWithProjFuncCUDA(torch.autograd.Function): 19 | @staticmethod 20 | def forward(ctx, a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, K, dist_max, wscale, \ 21 | W_k, rvlen_n=None, cutoff_radius=None): 22 | 23 | # a_ik : (points, heads) 24 | # rpos_ij_e : (edges, 3) 25 | # tvecs_n : (batch, 3, 3) 26 | # batch_i : (points) 27 | # edge_ij_e : (2, edges) 28 | # z_ijk = log( sum_n exp( a_ik*|pj + t1*n1+t2*n2+t3*n3 - pi|^2 ) ) 29 | # : (edges, heads) 30 | N, H = a_ik.shape 31 | E = edge_ij_e.shape[1] 32 | kw = {'device': a_ik.device, 'dtype': a_ik.dtype} 33 | 34 | a_ik = a_ik.contiguous().detach() 35 | rpos_ij_e = rpos_ij_e.contiguous() 36 | tvecs_n = tvecs_n.contiguous() 37 | batch_i = batch_i.contiguous() 38 | dist2_min_e = dist2_min_e.contiguous() if dist2_min_e is not None else None 39 | edge_ij_e = edge_ij_e.contiguous() 40 | if W_k is not None: 41 | W_k = W_k.detach().contiguous() 42 | assert W_k.dim() in (3, 4) 43 | W_num = 1 if W_k.dim() == 3 else W_k.shape[0] 44 | W_dim = W_k.shape[-2] 45 | v_ekd = torch.empty((E, H, W_dim), **kw) if K > 0 else None # not neaded for noproj 46 | else: 47 | W_num = 0 48 | W_dim = 0 49 | v_ekd = torch.empty((E, H, K), **kw) if K > 0 else None 50 | z_ek = torch.empty((E, H), **kw) 51 | 52 | bsz = H 53 | dev = a_ik.device 54 | 55 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 56 | if False and rvlen_n is None: 57 | KernelManager.position_enc_proj_forward( ((E*H+bsz-1)//bsz, ), (bsz, ), ( 58 | _to_copy(a_ik), 59 | _to_copy(rpos_ij_e), 60 | _to_copy(dist2_min_e), 61 | _to_copy(tvecs_n), 62 | _to_copy(batch_i), 63 | _to_copy(edge_ij_e), 64 | N, H, E, 65 | K, dist_max, wscale, 66 | _to_copy(W_k), W_num, 67 | _to_copy(z_ek), 68 | _to_copy(v_ekd), 69 | )) 70 | else: 71 | from .. import global_config as config 72 | kernel = KernelManager.real_enc_proj_fwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 73 | else KernelManager.real_enc_proj_fwd 74 | kernel( ((E*H+bsz-1)//bsz, ), (bsz, ), ( 75 | _to_copy(a_ik), 76 | _to_copy(rpos_ij_e), 77 | _to_copy(dist2_min_e), 78 | _to_copy(tvecs_n), 79 | _to_copy(batch_i), 80 | _to_copy(edge_ij_e), 81 | N, H, E, 82 | K, dist_max, wscale, 83 | _to_copy(W_k), W_num, 84 | _to_copy(rvlen_n), cutoff_radius, 85 | _to_copy(z_ek), 86 | _to_copy(v_ekd), 87 | )) 88 | 89 | ctx.save_for_backward(a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, rvlen_n, W_k, z_ek, v_ekd) 90 | ctx.K = K 91 | ctx.dist_max = dist_max 92 | ctx.wscale = wscale 93 | ctx.cutoff_radius = cutoff_radius 94 | if K <= 0: 95 | return z_ek, 96 | 97 | return z_ek, v_ekd 98 | 99 | @staticmethod 100 | def backward(ctx, gz_ek, gv_ekd=None): 101 | # a_ik, rpos_ij_e, tvecs_n, batch_i, edge_ij_e, z_ek = ctx.saved_tensors[:6] 102 | a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, rvlen_n, W_k, z_ek, v_ekd = ctx.saved_tensors 103 | K = ctx.K 104 | dist_max = ctx.dist_max 105 | wscale = ctx.wscale 106 | cutoff_radius = ctx.cutoff_radius 107 | N, H = a_ik.shape 108 | E = edge_ij_e.shape[1] 109 | 110 | e_start_i = torch.zeros(N+1, dtype=batch_i.dtype, device=batch_i.device) 111 | e_start_i.scatter_add_(0, edge_ij_e[0]+1, torch.ones_like(edge_ij_e[0])) 112 | e_start_i = e_start_i.cumsum(0) 113 | 114 | ga_ik = torch.empty_like(a_ik) 115 | 116 | dev = a_ik.device 117 | gW_k = None 118 | if W_k is not None: 119 | # W: (edges or 1, heads, head_dim, K) 120 | assert W_k.dim() in (3, 4) 121 | W_num = 1 if W_k.dim() == 3 else W_k.shape[0] 122 | W_dim = W_k.shape[-2] 123 | 124 | # W: (edges or 1, heads, head_dim, K) 125 | # gv_ekd:(edges , heads, Vdim) 126 | # v_ekd: (edges or 1, heads, K) 127 | gW_k = torch.empty((max(W_num,N),)+W_k.shape[-3:], device=dev, dtype=a_ik.dtype) 128 | else: 129 | W_num = 0 130 | W_dim = 0 131 | 132 | bsz = H 133 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 134 | if False and rvlen_n is None: 135 | KernelManager.position_enc_proj_backward(((N*H+bsz-1)//bsz, ), (bsz, ), ( 136 | _to_copy(a_ik.detach()), 137 | _to_copy(rpos_ij_e), 138 | _to_copy(dist2_min_e), 139 | _to_copy(tvecs_n), 140 | _to_copy(batch_i), 141 | _to_copy(edge_ij_e), 142 | _to_copy(e_start_i), 143 | _to_copy(z_ek.detach()), 144 | _to_copy(gz_ek.detach().contiguous()), 145 | _to_copy(gv_ekd), 146 | N, H, E, 147 | K, dist_max, wscale, 148 | _to_copy(W_k), W_num, 149 | _to_copy(ga_ik), 150 | _to_copy(gW_k), 151 | )) 152 | else: 153 | from .. import global_config as config 154 | kernel = KernelManager.real_enc_proj_bwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 155 | else KernelManager.real_enc_proj_bwd 156 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), ( 157 | _to_copy(a_ik.detach()), 158 | _to_copy(rpos_ij_e), 159 | # _to_copy(dist2_min_e), 160 | _to_copy(tvecs_n), 161 | _to_copy(batch_i), 162 | _to_copy(edge_ij_e), 163 | _to_copy(e_start_i), 164 | _to_copy(z_ek.detach()), 165 | _to_copy(gz_ek.detach().contiguous()), 166 | _to_copy(gv_ekd), 167 | N, H, E, #K, 168 | dist_max, wscale, 169 | _to_copy(W_k), W_num, 170 | _to_copy(rvlen_n), cutoff_radius, 171 | _to_copy(ga_ik), 172 | _to_copy(gW_k), 173 | )) 174 | 175 | if rvlen_n is None: 176 | return ga_ik, None, None, None, None, None, None, None, None, gW_k 177 | 178 | return ga_ik, None, None, None, None, None, None, None, None, gW_k, None, None 179 | 180 | 181 | -------------------------------------------------------------------------------- /models/cuda_funcs/fused_dpa_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | class FusedDotProductAttentionCUDA_v2(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, batch_i, edge_ij_e): 20 | N, H, K = que_ihk.shape 21 | E = edge_ij_e.shape[1] 22 | dev = que_ihk.device 23 | 24 | e_start_i = torch.zeros(N+1, dtype=batch_i.dtype, device=batch_i.device) 25 | e_start_i.scatter_add_(0, edge_ij_e[0]+1, torch.ones_like(edge_ij_e[0])) 26 | e_start_i = e_start_i.cumsum(0) 27 | 28 | que_ihk = que_ihk.contiguous().detach() 29 | key_ihk = key_ihk.contiguous().detach() 30 | val_ihk = val_ihk.contiguous().detach() 31 | aij_eh = aij_eh.contiguous().detach() if aij_eh is not None else None 32 | bij_ehk = bij_ehk.contiguous().detach() if bij_ehk is not None else None 33 | batch_i = batch_i.contiguous() 34 | edge_ij_e = edge_ij_e.contiguous() 35 | 36 | output = torch.empty_like(val_ihk) 37 | prob_eh = torch.empty((E, H), dtype=que_ihk.dtype, device=dev) 38 | 39 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 40 | bsz = 1 41 | KernelManager.fused_dpa_fwd_v2( 42 | ((N*H+bsz-1)//bsz,), (bsz,), 43 | ( 44 | _to_copy(que_ihk), 45 | _to_copy(key_ihk), 46 | _to_copy(val_ihk), 47 | _to_copy(aij_eh), 48 | _to_copy(bij_ehk), 49 | _to_copy(edge_ij_e), 50 | _to_copy(e_start_i), 51 | N, H, E, 52 | _to_copy(prob_eh), 53 | _to_copy(output), 54 | ) 55 | ) 56 | 57 | ctx.save_for_backward(que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, 58 | batch_i, edge_ij_e, e_start_i, 59 | prob_eh, output) 60 | return output 61 | 62 | @staticmethod 63 | def backward(ctx, go_ihk): 64 | que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, \ 65 | batch_i, edge_ij_e, e_start_i, \ 66 | prob_eh, output = ctx.saved_tensors 67 | 68 | N, H, K = que_ihk.shape 69 | E = edge_ij_e.shape[1] 70 | dev = que_ihk.device 71 | 72 | B = batch_i.max().item()+1 73 | sizes = torch.zeros(B, dtype=torch.long, device=dev) 74 | sizes.scatter_add_(0, batch_i, torch.ones_like(batch_i)) 75 | sizes2 = sizes*sizes 76 | 77 | if False: 78 | gque = [] 79 | gkey = [] 80 | gval = [] 81 | gbij = [] 82 | gaij = [] 83 | sizes = torch.zeros(N, dtype=batch_i.dtype, device=batch_i.device) 84 | sizes.scatter_add_(0, batch_i, torch.ones_like(batch_i)) 85 | sizes2 = sizes*sizes 86 | _sizes = sizes.tolist() 87 | _sizes2 = sizes2.tolist() 88 | for q,k,v,a,b,o,p,go,s in zip( 89 | que_ihk.split_with_sizes(_sizes), 90 | key_ihk.split_with_sizes(_sizes), 91 | val_ihk.split_with_sizes(_sizes), 92 | aij_eh.split_with_sizes(_sizes2), 93 | bij_ehk.split_with_sizes(_sizes2), 94 | output.split_with_sizes(_sizes), 95 | prob_eh.split_with_sizes(_sizes2), 96 | go_ihk.split_with_sizes(_sizes), 97 | _sizes): 98 | # q/k/v/o/go: (S, H, K) 99 | # a/p: (S*S, H) 100 | # b: (S*S, H, K) 101 | gb = go.reshape(s,1,H,K) * p.reshape(s,s,H,1) 102 | gv = gb.sum(dim=0) 103 | gval.append(gv) 104 | gbij.append(gb.reshape(s*s,H,K)) 105 | gsm = (v.reshape(1,s,H,K) + b.reshape(s,s,H,K) - o.reshape(s,1,H,K))*gb 106 | ga = gsm.sum(dim=3) 107 | gq = (ga.reshape(s,s,H,1)*k.reshape(1,s,H,K)).sum(dim=1) 108 | gk = (ga.reshape(s,s,H,1)*q.reshape(s,1,H,K)).sum(dim=0) 109 | gaij.append(ga.reshape(s*s,H)) 110 | gque.append(gq) 111 | gkey.append(gk) 112 | 113 | gque = torch.cat(gque) 114 | gkey = torch.cat(gkey) 115 | gval = torch.cat(gval) 116 | gbij = torch.cat(gbij) 117 | gaij = torch.cat(gaij) 118 | return gque, gkey, gval, gaij, gbij, None, None 119 | 120 | gque = torch.empty_like(que_ihk) 121 | gkey = torch.empty_like(key_ihk) 122 | gval = torch.empty_like(val_ihk) 123 | gaij = torch.empty_like(aij_eh) 124 | gbij = torch.empty_like(bij_ehk) if bij_ehk is not None else None 125 | go_ihk = go_ihk.contiguous().detach() 126 | 127 | tprob_eh = torch.empty_like(prob_eh) 128 | tbij_ehk = torch.empty_like(bij_ehk) if bij_ehk is not None else None 129 | start_inds = torch.constant_pad_nd(sizes2.cumsum(0), (1,0)) 130 | 131 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 132 | upper_mask = edge_ij_e[0] <= edge_ij_e[1] 133 | hE = upper_mask.long().sum().item() 134 | upper_e_t = torch.arange(E, dtype=torch.long, device=dev)[upper_mask] 135 | upper_batch_t = batch_i[edge_ij_e[0, upper_mask]] 136 | mat_sec_t = start_inds[upper_batch_t] 137 | sizes_t = sizes[upper_batch_t] 138 | 139 | def irregular_transpose(src:Tensor, dst:Tensor, C:int): 140 | bsz = min(32, C) 141 | KernelManager.irregular_transpose( 142 | ((hE*C+bsz-1)//bsz, ), (bsz, ), 143 | (_to_copy(src), _to_copy(upper_e_t), _to_copy(mat_sec_t), _to_copy(sizes_t), hE, C, _to_copy(dst)) 144 | ) 145 | 146 | irregular_transpose(prob_eh, tprob_eh, H) 147 | 148 | if bij_ehk is not None: 149 | irregular_transpose(bij_ehk, tbij_ehk, H*K) 150 | 151 | assert (sizes <= 1024).all(), "Max system size is 1024" 152 | bsz = 1 153 | KernelManager.fused_dpa_bwd_v2( 154 | ((N*H+bsz-1)//bsz,), (bsz,), 155 | ( 156 | _to_copy(que_ihk), 157 | _to_copy(val_ihk), 158 | _to_copy(tbij_ehk), 159 | _to_copy(edge_ij_e), 160 | _to_copy(e_start_i), 161 | N, H, E, 162 | _to_copy(tprob_eh), 163 | _to_copy(output), 164 | _to_copy(go_ihk), 165 | _to_copy(gkey), 166 | _to_copy(gval), 167 | _to_copy(gaij), 168 | _to_copy(gbij), 169 | ) 170 | ) 171 | 172 | # tranpose gaij and gbij 173 | irregular_transpose(gaij, gaij, H) 174 | if gbij is not None: 175 | irregular_transpose(gbij, gbij, H*K) 176 | 177 | # use gaij as grad softmax to compute grad q. 178 | bsz = 1 179 | KernelManager.fused_dpa_bwd_q_v2( 180 | ((N*H+bsz-1)//bsz,), (bsz,), 181 | ( 182 | _to_copy(key_ihk), 183 | _to_copy(gaij), 184 | _to_copy(edge_ij_e), 185 | _to_copy(e_start_i), 186 | N, H, E, 187 | _to_copy(gque), 188 | ) 189 | ) 190 | 191 | return gque, gkey, gval, gaij, gbij, None, None 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Crystalformer: Infinitely Connected Attention for Periodic Structure Encoding** 2 | Tatsunori Taniai, Ryo Igarashi, Yuta Suzuki, Naoya Chiba, Kotaro Saito, Yoshitaka Ushiku, and Kanta Ono 3 | In *The Twelfth International Conference on Learning Representations* (ICLR 2024) 4 | 5 | ![GNNs vs Crystalformer](https://omron-sinicx.github.io/crystalformer/teaser.png "Crystalfomer") 6 | 7 | [[Paper](https://openreview.net/forum?id=fxQiecl9HB)] [[Project](https://omron-sinicx.github.io/crystalformer/)] 8 | 9 | **NEWS: A cleaned codebase with extended features is provided in our follow-up work, [CrystalFramer](https://github.com/omron-sinicx/crystalframer).** 10 | 11 | ## Citation 12 | ```text 13 | @inproceedings{taniai2024crystalformer, 14 | title = {Crystalformer: Infinitely Connected Attention for Periodic Structure Encoding}, 15 | author = {Tatsunori Taniai and 16 | Ryo Igarashi and 17 | Yuta Suzuki and 18 | Naoya Chiba and 19 | Kotaro Saito and 20 | Yoshitaka Ushiku and 21 | Kanta Ono 22 | }, 23 | booktitle = {The Twelfth International Conference on Learning Representations}, 24 | year = {2024}, 25 | url = {https://openreview.net/forum?id=fxQiecl9HB} 26 | } 27 | ``` 28 | 29 | ## Setup a Docker environment 30 | ```bash 31 | cd docker/pytorch21_cuda121 32 | docker build -t main/crystalformer:latest . 33 | docker run --gpus=all --name crystalformer --shm-size=2g -v ../../:/workspace -it main/crystalformer:latest /bin/bash 34 | ``` 35 | 36 | ## Prepare datasets 37 | In the docker container: 38 | ```bash 39 | cd /workspace/data 40 | python download_megnet_elastic.py 41 | python downlad_jarvis.py 42 | ``` 43 | 44 | ## Testing 45 | Download pretrained weights: [[GoogleDrive](https://drive.google.com/file/d/1yEmwnWflYHGlwQia1xb3G91u2Edz8H2a/view?usp=sharing)] 46 | In the `/workspace` directory in the docker container: 47 | ```bash 48 | unzip weights.zip 49 | . demo.sh 50 | ``` 51 | Currently, pretrained models for MEGNET's bandgap and e_form with 4 or 7 attention blocks are available. 52 | 53 | ## Training 54 | ### Single GPU Training 55 | In the `/workspace` directory in the docker container: 56 | ```bash 57 | CUDA_VISIBLE_DEVICES=0 python train.py -p latticeformer/default.json \ 58 | --save_path result \ 59 | --n_epochs 500 \ 60 | --experiment_name demo \ 61 | --num_layers 4 \ 62 | --value_pe_dist_real 64 \ 63 | --target_set jarvis__megnet-shear \ 64 | --targets shear \ 65 | --batch_size 128 \ 66 | --lr 0.0005 \ 67 | --model_dim 128 \ 68 | --embedding_dim 128 \ 69 | 70 | ``` 71 | Setting `--value_pe_dist_real 0` yields the "simplified model" in the paper. 72 | 73 | ### Multiple GPU Training 74 | In the `/workspace' directory in the docker container: 75 | ```bash 76 | CUDA_VISIBLE_DEVICES=0,1 python train.py -p latticeformer/default.json \ 77 | --save_path result \ 78 | --n_epochs 500 \ 79 | --experiment_name demo \ 80 | --num_layers 4 \ 81 | --value_pe_dist_real 64 \ 82 | --target_set jarvis__megnet-shear \ 83 | --targets shear \ 84 | --batch_size 128 \ 85 | --lr 0.0005 \ 86 | --model_dim 128 \ 87 | --embedding_dim 128 \ 88 | 89 | ``` 90 | Currently, the throughput gain by multi-gpu training is limited. Suggest 2 or 4 GPUs at most. 91 | 92 | ## Datasets and Targets 93 | 94 | | target_set | targets | Unit | train | val | test | 95 | | ------------------------------- | ------------------- | --------- | --------- | ----- | --------- | 96 | | jarvis__megnet | e_form | eV/atom | 60000 | 5000 | 4239 | 97 | | jarvis__megnet | bandgap | eV | 60000 | 5000 | 4239 | 98 | | jarvis__megnet-bulk | bulk_modulus | log(GPA) | 4664 | 393 | 393 | 99 | | jarvis__megnet-shear | shear_modulus | log(GPA) | 4664 | 392 | 393 | 100 | | jarvis__dft_3d_2021 | formation_energy | eV/atom | 44578 | 5572 | 5572 | 101 | | jarvis__dft_3d_2021 | total_energy | eV/atom | 44578 | 5572 | 5572 | 102 | | jarvis__dft_3d_2021 | opt_bandgap | eV | 44578 | 5572 | 5572 | 103 | | jarvis__dft_3d_2021-mbj_bandgap | mbj_bandgap | eV | 14537 | 1817 | 1817 | 104 | | jarvis__dft_3d_2021-ehull | ehull | eV | 44296 | 5537 | 5537 | 105 | 106 | Use the following hyperparameters: 107 | - For the `jarvis__megnet` datasets: `--n_epochs 500 --batch_size 128` 108 | - For the `dft_3d_2021-mbj_bandgap` dataset: `--n_epochs 1600 --batch_size 256` 109 | - For the other `dft_3d_2021` datasets: `--n_epochs 800 --batch_size 256` 110 | 111 | ## Hyperparameters 112 | General training hyperparameters: 113 | - `n_epochs` (int): The number of training epochs. 114 | - `batch_size` (int): The batch size (i.e., the number of materials per training step). 115 | - `loss_func` (`L1`|`MSE`|`Smooth_L1`): The regression loss function form. 116 | - `optimizer` (`adamw`|`adam`|): The choice of optimizer. 117 | - `adam_betas` (floats): beta1 and beta2 of Adam and AdamW. 118 | - `lr` (float): The initial learning rate. The default setting (5e-4) works mostly the best. 119 | - `lr_sch` (`inverse_sqrt_nowarmup`|`const`): The learning rate schedule. `inverse_sqrt_nowarmup` sets learning rate to `lr*sqrt(t/(t+T))` where T is specified by `sch_params`. `const` uses a constant learning rate `lr`. 120 | 121 | Final MLP's hyperparameters: 122 | - `embedding_dim` (ints): The intermediate dims of the final MLP after pooling, defining Pooling-Repeat[Linear-ReLU]-FinalLinear. The default setting (128) defines Pooling-Linear(128)-ReLU-FinalLinear(1). 123 | - `norm_type` (`no`|`bn`): Whether or not use BatchNorm in MLP. 124 | 125 | Transformer's hyperparameters: 126 | - `num_layers` (int): The number of self-attention blocks. Should be 4 or higher. 127 | - `model_dim` (int): The feature dimension of Transformer. 128 | - `ff_dim` (int): The intermediate feature dimension of the feed-forward networks in Transformer. 129 | - `head_num` (int): The number of heads of multi-head attention (HMA). 130 | 131 | Crystalformer's hyperparameters. 132 | - `scale_real` (float or floats): "r_0" in the paper. (Passing multiple values allows different settings for individual attention blocks.) 133 | - `gauss_lb_real` (float): The bound "b" for the rho function in the paper. 134 | - `value_pe_dist_real` (int): The number of radial basis functions (i.e., edge feature dim "K" in the paper). Should be a multiple of 16. 135 | - `value_pe_dist_max` (float): "r_max" in the paper. A positive value directly specifies r_max in Å, while a negative value specifies r_max via r_max = (-value_pe_dist_max)*scale_real. 136 | - `domain` (`real`|`multihead`|`real-reci`): Whether use reciprocal-space attention by parallel MHA (`multihead`) or block-wisely interleaving between real and reciprocal space (`real-reci`). When reciprocal-space attention is used, `scale_reci` and `gauss_lb_reci` can also be specified. 137 | 138 | ## Use a custom dataset 139 | For each of train, val, and test splits, make a list of dicts containing pymatgen's Structures and label values: 140 | - list 141 | - dict 142 | - 'structure': pymatgen.core.structure.Structure 143 | - 'property1': a float value of `propety1` of this structure 144 | - 'property2': a float value of `propety2` of this structure 145 | - ... 146 | 147 | Dump the list of each split in a directory with your dataset name as 148 | ```python 149 | import os 150 | import pickle 151 | 152 | target_set = 'your_dataset_name' 153 | split = 'train' # or 'val' or 'test' 154 | 155 | os.makedirs(f'data/{target_set}/{split}/raw', exist_ok=True) 156 | with open(f'data/{target_set}/{split}/raw/raw_data.pkl', mode="wb") as fp: 157 | pickle.dump(data_list, fp) 158 | ``` 159 | 160 | Then, you can specify your dataset and its target property name as 161 | ```bash 162 | python train -p latticeformer/default.json \ 163 | --target_set your_dataset_name \ 164 | --targets [property1|property2] \ 165 | ``` 166 | -------------------------------------------------------------------------------- /models/kernels/real_enc_proj_bwd.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_proj_bwd( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | //const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const float* z_ek, 13 | const float* gz_ek, 14 | const float* gv_ekd, 15 | const long long int N, 16 | const long long int H, 17 | const long long int E, 18 | //const long long int K_, 19 | const double dist_max, 20 | const double wscale, 21 | const float* W_k, 22 | const long long int W_num, 23 | const float* rveclens_n, 24 | const double cutoff_radius, 25 | float* ga_ik, 26 | float* gW_k){ 27 | 28 | const unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; 29 | if (tid >= N*H) return; 30 | 31 | const unsigned int k = tid%H; 32 | const unsigned int i = tid/H; 33 | const unsigned int n = batch_i[i]; 34 | tvecs_n += n*9; 35 | const float t1_x = tvecs_n[0]; 36 | const float t1_y = tvecs_n[1]; 37 | const float t1_z = tvecs_n[2]; 38 | const float t2_x = tvecs_n[3]; 39 | const float t2_y = tvecs_n[4]; 40 | const float t2_z = tvecs_n[5]; 41 | const float t3_x = tvecs_n[6]; 42 | const float t3_y = tvecs_n[7]; 43 | const float t3_z = tvecs_n[8]; 44 | const float a = a_ik[i*H + k]; 45 | const unsigned int e_end = e_start_i[i+1]; 46 | #if VPE_DIM > 0 47 | __shared__ float shared_gv[THREAD_NUM][VPE_DIM+1]; 48 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 49 | float *sv = shared_v[threadIdx.x]; 50 | float *gW = NULL; 51 | if (gW_k != NULL && (W_num == N || W_num == 1)){ 52 | gW = &gW_k[(i*H+k)*V_HEAD_DIM*VPE_DIM]; 53 | for (int dim = 0; dim < V_HEAD_DIM*VPE_DIM; dim++) 54 | gW[dim] = 0; 55 | } 56 | #endif 57 | 58 | rveclens_n += n*3; 59 | const float rvl1 = rveclens_n[0]; 60 | const float rvl2 = rveclens_n[1]; 61 | const float rvl3 = rveclens_n[2]; 62 | 63 | float cutoff = (float)cutoff_radius; 64 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 65 | if (cutoff != 0.0f) 66 | { 67 | if (cutoff < 0) { 68 | // Better sync the threads in each block? 69 | // -> disabled due to thread stucking 70 | // float a_max = a; 71 | // for (int t = 0; t < THREAD_NUM; t++) 72 | // a_max = max(a_max, a_ik[i*H + t]); 73 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 74 | cutoff = sqrt(-0.5f/a)*(-cutoff); 75 | } 76 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 77 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 78 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 79 | float cutoff2 = cutoff*cutoff; 80 | 81 | #if MINIMUM_RANGE > 0 82 | R1 = max(R1, MINIMUM_RANGE); 83 | R2 = max(R2, MINIMUM_RANGE); 84 | R3 = max(R3, MINIMUM_RANGE); 85 | #endif 86 | } 87 | 88 | float sum = 0; 89 | float sum_v = 0; 90 | for (unsigned int e = e_start_i[i]; e < e_end; e++) 91 | { 92 | const unsigned int j = edge_ij_e[E+e]; 93 | const float r_ijx = rpos_ij_e[e*3+0]; 94 | const float r_ijy = rpos_ij_e[e*3+1]; 95 | const float r_ijz = rpos_ij_e[e*3+2]; 96 | const unsigned int ek = e*H+k; 97 | const float z = z_ek[ek]; 98 | const float gz = gz_ek[ek]; 99 | 100 | #if VPE_DIM > 0 101 | float *sgv = shared_gv[threadIdx.x]; 102 | if (gW_k == NULL){ 103 | const float *gv = &gv_ekd[ek*VPE_DIM]; 104 | #pragma unroll 105 | for (int dim = 0; dim < VPE_DIM; dim++) { 106 | sgv[dim] = gv[dim]; 107 | } 108 | } else { 109 | // Compute backward of v' = Wv, as gW = (gv')^T * v 110 | const float *gv = &gv_ekd[ek*V_HEAD_DIM]; 111 | unsigned int w_ind = 0; 112 | if (W_num == 1){ 113 | w_ind = 0; 114 | } else if (W_num == E) { 115 | w_ind = e; 116 | } else if (W_num == N) { 117 | w_ind = i; 118 | } 119 | const float *W = &W_k[(w_ind*H+k)*V_HEAD_DIM*VPE_DIM]; 120 | #pragma unroll 121 | for (int dim = 0; dim < VPE_DIM; dim++) 122 | sgv[dim] = 0; 123 | #pragma unroll 124 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 125 | float gv_val = gv[wdim]; 126 | #pragma unroll 127 | for (int dim = 0; dim < VPE_DIM; dim++){ 128 | sgv[dim] += W[wdim*VPE_DIM+dim]*gv_val; 129 | //sgv[dim] += (*W++)*gv_val; 130 | } 131 | } 132 | 133 | // for gW 134 | if (W_num == E){ 135 | gW = &gW_k[(e*H+k)*V_HEAD_DIM*VPE_DIM]; 136 | for (int dim = 0; dim < V_HEAD_DIM*VPE_DIM; dim++) 137 | gW[dim] = 0; 138 | } 139 | } 140 | #endif 141 | 142 | float px_avr = 0; 143 | float pbg_avr = 0; 144 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 145 | const float mu0 = (float)dist_max/VPE_DIM; 146 | #if VPE_DIM > 0 147 | if (gW_k != NULL){ 148 | #pragma unroll 149 | for (int dim = 0; dim < VPE_DIM; dim++) 150 | sv[dim] = 0; 151 | } 152 | #endif 153 | for (float n1 = -R1; n1 <= R1; n1++) 154 | for (float n2 = -R2; n2 <= R2; n2++) 155 | for (float n3 = -R3; n3 <= R3; n3++) 156 | { 157 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 158 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 159 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 160 | float d2 = dx*dx + dy*dy + dz*dz; 161 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 162 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 163 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 164 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 165 | float p = expf(a*d2 - z); 166 | float px = d2*p; 167 | px_avr += px; 168 | 169 | #if VPE_DIM > 0 170 | float bg = 0; 171 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 172 | #pragma unroll 173 | for (int dim = 0; dim < VPE_DIM; dim++) 174 | { 175 | b += reci_ws_sqrt2; 176 | float gauss = expf(-b*b); 177 | bg += gauss*sgv[dim]; 178 | sv[dim] += gauss*p; 179 | } 180 | sum_v += px*bg; 181 | pbg_avr += p*bg; 182 | #endif 183 | } 184 | /* 185 | b: (E, 1, R, K) 186 | x: (E, 1, R, 1) 187 | y: (N, H, 1, 1) 188 | z: (E, H, 1, K) 189 | g: (E, H, 1, K) 190 | p: (E, H, R, 1) 191 | 192 | (E,H,R,K) (E,H,R,1) (E,H,R,K) (E,H,1,K): (E,H,R,1)*(E,1,R,K)*(E,H,1,K) 193 | dz/dye = p*x * ( b*g - (p*b*g).sum(axis=R)) 194 | 195 | (E,H,1,1) 196 | dz/dyi = (dz/dye).sum(axis=R,K).sum_for_j() 197 | 198 | (E,H,R,1)*(E,H,R,1) (E,H,1,1) *(E,H,1,1) 199 | dz/dye = (p*x) *(b*g).sum(axis=K) - (p*x).sum(axis=R)*(p*b*g).sum(axis=R,K)) 200 | */ 201 | 202 | sum += px_avr*gz; 203 | sum_v -= px_avr*pbg_avr; 204 | 205 | #if VPE_DIM > 0 206 | if (gW_k != NULL){ 207 | const float *gv = &gv_ekd[ek*V_HEAD_DIM]; 208 | #pragma unroll 209 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 210 | float gv_val = gv[wdim]; 211 | #pragma unroll 212 | for (int dim = 0; dim < VPE_DIM; dim++){ 213 | //*(_sgw++) += sv[dim]*gv_val; 214 | gW[wdim*VPE_DIM+dim] += sv[dim]*gv_val; 215 | } 216 | } 217 | } 218 | #endif 219 | } 220 | ga_ik[tid] = sum + sum_v; 221 | } -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | def onehot(x, n_classes): 9 | return torch.eye(n_classes)[x] 10 | 11 | def seed_worker(worker_id): 12 | worker_seed = torch.initial_seed() % 2**32 13 | np.random.seed(worker_seed) 14 | random.seed(worker_seed) 15 | 16 | class Params(): 17 | """Class that loads hyperparameters from a json file. 18 | Example: 19 | ``` 20 | params = Params(json_path) 21 | print(params.learning_rate) 22 | params.learning_rate = 0.5 # change the value of learning_rate in params 23 | ``` 24 | """ 25 | 26 | def __init__(self, json_path): 27 | with open(json_path) as f: 28 | params = json.load(f) 29 | self.__dict__.update(params) 30 | 31 | def save(self, json_path): 32 | with open(json_path, 'w') as f: 33 | json.dump(self.__dict__, f, indent=4) 34 | 35 | def update(self, json_path): 36 | """Loads parameters from json file""" 37 | with open(json_path) as f: 38 | params = json.load(f) 39 | self.__dict__.update(params) 40 | 41 | @property 42 | def dict(self): 43 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 44 | return self.__dict__ 45 | 46 | 47 | # for analysis embeddings 48 | def search_kNN(embedding_query, embedding_target, use_gpu=True, k=10): 49 | import faiss 50 | """ 51 | embeddingはtensorなどarray-like、kは探す近傍点の個数(1以上のint) 52 | return: 53 | D: クエリベクトルからk近傍までの距離 54 | I: クエリベクトルに対するk近傍のインデックス 55 | """ 56 | vec_dim = embedding_target.shape[1] # ベクトルの次元(dimension) 57 | n_data = embedding_target.shape[0] # データベースのサイズ(database size) 58 | x_target_vec = embedding_target.numpy().astype('float32') 59 | x_query_vec = embedding_query.numpy().astype('float32') 60 | faiss_index_cpu = faiss.IndexFlatL2(vec_dim) 61 | if use_gpu: 62 | faiss_index_gpu = faiss.index_cpu_to_all_gpus(faiss_index_cpu) 63 | faiss_index_gpu.add(x_target_vec) 64 | D, I = faiss_index_gpu.search(x_query_vec, k) 65 | else: 66 | faiss_index_cpu.add(x_target_vec) 67 | D, I = faiss_index_cpu.search(x_query_vec, k) 68 | return D, I 69 | 70 | def calc_grobal_top_k_acc(embedding_query, embedding_target, k=10): 71 | # top-k accを計算 72 | top_k_correct_samplenums = [] 73 | top_k_acc = [] 74 | n_data = embedding_target.shape[0] # データベースのサイズ(database size) 75 | for i in range(k): 76 | # k近傍のインデックスのarray内にクエリベクトル自身が入っていればok 77 | correct_samplenum = get_bool_of_corrected_predictions(embedding_query=embedding_query, 78 | embedding_target=embedding_target, 79 | k=i+1).sum() 80 | correct_samplenum = correct_samplenum.astype(np.float) 81 | top_k_correct_samplenums.append(correct_samplenum) 82 | top_k_acc.append(correct_samplenum/n_data) 83 | return top_k_acc 84 | 85 | def get_bool_of_corrected_predictions(embedding_query, embedding_target, k=10): 86 | D, I = search_kNN(embedding_query=embedding_query, embedding_target=embedding_target, k=k) 87 | series_idx = np.arange(embedding_target.shape[0]) # クエリベクトル自身のインデックス 88 | is_predict_correctly = np.isin(series_idx, I[:, :k]) 89 | return is_predict_correctly 90 | 91 | def retrieve_materials_properties(metadata, torch_dataset): 92 | """ 93 | metadata: 材料ごとのメタデータのdictをlistに入れたもの 94 | torch_dataset: metadataと突き合わせるためのPyTorchのdataset object 95 | metadataとtorch_datasetの中身の順番は一致していないといけない 96 | 97 | metadata dictの例: 98 | {'material_id': 'mp-1025051', 99 | 'pretty_formula': 'YbB2Rh3', 100 | 'energy_per_atom': -6.870412148333333, 101 | 'energy': -41.22247289, 102 | 'density': 10.61234910722115, 103 | 'final_structure': Structure Summary 104 | (中略) 105 | 'spacegroup.number': 191, 106 | 'band_gap': 0.0, 107 | 'formation_energy_per_atom': -0.7159613472222226, 108 | 'total_magnetization': 0.0005206, 109 | 'xrd_hist': array([0., 0., 0., ..., 0., 0., 0.])} 110 | """ 111 | pretty_formula = [] 112 | energy_per_atom = [] 113 | energy = [] 114 | density = [] 115 | formation_energy_per_atom = [] 116 | total_magnetization_uB = [] 117 | total_magnetization_T = [] 118 | band_gap = [] 119 | sgr = [] 120 | num_sites = [] 121 | cell_volume = [] 122 | weight = [] 123 | material_id = [] 124 | valid_material_ids = set(torch_dataset.data.material_id) 125 | const_µB_to_T = 9.27401007833 * 4*np.pi /10 126 | if 'e_above_hull' in metadata[0]: 127 | e_above_hull = [] 128 | else: 129 | e_above_hull = None 130 | 131 | for data in metadata: 132 | if data['material_id'] in valid_material_ids: 133 | material_id.append(data['material_id']) 134 | pretty_formula.append(data['pretty_formula']) 135 | energy_per_atom.append(data['energy_per_atom']) 136 | energy.append(data['energy']) 137 | density.append(data['density']) 138 | formation_energy_per_atom.append(data['formation_energy_per_atom']) 139 | # missing value workaround 140 | if data["material_id"] == "mp-1245128": 141 | data['total_magnetization'] = 0.0 142 | total_magnetization_uB.append(data['total_magnetization']) 143 | total_magnetization_T.append(data['total_magnetization']/data['final_structure'].volume*const_µB_to_T) 144 | band_gap.append(data['band_gap']) 145 | sgr.append(data['spacegroup.number']) 146 | num_sites.append(data['final_structure'].num_sites) 147 | cell_volume.append(data['final_structure'].volume) 148 | weight.append(data['final_structure'].composition.weight) 149 | if e_above_hull is not None: 150 | e_above_hull.append(data['e_above_hull']) 151 | 152 | properties = {'energy_per_atom':energy_per_atom , 'energy':energy, 'density':density, 153 | 'formation_energy_per_atom':formation_energy_per_atom, 154 | 'total_magnetization_uB':total_magnetization_uB, 'total_magnetization_T':total_magnetization_T, 'band_gap':band_gap, 'sgr':sgr , 155 | 'num_sites':num_sites, 'cell_volume':cell_volume, 'weight':weight} 156 | if e_above_hull is not None: 157 | properties['e_above_hull'] = e_above_hull 158 | return properties, material_id, pretty_formula 159 | 160 | def plot_embedding(tsne_embedding_xrd, tsne_embedding_crystal, value_for_color, color_key=None): 161 | fig, (ax1, ax2) = plt.subplots(figsize=(11, 4), ncols=2, dpi=100) 162 | # ax1 = plt.subplot(121) 163 | pos = ax1.scatter(tsne_embedding_xrd[:, 0], tsne_embedding_xrd[:, 1], 164 | c=value_for_color, s=8, linewidths=0.01, alpha=0.2) 165 | ax1.set_title('t-SNE visualization of xrd embedding') 166 | fig.colorbar(pos, ax=ax1) 167 | 168 | # ax2 = plt.subplot(122) 169 | pos = ax2.scatter(tsne_embedding_crystal[:, 0], tsne_embedding_crystal[:, 1], 170 | c=value_for_color, s=8, linewidths=0.01, alpha=0.2) 171 | ax2.set_title('t-SNE visualization of crystal embedding') 172 | fig.colorbar(pos, ax=ax2) 173 | if color_key is not None: 174 | fig.suptitle(color_key, fontsize='large') 175 | 176 | def retrieve_neighbour_materials(query_mp_id, embedding, embedding_metadata, n_neighbours=1000, use_gpu=True): 177 | """ 178 | ある物質のembeddingについて近傍の物質を検索して可視化する 179 | 180 | Parameters 181 | ---------- 182 | query_mp_id: str 183 | クエリする物質のmp_id (e.g. mp-764) 184 | 185 | embedding: array-like 186 | 検索対象のembeddingのtensorやarray 187 | 188 | Returns 189 | ------- 190 | retrieved_neighbours : pd.DataFrame 191 | クエリ近傍の物質のメタデータのdataframe 192 | 193 | disp: ipywidgets.widgets.widget_box.HBox 194 | クエリ近傍の物質の結晶構造を可視化するNGL Viewerのipython widget 195 | 196 | """ 197 | idx = embedding_metadata.query('mp_id == @query_mp_id').index[0] 198 | D, I = search_kNN(embedding_query=embedding[idx].unsqueeze(0), embedding_target=embedding, k=n_neighbours, use_gpu=use_gpu) 199 | retrieved_neighbours = embedding_metadata.iloc[I.squeeze()] 200 | return retrieved_neighbours 201 | -------------------------------------------------------------------------------- /models/kernels/real_enc_proj_bwd_v2.cu: -------------------------------------------------------------------------------- 1 | #include 2 | extern "C" __global__ 3 | 4 | void real_enc_proj_bwd_v2( 5 | const float* a_ik, 6 | const float* rpos_ij_e, 7 | //const float* dist2_min_e, 8 | const float* tvecs_n, 9 | const long long int* batch_i, 10 | const long long int* edge_ij_e, 11 | const long long int* e_start_i, 12 | const float* z_ek, 13 | const float* gz_ek, 14 | const float* gv_ekd, 15 | const long long int N, 16 | const long long int H, 17 | const long long int E, 18 | //const long long int K_, 19 | const double dist_max, 20 | const double wscale, 21 | const float* W_k, 22 | const long long int W_num, 23 | const float* rveclens_n, 24 | const double cutoff_radius, 25 | float* ga_ik, 26 | float* gW_k){ 27 | 28 | const unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; 29 | if (tid >= N*H) return; 30 | 31 | const unsigned int k = tid%H; 32 | const unsigned int i = tid/H; 33 | const unsigned int n = batch_i[i]; 34 | tvecs_n += n*9; 35 | const float t1_x = tvecs_n[0]; 36 | const float t1_y = tvecs_n[1]; 37 | const float t1_z = tvecs_n[2]; 38 | const float t2_x = tvecs_n[3]; 39 | const float t2_y = tvecs_n[4]; 40 | const float t2_z = tvecs_n[5]; 41 | const float t3_x = tvecs_n[6]; 42 | const float t3_y = tvecs_n[7]; 43 | const float t3_z = tvecs_n[8]; 44 | const float a = a_ik[i*H + k]; 45 | const unsigned int e_end = e_start_i[i+1]; 46 | #if VPE_DIM > 0 47 | __shared__ float shared_gv[THREAD_NUM][VPE_DIM+1]; 48 | __shared__ float shared_v[THREAD_NUM][VPE_DIM+1]; 49 | float *sv = shared_v[threadIdx.x]; 50 | float *gW = NULL; 51 | if (gW_k != NULL && (W_num == N || W_num == 1)){ 52 | gW = &gW_k[(i*H+k)*V_HEAD_DIM*VPE_DIM]; 53 | for (int dim = 0; dim < V_HEAD_DIM*VPE_DIM; dim++) 54 | gW[dim] = 0; 55 | } 56 | #endif 57 | 58 | rveclens_n += n*3; 59 | const float rvl1 = rveclens_n[0]; 60 | const float rvl2 = rveclens_n[1]; 61 | const float rvl3 = rveclens_n[2]; 62 | 63 | float cutoff = (float)cutoff_radius; 64 | int R1 = LATTICE_RANGE, R2 = LATTICE_RANGE, R3 = LATTICE_RANGE; 65 | if (cutoff != 0.0f) 66 | { 67 | if (cutoff < 0) { 68 | // Better sync the threads in each block? 69 | // -> disabled due to thread stucking 70 | // float a_max = a; 71 | // for (int t = 0; t < THREAD_NUM; t++) 72 | // a_max = max(a_max, a_ik[i*H + t]); 73 | //cutoff = sqrt(-0.5f/a_max)*(-cutoff); 74 | cutoff = sqrt(-0.5f/a)*(-cutoff); 75 | } 76 | R1 = ceil((cutoff + 0.01f)*rvl1/(2.0*CUDART_PI_F)); 77 | R2 = ceil((cutoff + 0.01f)*rvl2/(2.0*CUDART_PI_F)); 78 | R3 = ceil((cutoff + 0.01f)*rvl3/(2.0*CUDART_PI_F)); 79 | float cutoff2 = cutoff*cutoff; 80 | 81 | #if MINIMUM_RANGE > 0 82 | R1 = max(R1, MINIMUM_RANGE); 83 | R2 = max(R2, MINIMUM_RANGE); 84 | R3 = max(R3, MINIMUM_RANGE); 85 | #endif 86 | } 87 | 88 | float sum = 0; 89 | float sum_v = 0; 90 | for (unsigned int e = e_start_i[i]; e < e_end; e++) 91 | { 92 | const unsigned int j = edge_ij_e[E+e]; 93 | const float r_ijx = rpos_ij_e[e*3+0]; 94 | const float r_ijy = rpos_ij_e[e*3+1]; 95 | const float r_ijz = rpos_ij_e[e*3+2]; 96 | const unsigned int ek = e*H+k; 97 | const float z = z_ek[ek]; 98 | const float gz = gz_ek[ek]; 99 | 100 | #if VPE_DIM > 0 101 | float *sgv = shared_gv[threadIdx.x]; 102 | if (gW_k == NULL){ 103 | const float *gv = &gv_ekd[ek*VPE_DIM]; 104 | #pragma unroll 105 | for (int dim = 0; dim < VPE_DIM; dim++) { 106 | sgv[dim] = gv[dim]; 107 | } 108 | } else { 109 | // Compute backward of v' = Wv, as gW = (gv')^T * v 110 | const float *gv = &gv_ekd[ek*V_HEAD_DIM]; 111 | unsigned int w_ind = 0; 112 | if (W_num == 1){ 113 | w_ind = 0; 114 | } else if (W_num == E) { 115 | w_ind = e; 116 | } else if (W_num == N) { 117 | w_ind = i; 118 | } 119 | const float *W = &W_k[(w_ind*H+k)*V_HEAD_DIM*VPE_DIM]; 120 | #pragma unroll 121 | for (int dim = 0; dim < VPE_DIM; dim++) 122 | sgv[dim] = 0; 123 | #pragma unroll 124 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 125 | float gv_val = gv[wdim]; 126 | #pragma unroll 127 | for (int dim = 0; dim < VPE_DIM; dim++){ 128 | sgv[dim] += W[wdim*VPE_DIM+dim]*gv_val; 129 | //sgv[dim] += (*W++)*gv_val; 130 | } 131 | } 132 | 133 | // for gW 134 | if (W_num == E){ 135 | gW = &gW_k[(e*H+k)*V_HEAD_DIM*VPE_DIM]; 136 | for (int dim = 0; dim < V_HEAD_DIM*VPE_DIM; dim++) 137 | gW[dim] = 0; 138 | } 139 | } 140 | #endif 141 | 142 | float px_avr = 0; 143 | float pbg_avr = 0; 144 | const float reci_ws_sqrt2 = 1.0f/((float)wscale*sqrt(2.0f)); 145 | const float mu0 = (float)dist_max/VPE_DIM; 146 | #if VPE_DIM > 0 147 | if (gW_k != NULL){ 148 | #pragma unroll 149 | for (int dim = 0; dim < VPE_DIM; dim++) 150 | sv[dim] = 0; 151 | } 152 | #endif 153 | for (float n1 = -R1, px_avr1=0, pbg_avr1=0, sum_v1=0; n1 <= R1; n1++, px_avr +=px_avr1, pbg_avr +=pbg_avr1, sum_v +=sum_v1, px_avr1=pbg_avr1=sum_v1=0) 154 | for (float n2 = -R2, px_avr2=0, pbg_avr2=0, sum_v2=0; n2 <= R2; n2++, px_avr1+=px_avr2, pbg_avr1+=pbg_avr2, sum_v1+=sum_v2, px_avr2=pbg_avr2=sum_v2=0) 155 | for (float n3 = -R3; n3 <= R3; n3++) 156 | { 157 | float dx = r_ijx + t1_x*n1 + t2_x*n2 + t3_x*n3; 158 | float dy = r_ijy + t1_y*n1 + t2_y*n2 + t3_y*n3; 159 | float dz = r_ijz + t1_z*n1 + t2_z*n2 + t3_z*n3; 160 | float d2 = dx*dx + dy*dy + dz*dz; 161 | // float dx = fmaf(t1_x, n1, fmaf(t2_x, n2, fmaf(t3_x, n3, r_ijx))); 162 | // float dy = fmaf(t1_y, n1, fmaf(t2_y, n2, fmaf(t3_y, n3, r_ijy))); 163 | // float dz = fmaf(t1_z, n1, fmaf(t2_z, n2, fmaf(t3_z, n3, r_ijz))); 164 | // float d2 = fmaf(dx,dx, fmaf(dy,dy, dz*dz)); 165 | float p = expf(a*d2 - z); 166 | float px = d2*p; 167 | px_avr2 += px; 168 | 169 | #if VPE_DIM > 0 170 | float bg = 0; 171 | float b = -sqrtf(d2)/mu0*reci_ws_sqrt2; 172 | #pragma unroll 173 | for (int dim = 0; dim < VPE_DIM; dim++) 174 | { 175 | b += reci_ws_sqrt2; 176 | float gauss = expf(-b*b); 177 | bg += gauss*sgv[dim]; 178 | sv[dim] += gauss*p; 179 | } 180 | sum_v2 += px*bg; 181 | pbg_avr2 += p*bg; 182 | #endif 183 | } 184 | /* 185 | b: (E, 1, R, K) 186 | x: (E, 1, R, 1) 187 | y: (N, H, 1, 1) 188 | z: (E, H, 1, K) 189 | g: (E, H, 1, K) 190 | p: (E, H, R, 1) 191 | 192 | (E,H,R,K) (E,H,R,1) (E,H,R,K) (E,H,1,K): (E,H,R,1)*(E,1,R,K)*(E,H,1,K) 193 | dz/dye = p*x * ( b*g - (p*b*g).sum(axis=R)) 194 | 195 | (E,H,1,1) 196 | dz/dyi = (dz/dye).sum(axis=R,K).sum_for_j() 197 | 198 | (E,H,R,1)*(E,H,R,1) (E,H,1,1) *(E,H,1,1) 199 | dz/dye = (p*x) *(b*g).sum(axis=K) - (p*x).sum(axis=R)*(p*b*g).sum(axis=R,K)) 200 | */ 201 | 202 | sum += px_avr*gz; 203 | sum_v -= px_avr*pbg_avr; 204 | 205 | #if VPE_DIM > 0 206 | if (gW_k != NULL){ 207 | const float *gv = &gv_ekd[ek*V_HEAD_DIM]; 208 | #pragma unroll 209 | for (int wdim = 0; wdim < V_HEAD_DIM; wdim++){ 210 | float gv_val = gv[wdim]; 211 | #pragma unroll 212 | for (int dim = 0; dim < VPE_DIM; dim++){ 213 | //*(_sgw++) += sv[dim]*gv_val; 214 | gW[wdim*VPE_DIM+dim] += sv[dim]*gv_val; 215 | } 216 | } 217 | } 218 | #endif 219 | } 220 | ga_ik[tid] = sum + sum_v; 221 | } -------------------------------------------------------------------------------- /models/cuda_funcs/fused_dpa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | class FusedDotProductAttentionCUDA(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, batch_i, edge_ij_e): 20 | N, H, K = que_ihk.shape 21 | E = edge_ij_e.shape[1] 22 | dev = que_ihk.device 23 | 24 | e_start_i = torch.zeros(N+1, dtype=batch_i.dtype, device=batch_i.device) 25 | e_start_i.scatter_add_(0, edge_ij_e[0]+1, torch.ones_like(edge_ij_e[0])) 26 | e_start_i = e_start_i.cumsum(0) 27 | 28 | que_ihk = que_ihk.contiguous().detach() 29 | key_ihk = key_ihk.contiguous().detach() 30 | val_ihk = val_ihk.contiguous().detach() 31 | aij_eh = aij_eh.contiguous().detach() if aij_eh is not None else None 32 | bij_ehk = bij_ehk.contiguous().detach() if bij_ehk is not None else None 33 | batch_i = batch_i.contiguous() 34 | edge_ij_e = edge_ij_e.contiguous() 35 | 36 | output = torch.empty_like(val_ihk) 37 | prob_eh = torch.empty((E, H), dtype=que_ihk.dtype, device=dev) 38 | 39 | bsz = H 40 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 41 | from .. import global_config as config 42 | kernel = KernelManager.fused_dpa_fwd_v3 if config.REPRODUCIBLITY_STATE>=3 \ 43 | else KernelManager.fused_dpa_fwd 44 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), 45 | ( 46 | _to_copy(que_ihk), 47 | _to_copy(key_ihk), 48 | _to_copy(val_ihk), 49 | _to_copy(aij_eh), 50 | _to_copy(bij_ehk), 51 | _to_copy(edge_ij_e), 52 | _to_copy(e_start_i), 53 | N, H, E, 54 | _to_copy(prob_eh), 55 | _to_copy(output), 56 | ) 57 | ) 58 | ctx.save_for_backward(que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, 59 | batch_i, edge_ij_e, e_start_i, 60 | prob_eh, output) 61 | return output 62 | 63 | @staticmethod 64 | def backward(ctx, go_ihk): 65 | que_ihk, key_ihk, val_ihk, aij_eh, bij_ehk, \ 66 | batch_i, edge_ij_e, e_start_i, \ 67 | prob_eh, output = ctx.saved_tensors 68 | 69 | N, H, K = que_ihk.shape 70 | E = edge_ij_e.shape[1] 71 | dev = que_ihk.device 72 | 73 | B = batch_i.max().item()+1 74 | sizes = torch.zeros(B, dtype=torch.long, device=dev) 75 | sizes.scatter_add_(0, batch_i, torch.ones_like(batch_i)) 76 | sizes2 = sizes*sizes 77 | 78 | if False: 79 | gque = [] 80 | gkey = [] 81 | gval = [] 82 | gbij = [] 83 | gaij = [] 84 | sizes = torch.zeros(N, dtype=batch_i.dtype, device=batch_i.device) 85 | sizes.scatter_add_(0, batch_i, torch.ones_like(batch_i)) 86 | sizes2 = sizes*sizes 87 | _sizes = sizes.tolist() 88 | _sizes2 = sizes2.tolist() 89 | for q,k,v,a,b,o,p,go,s in zip( 90 | que_ihk.split_with_sizes(_sizes), 91 | key_ihk.split_with_sizes(_sizes), 92 | val_ihk.split_with_sizes(_sizes), 93 | aij_eh.split_with_sizes(_sizes2), 94 | bij_ehk.split_with_sizes(_sizes2), 95 | output.split_with_sizes(_sizes), 96 | prob_eh.split_with_sizes(_sizes2), 97 | go_ihk.split_with_sizes(_sizes), 98 | _sizes): 99 | # q/k/v/o/go: (S, H, K) 100 | # a/p: (S*S, H) 101 | # b: (S*S, H, K) 102 | gb = go.reshape(s,1,H,K) * p.reshape(s,s,H,1) 103 | gv = gb.sum(dim=0) 104 | gval.append(gv) 105 | gbij.append(gb.reshape(s*s,H,K)) 106 | gsm = (v.reshape(1,s,H,K) + b.reshape(s,s,H,K) - o.reshape(s,1,H,K))*gb 107 | ga = gsm.sum(dim=3) 108 | gq = (ga.reshape(s,s,H,1)*k.reshape(1,s,H,K)).sum(dim=1) 109 | gk = (ga.reshape(s,s,H,1)*q.reshape(s,1,H,K)).sum(dim=0) 110 | gaij.append(ga.reshape(s*s,H)) 111 | gque.append(gq) 112 | gkey.append(gk) 113 | 114 | gque = torch.cat(gque) 115 | gkey = torch.cat(gkey) 116 | gval = torch.cat(gval) 117 | gbij = torch.cat(gbij) 118 | gaij = torch.cat(gaij) 119 | return gque, gkey, gval, gaij, gbij, None, None 120 | 121 | gque = torch.empty_like(que_ihk) 122 | gkey = torch.empty_like(key_ihk) 123 | gval = torch.empty_like(val_ihk) 124 | gaij = torch.empty_like(aij_eh) 125 | gbij = torch.empty_like(bij_ehk) if bij_ehk is not None else None 126 | go_ihk = go_ihk.contiguous().detach() 127 | 128 | tprob_eh = torch.empty_like(prob_eh) 129 | tbij_ehk = torch.empty_like(bij_ehk) if bij_ehk is not None else None 130 | start_inds = torch.constant_pad_nd(sizes2.cumsum(0), (1,0)) 131 | 132 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 133 | _start_inds = _to_copy(start_inds) 134 | _sizes = _to_copy(sizes) 135 | 136 | upper_mask = edge_ij_e[0] <= edge_ij_e[1] 137 | hE = upper_mask.long().sum().item() 138 | upper_e_t = torch.arange(E, dtype=torch.long, device=dev)[upper_mask] 139 | upper_batch_t = batch_i[edge_ij_e[0, upper_mask]] 140 | mat_sec_t = start_inds[upper_batch_t] 141 | sizes_t = sizes[upper_batch_t] 142 | 143 | def irregular_transpose(src:Tensor, dst:Tensor, C:int): 144 | bsz = min(32, C) 145 | KernelManager.irregular_transpose( 146 | ((hE*C+bsz-1)//bsz, ), (bsz, ), 147 | (_to_copy(src), _to_copy(upper_e_t), _to_copy(mat_sec_t), _to_copy(sizes_t), hE, C, _to_copy(dst)) 148 | ) 149 | 150 | # def irregular_transpose(src:Tensor, dst:Tensor, C:int): 151 | # bsz = min(32, C) 152 | # kernels['irregular_transpose_old']( 153 | # ((B*C+bsz-1)//bsz, ), (bsz, ), 154 | # (_to_copy(src), _to_copy(start_inds), _to_copy(sizes), B, C, _to_copy(dst)) 155 | # ) 156 | 157 | irregular_transpose(prob_eh, tprob_eh, H) 158 | if bij_ehk is not None: 159 | irregular_transpose(bij_ehk, tbij_ehk, H*K) 160 | 161 | assert (sizes <= KernelManager.MAX_SYSTEM_SIZE).all(), "Increase MAX_SYSTEM_SIZE in KernelManager" 162 | bsz = H 163 | from .. import global_config as config 164 | kernel = KernelManager.fused_dpa_bwd_v3 if config.REPRODUCIBLITY_STATE>=3 \ 165 | else KernelManager.fused_dpa_bwd 166 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), 167 | ( 168 | _to_copy(que_ihk), 169 | _to_copy(key_ihk), 170 | _to_copy(val_ihk), 171 | _to_copy(aij_eh), 172 | _to_copy(tbij_ehk), 173 | _to_copy(batch_i), 174 | _to_copy(edge_ij_e), 175 | _to_copy(e_start_i), 176 | N, H, E, 177 | _to_copy(tprob_eh), 178 | _to_copy(output), 179 | _to_copy(go_ihk), 180 | _to_copy(gque), 181 | _to_copy(gkey), 182 | _to_copy(gval), 183 | _to_copy(gaij), 184 | _to_copy(gbij), 185 | ) 186 | ) 187 | 188 | # tranpose gaij and gbij 189 | irregular_transpose(gaij, gaij, H) 190 | if gbij is not None: 191 | irregular_transpose(gbij, gbij, H*K) 192 | 193 | # use gaij as grad softmax to compute grad q. 194 | bsz = H 195 | kernel = KernelManager.fused_dpa_bwd_q_v3 if config.REPRODUCIBLITY_STATE>=3 \ 196 | else KernelManager.fused_dpa_bwd_q 197 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), 198 | ( 199 | _to_copy(key_ihk), 200 | _to_copy(gaij), 201 | _to_copy(edge_ij_e), 202 | _to_copy(e_start_i), 203 | N, H, E, 204 | _to_copy(gque), 205 | ) 206 | ) 207 | 208 | return gque, gkey, gval, gaij, gbij, None, None 209 | -------------------------------------------------------------------------------- /models/cuda_funcs/real_space_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from .kernel_manager import KernelManager 4 | 5 | try: 6 | import cupy as cp 7 | import pytorch_pfn_extras as ppe 8 | from torch.utils.dlpack import to_dlpack, from_dlpack 9 | except: 10 | pass 11 | 12 | def _to_copy(x): 13 | if x is not None: 14 | return cp.from_dlpack(to_dlpack(x)) 15 | return 0 16 | 17 | class RealPeriodicEncodingFuncCUDA(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, K, dist_max, wscale, \ 20 | W_k, rvlen_n=None, cutoff_radius=None): 21 | # if True or not a_ik.requires_grad: 22 | # R = 2 23 | # grids = torch.arange(-R, R+1, dtype=a_ik.dtype, device=a_ik.device) 24 | # grids = torch.stack(torch.meshgrid([grids]*3, indexing='ij'), dim=-1) 25 | # grids = grids.reshape(-1, 3) # (R,3) 26 | # b2e = batch_i[edge_ij_e[0]] # (N)[b2e] -> (E) 27 | # lattice = grids @ tvecs_n # (N,R,D) = ( R,D)x(N, D,D) 28 | # pos_ijn = rpos_ij_e[:,None] + lattice[b2e] # (E,R,D) = (E, 1,D)+(E, R,D) 29 | # del lattice 30 | # dist2 = (pos_ijn**2).sum(axis=2,keepdim=True) # (E,R,1) 31 | # dist2_min = dist2.min(axis=1, keepdim=True)[0] # (E,1,1) 32 | # dist2 -= dist2_min 33 | # a_e1k = a_ik[edge_ij_e[0]][:,None] # (N,H) -> (E,1,H) 34 | # z = torch.exp(dist2*a_e1k) 35 | # z = z.sum(axis=1, keepdim=True) 36 | # z.log_() 37 | # z += a_e1k*dist2_min 38 | # z.squeeze_(1) 39 | 40 | # N, H = a_ik.shape 41 | # E = edge_ij_e.shape[1] 42 | # kw = {'device': a_ik.device, 'dtype': a_ik.dtype} 43 | # v_ekd = torch.empty((E, H, K), **kw) 44 | # ctx.save_for_backward(a_ik, rpos_ij_e, tvecs_n, batch_i, edge_ij_e, z, dist2_min, v_ekd) 45 | # ctx.K = K 46 | # ctx.dist_max = dist_max 47 | # ctx.wscale = wscale 48 | # return z, 49 | 50 | # a_ik : (points, heads) 51 | # rpos_ij_e : (edges, 3) 52 | # tvecs_n : (batch, 3, 3) 53 | # batch_i : (points) 54 | # edge_ij_e : (2, edges) 55 | # z_ijk = log( sum_n exp( a_ik*|pj + t1*n1+t2*n2+t3*n3 - pi|^2 ) ) 56 | # : (edges, heads) 57 | N, H = a_ik.shape 58 | E = edge_ij_e.shape[1] 59 | kw = {'device': a_ik.device, 'dtype': a_ik.dtype} 60 | 61 | a_ik = a_ik.contiguous().detach() 62 | rpos_ij_e = rpos_ij_e.contiguous() 63 | tvecs_n = tvecs_n.contiguous() 64 | batch_i = batch_i.contiguous() 65 | dist2_min_e = dist2_min_e.contiguous() if dist2_min_e is not None else None 66 | edge_ij_e = edge_ij_e.contiguous() 67 | if W_k is not None: 68 | W_k = W_k.detach().contiguous() 69 | assert W_k.dim() in (3, 4) 70 | W_num = 1 if W_k.dim() == 3 else W_k.shape[0] 71 | W_dim = W_k.shape[-2] 72 | # v_ekd = torch.empty((E, H, W_dim), **kw) # not neaded for noproj 73 | else: 74 | W_num = 0 75 | W_dim = 0 76 | v_ekd = torch.empty((E, H, K), **kw) 77 | z_ek = torch.empty((E, H), **kw) 78 | 79 | bsz = H 80 | dev = a_ik.device 81 | 82 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 83 | if False and rvlen_n is None: 84 | KernelManager.position_enc_forward( ((E*H+bsz-1)//bsz, ), (bsz, ), ( 85 | _to_copy(a_ik), 86 | _to_copy(rpos_ij_e), 87 | _to_copy(dist2_min_e), 88 | _to_copy(tvecs_n), 89 | _to_copy(batch_i), 90 | _to_copy(edge_ij_e), 91 | N, H, E, 92 | K, dist_max, wscale, 93 | #_to_copy(W_k), W_num, 94 | _to_copy(z_ek), 95 | _to_copy(v_ekd), 96 | )) 97 | else: 98 | from .. import global_config as config 99 | kernel = KernelManager.real_enc_fwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 100 | else KernelManager.real_enc_fwd 101 | kernel( ((E*H+bsz-1)//bsz, ), (bsz, ), ( 102 | _to_copy(a_ik), 103 | _to_copy(rpos_ij_e), 104 | _to_copy(dist2_min_e), 105 | _to_copy(tvecs_n), 106 | _to_copy(batch_i), 107 | _to_copy(edge_ij_e), 108 | N, H, E, 109 | K, dist_max, wscale, 110 | #_to_copy(W_k), W_num, 111 | _to_copy(rvlen_n), cutoff_radius, 112 | _to_copy(z_ek), 113 | _to_copy(v_ekd), 114 | )) 115 | 116 | if W_k is not None: 117 | # W_k : (edges or 1, heads, vdim, vpe_dim) 118 | # v_ekd: (edges , heads, vpe_dim) 119 | # v_out: (edges , heads, vdim, 1) 120 | v_out = W_k @ v_ekd[..., None] 121 | v_out = v_out.reshape(E, H, -1) 122 | assert W_num in (E, 1) 123 | #if W_num == 1: 124 | # v_ekd = v_ekd.sum(dim=0, keepdim=(W_k.dim()==4)) 125 | else: 126 | v_out = v_ekd 127 | v_ekd = None 128 | 129 | 130 | ctx.save_for_backward(a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, rvlen_n, W_k, z_ek, v_ekd) 131 | ctx.K = K 132 | ctx.dist_max = dist_max 133 | ctx.wscale = wscale 134 | ctx.cutoff_radius = cutoff_radius 135 | if K <= 0: 136 | return z_ek, 137 | 138 | return z_ek, v_out 139 | 140 | @staticmethod 141 | def backward(ctx, gz_ek, gv_ekd=None): 142 | # a_ik, rpos_ij_e, tvecs_n, batch_i, edge_ij_e, z_ek = ctx.saved_tensors[:6] 143 | 144 | # R = 2 145 | # edges0 = edge_ij_e[0] 146 | # grids = torch.arange(-R, R+1, dtype=a_ik.dtype, device=a_ik.device) 147 | # grids = torch.stack(torch.meshgrid([grids]*3, indexing='ij'), dim=-1) 148 | # grids = grids.reshape(-1, 3) # (R,3) 149 | # b2e = batch_i[edges0] # (N)[b2e] -> (E) 150 | # lattice = grids @ tvecs_n # (N,R,D) = ( R,D)x(N, D,D) 151 | # pos_ijn = rpos_ij_e[:,None] + lattice[b2e] # (E,R,D) = (E, 1,D)+(E, R,D) 152 | # del lattice 153 | # dist2 = (pos_ijn**2).sum(axis=2,keepdim=True) # (E,R,1) 154 | # #dist2 -= dist2_min 155 | # a_e1k = a_ik[edges0][:,None] # (N,H) -> (E,1,H) 156 | # p = torch.exp(dist2*a_e1k - z_ek[:,None,:]) # (E,R,H) 157 | # g = (dist2*p).sum(axis=1)*gz_ek # (E,H) 158 | # N = a_ik.shape[0] 159 | # sizes = torch.zeros(N, dtype=torch.long, device=a_ik.device) 160 | # sizes.scatter_add_(0, edges0, torch.ones_like(edges0)) 161 | # g = [x.sum(axis=0) for x in torch.split_with_sizes(g, sizes.tolist())] 162 | # ga_ik = torch.stack(g) 163 | 164 | # return ga_ik, None, None, None, None, None, None, None 165 | 166 | a_ik, rpos_ij_e, dist2_min_e, tvecs_n, batch_i, edge_ij_e, rvlen_n, W_k, z_ek, v_ekd = ctx.saved_tensors 167 | K = ctx.K 168 | dist_max = ctx.dist_max 169 | wscale = ctx.wscale 170 | cutoff_radius = ctx.cutoff_radius 171 | N, H = a_ik.shape 172 | E = edge_ij_e.shape[1] 173 | 174 | e_start_i = torch.zeros(N+1, dtype=batch_i.dtype, device=batch_i.device) 175 | e_start_i.scatter_add_(0, edge_ij_e[0]+1, torch.ones_like(edge_ij_e[0])) 176 | e_start_i = e_start_i.cumsum(0) 177 | 178 | ga_ik = torch.empty_like(a_ik) 179 | 180 | dev = a_ik.device 181 | gW_k = None 182 | if W_k is not None: 183 | # W: (edges or 1, heads, head_dim, K) 184 | assert W_k.dim() in (3, 4) 185 | W_num = 1 if W_k.dim() == 3 else W_k.shape[0] 186 | W_dim = W_k.shape[-2] 187 | 188 | # W: (edges or 1, heads, head_dim, K) 189 | # gv_ekd:(edges , heads, Vdim) 190 | # v_ekd: (edges or 1, heads, K) 191 | gW_k = gv_ekd[..., :, None] * v_ekd[..., None, :] 192 | else: 193 | W_num = 0 194 | W_dim = 0 195 | 196 | if gv_ekd is not None and W_k is not None: 197 | # gv_ekd: (E,H,Vdim) 198 | # W_k : (1 or E, H, Vdim, VPEdim) 199 | gv_ekd = W_k.transpose(-1, -2) @ gv_ekd[..., None] 200 | 201 | bsz = H 202 | with cp.cuda.Device(dev.index), ppe.cuda.stream(torch.cuda.current_stream(dev)): 203 | if False and rvlen_n is None: 204 | KernelManager.position_enc_backward(((N*H+bsz-1)//bsz, ), (bsz, ), ( 205 | _to_copy(a_ik.detach()), 206 | _to_copy(rpos_ij_e), 207 | _to_copy(dist2_min_e), 208 | _to_copy(tvecs_n), 209 | _to_copy(batch_i), 210 | _to_copy(edge_ij_e), 211 | _to_copy(e_start_i), 212 | _to_copy(z_ek.detach()), 213 | _to_copy(gz_ek.detach().contiguous()), 214 | _to_copy(gv_ekd), 215 | N, H, E, 216 | K, dist_max, wscale, 217 | #_to_copy(W_k), W_num, 218 | _to_copy(ga_ik), 219 | #_to_copy(gW_k), 220 | )) 221 | else: 222 | from .. import global_config as config 223 | kernel = KernelManager.real_enc_bwd_v2 if config.REPRODUCIBLITY_STATE >= 4 \ 224 | else KernelManager.real_enc_bwd 225 | kernel(((N*H+bsz-1)//bsz, ), (bsz, ), ( 226 | _to_copy(a_ik.detach()), 227 | _to_copy(rpos_ij_e), 228 | _to_copy(dist2_min_e), 229 | _to_copy(tvecs_n), 230 | _to_copy(batch_i), 231 | _to_copy(edge_ij_e), 232 | _to_copy(e_start_i), 233 | _to_copy(z_ek.detach()), 234 | _to_copy(gz_ek.detach().contiguous()), 235 | _to_copy(gv_ekd), 236 | N, H, E, 237 | K, dist_max, wscale, 238 | #_to_copy(W_k), W_num, 239 | _to_copy(rvlen_n), cutoff_radius, 240 | _to_copy(ga_ik), 241 | #_to_copy(gW_k), 242 | )) 243 | 244 | if rvlen_n is None: 245 | return ga_ik, None, None, None, None, None, None, None, None, gW_k 246 | 247 | return ga_ik, None, None, None, None, None, None, None, None, gW_k, None, None 248 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import sys 4 | import torch 5 | from torch_geometric.loader import DataLoader 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import Trainer, loggers 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | import shutil 10 | 11 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | from utils import Params, seed_worker 14 | from models.regression import RegressionModel 15 | from dataloaders.common import filter_by_atom_num 16 | from functools import partial 17 | from distutils.util import strtobool 18 | 19 | def get_option(): 20 | argparser = ArgumentParser(description='Training the network') 21 | argparser.add_argument('-p', '--param_file', type=str, default='default.json', help='filename of the parameter JSON') 22 | args, unknown = argparser.parse_known_args() 23 | return args 24 | 25 | def train(): 26 | args = get_option() 27 | print('parsed args :') 28 | print(args) 29 | try: 30 | params = Params(f'{args.param_file}') 31 | except: 32 | params = Params(f'./params/{args.param_file}') 33 | 34 | parser = ArgumentParser(description='Training the network') 35 | parser.add_argument('-p', '--param_file', type=str, default=args.param_file, help='Config json file for default params') 36 | # load the json config and use it as default values. 37 | boolder = lambda x:bool(strtobool(x)) 38 | typefinder = lambda v: str if v is None else boolder if type(v)==bool else type(v) 39 | for key in params.dict: 40 | v = params.dict[key] 41 | if isinstance(v, (list, tuple)): 42 | parser.add_argument(f"--{key}", type=typefinder(v[0]), default=v, nargs='+') 43 | else: 44 | parser.add_argument(f"--{key}", type=typefinder(v), default=v) 45 | params.__dict__ = parser.parse_args().__dict__ 46 | print(params.dict) 47 | 48 | import models.global_config as config 49 | config.REPRODUCIBLITY_STATE = getattr(params, 'reproduciblity_state', 0) 50 | print(f"reproduciblity_state = {config.REPRODUCIBLITY_STATE}") 51 | 52 | # Reproducibility 53 | seed = getattr(params, 'seed', 123) 54 | deterministic = params.encoder_name in ["latticeformer"] 55 | pl.seed_everything(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | torch.backends.cudnn.benchmark = False 60 | torch.backends.cudnn.deterministic = deterministic 61 | torch.backends.cuda.matmul.allow_tf32 = False 62 | torch.backends.cudnn.allow_tf32 = False 63 | # torch.backends.cuda.preferred_linalg_library("cusolver") # since torch 1.11, needed to avoid an error by torch.det(), but now det_3x3 is implemented manually. 64 | 65 | if params.encoder_name == "latticeformer": 66 | from dataloaders.dataset_latticeformer import RegressionDatasetMP_Latticeformer as Dataset 67 | else: 68 | raise NameError(params.encoder_name) 69 | # from dataloaders.dataloader import PyMgStructureMP as Dataset 70 | 71 | # Setup datasets 72 | ddp = getattr(params, "ddp", False) 73 | max_val = getattr(params, "train_filter_max", 0) 74 | min_val = getattr(params, "train_filter_min", 0) 75 | num_workers = getattr(params, "num_workers", 4) 76 | num_workers = num_workers if num_workers >= 0 else os.cpu_count() 77 | target_set = getattr(params, "target_set", None) 78 | train_filter = partial(filter_by_atom_num, max_val=max_val, min_val=min_val) \ 79 | if max_val > 0 or min_val > 1 else None 80 | if not hasattr(params, "training_data") or params.training_data == "default": 81 | train_dataset = Dataset(target_split='train', target_set=target_set, post_filter=train_filter) 82 | elif params.training_data in ["train_6400", "train_10k"]: 83 | train_dataset = Dataset(target_split=params.training_data, post_filter=train_filter) 84 | else: 85 | raise NameError(params.training_data) 86 | 87 | val_dataset = Dataset(target_split='val', target_set=target_set) 88 | test_dataset = Dataset(target_split='test', target_set=target_set) 89 | 90 | if torch.cuda.device_count() == 1 or not ddp: 91 | train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True, num_workers=num_workers, drop_last=True, 92 | worker_init_fn=seed_worker, pin_memory=True) 93 | val_loader = DataLoader(val_dataset, batch_size=params.batch_size, shuffle=False, num_workers=num_workers, drop_last=False, pin_memory=True) 94 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=num_workers, drop_last=False) 95 | else: 96 | train_sampler = torch.utils.data.distributed.DistributedSampler( 97 | train_dataset, rank=0, num_replicas=torch.cuda.device_count(), shuffle=True) 98 | val_sampler = torch.utils.data.distributed.DistributedSampler( 99 | val_dataset, rank=0, num_replicas=torch.cuda.device_count(), shuffle=False) 100 | test_sampler = torch.utils.data.distributed.DistributedSampler( 101 | test_dataset, rank=0, num_replicas=torch.cuda.device_count(), shuffle=False) 102 | 103 | train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=False, num_workers=num_workers, drop_last=True, 104 | worker_init_fn=seed_worker, sampler=train_sampler, pin_memory=True) 105 | val_loader = DataLoader(val_dataset, batch_size=params.batch_size, shuffle=False, num_workers=num_workers, drop_last=False, sampler=val_sampler, pin_memory=True) 106 | test_loader = DataLoader(test_dataset, batch_size=params.batch_size, shuffle=False, num_workers=num_workers, drop_last=False, sampler=test_sampler) 107 | 108 | # Uncomment below to force the updating of the cache files. 109 | # train_dataset.process() 110 | # val_dataset.process() 111 | # test_dataset.process() 112 | 113 | # Setup model and trainer 114 | logger = loggers.TensorBoardLogger(params.save_path, name=params.experiment_name, default_hp_metric=False) 115 | logger.log_hyperparams(params.__dict__, \ 116 | {"hp/val": 1.0, "hp/avr50":1.0, "hp/min_avr50":1.0, "hp/min":1.0, "hp/mean_min10": 1.0} 117 | ) 118 | ckpt_dir=logger.log_dir+'/model_checkpoint' 119 | checkpoint_callback = ModelCheckpoint(save_top_k=params.model_checkpoint_save_top_k, 120 | monitor='val/loss', mode='min', dirpath=ckpt_dir) 121 | 122 | system = RegressionModel(params, train_loader, val_loader) 123 | param_num = sum([p.nelement() for p in system.model.parameters()]) 124 | print(f"Whole: {param_num}, {param_num*4/1024**2} MB") 125 | param_num = sum([p.nelement() for p in system.model.encoder.layers[0].parameters()]) 126 | print(f"Block: {param_num}, {param_num*4/1024**1} KB") 127 | 128 | # initialize mean and std values in crystalformer by forwarding once. 129 | if ddp: 130 | with torch.no_grad(): 131 | import random 132 | import numpy 133 | state = torch.random.get_rng_state(), random.getstate(), numpy.random.get_state() 134 | system.train() 135 | system.cuda().forward(next(iter(train_loader)).cuda()) 136 | system.cpu() 137 | torch.random.set_rng_state(state[0]) # usually, resetting torch's state is sufficient 138 | random.setstate(state[1]) 139 | numpy.random.set_state(state[2]) 140 | 141 | if params.pretrained_model is not None: 142 | system = RegressionModel.load_from_checkpoint( 143 | params.pretrained_model, 144 | params=params, 145 | train_loader=train_loader, 146 | val_loader=val_loader, 147 | strict=False) 148 | 149 | # Train model 150 | trainer = Trainer( 151 | logger=logger, 152 | devices=torch.cuda.device_count() if ddp else 1, 153 | strategy='ddp' if ddp else 'auto', 154 | max_epochs=params.n_epochs, 155 | default_root_dir=params.save_path, 156 | enable_checkpointing=True, 157 | callbacks=[checkpoint_callback], 158 | num_nodes=1, 159 | limit_train_batches=params.train_percent_check, 160 | limit_val_batches=params.val_percent_check, 161 | fast_dev_run=False, 162 | deterministic=deterministic) 163 | 164 | import time 165 | time_dict = {} 166 | start_time = time.time() 167 | trainer.fit(system) 168 | time_dict['time-train'] = (time.time()-start_time) 169 | scores = [] 170 | 171 | trainer.save_checkpoint(ckpt_dir+'/last.ckpt') # ensure checkpointing after SWA's BN updating 172 | # Validate and test the SWA model if available 173 | if system.enable_average_model('val-swa'): 174 | start_time = time.time() 175 | scores += trainer.validate(model=system, dataloaders=val_loader) 176 | time_dict['time-val-swa'] = (time.time()-start_time) 177 | if system.enable_average_model('test-swa'): 178 | start_time = time.time() 179 | scores += trainer.test(model=system, dataloaders=test_loader) 180 | time_dict['time-test-swa'] = (time.time()-start_time) 181 | system.disable_average_model() 182 | 183 | # Prepare the best model for testing 184 | if os.path.exists(checkpoint_callback.best_model_path): 185 | best_model = RegressionModel.load_from_checkpoint( 186 | checkpoint_callback.best_model_path, 187 | params=params, 188 | train_loader=train_loader, 189 | val_loader=val_loader) 190 | system.model = best_model.model 191 | system.disable_average_model() 192 | del best_model 193 | trainer.save_checkpoint(ckpt_dir+'/best.ckpt') 194 | 195 | start_time = time.time() 196 | scores += trainer.validate(model=system, dataloaders=val_loader) 197 | time_dict['time-val'] = (time.time()-start_time) 198 | 199 | start_time = time.time() 200 | scores += trainer.test(model=system, dataloaders=test_loader) 201 | time_dict['time-test'] = (time.time()-start_time) 202 | 203 | print("Running times-----------------------------------------") 204 | print(f"time-train : {time_dict['time-train']/(60**2)} h") 205 | print(f"time-val-swa : {time_dict['time-val-swa']} s") 206 | print(f"time-test-swa: {time_dict['time-test-swa']/len(test_dataset)*1000} ms") 207 | print(f"time-val : {time_dict['time-val']} s") 208 | print(f"time-test : {time_dict['time-test']/len(test_dataset)*1000} ms") 209 | with open(f'{logger.log_dir}/time.txt', 'w') as f: 210 | print(f"time-train : {time_dict['time-train']/(60**2)} h", file=f) 211 | print(f"time-val-swa : {time_dict['time-val-swa']} s", file=f) 212 | print(f"time-test-swa: {time_dict['time-test-swa']/len(test_dataset)*1000} ms", file=f) 213 | print(f"time-val : {time_dict['time-val']} s", file=f) 214 | print(f"time-test : {time_dict['time-test']/len(test_dataset)*1000} ms", file=f) 215 | for score in scores: 216 | for key in score: 217 | print(f"{key}\t:{score[key]}", file=f) 218 | 219 | logger.finalize('success') # to properly output all test scores in a TB log. 220 | 221 | if __name__ == '__main__': 222 | train() 223 | --------------------------------------------------------------------------------