├── .github └── workflows │ └── python-app.yml ├── API ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── dataloader_caltech.cpython-38.pyc │ ├── dataloader_moving_mnist.cpython-38.pyc │ ├── dataloader_taxibj.cpython-38.pyc │ ├── metrics.cpython-38.pyc │ └── recorder.cpython-38.pyc ├── dataloader.py ├── dataloader_caltech.py ├── dataloader_caltech0.py ├── dataloader_moving_mnist.py ├── dataloader_moving_mnist_v2.py ├── dataloader_sevir.py ├── dataloader_taxibj.py ├── metrics.py └── recorder.py ├── DiscreteSTModel └── log.log ├── README.md ├── __pycache__ ├── exp_vq.cpython-38.pyc ├── params.cpython-38.pyc └── utils.cpython-38.pyc ├── data └── moving_mnist │ └── download_mmnist.sh ├── exp_vq.py ├── figure └── mm.png ├── id_estimate.py ├── log.log ├── models ├── Fourier.py ├── PastNet_Model.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-38.pyc └── vqvae.ckpt ├── modules ├── DiscreteSTModel_modules.py ├── DiscreteSTModel_modules_BN.py ├── DiscreteSTModel_modules_GN.py ├── Fourier_modules.py ├── STConvEncoderDecoder_modules.py ├── __init__.py └── __pycache__ │ ├── DiscreteSTModel_modules.cpython-38.pyc │ ├── Fourier_modules.cpython-38.pyc │ ├── Fourier_modules.cpython-39.pyc │ ├── STConvEncoderDecoder_modules.cpython-38.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── params.py ├── results └── DiscreteSTModel │ └── log.log ├── train_model.py └── utils.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ "main" ] 9 | pull_request: 10 | branches: [ "main" ] 11 | 12 | permissions: 13 | contents: read 14 | 15 | jobs: 16 | build: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 3.10 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: "3.10" 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /API/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import load_data 2 | from .metrics import metric 3 | from .recorder import Recorder -------------------------------------------------------------------------------- /API/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /API/__pycache__/dataloader_caltech.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_caltech.cpython-38.pyc -------------------------------------------------------------------------------- /API/__pycache__/dataloader_moving_mnist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_moving_mnist.cpython-38.pyc -------------------------------------------------------------------------------- /API/__pycache__/dataloader_taxibj.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_taxibj.cpython-38.pyc -------------------------------------------------------------------------------- /API/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /API/__pycache__/recorder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/recorder.cpython-38.pyc -------------------------------------------------------------------------------- /API/dataloader.py: -------------------------------------------------------------------------------- 1 | from .dataloader_taxibj import load_data as load_taxibj 2 | from .dataloader_moving_mnist import load_data as load_mmnist 3 | from .dataloader_sevir import load_data as load_sevir 4 | 5 | def load_data(dataname,batch_size, val_batch_size, data_root, num_workers, **kwargs): 6 | if dataname == 'taxibj': 7 | return load_taxibj(batch_size, val_batch_size, data_root, num_workers) 8 | elif dataname == 'mmnist': 9 | return load_mmnist(batch_size, val_batch_size, data_root, num_workers) 10 | elif dataname == 'sevir': 11 | return load_sevir(batch_size, val_batch_size, data_root, num_workers) -------------------------------------------------------------------------------- /API/dataloader_caltech.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import bisect 7 | from torch.utils.data import Dataset 8 | 9 | split_string = "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46" 10 | 11 | def read_seq(path): 12 | f = open(path, 'rb+') 13 | string = f.read().decode('latin-1') 14 | str_list = string.split(split_string) 15 | # print(len(str_list)) 16 | f.close() 17 | return str_list[1:] 18 | 19 | def seq_to_images(bytes_string): 20 | res = split_string.encode('latin-1') + bytes_string.encode('latin-1') 21 | img = cv2.imdecode(np.frombuffer(res, np.uint8), cv2.IMREAD_COLOR) 22 | return img / 255.0 23 | 24 | def load_caltech(root): 25 | file_list = [file for file in os.listdir(root) if file.split('.')[-1] == "seq"] 26 | print(file_list) 27 | for file in file_list[:1]: 28 | path = os.path.join(root, file) 29 | str_list, len = read_seq(path) 30 | imgs = np.zeros([len - 1, 480, 640, 3]) 31 | idx = 0 32 | for str in str_list[1:]: 33 | imgs[idx] = seq_to_images(str) 34 | idx += 1 35 | return imgs 36 | 37 | class Caltech(Dataset): 38 | @staticmethod 39 | def cumsum(sequence): 40 | r, s = [], 0 41 | for e in sequence: 42 | l = len(e) 43 | r.append(l + s) 44 | s += l 45 | return r 46 | 47 | def __init__(self, root, is_train=True, file_list=['V001.seq'], n_frames_input=4, n_frames_output=1): 48 | super().__init__() 49 | datasets = [] 50 | for file in file_list: 51 | datasets.append(SingleCaltech(os.path.join(root, file), is_train=is_train, n_frames_input=n_frames_input, n_frames_output=n_frames_output)) 52 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 53 | self.datasets = list(datasets[:1]) 54 | self.cumulative_sizes = self.cumsum(self.datasets) 55 | 56 | self.mean = 0 57 | self.std = 1 58 | 59 | def __len__(self): 60 | return self.cumulative_sizes[-1] 61 | 62 | def __getitem__(self, idx): 63 | if idx < 0: 64 | if -idx > len(self): 65 | raise ValueError("absolute value of index should not exceed dataset length") 66 | idx = len(self) + idx 67 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 68 | if dataset_idx == 0: 69 | sample_idx = idx 70 | else: 71 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 72 | return self.datasets[dataset_idx][sample_idx] 73 | 74 | @property 75 | def cummulative_sizes(self): 76 | warnings.warn("cummulative_sizes attribute is renamed to " 77 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 78 | return self.cumulative_sizes 79 | 80 | 81 | class SingleCaltech(Dataset): 82 | def __init__(self, root, is_train=True, n_frames_input=4, n_frames_output=1): 83 | super().__init__() 84 | self.root = root 85 | if is_train: 86 | self.length = 100 87 | else: 88 | self.length = 50 89 | 90 | self.input_length = n_frames_input 91 | self.output_length = n_frames_output 92 | 93 | self.sequence = None 94 | self.get_current_data() 95 | 96 | 97 | 98 | def get_current_data(self): 99 | str_list = read_seq(self.root) 100 | if self.length == 100: 101 | str_list = str_list[:104] 102 | self.sequence = np.zeros([104, 480, 640, 3]) 103 | else: 104 | str_list = str_list[104:153] 105 | self.sequence = np.zeros([54, 480, 640, 3]) 106 | 107 | for i, str in enumerate(str_list): 108 | self.sequence[i] = seq_to_images(str) 109 | 110 | def __getitem__(self, index): 111 | input = self.sequence[index: index + self.input_length] 112 | input = np.transpose(input, (0, 3, 1, 2)) 113 | output = self.sequence[index + self.input_length: index + self.input_length + self.output_length] 114 | output = np.transpose(output, (0, 3, 1, 2)) 115 | input = torch.from_numpy(input).contiguous().float() 116 | output = torch.from_numpy(output).contiguous().float() 117 | return input, output 118 | 119 | def __len__(self): 120 | return self.length 121 | 122 | 123 | def load_data( 124 | batch_size, val_batch_size, 125 | data_root, num_workers): 126 | 127 | file_list = [file for file in os.listdir(data_root) if file.split('.')[-1] == "seq"] 128 | # print(data_root) 129 | train_set = Caltech(root=data_root, is_train=True, 130 | n_frames_input=4, n_frames_output=1, file_list=file_list) 131 | test_set = Caltech(root=data_root, is_train=False, 132 | n_frames_input=4, n_frames_output=1, file_list=file_list) 133 | 134 | dataloader_train = torch.utils.data.DataLoader( 135 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers) 136 | dataloader_validation = torch.utils.data.DataLoader( 137 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 138 | dataloader_test = torch.utils.data.DataLoader( 139 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 140 | 141 | mean, std = 0, 1 142 | return dataloader_train, dataloader_validation, dataloader_test, mean, std 143 | 144 | 145 | if __name__ == "__main__": 146 | # data = load_caltech("/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01") 147 | file_list = [file for file in os.listdir("/home/pan/workspace/simvp/SimVP/data/caltech/USA/set01") if file.split('.')[-1] == "seq"] 148 | dataset = Caltech(root="/home/pan/workspace/simvp/SimVP/data/caltech/USA/set01", is_train=False, file_list=file_list) 149 | dataloader = torch.utils.data.DataLoader( 150 | dataset, batch_size=16, shuffle=True) 151 | 152 | for input, output in dataloader: 153 | print(input.shape, output.shape) -------------------------------------------------------------------------------- /API/dataloader_caltech0.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | split_string = "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46" 9 | 10 | def read_seq(path): 11 | f = open(path, 'rb+') 12 | string = f.read().decode('latin-1') 13 | str_list = string.split(split_string) 14 | print(len(str_list)) 15 | f.close() 16 | return str_list, len(str_list) 17 | 18 | def seq_to_images(bytes_string): 19 | res = split_string.encode('latin-1') + bytes_string.encode('latin-1') 20 | img = cv2.imdecode(np.frombuffer(res, np.uint8), cv2.IMREAD_COLOR) 21 | return img / 255.0 22 | 23 | def load_caltech(root): 24 | file_list = [file for file in os.listdir(root) if file.split('.')[-1] == "seq"] 25 | print(file_list) 26 | for file in file_list[:1]: 27 | path = os.path.join(root, file) 28 | str_list, len = read_seq(path) 29 | imgs = np.zeros([len - 1, 480, 640, 3]) 30 | idx = 0 31 | for str in str_list[1:]: 32 | imgs[idx] = seq_to_images(str) 33 | idx += 1 34 | return imgs.transpose(0, 3, 1, 2) 35 | 36 | class Caltech(Dataset): 37 | def __init__(self, root, is_train=True, n_frames_input=4, n_frames_output=1): 38 | super().__init__() 39 | self.root = root 40 | print("loading .seq file list") 41 | self.file_list = [file for file in os.listdir(self.root) if file.split('.')[-1] == "seq"] 42 | 43 | if is_train: 44 | self.file_list = self.file_list[:-1] 45 | else: 46 | self.file_list = self.file_list[-1:] 47 | 48 | print("loading file list done, file list: ", self.file_list) 49 | self.length = 0 50 | 51 | self.input_length = n_frames_input 52 | self.output_length = n_frames_output 53 | 54 | self.current_seq = None 55 | self.current_length = 0 56 | self.current_file_index = 0 57 | 58 | self.get_next = True 59 | 60 | self.get_total_len(root) 61 | self.get_current_data() 62 | 63 | 64 | def get_total_len(self, root): 65 | print("calculating total length") 66 | count = 0 67 | for file in self.file_list: 68 | path = os.path.join(root, file) 69 | _, len = read_seq(path) 70 | count += (len - 5) 71 | self.length = count 72 | print("calculating total length done, total length: ", self.length) 73 | 74 | def get_current_data(self): 75 | print("getting current sequence") 76 | if self.current_file_index >= len(self.file_list): 77 | self.get_next = False 78 | return 79 | current_file = os.path.join(self.root, self.file_list[self.current_file_index]) 80 | str_list, length = read_seq(current_file) 81 | self.current_length = length - 5 82 | self.current_seq = np.zeros([length - 1, 480, 640, 3]) 83 | for i, str in enumerate(str_list[1:]): 84 | self.current_seq[i] = seq_to_images(str) 85 | print("getting current sequence done, the shape:", self.current_seq.shape) 86 | 87 | def get_next_seq(self): 88 | print("getting next sequence") 89 | self.current_file_index += 1 90 | self.get_current_data() 91 | self.get_next = False 92 | 93 | def __getitem__(self, index): 94 | if index >= self.current_length: 95 | self.get_next = True 96 | if self.get_next: 97 | self.get_next_seq() 98 | input = self.current_seq[index: index + self.input_length] 99 | output = self.current_seq[index + self.input_length: index + self.input_length + self.output_length] 100 | input = torch.from_numpy(input).contiguous().float() 101 | output = torch.from_numpy(output).contiguous().float() 102 | return input, output 103 | 104 | def __len__(self): 105 | return self.current_length 106 | 107 | 108 | def load_data( 109 | batch_size, val_batch_size, 110 | data_root, num_workers): 111 | 112 | train_set = Caltech(root=data_root, is_train=True, 113 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 114 | test_set = Caltech(root=data_root, is_train=False, 115 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 116 | 117 | dataloader_train = torch.utils.data.DataLoader( 118 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers) 119 | dataloader_validation = torch.utils.data.DataLoader( 120 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 121 | dataloader_test = torch.utils.data.DataLoader( 122 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 123 | 124 | mean, std = 0, 1 125 | return dataloader_train, dataloader_validation, dataloader_test, mean, std 126 | 127 | 128 | if __name__ == "__main__": 129 | # data = load_caltech("/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01") 130 | dataset = Caltech(root="/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01") 131 | dataloader = torch.utils.data.DataLoader( 132 | dataset, batch_size=16, shuffle=True, drop_last=True) 133 | 134 | for input, output in dataloader: 135 | print(input.shape, output.shape) -------------------------------------------------------------------------------- /API/dataloader_moving_mnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gzip 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | 8 | 9 | def load_mnist(root): 10 | # Load MNIST dataset for generating training data. 11 | path = os.path.join(root, 'moving_mnist/train-images-idx3-ubyte.gz') 12 | with gzip.open(path, 'rb') as f: 13 | mnist = np.frombuffer(f.read(), np.uint8, offset=16) 14 | mnist = mnist.reshape(-1, 28, 28) 15 | return mnist 16 | 17 | 18 | def load_fixed_set(root): 19 | # Load the fixed dataset 20 | filename = 'moving_mnist/mnist_test_seq.npy' 21 | path = os.path.join(root, filename) 22 | dataset = np.load(path) 23 | dataset = dataset[..., np.newaxis] 24 | return dataset 25 | 26 | 27 | class MovingMNIST(data.Dataset): 28 | def __init__(self, root, is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2], 29 | transform=None): 30 | super(MovingMNIST, self).__init__() 31 | 32 | self.dataset = None 33 | if is_train: 34 | self.mnist = load_mnist(root) 35 | else: 36 | if num_objects[0] != 2: 37 | self.mnist = load_mnist(root) 38 | else: 39 | self.dataset = load_fixed_set(root) 40 | self.length = int(1e4) if self.dataset is None else self.dataset.shape[1] 41 | 42 | self.is_train = is_train 43 | self.num_objects = num_objects 44 | self.n_frames_input = n_frames_input 45 | self.n_frames_output = n_frames_output 46 | self.n_frames_total = self.n_frames_input + self.n_frames_output 47 | self.transform = transform 48 | # For generating data 49 | self.image_size_ = 64 50 | self.digit_size_ = 28 51 | self.step_length_ = 0.1 52 | 53 | self.mean = 0 54 | self.std = 1 55 | 56 | def get_random_trajectory(self, seq_length): 57 | ''' Generate a random sequence of a MNIST digit ''' 58 | canvas_size = self.image_size_ - self.digit_size_ 59 | x = random.random() 60 | y = random.random() 61 | theta = random.random() * 2 * np.pi 62 | v_y = np.sin(theta) 63 | v_x = np.cos(theta) 64 | 65 | start_y = np.zeros(seq_length) 66 | start_x = np.zeros(seq_length) 67 | for i in range(seq_length): 68 | # Take a step along velocity. 69 | y += v_y * self.step_length_ 70 | x += v_x * self.step_length_ 71 | 72 | # Bounce off edges. 73 | if x <= 0: 74 | x = 0 75 | v_x = -v_x 76 | if x >= 1.0: 77 | x = 1.0 78 | v_x = -v_x 79 | if y <= 0: 80 | y = 0 81 | v_y = -v_y 82 | if y >= 1.0: 83 | y = 1.0 84 | v_y = -v_y 85 | start_y[i] = y 86 | start_x[i] = x 87 | 88 | # Scale to the size of the canvas. 89 | start_y = (canvas_size * start_y).astype(np.int32) 90 | start_x = (canvas_size * start_x).astype(np.int32) 91 | return start_y, start_x 92 | 93 | def generate_moving_mnist(self, num_digits=2): 94 | ''' 95 | Get random trajectories for the digits and generate a video. 96 | ''' 97 | data = np.zeros((self.n_frames_total, self.image_size_, 98 | self.image_size_), dtype=np.float32) 99 | for n in range(num_digits): 100 | # Trajectory 101 | start_y, start_x = self.get_random_trajectory(self.n_frames_total) 102 | ind = random.randint(0, self.mnist.shape[0] - 1) 103 | digit_image = self.mnist[ind] 104 | for i in range(self.n_frames_total): 105 | top = start_y[i] 106 | left = start_x[i] 107 | bottom = top + self.digit_size_ 108 | right = left + self.digit_size_ 109 | # Draw digit 110 | data[i, top:bottom, left:right] = np.maximum( 111 | data[i, top:bottom, left:right], digit_image) 112 | 113 | data = data[..., np.newaxis] 114 | return data 115 | 116 | def __getitem__(self, idx): 117 | length = self.n_frames_input + self.n_frames_output 118 | if self.is_train or self.num_objects[0] != 2: 119 | # Sample number of objects 120 | num_digits = random.choice(self.num_objects) 121 | # Generate data on the fly 122 | images = self.generate_moving_mnist(num_digits) 123 | else: 124 | images = self.dataset[:, idx, ...] 125 | 126 | r = 1 127 | w = int(64 / r) 128 | images = images.reshape((length, w, r, w, r)).transpose( 129 | 0, 2, 4, 1, 3).reshape((length, r * r, w, w)) 130 | 131 | input = images[:self.n_frames_input] 132 | if self.n_frames_output > 0: 133 | output = images[self.n_frames_input:length] 134 | else: 135 | output = [] 136 | 137 | output = torch.from_numpy(output / 255.0).contiguous().float() 138 | input = torch.from_numpy(input / 255.0).contiguous().float() 139 | return input, output 140 | 141 | def __len__(self): 142 | return self.length 143 | 144 | 145 | def load_data( 146 | batch_size, val_batch_size, 147 | data_root, num_workers): 148 | 149 | train_set = MovingMNIST(root=data_root, is_train=True, 150 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 151 | test_set = MovingMNIST(root=data_root, is_train=False, 152 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 153 | 154 | dataloader_train = torch.utils.data.DataLoader( 155 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers) 156 | dataloader_validation = torch.utils.data.DataLoader( 157 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 158 | dataloader_test = torch.utils.data.DataLoader( 159 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 160 | 161 | mean, std = 0, 1 162 | return dataloader_train, dataloader_validation, dataloader_test, mean, std 163 | -------------------------------------------------------------------------------- /API/dataloader_moving_mnist_v2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | 5 | 6 | class MovingMnistSequence(data.Dataset): 7 | def __init__(self, train=True, shuffle=True, root='./data', transform=None): 8 | super().__init__() 9 | if train: 10 | npz = 'mnist_train.npz' 11 | self.data = np.load(f'{root}/{npz}')['input_raw_data'] 12 | else: 13 | npz = 'mnist_train.npz' 14 | self.data = np.load(f'{root}/{npz}')['input_raw_data'][:10000] 15 | 16 | 17 | self.transform = transform 18 | self.data = self.data.transpose(0, 2, 3, 1) 19 | 20 | def __len__(self): 21 | return self.data.shape[0] // 20 22 | 23 | def __getitem__(self, index): 24 | imgs = self.data[index * 20: (index + 1) * 20] 25 | imgs_tensor = torch.zeros([20, 1, 64, 64]) 26 | if self.transform is not None: 27 | for i in range(imgs.shape[0]): 28 | imgs_tensor[i] = self.transform(imgs[i]) 29 | return imgs_tensor 30 | 31 | 32 | def load_data( 33 | batch_size, val_batch_size, 34 | data_root, num_workers): 35 | 36 | train_set = MovingMnistSequence(root=data_root, is_train=True, 37 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 38 | test_set = MovingMnistSequence(root=data_root, is_train=False, 39 | n_frames_input=10, n_frames_output=10, num_objects=[2]) 40 | 41 | dataloader_train = torch.utils.data.DataLoader( 42 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers) 43 | dataloader_validation = torch.utils.data.DataLoader( 44 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 45 | dataloader_test = torch.utils.data.DataLoader( 46 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 47 | 48 | mean, std = 0, 1 49 | return dataloader_train, dataloader_validation, dataloader_test, mean, std 50 | -------------------------------------------------------------------------------- /API/dataloader_sevir.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import random 4 | from torch.utils.data import Dataset 5 | from torch.utils.data import DataLoader 6 | import torchvision.transforms as transforms 7 | import matplotlib.pyplot as plt 8 | import torch 9 | 10 | class VilDataset(Dataset): 11 | def __init__(self, train=True, root='./data', transform=None): 12 | super().__init__() 13 | if train: 14 | npy = ['SEVIR_IR069_STORMEVENTS_2018_0101_0630.npy', 'SEVIR_IR069_STORMEVENTS_2018_0701_1231.npy'] 15 | else: 16 | npy = ['SEVIR_IR069_RANDOMEVENTS_2018_0101_0430.npy'] 17 | 18 | data = [] 19 | for file in npy: 20 | data.append(np.load(f'{root}/{file}')) 21 | 22 | self.data = np.concatenate(data) 23 | #N, L, H, W = self.data.shape 24 | # self.data = self.data.reshape([N L, H, W]) 25 | self.transform = transform 26 | self.mean = 0 27 | self.std = 1 28 | 29 | def __len__(self): 30 | return self.data.shape[0] 31 | 32 | def __getitem__(self, index): 33 | img = self.data[index].reshape(20, 1, 128, 128) 34 | if self.transform: 35 | img = self.transform(img) 36 | 37 | input_img = img[:10] 38 | output_img = img[10:] 39 | input_img = img[:10] 40 | output_img = img[10:] 41 | input_img = torch.from_numpy(input_img) 42 | output_img = torch.from_numpy(output_img) 43 | input_img = input_img.contiguous().float() 44 | output_img = output_img.contiguous().float() 45 | return input_img, output_img 46 | 47 | 48 | def load_data(batch_size, val_batch_size, 49 | data_root, num_workers): 50 | train_set = VilDataset(train=True, root='./data', transform=None) 51 | test_set = VilDataset(train=True, root='./data', transform=None) 52 | 53 | dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, 54 | num_workers=num_workers) 55 | dataloader_validation = DataLoader(test_set, batch_size=val_batch_size, shuffle=False, 56 | pin_memory=True, num_workers=num_workers) 57 | dataloader_test = DataLoader(test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, 58 | num_workers=num_workers) 59 | mean, std = 0, 1 60 | 61 | return dataloader_train, dataloader_validation, dataloader_test, mean, std 62 | 63 | 64 | if __name__ == '__main__': 65 | dataset = VilDataset(root='/root/Model_Phy/data') 66 | input_img, output_img = dataset[1] 67 | # Assuming `input_img` is a NumPy array of shape (10, 64, 64, 1) 68 | fig, axes = plt.subplots(nrows=1, ncols=10) 69 | 70 | for i in range(10): 71 | axes[i].imshow(input_img[i, :, :, 0], cmap=None) 72 | axes[i].axis('off') 73 | 74 | plt.show() 75 | 76 | fig, axes = plt.subplots(nrows=1, ncols=10) 77 | 78 | for i in range(10): 79 | axes[i].imshow(output_img[i, :, :, 0], cmap=None) 80 | axes[i].axis('off') 81 | 82 | plt.show() -------------------------------------------------------------------------------- /API/dataloader_taxibj.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class TrafficDataset(Dataset): 7 | def __init__(self, X, Y): 8 | super(TrafficDataset, self).__init__() 9 | self.X = (X + 1) / 2 10 | self.Y = (Y + 1) / 2 11 | self.mean = 0 12 | self.std = 1 13 | 14 | def __len__(self): 15 | return self.X.shape[0] 16 | 17 | def __getitem__(self, index): 18 | data = torch.tensor(self.X[index, ::]).float() 19 | labels = torch.tensor(self.Y[index, ::]).float() 20 | return data, labels 21 | 22 | def load_data( 23 | batch_size, val_batch_size, 24 | data_root, num_workers): 25 | 26 | dataset = np.load(data_root+'taxibj/dataset.npz') 27 | X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset['Y_train'], dataset['X_test'], dataset['Y_test'] 28 | 29 | train_set = TrafficDataset(X=X_train, Y=Y_train) 30 | test_set = TrafficDataset(X=X_test, Y=Y_test) 31 | 32 | dataloader_train = torch.utils.data.DataLoader( 33 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers) 34 | dataloader_test = torch.utils.data.DataLoader( 35 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers) 36 | 37 | return dataloader_train, None, dataloader_test, 0, 1 -------------------------------------------------------------------------------- /API/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.metrics import structural_similarity as cal_ssim 3 | 4 | def MAE(pred, true): 5 | return np.mean(np.abs(pred-true),axis=(0,1)).sum() 6 | 7 | 8 | def MSE(pred, true): 9 | return np.mean((pred-true)**2,axis=(0,1)).sum() 10 | 11 | # cite the `PSNR` code from E3d-LSTM, Thanks! 12 | # https://github.com/google/e3d_lstm/blob/master/src/trainer.py line 39-40 13 | def PSNR(pred, true): 14 | mse = np.mean((np.uint8(pred * 255)-np.uint8(true * 255))**2) 15 | return 20 * np.log10(255) - 10 * np.log10(mse) 16 | 17 | def metric(pred, true, mean, std, return_ssim_psnr=False, clip_range=[0, 1]): 18 | pred = pred*std + mean 19 | true = true*std + mean 20 | mae = MAE(pred, true) 21 | mse = MSE(pred, true) 22 | 23 | if return_ssim_psnr: 24 | pred = np.maximum(pred, clip_range[0]) 25 | pred = np.minimum(pred, clip_range[1]) 26 | ssim, psnr = 0, 0 27 | for b in range(pred.shape[0]): 28 | for f in range(pred.shape[1]): 29 | ssim += cal_ssim(pred[b, f].swapaxes(0, 2), true[b, f].swapaxes(0, 2), multichannel=True) 30 | psnr += PSNR(pred[b, f], true[b, f]) 31 | ssim = ssim / (pred.shape[0] * pred.shape[1]) 32 | psnr = psnr / (pred.shape[0] * pred.shape[1]) 33 | return mse, mae, ssim, psnr 34 | else: 35 | return mse, mae -------------------------------------------------------------------------------- /API/recorder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class Recorder: 5 | def __init__(self, verbose=False, delta=0): 6 | self.verbose = verbose 7 | self.best_score = None 8 | self.val_loss_min = np.Inf 9 | self.delta = delta 10 | 11 | def __call__(self, val_loss, model, path): 12 | score = -val_loss 13 | if self.best_score is None: 14 | self.best_score = score 15 | self.save_checkpoint(val_loss, model, path) 16 | elif score >= self.best_score + self.delta: 17 | self.best_score = score 18 | self.save_checkpoint(val_loss, model, path) 19 | 20 | def save_checkpoint(self, val_loss, model, path): 21 | if self.verbose: 22 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') 23 | torch.save(model.state_dict(), path+'/'+'checkpoint.pth') 24 | self.val_loss_min = val_loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PastNet: Introducing Physical Inductive Biases for Spatio-temporal Video Prediction (ACMMM2024) 2 | 3 | 4 | ## Abstract 5 | 6 | In this paper, we investigate the challenge of spatio-temporal video prediction, which involves generating future videos based on historical data streams. Existing approaches typically utilize external information such as semantic maps to enhance video prediction, which often neglect the inherent physical knowledge embedded within videos. Furthermore, their high computational demands could impede their applications for high-resolution videos. To address these constraints, we introduce a novel approach called **Physics-assisted Spatio-temporal Network (PastNet)** for generating high-quality video prediction. The core of our PastNet lies in incorporating a spectral convolution operator in the Fourier domain, which efficiently introduces inductive biases from the underlying physical laws. Additionally, we employ a memory bank with the estimated intrinsic dimensionality to discretize local features during the processing of complex spatio-temporal signals, thereby reducing computational costs and facilitating efficient high-resolution video prediction. Extensive experiments on various widely-used datasets demonstrate the effectiveness and efficiency of the proposed PastNet compared with a range of state-of-the-art methods, particularly in high-resolution scenarios. 7 | 8 | 9 |

10 | 11 |

12 | 13 | 14 | 15 | ## Overview 16 | 17 | * `API/` contains dataloaders and metrics. 18 | * `modules/` contains several used module blocks. 19 | * `models/` contains the PastNet model. 20 | * `train_model.py` and `exp_vq.py` are the core files for training, validating, and testing pipelines. 21 | 22 | 23 | 24 | ## Citation 25 | 26 | If you are interested in our repository and our paper, please cite the following paper: 27 | 28 | ``` 29 | @inproceedings{wu2024pastnet, 30 | title={Pastnet: Introducing physical inductive biases for spatio-temporal video prediction}, 31 | author={Wu, Hao and Xu, Fan and Chen, Chong and Hua, Xian-Sheng and Luo, Xiao and Wang, Haixin}, 32 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia}, 33 | pages={2917--2926}, 34 | year={2024} 35 | } 36 | ``` 37 | 38 | -------------------------------------------------------------------------------- /__pycache__/exp_vq.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/exp_vq.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/params.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/params.cpython-38.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /data/moving_mnist/download_mmnist.sh: -------------------------------------------------------------------------------- 1 | wget http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy 2 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /exp_vq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import torch 5 | import pickle 6 | import logging 7 | import numpy as np 8 | from models.PastNet_Model import PastNetModel 9 | from tqdm import tqdm 10 | from API import * 11 | from utils import * 12 | 13 | 14 | def relative_l1_error(true_values, predicted_values): 15 | error = torch.abs(true_values - predicted_values) 16 | return torch.mean(error / torch.abs(true_values)) 17 | 18 | 19 | class PastNet_exp: 20 | def __init__(self, args): 21 | super(PastNet_exp, self).__init__() 22 | self.args = args 23 | self.config = self.args.__dict__ 24 | self.device = self._acquire_device() 25 | 26 | self._preparation() 27 | print_log(output_namespace(self.args)) 28 | 29 | self._get_data() 30 | self._select_optimizer() 31 | self._select_criterion() 32 | 33 | def _acquire_device(self): 34 | if self.args.use_gpu: 35 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu) 36 | device = torch.device('cuda:{}'.format(0)) 37 | print_log('Use GPU: {}'.format(self.args.gpu)) 38 | else: 39 | device = torch.device('cpu') 40 | print_log('Use CPU') 41 | return device 42 | 43 | def _preparation(self): 44 | # seed 45 | set_seed(self.args.seed) 46 | # log and checkpoint 47 | self.path = osp.join(self.args.res_dir, self.args.ex_name) 48 | check_dir(self.path) 49 | 50 | self.checkpoints_path = osp.join(self.path, 'checkpoints') 51 | check_dir(self.checkpoints_path) 52 | 53 | sv_param = osp.join(self.path, 'model_param.json') 54 | with open(sv_param, 'w') as file_obj: 55 | json.dump(self.args.__dict__, file_obj) 56 | 57 | for handler in logging.root.handlers[:]: 58 | logging.root.removeHandler(handler) 59 | logging.basicConfig(level=logging.INFO, filename=osp.join(self.path, 'log.log'), 60 | filemode='a', format='%(asctime)s - %(message)s') 61 | # prepare data 62 | self._get_data() 63 | # build the model 64 | self._build_model() 65 | 66 | def _build_model(self): 67 | args = self.args 68 | freeze = args.freeze_vqvae == 1 69 | self.model = PastNetModel(args, 70 | shape_in=tuple(args.in_shape), 71 | hid_T=args.hid_T, 72 | N_T=args.N_T, 73 | res_units=args.res_units, 74 | res_layers=args.res_layers, 75 | embedding_nums=args.K, 76 | embedding_dim=args.D).to(self.device) 77 | 78 | def _get_data(self): 79 | config = self.args.__dict__ 80 | self.train_loader, self.vali_loader, self.test_loader, self.data_mean, self.data_std = load_data(**config) 81 | self.vali_loader = self.test_loader if self.vali_loader is None else self.vali_loader 82 | 83 | def _select_optimizer(self): 84 | self.optimizer = torch.optim.Adam( 85 | self.model.parameters(), lr=self.args.lr) 86 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR( 87 | self.optimizer, max_lr=self.args.lr, steps_per_epoch=len(self.train_loader), epochs=self.args.epochs) 88 | return self.optimizer 89 | 90 | def _select_criterion(self): 91 | self.criterion = torch.nn.MSELoss() 92 | 93 | def _save(self, name=''): 94 | torch.save(self.model.state_dict(), os.path.join( 95 | self.checkpoints_path, name + '.pth')) 96 | state = self.scheduler.state_dict() 97 | fw = open(os.path.join(self.checkpoints_path, name + '.pkl'), 'wb') 98 | pickle.dump(state, fw) 99 | 100 | def train(self, args): 101 | config = args.__dict__ 102 | recorder = Recorder(verbose=True) 103 | 104 | for epoch in range(config['epochs']): 105 | train_loss = [] 106 | self.model.train() 107 | train_pbar = tqdm(self.train_loader) 108 | 109 | for batch_x, batch_y in train_pbar: 110 | self.optimizer.zero_grad() 111 | batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) 112 | pred_y = self.model(batch_x) 113 | 114 | loss = self.criterion(pred_y, batch_y) 115 | train_loss.append(loss.item()) 116 | train_pbar.set_description('train loss: {:.4f}'.format(loss.item())) 117 | 118 | loss.backward() 119 | self.optimizer.step() 120 | self.scheduler.step() 121 | 122 | train_loss = np.average(train_loss) 123 | 124 | if epoch % args.log_step == 0: 125 | with torch.no_grad(): 126 | vali_loss = self.vali(self.vali_loader) 127 | if epoch % (args.log_step * 100) == 0: 128 | self._save(name=str(epoch)) 129 | print_log("Epoch: {0} | Train Loss: {1:.4f} Vali Loss: {2:.4f}\n".format( 130 | epoch + 1, train_loss, vali_loss)) 131 | recorder(vali_loss, self.model, self.path) 132 | 133 | best_model_path = self.path + '/' + 'checkpoint.pth' 134 | self.model.load_state_dict(torch.load(best_model_path)) 135 | return self.model 136 | 137 | # def vali(self, vali_loader): 138 | # self.model.eval() 139 | # preds_lst, trues_lst, total_loss = [], [], [] 140 | # vali_pbar = tqdm(vali_loader) 141 | # for i, (batch_x, batch_y) in enumerate(vali_pbar): 142 | # if i * batch_x.shape[0] > 1000: 143 | # break 144 | 145 | # batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) 146 | # pred_y = self.model(batch_x) 147 | # list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [ 148 | # pred_y, batch_y], [preds_lst, trues_lst])) 149 | 150 | # loss = self.criterion(pred_y, batch_y) 151 | # vali_pbar.set_description( 152 | # 'vali loss: {:.4f}'.format(loss.mean().item())) 153 | # total_loss.append(loss.mean().item()) 154 | 155 | # total_loss = np.average(total_loss) 156 | # preds = np.concatenate(preds_lst, axis=0) 157 | # trues = np.concatenate(trues_lst, axis=0) 158 | # mse, mae, ssim, psnr = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std, True) 159 | # print_log('vali mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr)) 160 | 161 | # l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item() 162 | # relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues), torch.zeros_like(torch.tensor(trues))).item() 163 | 164 | # l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item() 165 | # rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item() 166 | 167 | # # 计算RMSE 168 | # rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)) 169 | # rmse = rmse.item() 170 | 171 | # print_log('RMSE: {:.7f}'.format(rmse)) 172 | # print_log('L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, rel_l1_err, l2_error, relative_l2_error)) 173 | 174 | # self.model.train() 175 | # return total_loss 176 | 177 | def vali(self, vali_loader): 178 | self.model.eval() 179 | preds_lst, trues_lst, total_loss = [], [], [] 180 | vali_pbar = tqdm(vali_loader) 181 | for i, (batch_x, batch_y) in enumerate(vali_pbar): 182 | if i * batch_x.shape[0] > 1000: 183 | break 184 | 185 | batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device) 186 | pred_y = self.model(batch_x) 187 | list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [ 188 | pred_y, batch_y], [preds_lst, trues_lst])) 189 | 190 | loss = self.criterion(pred_y, batch_y) 191 | vali_pbar.set_description( 192 | 'vali loss: {:.4f}'.format(loss.mean().item())) 193 | total_loss.append(loss.mean().item()) 194 | 195 | total_loss = np.average(total_loss) 196 | preds = np.concatenate(preds_lst, axis=0) 197 | trues = np.concatenate(trues_lst, axis=0) 198 | mse, mae, ssim, psnr = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std, True) 199 | # print_log('vali mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr)) 200 | 201 | l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item() 202 | relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues), 203 | torch.zeros_like(torch.tensor(trues))).item() 204 | 205 | l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item() 206 | rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item() 207 | 208 | # calculate the RMSE, MSE, and MAE 209 | rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)).item() 210 | mse = torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2).item() 211 | mae = torch.mean(torch.abs(torch.tensor(preds) - torch.tensor(trues))).item() 212 | 213 | ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / (trues + 1e-8) 214 | ape[torch.tensor(trues) == 0] = 0 # set APE to zero where true value is zero 215 | mape = torch.mean(ape).item() * 100 216 | # ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / torch.abs(torch.tensor(trues)) 217 | # mape = torch.mean(ape).item() * 100 218 | 219 | print_log( 220 | 'L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, 221 | rel_l1_err, 222 | l2_error, 223 | relative_l2_error)) 224 | print_log('RMSE: {:.7f}, MSE: {:.7f}, MAE: {:.7f}, MAPE: {:.7f}%'.format(rmse, mse, mae, mape)) 225 | 226 | self.model.train() 227 | return total_loss 228 | 229 | def test(self, args): 230 | self.model.eval() 231 | inputs_lst, trues_lst, preds_lst = [], [], [] 232 | for batch_x, batch_y in self.test_loader: 233 | pred_y = self.model(batch_x.to(self.device)) 234 | list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [ 235 | batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst])) 236 | 237 | inputs, trues, preds = map(lambda data: np.concatenate( 238 | data, axis=0), [inputs_lst, trues_lst, preds_lst]) 239 | 240 | folder_path = self.path + '/results/{}/sv/'.format(args.ex_name) 241 | if not os.path.exists(folder_path): 242 | os.makedirs(folder_path) 243 | 244 | mse, mae, ssim, psnr = metric(preds, trues, self.test_loader.dataset.mean, self.test_loader.dataset.std, True) 245 | print_log('mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr)) 246 | 247 | l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item() 248 | relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues), 249 | torch.zeros_like(torch.tensor(trues))).item() 250 | 251 | l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item() 252 | rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item() 253 | 254 | # 计算RMSE 255 | # rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)) 256 | # rmse = rmse.item() 257 | 258 | # print_log('RMSE: {:.7f}'.format(rmse)) 259 | # print_log('L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, rel_l1_err, l2_error, relative_l2_error)) 260 | 261 | # calculate the RMSE, MSE, and MAE 262 | rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)).item() 263 | mse = torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2).item() 264 | mae = torch.mean(torch.abs(torch.tensor(preds) - torch.tensor(trues))).item() 265 | ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / torch.abs(torch.tensor(trues)) 266 | mape = torch.mean(ape).item() * 100 267 | 268 | print_log( 269 | 'L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, 270 | rel_l1_err, 271 | l2_error, 272 | relative_l2_error)) 273 | print_log('RMSE: {:.7f}, MSE: {:.7f}, MAE: {:.7f}, MAPE: {:.7f}%'.format(rmse, mse, mae, mape)) 274 | 275 | for np_data in ['inputs', 'trues', 'preds']: 276 | np.save(osp.join(folder_path, np_data + '.npy'), vars()[np_data]) 277 | return mse -------------------------------------------------------------------------------- /figure/mm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/figure/mm.png -------------------------------------------------------------------------------- /id_estimate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules import ConvSC, Inception 4 | import torch.nn.functional as F 5 | import torch.fft 6 | import numpy as np 7 | import torch.optim as optimizer 8 | from functools import partial 9 | from collections import OrderedDict 10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 11 | from torch.utils.checkpoint import checkpoint_sequential 12 | from einops import rearrange, reduce, repeat 13 | from einops.layers.torch import Rearrange, Reduce 14 | from sklearn.neighbors import NearestNeighbors 15 | 16 | 17 | def kNN(X, n_neighbors, n_jobs): 18 | X = X.cpu().detach().numpy() 19 | neigh = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X) 20 | dists, inds = neigh.kneighbors(X) 21 | return dists, inds 22 | 23 | def Levina_Bickel(X, dists, k): 24 | m = np.log(dists[:, k:k+1] / dists[:, 1:k]) 25 | m = (k-2) / np.sum(m, axis=1) 26 | dim = np.mean(m) 27 | return dim 28 | 29 | def estimate_dimension(latent_embedding, k=1000): 30 | B, T, C_, H_, W_ = latent_embedding.shape 31 | latent_embedding = latent_embedding.permute(0, 3, 4, 1, 2) 32 | X = latent_embedding.reshape(B*H_*W_, T*C_) 33 | dists, _ = kNN(X, k+1, n_jobs=-1) 34 | dim_estimate = Levina_Bickel(X, dists, k) 35 | return dim_estimate -------------------------------------------------------------------------------- /models/Fourier.py: -------------------------------------------------------------------------------- 1 | from modules.Fourier_modules import * 2 | class FPG(nn.Module): 3 | def __init__(self, 4 | img_size=224, 5 | patch_size=16, 6 | in_channels=20, 7 | out_channels=20, 8 | input_frames=20, 9 | embed_dim=768, 10 | depth=12, 11 | mlp_ratio=4., 12 | uniform_drop=False, 13 | drop_rate=0., 14 | drop_path_rate=0., 15 | norm_layer=None, 16 | dropcls=0.): 17 | super(FPG, self).__init__() 18 | self.embed_dim = embed_dim 19 | self.num_frames = input_frames 20 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 21 | 22 | self.patch_embed = PatchEmbed(img_size=img_size, 23 | patch_size=patch_size, 24 | in_c=in_channels, 25 | embed_dim=embed_dim) 26 | num_patches = self.patch_embed.num_patches 27 | 28 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # [1, 196, 768] 29 | self.pos_drop = nn.Dropout(p=drop_rate) 30 | 31 | self.h = self.patch_embed.grid_size[0] 32 | self.w = self.patch_embed.grid_size[1] 33 | ''' 34 | stochastic depth decay rule 35 | ''' 36 | if uniform_drop: 37 | dpr = [drop_path_rate for _ in range(depth)] 38 | else: 39 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 40 | 41 | self.blocks = nn.ModuleList([FourierNetBlock( 42 | dim=embed_dim, 43 | mlp_ratio=mlp_ratio, 44 | drop=drop_rate, 45 | drop_path=dpr[i], 46 | act_layer=nn.GELU, 47 | norm_layer=norm_layer, 48 | h=self.h, 49 | w=self.w) 50 | for i in range(depth) 51 | ]) 52 | 53 | self.norm = norm_layer(embed_dim) 54 | 55 | self.linearprojection = nn.Sequential(OrderedDict([ 56 | ('transposeconv1', nn.ConvTranspose2d(embed_dim, out_channels * 16, kernel_size=(2, 2), stride=(2, 2))), 57 | ('act1', nn.Tanh()), 58 | ('transposeconv2', nn.ConvTranspose2d(out_channels * 16, out_channels * 4, kernel_size=(2, 2), stride=(2, 2))), 59 | ('act2', nn.Tanh()), 60 | ('transposeconv3', nn.ConvTranspose2d(out_channels * 4, out_channels, kernel_size=(4, 4), stride=(4, 4))) 61 | ])) 62 | 63 | if dropcls > 0: 64 | print('dropout %.2f before classifier' % dropcls) 65 | self.final_dropout = nn.Dropout(p=dropcls) 66 | else: 67 | self.final_dropout = nn.Identity() 68 | 69 | trunc_normal_(self.pos_embed, std=.02) 70 | self.apply(self._init_weights) 71 | 72 | 73 | def _init_weights(self, m): 74 | if isinstance(m, nn.Linear): 75 | trunc_normal_(m.weight, std=.02) 76 | if isinstance(m, nn.Linear) and m.bias is not None: 77 | nn.init.constant_(m.bias, 0) 78 | elif isinstance(m, nn.LayerNorm): 79 | nn.init.constant_(m.bias, 0) 80 | nn.init.constant_(m.weight, 1.0) 81 | 82 | @torch.jit.ignore 83 | def no_weight_decay(self): 84 | return {'pos_embed', 'cls_token'} 85 | 86 | def forward_features(self, x): 87 | ''' 88 | patch_embed: 89 | [B, T, C, H, W] -> [B*T, num_patches, embed_dim] 90 | ''' 91 | B,T,C,H,W = x.shape 92 | x = x.view(B*T, C, H, W) 93 | x = self.patch_embed(x) 94 | #enc = LearnableFourierPositionalEncoding(768, 768, 64, 768, 10) 95 | #fourierpos_embed = enc(x) 96 | x = self.pos_drop(x + self.pos_embed) 97 | #x = self.pos_drop(x + fourierpos_embed) 98 | 99 | 100 | if not get_fourcastnet_args().checkpoint_activations: 101 | for blk in self.blocks: 102 | x = blk(x) 103 | else: 104 | x = checkpoint_sequential(self.blocks, 4, x) 105 | 106 | x = self.norm(x).transpose(1, 2) 107 | x = torch.reshape(x, [-1, self.embed_dim, self.h, self.w]) 108 | return x 109 | 110 | def forward(self, x): 111 | B, T, C, H, W = x.shape 112 | x = self.forward_features(x) 113 | x = self.final_dropout(x) 114 | x = self.linearprojection(x) 115 | x = x.reshape(B, T, C, H, W) 116 | return x 117 | 118 | -------------------------------------------------------------------------------- /models/PastNet_Model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import * 3 | import logging 4 | from torch import nn 5 | from modules.DiscreteSTModel_modules import * 6 | from modules.Fourier_modules import * 7 | 8 | 9 | 10 | def stride_generator(N, reverse=False): 11 | strides = [1, 2]*10 12 | if reverse: 13 | return list(reversed(strides[:N])) 14 | else: 15 | return strides[:N] 16 | 17 | 18 | class GroupConv2d(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False): 20 | super(GroupConv2d, self).__init__() 21 | self.act_norm = act_norm 22 | if in_channels % groups != 0: 23 | groups = 1 24 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 25 | stride=stride, padding=padding, groups=groups) 26 | self.norm = nn.GroupNorm(groups, out_channels) 27 | self.activate = nn.LeakyReLU(0.2, inplace=True) 28 | 29 | def forward(self, x): 30 | y = self.conv(x) 31 | if self.act_norm: 32 | y = self.activate(self.norm(y)) 33 | return y 34 | 35 | 36 | class Inception(nn.Module): 37 | def __init__(self, C_in, C_hid, C_out, incep_ker=[3, 5, 7, 11], groups=8): 38 | super(Inception, self).__init__() 39 | self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) 40 | layers = [] 41 | for ker in incep_ker: 42 | layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, 43 | stride=1, padding=ker//2, groups=groups, act_norm=True)) 44 | self.layers = nn.Sequential(*layers) 45 | 46 | def forward(self, x): 47 | x = self.conv1(x) 48 | y = 0 49 | for layer in self.layers: 50 | y += layer(x) 51 | return y 52 | 53 | class FPG(nn.Module): 54 | def __init__(self, 55 | img_size=224, 56 | patch_size=16, 57 | in_channels=20, 58 | out_channels=20, 59 | input_frames=20, 60 | embed_dim=768, 61 | depth=12, 62 | mlp_ratio=4., 63 | uniform_drop=False, 64 | drop_rate=0., 65 | drop_path_rate=0., 66 | norm_layer=None, 67 | dropcls=0.): 68 | super(FPG, self).__init__() 69 | self.embed_dim = embed_dim 70 | self.num_frames = input_frames 71 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 72 | 73 | self.patch_embed = PatchEmbed(img_size=img_size, 74 | patch_size=patch_size, 75 | in_c=in_channels, 76 | embed_dim=embed_dim) 77 | num_patches = self.patch_embed.num_patches 78 | 79 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # [1, 196, 768] 80 | self.pos_drop = nn.Dropout(p=drop_rate) 81 | 82 | self.h = self.patch_embed.grid_size[0] 83 | self.w = self.patch_embed.grid_size[1] 84 | ''' 85 | stochastic depth decay rule 86 | ''' 87 | if uniform_drop: 88 | dpr = [drop_path_rate for _ in range(depth)] 89 | else: 90 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 91 | 92 | self.blocks = nn.ModuleList([FourierNetBlock( 93 | dim=embed_dim, 94 | mlp_ratio=mlp_ratio, 95 | drop=drop_rate, 96 | drop_path=dpr[i], 97 | act_layer=nn.GELU, 98 | norm_layer=norm_layer, 99 | h=self.h, 100 | w=self.w) 101 | for i in range(depth) 102 | ]) 103 | 104 | self.norm = norm_layer(embed_dim) 105 | 106 | self.linearprojection = nn.Sequential(OrderedDict([ 107 | ('transposeconv1', nn.ConvTranspose2d(embed_dim, out_channels * 16, kernel_size=(2, 2), stride=(2, 2))), 108 | ('act1', nn.Tanh()), 109 | ('transposeconv2', nn.ConvTranspose2d(out_channels * 16, out_channels * 4, kernel_size=(2, 2), stride=(2, 2))), 110 | ('act2', nn.Tanh()), 111 | ('transposeconv3', nn.ConvTranspose2d(out_channels * 4, out_channels, kernel_size=(4, 4), stride=(4, 4))) 112 | ])) 113 | 114 | if dropcls > 0: 115 | print('dropout %.2f before classifier' % dropcls) 116 | self.final_dropout = nn.Dropout(p=dropcls) 117 | else: 118 | self.final_dropout = nn.Identity() 119 | 120 | trunc_normal_(self.pos_embed, std=.02) 121 | self.apply(self._init_weights) 122 | 123 | 124 | def _init_weights(self, m): 125 | if isinstance(m, nn.Linear): 126 | trunc_normal_(m.weight, std=.02) 127 | if isinstance(m, nn.Linear) and m.bias is not None: 128 | nn.init.constant_(m.bias, 0) 129 | elif isinstance(m, nn.LayerNorm): 130 | nn.init.constant_(m.bias, 0) 131 | nn.init.constant_(m.weight, 1.0) 132 | 133 | @torch.jit.ignore 134 | def no_weight_decay(self): 135 | return {'pos_embed', 'cls_token'} 136 | 137 | def forward_features(self, x): 138 | ''' 139 | patch_embed: 140 | [B, T, C, H, W] -> [B*T, num_patches, embed_dim] 141 | ''' 142 | B,T,C,H,W = x.shape 143 | x = x.view(B*T, C, H, W) 144 | x = self.patch_embed(x) 145 | #enc = LearnableFourierPositionalEncoding(768, 768, 64, 768, 10) 146 | #fourierpos_embed = enc(x) 147 | x = self.pos_drop(x + self.pos_embed) 148 | #x = self.pos_drop(x + fourierpos_embed) 149 | 150 | 151 | if not get_fourcastnet_args().checkpoint_activations: 152 | for blk in self.blocks: 153 | x = blk(x) 154 | else: 155 | x = checkpoint_sequential(self.blocks, 4, x) 156 | 157 | x = self.norm(x).transpose(1, 2) 158 | x = torch.reshape(x, [-1, self.embed_dim, self.h, self.w]) 159 | return x 160 | 161 | def forward(self, x): 162 | B, T, C, H, W = x.shape 163 | x = self.forward_features(x) 164 | x = self.final_dropout(x) 165 | x = self.linearprojection(x) 166 | x = x.reshape(B, T, C, H, W) 167 | return x 168 | 169 | class DST(nn.Module): 170 | def __init__(self, 171 | in_channel=1, 172 | num_hiddens=128, 173 | res_layers=2, 174 | res_units=32, 175 | embedding_nums=512, # K 176 | embedding_dim=64, # D 177 | commitment_cost=0.25): 178 | super(DST, self).__init__() 179 | self.embedding_dim = embedding_dim 180 | self.num_embeddings = embedding_nums 181 | self._encoder = Encoder(in_channel, num_hiddens, 182 | res_layers, res_units) # 183 | self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens, 184 | out_channels=embedding_dim, 185 | kernel_size=1, 186 | stride=1) 187 | 188 | # code book 189 | self._vq_vae = VectorQuantizerEMA(embedding_nums, 190 | embedding_dim, 191 | commitment_cost, 192 | decay=0.99) 193 | 194 | self._decoder = Decoder(embedding_dim, 195 | num_hiddens, 196 | res_layers, 197 | res_units, 198 | in_channel) 199 | 200 | def forward(self, x): 201 | # input shape : [B, C, W, H] 202 | z = self._encoder(x) # [B, hidden_units, W//4, H//4] 203 | # [B, embedding_dims, W//4, H//4] z -> encoding 204 | z = self._pre_vq_conv(z) 205 | # quantized -> embedding, quantized相当于videoGPT中的 encoder输出 206 | loss, quantized, perplexity, _ = self._vq_vae(z) 207 | x_recon = self._decoder(quantized) 208 | return loss, x_recon, perplexity 209 | 210 | def get_embedding(self, x): 211 | return self._pre_vq_conv(self._encoder(x)) 212 | 213 | def get_quantization(self, x): 214 | z = self._encoder(x) 215 | z = self._pre_vq_conv(z) 216 | _, quantized, _, _ = self._vq_vae(z) 217 | return quantized 218 | 219 | def reconstruct_img_by_embedding(self, embedding): 220 | loss, quantized, perplexity, _ = self._vq_vae(embedding) 221 | return self._decoder(quantized) 222 | 223 | def reconstruct_img(self, q): 224 | return self._decoder(q) 225 | 226 | @property 227 | def pre_vq_conv(self): 228 | return self._pre_vq_conv 229 | 230 | @property 231 | def encoder(self): 232 | return self._encoder 233 | 234 | 235 | class DynamicPropagation(nn.Module): 236 | def __init__(self, channel_in, channel_hid, N_T, incep_ker=[3, 5, 7, 11], groups=8): 237 | super(DynamicPropagation, self).__init__() 238 | 239 | self.N_T = N_T 240 | enc_layers = [Inception( 241 | channel_in, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)] 242 | for i in range(1, N_T-1): 243 | enc_layers.append(Inception( 244 | channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)) 245 | enc_layers.append(Inception(channel_hid, channel_hid // 246 | 2, channel_hid, incep_ker=incep_ker, groups=groups)) 247 | 248 | dec_layers = [Inception( 249 | channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)] 250 | for i in range(1, N_T-1): 251 | dec_layers.append(Inception( 252 | 2*channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)) 253 | dec_layers.append(Inception(2*channel_hid, channel_hid // 254 | 2, channel_in, incep_ker=incep_ker, groups=groups)) 255 | 256 | self.enc = nn.Sequential(*enc_layers) 257 | self.dec = nn.Sequential(*dec_layers) 258 | 259 | def forward(self, input_state): 260 | B, T, C, H, W = input_state.shape 261 | input_state = input_state.reshape(B, T*C, H, W) 262 | # encoder 263 | skips = [] 264 | hidden_embed = input_state 265 | for i in range(self.N_T): 266 | hidden_embed = self.enc[i](hidden_embed) 267 | if i < self.N_T - 1: 268 | skips.append(hidden_embed) 269 | 270 | # decoder 271 | hidden_embed = self.dec[0](hidden_embed) 272 | for i in range(1, self.N_T): 273 | hidden_embed = self.dec[i](torch.cat([hidden_embed, skips[-i]], dim=1)) 274 | 275 | output_state = hidden_embed.reshape(B, T, C, H, W) 276 | return output_state 277 | 278 | 279 | class PastNetModel(nn.Module): 280 | def __init__(self, 281 | args, 282 | shape_in, 283 | hid_T=256, 284 | N_T=8, 285 | incep_ker=[3, 5, 7, 11], 286 | groups=8, 287 | res_units=64, 288 | res_layers=2, 289 | embedding_nums=512, 290 | embedding_dim=64): 291 | super(PastNetModel, self).__init__() 292 | T, C, H, W = shape_in 293 | self.DST_module = DST(in_channel=C, 294 | res_units=res_units, 295 | res_layers=res_layers, 296 | embedding_dim=embedding_dim, 297 | embedding_nums=embedding_nums) 298 | 299 | self.FPG_module = FPG(img_size=64, 300 | patch_size=16, 301 | in_channels=1, 302 | out_channels=1, 303 | embed_dim=128, 304 | input_frames=10, 305 | depth=1, 306 | mlp_ratio=2., 307 | uniform_drop=False, 308 | drop_rate=0., 309 | drop_path_rate=0., 310 | norm_layer=None, 311 | dropcls=0.) 312 | 313 | if args.load_pred_train: 314 | print_log("Load Pre-trained Model.") 315 | self.vq_vae.load_state_dict(torch.load("./models/vqvae.ckpt"), strict=False) 316 | 317 | if args.freeze_vqvae: 318 | print_log(f"Params of VQVAE is freezed.") 319 | for p in self.vq_vae.parameters(): 320 | p.requires_grad = False 321 | self.DynamicPro = DynamicPropagation(T*64, hid_T, N_T, incep_ker, groups) 322 | 323 | def forward(self, input_frames): 324 | B, T, C, H, W = input_frames.shape 325 | pde_features = self.FPG_module(input_frames) 326 | input_features = input_frames.view([B * T, C, H, W]) 327 | encoder_embed = self.DST_module._encoder(input_features) 328 | z = self.DST_module._pre_vq_conv(encoder_embed) 329 | vq_loss, Latent_embed, _, _ = self.DST_module._vq_vae(z) 330 | 331 | _, C_, H_, W_ = Latent_embed.shape 332 | Latent_embed = Latent_embed.reshape(B, T, C_, H_, W_) 333 | 334 | hidden_dim = self.DynamicPro(Latent_embed) 335 | B_, T_, C_, H_, W_ = hidden_dim.shape 336 | hid = hidden_dim.reshape([B_ * T_, C_, H_, W_]) 337 | 338 | predicti_feature = self.DST_module._decoder(hid) 339 | predicti_feature = predicti_feature.reshape([B, T, C, H, W]) + pde_features 340 | 341 | return predicti_feature 342 | 343 | 344 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/vqvae.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/vqvae.ckpt -------------------------------------------------------------------------------- /modules/DiscreteSTModel_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VectorQuantizer(nn.Module): 7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost): 8 | super(VectorQuantizer, self).__init__() 9 | 10 | self._embedding_dim = embedding_dim # D 11 | self._num_embeddings = num_embeddings # K 12 | 13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) # 14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) 15 | self._commitment_cost = commitment_cost 16 | 17 | def forward(self, inputs): 18 | # convert inputs from B, C, H, W -> B, H, W, C 19 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 20 | input_shape = inputs.shape 21 | 22 | # Flatten input 23 | flat_input = inputs.view(-1, self._embedding_dim) 24 | 25 | # Calculate distances 26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 27 | + torch.sum(self._embedding.weight ** 2, dim=1) 28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 29 | 30 | # Encoding 31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 33 | encodings.scatter_(1, encoding_indices, 1) 34 | 35 | # Quantize and unflatten 36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 37 | 38 | # Loss 39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 40 | q_latent_loss = F.mse_loss(quantized, inputs.detach()) 41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss 42 | 43 | quantized = inputs + (quantized - inputs).detach() 44 | avg_probs = torch.mean(encodings, dim=0) 45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 46 | 47 | # convert quantized from B, H, W, C -> B, C, H, W 48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings 49 | 50 | 51 | def lookup(self, x): 52 | embeddings = F.embedding(x, self._embedding) 53 | return embeddings 54 | 55 | 56 | class VectorQuantizerEMA(nn.Module): 57 | def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5): 58 | super(VectorQuantizerEMA, self).__init__() 59 | 60 | self._embedding_dim = embedding_dim 61 | self._num_embeddings = num_embeddings 62 | 63 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 64 | self._embedding.weight.data.normal_() 65 | self._commitment_cost = commitment_cost 66 | 67 | self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) 68 | self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) 69 | self._ema_w.data.normal_() 70 | 71 | self._decay = decay 72 | self._epsilon = epsilon 73 | 74 | def forward(self, inputs): 75 | # convert inputs from BCHW -> BHWC 76 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 77 | input_shape = inputs.shape 78 | 79 | # Flatten input 80 | flat_input = inputs.view(-1, self._embedding_dim) 81 | 82 | # Calculate distances 83 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 84 | + torch.sum(self._embedding.weight ** 2, dim=1) 85 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 86 | 87 | # Encoding 88 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 89 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 90 | encodings.scatter_(1, encoding_indices, 1) 91 | 92 | # Quantize and unflatten 93 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 94 | 95 | # Use EMA to update the embedding vectors 96 | if self.training: 97 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \ 98 | (1 - self._decay) * torch.sum(encodings, 0) 99 | 100 | # Laplace smoothing of the cluster size 101 | n = torch.sum(self._ema_cluster_size.data) 102 | self._ema_cluster_size = ( 103 | (self._ema_cluster_size + self._epsilon) 104 | / (n + self._num_embeddings * self._epsilon) * n) 105 | 106 | dw = torch.matmul(encodings.t(), flat_input) 107 | self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) 108 | 109 | self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) 110 | 111 | # Loss 112 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 113 | loss = self._commitment_cost * e_latent_loss 114 | 115 | # Straight Through Estimator 116 | quantized = inputs + (quantized - inputs).detach() 117 | avg_probs = torch.mean(encodings, dim=0) 118 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 119 | 120 | # convert quantized from BHWC -> BCHW 121 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings 122 | 123 | def lookup(self, x): 124 | embeddings = F.embedding(x, self._embedding) 125 | return embeddings 126 | 127 | 128 | class Residual(nn.Module): 129 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens): 130 | super(Residual, self).__init__() 131 | self._block = nn.Sequential( 132 | nn.ReLU(True), 133 | nn.Conv2d(in_channels=in_channels, 134 | out_channels=num_residual_hiddens, 135 | kernel_size=3, stride=1, padding=1, bias=False), 136 | nn.BatchNorm2d(num_residual_hiddens), 137 | nn.ReLU(True), 138 | nn.Conv2d(in_channels=num_residual_hiddens, 139 | out_channels=num_hiddens, 140 | kernel_size=1, stride=1, bias=False), 141 | nn.BatchNorm2d(num_hiddens) 142 | ) 143 | 144 | def forward(self, x): 145 | return x + self._block(x) 146 | 147 | 148 | class ResidualStack(nn.Module): 149 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 150 | super(ResidualStack, self).__init__() 151 | self._num_residual_layers = num_residual_layers 152 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) 153 | for _ in range(self._num_residual_layers)]) 154 | 155 | def forward(self, x): 156 | for i in range(self._num_residual_layers): 157 | x = self._layers[i](x) 158 | return F.relu(x) 159 | 160 | 161 | class Encoder(nn.Module): 162 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 163 | super(Encoder, self).__init__() 164 | 165 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 166 | out_channels=num_hiddens // 2, 167 | kernel_size=4, 168 | stride=2, 169 | padding=1) 170 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2, 171 | out_channels=num_hiddens, 172 | kernel_size=4, 173 | stride=2, 174 | padding=1) 175 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens, 176 | out_channels=num_hiddens, 177 | kernel_size=3, 178 | stride=1, padding=1) 179 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 180 | num_hiddens=num_hiddens, 181 | num_residual_layers=num_residual_layers, 182 | num_residual_hiddens=num_residual_hiddens) 183 | 184 | def forward(self, inputs): 185 | # input shape: [B, C, W, H] 186 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2] 187 | x = F.relu(x) 188 | 189 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4] 190 | x = F.relu(x) 191 | 192 | x = self._conv_3(x) 193 | return self._residual_stack(x) 194 | 195 | 196 | class Decoder(nn.Module): 197 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels): 198 | super(Decoder, self).__init__() 199 | 200 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 201 | out_channels=num_hiddens, 202 | kernel_size=3, 203 | stride=1, padding=1) 204 | 205 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 206 | num_hiddens=num_hiddens, 207 | num_residual_layers=num_residual_layers, 208 | num_residual_hiddens=num_residual_hiddens) 209 | 210 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, 211 | out_channels=num_hiddens // 2, 212 | kernel_size=4, 213 | stride=2, padding=1) 214 | 215 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2, 216 | out_channels=out_channels, 217 | kernel_size=4, 218 | stride=2, padding=1) 219 | 220 | def forward(self, inputs): 221 | x = self._conv_1(inputs) 222 | x = self._residual_stack(x) 223 | x = self._conv_trans_1(x) 224 | x = F.relu(x) 225 | return self._conv_trans_2(x) 226 | -------------------------------------------------------------------------------- /modules/DiscreteSTModel_modules_BN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VectorQuantizer(nn.Module): 7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost): 8 | super(VectorQuantizer, self).__init__() 9 | 10 | self._embedding_dim = embedding_dim # D 11 | self._num_embeddings = num_embeddings # K 12 | 13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) 15 | self._commitment_cost = commitment_cost 16 | 17 | def forward(self, inputs): 18 | # convert inputs from B, C, H, W -> B, H, W, C 19 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 20 | input_shape = inputs.shape 21 | 22 | # Flatten input 23 | flat_input = inputs.view(-1, self._embedding_dim) 24 | 25 | # Calculate distances 26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 27 | + torch.sum(self._embedding.weight ** 2, dim=1) 28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # 平方差公式优化 29 | 30 | # Encoding 31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 33 | encodings.scatter_(1, encoding_indices, 1) 34 | 35 | # Quantize and unflatten 36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 37 | 38 | # Loss 39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 40 | q_latent_loss = F.mse_loss(quantized, inputs.detach()) 41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss 42 | 43 | quantized = inputs + (quantized - inputs).detach() 44 | avg_probs = torch.mean(encodings, dim=0) 45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 46 | 47 | # convert quantized from B, H, W, C -> B, C, H, W 48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings 49 | 50 | def lookup(self, x): 51 | embeddings = F.embedding(x, self._embedding) 52 | return embeddings 53 | 54 | 55 | class Residual(nn.Module): 56 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens): 57 | super(Residual, self).__init__() 58 | self._block = nn.Sequential( 59 | nn.ReLU(True), 60 | nn.Conv2d(in_channels=in_channels, 61 | out_channels=num_residual_hiddens, 62 | kernel_size=3, stride=1, padding=1, bias=False), 63 | nn.BatchNorm2d(num_residual_hiddens), 64 | nn.ReLU(True), 65 | nn.Conv2d(in_channels=num_residual_hiddens, 66 | out_channels=num_hiddens, 67 | kernel_size=1, stride=1, bias=False), 68 | nn.BatchNorm2d(num_hiddens) 69 | ) 70 | 71 | def forward(self, x): 72 | return x + self._block(x) 73 | 74 | 75 | class ResidualStack(nn.Module): 76 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 77 | super(ResidualStack, self).__init__() 78 | self._num_residual_layers = num_residual_layers 79 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) 80 | for _ in range(self._num_residual_layers)]) 81 | 82 | def forward(self, x): 83 | for i in range(self._num_residual_layers): 84 | x = self._layers[i](x) 85 | return F.relu(x) 86 | 87 | 88 | class Encoder(nn.Module): 89 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 90 | super(Encoder, self).__init__() 91 | 92 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 93 | out_channels=num_hiddens // 2, 94 | kernel_size=4, 95 | stride=2, 96 | padding=1) 97 | self.norm_1 = nn.BatchNorm2d(num_hiddens // 2) 98 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2, 99 | out_channels=num_hiddens, 100 | kernel_size=4, 101 | stride=2, 102 | padding=1) 103 | self.norm_2 = nn.BatchNorm2d(num_hiddens) 104 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens, 105 | out_channels=num_hiddens, 106 | kernel_size=3, 107 | stride=1, padding=1) 108 | self.norm_3 = nn.BatchNorm2d(num_hiddens) 109 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 110 | num_hiddens=num_hiddens, 111 | num_residual_layers=num_residual_layers, 112 | num_residual_hiddens=num_residual_hiddens) 113 | 114 | def forward(self, inputs): 115 | # input shape: [B, C, W, H] 116 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2] 117 | x = F.relu(self.norm_1(x)) 118 | 119 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4] 120 | x = F.relu(self.norm_2(x)) 121 | 122 | x = self._conv_3(x) 123 | x = self.norm_3(x) 124 | return self._residual_stack(x) 125 | 126 | 127 | class Decoder(nn.Module): 128 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels): 129 | super(Decoder, self).__init__() 130 | 131 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 132 | out_channels=num_hiddens, 133 | kernel_size=3, 134 | stride=1, padding=1) 135 | self._norm_1 = nn.BatchNorm2d(num_hiddens) 136 | 137 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 138 | num_hiddens=num_hiddens, 139 | num_residual_layers=num_residual_layers, 140 | num_residual_hiddens=num_residual_hiddens) 141 | 142 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, 143 | out_channels=num_hiddens // 2, 144 | kernel_size=4, 145 | stride=2, padding=1) 146 | self._norm_2 = nn.BatchNorm2d(num_hiddens // 2) 147 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2, 148 | out_channels=out_channels, 149 | kernel_size=4, 150 | stride=2, padding=1) 151 | 152 | def forward(self, inputs): 153 | x = self._conv_1(inputs) 154 | x = self._residual_stack(self._norm_1(x)) 155 | x = self._conv_trans_1(x) 156 | x = F.relu(self._norm_2(x)) 157 | return self._conv_trans_2(x) -------------------------------------------------------------------------------- /modules/DiscreteSTModel_modules_GN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VectorQuantizer(nn.Module): 7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost): 8 | super(VectorQuantizer, self).__init__() 9 | 10 | self._embedding_dim = embedding_dim # D 11 | self._num_embeddings = num_embeddings # K 12 | 13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings) 15 | self._commitment_cost = commitment_cost 16 | 17 | def forward(self, inputs): 18 | # convert inputs from B, C, H, W -> B, H, W, C 19 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 20 | input_shape = inputs.shape 21 | 22 | # Flatten input 23 | flat_input = inputs.view(-1, self._embedding_dim) 24 | 25 | # Calculate distances 26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 27 | + torch.sum(self._embedding.weight ** 2, dim=1) 28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # 平方差公式优化 29 | 30 | # Encoding 31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 33 | encodings.scatter_(1, encoding_indices, 1) 34 | 35 | # Quantize and unflatten 36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 37 | 38 | # Loss 39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 40 | q_latent_loss = F.mse_loss(quantized, inputs.detach()) 41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss 42 | 43 | quantized = inputs + (quantized - inputs).detach() 44 | avg_probs = torch.mean(encodings, dim=0) 45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 46 | 47 | # convert quantized from B, H, W, C -> B, C, H, W 48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings 49 | 50 | def lookup(self, x): 51 | embeddings = F.embedding(x, self._embedding) 52 | return embeddings 53 | 54 | 55 | class Residual(nn.Module): 56 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens): 57 | super(Residual, self).__init__() 58 | self._block = nn.Sequential( 59 | nn.ReLU(True), 60 | nn.Conv2d(in_channels=in_channels, 61 | out_channels=num_residual_hiddens, 62 | kernel_size=3, stride=1, padding=1, bias=False), 63 | nn.GroupNorm(2, num_residual_hiddens), 64 | nn.ReLU(True), 65 | nn.Conv2d(in_channels=num_residual_hiddens, 66 | out_channels=num_hiddens, 67 | kernel_size=1, stride=1, bias=False), 68 | nn.GroupNorm(2, num_hiddens) 69 | ) 70 | 71 | def forward(self, x): 72 | return x + self._block(x) 73 | 74 | 75 | class ResidualStack(nn.Module): 76 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 77 | super(ResidualStack, self).__init__() 78 | self._num_residual_layers = num_residual_layers 79 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens) 80 | for _ in range(self._num_residual_layers)]) 81 | 82 | def forward(self, x): 83 | for i in range(self._num_residual_layers): 84 | x = self._layers[i](x) 85 | return F.relu(x) 86 | 87 | 88 | class Encoder(nn.Module): 89 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens): 90 | super(Encoder, self).__init__() 91 | 92 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 93 | out_channels=num_hiddens // 2, 94 | kernel_size=4, 95 | stride=2, 96 | padding=1) 97 | self.norm_1 = nn.GroupNorm(2, num_hiddens // 2) 98 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2, 99 | out_channels=num_hiddens, 100 | kernel_size=4, 101 | stride=2, 102 | padding=1) 103 | self.norm_2 = nn.GroupNorm(2, num_hiddens) 104 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens, 105 | out_channels=num_hiddens, 106 | kernel_size=3, 107 | stride=1, padding=1) 108 | self.norm_3 = nn.GroupNorm(2, num_hiddens) 109 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 110 | num_hiddens=num_hiddens, 111 | num_residual_layers=num_residual_layers, 112 | num_residual_hiddens=num_residual_hiddens) 113 | 114 | def forward(self, inputs): 115 | # input shape: [B, C, W, H] 116 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2] 117 | x = F.relu(self.norm_1(x)) 118 | 119 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4] 120 | x = F.relu(self.norm_2(x)) 121 | 122 | x = self._conv_3(x) 123 | x = self.norm_3(x) 124 | return self._residual_stack(x) 125 | 126 | 127 | class Decoder(nn.Module): 128 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels): 129 | super(Decoder, self).__init__() 130 | 131 | self._conv_1 = nn.Conv2d(in_channels=in_channels, 132 | out_channels=num_hiddens, 133 | kernel_size=3, 134 | stride=1, padding=1) 135 | self._norm_1 = nn.GroupNorm(2, num_hiddens) 136 | 137 | self._residual_stack = ResidualStack(in_channels=num_hiddens, 138 | num_hiddens=num_hiddens, 139 | num_residual_layers=num_residual_layers, 140 | num_residual_hiddens=num_residual_hiddens) 141 | 142 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens, 143 | out_channels=num_hiddens // 2, 144 | kernel_size=4, 145 | stride=2, padding=1) 146 | self._norm_2 = nn.GroupNorm(2, num_hiddens // 2) 147 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2, 148 | out_channels=out_channels, 149 | kernel_size=4, 150 | stride=2, padding=1) 151 | 152 | def forward(self, inputs): 153 | x = self._conv_1(inputs) 154 | x = self._residual_stack(self._norm_1(x)) 155 | x = self._conv_trans_1(x) 156 | x = F.relu(self._norm_2(x)) 157 | return self._conv_trans_2(x) 158 | 159 | -------------------------------------------------------------------------------- /modules/Fourier_modules.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from collections import OrderedDict 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | from torch.utils.checkpoint import checkpoint_sequential 5 | from params import get_fourcastnet_args 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.fft 10 | import numpy as np 11 | import torch.optim as optimizer 12 | 13 | 14 | class PatchEmbed(nn.Module): 15 | def __init__(self, img_size=None, patch_size=8, in_c=13, embed_dim=768, norm_layer=None): 16 | super(PatchEmbed, self).__init__() 17 | img_size = to_2tuple(img_size) 18 | patch_size = to_2tuple(patch_size) 19 | self.img_size = img_size 20 | self.patch_size = patch_size 21 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # h, w 22 | self.num_patches = self.grid_size[0] * self.grid_size[1] 23 | self.projection= nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size) 24 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 25 | 26 | def forward(self, x): 27 | B, C, H, W = x.shape 28 | assert H == self.img_size[0] and W == self.img_size[1], \ 29 | f"Error..." 30 | ''' 31 | [32, 3, 224, 224] -> [32, 768, 14, 14] -> [32, 768, 196] -> [32, 196, 768] 32 | Conv2D: [32, 3, 224, 224] -> [32, 768, 14, 14] 33 | Flatten: [B, C, H, W] -> [B, C, HW] 34 | Transpose: [B, C, HW] -> [B, HW, C] 35 | ''' 36 | x = self.projection(x).flatten(2).transpose(1, 2) 37 | x = self.norm(x) 38 | return x 39 | 40 | class Mlp(nn.Module): 41 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 42 | super(Mlp, self).__init__() 43 | out_features = out_features or in_features 44 | hidden_features = hidden_features or in_features 45 | self.fc1 = nn.Linear(in_features, hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.fc3 = nn.AdaptiveAvgPool1d(out_features) 49 | self.drop = nn.Dropout(drop) 50 | 51 | def forward(self, x): 52 | x = self.fc1(x) 53 | x = self.act(x) 54 | x = self.drop(x) 55 | x = self.fc3(x) 56 | x = self.drop(x) 57 | return x 58 | 59 | class LearnableFourierPositionalEncoding(nn.Module): 60 | def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float): 61 | 62 | super().__init__() 63 | self.M = M 64 | self.F_dim = F_dim 65 | self.H_dim = H_dim 66 | self.D = D 67 | self.gamma = gamma 68 | 69 | self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) 70 | self.mlp = nn.Sequential( 71 | nn.Linear(self.F_dim, self.H_dim, bias=True), 72 | nn.GELU(), 73 | nn.Linear(self.H_dim, self.D) 74 | ) 75 | 76 | self.init_weights() 77 | 78 | def init_weights(self): 79 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) 80 | 81 | def forward(self, x): 82 | 83 | B, N, M = x.shape 84 | projected = self.Wr(x) 85 | cosines = torch.cos(projected) 86 | sines = torch.sin(projected) 87 | F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1) 88 | Y = self.mlp(F) 89 | PEx = Y.reshape((B, N, self.D)) 90 | return PEx 91 | 92 | class AdativeFourierNeuralOperator(nn.Module): 93 | def __init__(self, dim, h=14, w=14): 94 | super(AdativeFourierNeuralOperator, self).__init__() 95 | args = get_fourcastnet_args() 96 | self.hidden_size = dim 97 | self.h = h 98 | self.w = w 99 | self.num_blocks = args.fno_blocks 100 | self.block_size = self.hidden_size // self.num_blocks 101 | assert self.hidden_size % self.num_blocks == 0 102 | 103 | self.scale = 0.02 104 | self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) 105 | self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 106 | self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size)) 107 | self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size)) 108 | self.relu = nn.ReLU() 109 | 110 | if args.fno_bias: 111 | self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1) 112 | else: 113 | self.bias = None 114 | 115 | self.softshrink = args.fno_softshrink 116 | 117 | def multiply(self, input, weights): 118 | return torch.einsum('...bd, bdk->...bk', input, weights) 119 | 120 | def forward(self, x): 121 | B, N, C = x.shape 122 | 123 | if self.bias: 124 | bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1) 125 | else: 126 | bias = torch.zeros(x.shape, device=x.device) 127 | 128 | x = x.reshape(B, self.h, self.w, C) 129 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') 130 | x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size) 131 | 132 | x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0], inplace=True) 133 | x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1], inplace=True) 134 | x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0] 135 | x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1] 136 | 137 | x = torch.stack([x_real, x_imag], dim=-1) 138 | x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x 139 | 140 | x = torch.view_as_complex(x) 141 | x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size) 142 | x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1,2), norm='ortho') 143 | x = x.reshape(B, N, C) 144 | 145 | return x+bias 146 | 147 | class FourierNetBlock(nn.Module): 148 | def __init__(self, 149 | dim, 150 | mlp_ratio=4., 151 | drop=0., 152 | drop_path=0., 153 | act_layer=nn.GELU, 154 | norm_layer=nn.LayerNorm, 155 | h=14, 156 | w=14): 157 | super(FourierNetBlock, self).__init__() 158 | args = get_fourcastnet_args() 159 | self.normlayer1 = norm_layer(dim) 160 | self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w) 161 | 162 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 163 | self.normlayer2 = norm_layer(dim) 164 | mlp_hidden_dim = int(dim * mlp_ratio) 165 | self.mlp = Mlp(in_features=dim, 166 | hidden_features=mlp_hidden_dim, 167 | act_layer=act_layer, 168 | drop=drop) 169 | self.double_skip = args.double_skip 170 | 171 | def forward(self, x): 172 | x = x + self.drop_path(self.filter(self.normlayer1(x))) 173 | x = x + self.drop_path(self.mlp(self.normlayer2(x))) 174 | return x -------------------------------------------------------------------------------- /modules/STConvEncoderDecoder_modules.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class BasicConv2d(nn.Module): 5 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False): 6 | super(BasicConv2d, self).__init__() 7 | self.act_norm=act_norm 8 | if not transpose: 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 10 | else: 11 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 ) 12 | self.norm = nn.GroupNorm(2, out_channels) 13 | self.act = nn.LeakyReLU(0.2, inplace=True) 14 | 15 | def forward(self, x): 16 | y = self.conv(x) 17 | if self.act_norm: 18 | y = self.act(self.norm(y)) 19 | return y 20 | 21 | 22 | class ConvSC(nn.Module): 23 | def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True): 24 | super(ConvSC, self).__init__() 25 | if stride == 1: 26 | transpose = False 27 | self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride, 28 | padding=1, transpose=transpose, act_norm=act_norm) 29 | 30 | def forward(self, x): 31 | y = self.conv(x) 32 | return y 33 | 34 | 35 | class GroupConv2d(nn.Module): 36 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False): 37 | super(GroupConv2d, self).__init__() 38 | self.act_norm = act_norm 39 | if in_channels % groups != 0: 40 | groups = 1 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups) 42 | self.norm = nn.GroupNorm(groups,out_channels) 43 | self.activate = nn.LeakyReLU(0.2, inplace=True) 44 | 45 | def forward(self, x): 46 | y = self.conv(x) 47 | if self.act_norm: 48 | y = self.activate(self.norm(y)) 49 | return y 50 | 51 | 52 | class Inception(nn.Module): 53 | def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8): 54 | super(Inception, self).__init__() 55 | self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) 56 | layers = [] 57 | for ker in incep_ker: 58 | layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True)) 59 | self.layers = nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | x = self.conv1(x) 63 | y = 0 64 | for layer in self.layers: 65 | y += layer(x) 66 | return y -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__init__.py -------------------------------------------------------------------------------- /modules/__pycache__/DiscreteSTModel_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/DiscreteSTModel_modules.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/Fourier_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/Fourier_modules.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/Fourier_modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/Fourier_modules.cpython-39.pyc -------------------------------------------------------------------------------- /modules/__pycache__/STConvEncoderDecoder_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/STConvEncoderDecoder_modules.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /modules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def get_fourcastnet_args(): 5 | parser = argparse.ArgumentParser('FourCastNet training and evaluation script', add_help=False) 6 | parser.add_argument('--batch-size', default=4, type=int) 7 | parser.add_argument('--pretrain-epochs', default=80, type=int) 8 | parser.add_argument('--fintune-epochs', default=25, type=int) 9 | 10 | # Model parameters 11 | parser.add_argument('--arch', default='deit_small', type=str, help='Name of model to train') 12 | 13 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') 14 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)') 15 | 16 | # Optimizer parameters 17 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 18 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') 19 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') 20 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') 21 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 22 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') 23 | # Learning rate schedule parameters 24 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') 25 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') 26 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') 27 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') 28 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') 29 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') 30 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 31 | 32 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') 33 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') 34 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 35 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') 36 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') 37 | 38 | # Augmentation parameters 39 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') 40 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". "(default: rand-m9-mstd0.5-inc1)'), 41 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 42 | parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 43 | 44 | parser.add_argument('--repeated-aug', action='store_true') 45 | parser.set_defaults(repeated_aug=False) 46 | 47 | # * Random Erase params 48 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', help='Random erase prob (default: 0.25)') 49 | parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') 50 | parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') 51 | parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') 52 | 53 | # fno parameters 54 | parser.add_argument('--fno-bias', action='store_true') 55 | parser.add_argument('--fno-blocks', type=int, default=4) 56 | parser.add_argument('--fno-softshrink', type=float, default=0.00) 57 | parser.add_argument('--double-skip', action='store_true') 58 | parser.add_argument('--tensorboard-dir', type=str, default=None) 59 | parser.add_argument('--hidden-size', type=int, default=256) 60 | parser.add_argument('--num-layers', type=int, default=12) 61 | parser.add_argument('--checkpoint-activations', action='store_true') 62 | parser.add_argument('--autoresume', action='store_true') 63 | 64 | # attention parameters 65 | parser.add_argument('--num-attention-heads', type=int, default=1) 66 | 67 | # long short parameters 68 | parser.add_argument('--ls-w', type=int, default=4) 69 | parser.add_argument('--ls-dp-rank', type=int, default=16) 70 | 71 | return parser.parse_args() 72 | 73 | 74 | def get_graphcast_args(): 75 | parser = argparse.ArgumentParser('Graphcast training and evaluation script', add_help=False) 76 | parser.add_argument('--batch-size', default=2, type=int) 77 | parser.add_argument('--epochs', default=200, type=int) 78 | 79 | # Model parameters 80 | parser.add_argument('--grid-node-num', default=720 * 1440, type=int, help='The number of grid nodes') 81 | parser.add_argument('--mesh-node-num', default=128 * 320, type=int, help='The number of mesh nodes') 82 | parser.add_argument('--mesh-edge-num', default=217170, type=int, help='The number of mesh nodes') 83 | parser.add_argument('--grid2mesh-edge-num', default=1357920, type=int, help='The number of mesh nodes') 84 | parser.add_argument('--mesh2grid-edge-num', default=2230560, type=int, help='The number of mesh nodes') 85 | parser.add_argument('--grid-node-dim', default=49, type=int, help='The input dim of grid nodes') 86 | parser.add_argument('--grid-node-pred-dim', default=20, type=int, help='The output dim of grid-node prediction') 87 | parser.add_argument('--mesh-node-dim', default=3, type=int, help='The input dim of mesh nodes') 88 | parser.add_argument('--edge-dim', default=4, type=int, help='The input dim of all edges') 89 | parser.add_argument('--grid-node-embed-dim', default=64, type=int, help='The embedding dim of grid nodes') 90 | parser.add_argument('--mesh-node-embed-dim', default=64, type=int, help='The embedding dim of mesh nodes') 91 | parser.add_argument('--edge-embed-dim', default=8, type=int, help='The embedding dim of mesh nodes') 92 | 93 | # Optimizer parameters 94 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 95 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)') 96 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') 97 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') 98 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 99 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') 100 | 101 | # Learning rate schedule parameters 102 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"') 103 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)') 104 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') 105 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') 106 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') 107 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)') 108 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 109 | 110 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR') 111 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports') 112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 113 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') 114 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') 115 | 116 | # Pipline training parameters 117 | parser.add_argument('--pp_size', type=int, default=8, help='pipeline parallel size') 118 | parser.add_argument('--chunks', type=int, default=1, help='chunk size') 119 | 120 | return parser.parse_args() -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from exp_vq import PastNet_exp 3 | 4 | import warnings 5 | warnings.filterwarnings('ignore') 6 | 7 | def create_parser(): 8 | parser = argparse.ArgumentParser() 9 | # Set-up parameters 10 | parser.add_argument('--device', default='cuda', type=str, help='Name of device to use for tensor computations (cuda/cpu)') 11 | parser.add_argument('--res_dir', default='./results', type=str) 12 | parser.add_argument('--ex_name', default='DiscreteSTModel', type=str) 13 | parser.add_argument('--use_gpu', default=True, type=bool) 14 | parser.add_argument('--gpu', default=0, type=int) 15 | parser.add_argument('--seed', default=1, type=int) 16 | parser.add_argument('--load_model', default="", type=str) 17 | 18 | # dataset parameters 19 | parser.add_argument('--batch_size', default=16, type=int, help='Batch size') 20 | parser.add_argument('--val_batch_size', default=16, type=int, help='Batch size') 21 | parser.add_argument('--data_root', default='./data/') 22 | parser.add_argument('--dataname', default='mmnist', choices=['mmnist', 'taxibj', 'caltech']) 23 | parser.add_argument('--num_workers', default=8, type=int) 24 | 25 | # model parameters 26 | parser.add_argument('--in_shape', default=[10, 1, 64, 64], type=int, nargs='*') # [10, 1, 64, 64] for mmnist, [4, 2, 32, 32] for taxibj 27 | parser.add_argument('--hid_T', default=256, type=int) 28 | parser.add_argument('--N_T', default=8, type=int) 29 | parser.add_argument('--groups', default=4, type=int) 30 | parser.add_argument('--res_units', default=32, type=int) 31 | parser.add_argument('--res_layers', default=4, type=int) 32 | parser.add_argument('--K', default=512, type=int) 33 | parser.add_argument('--D', default=64, type=int) 34 | 35 | # Training parameters 36 | parser.add_argument('--epochs', default=201, type=int) 37 | parser.add_argument('--log_step', default=1, type=int) 38 | parser.add_argument('--lr', default=0.01, type=float, help='Learning rate') 39 | parser.add_argument('--load_pred_train', default=0, type=int, help='Learning rate') 40 | parser.add_argument('--freeze_vqvae', default=0, type=int) 41 | parser.add_argument('--theta', default=1, type=float) 42 | return parser 43 | 44 | 45 | if __name__ == '__main__': 46 | args = create_parser().parse_args() 47 | config = args.__dict__ 48 | 49 | exp = PastNet_exp(args) 50 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>> start <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') 51 | exp.train(args) 52 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>> testing <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<') 53 | mse = exp.test(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import random 5 | import numpy as np 6 | import torch.backends.cudnn as cudnn 7 | 8 | def set_seed(seed): 9 | random.seed(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | cudnn.deterministic = True 13 | 14 | def print_log(message): 15 | print(message) 16 | logging.info(message) 17 | 18 | def output_namespace(namespace): 19 | configs = namespace.__dict__ 20 | message = '' 21 | for k, v in configs.items(): 22 | message += '\n' + k + ': \t' + str(v) + '\t' 23 | return message 24 | 25 | def check_dir(path): 26 | if not os.path.exists(path): 27 | os.makedirs(path) --------------------------------------------------------------------------------