├── README.md ├── data_prep.py ├── evaluation.py ├── models ├── __init__.py ├── model_mag.py ├── model_noequ.py ├── model_rot.py ├── model_scale.py └── model_um.py ├── requirements.txt ├── run.sh ├── run_model.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## Paper: 2 | Rui Wang*, Robin Walters*, Rose Yu [Incorporating Symmetry into Deep Dynamics Models for Improved Generalization](https://arxiv.org/abs/2002.03061), International Conference on Learning Representations 2021. (*equal contribution) 3 | 4 | ## Abstract: 5 | Recent work has shown deep learning can accelerate the prediction of physical dynamics relative to numerical solvers. However, limited physical accuracy and an inability to generalize under distributional shift limit its applicability to the real world. We propose to improve accuracy and generalization by incorporating symmetries into convolutional neural networks. Specifically, we employ a variety of methods each tailored to enforce a different symmetry. Our models are both theoretically and experimentally robust to distributional shift by symmetry group transformations and enjoy favorable sample complexity. We demonstrate the advantage of our approach on a variety of physical dynamics including Rayleigh–Bénard convection and real-world ocean currents and temperatures. Compare with image or text applications, our work is a significant step towards applying equivariant neural networks to high-dimensional systems with complex dynamics. 6 | 7 | ## Data Sets 8 | * [Rayleigh–Bénard convection DataSet](https://roselab1.ucsd.edu/seafile/d/7e7abe7c9c51489daa21/.) 2000 velocity fields (![formula](https://render.githubusercontent.com/render/math?math=2000\times2\times256\times1792)) 9 | * [Ocean Current DataSet](https://data.marine.copernicus.eu/products) 10 | 11 | ## Requirements 12 | - To install requirements 13 | ``` 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Description 18 | 1. models: pytorch implementation of equivariant models. 19 | 20 | 2. utils.py: training functions and dataset classes. 21 | 22 | 3. evaluation: code for computing Energy Spectrum Errors. 23 | 24 | 4. data_prep.py: preprocess RBC and Ocean data. 25 | 26 | 5. run_model.py: train models and report test RMSEs and ESEs. 27 | 28 | ## Instructions 29 | ### Dataset and Preprocessing 30 | - Download [RBC data and Ocean Data.](https://roselab1.ucsd.edu/seafile/d/7e7abe7c9c51489daa21/.) and put 'rbc_data.pt' and all the ocean NetCDF files in the same directory as data_prep.py. Due to the unavailability of the ocean data we previously downloaded from [Copernicus](https://resources.marine.copernicus.eu/?option=com_csw&view=details&product_id=GLOBAL_ANALYSIS_FORECAST_PHY_001_024) for the years 2016 to 2017, we conducted the experiments again using data from 2021 to 2022 with the same latitude and longitude range. 31 | 32 | - run data_prep.py to preprocess RBC and Ocean data and generate training, validation, test (and transformed test sets). 33 | ``` 34 | python data_prep.py 35 | ``` 36 | 37 | ### Training 38 | - Train Equiv and Non-Equiv ResNets and Unets on RBC data and Ocean data. 39 | ``` 40 | sh run.sh 41 | ``` 42 | 43 | ### New results on ocean currents data from 2021 to 2022 44 | 45 | | | RMSE | | ESE | | 46 | |-----------|:-------------:|:-------------:|:-------------:|:-------------:| 47 | | | Test_time | Test_domain | Test_time | Test_domain | 48 | | ResNet | 0.82-0.02 | 1.42-0.02 | 1.09-0.13 | 1.23-0.14 | 49 | | Equ_UM | 0.80-0.01 | 1.29-0.01 | 1.10-0.02 | 1.19-0.03 | 50 | | Equ_Mag | 0.79-0.01 | 1.29-0.00 | 1.01-0.01 | 1.02-0.03 | 51 | | Equ_Rot | 0.78-0.01 | 1.26-0.01 | **0.88-0.01** | **0.98-0.01** | 52 | | Equ_Scale | **0.76-0.03** | **1.25-0.02** | 0.98-0.04 | 1.00-0.03 | 53 | | Unet | 0.78-0.03 | 1.30-0.02 | 1.17-0.08 | 1.25-0.05 | 54 | | Equ_UM | 0.75-0.01 | 1.32-0.02 | 1.19-0.02 | 1.29-0.00 | 55 | | Equ_Mag | **0.74-0.01** | 1.27-0.02 | 0.94-0.01 | 1.00-0.03 | 56 | | Equ_Rot | 0.74-0.02 | **1.08-0.02** | **0.63-0.00** | **0.85-0.01** | 57 | | Equ_Scale | 0.75-0.00 | 1.12-0.00 | 0.89-0.04 | 0.98-0.02 | 58 | 59 | 60 | ## Cite 61 | ``` 62 | @inproceedings{wang2021incorporating, 63 | title={Incorporating Symmetry into Deep Dynamics Models for Improved Generalization}, 64 | author={Rui Wang and Robin Walters and Rose Yu}, 65 | booktitle={International Conference on Learning Representations}, 66 | year={2021}, 67 | url={https://openreview.net/forum?id=wta_8Hx2KD} 68 | } 69 | ``` 70 | -------------------------------------------------------------------------------- /data_prep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | #from netCDF4 import Dataset 4 | import os 5 | import cv2 6 | import random 7 | from torchvision import transforms 8 | import torchvision.transforms.functional as TF 9 | from PIL import Image 10 | import torch.nn.functional as F 11 | from netCDF4 import Dataset 12 | 13 | ################ Data Preprocessing ################ 14 | # read data 15 | data = torch.load("rbc_data.pt") 16 | 17 | # create folder for original data samples 18 | orig_data_direc = "data_64/" 19 | os.mkdir(orig_data_direc) 20 | 21 | # standardization 22 | std = torch.std(data) 23 | avg = torch.mean(data) 24 | data = (data - avg)/std 25 | data = data[:,:,::4,::4] 26 | 27 | # divide each rectangular snapshot into 7 subregions 28 | # data_prep shape: num_subregions * time * channels * w * h 29 | data_prep = torch.stack([data[:,:,:,k*64:(k+1)*64] for k in range(7)]) 30 | 31 | # use sliding windows to generate 10000 samples 32 | # training 6000, validation 2000, test 2000 33 | for j in range(0, 1500): 34 | for i in range(7): 35 | torch.save(data_prep[i, j : j + 50].double().float(), orig_data_direc + "sample_" + str(j*7+i) + ".pt") 36 | 37 | 38 | 39 | ################ Generate Transformed Test Sets ################ 40 | # Magnitude Transformation 41 | mag_data_direc = "data_mag/" 42 | os.mkdir(mag_data_direc) 43 | for i in range(8000, 10000): 44 | # multiplied by random values sampled from U(0, 2); 45 | mag_transformed_img = torch.load(orig_data_direc + "sample_" + str(i) + ".pt") * torch.rand(1) * 2 46 | torch.save(mag_transformed_img, mag_data_direc + "sample_" + str(i) + ".pt") 47 | 48 | # Uniform Motion Transformation 49 | um_data_direc = "data_um/" 50 | os.mkdir(um_data_direc) 51 | for i in range(8000, 10000): 52 | # added random vectors drawn from U(−2, 2); 53 | um_transformed_img = torch.load(orig_data_direc + "sample_" + str(i) + ".pt") + (torch.rand(1, 2, 1, 1)*4-2) 54 | torch.save(um_transformed_img, um_data_direc + "sample_" + str(i) + ".pt") 55 | 56 | # Rotation Transformation 57 | def rotate_image(mat, angle): 58 | """ 59 | Rotates an image (angle in degrees) and expands image to avoid cropping 60 | """ 61 | 62 | height, width = mat.shape[:2] # image shape has 3 dimensions 63 | image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape 64 | 65 | rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.) 66 | 67 | # rotation calculates the cos and sin, taking absolutes of those. 68 | abs_cos = abs(rotation_mat[0,0]) 69 | abs_sin = abs(rotation_mat[0,1]) 70 | 71 | # find the new width and height bounds 72 | bound_w = int(height * abs_sin + width * abs_cos) 73 | bound_h = int(height * abs_cos + width * abs_sin) 74 | 75 | # subtract old image center (bringing image back to origo) and adding the new image center coordinates 76 | rotation_mat[0, 2] += bound_w/2 - image_center[0] 77 | rotation_mat[1, 2] += bound_h/2 - image_center[1] 78 | 79 | # rotate image with the new bounds and translated rotation matrix 80 | rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h)) 81 | return rotated_mat 82 | 83 | def normalize(tensor): 84 | return (tensor - torch.min(tensor))/(torch.max(tensor) - torch.min(tensor)) 85 | 86 | def rotate(img, degree): 87 | #img shape 2*128*128 88 | #2*2 2*1*128*128 -> 2*1*128*128 89 | theta = torch.tensor(degree/180*np.pi) 90 | rot_m = torch.tensor([[torch.cos(theta), -torch.sin(theta)], [torch.sin(theta), torch.cos(theta)]]) 91 | img = torch.einsum("ab, bcde -> acde",(rot_m, img.unsqueeze(1))).squeeze(1) 92 | 93 | mmin = torch.min(img) 94 | mmax = torch.max(img) 95 | img = normalize(img).data.numpy() 96 | x = TTen(TF.rotate(Image.fromarray(np.uint8(img[0]*255)),degree, expand = True)) 97 | y = TTen(TF.rotate(Image.fromarray(np.uint8(img[1]*255)),degree, expand = True)) 98 | rot_img = torch.cat([x, y], dim = 0)#)normalize( 99 | #print(np.max(img), np.min(img), torch.max(rot_img), torch.min(rot_img)) 100 | rot_img[rot_img!=0] = rot_img[rot_img!=0]*(mmax - mmin) + mmin 101 | return rot_img 102 | 103 | rot_data_direc = "data_rot/" 104 | os.mkdir(rot_data_direc) 105 | PIL = transforms.ToPILImage() 106 | TTen = transforms.ToTensor() 107 | for i in range(8000, 10000): 108 | degree = random.choice([90, 180, 270, 360]) 109 | img = torch.load(orig_data_direc + "sample_" + str(i) + ".pt") 110 | rot_img = torch.cat([rotate(img[j], degree).unsqueeze(0) for j in range(img.shape[0])], dim = 0) 111 | torch.save(img, rot_data_direc + "sample_" + str(i) + ".pt") 112 | 113 | # Scale Transformation 114 | scale_data_direc = "data_scale/" 115 | os.mkdir(scale_data_direc) 116 | for i in range(8000, 10000): 117 | img = torch.load(orig_data_direc + "sample_" + str(i) + ".pt") 118 | factor = (torch.rand(1)*9+1)/2 119 | scale_transformed_img = F.interpolate(img.transpose(0,1).unsqueeze(0), scale_factor = (factor**2, factor, factor), mode="trilinear", align_corners=None)[0,:,:100].transpose(0,1)/factor 120 | torch.save(scale_transformed_img, scale_data_direc + "sample_" + str(i) + ".pt") 121 | 122 | 123 | 124 | ############### Preprocess Ocean Data ################## 125 | def load_nc(path): 126 | nc = Dataset(path) 127 | u0 = torch.from_numpy(np.array([nc["uo"][i].filled()[0] for i in range(len(nc["uo"]))])).float().unsqueeze(1) 128 | v0 = torch.from_numpy(np.array([nc["vo"][i].filled()[0] for i in range(len(nc["vo"]))])).float().unsqueeze(1) 129 | w = torch.cat([u0, v0], dim = 1) 130 | w[w<-1000] = 0 131 | w[w>10000] = 0 132 | return w 133 | 134 | 135 | atlantic = load_nc("atlantic.nc") 136 | indian = load_nc("indian.nc") 137 | north_pacific = load_nc("north_pacific.nc") 138 | south_pacific_test = load_nc("south_pacific_test.nc") 139 | 140 | 141 | os.mkdir("ocean_train") 142 | os.mkdir("ocean_test") 143 | 144 | k = 0 145 | for t in range(500): 146 | for i in range(3): 147 | for j in range(3): 148 | torch.save(atlantic[t:t+50,:,64*i:64*(i+1),64*j:64*(j+1)].double().float(), "ocean_train/sample_" + str(k) + ".pt") 149 | k += 1 150 | torch.save(indian[t:t+50,:,64*i:64*(i+1),64*j:64*(j+1)].double().float(), "ocean_train/sample_" + str(k) + ".pt") 151 | k += 1 152 | torch.save(north_pacific[t:t+50,:,64*i:64*(i+1),64*j:64*(j+1)].double().float().float(), "ocean_train/sample_" + str(k) + ".pt") 153 | k += 1 154 | 155 | k = 0 156 | for t in range(300): 157 | for i in range(3): 158 | for j in range(3): 159 | torch.save(south_pacific_test[t:t+50,:,64*i:64*(i+1),64*j:64*(j+1)].double().float(), "ocean_test/sample_" + str(k) + ".pt") 160 | k += 1 -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.utils import data 5 | from torch.autograd import Variable 6 | import os 7 | 8 | def azimuthalAverage(image, center=None): 9 | """ 10 | Calculate the azimuthally averaged radial profile. 11 | 12 | image - The 2D image 13 | center - The [x,y] pixel coordinates used as the center. The default is 14 | None, which then uses the center of the image (including 15 | fracitonal pixels). 16 | 17 | """ 18 | # Calculate the indices from the image 19 | y, x = np.indices(image.shape) 20 | 21 | if not center: 22 | center = np.array([(x.max()-x.min())/2.0, (x.max()-x.min())/2.0]) 23 | 24 | r = np.hypot(x - center[0], y - center[1]) 25 | 26 | # Get sorted radii 27 | ind = np.argsort(r.flat) 28 | r_sorted = r.flat[ind] 29 | i_sorted = image.flat[ind] 30 | 31 | # Get the integer part of the radii (bin size = 1) 32 | r_int = r_sorted.astype(int) 33 | 34 | # Find all pixels that fall within each radial bin. 35 | deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented 36 | rind = np.where(deltar)[0] # location of changed radius 37 | nr = rind[1:] - rind[:-1] # number of radius bin 38 | 39 | # Cumulative sum to figure out sums for each radius bin 40 | csim = np.cumsum(i_sorted, dtype=float) 41 | tbin = csim[rind[1:]] - csim[rind[:-1]] 42 | 43 | radial_prof = tbin / nr 44 | 45 | return radial_prof 46 | 47 | 48 | def TKE(preds): 49 | """ 50 | Calculate the TKE field of the predictions 51 | """ 52 | mean_flow = np.expand_dims(np.mean(preds, axis = 0), axis = 0) 53 | tur_preds = np.mean((preds - mean_flow)**2, axis = 0) 54 | tke = (tur_preds[0] + tur_preds[1])/2 55 | return tke 56 | 57 | def tke2spectrum(tke): 58 | """ 59 | Convert TKE field to spectrum 60 | """ 61 | sp = np.fft.fft2(tke) 62 | sp = np.fft.fftshift(sp) 63 | sp = np.real(sp*np.conjugate(sp)) 64 | sp1D = azimuthalAverage(sp) 65 | return np.log10(sp1D) 66 | 67 | 68 | def spectrum_band(tensor): 69 | """ 70 | Calculate spectrum_band of predictions 71 | """ 72 | spec = np.array([tke2spectrum(TKE(tensor[i])) for i in range(tensor.shape[0])]) 73 | return spec 74 | 75 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_noequ import ResNet, Unet 2 | from .model_um import ResNet_UM, Unet_UM 3 | from .model_mag import ResNet_Mag, Unet_Mag 4 | from .model_rot import ResNet_Rot, Unet_Rot 5 | from .model_scale import ResNet_Scale, Unet_Scale 6 | import e2cnn -------------------------------------------------------------------------------- /models/model_mag.py: -------------------------------------------------------------------------------- 1 | """ 2 | Magnitude Equivariant ResNet and U-net 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class mag_conv2d(nn.Module): 13 | def __init__(self, 14 | input_channels, 15 | output_channels, 16 | kernel_size, 17 | um_dim = 2, # The number of channels of each frame 18 | activation = True, # whether to use activation functions 19 | stride = 1, 20 | deconv = False):# Whether this is used as a deconvolutional layer 21 | """ 22 | Magnitude Equivariant 2D Convolutional Layers 23 | """ 24 | super(mag_conv2d, self).__init__() 25 | self.activation = activation 26 | self.input_channels = input_channels 27 | self.kernel_size = kernel_size 28 | self.um_dim = um_dim 29 | self.stride = stride 30 | self.conv2d = nn.Conv2d(input_channels, output_channels, kernel_size, stride = kernel_size, bias = True) 31 | self.pad_size = (kernel_size - 1)//2 32 | self.input_channels = self.input_channels 33 | self.batchnorm = nn.BatchNorm2d(output_channels) 34 | self.deconv = deconv 35 | 36 | def unfold(self, x): 37 | """ 38 | Extracts sliding local blocks from a batched input tensor. 39 | """ 40 | if not self.deconv: 41 | x = F.pad(x, ((self.pad_size, self.pad_size)*2), mode = 'replicate') 42 | out = F.unfold(x, kernel_size = self.kernel_size) 43 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, out.shape[-1]) 44 | 45 | ## Batch_size x (in_channels x kernel_size x kernel_size) x 64 x 64 46 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, int(np.sqrt(out.shape[-1])), int(np.sqrt(out.shape[-1]))) 47 | if self.stride > 1: 48 | return out[:,:,:,:,::self.stride,::self.stride] 49 | return out 50 | 51 | def transform(self, x): 52 | """ 53 | Max-Min Normalization on each sliding local block. 54 | """ 55 | # Calculates the max-min of input sliding local blocks 56 | out = x.reshape(x.shape[0], self.input_channels//self.um_dim, self.um_dim, self.kernel_size, self.kernel_size, x.shape[-2], x.shape[-1]) 57 | stds = (out.max(1).values.unsqueeze(1).max(3).values.unsqueeze(3).max(4).values.unsqueeze(4) - 58 | out.min(1).values.unsqueeze(1).min(3).values.unsqueeze(3).min(4).values.unsqueeze(4)) 59 | 60 | out = out /(stds + 10e-7) 61 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, x.shape[-2], x.shape[-1]).transpose(2,4).transpose(-1,-2) 62 | out = out.reshape(out.shape[0], self.input_channels, x.shape[-2]*self.kernel_size, x.shape[-1], self.kernel_size) 63 | 64 | ## Batch_size x in_channels x (64 x kernel_size) x (64 x kernel_size) 65 | out = out.reshape(out.shape[0], self.input_channels, x.shape[-2]*self.kernel_size, x.shape[-1]*self.kernel_size) 66 | return out, stds.squeeze(3).squeeze(3) 67 | 68 | 69 | def inverse_transform(self, out, stds): 70 | """ 71 | Inverse Max-Min Normalization. 72 | """ 73 | out = out.reshape(out.shape[0], out.shape[1]//self.um_dim, self.um_dim, out.shape[-2], out.shape[-1]) 74 | out = out * (stds + 10e-7) 75 | out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1]) 76 | return out 77 | 78 | def forward(self, x): 79 | x = self.unfold(x) 80 | x, stds = self.transform(x) 81 | out = self.conv2d(x) 82 | if self.activation: 83 | out = F.relu(out) 84 | out = self.inverse_transform(out, stds) 85 | return out 86 | 87 | 88 | class mag_deconv2d(nn.Module): 89 | def __init__(self, input_channels, output_channels): 90 | """ 91 | Magnitude Equivariant 2D Transposed Convolutional Layers 92 | """ 93 | super(mag_deconv2d, self).__init__() 94 | self.conv2d = mag_conv2d(input_channels = input_channels, output_channels = output_channels, kernel_size = 4, um_dim = 2, 95 | activation = True, stride = 1, deconv = True) 96 | 97 | def pad(self, x): 98 | pad_x = torch.zeros(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]*2) 99 | pad_x[:,:,::2,::2].copy_(x) 100 | pad_x = F.pad(pad_x, (1,2,1,2), mode='replicate') 101 | return pad_x 102 | 103 | def forward(self, x): 104 | out = self.pad(x).to(device) 105 | return self.conv2d(out) 106 | 107 | # Magnitude Equivariant ResNet. 108 | class mag_resblock(nn.Module): 109 | def __init__(self, input_channels, hidden_dim, kernel_size): 110 | super(mag_resblock, self).__init__() 111 | self.layer1 = mag_conv2d(input_channels, hidden_dim, kernel_size) 112 | self.layer2 = mag_conv2d(hidden_dim, hidden_dim, kernel_size) 113 | self.input_channels = input_channels 114 | self.hidden_dim = hidden_dim 115 | 116 | if input_channels != hidden_dim: 117 | self.upscale = mag_conv2d(input_channels, hidden_dim, kernel_size, activation = False) 118 | 119 | def forward(self, x): 120 | out = self.layer1(x) 121 | 122 | if self.input_channels != self.hidden_dim: 123 | out = self.layer2(out) + self.upscale(x) 124 | else: 125 | out = self.layer2(out) + x 126 | 127 | return out 128 | 129 | class ResNet_Mag(nn.Module): 130 | def __init__(self, input_channels, output_channels, kernel_size): 131 | super(ResNet_Mag, self).__init__() 132 | layers = [mag_resblock(input_channels, 64, kernel_size), mag_resblock(64, 64, kernel_size)] 133 | layers += [mag_resblock(64, 128, kernel_size), mag_resblock(128, 128, kernel_size)] 134 | layers += [mag_resblock(128, 256, kernel_size), mag_resblock(256, 256, kernel_size)] 135 | layers += [mag_resblock(256, 512, kernel_size), mag_resblock(512, 512, kernel_size)] 136 | layers += [mag_conv2d(512, output_channels, kernel_size = kernel_size, activation = False)] 137 | self.model = nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | out = self.model(x) 141 | return out 142 | 143 | 144 | # Magnitude Equivariant U_net. 145 | class Unet_Mag(nn.Module): 146 | def __init__(self, input_channels, output_channels, kernel_size): 147 | super(Unet_Mag, self).__init__() 148 | self.input_channels = input_channels 149 | self.conv1 = mag_conv2d(input_channels, 64, kernel_size = kernel_size, stride=2) 150 | self.conv1_1 = mag_conv2d(64, 64, kernel_size = kernel_size, stride=1) 151 | self.conv2 = mag_conv2d(64, 128, kernel_size = kernel_size, stride=2) 152 | self.conv2_1 = mag_conv2d(128, 128, kernel_size = kernel_size, stride = 1) 153 | self.conv3 = mag_conv2d(128, 256, kernel_size = kernel_size, stride=2) 154 | self.conv3_1 = mag_conv2d(256, 256, kernel_size = kernel_size, stride=1) 155 | self.conv4 = mag_conv2d(256, 512, kernel_size = kernel_size, stride=2) 156 | self.conv4_1 = mag_conv2d(512, 512, kernel_size = kernel_size, stride=1) 157 | 158 | self.deconv3 = mag_deconv2d(512, 128) 159 | self.deconv2 = mag_deconv2d(384, 64) 160 | self.deconv1 = mag_deconv2d(192, 32) 161 | self.deconv0 = mag_deconv2d(96, 16) 162 | self.output_layer = mag_conv2d(16 + input_channels, output_channels, kernel_size=kernel_size, activation = False) 163 | 164 | def forward(self, x): 165 | out_conv1 = self.conv1_1(self.conv1(x)) 166 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 167 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 168 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 169 | 170 | out_deconv3 = self.deconv3(out_conv4) 171 | concat3 = torch.cat((out_conv3, out_deconv3), 1) 172 | out_deconv2 = self.deconv2(concat3) 173 | concat2 = torch.cat((out_conv2, out_deconv2), 1) 174 | out_deconv1 = self.deconv1(concat2) 175 | concat1 = torch.cat((out_conv1, out_deconv1), 1) 176 | out_deconv0 = self.deconv0(concat1) 177 | concat0 = torch.cat((x, out_deconv0), 1) 178 | out = self.output_layer(concat0) 179 | return out 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /models/model_noequ.py: -------------------------------------------------------------------------------- 1 | """ 2 | Non-equivariant ResNet and U-net 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.nn.functional as F 9 | 10 | 11 | def conv(input_channels, output_channels, kernel_size, stride): 12 | return nn.Sequential( 13 | nn.Conv2d(input_channels, output_channels, kernel_size = kernel_size, 14 | stride = stride, padding=(kernel_size - 1) // 2), 15 | nn.ReLU() 16 | ) 17 | 18 | def deconv(input_channels, output_channels): 19 | return nn.Sequential( 20 | nn.ConvTranspose2d(input_channels, output_channels, kernel_size = 4, 21 | stride = 2, padding=1), 22 | nn.ReLU() 23 | ) 24 | 25 | class Unet(nn.Module): 26 | def __init__(self, input_channels, output_channels, kernel_size): 27 | super(Unet, self).__init__() 28 | self.input_channels = input_channels 29 | self.conv1 = conv(input_channels, 64, kernel_size=kernel_size, stride=2) 30 | self.conv2 = conv(64, 128, kernel_size=kernel_size, stride=2) 31 | self.conv2_1 = conv(128, 128, kernel_size=kernel_size, stride=1) 32 | self.conv3 = conv(128, 256, kernel_size=kernel_size, stride=2) 33 | self.conv3_1 = conv(256, 256, kernel_size=kernel_size, stride=1) 34 | self.conv4 = conv(256, 512, kernel_size=kernel_size, stride=2) 35 | self.conv4_1 = conv(512, 512, kernel_size=kernel_size, stride=1) 36 | 37 | self.deconv3 = deconv(512, 128) 38 | self.deconv2 = deconv(384, 64) 39 | self.deconv1 = deconv(192, 32) 40 | self.deconv0 = deconv(96, 16) 41 | 42 | self.output_layer = nn.Conv2d(16 + input_channels, output_channels, kernel_size=kernel_size, 43 | stride = 1, padding=(kernel_size - 1) // 2) 44 | 45 | def forward(self, x): 46 | out_conv1 = self.conv1(x) 47 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 48 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 49 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 50 | 51 | out_deconv3 = self.deconv3(out_conv4) 52 | concat3 = torch.cat((out_conv3, out_deconv3), 1) 53 | out_deconv2 = self.deconv2(concat3) 54 | concat2 = torch.cat((out_conv2, out_deconv2), 1) 55 | out_deconv1 = self.deconv1(concat2) 56 | concat1 = torch.cat((out_conv1, out_deconv1), 1) 57 | out_deconv0 = self.deconv0(concat1) 58 | concat0 = torch.cat((x, out_deconv0), 1) 59 | out = self.output_layer(concat0) 60 | return out 61 | 62 | 63 | class Resblock(nn.Module): 64 | def __init__(self, input_channels, hidden_dim, kernel_size): 65 | super(Resblock, self).__init__() 66 | self.layer1 = nn.Sequential( 67 | nn.Conv2d(input_channels, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2), 68 | nn.BatchNorm2d(hidden_dim), 69 | nn.LeakyReLU() 70 | ) 71 | self.layer2 = nn.Sequential( 72 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2), 73 | nn.BatchNorm2d(hidden_dim), 74 | nn.LeakyReLU() 75 | ) 76 | 77 | if input_channels != hidden_dim: 78 | self.upscale = nn.Sequential( 79 | nn.Conv2d(input_channels, hidden_dim, kernel_size = kernel_size, padding = (kernel_size-1)//2), 80 | nn.LeakyReLU() 81 | ) 82 | self.input_channels = input_channels 83 | self.hidden_dim = hidden_dim 84 | 85 | 86 | def forward(self, xx): 87 | out = self.layer1(xx) 88 | if self.input_channels != self.hidden_dim: 89 | out = self.layer2(out) + self.upscale(xx) 90 | else: 91 | out = self.layer2(out) + xx 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | def __init__(self, input_channels, output_channels, kernel_size): 97 | super(ResNet, self).__init__() 98 | layers = [Resblock(input_channels, 64, kernel_size), Resblock(64, 64, kernel_size)] 99 | layers += [Resblock(64, 128, kernel_size), Resblock(128, 128, kernel_size)] 100 | layers += [Resblock(128, 256, kernel_size), Resblock(256, 256, kernel_size)] 101 | layers += [Resblock(256, 512, kernel_size), Resblock(512, 512, kernel_size)] 102 | layers += [nn.Conv2d(512, output_channels, kernel_size = kernel_size, padding = (kernel_size-1)//2)] 103 | self.model = nn.Sequential(*layers) 104 | 105 | def forward(self, xx): 106 | out = self.model(xx) 107 | return out -------------------------------------------------------------------------------- /models/model_rot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Rotational Equivariant ResNet and U-net 3 | """ 4 | import os 5 | import torch 6 | from e2cnn import gspaces 7 | from e2cnn import nn 8 | import numpy as np 9 | from torch.utils import data 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | 14 | ##### Rotational Equivariant ResNet ##### 15 | class rot_resblock(torch.nn.Module): 16 | def __init__(self, 17 | input_channels, 18 | hidden_dim, 19 | kernel_size, 20 | N # Group size 21 | ): 22 | super(rot_resblock, self).__init__() 23 | 24 | # Specify symmetry transformation 25 | r2_act = gspaces.Rot2dOnR2(N = N) 26 | feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) 27 | feat_type_hid = nn.FieldType(r2_act, hidden_dim*[r2_act.regular_repr]) 28 | 29 | self.layer1 = nn.SequentialModule( 30 | nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), 31 | nn.InnerBatchNorm(feat_type_hid), 32 | nn.ReLU(feat_type_hid) 33 | ) 34 | 35 | self.layer2 = nn.SequentialModule( 36 | nn.R2Conv(feat_type_hid, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), 37 | nn.InnerBatchNorm(feat_type_hid), 38 | nn.ReLU(feat_type_hid) 39 | ) 40 | 41 | self.upscale = nn.SequentialModule( 42 | nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), 43 | nn.InnerBatchNorm(feat_type_hid), 44 | nn.ReLU(feat_type_hid) 45 | ) 46 | 47 | self.input_channels = input_channels 48 | self.hidden_dim = hidden_dim 49 | 50 | def forward(self, x): 51 | out = self.layer1(x) 52 | 53 | if self.input_channels != self.hidden_dim: 54 | out = self.layer2(out) + self.upscale(x) 55 | else: 56 | out = self.layer2(out) + x 57 | 58 | return out 59 | 60 | 61 | ##### Rotational Equivariant ResNet ##### 62 | class ResNet_Rot(torch.nn.Module): 63 | def __init__(self, input_frames, output_frames, kernel_size, N): 64 | super(ResNet_Rot, self).__init__() 65 | r2_act = gspaces.Rot2dOnR2(N = N) 66 | # we use rho_1 representation since the input is velocity fields 67 | self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) 68 | # we use regular representation for middle layers 69 | self.feat_type_in_hid = nn.FieldType(r2_act, 16*[r2_act.regular_repr]) 70 | self.feat_type_hid_out = nn.FieldType(r2_act, 192*[r2_act.regular_repr]) 71 | self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) 72 | 73 | self.input_layer = nn.SequentialModule( 74 | nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, padding = (kernel_size - 1)//2), 75 | nn.InnerBatchNorm(self.feat_type_in_hid), 76 | nn.ReLU(self.feat_type_in_hid) 77 | ) 78 | layers = [self.input_layer] 79 | layers += [rot_resblock(16, 32, kernel_size, N), rot_resblock(32, 32, kernel_size, N)] 80 | layers += [rot_resblock(32, 64, kernel_size, N), rot_resblock(64, 64, kernel_size, N)] 81 | layers += [rot_resblock(64, 128, kernel_size, N), rot_resblock(128, 128, kernel_size, N)] 82 | layers += [rot_resblock(128, 192, kernel_size, N), rot_resblock(192, 192, kernel_size, N)] 83 | layers += [nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2)] 84 | self.model = torch.nn.Sequential(*layers) 85 | 86 | def forward(self, x): 87 | #BxCxHxW 88 | x = nn.GeometricTensor(x, self.feat_type_in) 89 | out = self.model(x) 90 | return out.tensor 91 | 92 | 93 | ##### Rotational Equivariant Unet ##### 94 | class rot_conv2d(torch.nn.Module): 95 | def __init__(self, input_channels, output_channels, kernel_size, stride, N, activation = True, deconv = False, last_deconv = False): 96 | super(rot_conv2d, self).__init__() 97 | r2_act = gspaces.Rot2dOnR2(N = N) 98 | 99 | feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) 100 | feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.regular_repr]) 101 | if not deconv: 102 | if activation: 103 | self.layer = nn.SequentialModule( 104 | nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = (kernel_size - 1)//2), 105 | nn.InnerBatchNorm(feat_type_hid), 106 | nn.ReLU(feat_type_hid) 107 | ) 108 | else: 109 | self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride,padding = (kernel_size - 1)//2) 110 | else: 111 | if last_deconv: 112 | feat_type_in = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) 113 | feat_type_hid = nn.FieldType(r2_act, output_channels*[r2_act.irrep(1)]) 114 | self.layer = nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0) 115 | else: 116 | self.layer = nn.SequentialModule( 117 | nn.R2Conv(feat_type_in, feat_type_hid, kernel_size = kernel_size, stride = stride, padding = 0), 118 | nn.InnerBatchNorm(feat_type_hid), 119 | nn.ReLU(feat_type_hid) 120 | ) 121 | 122 | def forward(self, x): 123 | return self.layer(x) 124 | 125 | class rot_deconv2d(torch.nn.Module): 126 | def __init__(self, input_channels, output_channels, N, last_deconv = False): 127 | super(rot_deconv2d, self).__init__() 128 | self.conv2d = rot_conv2d(input_channels = input_channels, output_channels = output_channels, kernel_size = 4, 129 | activation = True, stride = 1, N = N, deconv = True, last_deconv = last_deconv) 130 | r2_act = gspaces.Rot2dOnR2(N = N) 131 | self.feat_type = nn.FieldType(r2_act, input_channels*[r2_act.regular_repr]) 132 | 133 | def pad(self, x): 134 | new_x = torch.zeros(x.shape[0], x.shape[1], x.shape[2]*2 + 3, x.shape[3]*2 + 3) 135 | new_x[:,:,:-3,:-3][:,:,::2,::2] = x 136 | new_x[:,:,:-3,:-3][:,:,1::2,1::2] = x 137 | new_x = nn.GeometricTensor(new_x, self.feat_type) 138 | return new_x 139 | 140 | def forward(self, x): 141 | out = self.pad(x).to(device) 142 | return self.conv2d(out) 143 | 144 | class Unet_Rot(torch.nn.Module): 145 | def __init__(self, input_frames, output_frames, kernel_size, N): 146 | super(Unet_Rot, self).__init__() 147 | r2_act = gspaces.Rot2dOnR2(N = N) 148 | self.feat_type_in = nn.FieldType(r2_act, input_frames*[r2_act.irrep(1)]) 149 | self.feat_type_in_hid = nn.FieldType(r2_act, 32*[r2_act.regular_repr]) 150 | self.feat_type_hid_out = nn.FieldType(r2_act, (16 + input_frames)*[r2_act.irrep(1)]) 151 | self.feat_type_out = nn.FieldType(r2_act, output_frames*[r2_act.irrep(1)]) 152 | 153 | self.conv1 = nn.SequentialModule( 154 | nn.R2Conv(self.feat_type_in, self.feat_type_in_hid, kernel_size = kernel_size, stride = 2, padding = (kernel_size - 1)//2), 155 | nn.InnerBatchNorm(self.feat_type_in_hid), 156 | nn.ReLU(self.feat_type_in_hid) 157 | ) 158 | 159 | self.conv2 = rot_conv2d(32, 64, kernel_size = kernel_size, stride = 1, N = N) 160 | self.conv2_1 = rot_conv2d(64, 64, kernel_size = kernel_size, stride = 1, N = N) 161 | self.conv3 = rot_conv2d(64, 128, kernel_size = kernel_size, stride = 2, N = N) 162 | self.conv3_1 = rot_conv2d(128, 128, kernel_size = kernel_size, stride = 1, N = N) 163 | self.conv4 = rot_conv2d(128, 256, kernel_size = kernel_size, stride = 2, N = N) 164 | self.conv4_1 = rot_conv2d(256, 256, kernel_size = kernel_size, stride = 1, N = N) 165 | 166 | self.deconv3 = rot_deconv2d(256, 64, N) 167 | self.deconv2 = rot_deconv2d(192, 32, N) 168 | self.deconv1 = rot_deconv2d(96, 16, N, last_deconv = True) 169 | 170 | 171 | self.output_layer = nn.R2Conv(self.feat_type_hid_out, self.feat_type_out, kernel_size = kernel_size, padding = (kernel_size - 1)//2) 172 | 173 | 174 | def forward(self, x): 175 | 176 | x = nn.GeometricTensor(x, self.feat_type_in) 177 | out_conv1 = self.conv1(x) 178 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 179 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 180 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 181 | 182 | out_deconv3 = self.deconv3(out_conv4.tensor) 183 | concat3 = torch.cat((out_conv3.tensor, out_deconv3.tensor), 1) 184 | out_deconv2 = self.deconv2(concat3) 185 | concat2 = torch.cat((out_conv2.tensor, out_deconv2.tensor), 1) 186 | out_deconv1 = self.deconv1(concat2) 187 | 188 | concat0 = torch.cat((x.tensor, out_deconv1.tensor), 1) 189 | concat0 = nn.GeometricTensor(concat0, self.feat_type_hid_out) 190 | out = self.output_layer(concat0) 191 | return out.tensor 192 | -------------------------------------------------------------------------------- /models/model_scale.py: -------------------------------------------------------------------------------- 1 | """ 2 | Scale Equivariant ResNet and U-net 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import math 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | class scale_conv2d(nn.Module): 13 | def __init__(self, 14 | in_channels, 15 | out_channels, 16 | kernel_size, 17 | l = 3, # Number of levels of input 18 | sout = 5, # Number of scales we model in the convolution layer. 19 | activation = True, # If add the activation function at the end 20 | stride = 1, 21 | deconv = False): 22 | super(scale_conv2d, self).__init__() 23 | self.out_channels= out_channels 24 | self.in_channels = in_channels 25 | self.l = l 26 | self.sout = sout 27 | self.activation = activation 28 | self.kernel_size = kernel_size 29 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 30 | weight_shape = (out_channels, l, 2, in_channels//2, kernel_size, kernel_size) # The shape of scale equivariant conv2d kernels 31 | self.stdv = math.sqrt(1. / (kernel_size * kernel_size * in_channels * l)) 32 | self.weights = nn.Parameter(torch.Tensor(*weight_shape)) 33 | self.reset_parameters() 34 | self.stride = stride 35 | self.deconv = deconv 36 | 37 | 38 | def reset_parameters(self): 39 | self.weights.data.uniform_(-self.stdv, self.stdv) 40 | if self.bias is not None: 41 | self.bias.data.fill_(0) 42 | 43 | def shrink_kernel(self, kernel, up_scale): 44 | """ 45 | Shrink the kernel via boundary padding and grid sampling. 46 | """ 47 | up_scale = torch.tensor(up_scale).float() 48 | # boundary padding based on the scaling law 49 | pad_in = (torch.ceil(up_scale**2).int())*((kernel.shape[2]-1)//2) 50 | pad_h = (torch.ceil(up_scale).int())*((kernel.shape[3]-1)//2) 51 | pad_w = (torch.ceil(up_scale).int())*((kernel.shape[4]-1)//2) 52 | padded_kernel = F.pad(kernel, (pad_w, pad_w, pad_h, pad_h, pad_in, pad_in)) 53 | delta = up_scale%1 54 | 55 | if delta == 0: 56 | shrink_factor = 1 57 | else: 58 | # shrink_factor for coordinates. 59 | shrink_factor = (((kernel.shape[4]-1))/(padded_kernel.shape[-1]-1)*(up_scale+1)) 60 | 61 | # Adjustment to deal with weird filtering on the grid sample function. 62 | shrink_factor = 1.5*(shrink_factor-0.5)**3 + 0.57 63 | 64 | grid = torch.meshgrid(torch.linspace(-1, 1, kernel.shape[2])*(shrink_factor**2), 65 | torch.linspace(-1, 1, kernel.shape[3])*shrink_factor, 66 | torch.linspace(-1, 1, kernel.shape[4])*shrink_factor) 67 | 68 | grid = torch.cat([grid[2].unsqueeze(0).unsqueeze(-1), 69 | grid[1].unsqueeze(0).unsqueeze(-1), 70 | grid[0].unsqueeze(0).unsqueeze(-1)], dim = -1).repeat(kernel.shape[0],1,1,1,1) 71 | 72 | new_kernel = F.grid_sample(padded_kernel, grid.to(device)) 73 | if kernel.shape[-1] - 2*up_scale > 0: 74 | new_kernel = new_kernel * (kernel.shape[-1]**2/((kernel.shape[-1] - 2*up_scale)**2 + 0.01)) 75 | return new_kernel 76 | 77 | def dilate_kernel(self, kernel, dilation): 78 | """ 79 | upscale the kernel via inside padding and grid sampling. 80 | """ 81 | if dilation == 0: 82 | return kernel 83 | # inside padding based on the scaling law 84 | dilation = torch.tensor(dilation).float() 85 | delta = dilation%1 86 | 87 | d_in = torch.ceil(dilation**2).int() 88 | new_in = kernel.shape[2] + (kernel.shape[2]-1)*d_in 89 | 90 | d_h = torch.ceil(dilation).int() 91 | new_h = kernel.shape[3] + (kernel.shape[3]-1)*d_h 92 | 93 | d_w = torch.ceil(dilation).int() 94 | new_w = kernel.shape[4] + (kernel.shape[4]-1)*d_h 95 | 96 | new_kernel = torch.zeros(kernel.shape[0], kernel.shape[1], new_in, new_h, new_w) 97 | new_kernel[:,:,::(d_in+1),::(d_h+1), ::(d_w+1)] = kernel 98 | dilate_factor = 1 99 | 100 | new_kernel = F.pad(new_kernel, ((kernel.shape[4]-1)//2, (kernel.shape[4]-1)//2)*3) 101 | 102 | dilate_factor = (new_kernel.shape[-1] - 1 - (kernel.shape[4]-1)*(delta))/(new_kernel.shape[-1] - 1) 103 | 104 | grid = torch.meshgrid(torch.linspace(-1, 1, new_in)*(dilate_factor**2), 105 | torch.linspace(-1, 1, new_h)*dilate_factor, 106 | torch.linspace(-1, 1, new_w)*dilate_factor) 107 | 108 | grid = torch.cat([grid[2].unsqueeze(0).unsqueeze(-1), 109 | grid[1].unsqueeze(0).unsqueeze(-1), 110 | grid[0].unsqueeze(0).unsqueeze(-1)], dim = -1).repeat(kernel.shape[0],1,1,1,1) 111 | 112 | new_kernel = F.grid_sample(new_kernel, grid) 113 | 114 | return new_kernel[:,:,-kernel.shape[2]:] 115 | 116 | 117 | def forward(self, xx): 118 | out = [] 119 | for s in range(self.sout): 120 | t = np.minimum(s + self.l, self.sout) 121 | inp = xx[:,s:t].reshape(xx.shape[0], -1, xx.shape[-2], xx.shape[-1]) 122 | w = self.weights[:,:(t-s),:,:,:].reshape(self.out_channels, 2*(t-s), self.in_channels//2, self.kernel_size, self.kernel_size).to(device) 123 | 124 | if (s - self.sout//2) < 0: 125 | new_kernel = self.shrink_kernel(w, (self.sout//2 - s)/2).to(device) 126 | elif (s - self.sout//2) > 0: 127 | new_kernel = self.dilate_kernel(w, (s - self.sout//2)/2).to(device) 128 | else: 129 | new_kernel = w.to(device) 130 | 131 | new_kernel = new_kernel.reshape(self.out_channels, (t-s)*self.in_channels, new_kernel.shape[-2], new_kernel.shape[-1]) 132 | 133 | 134 | if self.deconv: 135 | if (s - self.sout//2) > 0: 136 | conv = F.conv2d(F.pad(inp, (1,2,1,2)), new_kernel) 137 | else: 138 | conv = F.conv2d(inp, new_kernel) 139 | else: 140 | conv = F.conv2d(inp, new_kernel, padding = ((new_kernel.shape[-2]-1)//2, (new_kernel.shape[-1]-1)//2), stride = self.stride) 141 | 142 | out.append(conv.unsqueeze(1)) 143 | 144 | out = torch.cat(out, dim = 1) 145 | if self.activation: 146 | out = F.leaky_relu(out) 147 | 148 | return out 149 | 150 | class scale_deconv2d(nn.Module): 151 | def __init__(self, in_channels, out_channels): 152 | super(scale_deconv2d, self).__init__() 153 | self.conv2d = scale_conv2d(in_channels, out_channels, kernel_size = 4, deconv = True) 154 | 155 | def pad(self, xx): 156 | new_xx = torch.zeros(xx.shape[0], xx.shape[1], xx.shape[2], xx.shape[3]*2+3, xx.shape[4]*2+3) 157 | new_xx[:,:,:,:-3,:-3][:,:,:,::2,::2] = xx 158 | return new_xx 159 | 160 | def forward(self, xx): 161 | out = self.pad(xx).to(device) 162 | return self.conv2d(out) 163 | 164 | class scale_resblock(nn.Module): 165 | def __init__(self, in_channels, hidden_dim, kernel_size, skip = True): 166 | super(scale_resblock, self).__init__() 167 | self.layer1 = scale_conv2d(in_channels = in_channels, out_channels = hidden_dim, kernel_size = kernel_size) 168 | self.layer2 = scale_conv2d(in_channels = hidden_dim, out_channels = hidden_dim, kernel_size = kernel_size) 169 | self.input_channels = in_channels 170 | self.hidden_dim = hidden_dim 171 | if in_channels != hidden_dim: 172 | self.upscale = scale_conv2d(in_channels = in_channels, out_channels = hidden_dim, kernel_size = kernel_size, activation = False) 173 | 174 | def forward(self, xx): 175 | out = self.layer1(xx) 176 | if self.input_channels != self.hidden_dim: 177 | out = self.layer2(out) + self.upscale(xx) 178 | else: 179 | out = self.layer2(out) + xx 180 | return out 181 | 182 | class ResNet_Scale(nn.Module): 183 | def __init__(self, input_channels, output_channels, kernel_size): 184 | super(ResNet_Scale, self).__init__() 185 | self.input_layer = scale_conv2d(out_channels = 32, in_channels = input_channels, kernel_size = kernel_size) 186 | layers = [self.input_layer] 187 | layers += [scale_resblock(32, 32, kernel_size, True), scale_resblock(32, 32, kernel_size, True)] 188 | layers += [scale_resblock(32, 64, kernel_size, False), scale_resblock(64, 64, kernel_size, True)] 189 | layers += [scale_resblock(64, 128, kernel_size, False), scale_resblock(128, 128, kernel_size, True)] 190 | layers += [scale_resblock(128, 128, kernel_size, True), scale_resblock(128, 128, kernel_size, True)] 191 | layers += [scale_conv2d(out_channels = output_channels, in_channels = 128, kernel_size = kernel_size, sout = 1, activation = False)] 192 | self.model = nn.Sequential(*layers) 193 | 194 | def forward(self, xx): 195 | out = self.model(xx) 196 | out = out.squeeze(1) 197 | return out 198 | 199 | 200 | class Unet_Scale(nn.Module): 201 | def __init__(self, input_channels, output_channels, kernel_size): 202 | super(Unet_Scale, self).__init__() 203 | self.conv1 = scale_conv2d(input_channels, 32, kernel_size = kernel_size, stride=2) 204 | self.conv2 = scale_conv2d(32, 64, kernel_size = kernel_size, stride=2) 205 | self.conv2_2 = scale_conv2d(64, 64, kernel_size = kernel_size, stride = 1) 206 | self.conv3 = scale_conv2d(64, 128, kernel_size = kernel_size, stride=2) 207 | self.conv3_1 = scale_conv2d(128, 128, kernel_size = kernel_size, stride=1) 208 | self.conv4 = scale_conv2d(128, 256, kernel_size = kernel_size, stride=2) 209 | self.conv4_1 = scale_conv2d(256, 256, kernel_size = kernel_size, stride=1) 210 | 211 | self.deconv3 = scale_deconv2d(256, 64) 212 | self.deconv2 = scale_deconv2d(192, 32) 213 | self.deconv1 = scale_deconv2d(96, 16) 214 | self.deconv0 = scale_deconv2d(48, 8) 215 | 216 | self.output_layer = scale_conv2d(8 + input_channels, output_channels, kernel_size=kernel_size, activation = False, sout = 1) 217 | 218 | def forward(self, x): 219 | 220 | out_conv1 = self.conv1(x) 221 | 222 | 223 | out_conv2 = self.conv2_2(self.conv2(out_conv1)) 224 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 225 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 226 | out_deconv3 = self.deconv3(out_conv4) 227 | 228 | concat3 = torch.cat((out_conv3, out_deconv3), 2) 229 | out_deconv2 = self.deconv2(concat3) 230 | 231 | concat2 = torch.cat((out_conv2, out_deconv2), 2) 232 | out_deconv1 = self.deconv1(concat2) 233 | 234 | concat1 = torch.cat((out_conv1, out_deconv1), 2) 235 | out_deconv0 = self.deconv0(concat1) 236 | 237 | concat0 = torch.cat((x.reshape([x.shape[0], x.shape[1], -1, x.shape[4], x.shape[5]]), out_deconv0), 2) 238 | out = self.output_layer(concat0) 239 | out = out.squeeze(1) 240 | return out 241 | 242 | 243 | -------------------------------------------------------------------------------- /models/model_um.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uniform Motion Equivariant ResNet and U-net 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | class um_conv2d(nn.Module): 13 | def __init__(self, 14 | input_channels, 15 | output_channels, 16 | kernel_size, 17 | um_dim = 2, # The number of channels of each frame 18 | activation = True, # whether to use activation functions 19 | stride = 1, 20 | deconv = False # Whether this is used as a deconvolutional layer 21 | ): 22 | """ 23 | Uniform Motion Equivariant 2D Convolutional Layers 24 | """ 25 | super(um_conv2d, self).__init__() 26 | self.activation = activation 27 | self.input_channels = input_channels 28 | self.kernel_size = kernel_size 29 | self.um_dim = um_dim 30 | self.stride = stride 31 | 32 | # Applied to unfolded input tensors (stride == kernel_size) 33 | self.conv2d = nn.Conv2d(input_channels, output_channels, kernel_size, stride = kernel_size, bias = True) 34 | self.pad_size = (kernel_size - 1)//2 35 | self.input_channels = self.input_channels 36 | self.batchnorm = nn.BatchNorm2d(output_channels) 37 | self.deconv = deconv 38 | 39 | 40 | def unfold(self, x): 41 | """ 42 | Extracts sliding local blocks from a batched input tensor. 43 | """ 44 | if not self.deconv: 45 | x = F.pad(x, ((self.pad_size, self.pad_size)*2), mode='replicate') 46 | 47 | # Unfold input tensor 48 | out = F.unfold(x, kernel_size = self.kernel_size) 49 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, out.shape[-1]) 50 | 51 | # Batch_size x (in_channels x kernel_size x kernel_size) x 64 x 64 52 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, int(np.sqrt(out.shape[-1])), int(np.sqrt(out.shape[-1]))) 53 | 54 | if self.stride > 1: 55 | return out[:,:,:,:,::self.stride,::self.stride] 56 | return out 57 | 58 | def subtract_mean(self, x): 59 | """ 60 | Shifts the mean of input sliding local blocks to zero. 61 | """ 62 | # Calculates and Subtracts mean velocity 63 | out = x.reshape(x.shape[0], self.input_channels//self.um_dim, self.um_dim, self.kernel_size, self.kernel_size, x.shape[-2], x.shape[-1]) 64 | avgs = out.mean((1,3,4), keepdim=True) 65 | out = out - avgs 66 | 67 | # Reshape the input 68 | out = out.reshape(out.shape[0], self.input_channels, self.kernel_size, self.kernel_size, x.shape[-2], x.shape[-1]).transpose(2,4).transpose(-1,-2) 69 | out = out.reshape(out.shape[0], self.input_channels, x.shape[-2]*self.kernel_size, x.shape[-1], self.kernel_size) 70 | 71 | # Batch_size x in_channels x (64 x kernel_size) x (64 x kernel_size) 72 | out = out.reshape(out.shape[0], self.input_channels, x.shape[-2]*self.kernel_size, x.shape[-1]*self.kernel_size) 73 | return out, avgs.squeeze(3).squeeze(3) 74 | 75 | 76 | def add_mean(self, out, avgs): 77 | """ 78 | Shifts the mean of the outputs back. 79 | """ 80 | out = out.reshape(out.shape[0], out.shape[1]//self.um_dim, self.um_dim, out.shape[-2], out.shape[-1]) 81 | out = out + avgs 82 | out = out.reshape(out.shape[0], -1, out.shape[-2], out.shape[-1]) 83 | return out 84 | 85 | def forward(self, x, shift_back = True): 86 | x = self.unfold(x) 87 | x, avgs = self.subtract_mean(x) 88 | out = self.conv2d(x) 89 | 90 | if self.activation: 91 | out = F.relu(out) 92 | 93 | if shift_back: 94 | out = self.add_mean(out, avgs) 95 | return out 96 | 97 | 98 | class um_deconv2d(nn.Module): 99 | def __init__(self, input_channels, output_channels): 100 | """ 101 | Uniform Motion Equivariant 2D Transposed Convolutional Layers 102 | """ 103 | super(um_deconv2d, self).__init__() 104 | self.conv2d = um_conv2d(input_channels = input_channels, output_channels = output_channels, kernel_size = 4, um_dim = 2, 105 | activation = True, stride = 1, deconv = True) 106 | 107 | def pad(self, x): 108 | # Add padding inside of the input tensor 109 | # To preserve uniform motion equivariance, we use the input tensor itself as padding instead of zero. 110 | # Because the padding zero would change the original mean of input sliding blocks. 111 | pad_x = torch.zeros(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]*2) 112 | pad_x[:,:,::2,::2].copy_(x) 113 | pad_x[:,:,1::2,::2].copy_(x) 114 | pad_x[:,:,::2,1::2].copy_(x) 115 | pad_x[:,:,1::2,1::2].copy_(x) 116 | pad_x = F.pad(pad_x, (1,2,1,2), mode='replicate') 117 | return pad_x 118 | 119 | def forward(self, x): 120 | out = self.pad(x).to(device) 121 | return self.conv2d(out) 122 | 123 | 124 | # Uniform Motion Equivariant U_net. 125 | class Unet_UM(nn.Module): 126 | def __init__(self, input_channels, output_channels, kernel_size): 127 | super(Unet_UM, self).__init__() 128 | self.input_channels = input_channels 129 | self.conv1 = um_conv2d(input_channels, 64, kernel_size = kernel_size, stride=2) 130 | self.conv2 = um_conv2d(64, 128, kernel_size = kernel_size, stride=2) 131 | self.conv2_2 = um_conv2d(128, 128, kernel_size = kernel_size, stride = 1) 132 | self.conv3 = um_conv2d(128, 256, kernel_size = kernel_size, stride=2) 133 | self.conv3_1 = um_conv2d(256, 256, kernel_size = kernel_size, stride=1) 134 | self.conv4 = um_conv2d(256, 512, kernel_size = kernel_size, stride=2) 135 | self.conv4_1 = um_conv2d(512, 512, kernel_size = kernel_size, stride=1) 136 | 137 | self.deconv3 = um_deconv2d(512, 128) 138 | self.deconv2 = um_deconv2d(384, 64) 139 | self.deconv1 = um_deconv2d(192, 32) 140 | self.deconv0 = um_deconv2d(96, 16) 141 | 142 | self.output_layer = um_conv2d(16 + input_channels, output_channels, kernel_size=kernel_size, activation = False) 143 | 144 | def forward(self, x): 145 | out_conv1 = self.conv1(x) 146 | out_conv2 = self.conv2_2(self.conv2(out_conv1)) 147 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 148 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 149 | 150 | out_deconv3 = self.deconv3(out_conv4) 151 | concat3 = torch.cat((out_conv3, out_deconv3), 1) 152 | out_deconv2 = self.deconv2(concat3) 153 | concat2 = torch.cat((out_conv2, out_deconv2), 1) 154 | out_deconv1 = self.deconv1(concat2) 155 | concat1 = torch.cat((out_conv1, out_deconv1), 1) 156 | out_deconv0 = self.deconv0(concat1) 157 | concat0 = torch.cat((x, out_deconv0), 1) 158 | out = self.output_layer(concat0) 159 | return out 160 | 161 | 162 | # Uniform Motion Equivariant ResNet. 163 | class um_resblock(nn.Module): 164 | def __init__(self, input_channels, hidden_dim, kernel_size): 165 | super(um_resblock, self).__init__() 166 | self.layer1 = um_conv2d(input_channels, hidden_dim, kernel_size) 167 | self.layer2 = um_conv2d(hidden_dim, hidden_dim, kernel_size) 168 | if input_channels != hidden_dim: 169 | self.upscale = um_conv2d(input_channels, hidden_dim, kernel_size) 170 | 171 | self.input_channels = input_channels 172 | self.hidden_dim = hidden_dim 173 | 174 | def forward(self, x): 175 | out = self.layer1(x) 176 | if self.input_channels == self.hidden_dim: 177 | out = self.layer2(out, False) + x 178 | else: 179 | out = self.layer2(out, False) + self.upscale(x) 180 | 181 | return out 182 | 183 | class ResNet_UM(nn.Module): 184 | def __init__(self, input_channels, output_channels, kernel_size): 185 | super(ResNet_UM, self).__init__() 186 | layers = [um_resblock(input_channels, 64, kernel_size), um_resblock(64, 64, kernel_size)] 187 | layers += [um_resblock(64, 128, kernel_size), um_resblock(128, 128, kernel_size)] 188 | layers += [um_resblock(128, 256, kernel_size), um_resblock(256, 256, kernel_size)] 189 | layers += [um_resblock(256, 512, kernel_size), um_resblock(512, 512, kernel_size)] 190 | layers += [um_conv2d(512, output_channels, kernel_size = kernel_size, activation = False)] 191 | self.model = nn.Sequential(*layers) 192 | 193 | def forward(self, x): 194 | out = self.model(x) 195 | return out -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.2 2 | pandas==1.5.3 3 | pytorch==1.11.0 4 | python=3.9.16 5 | cudatoolkit==11.7.64 6 | pip==23.0.1 7 | conda==4.14.0 8 | opencv-python==4.7.0.72 9 | torchvision==0.15.2 10 | e2cnn 11 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # RBC Data 3 | python3 run_model.py --dataset=RBC --architecture=ResNet --symmetry=None --output_length=3 --learning_rate=0.001 4 | python3 run_model.py --dataset=RBC --architecture=Unet --symmetry=None --output_length=4 --learning_rate=0.001 5 | python3 run_model.py --dataset=RBC --architecture=ResNet --symmetry=UM --output_length=3 --learning_rate=0.001 6 | python3 run_model.py --dataset=RBC --architecture=Unet --symmetry=UM --output_length=4 --learning_rate=0.001 7 | python3 run_model.py --dataset=RBC --architecture=ResNet --symmetry=Rot --output_length=3 --learning_rate=0.001 8 | python3 run_model.py --dataset=RBC --architecture=Unet --symmetry=Rot --output_length=4 --learning_rate=0.001 9 | python3 run_model.py --dataset=RBC --architecture=ResNet --symmetry=Mag --output_length=3 --learning_rate=0.005 10 | python3 run_model.py --dataset=RBC --architecture=Unet --symmetry=Mag --output_length=4 --learning_rate=0.005 11 | python3 run_model.py --dataset=RBC --architecture=ResNet --symmetry=Scale --output_length=3 --learning_rate=0.0001 12 | python3 run_model.py --dataset=RBC --architecture=Unet --symmetry=Scale --output_length=4 --learning_rate=0.0001 13 | 14 | # Ocean Currents 15 | # to reproduce the numbers, please run five times with different random seeds 16 | python3 run_model.py --dataset=Ocean --architecture=ResNet --symmetry=None --output_length=4 --learning_rate=0.001 --input_length=24 --batch_size=32 --decay_rate=0.9 --seed=0 17 | python3 run_model.py --dataset=Ocean --architecture=Unet --symmetry=None --output_length=5 --learning_rate=0.001 --input_length=21 --batch_size=16 --decay_rate=0.9 --seed=0 18 | python3 run_model.py --dataset=Ocean --architecture=ResNet --symmetry=Rot --output_length=3 --learning_rate=0.001 --input_length=21 --batch_size=16 --decay_rate=0.9 --seed=0 19 | python3 run_model.py --dataset=Ocean --architecture=Unet --symmetry=Rot --output_length=3 --learning_rate=0.001 --input_length=21 --batch_size=64 --decay_rate=0.9 --seed=0 20 | python3 run_model.py --dataset=Ocean --architecture=ResNet --symmetry=UM --output_length=5 --learning_rate=0.001 --input_length=24 --batch_size=32 --decay_rate=0.9 --seed=0 21 | python3 run_model.py --dataset=Ocean --architecture=Unet --symmetry=UM --output_length=5 --learning_rate=0.001 --input_length=24 --batch_size=32 --decay_rate=0.9 --seed=0 22 | python3 run_model.py --dataset=Ocean --architecture=ResNet --symmetry=Mag --output_length=5 --learning_rate=0.001 --input_length=21 --batch_size=32 --decay_rate=0.9 --seed=0 23 | python3 run_model.py --dataset=Ocean --architecture=Unet --symmetry=Mag --output_length=6 --learning_rate=0.001 --input_length=21 --batch_size=32 --decay_rate=0.9 --seed=0 24 | python3 run_model.py --dataset=Ocean --architecture=ResNet --symmetry=Scale --output_length=3 --learning_rate=0.0001 --input_length=26 --batch_size=32 --decay_rate=0.9 --kernel_size=5 --seed=0 25 | python3 run_model.py --dataset=Ocean --architecture=Unet --symmetry=Scale --output_length=3 --learning_rate=0.0001 --input_length=26 --batch_size=16 --decay_rate=0.9 --kernel_size=5 --seed=0 26 | -------------------------------------------------------------------------------- /run_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | import numpy as np 6 | import torch.nn as nn 7 | import random 8 | import torch.optim as optim 9 | from torch.utils import data 10 | import torch.nn.functional as F 11 | from evaluation import spectrum_band 12 | from models import ResNet, Unet, ResNet_UM, Unet_UM, ResNet_Mag, Unet_Mag, ResNet_Rot, Unet_Rot, ResNet_Scale, Unet_Scale 13 | import matplotlib.pyplot as plt 14 | from utils import train_epoch, eval_epoch, test_epoch, Dataset, get_lr, train_epoch_scale, eval_epoch_scale, test_epoch_scale, Dataset_scale 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | 18 | ###### Hyperparameter ###### 19 | 20 | parser = argparse.ArgumentParser(description='Deep Equivariant Dynamics Models') 21 | parser.add_argument('--dataset', type=str, required=False, default="RBC", help='RBC or Ocean') 22 | parser.add_argument('--kernel_size', type=int, required=False, default="3", help='convolution kernel size') 23 | parser.add_argument('--symmetry', type=str, required=False, default="UM", help='None, UM, Rot, Mag, Scale') 24 | parser.add_argument('--architecture', type=str, required=False, default="ResNet", help='ResNet or Unet') 25 | parser.add_argument('--output_length', type=int, required=False, default="4", help='number of prediction losses used for backpropagation') 26 | parser.add_argument('--input_length', type=int, required=False, default="24", help='input length') 27 | parser.add_argument('--batch_size', type=int, required=False, default="16", help='batch size') 28 | parser.add_argument('--num_epoch', type=int, required=False, default="100", help='maximum number of epochs') 29 | parser.add_argument('--learning_rate', type=float, required=False, default="0.001", help='learning rate') 30 | parser.add_argument('--decay_rate', type=float, required=False, default="0.95", help='learning decay rate') 31 | parser.add_argument('--seed', type=int, required=False, default="0", help='random seed') 32 | args = parser.parse_args() 33 | 34 | 35 | random.seed(args.seed) # python random generator 36 | np.random.seed(args.seed) # numpy random generator 37 | 38 | torch.manual_seed(args.seed) 39 | torch.cuda.manual_seed_all(args.seed) 40 | 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | 44 | 45 | symmetry = args.symmetry 46 | model_name = args.architecture + "_" + args.symmetry 47 | num_epoch = args.num_epoch 48 | learning_rate = args.learning_rate # 0.0005 for mag_equ resnet; 0.0001 for scale_equ resnet 49 | batch_size = args.batch_size 50 | input_length = args.input_length 51 | train_output_length = args.output_length # 4 for all Unets 52 | test_output_length = 10 53 | kernel_size = args.kernel_size 54 | lr_decay = args.decay_rate 55 | ########################### 56 | 57 | ########## Data ########### 58 | if args.dataset == "RBC": 59 | train_direc = "data_64/sample_" 60 | valid_direc = "data_64/sample_" 61 | train_indices = list(range(0, 6000)) 62 | valid_indices = list(range(6000, 8000)) 63 | 64 | # test on future time steps 65 | test_future_direc = "data_64/sample_" 66 | test_future_indices = list(range(8000, 10000)) 67 | 68 | # test on data applied with symmetry transformations 69 | test_domain_direc = "data_64/sample_" if args.symmetry == "None" else "data_" + symmetry.lower() + "/sample_" 70 | print(test_domain_direc) 71 | test_domain_indices = list(range(8000, 10000)) 72 | 73 | elif args.dataset == "Ocean": 74 | train_direc = "ocean_train/sample_" 75 | valid_direc = "ocean_train/sample_" 76 | train_indices = list(range(0, 8000)) 77 | valid_indices = list(range(8000, 10000)) 78 | 79 | # test on future time steps 80 | test_future_direc = "ocean_train/sample_" 81 | test_future_indices = list(range(10000, 12000)) 82 | 83 | # test on data from different domain 84 | test_domain_direc = "ocean_test/sample_" 85 | test_domain_indices = list(range(0, 2000)) 86 | 87 | else: 88 | print("Invalid dataset name entered!") 89 | 90 | if symmetry != "Scale": 91 | train_set = Dataset(train_indices, input_length, 30, train_output_length, train_direc, True) 92 | valid_set = Dataset(valid_indices, input_length, 30, train_output_length, valid_direc, True) 93 | test_future_set = Dataset(test_future_indices, input_length, 30, test_output_length, test_future_direc, True) 94 | test_domain_set = Dataset(test_domain_indices, input_length, 30, test_output_length, test_domain_direc, True) 95 | else: 96 | # use Dataset_scale for scale equivariant models 97 | train_set = Dataset_scale(train_indices, input_length, 30, train_output_length, train_direc) 98 | valid_set = Dataset_scale(valid_indices, input_length, 30, train_output_length, train_direc) 99 | test_future_set = Dataset_scale(test_future_indices, input_length, 30, test_output_length, test_future_direc) 100 | test_domain_set = Dataset_scale(test_domain_indices, input_length, 30, test_output_length, test_domain_direc) 101 | 102 | train_loader = data.DataLoader(train_set, batch_size = batch_size, shuffle = True, num_workers = 8, pin_memory=True) 103 | valid_loader = data.DataLoader(valid_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True) 104 | test_future_loader = data.DataLoader(test_future_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True) 105 | test_domain_loader = data.DataLoader(test_domain_set, batch_size = batch_size, shuffle = False, num_workers = 8, pin_memory=True) 106 | 107 | 108 | 109 | save_name = args.dataset + "_model{}_bz{}_inp{}_pred{}_lr{}_decay{}_kernel{}_seed{}".format(model_name, 110 | batch_size, 111 | input_length, 112 | train_output_length, 113 | learning_rate, 114 | lr_decay, 115 | kernel_size, 116 | args.seed) 117 | 118 | print(save_name) 119 | ####### Select Model ####### 120 | if model_name == "ResNet_UM": 121 | model = nn.DataParallel(ResNet_UM(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 122 | elif model_name == "Unet_UM": 123 | model = nn.DataParallel(Unet_UM(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 124 | elif model_name == "ResNet_Rot": 125 | model = nn.DataParallel(ResNet_Rot(input_frames = input_length, output_frames = 1, kernel_size = kernel_size, N = 8).to(device)) 126 | elif model_name == "Unet_Rot": 127 | model = nn.DataParallel(Unet_Rot(input_frames = input_length, output_frames = 1, kernel_size = kernel_size, N = 8).to(device)) 128 | elif model_name == "ResNet_Mag": 129 | model = nn.DataParallel(ResNet_Mag(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 130 | elif model_name == "Unet_Mag": 131 | model = nn.DataParallel(Unet_Mag(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 132 | elif model_name == "ResNet_Scale": 133 | model = nn.DataParallel(ResNet_Scale(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 134 | elif model_name == "Unet_Scale": 135 | model = nn.DataParallel(Unet_Scale(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 136 | elif model_name == "ResNet_None": 137 | model = nn.DataParallel(ResNet(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 138 | elif model_name == "Unet_None": 139 | model = nn.DataParallel(Unet(input_channels = input_length*2, output_channels = 2, kernel_size = kernel_size).to(device)) 140 | else: 141 | print("Invalid model name entered!") 142 | 143 | 144 | optimizer = torch.optim.Adam(model.parameters(), learning_rate,betas=(0.9, 0.999), weight_decay=4e-4) 145 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=lr_decay) 146 | loss_fun = torch.nn.MSELoss() 147 | 148 | min_rmse = 1e6 149 | train_rmse = [] 150 | valid_rmse = [] 151 | test_rmse = [] 152 | 153 | for i in range(num_epoch): 154 | start = time.time() 155 | 156 | if symmetry != "Scale": 157 | model.train() 158 | train_rmse.append(train_epoch(train_loader, model, optimizer, loss_fun)) 159 | model.eval() 160 | rmse, _, _ = eval_epoch(valid_loader, model, loss_fun) 161 | valid_rmse.append(rmse) 162 | else: 163 | model.train() 164 | train_rmse.append(train_epoch_scale(train_loader, model, optimizer, loss_fun)) 165 | model.eval() 166 | rmse, _, _ = eval_epoch_scale(valid_loader, model, loss_fun) 167 | valid_rmse.append(rmse) 168 | 169 | if valid_rmse[-1] < min_rmse: 170 | min_rmse = valid_rmse[-1] 171 | best_model = model 172 | end = time.time() 173 | 174 | # Early Stopping but train at least for 50 epochs 175 | if (len(train_rmse) > 50 and np.mean(valid_rmse[-5:]) >= np.mean(valid_rmse[-10:-5])): 176 | break 177 | print("Epoch {} | T: {:0.2f} | Train RMSE: {:0.3f} | Valid RMSE: {:0.3f}".format(i + 1, (end-start) / 60, train_rmse[-1], valid_rmse[-1])) 178 | scheduler.step() 179 | 180 | 181 | if symmetry != "Scale": 182 | test_future_rmse, test_future_preds, test_future_trues, test_future_loss_curve = test_epoch(test_future_loader, best_model, loss_fun) 183 | test_domain_rmse, test_domain_preds, test_domain_trues, test_domain_loss_curve = test_epoch(test_domain_loader, best_model, loss_fun) 184 | else: 185 | test_future_rmse, test_future_preds, test_future_trues, test_future_loss_curve = test_epoch_scale(test_future_loader, best_model, loss_fun) 186 | test_domain_rmse, test_domain_preds, test_domain_trues, test_domain_loss_curve = test_epoch_scale(test_domain_loader, best_model, loss_fun) 187 | 188 | # Compute Energy Spectrum Errors 189 | test_future_ese = np.sqrt(np.mean((spectrum_band(test_future_preds) - spectrum_band(test_future_trues))**2)) 190 | test_domain_ese = np.sqrt(np.mean((spectrum_band(test_domain_preds) - spectrum_band(test_domain_trues))**2)) 191 | print("Model: {} | Symmetry: {} | Future RMSE: {:0.3f} | Future ESE: {:0.3f} | Domain RMSE: {:0.3f} | Domain ESE: {:0.3f} ".format(args.architecture, 192 | args.symmetry, 193 | test_future_rmse, 194 | test_future_ese, 195 | test_domain_rmse, 196 | test_domain_ese)) 197 | 198 | torch.save({"test_future": [test_future_rmse, test_future_ese, test_future_preds[::10], test_future_trues[::10]], 199 | "test_domain": [test_domain_rmse, test_domain_ese, test_domain_preds[::10], test_domain_trues[::10]]}, 200 | save_name + ".pt") 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import pandas as pd 7 | from torch.utils import data 8 | import itertools 9 | import re 10 | import random 11 | import time 12 | from torch.autograd import Variable 13 | import math 14 | from scipy.ndimage import gaussian_filter 15 | from torch.autograd import Variable 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | def get_lr(optimizer): 19 | for param_group in optimizer.param_groups: 20 | return param_group['lr'] 21 | 22 | class Dataset(data.Dataset): 23 | def __init__(self, indices, input_length, mid, output_length, direc, stack_x): 24 | self.input_length = input_length 25 | self.mid = mid 26 | self.output_length = output_length 27 | self.stack_x = stack_x 28 | self.direc = direc 29 | self.list_IDs = indices 30 | 31 | def __len__(self): 32 | return len(self.list_IDs) 33 | 34 | def __getitem__(self, index): 35 | ID = self.list_IDs[index] 36 | y = torch.load(self.direc + str(ID) + ".pt")[self.mid:(self.mid+self.output_length)] 37 | if self.stack_x: 38 | x = torch.load(self.direc + str(ID) + ".pt")[(self.mid-self.input_length):self.mid].reshape(-1, y.shape[-2], y.shape[-1]) 39 | else: 40 | x = torch.load(self.direc + str(ID) + ".pt")[(self.mid-self.input_length):self.mid] 41 | 42 | return x.float(), y.float() 43 | 44 | def train_epoch(train_loader, model, optimizer, loss_function): 45 | train_mse = [] 46 | for xx, yy in train_loader: 47 | xx = xx.to(device) 48 | yy = yy.to(device) 49 | loss = 0 50 | ims = [] 51 | for y in yy.transpose(0,1): 52 | im = model(xx) 53 | xx = torch.cat([xx[:, 2:], im], 1) 54 | loss += loss_function(im, y) 55 | train_mse.append(loss.item()/yy.shape[1]) 56 | optimizer.zero_grad() 57 | loss.backward() 58 | optimizer.step() 59 | 60 | train_mse = round(np.sqrt(np.mean(train_mse)),5) 61 | return train_mse 62 | 63 | 64 | 65 | def eval_epoch(valid_loader, model, loss_function): 66 | valid_mse = [] 67 | preds = [] 68 | trues = [] 69 | with torch.no_grad(): 70 | for xx, yy in valid_loader: 71 | xx = xx.to(device) 72 | yy = yy.to(device) 73 | loss = 0 74 | ims = [] 75 | for y in yy.transpose(0,1): 76 | im = model(xx) 77 | xx = torch.cat([xx[:, 2:], im], 1) 78 | loss += loss_function(im, y) 79 | ims.append(im.unsqueeze(1).cpu().data.numpy()) 80 | 81 | ims = np.concatenate(ims, axis = 1) 82 | preds.append(ims) 83 | trues.append(yy.cpu().data.numpy()) 84 | valid_mse.append(loss.item()/yy.shape[1]) 85 | preds = np.concatenate(preds, axis = 0) 86 | trues = np.concatenate(trues, axis = 0) 87 | valid_mse = round(np.sqrt(np.mean(valid_mse)), 5) 88 | return valid_mse, preds, trues 89 | 90 | def test_epoch(valid_loader, model, loss_function): 91 | valid_mse = [] 92 | preds = [] 93 | trues = [] 94 | with torch.no_grad(): 95 | loss_curve = [] 96 | for xx, yy in valid_loader: 97 | xx = xx.to(device) 98 | yy = yy.to(device) 99 | loss = 0 100 | ims = [] 101 | 102 | for y in yy.transpose(0,1): 103 | im = model(xx) 104 | xx = torch.cat([xx[:, 2:], im], 1) 105 | mse = loss_function(im, y) 106 | loss += mse 107 | loss_curve.append(mse.item()) 108 | ims.append(im.unsqueeze(1).cpu().data.numpy()) 109 | 110 | ims = np.concatenate(ims, axis = 1) 111 | preds.append(ims) 112 | trues.append(yy.cpu().data.numpy()) 113 | valid_mse.append(loss.item()/yy.shape[1]) 114 | 115 | loss_curve = np.array(loss_curve).reshape(-1,yy.shape[1]) 116 | preds = np.concatenate(preds, axis = 0) 117 | trues = np.concatenate(trues, axis = 0) 118 | valid_mse = np.mean(valid_mse) 119 | loss_curve = np.sqrt(np.mean(loss_curve, axis = 0)) 120 | return valid_mse, preds, trues, loss_curve 121 | 122 | 123 | 124 | ### Functions for Scale equivariant models ########## 125 | 126 | def get_lr(optimizer): 127 | for param_group in optimizer.param_groups: 128 | return param_group['lr'] 129 | 130 | class gaussain_blur(nn.Module): 131 | def __init__(self, size, sigma, dim, channels): 132 | super(gaussain_blur, self).__init__() 133 | self.kernel = self.gaussian_kernel(size, sigma, dim, channels).to(device) 134 | 135 | def gaussian_kernel(self, size, sigma, dim, channels): 136 | 137 | kernel_size = 2*size + 1 138 | kernel_size = [kernel_size] * dim 139 | kernel = 1 140 | meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size]) 141 | 142 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 143 | mean = (size - 1) / 2 144 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / (2 * std)) ** 2) 145 | 146 | # Make sure sum of values in gaussian kernel equals 1. 147 | kernel = kernel / torch.sum(kernel) 148 | 149 | # Reshape to depthwise convolutional weight 150 | kernel = kernel.view(1, 1, *kernel.size()) 151 | kernel = kernel.repeat(1, channels, 1, 1, 1) 152 | 153 | return kernel 154 | 155 | def forward(self, xx): 156 | xx = xx.reshape(xx.shape[0]*2, 1, xx.shape[2], xx.shape[3], xx.shape[4]) 157 | xx = F.conv3d(xx, self.kernel, padding = (self.kernel.shape[-1]-1)//2) 158 | return xx.reshape(xx.shape[0]//2, 2, xx.shape[2], xx.shape[3], xx.shape[4]) 159 | 160 | 161 | def blur_input(xx): 162 | out = [] 163 | for s in np.linspace(-1, 1, 5): 164 | if s > 0: 165 | blur = gaussain_blur(size = np.ceil(s), sigma = [s**2, s, s], dim = 3, channels = 1).to(device) 166 | out.append(blur(xx).unsqueeze(1)*(s+1)) 167 | elif s<0: 168 | out.append(xx.unsqueeze(1)*(1/(np.abs(s)+1))) 169 | else: 170 | out.append(xx.unsqueeze(1)) 171 | out = torch.cat(out, dim = 1) 172 | return out 173 | 174 | class Dataset_scale(data.Dataset): 175 | def __init__(self, indices, input_length, mid, output_length, direc): 176 | self.input_length = input_length 177 | self.mid = mid 178 | self.output_length = output_length 179 | self.direc = direc 180 | self.list_IDs = indices 181 | 182 | def __len__(self): 183 | return len(self.list_IDs) 184 | 185 | def __getitem__(self, index): 186 | ID = self.list_IDs[index] 187 | x = torch.load(self.direc + str(ID) + ".pt")[(self.mid-self.input_length):self.mid].transpose(0,1) 188 | y = torch.load(self.direc + str(ID) + ".pt")[self.mid:(self.mid+self.output_length)].transpose(0,1) 189 | return x.float(), y.float() 190 | 191 | # Training functions for scale equivariant models. 192 | def train_epoch_scale(train_loader, model, optimizer, loss_function): 193 | train_mse = [] 194 | for xx, yy in train_loader: 195 | xx = xx.to(device) 196 | yy = yy.to(device) 197 | loss = 0 198 | ims = [] 199 | for i in range(yy.shape[2]): 200 | blur_xx = blur_input(xx) 201 | im = model(blur_xx) 202 | # print(xx.shape, im.shape) 203 | xx = torch.cat([xx[:, :, 1:], im.unsqueeze(2)], 2) 204 | loss += loss_function(im, yy[:,:,i]) 205 | train_mse.append(loss.item()/yy.shape[2]) 206 | optimizer.zero_grad() 207 | loss.backward() 208 | optimizer.step() 209 | train_mse = round(np.sqrt(np.mean(train_mse)),5) 210 | return train_mse 211 | 212 | 213 | 214 | def eval_epoch_scale(valid_loader, model, loss_function): 215 | valid_mse = [] 216 | preds = [] 217 | trues = [] 218 | with torch.no_grad(): 219 | for xx, yy in valid_loader: 220 | xx = xx.to(device) 221 | yy = yy.to(device) 222 | loss = 0 223 | ims = [] 224 | for i in range(yy.shape[2]): 225 | blur_xx = blur_input(xx) 226 | im = model(blur_xx) 227 | xx = torch.cat([xx[:, :, 1:], im.unsqueeze(2)], 2) 228 | loss += loss_function(im, yy[:,:,i]) 229 | ims.append(im.unsqueeze(2).cpu().data.numpy()) 230 | 231 | valid_mse.append(loss.item()/yy.shape[2]) 232 | ims = np.concatenate(ims, axis = 2) 233 | preds.append(ims) 234 | trues.append(yy.cpu().data.numpy()) 235 | try: 236 | preds = np.concatenate(preds, axis = 0) 237 | trues = np.concatenate(trues, axis = 0) 238 | except: 239 | print("can't concatenate") 240 | valid_mse = round(np.sqrt(np.mean(valid_mse)), 5) 241 | return valid_mse, preds, trues 242 | 243 | def test_epoch_scale(valid_loader, model, loss_function): 244 | valid_mse = [] 245 | preds = [] 246 | trues = [] 247 | with torch.no_grad(): 248 | loss_curve = [] 249 | for xx, yy in valid_loader: 250 | xx = xx.to(device) 251 | yy = yy.to(device) 252 | loss = 0 253 | ims = [] 254 | 255 | for i in range(yy.shape[2]): 256 | im = model(xx) 257 | xx = torch.cat([xx[:, :, 1:], im.unsqueeze(2)], 2) 258 | mse = loss_function(im, yy[:,:,i]) 259 | loss += mse 260 | loss_curve.append(mse.item()) 261 | ims.append(im.unsqueeze(2).cpu().data.numpy()) 262 | 263 | ims = np.concatenate(ims, axis = 2) 264 | preds.append(ims) 265 | trues.append(yy.cpu().data.numpy()) 266 | valid_mse.append(loss.item()/yy.shape[1]) 267 | loss_curve = np.array(loss_curve).reshape(-1,yy.shape[1]) 268 | valid_mse = round(np.mean(valid_mse), 5) 269 | loss_curve = np.sqrt(np.mean(loss_curve, axis = 0)) 270 | return valid_mse, preds, trues, loss_curve 271 | 272 | --------------------------------------------------------------------------------