├── .github
└── workflows
│ └── python-app.yml
├── API
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── dataloader_caltech.cpython-38.pyc
│ ├── dataloader_moving_mnist.cpython-38.pyc
│ ├── dataloader_taxibj.cpython-38.pyc
│ ├── metrics.cpython-38.pyc
│ └── recorder.cpython-38.pyc
├── dataloader.py
├── dataloader_caltech.py
├── dataloader_caltech0.py
├── dataloader_moving_mnist.py
├── dataloader_moving_mnist_v2.py
├── dataloader_sevir.py
├── dataloader_taxibj.py
├── metrics.py
└── recorder.py
├── DiscreteSTModel
└── log.log
├── README.md
├── __pycache__
├── exp_vq.cpython-38.pyc
├── params.cpython-38.pyc
└── utils.cpython-38.pyc
├── data
└── moving_mnist
│ └── download_mmnist.sh
├── exp_vq.py
├── figure
└── mm.png
├── id_estimate.py
├── log.log
├── models
├── Fourier.py
├── PastNet_Model.py
├── __init__.py
├── __pycache__
│ └── __init__.cpython-38.pyc
└── vqvae.ckpt
├── modules
├── DiscreteSTModel_modules.py
├── DiscreteSTModel_modules_BN.py
├── DiscreteSTModel_modules_GN.py
├── Fourier_modules.py
├── STConvEncoderDecoder_modules.py
├── __init__.py
└── __pycache__
│ ├── DiscreteSTModel_modules.cpython-38.pyc
│ ├── Fourier_modules.cpython-38.pyc
│ ├── Fourier_modules.cpython-39.pyc
│ ├── STConvEncoderDecoder_modules.cpython-38.pyc
│ ├── __init__.cpython-38.pyc
│ └── __init__.cpython-39.pyc
├── params.py
├── results
└── DiscreteSTModel
│ └── log.log
├── train_model.py
└── utils.py
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ "main" ]
9 | pull_request:
10 | branches: [ "main" ]
11 |
12 | permissions:
13 | contents: read
14 |
15 | jobs:
16 | build:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v3
22 | - name: Set up Python 3.10
23 | uses: actions/setup-python@v3
24 | with:
25 | python-version: "3.10"
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install flake8 pytest
30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31 | - name: Lint with flake8
32 | run: |
33 | # stop the build if there are Python syntax errors or undefined names
34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37 | - name: Test with pytest
38 | run: |
39 | pytest
40 |
--------------------------------------------------------------------------------
/API/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import load_data
2 | from .metrics import metric
3 | from .recorder import Recorder
--------------------------------------------------------------------------------
/API/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/API/__pycache__/dataloader_caltech.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_caltech.cpython-38.pyc
--------------------------------------------------------------------------------
/API/__pycache__/dataloader_moving_mnist.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_moving_mnist.cpython-38.pyc
--------------------------------------------------------------------------------
/API/__pycache__/dataloader_taxibj.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/dataloader_taxibj.cpython-38.pyc
--------------------------------------------------------------------------------
/API/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/API/__pycache__/recorder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/API/__pycache__/recorder.cpython-38.pyc
--------------------------------------------------------------------------------
/API/dataloader.py:
--------------------------------------------------------------------------------
1 | from .dataloader_taxibj import load_data as load_taxibj
2 | from .dataloader_moving_mnist import load_data as load_mmnist
3 | from .dataloader_sevir import load_data as load_sevir
4 |
5 | def load_data(dataname,batch_size, val_batch_size, data_root, num_workers, **kwargs):
6 | if dataname == 'taxibj':
7 | return load_taxibj(batch_size, val_batch_size, data_root, num_workers)
8 | elif dataname == 'mmnist':
9 | return load_mmnist(batch_size, val_batch_size, data_root, num_workers)
10 | elif dataname == 'sevir':
11 | return load_sevir(batch_size, val_batch_size, data_root, num_workers)
--------------------------------------------------------------------------------
/API/dataloader_caltech.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | import bisect
7 | from torch.utils.data import Dataset
8 |
9 | split_string = "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46"
10 |
11 | def read_seq(path):
12 | f = open(path, 'rb+')
13 | string = f.read().decode('latin-1')
14 | str_list = string.split(split_string)
15 | # print(len(str_list))
16 | f.close()
17 | return str_list[1:]
18 |
19 | def seq_to_images(bytes_string):
20 | res = split_string.encode('latin-1') + bytes_string.encode('latin-1')
21 | img = cv2.imdecode(np.frombuffer(res, np.uint8), cv2.IMREAD_COLOR)
22 | return img / 255.0
23 |
24 | def load_caltech(root):
25 | file_list = [file for file in os.listdir(root) if file.split('.')[-1] == "seq"]
26 | print(file_list)
27 | for file in file_list[:1]:
28 | path = os.path.join(root, file)
29 | str_list, len = read_seq(path)
30 | imgs = np.zeros([len - 1, 480, 640, 3])
31 | idx = 0
32 | for str in str_list[1:]:
33 | imgs[idx] = seq_to_images(str)
34 | idx += 1
35 | return imgs
36 |
37 | class Caltech(Dataset):
38 | @staticmethod
39 | def cumsum(sequence):
40 | r, s = [], 0
41 | for e in sequence:
42 | l = len(e)
43 | r.append(l + s)
44 | s += l
45 | return r
46 |
47 | def __init__(self, root, is_train=True, file_list=['V001.seq'], n_frames_input=4, n_frames_output=1):
48 | super().__init__()
49 | datasets = []
50 | for file in file_list:
51 | datasets.append(SingleCaltech(os.path.join(root, file), is_train=is_train, n_frames_input=n_frames_input, n_frames_output=n_frames_output))
52 | assert len(datasets) > 0, 'datasets should not be an empty iterable'
53 | self.datasets = list(datasets[:1])
54 | self.cumulative_sizes = self.cumsum(self.datasets)
55 |
56 | self.mean = 0
57 | self.std = 1
58 |
59 | def __len__(self):
60 | return self.cumulative_sizes[-1]
61 |
62 | def __getitem__(self, idx):
63 | if idx < 0:
64 | if -idx > len(self):
65 | raise ValueError("absolute value of index should not exceed dataset length")
66 | idx = len(self) + idx
67 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
68 | if dataset_idx == 0:
69 | sample_idx = idx
70 | else:
71 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
72 | return self.datasets[dataset_idx][sample_idx]
73 |
74 | @property
75 | def cummulative_sizes(self):
76 | warnings.warn("cummulative_sizes attribute is renamed to "
77 | "cumulative_sizes", DeprecationWarning, stacklevel=2)
78 | return self.cumulative_sizes
79 |
80 |
81 | class SingleCaltech(Dataset):
82 | def __init__(self, root, is_train=True, n_frames_input=4, n_frames_output=1):
83 | super().__init__()
84 | self.root = root
85 | if is_train:
86 | self.length = 100
87 | else:
88 | self.length = 50
89 |
90 | self.input_length = n_frames_input
91 | self.output_length = n_frames_output
92 |
93 | self.sequence = None
94 | self.get_current_data()
95 |
96 |
97 |
98 | def get_current_data(self):
99 | str_list = read_seq(self.root)
100 | if self.length == 100:
101 | str_list = str_list[:104]
102 | self.sequence = np.zeros([104, 480, 640, 3])
103 | else:
104 | str_list = str_list[104:153]
105 | self.sequence = np.zeros([54, 480, 640, 3])
106 |
107 | for i, str in enumerate(str_list):
108 | self.sequence[i] = seq_to_images(str)
109 |
110 | def __getitem__(self, index):
111 | input = self.sequence[index: index + self.input_length]
112 | input = np.transpose(input, (0, 3, 1, 2))
113 | output = self.sequence[index + self.input_length: index + self.input_length + self.output_length]
114 | output = np.transpose(output, (0, 3, 1, 2))
115 | input = torch.from_numpy(input).contiguous().float()
116 | output = torch.from_numpy(output).contiguous().float()
117 | return input, output
118 |
119 | def __len__(self):
120 | return self.length
121 |
122 |
123 | def load_data(
124 | batch_size, val_batch_size,
125 | data_root, num_workers):
126 |
127 | file_list = [file for file in os.listdir(data_root) if file.split('.')[-1] == "seq"]
128 | # print(data_root)
129 | train_set = Caltech(root=data_root, is_train=True,
130 | n_frames_input=4, n_frames_output=1, file_list=file_list)
131 | test_set = Caltech(root=data_root, is_train=False,
132 | n_frames_input=4, n_frames_output=1, file_list=file_list)
133 |
134 | dataloader_train = torch.utils.data.DataLoader(
135 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
136 | dataloader_validation = torch.utils.data.DataLoader(
137 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
138 | dataloader_test = torch.utils.data.DataLoader(
139 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
140 |
141 | mean, std = 0, 1
142 | return dataloader_train, dataloader_validation, dataloader_test, mean, std
143 |
144 |
145 | if __name__ == "__main__":
146 | # data = load_caltech("/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01")
147 | file_list = [file for file in os.listdir("/home/pan/workspace/simvp/SimVP/data/caltech/USA/set01") if file.split('.')[-1] == "seq"]
148 | dataset = Caltech(root="/home/pan/workspace/simvp/SimVP/data/caltech/USA/set01", is_train=False, file_list=file_list)
149 | dataloader = torch.utils.data.DataLoader(
150 | dataset, batch_size=16, shuffle=True)
151 |
152 | for input, output in dataloader:
153 | print(input.shape, output.shape)
--------------------------------------------------------------------------------
/API/dataloader_caltech0.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | split_string = "\xFF\xD8\xFF\xE0\x00\x10\x4A\x46\x49\x46"
9 |
10 | def read_seq(path):
11 | f = open(path, 'rb+')
12 | string = f.read().decode('latin-1')
13 | str_list = string.split(split_string)
14 | print(len(str_list))
15 | f.close()
16 | return str_list, len(str_list)
17 |
18 | def seq_to_images(bytes_string):
19 | res = split_string.encode('latin-1') + bytes_string.encode('latin-1')
20 | img = cv2.imdecode(np.frombuffer(res, np.uint8), cv2.IMREAD_COLOR)
21 | return img / 255.0
22 |
23 | def load_caltech(root):
24 | file_list = [file for file in os.listdir(root) if file.split('.')[-1] == "seq"]
25 | print(file_list)
26 | for file in file_list[:1]:
27 | path = os.path.join(root, file)
28 | str_list, len = read_seq(path)
29 | imgs = np.zeros([len - 1, 480, 640, 3])
30 | idx = 0
31 | for str in str_list[1:]:
32 | imgs[idx] = seq_to_images(str)
33 | idx += 1
34 | return imgs.transpose(0, 3, 1, 2)
35 |
36 | class Caltech(Dataset):
37 | def __init__(self, root, is_train=True, n_frames_input=4, n_frames_output=1):
38 | super().__init__()
39 | self.root = root
40 | print("loading .seq file list")
41 | self.file_list = [file for file in os.listdir(self.root) if file.split('.')[-1] == "seq"]
42 |
43 | if is_train:
44 | self.file_list = self.file_list[:-1]
45 | else:
46 | self.file_list = self.file_list[-1:]
47 |
48 | print("loading file list done, file list: ", self.file_list)
49 | self.length = 0
50 |
51 | self.input_length = n_frames_input
52 | self.output_length = n_frames_output
53 |
54 | self.current_seq = None
55 | self.current_length = 0
56 | self.current_file_index = 0
57 |
58 | self.get_next = True
59 |
60 | self.get_total_len(root)
61 | self.get_current_data()
62 |
63 |
64 | def get_total_len(self, root):
65 | print("calculating total length")
66 | count = 0
67 | for file in self.file_list:
68 | path = os.path.join(root, file)
69 | _, len = read_seq(path)
70 | count += (len - 5)
71 | self.length = count
72 | print("calculating total length done, total length: ", self.length)
73 |
74 | def get_current_data(self):
75 | print("getting current sequence")
76 | if self.current_file_index >= len(self.file_list):
77 | self.get_next = False
78 | return
79 | current_file = os.path.join(self.root, self.file_list[self.current_file_index])
80 | str_list, length = read_seq(current_file)
81 | self.current_length = length - 5
82 | self.current_seq = np.zeros([length - 1, 480, 640, 3])
83 | for i, str in enumerate(str_list[1:]):
84 | self.current_seq[i] = seq_to_images(str)
85 | print("getting current sequence done, the shape:", self.current_seq.shape)
86 |
87 | def get_next_seq(self):
88 | print("getting next sequence")
89 | self.current_file_index += 1
90 | self.get_current_data()
91 | self.get_next = False
92 |
93 | def __getitem__(self, index):
94 | if index >= self.current_length:
95 | self.get_next = True
96 | if self.get_next:
97 | self.get_next_seq()
98 | input = self.current_seq[index: index + self.input_length]
99 | output = self.current_seq[index + self.input_length: index + self.input_length + self.output_length]
100 | input = torch.from_numpy(input).contiguous().float()
101 | output = torch.from_numpy(output).contiguous().float()
102 | return input, output
103 |
104 | def __len__(self):
105 | return self.current_length
106 |
107 |
108 | def load_data(
109 | batch_size, val_batch_size,
110 | data_root, num_workers):
111 |
112 | train_set = Caltech(root=data_root, is_train=True,
113 | n_frames_input=10, n_frames_output=10, num_objects=[2])
114 | test_set = Caltech(root=data_root, is_train=False,
115 | n_frames_input=10, n_frames_output=10, num_objects=[2])
116 |
117 | dataloader_train = torch.utils.data.DataLoader(
118 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
119 | dataloader_validation = torch.utils.data.DataLoader(
120 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
121 | dataloader_test = torch.utils.data.DataLoader(
122 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
123 |
124 | mean, std = 0, 1
125 | return dataloader_train, dataloader_validation, dataloader_test, mean, std
126 |
127 |
128 | if __name__ == "__main__":
129 | # data = load_caltech("/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01")
130 | dataset = Caltech(root="/home/pan/workspace/simvp/SimVP-Simpler-yet-Better-Video-Prediction-master/data/caltech/USA/set01")
131 | dataloader = torch.utils.data.DataLoader(
132 | dataset, batch_size=16, shuffle=True, drop_last=True)
133 |
134 | for input, output in dataloader:
135 | print(input.shape, output.shape)
--------------------------------------------------------------------------------
/API/dataloader_moving_mnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gzip
3 | import random
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 |
8 |
9 | def load_mnist(root):
10 | # Load MNIST dataset for generating training data.
11 | path = os.path.join(root, 'moving_mnist/train-images-idx3-ubyte.gz')
12 | with gzip.open(path, 'rb') as f:
13 | mnist = np.frombuffer(f.read(), np.uint8, offset=16)
14 | mnist = mnist.reshape(-1, 28, 28)
15 | return mnist
16 |
17 |
18 | def load_fixed_set(root):
19 | # Load the fixed dataset
20 | filename = 'moving_mnist/mnist_test_seq.npy'
21 | path = os.path.join(root, filename)
22 | dataset = np.load(path)
23 | dataset = dataset[..., np.newaxis]
24 | return dataset
25 |
26 |
27 | class MovingMNIST(data.Dataset):
28 | def __init__(self, root, is_train=True, n_frames_input=10, n_frames_output=10, num_objects=[2],
29 | transform=None):
30 | super(MovingMNIST, self).__init__()
31 |
32 | self.dataset = None
33 | if is_train:
34 | self.mnist = load_mnist(root)
35 | else:
36 | if num_objects[0] != 2:
37 | self.mnist = load_mnist(root)
38 | else:
39 | self.dataset = load_fixed_set(root)
40 | self.length = int(1e4) if self.dataset is None else self.dataset.shape[1]
41 |
42 | self.is_train = is_train
43 | self.num_objects = num_objects
44 | self.n_frames_input = n_frames_input
45 | self.n_frames_output = n_frames_output
46 | self.n_frames_total = self.n_frames_input + self.n_frames_output
47 | self.transform = transform
48 | # For generating data
49 | self.image_size_ = 64
50 | self.digit_size_ = 28
51 | self.step_length_ = 0.1
52 |
53 | self.mean = 0
54 | self.std = 1
55 |
56 | def get_random_trajectory(self, seq_length):
57 | ''' Generate a random sequence of a MNIST digit '''
58 | canvas_size = self.image_size_ - self.digit_size_
59 | x = random.random()
60 | y = random.random()
61 | theta = random.random() * 2 * np.pi
62 | v_y = np.sin(theta)
63 | v_x = np.cos(theta)
64 |
65 | start_y = np.zeros(seq_length)
66 | start_x = np.zeros(seq_length)
67 | for i in range(seq_length):
68 | # Take a step along velocity.
69 | y += v_y * self.step_length_
70 | x += v_x * self.step_length_
71 |
72 | # Bounce off edges.
73 | if x <= 0:
74 | x = 0
75 | v_x = -v_x
76 | if x >= 1.0:
77 | x = 1.0
78 | v_x = -v_x
79 | if y <= 0:
80 | y = 0
81 | v_y = -v_y
82 | if y >= 1.0:
83 | y = 1.0
84 | v_y = -v_y
85 | start_y[i] = y
86 | start_x[i] = x
87 |
88 | # Scale to the size of the canvas.
89 | start_y = (canvas_size * start_y).astype(np.int32)
90 | start_x = (canvas_size * start_x).astype(np.int32)
91 | return start_y, start_x
92 |
93 | def generate_moving_mnist(self, num_digits=2):
94 | '''
95 | Get random trajectories for the digits and generate a video.
96 | '''
97 | data = np.zeros((self.n_frames_total, self.image_size_,
98 | self.image_size_), dtype=np.float32)
99 | for n in range(num_digits):
100 | # Trajectory
101 | start_y, start_x = self.get_random_trajectory(self.n_frames_total)
102 | ind = random.randint(0, self.mnist.shape[0] - 1)
103 | digit_image = self.mnist[ind]
104 | for i in range(self.n_frames_total):
105 | top = start_y[i]
106 | left = start_x[i]
107 | bottom = top + self.digit_size_
108 | right = left + self.digit_size_
109 | # Draw digit
110 | data[i, top:bottom, left:right] = np.maximum(
111 | data[i, top:bottom, left:right], digit_image)
112 |
113 | data = data[..., np.newaxis]
114 | return data
115 |
116 | def __getitem__(self, idx):
117 | length = self.n_frames_input + self.n_frames_output
118 | if self.is_train or self.num_objects[0] != 2:
119 | # Sample number of objects
120 | num_digits = random.choice(self.num_objects)
121 | # Generate data on the fly
122 | images = self.generate_moving_mnist(num_digits)
123 | else:
124 | images = self.dataset[:, idx, ...]
125 |
126 | r = 1
127 | w = int(64 / r)
128 | images = images.reshape((length, w, r, w, r)).transpose(
129 | 0, 2, 4, 1, 3).reshape((length, r * r, w, w))
130 |
131 | input = images[:self.n_frames_input]
132 | if self.n_frames_output > 0:
133 | output = images[self.n_frames_input:length]
134 | else:
135 | output = []
136 |
137 | output = torch.from_numpy(output / 255.0).contiguous().float()
138 | input = torch.from_numpy(input / 255.0).contiguous().float()
139 | return input, output
140 |
141 | def __len__(self):
142 | return self.length
143 |
144 |
145 | def load_data(
146 | batch_size, val_batch_size,
147 | data_root, num_workers):
148 |
149 | train_set = MovingMNIST(root=data_root, is_train=True,
150 | n_frames_input=10, n_frames_output=10, num_objects=[2])
151 | test_set = MovingMNIST(root=data_root, is_train=False,
152 | n_frames_input=10, n_frames_output=10, num_objects=[2])
153 |
154 | dataloader_train = torch.utils.data.DataLoader(
155 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
156 | dataloader_validation = torch.utils.data.DataLoader(
157 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
158 | dataloader_test = torch.utils.data.DataLoader(
159 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
160 |
161 | mean, std = 0, 1
162 | return dataloader_train, dataloader_validation, dataloader_test, mean, std
163 |
--------------------------------------------------------------------------------
/API/dataloader_moving_mnist_v2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data as data
4 |
5 |
6 | class MovingMnistSequence(data.Dataset):
7 | def __init__(self, train=True, shuffle=True, root='./data', transform=None):
8 | super().__init__()
9 | if train:
10 | npz = 'mnist_train.npz'
11 | self.data = np.load(f'{root}/{npz}')['input_raw_data']
12 | else:
13 | npz = 'mnist_train.npz'
14 | self.data = np.load(f'{root}/{npz}')['input_raw_data'][:10000]
15 |
16 |
17 | self.transform = transform
18 | self.data = self.data.transpose(0, 2, 3, 1)
19 |
20 | def __len__(self):
21 | return self.data.shape[0] // 20
22 |
23 | def __getitem__(self, index):
24 | imgs = self.data[index * 20: (index + 1) * 20]
25 | imgs_tensor = torch.zeros([20, 1, 64, 64])
26 | if self.transform is not None:
27 | for i in range(imgs.shape[0]):
28 | imgs_tensor[i] = self.transform(imgs[i])
29 | return imgs_tensor
30 |
31 |
32 | def load_data(
33 | batch_size, val_batch_size,
34 | data_root, num_workers):
35 |
36 | train_set = MovingMnistSequence(root=data_root, is_train=True,
37 | n_frames_input=10, n_frames_output=10, num_objects=[2])
38 | test_set = MovingMnistSequence(root=data_root, is_train=False,
39 | n_frames_input=10, n_frames_output=10, num_objects=[2])
40 |
41 | dataloader_train = torch.utils.data.DataLoader(
42 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
43 | dataloader_validation = torch.utils.data.DataLoader(
44 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
45 | dataloader_test = torch.utils.data.DataLoader(
46 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
47 |
48 | mean, std = 0, 1
49 | return dataloader_train, dataloader_validation, dataloader_test, mean, std
50 |
--------------------------------------------------------------------------------
/API/dataloader_sevir.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import numpy as np
3 | import random
4 | from torch.utils.data import Dataset
5 | from torch.utils.data import DataLoader
6 | import torchvision.transforms as transforms
7 | import matplotlib.pyplot as plt
8 | import torch
9 |
10 | class VilDataset(Dataset):
11 | def __init__(self, train=True, root='./data', transform=None):
12 | super().__init__()
13 | if train:
14 | npy = ['SEVIR_IR069_STORMEVENTS_2018_0101_0630.npy', 'SEVIR_IR069_STORMEVENTS_2018_0701_1231.npy']
15 | else:
16 | npy = ['SEVIR_IR069_RANDOMEVENTS_2018_0101_0430.npy']
17 |
18 | data = []
19 | for file in npy:
20 | data.append(np.load(f'{root}/{file}'))
21 |
22 | self.data = np.concatenate(data)
23 | #N, L, H, W = self.data.shape
24 | # self.data = self.data.reshape([N L, H, W])
25 | self.transform = transform
26 | self.mean = 0
27 | self.std = 1
28 |
29 | def __len__(self):
30 | return self.data.shape[0]
31 |
32 | def __getitem__(self, index):
33 | img = self.data[index].reshape(20, 1, 128, 128)
34 | if self.transform:
35 | img = self.transform(img)
36 |
37 | input_img = img[:10]
38 | output_img = img[10:]
39 | input_img = img[:10]
40 | output_img = img[10:]
41 | input_img = torch.from_numpy(input_img)
42 | output_img = torch.from_numpy(output_img)
43 | input_img = input_img.contiguous().float()
44 | output_img = output_img.contiguous().float()
45 | return input_img, output_img
46 |
47 |
48 | def load_data(batch_size, val_batch_size,
49 | data_root, num_workers):
50 | train_set = VilDataset(train=True, root='./data', transform=None)
51 | test_set = VilDataset(train=True, root='./data', transform=None)
52 |
53 | dataloader_train = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True,
54 | num_workers=num_workers)
55 | dataloader_validation = DataLoader(test_set, batch_size=val_batch_size, shuffle=False,
56 | pin_memory=True, num_workers=num_workers)
57 | dataloader_test = DataLoader(test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True,
58 | num_workers=num_workers)
59 | mean, std = 0, 1
60 |
61 | return dataloader_train, dataloader_validation, dataloader_test, mean, std
62 |
63 |
64 | if __name__ == '__main__':
65 | dataset = VilDataset(root='/root/Model_Phy/data')
66 | input_img, output_img = dataset[1]
67 | # Assuming `input_img` is a NumPy array of shape (10, 64, 64, 1)
68 | fig, axes = plt.subplots(nrows=1, ncols=10)
69 |
70 | for i in range(10):
71 | axes[i].imshow(input_img[i, :, :, 0], cmap=None)
72 | axes[i].axis('off')
73 |
74 | plt.show()
75 |
76 | fig, axes = plt.subplots(nrows=1, ncols=10)
77 |
78 | for i in range(10):
79 | axes[i].imshow(output_img[i, :, :, 0], cmap=None)
80 | axes[i].axis('off')
81 |
82 | plt.show()
--------------------------------------------------------------------------------
/API/dataloader_taxibj.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 |
5 |
6 | class TrafficDataset(Dataset):
7 | def __init__(self, X, Y):
8 | super(TrafficDataset, self).__init__()
9 | self.X = (X + 1) / 2
10 | self.Y = (Y + 1) / 2
11 | self.mean = 0
12 | self.std = 1
13 |
14 | def __len__(self):
15 | return self.X.shape[0]
16 |
17 | def __getitem__(self, index):
18 | data = torch.tensor(self.X[index, ::]).float()
19 | labels = torch.tensor(self.Y[index, ::]).float()
20 | return data, labels
21 |
22 | def load_data(
23 | batch_size, val_batch_size,
24 | data_root, num_workers):
25 |
26 | dataset = np.load(data_root+'taxibj/dataset.npz')
27 | X_train, Y_train, X_test, Y_test = dataset['X_train'], dataset['Y_train'], dataset['X_test'], dataset['Y_test']
28 |
29 | train_set = TrafficDataset(X=X_train, Y=Y_train)
30 | test_set = TrafficDataset(X=X_test, Y=Y_test)
31 |
32 | dataloader_train = torch.utils.data.DataLoader(
33 | train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)
34 | dataloader_test = torch.utils.data.DataLoader(
35 | test_set, batch_size=val_batch_size, shuffle=False, pin_memory=True, num_workers=num_workers)
36 |
37 | return dataloader_train, None, dataloader_test, 0, 1
--------------------------------------------------------------------------------
/API/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.metrics import structural_similarity as cal_ssim
3 |
4 | def MAE(pred, true):
5 | return np.mean(np.abs(pred-true),axis=(0,1)).sum()
6 |
7 |
8 | def MSE(pred, true):
9 | return np.mean((pred-true)**2,axis=(0,1)).sum()
10 |
11 | # cite the `PSNR` code from E3d-LSTM, Thanks!
12 | # https://github.com/google/e3d_lstm/blob/master/src/trainer.py line 39-40
13 | def PSNR(pred, true):
14 | mse = np.mean((np.uint8(pred * 255)-np.uint8(true * 255))**2)
15 | return 20 * np.log10(255) - 10 * np.log10(mse)
16 |
17 | def metric(pred, true, mean, std, return_ssim_psnr=False, clip_range=[0, 1]):
18 | pred = pred*std + mean
19 | true = true*std + mean
20 | mae = MAE(pred, true)
21 | mse = MSE(pred, true)
22 |
23 | if return_ssim_psnr:
24 | pred = np.maximum(pred, clip_range[0])
25 | pred = np.minimum(pred, clip_range[1])
26 | ssim, psnr = 0, 0
27 | for b in range(pred.shape[0]):
28 | for f in range(pred.shape[1]):
29 | ssim += cal_ssim(pred[b, f].swapaxes(0, 2), true[b, f].swapaxes(0, 2), multichannel=True)
30 | psnr += PSNR(pred[b, f], true[b, f])
31 | ssim = ssim / (pred.shape[0] * pred.shape[1])
32 | psnr = psnr / (pred.shape[0] * pred.shape[1])
33 | return mse, mae, ssim, psnr
34 | else:
35 | return mse, mae
--------------------------------------------------------------------------------
/API/recorder.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | class Recorder:
5 | def __init__(self, verbose=False, delta=0):
6 | self.verbose = verbose
7 | self.best_score = None
8 | self.val_loss_min = np.Inf
9 | self.delta = delta
10 |
11 | def __call__(self, val_loss, model, path):
12 | score = -val_loss
13 | if self.best_score is None:
14 | self.best_score = score
15 | self.save_checkpoint(val_loss, model, path)
16 | elif score >= self.best_score + self.delta:
17 | self.best_score = score
18 | self.save_checkpoint(val_loss, model, path)
19 |
20 | def save_checkpoint(self, val_loss, model, path):
21 | if self.verbose:
22 | print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
23 | torch.save(model.state_dict(), path+'/'+'checkpoint.pth')
24 | self.val_loss_min = val_loss
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PastNet: Introducing Physical Inductive Biases for Spatio-temporal Video Prediction (ACMMM2024)
2 |
3 |
4 | ## Abstract
5 |
6 | In this paper, we investigate the challenge of spatio-temporal video prediction, which involves generating future videos based on historical data streams. Existing approaches typically utilize external information such as semantic maps to enhance video prediction, which often neglect the inherent physical knowledge embedded within videos. Furthermore, their high computational demands could impede their applications for high-resolution videos. To address these constraints, we introduce a novel approach called **Physics-assisted Spatio-temporal Network (PastNet)** for generating high-quality video prediction. The core of our PastNet lies in incorporating a spectral convolution operator in the Fourier domain, which efficiently introduces inductive biases from the underlying physical laws. Additionally, we employ a memory bank with the estimated intrinsic dimensionality to discretize local features during the processing of complex spatio-temporal signals, thereby reducing computational costs and facilitating efficient high-resolution video prediction. Extensive experiments on various widely-used datasets demonstrate the effectiveness and efficiency of the proposed PastNet compared with a range of state-of-the-art methods, particularly in high-resolution scenarios.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Overview
16 |
17 | * `API/` contains dataloaders and metrics.
18 | * `modules/` contains several used module blocks.
19 | * `models/` contains the PastNet model.
20 | * `train_model.py` and `exp_vq.py` are the core files for training, validating, and testing pipelines.
21 |
22 |
23 |
24 | ## Citation
25 |
26 | If you are interested in our repository and our paper, please cite the following paper:
27 |
28 | ```
29 | @inproceedings{wu2024pastnet,
30 | title={Pastnet: Introducing physical inductive biases for spatio-temporal video prediction},
31 | author={Wu, Hao and Xu, Fan and Chen, Chong and Hua, Xian-Sheng and Luo, Xiao and Wang, Haixin},
32 | booktitle={Proceedings of the 32nd ACM International Conference on Multimedia},
33 | pages={2917--2926},
34 | year={2024}
35 | }
36 | ```
37 |
38 |
--------------------------------------------------------------------------------
/__pycache__/exp_vq.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/exp_vq.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/params.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/params.cpython-38.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/data/moving_mnist/download_mmnist.sh:
--------------------------------------------------------------------------------
1 | wget http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy
2 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
--------------------------------------------------------------------------------
/exp_vq.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import json
4 | import torch
5 | import pickle
6 | import logging
7 | import numpy as np
8 | from models.PastNet_Model import PastNetModel
9 | from tqdm import tqdm
10 | from API import *
11 | from utils import *
12 |
13 |
14 | def relative_l1_error(true_values, predicted_values):
15 | error = torch.abs(true_values - predicted_values)
16 | return torch.mean(error / torch.abs(true_values))
17 |
18 |
19 | class PastNet_exp:
20 | def __init__(self, args):
21 | super(PastNet_exp, self).__init__()
22 | self.args = args
23 | self.config = self.args.__dict__
24 | self.device = self._acquire_device()
25 |
26 | self._preparation()
27 | print_log(output_namespace(self.args))
28 |
29 | self._get_data()
30 | self._select_optimizer()
31 | self._select_criterion()
32 |
33 | def _acquire_device(self):
34 | if self.args.use_gpu:
35 | os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu)
36 | device = torch.device('cuda:{}'.format(0))
37 | print_log('Use GPU: {}'.format(self.args.gpu))
38 | else:
39 | device = torch.device('cpu')
40 | print_log('Use CPU')
41 | return device
42 |
43 | def _preparation(self):
44 | # seed
45 | set_seed(self.args.seed)
46 | # log and checkpoint
47 | self.path = osp.join(self.args.res_dir, self.args.ex_name)
48 | check_dir(self.path)
49 |
50 | self.checkpoints_path = osp.join(self.path, 'checkpoints')
51 | check_dir(self.checkpoints_path)
52 |
53 | sv_param = osp.join(self.path, 'model_param.json')
54 | with open(sv_param, 'w') as file_obj:
55 | json.dump(self.args.__dict__, file_obj)
56 |
57 | for handler in logging.root.handlers[:]:
58 | logging.root.removeHandler(handler)
59 | logging.basicConfig(level=logging.INFO, filename=osp.join(self.path, 'log.log'),
60 | filemode='a', format='%(asctime)s - %(message)s')
61 | # prepare data
62 | self._get_data()
63 | # build the model
64 | self._build_model()
65 |
66 | def _build_model(self):
67 | args = self.args
68 | freeze = args.freeze_vqvae == 1
69 | self.model = PastNetModel(args,
70 | shape_in=tuple(args.in_shape),
71 | hid_T=args.hid_T,
72 | N_T=args.N_T,
73 | res_units=args.res_units,
74 | res_layers=args.res_layers,
75 | embedding_nums=args.K,
76 | embedding_dim=args.D).to(self.device)
77 |
78 | def _get_data(self):
79 | config = self.args.__dict__
80 | self.train_loader, self.vali_loader, self.test_loader, self.data_mean, self.data_std = load_data(**config)
81 | self.vali_loader = self.test_loader if self.vali_loader is None else self.vali_loader
82 |
83 | def _select_optimizer(self):
84 | self.optimizer = torch.optim.Adam(
85 | self.model.parameters(), lr=self.args.lr)
86 | self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
87 | self.optimizer, max_lr=self.args.lr, steps_per_epoch=len(self.train_loader), epochs=self.args.epochs)
88 | return self.optimizer
89 |
90 | def _select_criterion(self):
91 | self.criterion = torch.nn.MSELoss()
92 |
93 | def _save(self, name=''):
94 | torch.save(self.model.state_dict(), os.path.join(
95 | self.checkpoints_path, name + '.pth'))
96 | state = self.scheduler.state_dict()
97 | fw = open(os.path.join(self.checkpoints_path, name + '.pkl'), 'wb')
98 | pickle.dump(state, fw)
99 |
100 | def train(self, args):
101 | config = args.__dict__
102 | recorder = Recorder(verbose=True)
103 |
104 | for epoch in range(config['epochs']):
105 | train_loss = []
106 | self.model.train()
107 | train_pbar = tqdm(self.train_loader)
108 |
109 | for batch_x, batch_y in train_pbar:
110 | self.optimizer.zero_grad()
111 | batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
112 | pred_y = self.model(batch_x)
113 |
114 | loss = self.criterion(pred_y, batch_y)
115 | train_loss.append(loss.item())
116 | train_pbar.set_description('train loss: {:.4f}'.format(loss.item()))
117 |
118 | loss.backward()
119 | self.optimizer.step()
120 | self.scheduler.step()
121 |
122 | train_loss = np.average(train_loss)
123 |
124 | if epoch % args.log_step == 0:
125 | with torch.no_grad():
126 | vali_loss = self.vali(self.vali_loader)
127 | if epoch % (args.log_step * 100) == 0:
128 | self._save(name=str(epoch))
129 | print_log("Epoch: {0} | Train Loss: {1:.4f} Vali Loss: {2:.4f}\n".format(
130 | epoch + 1, train_loss, vali_loss))
131 | recorder(vali_loss, self.model, self.path)
132 |
133 | best_model_path = self.path + '/' + 'checkpoint.pth'
134 | self.model.load_state_dict(torch.load(best_model_path))
135 | return self.model
136 |
137 | # def vali(self, vali_loader):
138 | # self.model.eval()
139 | # preds_lst, trues_lst, total_loss = [], [], []
140 | # vali_pbar = tqdm(vali_loader)
141 | # for i, (batch_x, batch_y) in enumerate(vali_pbar):
142 | # if i * batch_x.shape[0] > 1000:
143 | # break
144 |
145 | # batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
146 | # pred_y = self.model(batch_x)
147 | # list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
148 | # pred_y, batch_y], [preds_lst, trues_lst]))
149 |
150 | # loss = self.criterion(pred_y, batch_y)
151 | # vali_pbar.set_description(
152 | # 'vali loss: {:.4f}'.format(loss.mean().item()))
153 | # total_loss.append(loss.mean().item())
154 |
155 | # total_loss = np.average(total_loss)
156 | # preds = np.concatenate(preds_lst, axis=0)
157 | # trues = np.concatenate(trues_lst, axis=0)
158 | # mse, mae, ssim, psnr = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std, True)
159 | # print_log('vali mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr))
160 |
161 | # l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item()
162 | # relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues), torch.zeros_like(torch.tensor(trues))).item()
163 |
164 | # l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item()
165 | # rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item()
166 |
167 | # # 计算RMSE
168 | # rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2))
169 | # rmse = rmse.item()
170 |
171 | # print_log('RMSE: {:.7f}'.format(rmse))
172 | # print_log('L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, rel_l1_err, l2_error, relative_l2_error))
173 |
174 | # self.model.train()
175 | # return total_loss
176 |
177 | def vali(self, vali_loader):
178 | self.model.eval()
179 | preds_lst, trues_lst, total_loss = [], [], []
180 | vali_pbar = tqdm(vali_loader)
181 | for i, (batch_x, batch_y) in enumerate(vali_pbar):
182 | if i * batch_x.shape[0] > 1000:
183 | break
184 |
185 | batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
186 | pred_y = self.model(batch_x)
187 | list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
188 | pred_y, batch_y], [preds_lst, trues_lst]))
189 |
190 | loss = self.criterion(pred_y, batch_y)
191 | vali_pbar.set_description(
192 | 'vali loss: {:.4f}'.format(loss.mean().item()))
193 | total_loss.append(loss.mean().item())
194 |
195 | total_loss = np.average(total_loss)
196 | preds = np.concatenate(preds_lst, axis=0)
197 | trues = np.concatenate(trues_lst, axis=0)
198 | mse, mae, ssim, psnr = metric(preds, trues, vali_loader.dataset.mean, vali_loader.dataset.std, True)
199 | # print_log('vali mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr))
200 |
201 | l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item()
202 | relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues),
203 | torch.zeros_like(torch.tensor(trues))).item()
204 |
205 | l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item()
206 | rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item()
207 |
208 | # calculate the RMSE, MSE, and MAE
209 | rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)).item()
210 | mse = torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2).item()
211 | mae = torch.mean(torch.abs(torch.tensor(preds) - torch.tensor(trues))).item()
212 |
213 | ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / (trues + 1e-8)
214 | ape[torch.tensor(trues) == 0] = 0 # set APE to zero where true value is zero
215 | mape = torch.mean(ape).item() * 100
216 | # ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / torch.abs(torch.tensor(trues))
217 | # mape = torch.mean(ape).item() * 100
218 |
219 | print_log(
220 | 'L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error,
221 | rel_l1_err,
222 | l2_error,
223 | relative_l2_error))
224 | print_log('RMSE: {:.7f}, MSE: {:.7f}, MAE: {:.7f}, MAPE: {:.7f}%'.format(rmse, mse, mae, mape))
225 |
226 | self.model.train()
227 | return total_loss
228 |
229 | def test(self, args):
230 | self.model.eval()
231 | inputs_lst, trues_lst, preds_lst = [], [], []
232 | for batch_x, batch_y in self.test_loader:
233 | pred_y = self.model(batch_x.to(self.device))
234 | list(map(lambda data, lst: lst.append(data.detach().cpu().numpy()), [
235 | batch_x, batch_y, pred_y], [inputs_lst, trues_lst, preds_lst]))
236 |
237 | inputs, trues, preds = map(lambda data: np.concatenate(
238 | data, axis=0), [inputs_lst, trues_lst, preds_lst])
239 |
240 | folder_path = self.path + '/results/{}/sv/'.format(args.ex_name)
241 | if not os.path.exists(folder_path):
242 | os.makedirs(folder_path)
243 |
244 | mse, mae, ssim, psnr = metric(preds, trues, self.test_loader.dataset.mean, self.test_loader.dataset.std, True)
245 | print_log('mse:{:.4f}, mae:{:.4f}, ssim:{:.4f}, psnr:{:.4f}'.format(mse, mae, ssim, psnr))
246 |
247 | l2_error = torch.nn.MSELoss()(torch.tensor(preds), torch.tensor(trues)).item()
248 | relative_l2_error = l2_error / torch.nn.MSELoss()(torch.tensor(trues),
249 | torch.zeros_like(torch.tensor(trues))).item()
250 |
251 | l1_error = torch.nn.L1Loss()(torch.tensor(preds), torch.tensor(trues)).item()
252 | rel_l1_err = relative_l1_error(torch.tensor(trues), torch.tensor(preds)).item()
253 |
254 | # 计算RMSE
255 | # rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2))
256 | # rmse = rmse.item()
257 |
258 | # print_log('RMSE: {:.7f}'.format(rmse))
259 | # print_log('L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error, rel_l1_err, l2_error, relative_l2_error))
260 |
261 | # calculate the RMSE, MSE, and MAE
262 | rmse = torch.sqrt(torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2)).item()
263 | mse = torch.mean((torch.tensor(preds) - torch.tensor(trues)) ** 2).item()
264 | mae = torch.mean(torch.abs(torch.tensor(preds) - torch.tensor(trues))).item()
265 | ape = torch.abs(torch.tensor(preds) - torch.tensor(trues)) / torch.abs(torch.tensor(trues))
266 | mape = torch.mean(ape).item() * 100
267 |
268 | print_log(
269 | 'L1 error: {:.7f}, Relative L1 Error: {:.7f}, L2 error: {:.7f}, Relative L2 error: {:.7f},'.format(l1_error,
270 | rel_l1_err,
271 | l2_error,
272 | relative_l2_error))
273 | print_log('RMSE: {:.7f}, MSE: {:.7f}, MAE: {:.7f}, MAPE: {:.7f}%'.format(rmse, mse, mae, mape))
274 |
275 | for np_data in ['inputs', 'trues', 'preds']:
276 | np.save(osp.join(folder_path, np_data + '.npy'), vars()[np_data])
277 | return mse
--------------------------------------------------------------------------------
/figure/mm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/figure/mm.png
--------------------------------------------------------------------------------
/id_estimate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from modules import ConvSC, Inception
4 | import torch.nn.functional as F
5 | import torch.fft
6 | import numpy as np
7 | import torch.optim as optimizer
8 | from functools import partial
9 | from collections import OrderedDict
10 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
11 | from torch.utils.checkpoint import checkpoint_sequential
12 | from einops import rearrange, reduce, repeat
13 | from einops.layers.torch import Rearrange, Reduce
14 | from sklearn.neighbors import NearestNeighbors
15 |
16 |
17 | def kNN(X, n_neighbors, n_jobs):
18 | X = X.cpu().detach().numpy()
19 | neigh = NearestNeighbors(n_neighbors=n_neighbors, n_jobs=n_jobs).fit(X)
20 | dists, inds = neigh.kneighbors(X)
21 | return dists, inds
22 |
23 | def Levina_Bickel(X, dists, k):
24 | m = np.log(dists[:, k:k+1] / dists[:, 1:k])
25 | m = (k-2) / np.sum(m, axis=1)
26 | dim = np.mean(m)
27 | return dim
28 |
29 | def estimate_dimension(latent_embedding, k=1000):
30 | B, T, C_, H_, W_ = latent_embedding.shape
31 | latent_embedding = latent_embedding.permute(0, 3, 4, 1, 2)
32 | X = latent_embedding.reshape(B*H_*W_, T*C_)
33 | dists, _ = kNN(X, k+1, n_jobs=-1)
34 | dim_estimate = Levina_Bickel(X, dists, k)
35 | return dim_estimate
--------------------------------------------------------------------------------
/models/Fourier.py:
--------------------------------------------------------------------------------
1 | from modules.Fourier_modules import *
2 | class FPG(nn.Module):
3 | def __init__(self,
4 | img_size=224,
5 | patch_size=16,
6 | in_channels=20,
7 | out_channels=20,
8 | input_frames=20,
9 | embed_dim=768,
10 | depth=12,
11 | mlp_ratio=4.,
12 | uniform_drop=False,
13 | drop_rate=0.,
14 | drop_path_rate=0.,
15 | norm_layer=None,
16 | dropcls=0.):
17 | super(FPG, self).__init__()
18 | self.embed_dim = embed_dim
19 | self.num_frames = input_frames
20 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
21 |
22 | self.patch_embed = PatchEmbed(img_size=img_size,
23 | patch_size=patch_size,
24 | in_c=in_channels,
25 | embed_dim=embed_dim)
26 | num_patches = self.patch_embed.num_patches
27 |
28 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # [1, 196, 768]
29 | self.pos_drop = nn.Dropout(p=drop_rate)
30 |
31 | self.h = self.patch_embed.grid_size[0]
32 | self.w = self.patch_embed.grid_size[1]
33 | '''
34 | stochastic depth decay rule
35 | '''
36 | if uniform_drop:
37 | dpr = [drop_path_rate for _ in range(depth)]
38 | else:
39 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
40 |
41 | self.blocks = nn.ModuleList([FourierNetBlock(
42 | dim=embed_dim,
43 | mlp_ratio=mlp_ratio,
44 | drop=drop_rate,
45 | drop_path=dpr[i],
46 | act_layer=nn.GELU,
47 | norm_layer=norm_layer,
48 | h=self.h,
49 | w=self.w)
50 | for i in range(depth)
51 | ])
52 |
53 | self.norm = norm_layer(embed_dim)
54 |
55 | self.linearprojection = nn.Sequential(OrderedDict([
56 | ('transposeconv1', nn.ConvTranspose2d(embed_dim, out_channels * 16, kernel_size=(2, 2), stride=(2, 2))),
57 | ('act1', nn.Tanh()),
58 | ('transposeconv2', nn.ConvTranspose2d(out_channels * 16, out_channels * 4, kernel_size=(2, 2), stride=(2, 2))),
59 | ('act2', nn.Tanh()),
60 | ('transposeconv3', nn.ConvTranspose2d(out_channels * 4, out_channels, kernel_size=(4, 4), stride=(4, 4)))
61 | ]))
62 |
63 | if dropcls > 0:
64 | print('dropout %.2f before classifier' % dropcls)
65 | self.final_dropout = nn.Dropout(p=dropcls)
66 | else:
67 | self.final_dropout = nn.Identity()
68 |
69 | trunc_normal_(self.pos_embed, std=.02)
70 | self.apply(self._init_weights)
71 |
72 |
73 | def _init_weights(self, m):
74 | if isinstance(m, nn.Linear):
75 | trunc_normal_(m.weight, std=.02)
76 | if isinstance(m, nn.Linear) and m.bias is not None:
77 | nn.init.constant_(m.bias, 0)
78 | elif isinstance(m, nn.LayerNorm):
79 | nn.init.constant_(m.bias, 0)
80 | nn.init.constant_(m.weight, 1.0)
81 |
82 | @torch.jit.ignore
83 | def no_weight_decay(self):
84 | return {'pos_embed', 'cls_token'}
85 |
86 | def forward_features(self, x):
87 | '''
88 | patch_embed:
89 | [B, T, C, H, W] -> [B*T, num_patches, embed_dim]
90 | '''
91 | B,T,C,H,W = x.shape
92 | x = x.view(B*T, C, H, W)
93 | x = self.patch_embed(x)
94 | #enc = LearnableFourierPositionalEncoding(768, 768, 64, 768, 10)
95 | #fourierpos_embed = enc(x)
96 | x = self.pos_drop(x + self.pos_embed)
97 | #x = self.pos_drop(x + fourierpos_embed)
98 |
99 |
100 | if not get_fourcastnet_args().checkpoint_activations:
101 | for blk in self.blocks:
102 | x = blk(x)
103 | else:
104 | x = checkpoint_sequential(self.blocks, 4, x)
105 |
106 | x = self.norm(x).transpose(1, 2)
107 | x = torch.reshape(x, [-1, self.embed_dim, self.h, self.w])
108 | return x
109 |
110 | def forward(self, x):
111 | B, T, C, H, W = x.shape
112 | x = self.forward_features(x)
113 | x = self.final_dropout(x)
114 | x = self.linearprojection(x)
115 | x = x.reshape(B, T, C, H, W)
116 | return x
117 |
118 |
--------------------------------------------------------------------------------
/models/PastNet_Model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import *
3 | import logging
4 | from torch import nn
5 | from modules.DiscreteSTModel_modules import *
6 | from modules.Fourier_modules import *
7 |
8 |
9 |
10 | def stride_generator(N, reverse=False):
11 | strides = [1, 2]*10
12 | if reverse:
13 | return list(reversed(strides[:N]))
14 | else:
15 | return strides[:N]
16 |
17 |
18 | class GroupConv2d(nn.Module):
19 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
20 | super(GroupConv2d, self).__init__()
21 | self.act_norm = act_norm
22 | if in_channels % groups != 0:
23 | groups = 1
24 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
25 | stride=stride, padding=padding, groups=groups)
26 | self.norm = nn.GroupNorm(groups, out_channels)
27 | self.activate = nn.LeakyReLU(0.2, inplace=True)
28 |
29 | def forward(self, x):
30 | y = self.conv(x)
31 | if self.act_norm:
32 | y = self.activate(self.norm(y))
33 | return y
34 |
35 |
36 | class Inception(nn.Module):
37 | def __init__(self, C_in, C_hid, C_out, incep_ker=[3, 5, 7, 11], groups=8):
38 | super(Inception, self).__init__()
39 | self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
40 | layers = []
41 | for ker in incep_ker:
42 | layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker,
43 | stride=1, padding=ker//2, groups=groups, act_norm=True))
44 | self.layers = nn.Sequential(*layers)
45 |
46 | def forward(self, x):
47 | x = self.conv1(x)
48 | y = 0
49 | for layer in self.layers:
50 | y += layer(x)
51 | return y
52 |
53 | class FPG(nn.Module):
54 | def __init__(self,
55 | img_size=224,
56 | patch_size=16,
57 | in_channels=20,
58 | out_channels=20,
59 | input_frames=20,
60 | embed_dim=768,
61 | depth=12,
62 | mlp_ratio=4.,
63 | uniform_drop=False,
64 | drop_rate=0.,
65 | drop_path_rate=0.,
66 | norm_layer=None,
67 | dropcls=0.):
68 | super(FPG, self).__init__()
69 | self.embed_dim = embed_dim
70 | self.num_frames = input_frames
71 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
72 |
73 | self.patch_embed = PatchEmbed(img_size=img_size,
74 | patch_size=patch_size,
75 | in_c=in_channels,
76 | embed_dim=embed_dim)
77 | num_patches = self.patch_embed.num_patches
78 |
79 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # [1, 196, 768]
80 | self.pos_drop = nn.Dropout(p=drop_rate)
81 |
82 | self.h = self.patch_embed.grid_size[0]
83 | self.w = self.patch_embed.grid_size[1]
84 | '''
85 | stochastic depth decay rule
86 | '''
87 | if uniform_drop:
88 | dpr = [drop_path_rate for _ in range(depth)]
89 | else:
90 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
91 |
92 | self.blocks = nn.ModuleList([FourierNetBlock(
93 | dim=embed_dim,
94 | mlp_ratio=mlp_ratio,
95 | drop=drop_rate,
96 | drop_path=dpr[i],
97 | act_layer=nn.GELU,
98 | norm_layer=norm_layer,
99 | h=self.h,
100 | w=self.w)
101 | for i in range(depth)
102 | ])
103 |
104 | self.norm = norm_layer(embed_dim)
105 |
106 | self.linearprojection = nn.Sequential(OrderedDict([
107 | ('transposeconv1', nn.ConvTranspose2d(embed_dim, out_channels * 16, kernel_size=(2, 2), stride=(2, 2))),
108 | ('act1', nn.Tanh()),
109 | ('transposeconv2', nn.ConvTranspose2d(out_channels * 16, out_channels * 4, kernel_size=(2, 2), stride=(2, 2))),
110 | ('act2', nn.Tanh()),
111 | ('transposeconv3', nn.ConvTranspose2d(out_channels * 4, out_channels, kernel_size=(4, 4), stride=(4, 4)))
112 | ]))
113 |
114 | if dropcls > 0:
115 | print('dropout %.2f before classifier' % dropcls)
116 | self.final_dropout = nn.Dropout(p=dropcls)
117 | else:
118 | self.final_dropout = nn.Identity()
119 |
120 | trunc_normal_(self.pos_embed, std=.02)
121 | self.apply(self._init_weights)
122 |
123 |
124 | def _init_weights(self, m):
125 | if isinstance(m, nn.Linear):
126 | trunc_normal_(m.weight, std=.02)
127 | if isinstance(m, nn.Linear) and m.bias is not None:
128 | nn.init.constant_(m.bias, 0)
129 | elif isinstance(m, nn.LayerNorm):
130 | nn.init.constant_(m.bias, 0)
131 | nn.init.constant_(m.weight, 1.0)
132 |
133 | @torch.jit.ignore
134 | def no_weight_decay(self):
135 | return {'pos_embed', 'cls_token'}
136 |
137 | def forward_features(self, x):
138 | '''
139 | patch_embed:
140 | [B, T, C, H, W] -> [B*T, num_patches, embed_dim]
141 | '''
142 | B,T,C,H,W = x.shape
143 | x = x.view(B*T, C, H, W)
144 | x = self.patch_embed(x)
145 | #enc = LearnableFourierPositionalEncoding(768, 768, 64, 768, 10)
146 | #fourierpos_embed = enc(x)
147 | x = self.pos_drop(x + self.pos_embed)
148 | #x = self.pos_drop(x + fourierpos_embed)
149 |
150 |
151 | if not get_fourcastnet_args().checkpoint_activations:
152 | for blk in self.blocks:
153 | x = blk(x)
154 | else:
155 | x = checkpoint_sequential(self.blocks, 4, x)
156 |
157 | x = self.norm(x).transpose(1, 2)
158 | x = torch.reshape(x, [-1, self.embed_dim, self.h, self.w])
159 | return x
160 |
161 | def forward(self, x):
162 | B, T, C, H, W = x.shape
163 | x = self.forward_features(x)
164 | x = self.final_dropout(x)
165 | x = self.linearprojection(x)
166 | x = x.reshape(B, T, C, H, W)
167 | return x
168 |
169 | class DST(nn.Module):
170 | def __init__(self,
171 | in_channel=1,
172 | num_hiddens=128,
173 | res_layers=2,
174 | res_units=32,
175 | embedding_nums=512, # K
176 | embedding_dim=64, # D
177 | commitment_cost=0.25):
178 | super(DST, self).__init__()
179 | self.embedding_dim = embedding_dim
180 | self.num_embeddings = embedding_nums
181 | self._encoder = Encoder(in_channel, num_hiddens,
182 | res_layers, res_units) #
183 | self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
184 | out_channels=embedding_dim,
185 | kernel_size=1,
186 | stride=1)
187 |
188 | # code book
189 | self._vq_vae = VectorQuantizerEMA(embedding_nums,
190 | embedding_dim,
191 | commitment_cost,
192 | decay=0.99)
193 |
194 | self._decoder = Decoder(embedding_dim,
195 | num_hiddens,
196 | res_layers,
197 | res_units,
198 | in_channel)
199 |
200 | def forward(self, x):
201 | # input shape : [B, C, W, H]
202 | z = self._encoder(x) # [B, hidden_units, W//4, H//4]
203 | # [B, embedding_dims, W//4, H//4] z -> encoding
204 | z = self._pre_vq_conv(z)
205 | # quantized -> embedding, quantized相当于videoGPT中的 encoder输出
206 | loss, quantized, perplexity, _ = self._vq_vae(z)
207 | x_recon = self._decoder(quantized)
208 | return loss, x_recon, perplexity
209 |
210 | def get_embedding(self, x):
211 | return self._pre_vq_conv(self._encoder(x))
212 |
213 | def get_quantization(self, x):
214 | z = self._encoder(x)
215 | z = self._pre_vq_conv(z)
216 | _, quantized, _, _ = self._vq_vae(z)
217 | return quantized
218 |
219 | def reconstruct_img_by_embedding(self, embedding):
220 | loss, quantized, perplexity, _ = self._vq_vae(embedding)
221 | return self._decoder(quantized)
222 |
223 | def reconstruct_img(self, q):
224 | return self._decoder(q)
225 |
226 | @property
227 | def pre_vq_conv(self):
228 | return self._pre_vq_conv
229 |
230 | @property
231 | def encoder(self):
232 | return self._encoder
233 |
234 |
235 | class DynamicPropagation(nn.Module):
236 | def __init__(self, channel_in, channel_hid, N_T, incep_ker=[3, 5, 7, 11], groups=8):
237 | super(DynamicPropagation, self).__init__()
238 |
239 | self.N_T = N_T
240 | enc_layers = [Inception(
241 | channel_in, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)]
242 | for i in range(1, N_T-1):
243 | enc_layers.append(Inception(
244 | channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups))
245 | enc_layers.append(Inception(channel_hid, channel_hid //
246 | 2, channel_hid, incep_ker=incep_ker, groups=groups))
247 |
248 | dec_layers = [Inception(
249 | channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups)]
250 | for i in range(1, N_T-1):
251 | dec_layers.append(Inception(
252 | 2*channel_hid, channel_hid//2, channel_hid, incep_ker=incep_ker, groups=groups))
253 | dec_layers.append(Inception(2*channel_hid, channel_hid //
254 | 2, channel_in, incep_ker=incep_ker, groups=groups))
255 |
256 | self.enc = nn.Sequential(*enc_layers)
257 | self.dec = nn.Sequential(*dec_layers)
258 |
259 | def forward(self, input_state):
260 | B, T, C, H, W = input_state.shape
261 | input_state = input_state.reshape(B, T*C, H, W)
262 | # encoder
263 | skips = []
264 | hidden_embed = input_state
265 | for i in range(self.N_T):
266 | hidden_embed = self.enc[i](hidden_embed)
267 | if i < self.N_T - 1:
268 | skips.append(hidden_embed)
269 |
270 | # decoder
271 | hidden_embed = self.dec[0](hidden_embed)
272 | for i in range(1, self.N_T):
273 | hidden_embed = self.dec[i](torch.cat([hidden_embed, skips[-i]], dim=1))
274 |
275 | output_state = hidden_embed.reshape(B, T, C, H, W)
276 | return output_state
277 |
278 |
279 | class PastNetModel(nn.Module):
280 | def __init__(self,
281 | args,
282 | shape_in,
283 | hid_T=256,
284 | N_T=8,
285 | incep_ker=[3, 5, 7, 11],
286 | groups=8,
287 | res_units=64,
288 | res_layers=2,
289 | embedding_nums=512,
290 | embedding_dim=64):
291 | super(PastNetModel, self).__init__()
292 | T, C, H, W = shape_in
293 | self.DST_module = DST(in_channel=C,
294 | res_units=res_units,
295 | res_layers=res_layers,
296 | embedding_dim=embedding_dim,
297 | embedding_nums=embedding_nums)
298 |
299 | self.FPG_module = FPG(img_size=64,
300 | patch_size=16,
301 | in_channels=1,
302 | out_channels=1,
303 | embed_dim=128,
304 | input_frames=10,
305 | depth=1,
306 | mlp_ratio=2.,
307 | uniform_drop=False,
308 | drop_rate=0.,
309 | drop_path_rate=0.,
310 | norm_layer=None,
311 | dropcls=0.)
312 |
313 | if args.load_pred_train:
314 | print_log("Load Pre-trained Model.")
315 | self.vq_vae.load_state_dict(torch.load("./models/vqvae.ckpt"), strict=False)
316 |
317 | if args.freeze_vqvae:
318 | print_log(f"Params of VQVAE is freezed.")
319 | for p in self.vq_vae.parameters():
320 | p.requires_grad = False
321 | self.DynamicPro = DynamicPropagation(T*64, hid_T, N_T, incep_ker, groups)
322 |
323 | def forward(self, input_frames):
324 | B, T, C, H, W = input_frames.shape
325 | pde_features = self.FPG_module(input_frames)
326 | input_features = input_frames.view([B * T, C, H, W])
327 | encoder_embed = self.DST_module._encoder(input_features)
328 | z = self.DST_module._pre_vq_conv(encoder_embed)
329 | vq_loss, Latent_embed, _, _ = self.DST_module._vq_vae(z)
330 |
331 | _, C_, H_, W_ = Latent_embed.shape
332 | Latent_embed = Latent_embed.reshape(B, T, C_, H_, W_)
333 |
334 | hidden_dim = self.DynamicPro(Latent_embed)
335 | B_, T_, C_, H_, W_ = hidden_dim.shape
336 | hid = hidden_dim.reshape([B_ * T_, C_, H_, W_])
337 |
338 | predicti_feature = self.DST_module._decoder(hid)
339 | predicti_feature = predicti_feature.reshape([B, T, C, H, W]) + pde_features
340 |
341 | return predicti_feature
342 |
343 |
344 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/vqvae.ckpt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/models/vqvae.ckpt
--------------------------------------------------------------------------------
/modules/DiscreteSTModel_modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class VectorQuantizer(nn.Module):
7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost):
8 | super(VectorQuantizer, self).__init__()
9 |
10 | self._embedding_dim = embedding_dim # D
11 | self._num_embeddings = num_embeddings # K
12 |
13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) #
14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
15 | self._commitment_cost = commitment_cost
16 |
17 | def forward(self, inputs):
18 | # convert inputs from B, C, H, W -> B, H, W, C
19 | inputs = inputs.permute(0, 2, 3, 1).contiguous()
20 | input_shape = inputs.shape
21 |
22 | # Flatten input
23 | flat_input = inputs.view(-1, self._embedding_dim)
24 |
25 | # Calculate distances
26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
27 | + torch.sum(self._embedding.weight ** 2, dim=1)
28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
29 |
30 | # Encoding
31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
33 | encodings.scatter_(1, encoding_indices, 1)
34 |
35 | # Quantize and unflatten
36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
37 |
38 | # Loss
39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
40 | q_latent_loss = F.mse_loss(quantized, inputs.detach())
41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss
42 |
43 | quantized = inputs + (quantized - inputs).detach()
44 | avg_probs = torch.mean(encodings, dim=0)
45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
46 |
47 | # convert quantized from B, H, W, C -> B, C, H, W
48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
49 |
50 |
51 | def lookup(self, x):
52 | embeddings = F.embedding(x, self._embedding)
53 | return embeddings
54 |
55 |
56 | class VectorQuantizerEMA(nn.Module):
57 | def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay=0.99, epsilon=1e-5):
58 | super(VectorQuantizerEMA, self).__init__()
59 |
60 | self._embedding_dim = embedding_dim
61 | self._num_embeddings = num_embeddings
62 |
63 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
64 | self._embedding.weight.data.normal_()
65 | self._commitment_cost = commitment_cost
66 |
67 | self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
68 | self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
69 | self._ema_w.data.normal_()
70 |
71 | self._decay = decay
72 | self._epsilon = epsilon
73 |
74 | def forward(self, inputs):
75 | # convert inputs from BCHW -> BHWC
76 | inputs = inputs.permute(0, 2, 3, 1).contiguous()
77 | input_shape = inputs.shape
78 |
79 | # Flatten input
80 | flat_input = inputs.view(-1, self._embedding_dim)
81 |
82 | # Calculate distances
83 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
84 | + torch.sum(self._embedding.weight ** 2, dim=1)
85 | - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
86 |
87 | # Encoding
88 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
89 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
90 | encodings.scatter_(1, encoding_indices, 1)
91 |
92 | # Quantize and unflatten
93 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
94 |
95 | # Use EMA to update the embedding vectors
96 | if self.training:
97 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \
98 | (1 - self._decay) * torch.sum(encodings, 0)
99 |
100 | # Laplace smoothing of the cluster size
101 | n = torch.sum(self._ema_cluster_size.data)
102 | self._ema_cluster_size = (
103 | (self._ema_cluster_size + self._epsilon)
104 | / (n + self._num_embeddings * self._epsilon) * n)
105 |
106 | dw = torch.matmul(encodings.t(), flat_input)
107 | self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
108 |
109 | self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
110 |
111 | # Loss
112 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
113 | loss = self._commitment_cost * e_latent_loss
114 |
115 | # Straight Through Estimator
116 | quantized = inputs + (quantized - inputs).detach()
117 | avg_probs = torch.mean(encodings, dim=0)
118 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
119 |
120 | # convert quantized from BHWC -> BCHW
121 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
122 |
123 | def lookup(self, x):
124 | embeddings = F.embedding(x, self._embedding)
125 | return embeddings
126 |
127 |
128 | class Residual(nn.Module):
129 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
130 | super(Residual, self).__init__()
131 | self._block = nn.Sequential(
132 | nn.ReLU(True),
133 | nn.Conv2d(in_channels=in_channels,
134 | out_channels=num_residual_hiddens,
135 | kernel_size=3, stride=1, padding=1, bias=False),
136 | nn.BatchNorm2d(num_residual_hiddens),
137 | nn.ReLU(True),
138 | nn.Conv2d(in_channels=num_residual_hiddens,
139 | out_channels=num_hiddens,
140 | kernel_size=1, stride=1, bias=False),
141 | nn.BatchNorm2d(num_hiddens)
142 | )
143 |
144 | def forward(self, x):
145 | return x + self._block(x)
146 |
147 |
148 | class ResidualStack(nn.Module):
149 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
150 | super(ResidualStack, self).__init__()
151 | self._num_residual_layers = num_residual_layers
152 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
153 | for _ in range(self._num_residual_layers)])
154 |
155 | def forward(self, x):
156 | for i in range(self._num_residual_layers):
157 | x = self._layers[i](x)
158 | return F.relu(x)
159 |
160 |
161 | class Encoder(nn.Module):
162 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
163 | super(Encoder, self).__init__()
164 |
165 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
166 | out_channels=num_hiddens // 2,
167 | kernel_size=4,
168 | stride=2,
169 | padding=1)
170 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
171 | out_channels=num_hiddens,
172 | kernel_size=4,
173 | stride=2,
174 | padding=1)
175 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
176 | out_channels=num_hiddens,
177 | kernel_size=3,
178 | stride=1, padding=1)
179 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
180 | num_hiddens=num_hiddens,
181 | num_residual_layers=num_residual_layers,
182 | num_residual_hiddens=num_residual_hiddens)
183 |
184 | def forward(self, inputs):
185 | # input shape: [B, C, W, H]
186 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2]
187 | x = F.relu(x)
188 |
189 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4]
190 | x = F.relu(x)
191 |
192 | x = self._conv_3(x)
193 | return self._residual_stack(x)
194 |
195 |
196 | class Decoder(nn.Module):
197 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels):
198 | super(Decoder, self).__init__()
199 |
200 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
201 | out_channels=num_hiddens,
202 | kernel_size=3,
203 | stride=1, padding=1)
204 |
205 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
206 | num_hiddens=num_hiddens,
207 | num_residual_layers=num_residual_layers,
208 | num_residual_hiddens=num_residual_hiddens)
209 |
210 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
211 | out_channels=num_hiddens // 2,
212 | kernel_size=4,
213 | stride=2, padding=1)
214 |
215 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
216 | out_channels=out_channels,
217 | kernel_size=4,
218 | stride=2, padding=1)
219 |
220 | def forward(self, inputs):
221 | x = self._conv_1(inputs)
222 | x = self._residual_stack(x)
223 | x = self._conv_trans_1(x)
224 | x = F.relu(x)
225 | return self._conv_trans_2(x)
226 |
--------------------------------------------------------------------------------
/modules/DiscreteSTModel_modules_BN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class VectorQuantizer(nn.Module):
7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost):
8 | super(VectorQuantizer, self).__init__()
9 |
10 | self._embedding_dim = embedding_dim # D
11 | self._num_embeddings = num_embeddings # K
12 |
13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
15 | self._commitment_cost = commitment_cost
16 |
17 | def forward(self, inputs):
18 | # convert inputs from B, C, H, W -> B, H, W, C
19 | inputs = inputs.permute(0, 2, 3, 1).contiguous()
20 | input_shape = inputs.shape
21 |
22 | # Flatten input
23 | flat_input = inputs.view(-1, self._embedding_dim)
24 |
25 | # Calculate distances
26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
27 | + torch.sum(self._embedding.weight ** 2, dim=1)
28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # 平方差公式优化
29 |
30 | # Encoding
31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
33 | encodings.scatter_(1, encoding_indices, 1)
34 |
35 | # Quantize and unflatten
36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
37 |
38 | # Loss
39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
40 | q_latent_loss = F.mse_loss(quantized, inputs.detach())
41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss
42 |
43 | quantized = inputs + (quantized - inputs).detach()
44 | avg_probs = torch.mean(encodings, dim=0)
45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
46 |
47 | # convert quantized from B, H, W, C -> B, C, H, W
48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
49 |
50 | def lookup(self, x):
51 | embeddings = F.embedding(x, self._embedding)
52 | return embeddings
53 |
54 |
55 | class Residual(nn.Module):
56 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
57 | super(Residual, self).__init__()
58 | self._block = nn.Sequential(
59 | nn.ReLU(True),
60 | nn.Conv2d(in_channels=in_channels,
61 | out_channels=num_residual_hiddens,
62 | kernel_size=3, stride=1, padding=1, bias=False),
63 | nn.BatchNorm2d(num_residual_hiddens),
64 | nn.ReLU(True),
65 | nn.Conv2d(in_channels=num_residual_hiddens,
66 | out_channels=num_hiddens,
67 | kernel_size=1, stride=1, bias=False),
68 | nn.BatchNorm2d(num_hiddens)
69 | )
70 |
71 | def forward(self, x):
72 | return x + self._block(x)
73 |
74 |
75 | class ResidualStack(nn.Module):
76 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
77 | super(ResidualStack, self).__init__()
78 | self._num_residual_layers = num_residual_layers
79 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
80 | for _ in range(self._num_residual_layers)])
81 |
82 | def forward(self, x):
83 | for i in range(self._num_residual_layers):
84 | x = self._layers[i](x)
85 | return F.relu(x)
86 |
87 |
88 | class Encoder(nn.Module):
89 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
90 | super(Encoder, self).__init__()
91 |
92 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
93 | out_channels=num_hiddens // 2,
94 | kernel_size=4,
95 | stride=2,
96 | padding=1)
97 | self.norm_1 = nn.BatchNorm2d(num_hiddens // 2)
98 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
99 | out_channels=num_hiddens,
100 | kernel_size=4,
101 | stride=2,
102 | padding=1)
103 | self.norm_2 = nn.BatchNorm2d(num_hiddens)
104 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
105 | out_channels=num_hiddens,
106 | kernel_size=3,
107 | stride=1, padding=1)
108 | self.norm_3 = nn.BatchNorm2d(num_hiddens)
109 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
110 | num_hiddens=num_hiddens,
111 | num_residual_layers=num_residual_layers,
112 | num_residual_hiddens=num_residual_hiddens)
113 |
114 | def forward(self, inputs):
115 | # input shape: [B, C, W, H]
116 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2]
117 | x = F.relu(self.norm_1(x))
118 |
119 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4]
120 | x = F.relu(self.norm_2(x))
121 |
122 | x = self._conv_3(x)
123 | x = self.norm_3(x)
124 | return self._residual_stack(x)
125 |
126 |
127 | class Decoder(nn.Module):
128 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels):
129 | super(Decoder, self).__init__()
130 |
131 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
132 | out_channels=num_hiddens,
133 | kernel_size=3,
134 | stride=1, padding=1)
135 | self._norm_1 = nn.BatchNorm2d(num_hiddens)
136 |
137 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
138 | num_hiddens=num_hiddens,
139 | num_residual_layers=num_residual_layers,
140 | num_residual_hiddens=num_residual_hiddens)
141 |
142 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
143 | out_channels=num_hiddens // 2,
144 | kernel_size=4,
145 | stride=2, padding=1)
146 | self._norm_2 = nn.BatchNorm2d(num_hiddens // 2)
147 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
148 | out_channels=out_channels,
149 | kernel_size=4,
150 | stride=2, padding=1)
151 |
152 | def forward(self, inputs):
153 | x = self._conv_1(inputs)
154 | x = self._residual_stack(self._norm_1(x))
155 | x = self._conv_trans_1(x)
156 | x = F.relu(self._norm_2(x))
157 | return self._conv_trans_2(x)
--------------------------------------------------------------------------------
/modules/DiscreteSTModel_modules_GN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class VectorQuantizer(nn.Module):
7 | def __init__(self, num_embeddings, embedding_dim, commitment_cost):
8 | super(VectorQuantizer, self).__init__()
9 |
10 | self._embedding_dim = embedding_dim # D
11 | self._num_embeddings = num_embeddings # K
12 |
13 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
14 | self._embedding.weight.data.uniform_(-1 / self._num_embeddings, 1 / self._num_embeddings)
15 | self._commitment_cost = commitment_cost
16 |
17 | def forward(self, inputs):
18 | # convert inputs from B, C, H, W -> B, H, W, C
19 | inputs = inputs.permute(0, 2, 3, 1).contiguous()
20 | input_shape = inputs.shape
21 |
22 | # Flatten input
23 | flat_input = inputs.view(-1, self._embedding_dim)
24 |
25 | # Calculate distances
26 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
27 | + torch.sum(self._embedding.weight ** 2, dim=1)
28 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) # 平方差公式优化
29 |
30 | # Encoding
31 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
32 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
33 | encodings.scatter_(1, encoding_indices, 1)
34 |
35 | # Quantize and unflatten
36 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
37 |
38 | # Loss
39 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
40 | q_latent_loss = F.mse_loss(quantized, inputs.detach())
41 | loss = q_latent_loss + self._commitment_cost * e_latent_loss
42 |
43 | quantized = inputs + (quantized - inputs).detach()
44 | avg_probs = torch.mean(encodings, dim=0)
45 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
46 |
47 | # convert quantized from B, H, W, C -> B, C, H, W
48 | return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
49 |
50 | def lookup(self, x):
51 | embeddings = F.embedding(x, self._embedding)
52 | return embeddings
53 |
54 |
55 | class Residual(nn.Module):
56 | def __init__(self, in_channels, num_hiddens, num_residual_hiddens):
57 | super(Residual, self).__init__()
58 | self._block = nn.Sequential(
59 | nn.ReLU(True),
60 | nn.Conv2d(in_channels=in_channels,
61 | out_channels=num_residual_hiddens,
62 | kernel_size=3, stride=1, padding=1, bias=False),
63 | nn.GroupNorm(2, num_residual_hiddens),
64 | nn.ReLU(True),
65 | nn.Conv2d(in_channels=num_residual_hiddens,
66 | out_channels=num_hiddens,
67 | kernel_size=1, stride=1, bias=False),
68 | nn.GroupNorm(2, num_hiddens)
69 | )
70 |
71 | def forward(self, x):
72 | return x + self._block(x)
73 |
74 |
75 | class ResidualStack(nn.Module):
76 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
77 | super(ResidualStack, self).__init__()
78 | self._num_residual_layers = num_residual_layers
79 | self._layers = nn.ModuleList([Residual(in_channels, num_hiddens, num_residual_hiddens)
80 | for _ in range(self._num_residual_layers)])
81 |
82 | def forward(self, x):
83 | for i in range(self._num_residual_layers):
84 | x = self._layers[i](x)
85 | return F.relu(x)
86 |
87 |
88 | class Encoder(nn.Module):
89 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens):
90 | super(Encoder, self).__init__()
91 |
92 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
93 | out_channels=num_hiddens // 2,
94 | kernel_size=4,
95 | stride=2,
96 | padding=1)
97 | self.norm_1 = nn.GroupNorm(2, num_hiddens // 2)
98 | self._conv_2 = nn.Conv2d(in_channels=num_hiddens // 2,
99 | out_channels=num_hiddens,
100 | kernel_size=4,
101 | stride=2,
102 | padding=1)
103 | self.norm_2 = nn.GroupNorm(2, num_hiddens)
104 | self._conv_3 = nn.Conv2d(in_channels=num_hiddens,
105 | out_channels=num_hiddens,
106 | kernel_size=3,
107 | stride=1, padding=1)
108 | self.norm_3 = nn.GroupNorm(2, num_hiddens)
109 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
110 | num_hiddens=num_hiddens,
111 | num_residual_layers=num_residual_layers,
112 | num_residual_hiddens=num_residual_hiddens)
113 |
114 | def forward(self, inputs):
115 | # input shape: [B, C, W, H]
116 | x = self._conv_1(inputs) # [B, hidden_units//2 , W//2, H//2]
117 | x = F.relu(self.norm_1(x))
118 |
119 | x = self._conv_2(x) # [B, hidden_units, W//4, H//4]
120 | x = F.relu(self.norm_2(x))
121 |
122 | x = self._conv_3(x)
123 | x = self.norm_3(x)
124 | return self._residual_stack(x)
125 |
126 |
127 | class Decoder(nn.Module):
128 | def __init__(self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens, out_channels):
129 | super(Decoder, self).__init__()
130 |
131 | self._conv_1 = nn.Conv2d(in_channels=in_channels,
132 | out_channels=num_hiddens,
133 | kernel_size=3,
134 | stride=1, padding=1)
135 | self._norm_1 = nn.GroupNorm(2, num_hiddens)
136 |
137 | self._residual_stack = ResidualStack(in_channels=num_hiddens,
138 | num_hiddens=num_hiddens,
139 | num_residual_layers=num_residual_layers,
140 | num_residual_hiddens=num_residual_hiddens)
141 |
142 | self._conv_trans_1 = nn.ConvTranspose2d(in_channels=num_hiddens,
143 | out_channels=num_hiddens // 2,
144 | kernel_size=4,
145 | stride=2, padding=1)
146 | self._norm_2 = nn.GroupNorm(2, num_hiddens // 2)
147 | self._conv_trans_2 = nn.ConvTranspose2d(in_channels=num_hiddens // 2,
148 | out_channels=out_channels,
149 | kernel_size=4,
150 | stride=2, padding=1)
151 |
152 | def forward(self, inputs):
153 | x = self._conv_1(inputs)
154 | x = self._residual_stack(self._norm_1(x))
155 | x = self._conv_trans_1(x)
156 | x = F.relu(self._norm_2(x))
157 | return self._conv_trans_2(x)
158 |
159 |
--------------------------------------------------------------------------------
/modules/Fourier_modules.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from collections import OrderedDict
3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
4 | from torch.utils.checkpoint import checkpoint_sequential
5 | from params import get_fourcastnet_args
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | import torch.fft
10 | import numpy as np
11 | import torch.optim as optimizer
12 |
13 |
14 | class PatchEmbed(nn.Module):
15 | def __init__(self, img_size=None, patch_size=8, in_c=13, embed_dim=768, norm_layer=None):
16 | super(PatchEmbed, self).__init__()
17 | img_size = to_2tuple(img_size)
18 | patch_size = to_2tuple(patch_size)
19 | self.img_size = img_size
20 | self.patch_size = patch_size
21 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # h, w
22 | self.num_patches = self.grid_size[0] * self.grid_size[1]
23 | self.projection= nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
24 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
25 |
26 | def forward(self, x):
27 | B, C, H, W = x.shape
28 | assert H == self.img_size[0] and W == self.img_size[1], \
29 | f"Error..."
30 | '''
31 | [32, 3, 224, 224] -> [32, 768, 14, 14] -> [32, 768, 196] -> [32, 196, 768]
32 | Conv2D: [32, 3, 224, 224] -> [32, 768, 14, 14]
33 | Flatten: [B, C, H, W] -> [B, C, HW]
34 | Transpose: [B, C, HW] -> [B, HW, C]
35 | '''
36 | x = self.projection(x).flatten(2).transpose(1, 2)
37 | x = self.norm(x)
38 | return x
39 |
40 | class Mlp(nn.Module):
41 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
42 | super(Mlp, self).__init__()
43 | out_features = out_features or in_features
44 | hidden_features = hidden_features or in_features
45 | self.fc1 = nn.Linear(in_features, hidden_features)
46 | self.act = act_layer()
47 | self.fc2 = nn.Linear(hidden_features, out_features)
48 | self.fc3 = nn.AdaptiveAvgPool1d(out_features)
49 | self.drop = nn.Dropout(drop)
50 |
51 | def forward(self, x):
52 | x = self.fc1(x)
53 | x = self.act(x)
54 | x = self.drop(x)
55 | x = self.fc3(x)
56 | x = self.drop(x)
57 | return x
58 |
59 | class LearnableFourierPositionalEncoding(nn.Module):
60 | def __init__(self, M: int, F_dim: int, H_dim: int, D: int, gamma: float):
61 |
62 | super().__init__()
63 | self.M = M
64 | self.F_dim = F_dim
65 | self.H_dim = H_dim
66 | self.D = D
67 | self.gamma = gamma
68 |
69 | self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
70 | self.mlp = nn.Sequential(
71 | nn.Linear(self.F_dim, self.H_dim, bias=True),
72 | nn.GELU(),
73 | nn.Linear(self.H_dim, self.D)
74 | )
75 |
76 | self.init_weights()
77 |
78 | def init_weights(self):
79 | nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
80 |
81 | def forward(self, x):
82 |
83 | B, N, M = x.shape
84 | projected = self.Wr(x)
85 | cosines = torch.cos(projected)
86 | sines = torch.sin(projected)
87 | F = 1 / np.sqrt(self.F_dim) * torch.cat([cosines, sines], dim=-1)
88 | Y = self.mlp(F)
89 | PEx = Y.reshape((B, N, self.D))
90 | return PEx
91 |
92 | class AdativeFourierNeuralOperator(nn.Module):
93 | def __init__(self, dim, h=14, w=14):
94 | super(AdativeFourierNeuralOperator, self).__init__()
95 | args = get_fourcastnet_args()
96 | self.hidden_size = dim
97 | self.h = h
98 | self.w = w
99 | self.num_blocks = args.fno_blocks
100 | self.block_size = self.hidden_size // self.num_blocks
101 | assert self.hidden_size % self.num_blocks == 0
102 |
103 | self.scale = 0.02
104 | self.w1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
105 | self.b1 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
106 | self.w2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size, self.block_size))
107 | self.b2 = torch.nn.Parameter(self.scale * torch.randn(2, self.num_blocks, self.block_size))
108 | self.relu = nn.ReLU()
109 |
110 | if args.fno_bias:
111 | self.bias = nn.Conv1d(self.hidden_size, self.hidden_size, 1)
112 | else:
113 | self.bias = None
114 |
115 | self.softshrink = args.fno_softshrink
116 |
117 | def multiply(self, input, weights):
118 | return torch.einsum('...bd, bdk->...bk', input, weights)
119 |
120 | def forward(self, x):
121 | B, N, C = x.shape
122 |
123 | if self.bias:
124 | bias = self.bias(x.permute(0, 2, 1)).permute(0, 2, 1)
125 | else:
126 | bias = torch.zeros(x.shape, device=x.device)
127 |
128 | x = x.reshape(B, self.h, self.w, C)
129 | x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
130 | x = x.reshape(B, x.shape[1], x.shape[2], self.num_blocks, self.block_size)
131 |
132 | x_real = F.relu(self.multiply(x.real, self.w1[0]) - self.multiply(x.imag, self.w1[1]) + self.b1[0], inplace=True)
133 | x_imag = F.relu(self.multiply(x.real, self.w1[1]) + self.multiply(x.imag, self.w1[0]) + self.b1[1], inplace=True)
134 | x_real = self.multiply(x_real, self.w2[0]) - self.multiply(x_imag, self.w2[1]) + self.b2[0]
135 | x_imag = self.multiply(x_real, self.w2[1]) + self.multiply(x_imag, self.w2[0]) + self.b2[1]
136 |
137 | x = torch.stack([x_real, x_imag], dim=-1)
138 | x = F.softshrink(x, lambd=self.softshrink) if self.softshrink else x
139 |
140 | x = torch.view_as_complex(x)
141 | x = x.reshape(B, x.shape[1], x.shape[2], self.hidden_size)
142 | x = torch.fft.irfft2(x, s=(self.h, self.w), dim=(1,2), norm='ortho')
143 | x = x.reshape(B, N, C)
144 |
145 | return x+bias
146 |
147 | class FourierNetBlock(nn.Module):
148 | def __init__(self,
149 | dim,
150 | mlp_ratio=4.,
151 | drop=0.,
152 | drop_path=0.,
153 | act_layer=nn.GELU,
154 | norm_layer=nn.LayerNorm,
155 | h=14,
156 | w=14):
157 | super(FourierNetBlock, self).__init__()
158 | args = get_fourcastnet_args()
159 | self.normlayer1 = norm_layer(dim)
160 | self.filter = AdativeFourierNeuralOperator(dim, h=h, w=w)
161 |
162 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163 | self.normlayer2 = norm_layer(dim)
164 | mlp_hidden_dim = int(dim * mlp_ratio)
165 | self.mlp = Mlp(in_features=dim,
166 | hidden_features=mlp_hidden_dim,
167 | act_layer=act_layer,
168 | drop=drop)
169 | self.double_skip = args.double_skip
170 |
171 | def forward(self, x):
172 | x = x + self.drop_path(self.filter(self.normlayer1(x)))
173 | x = x + self.drop_path(self.mlp(self.normlayer2(x)))
174 | return x
--------------------------------------------------------------------------------
/modules/STConvEncoderDecoder_modules.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class BasicConv2d(nn.Module):
5 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
6 | super(BasicConv2d, self).__init__()
7 | self.act_norm=act_norm
8 | if not transpose:
9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
10 | else:
11 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
12 | self.norm = nn.GroupNorm(2, out_channels)
13 | self.act = nn.LeakyReLU(0.2, inplace=True)
14 |
15 | def forward(self, x):
16 | y = self.conv(x)
17 | if self.act_norm:
18 | y = self.act(self.norm(y))
19 | return y
20 |
21 |
22 | class ConvSC(nn.Module):
23 | def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
24 | super(ConvSC, self).__init__()
25 | if stride == 1:
26 | transpose = False
27 | self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
28 | padding=1, transpose=transpose, act_norm=act_norm)
29 |
30 | def forward(self, x):
31 | y = self.conv(x)
32 | return y
33 |
34 |
35 | class GroupConv2d(nn.Module):
36 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
37 | super(GroupConv2d, self).__init__()
38 | self.act_norm = act_norm
39 | if in_channels % groups != 0:
40 | groups = 1
41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
42 | self.norm = nn.GroupNorm(groups,out_channels)
43 | self.activate = nn.LeakyReLU(0.2, inplace=True)
44 |
45 | def forward(self, x):
46 | y = self.conv(x)
47 | if self.act_norm:
48 | y = self.activate(self.norm(y))
49 | return y
50 |
51 |
52 | class Inception(nn.Module):
53 | def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
54 | super(Inception, self).__init__()
55 | self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
56 | layers = []
57 | for ker in incep_ker:
58 | layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
59 | self.layers = nn.Sequential(*layers)
60 |
61 | def forward(self, x):
62 | x = self.conv1(x)
63 | y = 0
64 | for layer in self.layers:
65 | y += layer(x)
66 | return y
--------------------------------------------------------------------------------
/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__init__.py
--------------------------------------------------------------------------------
/modules/__pycache__/DiscreteSTModel_modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/DiscreteSTModel_modules.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/Fourier_modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/Fourier_modules.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/Fourier_modules.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/Fourier_modules.cpython-39.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/STConvEncoderDecoder_modules.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/STConvEncoderDecoder_modules.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/easylearningscores/PastNet/d7b30bdf26397f49e0993bfde1255aac31a1bae5/modules/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def get_fourcastnet_args():
5 | parser = argparse.ArgumentParser('FourCastNet training and evaluation script', add_help=False)
6 | parser.add_argument('--batch-size', default=4, type=int)
7 | parser.add_argument('--pretrain-epochs', default=80, type=int)
8 | parser.add_argument('--fintune-epochs', default=25, type=int)
9 |
10 | # Model parameters
11 | parser.add_argument('--arch', default='deit_small', type=str, help='Name of model to train')
12 |
13 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)')
14 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', help='Drop path rate (default: 0.1)')
15 |
16 | # Optimizer parameters
17 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"')
18 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)')
19 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)')
20 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')
21 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
22 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)')
23 | # Learning rate schedule parameters
24 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"')
25 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)')
26 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages')
27 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)')
28 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)')
29 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)')
30 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
31 |
32 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR')
33 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports')
34 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
35 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10')
36 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)')
37 |
38 | # Augmentation parameters
39 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)')
40 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy. "v0" or "original". "(default: rand-m9-mstd0.5-inc1)'),
41 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
42 | parser.add_argument('--train-interpolation', type=str, default='bicubic', help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
43 |
44 | parser.add_argument('--repeated-aug', action='store_true')
45 | parser.set_defaults(repeated_aug=False)
46 |
47 | # * Random Erase params
48 | parser.add_argument('--reprob', type=float, default=0, metavar='PCT', help='Random erase prob (default: 0.25)')
49 | parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")')
50 | parser.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)')
51 | parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split')
52 |
53 | # fno parameters
54 | parser.add_argument('--fno-bias', action='store_true')
55 | parser.add_argument('--fno-blocks', type=int, default=4)
56 | parser.add_argument('--fno-softshrink', type=float, default=0.00)
57 | parser.add_argument('--double-skip', action='store_true')
58 | parser.add_argument('--tensorboard-dir', type=str, default=None)
59 | parser.add_argument('--hidden-size', type=int, default=256)
60 | parser.add_argument('--num-layers', type=int, default=12)
61 | parser.add_argument('--checkpoint-activations', action='store_true')
62 | parser.add_argument('--autoresume', action='store_true')
63 |
64 | # attention parameters
65 | parser.add_argument('--num-attention-heads', type=int, default=1)
66 |
67 | # long short parameters
68 | parser.add_argument('--ls-w', type=int, default=4)
69 | parser.add_argument('--ls-dp-rank', type=int, default=16)
70 |
71 | return parser.parse_args()
72 |
73 |
74 | def get_graphcast_args():
75 | parser = argparse.ArgumentParser('Graphcast training and evaluation script', add_help=False)
76 | parser.add_argument('--batch-size', default=2, type=int)
77 | parser.add_argument('--epochs', default=200, type=int)
78 |
79 | # Model parameters
80 | parser.add_argument('--grid-node-num', default=720 * 1440, type=int, help='The number of grid nodes')
81 | parser.add_argument('--mesh-node-num', default=128 * 320, type=int, help='The number of mesh nodes')
82 | parser.add_argument('--mesh-edge-num', default=217170, type=int, help='The number of mesh nodes')
83 | parser.add_argument('--grid2mesh-edge-num', default=1357920, type=int, help='The number of mesh nodes')
84 | parser.add_argument('--mesh2grid-edge-num', default=2230560, type=int, help='The number of mesh nodes')
85 | parser.add_argument('--grid-node-dim', default=49, type=int, help='The input dim of grid nodes')
86 | parser.add_argument('--grid-node-pred-dim', default=20, type=int, help='The output dim of grid-node prediction')
87 | parser.add_argument('--mesh-node-dim', default=3, type=int, help='The input dim of mesh nodes')
88 | parser.add_argument('--edge-dim', default=4, type=int, help='The input dim of all edges')
89 | parser.add_argument('--grid-node-embed-dim', default=64, type=int, help='The embedding dim of grid nodes')
90 | parser.add_argument('--mesh-node-embed-dim', default=64, type=int, help='The embedding dim of mesh nodes')
91 | parser.add_argument('--edge-embed-dim', default=8, type=int, help='The embedding dim of mesh nodes')
92 |
93 | # Optimizer parameters
94 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"')
95 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: 1e-8)')
96 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)')
97 | parser.add_argument('--clip-grad', type=float, default=1, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')
98 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
99 | parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)')
100 |
101 | # Learning rate schedule parameters
102 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "cosine"')
103 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', help='learning rate (default: 5e-4)')
104 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages')
105 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)')
106 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)')
107 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', help='warmup learning rate (default: 1e-6)')
108 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
109 |
110 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', help='epoch interval to decay LR')
111 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', help='epochs to warmup LR, if scheduler supports')
112 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
113 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10')
114 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)')
115 |
116 | # Pipline training parameters
117 | parser.add_argument('--pp_size', type=int, default=8, help='pipeline parallel size')
118 | parser.add_argument('--chunks', type=int, default=1, help='chunk size')
119 |
120 | return parser.parse_args()
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from exp_vq import PastNet_exp
3 |
4 | import warnings
5 | warnings.filterwarnings('ignore')
6 |
7 | def create_parser():
8 | parser = argparse.ArgumentParser()
9 | # Set-up parameters
10 | parser.add_argument('--device', default='cuda', type=str, help='Name of device to use for tensor computations (cuda/cpu)')
11 | parser.add_argument('--res_dir', default='./results', type=str)
12 | parser.add_argument('--ex_name', default='DiscreteSTModel', type=str)
13 | parser.add_argument('--use_gpu', default=True, type=bool)
14 | parser.add_argument('--gpu', default=0, type=int)
15 | parser.add_argument('--seed', default=1, type=int)
16 | parser.add_argument('--load_model', default="", type=str)
17 |
18 | # dataset parameters
19 | parser.add_argument('--batch_size', default=16, type=int, help='Batch size')
20 | parser.add_argument('--val_batch_size', default=16, type=int, help='Batch size')
21 | parser.add_argument('--data_root', default='./data/')
22 | parser.add_argument('--dataname', default='mmnist', choices=['mmnist', 'taxibj', 'caltech'])
23 | parser.add_argument('--num_workers', default=8, type=int)
24 |
25 | # model parameters
26 | parser.add_argument('--in_shape', default=[10, 1, 64, 64], type=int, nargs='*') # [10, 1, 64, 64] for mmnist, [4, 2, 32, 32] for taxibj
27 | parser.add_argument('--hid_T', default=256, type=int)
28 | parser.add_argument('--N_T', default=8, type=int)
29 | parser.add_argument('--groups', default=4, type=int)
30 | parser.add_argument('--res_units', default=32, type=int)
31 | parser.add_argument('--res_layers', default=4, type=int)
32 | parser.add_argument('--K', default=512, type=int)
33 | parser.add_argument('--D', default=64, type=int)
34 |
35 | # Training parameters
36 | parser.add_argument('--epochs', default=201, type=int)
37 | parser.add_argument('--log_step', default=1, type=int)
38 | parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
39 | parser.add_argument('--load_pred_train', default=0, type=int, help='Learning rate')
40 | parser.add_argument('--freeze_vqvae', default=0, type=int)
41 | parser.add_argument('--theta', default=1, type=float)
42 | return parser
43 |
44 |
45 | if __name__ == '__main__':
46 | args = create_parser().parse_args()
47 | config = args.__dict__
48 |
49 | exp = PastNet_exp(args)
50 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>> start <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
51 | exp.train(args)
52 | print('>>>>>>>>>>>>>>>>>>>>>>>>>>>> testing <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
53 | mse = exp.test(args)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import random
5 | import numpy as np
6 | import torch.backends.cudnn as cudnn
7 |
8 | def set_seed(seed):
9 | random.seed(seed)
10 | np.random.seed(seed)
11 | torch.manual_seed(seed)
12 | cudnn.deterministic = True
13 |
14 | def print_log(message):
15 | print(message)
16 | logging.info(message)
17 |
18 | def output_namespace(namespace):
19 | configs = namespace.__dict__
20 | message = ''
21 | for k, v in configs.items():
22 | message += '\n' + k + ': \t' + str(v) + '\t'
23 | return message
24 |
25 | def check_dir(path):
26 | if not os.path.exists(path):
27 | os.makedirs(path)
--------------------------------------------------------------------------------