├── .gitignore ├── LICENSE ├── README.md ├── configs ├── real_1.txt ├── real_2.txt ├── real_3.txt ├── synthetic_1.txt └── synthetic_2.txt ├── environment.yml ├── gifs ├── real_1.gif ├── real_2.gif ├── real_3.gif ├── synthetic_1.gif └── synthetic_2.gif ├── io_utils.py ├── learning_utils.py ├── simulation_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/logs 2 | **/data 3 | **/pretrained 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 yitongdeng-projects 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2023] Learning Vortex Dynamics for Fluid Inference and Prediction 2 | by [Yitong Deng](https://yitongdeng.github.io/), [Hong-Xing Yu](https://kovenyu.com/), [Jiajun Wu](https://jiajunwu.com/), and [Bo Zhu](https://cs.dartmouth.edu/~bozhu/). 3 | 4 | Our paper can be found at: https://arxiv.org/abs/2301.11494. 5 | 6 | Video results can be found at: https://yitongdeng.github.io/vortex_learning_webpage. 7 | 8 | ## Environment 9 | The environment can be installed by conda via: 10 | ```bash 11 | conda env create -f environment.yml 12 | conda activate vortex_env 13 | ``` 14 | 15 | Our code is tested on `Windows 10` and `Ubuntu 20.04`. 16 | 17 | ## Data 18 | The 5 videos (2 synthetic and 3 real-world) used in our paper can be downloaded from [Google Drive](https://drive.google.com/file/d/1kCX2NF4IMtB2IC_xZPfrrnR2uiV83zRC/view?usp=sharing). Once downloaded, place the unzipped `data` folder to the project root directory. 19 | 20 | ## Synthetic 1 21 | 22 | #### Pretrain 23 | 24 | First, execute the command below to pretrain the trajectory network so that the initial vortices are regularly spaced to cover the simulation domain and remains stationary. 25 | 26 | ```bash 27 | python train.py --config configs/synthetic_1.txt --run_pretrain True 28 | ``` 29 | 30 | Once completed, navigate to `pretrained/exp_synthetic_1/tests/` and check that the plotted dots are regularly spaced and remain roughly stationary. A file `pretrained.tar` shall also appear at `pretrained/exp_synthetic_1/ckpts/`. 31 | 32 | #### Train 33 | 34 | Then, run the command below to train. 35 | 36 | ```bash 37 | python train.py --config configs/synthetic_1.txt 38 | ``` 39 | 40 | Checkpoints and testing results are written to `logs/exp_synthetic_1/tests/` once every 1000 training iterations. 41 | 42 | #### Results 43 | 44 | When run on our Windows machine with AMD Ryzen Threadripper 3990X and NVIDIA RTX A6000, this is the final testing result we get: 45 | 46 | ![synthetic_1](gifs/synthetic_1.gif) 47 | 48 | Note that since our PyTorch code includes nondeterministic components (e.g., the CUDA grid sampler), it is expected that each training session will not generate the exact same outcome. 49 | 50 | ## Synthetic 2 51 | 52 | #### Pretrain 53 | 54 | ```bash 55 | python train.py --config configs/synthetic_2.txt --run_pretrain True 56 | ``` 57 | 58 | #### Train 59 | 60 | ```bash 61 | python train.py --config configs/synthetic_2.txt 62 | ``` 63 | 64 | #### Results 65 | 66 | ![synthetic_2](gifs/synthetic_2.gif) 67 | 68 | ## Real 1 69 | 70 | #### Pretrain 71 | 72 | ```bash 73 | python train.py --config configs/real_1.txt --run_pretrain True 74 | ``` 75 | 76 | #### Train 77 | 78 | ```bash 79 | python train.py --config configs/real_1.txt 80 | ``` 81 | 82 | #### Results 83 | 84 | ![real_1](gifs/real_1.gif) 85 | 86 | ## Real 2 87 | 88 | #### Pretrain 89 | 90 | ```bash 91 | python train.py --config configs/real_2.txt --run_pretrain True 92 | ``` 93 | 94 | #### Train 95 | 96 | ```bash 97 | python train.py --config configs/real_2.txt 98 | ``` 99 | 100 | #### Results 101 | 102 | ![real_2](gifs/real_2.gif) 103 | 104 | ## Real 3 105 | 106 | #### Pretrain 107 | 108 | ```bash 109 | python train.py --config configs/real_3.txt --run_pretrain True 110 | ``` 111 | 112 | #### Train 113 | 114 | ```bash 115 | python train.py --config configs/real_3.txt 116 | ``` 117 | 118 | #### Results 119 | 120 | ![real_3](gifs/real_3.gif) 121 | 122 | ## Trying your own video 123 | We assume the input is a Numpy array of shape `[num_frames], 256, 256, 3`, with the last dimension representing RGB pixel values between 0.0 and 1.0, located in `data/[your_name_here]/imgs.npy`. For fluid videos with boundaries (like in our real-world examples), it is required that a Numpy array of shape `256, 256` representing the signed distance field to the boundary be supplied in `data/[your_name_here]/sdf.npy`. We assume the signed distance has a unit of pixels. 124 | 125 | For videos of higher dynamical complexity, we also encourage playing around with the number of vortex particles used. Currently, this is determined by the `vorts_num_x` and `vorts_num_y` parameters in `train.py` hard coded to 4, which might need to be increased as needed. 126 | 127 | ## Bibliography 128 | If you find our paper or code helpful, please consider citing: 129 | ``` 130 | @inproceedings{deng2023vortex, 131 | title={Learning Vortex Dynamics for Fluid Inference and Prediction}, 132 | author={Yitong Deng and Hong-Xing Yu and Jiajun Wu and Bo Zhu}, 133 | booktitle={Proceedings of the International Conference on Learning Representations}, 134 | year={2023}, 135 | } 136 | ``` 137 | -------------------------------------------------------------------------------- /configs/real_1.txt: -------------------------------------------------------------------------------- 1 | --data_name real_1 2 | --seen_ratio 0.6666 3 | --exp_name exp_real_1 4 | --num_train = 4000 5 | --vort_scale = 0.5 6 | --init_vort_dist = 0.7 -------------------------------------------------------------------------------- /configs/real_2.txt: -------------------------------------------------------------------------------- 1 | --data_name real_2 2 | --seen_ratio 0.6666 3 | --exp_name exp_real_2 4 | --num_train = 4000 5 | --vort_scale = 0.5 6 | --init_vort_dist = 0.7 -------------------------------------------------------------------------------- /configs/real_3.txt: -------------------------------------------------------------------------------- 1 | --data_name real_3 2 | --seen_ratio 0.6666 3 | --exp_name exp_real_3 4 | --num_train = 4000 5 | --vort_scale = 0.5 6 | --init_vort_dist = 0.7 -------------------------------------------------------------------------------- /configs/synthetic_1.txt: -------------------------------------------------------------------------------- 1 | --data_name synthetic_1 2 | --seen_ratio 0.3333 3 | --exp_name exp_synthetic_1 4 | --num_train = 40000 5 | --vort_scale = 0.33 6 | --init_vort_dist = 1.0 7 | -------------------------------------------------------------------------------- /configs/synthetic_2.txt: -------------------------------------------------------------------------------- 1 | --data_name synthetic_2 2 | --seen_ratio 0.3333 3 | --exp_name exp_synthetic_2 4 | --num_train = 40000 5 | --vort_scale = 0.33 6 | --init_vort_dist = 1.0 7 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: vortex_env 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - pytorch=1.13.1 9 | - torchvision=0.14.1 10 | - cudatoolkit=11.7 11 | - pip: 12 | - configargparse==1.5.3 13 | - imageio 14 | - matplotlib -------------------------------------------------------------------------------- /gifs/real_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yitongdeng-projects/learning_vortex_dynamics_code/4d2576b38ea264db2994a20f2a140654237e7480/gifs/real_1.gif -------------------------------------------------------------------------------- /gifs/real_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yitongdeng-projects/learning_vortex_dynamics_code/4d2576b38ea264db2994a20f2a140654237e7480/gifs/real_2.gif -------------------------------------------------------------------------------- /gifs/real_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yitongdeng-projects/learning_vortex_dynamics_code/4d2576b38ea264db2994a20f2a140654237e7480/gifs/real_3.gif -------------------------------------------------------------------------------- /gifs/synthetic_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yitongdeng-projects/learning_vortex_dynamics_code/4d2576b38ea264db2994a20f2a140654237e7480/gifs/synthetic_1.gif -------------------------------------------------------------------------------- /gifs/synthetic_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yitongdeng-projects/learning_vortex_dynamics_code/4d2576b38ea264db2994a20f2a140654237e7480/gifs/synthetic_2.gif -------------------------------------------------------------------------------- /io_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import numpy as np 4 | import imageio.v2 as imageio 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import copy 8 | import configargparse 9 | 10 | def to_numpy(x): 11 | return x.detach().cpu().numpy() 12 | 13 | def to8b(x): 14 | return (255*np.clip(x,0,1)).astype(np.uint8) 15 | 16 | # create gif from images using FFMPEG 17 | def merge_imgs(framerate, save_dir): 18 | os.system('ffmpeg -hide_banner -loglevel error -y -i {0}/%03d.jpg -vf palettegen {0}/palette.png'.format(save_dir)) 19 | os.system('ffmpeg -hide_banner -loglevel error -y -framerate {0} -i {1}/%03d.jpg -i {1}/palette.png -lavfi paletteuse {1}/output.gif'.format(framerate, save_dir)) 20 | 21 | # remove everything in dir 22 | def remove_everything_in(folder): 23 | for filename in os.listdir(folder): 24 | file_path = os.path.join(folder, filename) 25 | try: 26 | if os.path.isfile(file_path) or os.path.islink(file_path): 27 | os.unlink(file_path) 28 | elif os.path.isdir(file_path): 29 | shutil.rmtree(file_path) 30 | except Exception as e: 31 | print('Failed to delete %s. Reason: %s' % (file_path, e)) 32 | 33 | # read image 34 | def imwrite(f, img): 35 | img = to8b(img) 36 | imageio.imwrite(f, img) # save frame as jpeg file 37 | 38 | # generate grid coordinates 39 | def gen_grid(width, height, device): 40 | img_n_grid_x = width 41 | img_n_grid_y = height 42 | img_dx = 1./img_n_grid_y 43 | c_x, c_y = torch.meshgrid(torch.arange(img_n_grid_x), torch.arange(img_n_grid_y), indexing = "ij") 44 | img_x = img_dx * (torch.cat((c_x[..., None], c_y[..., None]), axis = 2) + 0.5).to(device) # grid center locations 45 | return img_x 46 | 47 | # write image 48 | def write_image(img_xy, outdir, i): 49 | img_xy = copy.deepcopy(img_xy) 50 | c_pred = np.flip(img_xy.transpose([1,0,2]), 0) 51 | img8b = to8b(c_pred) 52 | save_filepath = os.path.join(outdir, '{:03d}.jpg'.format(i)) 53 | imageio.imwrite(save_filepath, img8b) 54 | 55 | # write vortex (particles) with their velocities 56 | def write_vorts(vorts_pos, vorts_uv, outdir, i): 57 | vorts_pos = copy.deepcopy(vorts_pos) 58 | pos = vorts_pos 59 | fig = plt.figure(num=1, figsize=(7, 7), clear=True) 60 | ax = fig.add_subplot() 61 | fig.subplots_adjust(0.1,0.1,0.9,0.9) 62 | ax.set_xlim([0, 1]) 63 | ax.set_ylim([0, 1]) 64 | s = ax.scatter(pos[..., 0], pos[..., 1], s = 100) 65 | ax.quiver(pos[..., 0], pos[..., 1], vorts_uv[..., 0], vorts_uv[..., 1], color = "red", scale = 10.) 66 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 67 | 68 | # write vorticity field 69 | def write_vorticity(vort_img, outdir, i): 70 | vort_img = copy.deepcopy(vort_img) 71 | array = vort_img 72 | scale = array.shape[1] 73 | array = np.transpose(array, (1, 0, 2)) # from X, Y to Y, X 74 | fig = plt.figure(num=1, figsize=(8, 7), clear=True) 75 | ax = fig.add_subplot() 76 | fig.subplots_adjust(0.05,0.,0.9,1) 77 | ax.set_xlim([0, array.shape[0]]) 78 | ax.set_ylim([0, array.shape[1]]) 79 | p = ax.imshow(array, alpha = 0.75, vmin = -10, vmax = 10) 80 | fig.colorbar(p, fraction=0.04) 81 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 82 | 83 | # write vortices over image 84 | def write_visualization(img, vorts_pos, vorts_w, outdir, i, boundary = None): 85 | img = copy.deepcopy(img) 86 | # mask the out-of-bound area as green 87 | if boundary is not None: 88 | OUT = (boundary[0] >= -boundary[2]).cpu().numpy() 89 | img[OUT] = (0.5 * img[OUT]) 90 | img[OUT, 1] += 0.5 91 | vorts_pos = copy.deepcopy(vorts_pos) 92 | vorts_w = copy.deepcopy(vorts_w) 93 | array = img 94 | scale = array.shape[1] 95 | pos = vorts_pos * scale 96 | array = np.transpose(array, (1, 0, 2)) # from X, Y to Y, X 97 | fig = plt.figure(num=1, figsize=(8, 7), clear=True) 98 | ax = fig.add_subplot() 99 | fig.subplots_adjust(0.05,0.,0.9,1) 100 | ax.set_xlim([0, array.shape[0]]) 101 | ax.set_ylim([0, array.shape[1]]) 102 | s = ax.scatter(pos[..., 0], pos[..., 1], s = 100, c = vorts_w.flatten(), vmin = None, vmax = None) 103 | p = ax.imshow(array, alpha = 0.75) 104 | fig.colorbar(s, fraction=0.04) 105 | fig.savefig(os.path.join(outdir, '{:03d}.jpg'.format(i)), dpi = 512//8) 106 | 107 | # convert rgb image to yuv parametrization 108 | def rgb_to_yuv(_rgb_image): 109 | rgb_image = _rgb_image[..., None] 110 | matrix = np.array([[0.299, 0.587, 0.114], 111 | [-0.14713, -0.28886, 0.436], 112 | [0.615, -0.51499, -0.10001]]).astype(np.float32) 113 | matrix = torch.from_numpy(matrix)[None, None, None, ...].to(rgb_image.get_device()) 114 | yuv_image = torch.einsum("abcde, abcef -> abcdf", matrix, rgb_image) 115 | return yuv_image.squeeze() 116 | 117 | # command line tools 118 | def config_parser(): 119 | parser = configargparse.ArgumentParser() 120 | parser.add_argument('--config', is_config_file=True, 121 | help='config file path') 122 | parser.add_argument('--seen_ratio', type=float, default=0.3333, 123 | help='fraction of input video available during training') 124 | parser.add_argument("--data_name", type=str, default='synthetic_1', 125 | help='name of video data') 126 | parser.add_argument("--run_pretrain", type=bool, default = False, 127 | help='whether to run pretrain only') 128 | parser.add_argument("--test_only", type=bool, default = False, 129 | help='whether to run test only') 130 | parser.add_argument("--start_over", type=bool, default = False, 131 | help='whether to clear previous record on this experiment') 132 | parser.add_argument("--exp_name", type=str, default = "exp_0", 133 | help='the name of current experiment') 134 | parser.add_argument('--vort_scale', type=float, default=0.33, 135 | help='characteristic scale for vortices') 136 | parser.add_argument('--num_train', type=int, default=40000, 137 | help='number of training iterations') 138 | parser.add_argument("--init_vort_dist", type=float, default=1.0, 139 | help='how spread out are init vortices') 140 | return parser 141 | -------------------------------------------------------------------------------- /learning_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from torch.optim.lr_scheduler import LambdaLR, StepLR 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | device = torch.device("cuda") 10 | real = torch.float32 11 | 12 | L2_Loss = nn.MSELoss().cuda() 13 | 14 | class SineResidualBlock(nn.Module): 15 | 16 | def __init__(self, in_features, out_features, bias=True, 17 | is_first=False, omega_0=30): 18 | super().__init__() 19 | self.omega_0 = omega_0 20 | self.is_first = is_first 21 | 22 | self.in_features = in_features 23 | self.linear = nn.Linear(in_features, out_features, bias=bias) 24 | # add shortcut 25 | self.shortcut = nn.Sequential() 26 | if in_features != out_features: 27 | self.shortcut = nn.Sequential( 28 | nn.Linear(in_features, out_features), 29 | ) 30 | 31 | self.init_weights() 32 | 33 | def init_weights(self): 34 | with torch.no_grad(): 35 | if self.is_first: 36 | self.linear.weight.uniform_(-1 / self.in_features, 37 | 1 / self.in_features) 38 | else: 39 | self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 40 | np.sqrt(6 / self.in_features) / self.omega_0) 41 | 42 | def forward(self, input): 43 | out = torch.sin(self.omega_0 * self.linear(input)) 44 | out += self.shortcut(input) 45 | out = nn.functional.relu(out) 46 | return out 47 | 48 | class Dynamics_Net(nn.Module): 49 | def __init__(self): 50 | super().__init__() 51 | in_dim = 1 52 | out_dim = 1 53 | width = 40 54 | self.layers = nn.Sequential(SineResidualBlock(in_dim, width, omega_0=1., is_first=True), 55 | SineResidualBlock(width, width, omega_0=1.), 56 | SineResidualBlock(width, width, omega_0=1.), 57 | SineResidualBlock(width, width, omega_0=1.), 58 | nn.Linear(width, out_dim), 59 | ) 60 | 61 | def forward(self, x): 62 | '''Forward pass''' 63 | return self.layers(x) 64 | 65 | class Position_Net(nn.Module): 66 | def __init__(self, num_vorts): 67 | super().__init__() 68 | in_dim = 1 69 | out_dim = num_vorts * 2 70 | self.layers = nn.Sequential(SineResidualBlock(in_dim, 64, omega_0=1., is_first=True), 71 | SineResidualBlock(64, 128, omega_0=1.), 72 | SineResidualBlock(128, 256, omega_0=1.), 73 | nn.Linear(256, out_dim) 74 | ) 75 | 76 | def forward(self, x): 77 | '''Forward pass''' 78 | return self.layers(x) 79 | 80 | 81 | def create_bundle(logdir, num_vorts, decay_step, decay_gamma, pretrain_dir = None): 82 | model_len = Dynamics_Net().to(device) 83 | model_pos = Position_Net(num_vorts).to(device) 84 | grad_vars = list(model_len.parameters()) 85 | grad_vars2 = list(model_pos.parameters()) 86 | ########################## 87 | # Load checkpoints 88 | ckpts = [os.path.join(logdir, f) for f in sorted(os.listdir(logdir)) if 'tar' in f] 89 | pretrain_ckpts = [] 90 | if pretrain_dir: 91 | pretrain_ckpts = [os.path.join(pretrain_dir, f) for f in sorted(os.listdir(pretrain_dir)) if 'tar' in f] 92 | 93 | if len(ckpts) <= 0: # no checkpoints to load 94 | w_pred = torch.zeros(num_vorts, 1, device = device, dtype = real) 95 | w_pred.requires_grad = True 96 | size_pred = torch.zeros(num_vorts, 1, device = device, dtype = real) 97 | size_pred.requires_grad = True 98 | start = 0 99 | optimizer = torch.optim.Adam([{'params': grad_vars}, \ 100 | {'params': grad_vars2, 'lr':3.e-4},\ 101 | {'params': w_pred, 'lr':5.e-3},\ 102 | {'params': size_pred, 'lr':5.e-3}], lr=1.e-3, betas=(0.9, 0.999)) 103 | # Load pretrained if there is one and no checkpoint exists 104 | if len(pretrain_ckpts) > 0: 105 | pre_ckpt_path = pretrain_ckpts[-1] 106 | print ("[Initialize] Has pretrained available, reloading from: ", pre_ckpt_path) 107 | pre_ckpt = torch.load(pre_ckpt_path) 108 | model_pos.load_state_dict(pre_ckpt['model_pos_state_dict']) 109 | 110 | else: # has checkpoints to load: 111 | ckpt_path = ckpts[-1] 112 | print ("[Initialize] Has checkpoint available, reloading from: ", ckpt_path) 113 | ckpt = torch.load(ckpt_path) 114 | start = ckpt['global_step'] 115 | w_pred = ckpt['w_pred'] 116 | w_pred.requires_grad = True 117 | size_pred = ckpt['size_pred'] 118 | size_pred.requires_grad = True 119 | model_len.load_state_dict(ckpt["model_len_state_dict"]) 120 | model_pos.load_state_dict(ckpt["model_pos_state_dict"]) 121 | optimizer = torch.optim.Adam([{'params': grad_vars}, \ 122 | {'params': grad_vars2, 'lr':3.e-4},\ 123 | {'params': w_pred, 'lr':5.e-3},\ 124 | {'params': size_pred, 'lr':5.e-3}], lr=1.e-3, betas=(0.9, 0.999)) 125 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 126 | 127 | lr_scheduler = StepLR(optimizer, step_size = decay_step, gamma = decay_gamma) 128 | 129 | 130 | ########################## 131 | net_dict = { 132 | 'model_len' : model_len, 133 | 'model_pos' : model_pos, 134 | 'w_pred' : w_pred, 135 | 'size_pred' : size_pred, 136 | } 137 | 138 | return net_dict, start, grad_vars, optimizer, lr_scheduler 139 | 140 | 141 | # vels: [batch, width, height, 2] 142 | def calc_div(vels): 143 | batch_size, width, height, D = vels.shape 144 | dx = 1./height 145 | du_dx = 1./(2*dx) * (vels[:, 2:, 1:-1, 0] - vels[:, :-2, 1:-1, 0]) 146 | dv_dy = 1./(2*dx) * (vels[:, 1:-1, 2:, 1] - vels[:, 1:-1, :-2, 1]) 147 | return du_dx + dv_dy 148 | 149 | # field: [batch, width, height, 1] 150 | def calc_grad(field): 151 | batch_size, width, height, _ = field.shape 152 | dx = 1./height 153 | df_dx = 1./(2*dx) * (field[:, 2:, 1:-1] - field[:, :-2, 1:-1]) 154 | df_dy = 1./(2*dx) * (field[:, 1:-1, 2:] - field[:, 1:-1, :-2]) 155 | return torch.cat((df_dx, df_dy), dim = -1) 156 | 157 | def calc_vort(vel_img, boundary = None): # compute the curl of velocity 158 | W, H, _ = vel_img.shape 159 | dx = 1./H 160 | vort_img = torch.zeros(W, H, 1, device = device, dtype = real) 161 | u = vel_img[...,[0]] 162 | v = vel_img[...,[1]] 163 | dvdx = 1/(2*dx) * (v[2:, 1:-1] - v[:-2, 1:-1]) 164 | dudy = 1/(2*dx) * (u[1:-1, 2:] - u[1:-1, :-2]) 165 | vort_img[1:-1, 1:-1] = dvdx - dudy 166 | if boundary is not None: 167 | # set out-of-bound pixels to 0 because velocity undefined there 168 | OUT = (boundary[0] >= -boundary[2] - 4) 169 | vort_img[OUT] *= 0 170 | return vort_img 171 | 172 | # sdf: [W, H] 173 | # sdf normal: [W, H, 2] 174 | def calc_sdf_normal(sdf): 175 | W, H = sdf.shape 176 | sdf_normal = torch.zeros((W, H, 2)).cuda() #[W, H, 2] 177 | sdf_normal[1:-1, 1:-1] = calc_grad(sdf[None,...,None])[0] # outward pointing [W, H, 2] 178 | sdf_normal = F.normalize(sdf_normal, dim = -1, p = 2) 179 | return sdf_normal 180 | 181 | # vorts_pos: [batch, num_vorts, 2] 182 | # query_pos: [num_query, 2] or [batch, num_query, 2] 183 | # return: [batch, num_queries, num_vorts, 2] 184 | def calc_diff_batched(_vorts_pos, _query_pos): 185 | vorts_pos = _vorts_pos[:, None, :, :] # [batch, 1, num_vorts, 2] 186 | if len(_query_pos.shape) > 2: 187 | query_pos = _query_pos[:, :, None, :] # [batch, num_query, 1, 2] 188 | else: 189 | query_pos = _query_pos[None, :, None, :] # [1, num_query, 1, 2] 190 | diff = query_pos - vorts_pos # [batch, num_queries, num_vorts, 2] 191 | return diff 192 | 193 | 194 | # vorts_pos shape: [batch, num_vorts, 2] 195 | # vorts_w shape: [num_vorts, 1] or [batch, num_vorts, 1] 196 | # vorts_size shape: [num_vorts, 1] or [batch, num_vorts, 1] 197 | def vort_to_vel(network_length, vorts_size, vorts_w, vorts_pos, query_pos, length_scale): 198 | diff = calc_diff_batched(vorts_pos, query_pos) # [batch_size, num_query, num_query, 2] 199 | # some broadcasting 200 | if len(vorts_size.shape) > 2: 201 | blob_size = vorts_size[:, None, ...] # [batch, 1, num_vorts, 1] 202 | else: 203 | blob_size = vorts_size[None, None, ...] # [1, 1, num_vorts, 1] 204 | if len(vorts_w.shape) > 2: 205 | vorts_w = vorts_w[:, None, ...] # [batch, num_query, num_vort, 1] 206 | else: 207 | vorts_w = vorts_w[None, None, ...] # [1, 1, num_vort, 1] 208 | 209 | diff = calc_diff_batched(vorts_pos, query_pos) 210 | dist = torch.norm(diff, dim = -1, p = 2, keepdim = True) 211 | dist_not_zero = dist > 0.0 212 | 213 | # cross product in 2D 214 | R = diff.flip([-1]) # (x, y) becomes (y, x) 215 | R[..., 0] *= -1 # (y, x) becomes (-y, x) 216 | R = F.normalize(R, dim = -1) 217 | 218 | dist = dist / (blob_size/length_scale) 219 | dist[dist_not_zero] = torch.pow(dist[dist_not_zero], 0.3) 220 | magnitude = network_length(dist) 221 | magnitude = magnitude / (blob_size/length_scale) 222 | 223 | result = magnitude * R * vorts_w 224 | result = torch.sum(result, dim = -2) # [batch_size, num_queries, 2] 225 | 226 | return result -------------------------------------------------------------------------------- /simulation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.manual_seed(123) 3 | import numpy as np 4 | import os 5 | from learning_utils import calc_grad 6 | import math 7 | import time 8 | from functorch import vmap 9 | import torch.nn.functional as F 10 | 11 | device = torch.device("cuda") 12 | real = torch.float32 13 | 14 | def RK1(pos, u, dt): 15 | return pos + dt * u 16 | 17 | def RK2(pos, u, dt): 18 | p_mid = pos + 0.5 * dt * u 19 | return pos + dt * sample_grid_batched(u, p_mid) 20 | 21 | def RK3(pos, u, dt): 22 | u1 = u 23 | p1 = pos + 0.5 * dt * u1 24 | u2 = sample_grid_batched(u, p1) 25 | p2 = pos + 0.75 * dt * u2 26 | u3 = sample_grid_batched(u, p2) 27 | return pos + dt * (2/9 * u1 + 1/3 * u2 + 4/9 * u3) 28 | 29 | def advect_quantity_batched(quantity, u, x, dt, boundary): 30 | return advect_quantity_batched_BFECC(quantity, u, x, dt, boundary) 31 | 32 | # pos: [num_queries, 2] 33 | # if a backtraced position is out-of-bound, project it to the interior 34 | def project_to_inside(pos, boundary): 35 | if boundary is None: # if no boundary then do nothing 36 | return pos 37 | sdf, sdf_normal, _ = boundary 38 | W, H = sdf.shape 39 | dx = 1./H 40 | pos_grid = (pos / dx).floor().long() 41 | pos_grid_x = pos_grid[...,0] 42 | pos_grid_y = pos_grid[...,1] 43 | pos_grid_x = torch.clamp(pos_grid_x, 0, W-1) 44 | pos_grid_y = torch.clamp(pos_grid_y, 0, H-1) 45 | sd_at_pos = sdf[pos_grid_x, pos_grid_y][...,None] # [num_queries, 1] 46 | sd_normal_at_pos = sdf_normal[pos_grid_x, pos_grid_y] # [num_queries, 2] 47 | OUT = (sd_at_pos >= -boundary[2]).squeeze(-1) # [num_queries] 48 | OUT_pos = pos[OUT] #[num_out_queries, 2] 49 | OUT_pos_fixed = OUT_pos - (sd_at_pos[OUT]+boundary[2]) * dx * sd_normal_at_pos[OUT] # remember to multiply by dx 50 | pos[OUT] = OUT_pos_fixed 51 | return pos 52 | 53 | 54 | def index_take_2D(source, index_x, index_y): 55 | W, H, Channel = source.shape 56 | W_, H_ = index_x.shape 57 | index_flattened_x = index_x.flatten() 58 | index_flattened_y = index_y.flatten() 59 | sampled = source[index_flattened_x, index_flattened_y].view((W_, H_, Channel)) 60 | return sampled 61 | 62 | index_take_batched = vmap(index_take_2D) 63 | 64 | # clipping used for MacCormack and BFECC 65 | def MacCormack_clip(advected_quantity, quantity, u, x, dt, boundary): 66 | batch, W, H, _ = u.shape 67 | prev_pos = RK3(x, u, -1. * dt) # [batch, W, H, 2] 68 | prev_pos = project_to_inside(prev_pos.view((-1, 2)), boundary).view(prev_pos.shape) 69 | dx = 1./H 70 | pos_grid = (prev_pos / dx - 0.5).floor().long() 71 | pos_grid_x = torch.clamp(pos_grid[..., 0], 0, W-2) 72 | pos_grid_y = torch.clamp(pos_grid[..., 1], 0, H-2) 73 | pos_grid_x_plus = pos_grid_x + 1 74 | pos_grid_y_plus = pos_grid_y + 1 75 | BL = index_take_batched(quantity, pos_grid_x, pos_grid_y) 76 | BR = index_take_batched(quantity, pos_grid_x_plus, pos_grid_y) 77 | TR = index_take_batched(quantity, pos_grid_x_plus, pos_grid_y_plus) 78 | TL = index_take_batched(quantity, pos_grid_x, pos_grid_y_plus) 79 | stacked = torch.stack((BL, BR, TR, TL), dim = 0) 80 | maxed = torch.max(stacked, dim = 0).values # [batch, W, H, 3] 81 | mined = torch.min(stacked, dim = 0).values # [batch, W, H, 3] 82 | _advected_quantity = torch.clamp(advected_quantity, mined, maxed) 83 | return _advected_quantity 84 | 85 | # SL 86 | def advect_quantity_batched_SL(quantity, u, x, dt, boundary): 87 | prev_pos = RK3(x, u, -1. * dt) # [batch, W, H, 2] 88 | prev_pos = project_to_inside(prev_pos.view((-1, 2)), boundary).view(prev_pos.shape) 89 | new_quantity = sample_grid_batched(quantity, prev_pos) 90 | return new_quantity 91 | 92 | # BFECC 93 | def advect_quantity_batched_BFECC(quantity, u, x, dt, boundary): 94 | quantity1 = advect_quantity_batched_SL(quantity, u, x, dt, boundary) 95 | quantity2 = advect_quantity_batched_SL(quantity1, u, x, -1.*dt, boundary) 96 | new_quantity = advect_quantity_batched_SL(quantity + 0.5 * (quantity-quantity2), u, x, dt, boundary) 97 | new_quantity = MacCormack_clip(new_quantity, quantity, u, x, dt, boundary) 98 | return new_quantity 99 | 100 | # MacCormack 101 | def advect_quantity_batched_MacCormack(quantity, u, x, dt, boundary): 102 | quantity1 = advect_quantity_batched_SL(quantity, u, x, dt, boundary) 103 | quantity2 = advect_quantity_batched_SL(quantity1, u, x, -1.*dt, boundary) 104 | new_quantity = quantity1 + 0.5 * (quantity - quantity2) 105 | new_quantity = MacCormack_clip(new_quantity, quantity, u, x, dt, boundary) 106 | return new_quantity 107 | 108 | # data = [batch, X, Y, n_channel] 109 | # pos = [batch, X, Y, 2] 110 | def sample_grid_batched(data, pos): 111 | data_ = data.permute([0, 3, 2, 1]) 112 | pos_ = pos.clone().permute([0, 2, 1, 3]) 113 | pos_ = (pos_ - 0.5) * 2 114 | F_sample_grid = F.grid_sample(data_, pos_, padding_mode = 'border', align_corners = False, mode = "bilinear") 115 | F_sample_grid = F_sample_grid.permute([0, 3, 2, 1]) 116 | return F_sample_grid 117 | 118 | # pos: [num_query, 2] or [batch, num_query, 2] 119 | # vel: [batch, num_query, 2] 120 | # mode: 0 for image, 1 for vort 121 | def boundary_treatment(pos, vel, boundary, mode = 0): 122 | vel_after = vel.clone() 123 | batch, num_query, _ = vel.shape 124 | sdf = boundary[0] # [W, H] 125 | sdf_normal = boundary[1] 126 | if mode == 0: 127 | score = torch.clamp((sdf / -15.), min = 0.).flatten() 128 | inside_band = (score < 1.).squeeze(-1).flatten() 129 | score = score[None, ..., None] 130 | vel_after[:, inside_band, :] = score[:, inside_band, :] * vel[:, inside_band, :] 131 | else: 132 | W, H = sdf.shape 133 | dx = 1./H 134 | pos_grid = (pos / dx).floor().long() 135 | pos_grid_x = pos_grid[...,0] 136 | pos_grid_y = pos_grid[...,1] 137 | pos_grid_x = torch.clamp(pos_grid_x, 0, W-1) 138 | pos_grid_y = torch.clamp(pos_grid_y, 0, H-1) 139 | sd = sdf[pos_grid_x, pos_grid_y][...,None] 140 | sd_normal = sdf_normal[pos_grid_x, pos_grid_y] 141 | score = torch.clamp((sd / -75.), min = 0.) 142 | inside_band = (score < 1.).squeeze(-1) 143 | vel_normal = torch.einsum('bij,bij->bi', vel, sd_normal)[...,None] * sd_normal 144 | vel_tang = vel - vel_normal 145 | tang_at_boundary = 0.33 146 | vel_after[inside_band] = ((1.-tang_at_boundary) * score[inside_band] + tang_at_boundary) * vel_tang[inside_band] + score[inside_band] * vel_normal[inside_band] 147 | 148 | return vel_after 149 | 150 | # simulate a single step 151 | def simulate_step(img, img_x, vorts_pos, vorts_w, vorts_size, vel_func, dt, boundary): 152 | batch_size = vorts_pos.shape[0] 153 | img_x_flattened = img_x.view(-1, 2) 154 | if boundary is None: 155 | img_vel_flattened = vel_func(vorts_size, vorts_w, vorts_pos, img_x_flattened) 156 | img_vel = img_vel_flattened.view((batch_size, img_x.shape[0], img_x.shape[1], -1)) 157 | new_img = torch.clip(advect_quantity_batched(img, img_vel, img_x, dt, boundary), 0., 1.) 158 | vorts_vel = vel_func(vorts_size, vorts_w, vorts_pos, vorts_pos) 159 | new_vorts_pos = RK1(vorts_pos, vorts_vel, dt) 160 | else: 161 | OUT = (boundary[0]>=-boundary[2]) 162 | IN = ~OUT 163 | img_x_flattened = img_x.view(-1, 2) 164 | IN_flattened = IN.expand(img_x.shape[:-1]).flatten() 165 | img_vel_flattened = torch.zeros(batch_size, *img_x_flattened.shape).to(device) 166 | # only the velocity of the IN part will be computed, the rest will be left as 0 167 | img_vel_flattened[:, IN_flattened] = vel_func(vorts_size, vorts_w, vorts_pos, img_x_flattened[IN_flattened]) 168 | img_vel_flattened = boundary_treatment(img_x_flattened, img_vel_flattened, boundary, mode = 0) 169 | img_vel = img_vel_flattened.view((batch_size, img_x.shape[0], img_x.shape[1], -1)) 170 | new_img = torch.clip(advect_quantity_batched(img, img_vel, img_x, dt, boundary), 0., 1.) 171 | new_img[:, OUT] = img[:, OUT] # the image of the OUT part will be left unchanged 172 | vorts_vel = vel_func(vorts_size, vorts_w, vorts_pos, vorts_pos) 173 | vorts_vel = boundary_treatment(vorts_pos, vorts_vel, boundary, mode = 1) 174 | new_vorts_pos = RK1(vorts_pos, vorts_vel, dt) 175 | 176 | return new_img, new_vorts_pos, img_vel, vorts_vel 177 | 178 | # simulate in batches 179 | # img: the initial image 180 | # img_x: the grid coordinates (meshgrid) 181 | # vorts_pos: init vortex positions 182 | # vorts_w: vorticity 183 | # vorts_size: size 184 | # num_steps: how many steps to simulate 185 | # vel_func: how to compute velocity from vorticity 186 | def simulate(img, img_x, vorts_pos, vorts_w, vorts_size, num_steps, vel_func, boundary = None, dt = 0.01): 187 | imgs = [] 188 | vorts_poss = [] 189 | img_vels = [] 190 | vorts_vels = [] 191 | for i in range(num_steps): 192 | img, vorts_pos, img_vel, vorts_vel = simulate_step(img, img_x, vorts_pos, vorts_w, vorts_size, vel_func, dt, boundary = boundary) 193 | imgs.append(img.clone()) 194 | vorts_poss.append(vorts_pos.clone()) 195 | img_vels.append(img_vel) 196 | vorts_vels.append(vorts_vel) 197 | 198 | return imgs, vorts_poss, img_vels, vorts_vels 199 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from io_utils import * 2 | from simulation_utils import * 3 | from learning_utils import * 4 | torch.manual_seed(123) 5 | import sys 6 | import os 7 | from functorch import jacrev, vmap 8 | 9 | device = torch.device("cuda") 10 | real = torch.float32 11 | 12 | # command line parse 13 | parser = config_parser() 14 | args = parser.parse_args() 15 | 16 | # some switches 17 | run_pretrain = args.run_pretrain # if this is set to true, then only pretrain() will be run 18 | test_only = args.test_only # if this is set to true, then only test() will be run 19 | start_over = args.start_over # if this is set to true, then the logs/[exp_name] dir. will be emptied 20 | 21 | # some hyperparameters 22 | print("[Train] Number of training iters: ", args.num_train) 23 | num_iters = args.num_train # total number of training iterations 24 | decimate_point = 20000 # LR decimates at this point 25 | decay_gamma = 0.99 26 | decay_step = max(1, int(decimate_point/math.log(0.1, decay_gamma))) # decay once every (# >= 1) learning steps 27 | save_ckpt_every = 1000 28 | test_every = 1000 29 | print_every = 20 30 | num_sims = 2 # the "m" param in paper 31 | batch_size = 4 32 | 33 | # load data 34 | datadir = os.path.join('data', args.data_name) 35 | print("[Data] Load from path: ", datadir) 36 | imgs = torch.from_numpy(np.load(os.path.join(datadir, 'imgs.npy'))).to(device).type(real) 37 | try: 38 | sdf = torch.from_numpy(np.load(os.path.join(datadir, 'sdf.npy'))).to(device).type(real) 39 | except: 40 | print("[Boundary] SDF file doesn't exist, no boundary") 41 | boundary = None 42 | else: 43 | print("[Boundary] SDF file exists, has boundary") 44 | sdf = torch.flip(sdf, [0]) 45 | sdf = torch.permute(sdf, (1, 0)) 46 | sdf_normal = calc_sdf_normal(sdf) 47 | # 1. signed distance field 48 | # 2. unit normal of sdf 49 | # 3. thickness (in pixels) 50 | boundary = (sdf, sdf_normal, 2) 51 | 52 | num_total_frames = imgs.shape[0] # seen + unseen frames 53 | print("[Data] Number of frames we have: ", num_total_frames) 54 | imgs = imgs[:math.ceil(num_total_frames * args.seen_ratio)] # select a number of frames to be revealed for training 55 | num_frames, width, height, num_channels = imgs.shape 56 | print("[Data] Number of frames revealed: ", num_frames) 57 | num_unseen_frames = num_total_frames - num_frames 58 | print("[Data] Number of frames concealed: ", num_unseen_frames) 59 | timestamps = torch.arange(num_frames).type(real)[..., None].to(device) * 0.01 60 | num_available_frames = num_frames - num_sims 61 | probs = torch.ones((num_available_frames), device = device, dtype = real) 62 | 63 | # setup initial vort (as a vorts_num_x X vorts_num_y grid) 64 | vorts_num_x = 4 65 | vorts_num_y = 4 66 | num_vorts = vorts_num_x * vorts_num_y 67 | 68 | # create some directories 69 | # logs dir 70 | exp_name = args.exp_name 71 | logsdir = os.path.join('logs', exp_name) 72 | print("[Output] Results saving to: ", logsdir) 73 | os.makedirs(logsdir, exist_ok=True) 74 | if start_over: 75 | remove_everything_in(logsdir) 76 | # folder for tests 77 | testdir = 'tests' 78 | testdir = os.path.join(logsdir, testdir) 79 | os.makedirs(testdir, exist_ok=True) 80 | # folder for ckpts 81 | ckptdir = 'ckpts' 82 | ckptdir = os.path.join(logsdir, ckptdir) 83 | os.makedirs(ckptdir, exist_ok=True) 84 | # folder for pre_trained 85 | pretraindir = 'pretrained' 86 | pretraindir = os.path.join(pretraindir, exp_name) 87 | os.makedirs(pretraindir, exist_ok=True) 88 | if run_pretrain: # if calling pretrain, then remove previous pretrain records 89 | remove_everything_in(pretraindir) 90 | pre_ckptdir = os.path.join(pretraindir, 'ckpts') # ckpt for pretrain 91 | os.makedirs(pre_ckptdir, exist_ok=True) 92 | pre_testdir = os.path.join(pretraindir, 'tests') # test for pretrain 93 | os.makedirs(pre_testdir, exist_ok=True) 94 | 95 | # init or load networks 96 | net_dict, start, grad_vars, optimizer, lr_scheduler = create_bundle(ckptdir, num_vorts, decay_step, decay_gamma, pretrain_dir = pre_ckptdir) 97 | img_x = gen_grid(width, height, device) # grid coordinates 98 | 99 | def eval_vel(vorts_size, vorts_w, vorts_pos, query_pos): 100 | return vort_to_vel(net_dict['model_len'], vorts_size, vorts_w, vorts_pos, query_pos, length_scale = args.vort_scale) 101 | 102 | def dist_2_len_(dist): 103 | return net_dict['model_len'](dist) 104 | 105 | def size_pred(): 106 | pred = net_dict['size_pred'] 107 | size = 0.03 + torch.sigmoid(pred) 108 | return size 109 | 110 | def w_pred(): 111 | pred = net_dict['w_pred'] 112 | w = torch.sin(pred) 113 | return w 114 | 115 | def comp_velocity(timestamps): 116 | jac = vmap(jacrev((net_dict['model_pos'])))(timestamps) 117 | post = jac[:, :, 0:1].view((timestamps.shape[0],-1,2,1)) 118 | xt = post[:, :, 0, :] 119 | yt = post[:, :, 1, :] 120 | uv = torch.cat((xt, yt), dim = 2) 121 | return uv 122 | 123 | # pretrain (of the trajectory module) 124 | # the scale parameter influences to the initial positions of the vortices 125 | def pretrain(scale = 1.): 126 | if start > 0: 127 | print("[Pretrain] Pretraining needs to be the start of the training pipeline. Please re-run with --start_over set to True.") 128 | sys.exit() 129 | 130 | with torch.no_grad(): 131 | init_poss = gen_grid(vorts_num_x, vorts_num_y, device).view([-1, 2]) 132 | init_poss = scale * init_poss + 0.5 * (1.-scale) # scale the initial grid 133 | pos_GT = init_poss[None, ...].expand(num_frames, -1, -1) 134 | vel_GT = torch.zeros_like(pos_GT) 135 | 136 | for it in range(10000): 137 | pos_pred = net_dict['model_pos'](timestamps).view(-1, num_vorts, 2) 138 | pos_loss = L2_Loss(pos_pred, pos_GT) 139 | 140 | vel_pred = comp_velocity(timestamps) 141 | vel_loss = 0.001 * L2_Loss(vel_pred, vel_GT) 142 | 143 | loss = pos_loss + vel_loss 144 | 145 | optimizer.zero_grad() 146 | loss.backward() 147 | optimizer.step() 148 | 149 | if it % 200 == 0: 150 | print("[Pretrain] Iter: ", it, ", loss: ", loss.detach().cpu().numpy(), "/ pos loss: ", pos_loss.detach().cpu().numpy(), "/ vel loss: ", vel_loss.detach().cpu().numpy()) 151 | 152 | # save pretrained results (trajectory module only) 153 | path = os.path.join(pre_ckptdir, 'pretrained.tar') 154 | torch.save({ 155 | 'model_pos_state_dict': net_dict['model_pos'].state_dict(), 156 | }, path) 157 | print('[Pretrain] Saved checkpoint to: ', path) 158 | with torch.no_grad(): 159 | # output all vort positions with velocity 160 | values = net_dict["model_pos"](timestamps) 161 | values = values.view([values.shape[0], -1, 2]) 162 | uvs = comp_velocity(timestamps) 163 | for i in range(values.shape[0]): 164 | print("[Pretrain] Writing test frame: ", i) 165 | vorts_pos_numpy = values[i].detach().cpu().numpy() 166 | vel_numpy = uvs[i].detach().cpu().numpy() 167 | write_vorts(vorts_pos_numpy, vel_numpy, pre_testdir, i) 168 | 169 | print('[Pretrain] Complete.') 170 | 171 | 172 | # test learned simulation 173 | def test(curr_it): 174 | print ("[Test] Testing at iter: " + str(curr_it)) 175 | currdir = os.path.join(testdir, str(curr_it)) 176 | os.makedirs(currdir, exist_ok=True) 177 | 178 | with torch.no_grad(): 179 | total_imgs = [imgs[[0]]] 180 | total_vels = [None] 181 | total_vorts = [None] 182 | for i in range(num_available_frames): 183 | num_to_sim = 1 184 | if i == num_available_frames-1: 185 | num_to_sim += num_sims + max(num_unseen_frames, int(1.5 * num_frames)) -1 # if at the last reveal image, simulate to the end of the video 186 | pos_pred = net_dict['model_pos'](timestamps[[i]]).view((1,num_vorts,2)) 187 | sim_imgs, sim_vorts_poss, sim_vels, sim_vorts_vels = simulate(total_imgs[-1].clone(), img_x, pos_pred.clone(), \ 188 | w_pred().clone(), size_pred().clone(),\ 189 | num_to_sim, vel_func = eval_vel, \ 190 | boundary = boundary) 191 | total_imgs = total_imgs + sim_imgs 192 | total_vels = total_vels + sim_vels 193 | total_vorts = total_vorts + sim_vorts_poss 194 | 195 | visdir = os.path.join(currdir, 'particles') 196 | os.makedirs(visdir, exist_ok=True) 197 | imgdir = os.path.join(currdir, 'imgs') 198 | os.makedirs(imgdir, exist_ok=True) 199 | vortdir = os.path.join(currdir, 'vorts') 200 | os.makedirs(vortdir, exist_ok=True) 201 | write_image(total_imgs[0][0].cpu().numpy(), imgdir, 0) # write init image 202 | for i in range(1, len(total_imgs)): 203 | print("[Test] Writing test frame: ", i) 204 | img = total_imgs[i].squeeze() 205 | vorts_pos = total_vorts[i] 206 | vorts_w = w_pred()[None,...] 207 | vorts_size = size_pred()[None,...] 208 | img_vel = total_vels[i].squeeze() 209 | vort_img = calc_vort(img_vel, boundary) 210 | vort_img_numpy = vort_img.detach().cpu().numpy() 211 | img_numpy = img.detach().cpu().numpy() 212 | vorts_pos_numpy = vorts_pos.detach().cpu().numpy() 213 | vorts_w_numpy = vorts_w.detach().cpu().numpy() 214 | write_visualization(img_numpy, vorts_pos_numpy, vorts_w_numpy, visdir, i, boundary = boundary) 215 | write_image(img_numpy, imgdir, i) 216 | write_vorticity(vort_img_numpy, vortdir, i) 217 | 218 | # # # # # 219 | 220 | # if pretrain is True then run pretrain() and quit 221 | if run_pretrain: 222 | pretrain(args.init_vort_dist) 223 | sys.exit() 224 | 225 | # if test_only is True then run test() and quit 226 | if test_only: 227 | test(start) 228 | sys.exit() 229 | 230 | # below is training code 231 | prev_time = time.time() 232 | for it in range(start, num_iters): 233 | # each iter select some different starting frames 234 | init_frames = probs.multinomial(num_samples = batch_size, replacement = False) 235 | 236 | # compute velocity prescribed by dynamics module 237 | with torch.no_grad(): 238 | pos_pred_gradless = net_dict['model_pos'](timestamps[init_frames]).view((-1,num_vorts,2)) 239 | D_vel = eval_vel(size_pred(), w_pred(), pos_pred_gradless, pos_pred_gradless) 240 | if boundary is not None: 241 | D_vel = boundary_treatment(pos_pred_gradless, D_vel, boundary, mode = 1) 242 | 243 | # velocity loss 244 | T_vel = comp_velocity(timestamps[init_frames]) # velocity prescribed by trajectory module 245 | vel_loss = 0.001 * L2_Loss(T_vel, D_vel) 246 | 247 | pos_pred = net_dict['model_pos'](timestamps[init_frames]).view((batch_size,num_vorts,2)) 248 | sim_imgs, sim_vorts_poss, sim_img_vels, sim_vorts_vels = simulate(imgs[init_frames].clone(), img_x, pos_pred, w_pred(), \ 249 | size_pred(), num_sims, vel_func = eval_vel, boundary = boundary) 250 | sim_imgs = torch.stack(sim_imgs) 251 | 252 | # comp img loss 253 | img_losses = [] 254 | if boundary is None: # if no boundary then compute loss on entire images 255 | for i in range(batch_size): 256 | pred = rgb_to_yuv(sim_imgs[:, i]) 257 | GT = rgb_to_yuv(imgs[init_frames[i]+1: init_frames[i]+1+num_sims]) 258 | img_losses.append(L2_Loss(pred, GT)) 259 | else: # if has boundary then compute loss only on the valid regions 260 | OUT = (boundary[0] >= -boundary[2]) 261 | IN = ~OUT 262 | for i in range(batch_size): 263 | pred = rgb_to_yuv(sim_imgs[:, i])[:, IN] 264 | GT = rgb_to_yuv(imgs[init_frames[i]+1: init_frames[i]+1+num_sims])[:, IN] 265 | img_losses.append(L2_Loss(pred, GT)) 266 | img_loss = torch.stack(img_losses).sum() 267 | 268 | # loss is the sum of the two losses 269 | loss = img_loss + vel_loss 270 | 271 | # optimize 272 | optimizer.zero_grad() 273 | loss.backward() 274 | optimizer.step() 275 | lr_scheduler.step() 276 | 277 | if it % print_every == 0: 278 | print("[Train] Iter: ", it, ", loss: ", loss.detach().cpu().numpy(), "/ img loss: ", img_loss.detach().cpu().numpy(), "/ vel loss: ", vel_loss.detach().cpu().numpy()) 279 | curr_time = time.time() 280 | print("[Train] Time Cost: ", curr_time-prev_time) 281 | prev_time = curr_time 282 | 283 | next_it = it + 1 284 | # save ckpt 285 | if (next_it % save_ckpt_every == 0 and next_it > 0) or (next_it == num_iters): 286 | path = os.path.join(ckptdir, '{:06d}.tar'.format(next_it)) 287 | torch.save({ 288 | 'global_step': next_it, 289 | 'w_pred': net_dict['w_pred'], 290 | 'size_pred': net_dict['size_pred'], 291 | 'model_pos_state_dict': net_dict['model_pos'].state_dict(), 292 | 'model_len_state_dict': net_dict['model_len'].state_dict(), 293 | 'optimizer_state_dict': optimizer.state_dict(), 294 | }, path) 295 | print('[Train] Saved checkpoints at', path) 296 | 297 | if (next_it % test_every == 0 and next_it > 0) or (next_it == num_iters): 298 | test(next_it) 299 | 300 | print('[Train] Complete.') 301 | --------------------------------------------------------------------------------