├── .gitignore ├── asset ├── teaser.png ├── smokegun.png ├── chocolate_seq.png └── dambreak2d_seq.png ├── data └── image │ ├── fire.png │ ├── okeffe.jpg │ ├── starry.jpg │ ├── volcano.png │ ├── ben_giles.png │ ├── oil_crop.jpg │ ├── pattern1.png │ ├── turbulence.png │ ├── blue_strokes.jpg │ ├── seated-nude.jpg │ ├── dark_matter_bw.png │ └── Nature-EruptingVolcano.jpg ├── LICENSE ├── setup.bat ├── README.md ├── run.bat ├── scene ├── dambreak2d.py ├── chocolate.py └── smokegun.py ├── vgg.py ├── config.py ├── test_smokegun.py ├── test_dambreak2d.py ├── test_chocolate.py ├── styler_2p.py ├── test_smokegun_resim.py ├── styler_base.py ├── styler_3p.py ├── util.py └── transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .vscode 3 | data 4 | log 5 | venv -------------------------------------------------------------------------------- /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/asset/teaser.png -------------------------------------------------------------------------------- /asset/smokegun.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/asset/smokegun.png -------------------------------------------------------------------------------- /data/image/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/fire.png -------------------------------------------------------------------------------- /data/image/okeffe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/okeffe.jpg -------------------------------------------------------------------------------- /data/image/starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/starry.jpg -------------------------------------------------------------------------------- /data/image/volcano.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/volcano.png -------------------------------------------------------------------------------- /asset/chocolate_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/asset/chocolate_seq.png -------------------------------------------------------------------------------- /asset/dambreak2d_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/asset/dambreak2d_seq.png -------------------------------------------------------------------------------- /data/image/ben_giles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/ben_giles.png -------------------------------------------------------------------------------- /data/image/oil_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/oil_crop.jpg -------------------------------------------------------------------------------- /data/image/pattern1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/pattern1.png -------------------------------------------------------------------------------- /data/image/turbulence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/turbulence.png -------------------------------------------------------------------------------- /data/image/blue_strokes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/blue_strokes.jpg -------------------------------------------------------------------------------- /data/image/seated-nude.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/seated-nude.jpg -------------------------------------------------------------------------------- /data/image/dark_matter_bw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/dark_matter_bw.png -------------------------------------------------------------------------------- /data/image/Nature-EruptingVolcano.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byungsook/neural-flow-style/HEAD/data/image/Nature-EruptingVolcano.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright © 2020, ETH Zurich, Byungsoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /setup.bat: -------------------------------------------------------------------------------- 1 | REM install virtualenv if needed 2 | pip3 install virtualenv 3 | virtualenv --system-site-packages ./venv 4 | call .\venv\Scripts\activate 5 | pip install --upgrade pip 6 | 7 | REM install packages 8 | pip install --upgrade tensorflow==1.15 tqdm matplotlib Pillow imageio scipy scikit-image==0.14.2 open3d-python 9 | 10 | REM REM 1. mantaflow 11 | REM cd .. 12 | REM git clone https://bitbucket.org/mantaflow/manta.git 13 | REM cd manta 14 | REM git checkout 15eaf4 15 | 16 | REM REM 2. SPlisHSPlasH 17 | REM cd .. 18 | REM git clone https://github.com/InteractiveComputerGraphics/SPlisHSPlasH.git 19 | 20 | REM REM 3. partio 21 | REM cd .. 22 | REM git clone https://github.com/wdas/partio.git 23 | 24 | REM REM download freeglut (MSVC) for compiling partio 25 | REM https://www.transmissionzero.co.uk/files/software/development/GLUT/freeglut-MSVC.zip 26 | 27 | REM REM download swig for partio python-binding 28 | REM http://prdownloads.sourceforge.net/swig/swigwin-4.0.2.zip 29 | 30 | REM Note for partio compile 31 | REM Note1: uncheck BUILD_SHARED_LIBS 32 | REM Note2: remove None/None.lib; from Linker-Input of _partio PropertyPages. 33 | REM Note3: copy build/py/partio.py to build/py/Release -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lagrangian Neural Style Transfer for Fluids 2 | 3 | Tensorflow implementation of [Lagrangian Neural Style Transfer for Fluids](http://www.byungsoo.me/project/lnst). 4 | 5 | [Byungsoo Kim](http://www.byungsoo.me), [Vinicius C. Azevedo](http://graphics.ethz.ch/~vviniciu/), [Markus Gross](https://graphics.ethz.ch/people/grossm), [Barbara Solenthaler](https://graphics.ethz.ch/~sobarbar/) 6 | 7 | [Computer Graphics Laboratory](https://cgl.ethz.ch/), ETH Zurich 8 | 9 | ![teaser](./asset/teaser.png) 10 | 11 | (Note that [Transport-Based Neural Style Transfer for Smoke Simulations (TNST)](http://www.byungsoo.me/project/neural-flow-style) implementation is moved to `tnst` branch.) 12 | 13 | ## Requirements 14 | 15 | This code is tested on Windows 10 with GTX 1080 (8GB) and the following requirements: 16 | 17 | - [Python 3](https://www.python.org/) 18 | - [TensorFlow 1.15](https://www.tensorflow.org/install/) 19 | - [mantaflow](http://mantaflow.com) 20 | - [SPlisHSPlash](https://github.com/InteractiveComputerGraphics/SPlisHSPlasH) 21 | - [Partio](https://github.com/wdas/partio) 22 | 23 | Run `setup.bat` for setup. (3rd parties must be installed manually). 24 | 25 | Also download the pre-trained [inception](https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip) and [vgg19](http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz) networks, and unzip it in `data/model`. 26 | 27 | ## Usage 28 | 29 | For details about parameters and examples, please take a closer look at `run.bat` and corresponding demo code, and for semantic transfer, see [this page](http://storage.googleapis.com/deepdream/visualz/tensorflow_inception/index.html) for pattern selection. 30 | 31 | ## Results (smokegun, single frame) 32 | 33 | ![single](./asset/smokegun.png) 34 | 35 | ## Results (chocolate, sequence) 36 | 37 | ![sequence](./asset/chocolate_seq.png) 38 | 39 | ## Results (dambreak2d, sequence) 40 | 41 | ![sequence](./asset/dambreak2d_seq.png) 42 | 43 | ## Author 44 | 45 | [Byungsoo Kim](http://www.byungsoo.me) / [byungsook@github](https://github.com/byungsook) -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | REM setup if needed 2 | REM .\setup.bat 3 | 4 | REM activate env 5 | call .\venv\Scripts\activate 6 | 7 | REM ------------------------------------------------------- 8 | REM generate a smokegun dataset 9 | ..\manta\build\Release\manta.exe ./scene/smokegun.py 10 | 11 | REM generate a particle-based dataset from smokegun 12 | python test_smokegun_resim.py 13 | 14 | REM density based stylization 15 | python test_smokegun.py --tag net --content_layer mixed3b_3x3_bottleneck_pre_relu --content_channel 44 16 | python test_smokegun.py --tag square --content_layer mixed3b_3x3_bottleneck_pre_relu --content_channel 65 17 | python test_smokegun.py --tag cloud --content_layer mixed4b_pool_reduce_pre_relu --content_channel 6 18 | python test_smokegun.py --tag flower --content_layer mixed4b_pool_reduce_pre_relu --content_channel 16 19 | python test_smokegun.py --tag fluffy --content_layer mixed4b_pool_reduce_pre_relu --content_channel 60 20 | python test_smokegun.py --tag ribbon --content_layer mixed4b_pool_reduce_pre_relu --content_channel 38 21 | 22 | python test_smokegun.py --tag fire --style_target data/image/fire.png --w_content 0 --w_style 1 23 | python test_smokegun.py --tag volcano --style_target data/image/volcano.png --w_content 0 --w_style 1 24 | python test_smokegun.py --tag nude --style_target data/image/seated-nude.jpg --w_content 0 --w_style 1 25 | python test_smokegun.py --tag starry --style_target data/image/starry.jpg --w_content 0 --w_style 1 26 | python test_smokegun.py --tag stroke --style_target data/image/blue_strokes.jpg --w_content 0 --w_style 1 27 | python test_smokegun.py --tag spiral --style_target data/image/pattern1.png --w_content 0 --w_style 1 28 | 29 | REM ------------------------------------------------------- 30 | REM generate a chocolate dataset 31 | python scene/chocolate.py 32 | 33 | REM position based stylization 34 | python test_chocolate.py --dataset chocolate --target_frame 70 --style_target data/image/pattern1.png --w_style 1 --w_content 0 35 | REM interpolation test 36 | python test_chocolate.py --dataset chocolate --target_frame 70 --num_frames 21 --interp 5 --style_target data/image/pattern1.png --w_style 1 --w_content 0 37 | 38 | REM ------------------------------------------------------- 39 | REM generate a dambreak2d dataset 40 | python scene/dambreak2d.py 41 | 42 | REM 2d color stylization 43 | python test_dambreak2d.py --dataset dambread2d --target_frame 150 --style_target data/image/fire_new.jpg --w_style 1 --w_content 0 44 | python test_dambreak2d.py --dataset dambread2d --target_frame 150 --num_frames 20 --batch_size 4 --style_target data/image/wave.jpeg --w_style 1 --w_content 0 -------------------------------------------------------------------------------- /scene/dambreak2d.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import argparse 5 | from datetime import datetime 6 | import os 7 | import json 8 | from subprocess import call 9 | from glob import glob 10 | import numpy as np 11 | from tqdm import trange 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--data_dir", type=str, default='E:/lnst/data/dambreak2d') 16 | parser.add_argument("--path_format", type=str, default='%03d.%s') 17 | parser.add_argument("--sph_path", type=str, default='E:/SPlisHSPlasH/bin/StaticBoundarySimulator.exe') 18 | 19 | parser.add_argument("--gui", type=bool, default=False) 20 | parser.add_argument("--fps", type=int, default=25) 21 | parser.add_argument("--attr", type=str, default='density') # ;velocity 22 | 23 | parser.add_argument("--particleRadius", type=float, default=0.025) 24 | parser.add_argument("--cflMaxTimeStepSize", type=float, default=0.0025) 25 | parser.add_argument("--timeStepSize", type=float, default=0.001) 26 | parser.add_argument("--numFrames", type=int, default=200) 27 | 28 | parser.add_argument("--res_x", type=int, default=256) 29 | parser.add_argument("--res_y", type=int, default=128) 30 | parser.add_argument("--disc", type=int, default=2) # 4 in 2d if 2 (8 in 2d) 31 | 32 | args = parser.parse_args() 33 | 34 | args.cell_size = 2*args.particleRadius * args.disc 35 | args.domain_x = args.res_x * args.cell_size 36 | args.domain_y = args.res_y * args.cell_size 37 | print('res:', args.res_x, args.res_y) 38 | print('domain:', args.domain_x, args.domain_y) 39 | 40 | # default scene 41 | scene = { 42 | "Configuration": { 43 | "pause": True, 44 | "sim2D": True, 45 | "particleRadius": args.particleRadius, 46 | "colorMapType": 1, 47 | "numberOfStepsPerRenderUpdate": 4, 48 | "density0": 1000, 49 | "simulationMethod": 4, 50 | "gravitation": [ 0, -9.81, 0 ], 51 | "cflMethod": 1, 52 | "cflFactor": 1, 53 | "cflMaxTimeStepSize": args.cflMaxTimeStepSize, 54 | "maxIterations": 100, 55 | "maxError": 0.1, 56 | "maxIterationsV": 100, 57 | "maxErrorV": 0.1, 58 | "stiffness": 50000, 59 | "exponent": 7, 60 | "velocityUpdateMethod": 0, 61 | "enableDivergenceSolver": True, 62 | "boundaryHandlingMethod": 2, 63 | "enableZSort": False, 64 | 'renderWalls': 3, 65 | "stopAt": args.numFrames / args.fps, 66 | 'enablePartioExport': True, 67 | 'dataExportFPS': args.fps, 68 | 'particleAttributes': args.attr 69 | }, 70 | "Fluid": { 71 | "surfaceTension": 0.2, 72 | "surfaceTensionMethod": 0, 73 | "viscosity": 0.01, 74 | "viscosityMethod": 1, 75 | "vorticityMethod": 1, 76 | "vorticity": 0.1, 77 | "viscosityOmega": 0.05, 78 | "inertiaInverse": 0.5, 79 | "colorMapType": 1 80 | }, 81 | "RigidBodies": [ 82 | { 83 | "geometryFile": "E:/SPlisHSPlasH/data/models/UnitBox.obj", 84 | "translation": [ args.domain_x/2, args.domain_y/2, 0 ], 85 | "rotationAxis": [ 1, 0, 0 ], 86 | "rotationAngle": 0, 87 | "scale": [ args.domain_x, args.domain_y, 1 ], 88 | "color": [ 0.1, 0.4, 0.6, 1.0 ], 89 | "isDynamic": False, 90 | "isWall": True, 91 | "mapInvert": True, 92 | "mapThickness": 0.0, 93 | "mapResolution": [ 30, 30, 20 ] 94 | } 95 | ], 96 | "FluidBlocks": [ 97 | { 98 | "denseMode": 0, 99 | "start": [ 0.0, 0.0, -1 ], 100 | "end": [ args.domain_x*0.35, args.domain_y*0.6, 1 ], 101 | "translation": [ args.particleRadius, args.particleRadius, 0 ], 102 | "scale": [ 1, 1, 1 ] 103 | } 104 | ] 105 | } 106 | 107 | def main(): 108 | if not os.path.exists(args.data_dir): 109 | os.makedirs(args.data_dir) 110 | 111 | args.scene_path = os.path.join(args.data_dir, 'scene.json') 112 | with open(args.scene_path, 'w') as fp: 113 | json.dump(scene, fp, indent=2) #, sort_keys=True) 114 | 115 | args.sh = [args.sph_path, args.scene_path, '--output-dir', args.data_dir] #, '--no-cache'] 116 | if not args.gui: args.sh.append('--no-gui') 117 | 118 | args_file = os.path.join(args.data_dir, 'args.txt') 119 | with open(args_file, 'w') as f: 120 | print('%s: arguments' % datetime.now()) 121 | for k, v in vars(args).items(): 122 | print(' %s: %s' % (k, v)) 123 | f.write('%s: %s\n' % (k, v)) 124 | 125 | # simulation 126 | call(args.sh, shell=True) 127 | 128 | print('Done') 129 | 130 | if __name__ == '__main__': 131 | main() -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | # http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz 2 | # https://github.com/singlasahil14/style-transfer/blob/master/nets/vgg.py 3 | # https://medium.com/mlreview/getting-inception-architectures-to-work-with-style-transfer-767d53475bf8 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | from collections import OrderedDict 11 | import os 12 | 13 | slim = tf.contrib.slim 14 | 15 | # _R_MEAN = 123.68 16 | # _G_MEAN = 116.779 17 | # _B_MEAN = 103.939 18 | _R_MEAN = 0.485*255 19 | _G_MEAN = 0.456*255 20 | _B_MEAN = 0.406*255 21 | _R_STD = 0.229*255 22 | _G_STD = 0.224*255 23 | _B_STD = 0.225*255 24 | 25 | 26 | _content_layers_dict = { 27 | 'vgg-16': ('conv2_2',), 28 | 'vgg-19': ('conv2_2',), 29 | 'inception-v1': ('Conv2d_2c_3x3',), 30 | 'inception-v2': ('Conv2d_2c_3x3',), 31 | 'inception-v3': ('Conv2d_4a_3x3',), 32 | 'inception-v4': ('Mixed_3a',), 33 | } 34 | 35 | _style_layers_dict = { 36 | 'vgg-16': ('conv3_1', 'conv4_1', 'conv5_1'), 37 | 'vgg-19': ('conv3_1', 'conv4_1', 'conv5_1'), 38 | 'inception-v1': ('Conv2d_2c_3x3', 'Mixed_3c', 'Mixed_4b', 'Mixed_5b'), 39 | 'inception-v2': ('Conv2d_2c_3x3', 'Mixed_3b', 'Mixed_4a', 'Mixed_5a'), 40 | 'inception-v3': ('Conv2d_4a_3x3', 'Mixed_5b', 'Mixed_6a', 'Mixed_7a'), 41 | 'inception-v4': ('Mixed_4a', 'Mixed_5a', 'Mixed_6a', 'Mixed_7a'), 42 | } 43 | 44 | def vgg_arg_scope(padding='SAME'): 45 | with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, 46 | biases_initializer=tf.zeros_initializer()): 47 | with slim.arg_scope([slim.conv2d], padding=padding) as arg_sc: 48 | return arg_sc 49 | 50 | def preprocess(images): 51 | images -= tf.constant([ _R_MEAN , _G_MEAN, _B_MEAN]) 52 | # images /= tf.constant([ _R_STD , _G_STD, _B_STD]) 53 | return images 54 | 55 | def repeat(inputs, repetitions, layer, *args, **kwargs): 56 | scope = kwargs.pop('scope', 'conv') 57 | end_points = kwargs.pop('end_points', OrderedDict()) 58 | with tf.compat.v1.variable_scope(scope, 'Repeat', [inputs]): 59 | inputs = tf.convert_to_tensor(inputs) 60 | outputs = inputs 61 | for i in range(repetitions): 62 | scope_name = scope + '_' + str(i+1) 63 | kwargs['scope'] = scope_name 64 | outputs = layer(outputs, *args, **kwargs) 65 | end_points[scope_name] = outputs 66 | return outputs, end_points 67 | 68 | def vgg_16(inputs, scope='vgg_16', reuse=False, pool_fn=slim.avg_pool2d): 69 | with tf.compat.v1.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc: 70 | # Collect outputs for conv2d and pool_fn. 71 | with slim.arg_scope([slim.conv2d, pool_fn]): 72 | net, end_points = repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 73 | net = pool_fn(net, [2, 2], scope='pool1') 74 | end_points['pool1'] = net 75 | net, end_points = repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2', end_points=end_points) 76 | net = pool_fn(net, [2, 2], scope='pool2') 77 | end_points['pool2'] = net 78 | net, end_points = repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3', end_points=end_points) 79 | net = pool_fn(net, [2, 2], scope='pool3') 80 | end_points['pool3'] = net 81 | net, end_points = repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4', end_points=end_points) 82 | net = pool_fn(net, [2, 2], scope='pool4') 83 | end_points['pool4'] = net 84 | net, end_points = repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5', end_points=end_points) 85 | net = pool_fn(net, [2, 2], scope='pool5') 86 | end_points['pool5'] = net 87 | return end_points 88 | 89 | def vgg_19(inputs, scope='vgg_19', reuse=False, pool_fn=slim.avg_pool2d): 90 | with tf.compat.v1.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc: 91 | # Collect outputs for conv2d and pool_fn. 92 | with slim.arg_scope([slim.conv2d, pool_fn]): 93 | net, end_points = repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 94 | net = pool_fn(net, [2, 2], scope='pool1') 95 | end_points['pool1'] = net 96 | net, end_points = repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2', end_points=end_points) 97 | net = pool_fn(net, [2, 2], scope='pool2') 98 | end_points['pool2'] = net 99 | net, end_points = repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3', end_points=end_points) 100 | net = pool_fn(net, [2, 2], scope='pool3') 101 | end_points['pool3'] = net 102 | net, end_points = repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4', end_points=end_points) 103 | net = pool_fn(net, [2, 2], scope='pool4') 104 | end_points['pool4'] = net 105 | net, end_points = repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5', end_points=end_points) 106 | net = pool_fn(net, [2, 2], scope='pool5') 107 | end_points['pool5'] = net 108 | return end_points 109 | 110 | def load_vgg(d, model_path, sess, pool_fn=slim.avg_pool2d): 111 | model_name = os.path.basename(model_path).split('.')[0] # vgg_16 112 | # print(model_name) 113 | vgg_in = preprocess(d) 114 | arg_scope = vgg_arg_scope() 115 | with slim.arg_scope(arg_scope): 116 | if '16' in model_name: layers = vgg_16(vgg_in, pool_fn=pool_fn) 117 | else: layers = vgg_19(vgg_in, pool_fn=pool_fn) 118 | 119 | init = slim.assign_from_checkpoint_fn(model_path, slim.get_model_variables(model_name)) 120 | init(sess) 121 | return layers -------------------------------------------------------------------------------- /scene/chocolate.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import argparse 5 | from datetime import datetime 6 | import os 7 | import json 8 | from subprocess import call 9 | from glob import glob 10 | import numpy as np 11 | from tqdm import trange 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--data_dir", type=str, default='E:/lnst/data/chocolate') 16 | parser.add_argument("--path_format", type=str, default='%03d.%s') 17 | parser.add_argument("--sph_path", type=str, default='E:/SPlisHSPlasH/bin/StaticBoundarySimulator.exe') 18 | 19 | parser.add_argument("--gui", type=bool, default=False) 20 | parser.add_argument("--fps", type=int, default=30) 21 | parser.add_argument("--attr", type=str, default='density') # ;velocity 22 | 23 | parser.add_argument("--particleRadius", type=float, default=0.025) # 0.025 0.003125 24 | parser.add_argument("--cflMaxTimeStepSize", type=float, default=0.0025) # 0.005 0.0005 25 | parser.add_argument("--timeStepSize", type=float, default=0.0005) 26 | parser.add_argument("--numFrames", type=int, default=120) 27 | 28 | parser.add_argument("--res_x", type=int, default=128) 29 | parser.add_argument("--res_y", type=int, default=128) 30 | parser.add_argument("--res_z", type=int, default=128) 31 | parser.add_argument("--disc", type=int, default=2) # 4 in 2d if 2 (8 in 2d) 32 | 33 | args = parser.parse_args() 34 | 35 | args.cell_size = 2*args.particleRadius * args.disc 36 | args.domain_x = args.res_x * args.cell_size 37 | args.domain_y = args.res_y * args.cell_size 38 | args.domain_z = args.res_z * args.cell_size 39 | print('res:', args.res_x, args.res_y, args.res_z) 40 | print('domain:', args.domain_x, args.domain_y, args.domain_z) 41 | 42 | # default scene 43 | scene = { 44 | "Configuration": { 45 | "pause": True, 46 | "sim2D": False, 47 | "particleRadius": args.particleRadius, 48 | "colorMapType": 1, 49 | "numberOfStepsPerRenderUpdate": 4, 50 | "density0": 1000, 51 | "simulationMethod": 4, 52 | "gravitation": [ 0, -9.81, 0.1 ], 53 | "cflMethod": 1, 54 | "cflFactor": 1, 55 | "cflMaxTimeStepSize": args.cflMaxTimeStepSize, 56 | "maxIterations": 100, 57 | "maxError": 0.1, 58 | "maxIterationsV": 100, 59 | "maxErrorV": 0.1, 60 | "stiffness": 50000, 61 | "exponent": 7, 62 | "velocityUpdateMethod": 0, 63 | "enableDivergenceSolver": True, 64 | "boundaryHandlingMethod": 2, 65 | "enableZSort": False, 66 | 'renderWalls': 3, 67 | "stopAt": args.numFrames / args.fps, 68 | 'enablePartioExport': True, 69 | 'dataExportFPS': args.fps, 70 | 'particleAttributes': args.attr, 71 | }, 72 | "Fluid": { 73 | "surfaceTension": 0.2, 74 | "surfaceTensionMethod": 0, 75 | "viscosity": 5000, 76 | "viscosityBoundary": 5000, 77 | "viscosityMethod": 7, 78 | "viscoMaxIter": 200, 79 | "viscoMaxError": 0.05, 80 | "colorMapType": 1, 81 | "maxEmitterParticles": 1000000, 82 | "emitterReuseParticles": False, 83 | "emitterBoxMin": [args.domain_x*0.3, args.domain_y*0.95, args.domain_z*0.49], 84 | "emitterBoxMax": [args.domain_x*0.7, args.domain_y, args.domain_z*0.51], 85 | }, 86 | "RigidBodies": [ 87 | { 88 | "geometryFile": "E:/SPlisHSPlasH/data/models/UnitBox.obj", 89 | "translation": [ args.domain_x/2, args.domain_y/2, args.domain_z/2 ], 90 | "rotationAxis": [ 1, 0, 0 ], 91 | "rotationAngle": 0, 92 | "scale": [ args.domain_x, args.domain_y, args.domain_z ], 93 | "color": [ 0.1, 0.4, 0.6, 1.0 ], 94 | "isDynamic": False, 95 | "isWall": True, 96 | "mapInvert": True, 97 | "mapThickness": 0.0, 98 | "mapResolution": [ 30, 30, 30 ] 99 | }, 100 | { 101 | "geometryFile": "E:/SPlisHSPlasH/data/models/sphere.obj", 102 | "translation": [ args.domain_x/2, 0, args.domain_z/2.5 ], 103 | "rotationAxis": [0, 1, 0], 104 | "rotationAngle": 0, 105 | "scale": [args.domain_x/4, args.domain_y/4, args.domain_z/4], 106 | "color": [0.1, 0.4, 0.6, 1.0], 107 | "isDynamic": False, 108 | "isWall": False, 109 | "mapInvert": False, 110 | "mapThickness": 0.0, 111 | "mapResolution": [20,20,20] 112 | } 113 | ], 114 | "Emitters": [ 115 | { 116 | "width": 2, 117 | "height": 100, 118 | "translation": [args.domain_x*0.5,args.domain_y,args.domain_z*0.5], 119 | "rotationAxis": [0, 0, 1], 120 | "rotationAngle": -1.5707963267948966192313216916398, 121 | "velocity": 5, 122 | "type": 0, 123 | "emitStartTime": 0, 124 | "emitEndTime": 10000000, 125 | } 126 | ] 127 | } 128 | 129 | def main(): 130 | if not os.path.exists(args.data_dir): 131 | os.makedirs(args.data_dir) 132 | 133 | args.scene_path = os.path.join(args.data_dir, 'scene.json') 134 | with open(args.scene_path, 'w') as fp: 135 | json.dump(scene, fp, indent=2) #, sort_keys=True) 136 | 137 | args.sh = [args.sph_path, args.scene_path, '--output-dir', args.data_dir] #, '--no-cache'] 138 | if not args.gui: args.sh.append('--no-gui') 139 | 140 | args_file = os.path.join(args.data_dir, 'args.txt') 141 | with open(args_file, 'w') as f: 142 | print('%s: arguments' % datetime.now()) 143 | for k, v in vars(args).items(): 144 | print(' %s: %s' % (k, v)) 145 | f.write('%s: %s\n' % (k, v)) 146 | 147 | # simulation 148 | call(args.sh, shell=True) 149 | 150 | print('Done') 151 | 152 | if __name__ == '__main__': 153 | main() -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import argparse 5 | from util import str2bool 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def add_argument_group(name): 11 | arg = parser.add_argument_group(name) 12 | arg_lists.append(arg) 13 | return arg 14 | 15 | # path 16 | path_arg = add_argument_group('Path') 17 | path_arg.add_argument("--data_dir", type=str, default='data') 18 | path_arg.add_argument("--log_dir", type=str, default='log') 19 | path_arg.add_argument("--model_dir", type=str, default='model') 20 | path_arg.add_argument("--d_path", type=str, default='d/%03d.npz') 21 | path_arg.add_argument("--v_path", type=str, default='v/%03d.npz') 22 | path_arg.add_argument("--tag", type=str, default='test') 23 | 24 | # dataset 25 | data_arg = add_argument_group('Data') 26 | data_arg.add_argument("--dataset", type=str, default='smokegun') 27 | data_arg.add_argument("--target_frame", type=int, default=70) 28 | data_arg.add_argument("--num_frames", type=int, default=1) 29 | data_arg.add_argument("--scale", type=float, default=2.0) 30 | 31 | # network 32 | network_arg = add_argument_group('Network') 33 | network_arg.add_argument("--network", type=str, default='tensorflow_inception_graph.pb', 34 | choices=['tensorflow_inception_graph.pb','vgg_19.ckpt']) 35 | network_arg.add_argument("--pool1", type=str2bool, default=False) 36 | network_arg.add_argument("--batch_size", type=int, default=1) 37 | 38 | # grid 39 | grid_arg = add_argument_group('Grid') 40 | grid_arg.add_argument("--resolution", nargs='+', type=int, default=[384,288]) # HW or DHW 41 | grid_arg.add_argument("--adv_order", type=int, default=1, choices=[1,2], help='SL or MacCormack') 42 | 43 | # particle 44 | pt_arg = add_argument_group('Particle') 45 | pt_arg.add_argument("--domain", nargs='+', type=int, default=[12.8,12.8,12.8]) # HW or DHW 46 | pt_arg.add_argument("--radius", type=float, default=0.025, help='kernel radius for density estimation') 47 | pt_arg.add_argument("--disc", type=int, default=2, help='grid discretization') 48 | pt_arg.add_argument("--nsize", type=int, default=1, help='# neighbors cells to check') 49 | pt_arg.add_argument("--rest_density", type=float, default=1000) 50 | pt_arg.add_argument("--w_pressure", type=float, default=0) 51 | pt_arg.add_argument("--w_density", type=float, default=0) 52 | pt_arg.add_argument("--window_sigma", type=float, default=2) 53 | pt_arg.add_argument("--interp", type=int, default=1) 54 | pt_arg.add_argument("--support", type=float, default=4) 55 | pt_arg.add_argument("--k", type=int, default=3) 56 | pt_arg.add_argument("--clip", type=str2bool, default=False, help='whether to clamp particle pos to domain or not') 57 | 58 | # rendering 59 | render_arg = add_argument_group('Render') 60 | render_arg.add_argument("--resize_scale", type=float, default=1.0, help='to upscale rendering') 61 | render_arg.add_argument("--transmit", type=float, default=0.01) 62 | render_arg.add_argument("--rotate", type=str2bool, default=False) 63 | render_arg.add_argument('--phi0', type=int, default=-5) # latitude (elevation) start 64 | render_arg.add_argument('--phi1', type=int, default=5) # latitude end 65 | render_arg.add_argument('--phi_unit', type=int, default=5) 66 | render_arg.add_argument('--theta0', type=int, default=-10) # longitude start 67 | render_arg.add_argument('--theta1', type=int, default=10) # longitude end 68 | render_arg.add_argument('--theta_unit', type=int, default=10) 69 | render_arg.add_argument('--v_batch', type=int, default=1, help='# of rotation matrix for batch process') 70 | render_arg.add_argument('--n_views', type=int, default=9, help='# of view points') 71 | render_arg.add_argument('--sample_type', type=str, default='poisson', 72 | choices=['uniform', 'poisson', 'both']) 73 | render_arg.add_argument("--render_liquid", type=str2bool, default=False) 74 | 75 | # optimizer 76 | opt_arg = add_argument_group('Optimizer') 77 | opt_arg.add_argument("--target_field", type=str, default='p', choices=['d', 'p', 'c']) 78 | opt_arg.add_argument("--optimizer", type=str, default='adam') 79 | opt_arg.add_argument("--iter", type=int, default=20) 80 | opt_arg.add_argument("--lr", type=float, default=0.0007) 81 | opt_arg.add_argument("--lr_scale", type=float, default=1) 82 | opt_arg.add_argument("--octave_n", type=int, default=2) 83 | opt_arg.add_argument("--octave_scale", type=float, default=1.8) 84 | opt_arg.add_argument("--frames_per_opt", type=int, default=10) 85 | 86 | # style 87 | style_arg = add_argument_group('Style') 88 | style_arg.add_argument("--content_layer", type=str, default='mixed4d_3x3_bottleneck_pre_relu') 89 | style_arg.add_argument("--content_channel", type=int, default=139) 90 | style_arg.add_argument("--w_content", type=float, default=1) 91 | style_arg.add_argument("--w_content_amp", type=float, default=100) 92 | style_arg.add_argument("--content_target", type=str, default='') 93 | style_arg.add_argument("--top_k", type=int, default=5) 94 | style_arg.add_argument("--style_layer", nargs='+', type=str, default=['conv3_1']) #['conv2d2','mixed3a','mixed4a','mixed5a']) 95 | style_arg.add_argument("--w_style", type=float, default=0) 96 | style_arg.add_argument("--w_style_layer", nargs='+', type=float, default=[1]) #[1,0.01,0.3,10]) 97 | style_arg.add_argument("--hist_layer", nargs='+', type=str, default=['input']) 98 | style_arg.add_argument("--w_hist", type=float, default=0) 99 | style_arg.add_argument("--w_hist_layer", nargs='+', type=float, default=[1]) 100 | style_arg.add_argument("--w_tv", type=float, default=0) 101 | style_arg.add_argument("--style_target", type=str, default='') # data/image/fire_new.jpg 102 | style_arg.add_argument("--style_mask", type=str2bool, default=False) 103 | style_arg.add_argument("--style_mask_on_ref", type=str2bool, default=False) 104 | style_arg.add_argument("--style_tiling", type=int, default=1) 105 | style_arg.add_argument("--style_init", type=str, default='noise', choices=['noise','style']) 106 | 107 | # misc 108 | misc_arg = add_argument_group('Misc') 109 | misc_arg.add_argument("--seed", type=int, default=123) 110 | misc_arg.add_argument('--gpu_id', type=str, default='0', help='-1:cpu') 111 | 112 | def get_config(): 113 | config, unparsed = parser.parse_known_args() 114 | return config, unparsed -------------------------------------------------------------------------------- /test_smokegun.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import numpy as np 5 | import tensorflow as tf 6 | import matplotlib.pyplot as plt 7 | import os 8 | from tqdm import trange 9 | from config import get_config 10 | from util import * 11 | from styler_3p import Styler 12 | import sys 13 | sys.path.append('E:/partio/build/py/Release') 14 | import partio 15 | 16 | def run(config): 17 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi 18 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id # "0, 1" for multiple 19 | 20 | prepare_dirs_and_logger(config) 21 | tf.compat.v1.set_random_seed(config.seed) 22 | config.rng = np.random.RandomState(config.seed) 23 | 24 | styler = Styler(config) 25 | styler.load_img(config.resolution[1:]) 26 | 27 | params = {} 28 | 29 | # the number of particles range 30 | nmin, nmax = np.iinfo(np.int32).max, 0 31 | for i in range(config.num_frames): 32 | pt_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+i)) 33 | pt = partio.read(pt_path) 34 | p_num = pt.numParticles() 35 | nmin, nmax = min(nmin,p_num), max(nmax,p_num) 36 | 37 | print('# range:', nmin, nmax) 38 | 39 | p, r = [], [] 40 | for i in trange(config.num_frames, desc='load particle'): 41 | pt_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+i)) 42 | pt = partio.read(pt_path) 43 | 44 | p_id = pt.attributeInfo('id') 45 | p_pos = pt.attributeInfo('position') 46 | p_den = pt.attributeInfo('density') 47 | 48 | p_ = np.ones([nmax,3], dtype=np.float32)*-1 49 | r_ = np.zeros([nmax,config.num_kernels], dtype=np.float32) 50 | 51 | p_num = pt.numParticles() 52 | for j in range(p_num): 53 | p_id_j = pt.get(p_id, j)[0] 54 | p_[j] = pt.get(p_pos, p_id_j) 55 | r_[j] = pt.get(p_den, p_id_j) 56 | 57 | r.append(r_) 58 | 59 | # normalize particle position [0-1] 60 | px, py, pz = p_[...,0], p_[...,1], p_[...,2] 61 | px /= config.domain[2] 62 | py /= config.domain[1] 63 | pz /= config.domain[0] 64 | p_ = np.stack([pz,py,px], axis=-1) 65 | p.append(p_) 66 | 67 | 68 | print('resolution:', config.resolution) 69 | print('domain:', config.domain) 70 | print('radius:', config.radius) 71 | print('normalized px range', px.min(), px.max()) 72 | print('normalized py range', py.min(), py.max()) 73 | print('normalized pz range', pz.min(), pz.max()) 74 | 75 | params['p'] = p 76 | params['r'] = r 77 | 78 | # styler.render_test(params) 79 | result = styler.run(params) 80 | 81 | # save loss plot 82 | l = result['l'] 83 | lb = [] 84 | for o, l_ in enumerate(l): 85 | lb_, = plt.plot(range(len(l_)), l_, label='oct %d' % o) 86 | lb.append(lb_) 87 | plt.legend(handles=lb) 88 | # plt.show() 89 | plot_path = os.path.join(config.log_dir, 'loss_plot.png') 90 | plt.savefig(plot_path) 91 | 92 | r_sty = result['r'] 93 | for i, r_sty_ in enumerate(r_sty): 94 | im = Image.fromarray(r_sty_) 95 | d_path = os.path.join(config.log_dir, '%03d.png' % (config.target_frame+i)) 96 | im.save(d_path) 97 | 98 | d_sty = result['d'] 99 | for i, d_sty_ in enumerate(d_sty): 100 | d_path = os.path.join(config.log_dir, '%03d.npz' % (config.target_frame+i)) 101 | np.savez_compressed(d_path, x=d_sty_[:,::-1]) 102 | 103 | d_intm = result['d_intm'] 104 | for o, d_intm_o in enumerate(d_intm): 105 | for i, d_intm_ in enumerate(d_intm_o): 106 | if d_intm_ is None: continue 107 | im = Image.fromarray(d_intm_) 108 | d_path = os.path.join(config.log_dir, 'o%02d_%03d.png' % (o, config.target_frame+i)) 109 | im.save(d_path) 110 | 111 | def main(config): 112 | config.dataset = 'smokegun' 113 | 114 | # config.d_path = 'pt_low_o1/%03d.npz' 115 | # config.num_kernels = 1 116 | 117 | config.d_path = 'pt_low_o2/%03d.bgeo' 118 | config.num_kernels = 2 119 | 120 | config.kernel_scale = 2 121 | config.support = 4 122 | 123 | config.disc = 1 124 | cell_size = 1 # == 2*radius*disc 125 | config.radius = cell_size/config.disc/2 126 | config.nsize = 1 127 | config.rest_density = 1000 128 | config.resolution = [200,300,200] 129 | config.domain = [200,300,200] 130 | config.clip = False 131 | config.w_density = 0 132 | config.k = 3 133 | 134 | config.window_sigma = 3 135 | config.batch_size = 1 136 | config.frames_per_opt = 1 137 | 138 | config.target_field = 'd' 139 | config.lr = 0.1 140 | config.network = 'tensorflow_inception_graph.pb' 141 | config.style_layer = ['conv2d2','mixed3b','mixed4b'] 142 | config.w_style_layer = [1,1,1] 143 | config.octave_n = 1 144 | config.octave_scale = 1.8 145 | config.transmit = 0.01 # 0.01, 5 146 | config.iter = 20 147 | config.resize_scale = 300/config.resolution[0] 148 | config.rotate = False 149 | 150 | multi_frame = False 151 | config.interp = 1 152 | config.batch_size = 1 153 | config.frames_per_opt = 1 154 | # if multi_frame: 155 | # config.num_frames = 120 156 | # config.target_frame = 0 157 | # else: 158 | # config.target_frame = 70 159 | # config.num_frames = 1 160 | 161 | semantic = True 162 | density_reg = False 163 | 164 | # if semantic: 165 | # config.w_style = 0 166 | # config.w_content = 1 167 | # config.content_layer = 'mixed3b_3x3_bottleneck_pre_relu' 168 | # config.content_channel = 44 # net 169 | # else: 170 | # # style 171 | # config.w_style = 1 172 | # config.w_content = 0 173 | 174 | # style_list = { 175 | # 'spiral': 'pattern1.png', 176 | # 'fire_new': 'fire_new.jpg', 177 | # 'ben_giles': 'ben_giles.png', 178 | # 'wave': 'wave.jpeg', 179 | # } 180 | # style = 'spiral' 181 | # config.style_target = os.path.join(config.data_dir, 'image', style_list[style]) 182 | 183 | # density regularization 184 | if density_reg: 185 | config.w_density = 1e-6 186 | 187 | # if config.w_content == 1: 188 | # config.tag = 'test_%s_%s_%d' % ( 189 | # config.target_field, config.content_layer, config.content_channel) 190 | # else: 191 | # style = os.path.splitext(os.path.basename(config.style_target))[0] 192 | # config.tag = 'test_%s_%s' % ( 193 | # config.target_field, style) 194 | 195 | # config.tag += '_%d' % config.num_frames 196 | 197 | run(config) 198 | 199 | if __name__ == "__main__": 200 | config, unparsed = get_config() 201 | main(config) -------------------------------------------------------------------------------- /test_dambreak2d.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import numpy as np 5 | import tensorflow as tf 6 | import os 7 | from tqdm import trange 8 | from config import get_config 9 | from util import * 10 | from styler_2p import Styler 11 | import sys 12 | sys.path.append('E:/partio/build/py/Release') 13 | import partio 14 | 15 | def run(config): 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi 17 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id # "0, 1" for multiple 18 | 19 | prepare_dirs_and_logger(config) 20 | tf.compat.v1.set_random_seed(config.seed) 21 | config.rng = np.random.RandomState(config.seed) 22 | 23 | styler = Styler(config) 24 | styler.load_img(config.resolution) 25 | 26 | params = {} 27 | 28 | # load particles 29 | p = [] 30 | r = [] 31 | for i in trange(config.num_frames, desc='load particle'): 32 | pt_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+i)) 33 | pt = partio.read(pt_path) 34 | 35 | p_id = pt.attributeInfo('id') 36 | p_pos = pt.attributeInfo('position') 37 | p_den = pt.attributeInfo('density') 38 | 39 | p_num = pt.numParticles() 40 | p_ = np.zeros([p_num,2], dtype=np.float32) 41 | r_ = np.zeros([p_num,1], dtype=np.float32) 42 | 43 | for j in range(p_num): 44 | p_id_ = pt.get(p_id, j)[0] 45 | p_[p_id_] = pt.get(p_pos, p_id_)[:-1] # 2d 46 | r_[p_id_] = pt.get(p_den, p_id_) 47 | 48 | r.append(r_) 49 | 50 | # normalize particle position [0-1] 51 | px, py = p_[...,0], p_[...,1] 52 | px /= config.domain[1] 53 | py /= config.domain[0] 54 | p_ = np.stack([py,px], axis=-1) 55 | p.append(p_) 56 | 57 | print('resolution:', config.resolution) 58 | print('domain:', config.domain) 59 | print('radius:', config.radius) 60 | print('normalized px range', px.min(), px.max()) 61 | print('normalized py range', py.min(), py.max()) 62 | print('num particles:', p[0].shape) # the number of particles is fixed 63 | 64 | params['p'] = p 65 | params['r'] = r 66 | 67 | # styler.render_test(params) 68 | result = styler.run(params) 69 | 70 | # save loss plot 71 | l = result['l'] 72 | lb = [] 73 | for o, l_ in enumerate(l): 74 | lb_, = plt.plot(range(len(l_)), l_, label='oct %d' % o) 75 | lb.append(lb_) 76 | plt.legend(handles=lb) 77 | # plt.show() 78 | plot_path = os.path.join(config.log_dir, 'loss_plot.png') 79 | plt.savefig(plot_path) 80 | 81 | 82 | # save density fields 83 | d_sty = result['d'] # [0-255], uint8 84 | # d_path = os.path.join(config.log_dir, 'd%03d_%03d.png' % (config.target_frame,config.target_frame+config.num_frames-1)) 85 | # save_image(d_sty, d_path, nrow=5, gray=not 'c' in config.target_field) 86 | 87 | for i, d_sty_ in enumerate(d_sty): 88 | im = Image.fromarray(d_sty_) 89 | d_path = os.path.join(config.log_dir, '%03d.png' % (config.target_frame+i)) 90 | im.save(d_path) 91 | 92 | d_intm = result['d_intm'] 93 | for o, d_intm_o in enumerate(d_intm): 94 | for i, d_intm_ in enumerate(d_intm_o): 95 | im = Image.fromarray(d_intm_) 96 | d_path = os.path.join(config.log_dir, 'o%02d_%03d.png' % (o, config.target_frame)) 97 | im.save(d_path) 98 | 99 | # save particles (load using Houdini GPlay) 100 | c_sty = result['c'] 101 | p_org = [] 102 | for p_ in p: 103 | # denormalize particle positions 104 | px, py = p_[...,1], p_[...,0] 105 | px *= config.domain[1] 106 | py *= config.domain[0] 107 | p_org.append(np.stack([px,py], axis=-1)) 108 | 109 | for i in range(config.num_frames): 110 | # create a particle set and attributes 111 | pt = partio.create() 112 | position = pt.addAttribute("position",partio.VECTOR,2) 113 | color = pt.addAttribute("Cd",partio.FLOAT,3) 114 | radius = pt.addAttribute("radius",partio.FLOAT,1) 115 | # normal = pt.addAttribute("normal",partio.VECTOR,3) 116 | 117 | for pi in range(p_org[i].shape[0]): 118 | p_ = pt.addParticle() 119 | pt.set(position, p_, tuple(p_org[i][pi].astype(np.float))) 120 | pt.set(color, p_, tuple(c_sty[i][pi].astype(np.float))) 121 | pt.set(radius, p_, (config.radius,)) 122 | 123 | p_path = os.path.join(config.log_dir, '%03d.bgeo' % (config.target_frame+i)) 124 | partio.write(p_path, pt) 125 | 126 | # visualization using open3d 127 | bbox = [ 128 | [0,0,-1], 129 | [config.domain[1],config.domain[0],1], # [X,Y,Z] 130 | ] 131 | draw_pt(p_org, pc=c_sty, bbox=bbox, dt=0.1) 132 | 133 | def main(config): 134 | config.dataset = 'dambreak2d' 135 | config.d_path = 'partio/ParticleData_Fluid_%d.bgeo' 136 | 137 | # from scene 138 | config.radius = 0.025 139 | config.support = 4 140 | config.disc = 2 141 | config.rest_density = 1000 142 | config.resolution = [128, 256] # [H,W] 143 | cell_size = 2*config.radius*config.disc 144 | config.domain = [float(_*cell_size) for _ in config.resolution] # [H,W] 145 | config.nsize = max(3-config.disc,1) # 1 is enough if disc is 2, 2 if disc is 1 146 | 147 | # upscaling for rendering 148 | config.scale = 4 149 | config.nsize *= config.scale 150 | config.resolution = [config.resolution[0]*config.scale, config.resolution[1]*config.scale] 151 | 152 | ##################### 153 | # frame range setting 154 | multi_frame = True 155 | config.frames_per_opt = 200 156 | config.window_sigma = 3 157 | # if multi_frame: 158 | # config.target_frame = 1 159 | # config.num_frames = 200 160 | # config.batch_size = 4 161 | # else: 162 | # config.target_frame = 150 163 | # config.num_frames = 1 164 | # config.batch_size = 1 165 | 166 | ###### 167 | # color test 168 | config.target_field = 'c' 169 | config.lr = 0.01 170 | config.iter = 100 171 | config.octave_n = 3 172 | config.octave_scale = 1.7 173 | config.clip = False 174 | 175 | # style_list = { 176 | # 'fire_new': 'fire_new.jpg', 177 | # 'ben_giles': 'ben_giles.png', 178 | # 'wave': 'wave.jpeg', 179 | # } 180 | # style = 'wave' 181 | # config.style_target = os.path.join(config.data_dir, 'image', style_list[style]) 182 | 183 | config.network = 'vgg_19.ckpt' 184 | config.w_style = 1 185 | config.w_content = 0 186 | config.style_init = 'noise' 187 | config.style_layer = ['conv2_1','conv3_1'] 188 | config.w_style_layer = [0.5,0.5] 189 | config.style_mask = True 190 | config.style_mask_on_ref = False 191 | config.style_tiling = 2 192 | config.w_tv = 0.01 193 | 194 | if config.w_content == 1: 195 | config.tag = 'test_%s_%s_%d' % ( 196 | config.target_field, config.content_layer, config.content_channel) 197 | else: 198 | style = os.path.splitext(os.path.basename(config.style_target))[0] 199 | config.tag = 'test_%s_%s' % ( 200 | config.target_field, style) 201 | 202 | config.tag += '_%d' % config.num_frames 203 | 204 | run(config) 205 | 206 | if __name__ == "__main__": 207 | config, unparsed = get_config() 208 | main(config) -------------------------------------------------------------------------------- /scene/smokegun.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import argparse 5 | from datetime import datetime 6 | import os 7 | from tqdm import trange 8 | import numpy as np 9 | from PIL import Image 10 | import platform 11 | from subprocess import call 12 | try: 13 | from manta import * 14 | except ImportError: 15 | pass 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument("--data_dir", type=str, default='data/smokegun') 20 | parser.add_argument("--num_param", type=int, default=3) 21 | parser.add_argument("--path_format", type=str, default='%03d.%s') 22 | 23 | parser.add_argument("--src_x_pos", type=float, default=0.2) 24 | parser.add_argument("--src_z_pos", type=float, default=0.5) 25 | parser.add_argument("--src_y_pos", type=float, default=0.15) 26 | parser.add_argument("--src_inflow", type=float, default=8) 27 | parser.add_argument("--strength", type=float, default=0.05) 28 | parser.add_argument("--src_radius", type=float, default=0.12) 29 | parser.add_argument("--num_frames", type=int, default=120) 30 | parser.add_argument("--obstacle", type=bool, default=False) 31 | 32 | parser.add_argument("--resolution_x", type=int, default=200) 33 | parser.add_argument("--resolution_y", type=int, default=300) 34 | parser.add_argument("--resolution_z", type=int, default=200) 35 | parser.add_argument("--buoyancy", type=float, default=-4e-3) 36 | parser.add_argument("--bWidth", type=int, default=1) 37 | parser.add_argument("--open_bound", type=bool, default=True) 38 | parser.add_argument("--time_step", type=float, default=0.5) 39 | parser.add_argument("--adv_order", type=int, default=2) 40 | parser.add_argument("--clamp_mode", type=int, default=2) 41 | 42 | parser.add_argument("--transmit", type=float, default=0.01) 43 | parser.add_argument("--downup_factor", type=int, default=8) 44 | 45 | args = parser.parse_args() 46 | 47 | def downup_sample(): 48 | # d_path = os.path.join(args.data_dir, 'd_low') 49 | # v_path = os.path.join(args.data_dir, 'v_low') 50 | d_path = os.path.join('E:/lnst/data/smokegun', 'd_low') 51 | v_path = os.path.join('E:/lnst/data/smokegun', 'v_low') 52 | for f_path in [d_path, v_path]: 53 | if not os.path.exists(f_path): 54 | os.mkdir(f_path) 55 | 56 | res_x = args.resolution_x 57 | res_y = args.resolution_y 58 | res_z = args.resolution_z 59 | org_res = [res_z,res_y,res_x] 60 | down_res = [r//args.downup_factor for r in org_res] 61 | d_ = np.zeros(org_res, dtype=np.float32) 62 | 63 | # solver params 64 | gs = vec3(res_x, res_y, res_z) 65 | buoyancy = vec3(0,args.buoyancy,0) 66 | 67 | s = Solver(name='main', gridSize=gs, dim=3) 68 | s.timestep = args.time_step 69 | 70 | flags = s.create(FlagGrid) 71 | vel = s.create(MACGrid) 72 | density = s.create(RealGrid) 73 | 74 | flags.initDomain(boundaryWidth=args.bWidth) 75 | flags.fillGrid() 76 | if args.open_bound: 77 | setOpenBound(flags, args.bWidth,'xXyYzZ', FlagOutflow|FlagEmpty) 78 | 79 | radius = gs.x*args.src_radius 80 | center = gs*vec3(args.src_x_pos,args.src_y_pos,args.src_z_pos) 81 | source = s.create(Sphere, center=center, radius=radius) 82 | 83 | if args.obstacle: 84 | obs_radius = gs.x*0.15 85 | obs_center = gs*vec3(0.7, 0.5, 0.5) 86 | obs = s.create(Sphere, center=obs_center, radius=obs_radius) 87 | obs.applyToGrid(grid=flags, value=FlagObstacle) 88 | 89 | if (GUI): 90 | gui = Gui() 91 | gui.show(True) 92 | #gui.pause() 93 | 94 | def resize(v, vshape, order=3): 95 | import skimage.transform 96 | v0 = skimage.transform.resize(v[...,0], vshape, order=3).astype(np.float32) 97 | v1 = skimage.transform.resize(v[...,1], vshape, order=3).astype(np.float32) 98 | v2 = skimage.transform.resize(v[...,2], vshape, order=3).astype(np.float32) 99 | return np.stack([v0,v1,v2], axis=-1).astype(np.float32) 100 | 101 | for t in trange(args.num_frames, desc='downup_sample'): 102 | source.applyToGrid(grid=density, value=1) 103 | 104 | # save density 105 | copyGridToArrayReal(density, d_) 106 | d_file_path = os.path.join(d_path, args.path_format % (t, 'npz')) 107 | np.savez_compressed(d_file_path, x=d_) 108 | 109 | d_file_path = os.path.join(d_path, args.path_format % (t, 'png')) 110 | transmit = np.exp(-np.cumsum(d_[::-1], axis=0)*args.transmit) 111 | d_img = np.sum(d_*transmit, axis=0) 112 | d_img /= d_img.max() 113 | im = Image.fromarray((d_img[::-1]*255).astype(np.uint8)) 114 | im.save(d_file_path) 115 | 116 | v_file_path = os.path.join(args.data_dir, 'v', args.path_format % (t, 'npz')) 117 | with np.load(v_file_path) as data: 118 | v_ = data['x'] 119 | v_ = resize(resize(v_, down_res), org_res) 120 | 121 | # save velocity 122 | v_file_path = os.path.join(v_path, args.path_format % (t, 'npz')) 123 | np.savez_compressed(v_file_path, x=v_) 124 | 125 | # advect density 126 | copyArrayToGridMAC(v_, vel) 127 | advectSemiLagrange(flags=flags, vel=vel, grid=density, order=args.adv_order, 128 | openBounds=args.open_bound, boundaryWidth=args.bWidth, clampMode=args.clamp_mode) 129 | 130 | s.step() 131 | 132 | def main(): 133 | if not os.path.exists(args.data_dir): 134 | os.makedirs(args.data_dir) 135 | 136 | field_type = ['d', 'v'] 137 | for field in field_type: 138 | field_path = os.path.join(args.data_dir,field) 139 | if not os.path.exists(field_path): 140 | os.mkdir(field_path) 141 | 142 | args_file = os.path.join(args.data_dir, 'args.txt') 143 | with open(args_file, 'w') as f: 144 | print('%s: arguments' % datetime.now()) 145 | for k, v in vars(args).items(): 146 | print(' %s: %s' % (k, v)) 147 | f.write('%s: %s\n' % (k, v)) 148 | 149 | res_x = args.resolution_x 150 | res_y = args.resolution_y 151 | res_z = args.resolution_z 152 | d_ = np.zeros([res_z,res_y,res_x], dtype=np.float32) 153 | v_ = np.zeros([res_z,res_y,res_x,3], dtype=np.float32) 154 | 155 | # solver params 156 | gs = vec3(res_x, res_y, res_z) 157 | buoyancy = vec3(0,args.buoyancy,0) 158 | 159 | s = Solver(name='main', gridSize=gs, dim=3) 160 | s.timestep = args.time_step 161 | 162 | flags = s.create(FlagGrid) 163 | vel = s.create(MACGrid) 164 | density = s.create(RealGrid) 165 | pressure = s.create(RealGrid) 166 | 167 | flags.initDomain(boundaryWidth=args.bWidth) 168 | flags.fillGrid() 169 | if args.open_bound: 170 | setOpenBound(flags, args.bWidth,'xXyYzZ', FlagOutflow|FlagEmpty) 171 | 172 | radius = gs.x*args.src_radius 173 | center = gs*vec3(args.src_x_pos,args.src_y_pos,args.src_z_pos) 174 | source = s.create(Sphere, center=center, radius=radius) 175 | 176 | if args.obstacle: 177 | obs_radius = gs.x*0.15 178 | obs_center = gs*vec3(0.7, 0.5, 0.5) 179 | obs = s.create(Sphere, center=obs_center, radius=obs_radius) 180 | obs.applyToGrid(grid=flags, value=FlagObstacle) 181 | 182 | if (GUI): 183 | gui = Gui() 184 | gui.show(True) 185 | #gui.pause() 186 | 187 | for t in trange(args.num_frames, desc='sim'): 188 | source.applyToGrid(grid=density, value=1) 189 | source.applyToGrid(grid=vel, value=vec3(args.src_inflow,0,0)) 190 | 191 | # save density 192 | copyGridToArrayReal(density, d_) 193 | d_file_path = os.path.join(args.data_dir, 'd', args.path_format % (t, 'npz')) 194 | np.savez_compressed(d_file_path, x=d_) 195 | 196 | d_file_path = os.path.join(args.data_dir, 'd', args.path_format % (t, 'png')) 197 | transmit = np.exp(-np.cumsum(d_[::-1], axis=0)*args.transmit) 198 | d_img = np.sum(d_*transmit, axis=0) 199 | d_img /= d_img.max() 200 | im = Image.fromarray((d_img[::-1]*255).astype(np.uint8)) 201 | im.save(d_file_path) 202 | 203 | # save velocity 204 | v_file_path = os.path.join(args.data_dir, 'v', args.path_format % (t, 'npz')) 205 | copyGridToArrayMAC(vel, v_) 206 | np.savez_compressed(v_file_path, x=v_) 207 | 208 | advectSemiLagrange(flags=flags, vel=vel, grid=density, order=args.adv_order, 209 | openBounds=args.open_bound, boundaryWidth=args.bWidth, clampMode=args.clamp_mode) 210 | 211 | advectSemiLagrange(flags=flags, vel=vel, grid=vel, order=args.adv_order, 212 | openBounds=args.open_bound, boundaryWidth=args.bWidth, clampMode=args.clamp_mode) 213 | 214 | vorticityConfinement(vel=vel, flags=flags, strength=args.strength) 215 | 216 | setWallBcs(flags=flags, vel=vel) 217 | addBuoyancy(density=density, vel=vel, gravity=buoyancy, flags=flags) 218 | solvePressure(flags=flags, vel=vel, pressure=pressure, cgMaxIterFac=10.0, cgAccuracy=0.0001) 219 | setWallBcs(flags=flags, vel=vel) 220 | 221 | s.step() 222 | 223 | if __name__ == '__main__': 224 | main() 225 | downup_sample() -------------------------------------------------------------------------------- /test_chocolate.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import numpy as np 5 | import tensorflow as tf 6 | import os 7 | from tqdm import trange 8 | from config import get_config 9 | from util import * 10 | from styler_3p import Styler 11 | import sys 12 | sys.path.append('E:/partio/build/py/Release') 13 | import partio 14 | 15 | def run(config): 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi 17 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id # "0, 1" for multiple 18 | 19 | prepare_dirs_and_logger(config) 20 | tf.compat.v1.set_random_seed(config.seed) 21 | config.rng = np.random.RandomState(config.seed) 22 | 23 | styler = Styler(config) 24 | styler.load_img(config.resolution[1:]) 25 | 26 | params = {} 27 | 28 | # load particles 29 | nmin, nmax = np.iinfo(np.int32).max, 0 30 | for i in range(config.num_frames): 31 | pt_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+i)) 32 | pt = partio.read(pt_path) 33 | p_num = pt.numParticles() 34 | nmin = min(p_num, nmin) 35 | nmax = max(p_num, nmax) 36 | 37 | print('# range:', nmin, nmax) 38 | 39 | p = [] 40 | # r = [] 41 | for i in trange(config.num_frames, desc='load particle'): # last one for mask 42 | pt_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+i)) 43 | pt = partio.read(pt_path) 44 | 45 | p_attr_id = pt.attributeInfo('id') 46 | p_attr_pos = pt.attributeInfo('position') 47 | # p_attr_den = pt.attributeInfo('density') 48 | 49 | p_ = np.ones([nmax,3], dtype=np.float32)*-1 50 | # r_ = np.zeros([nmax,1], dtype=np.float32) 51 | 52 | p_num = pt.numParticles() 53 | for j in range(p_num): 54 | p_id_j = pt.get(p_attr_id, j)[0] 55 | p_[p_id_j] = pt.get(p_attr_pos, p_id_j) 56 | # r_[p_id_j] = pt.get(p_attr_den, p_id_j) 57 | # r.append(r_) 58 | 59 | # normalize particle position [0-1] 60 | px, py, pz = p_[...,0], p_[...,1], p_[...,2] 61 | px /= config.domain[2] 62 | py /= config.domain[1] 63 | pz /= config.domain[0] 64 | p_ = np.stack([pz,py,px], axis=-1) 65 | p.append(p_) 66 | 67 | print('resolution:', config.resolution) 68 | print('domain:', config.domain) 69 | print('radius:', config.radius) 70 | print('normalized px range', px.min(), px.max()) 71 | print('normalized py range', py.min(), py.max()) 72 | 73 | params['p'] = p 74 | 75 | # styler.render_test(params) 76 | result = styler.run(params) 77 | 78 | # save loss plot 79 | l = result['l'] 80 | lb = [] 81 | for o, l_ in enumerate(l): 82 | lb_, = plt.plot(range(len(l_)), l_, label='oct %d' % o) 83 | lb.append(lb_) 84 | plt.legend(handles=lb) 85 | # plt.show() 86 | plot_path = os.path.join(config.log_dir, 'loss_plot.png') 87 | plt.savefig(plot_path) 88 | 89 | # save particle (load using Houdini GPlay) 90 | p_sty = result['p'] 91 | p = [] 92 | # v_sty = result['v'] 93 | # v = [] 94 | for i in range(config.num_frames): 95 | # denormalize particle positions 96 | px, py, pz = p_sty[i][...,2], p_sty[i][...,1], p_sty[i][...,0] 97 | px *= config.domain[2] 98 | py *= config.domain[1] 99 | pz *= config.domain[0] 100 | p_sty_ = np.stack([px,py,pz], axis=-1) 101 | p.append(p_sty_) 102 | 103 | # # denormalize particle displacement for stylization 104 | # vx, vy, vz = v_sty[i][...,2], v_sty[i][...,1], v_sty[i][...,0] 105 | # vx *= config.domain[2] 106 | # vy *= config.domain[1] 107 | # vz *= config.domain[0] 108 | # v_sty_ = np.stack([vx,vy,vz], axis=-1) 109 | # v.append(v_sty_) 110 | 111 | # create a particle set and attributes 112 | pt = partio.create() 113 | position = pt.addAttribute("position",partio.VECTOR,3) 114 | # color = pt.addAttribute("Cd",partio.FLOAT,3) 115 | radius = pt.addAttribute("radius",partio.FLOAT,1) 116 | # normal = pt.addAttribute("normal",partio.VECTOR,3) 117 | 118 | for p_sty_i in p_sty_: 119 | if p_sty_i[0] < 0: continue 120 | p_ = pt.addParticle() 121 | pt.set(position, p_, tuple(p_sty_i.astype(np.float))) 122 | pt.set(radius, p_, (config.radius,)) 123 | 124 | p_path = os.path.join(config.log_dir, '%03d.bgeo' % (config.target_frame+i)) 125 | partio.write(p_path, pt) 126 | 127 | r_sty = result['r'] 128 | for i, r_sty_ in enumerate(r_sty): 129 | im = Image.fromarray(r_sty_) 130 | d_path = os.path.join(config.log_dir, '%03d.png' % (config.target_frame+i)) 131 | im.save(d_path) 132 | 133 | d_intm = result['d_intm'] 134 | for o, d_intm_o in enumerate(d_intm): 135 | for i, d_intm_ in enumerate(d_intm_o): 136 | if d_intm_ is None: continue 137 | im = Image.fromarray(d_intm_) 138 | d_path = os.path.join(config.log_dir, 'o%02d_%03d.png' % (o, config.target_frame+i)) 139 | im.save(d_path) 140 | 141 | # visualization using open3d 142 | bbox = [ 143 | [0,0,0], 144 | [config.domain[2],config.domain[1],config.domain[0]], # [X,Y,Z] 145 | ] 146 | draw_pt(p, bbox=bbox, dt=0.1, is_2d=False) # pv=v, 147 | 148 | def main(config): 149 | config.dataset = 'chocolate' 150 | config.d_path = 'partio/ParticleData_Fluid_%d.bgeo' 151 | 152 | # from scene 153 | config.radius = 0.025 154 | config.support = 4 155 | config.disc = 2 # 1 or 2 156 | config.rest_density = 1000 157 | config.resolution = [128,128,128] # original resolution, # [D,H,W] 158 | cell_size = 2*config.radius*config.disc 159 | config.domain = [float(_*cell_size) for _ in config.resolution] # [D,H,W] 160 | config.nsize = max(3-config.disc,1) # 1 is enough if disc is 2, 2 if disc is 1 161 | 162 | # upscaling for rendering 163 | config.resolution = [200,200,200] 164 | 165 | # default settings 166 | config.lr = 0.002 167 | config.iter = 20 168 | config.resize_scale = 1 169 | config.transmit = 0.2 # 0.01, 1 170 | config.clip = False # ignore particles outside of domain 171 | config.num_kernels = 1 172 | config.k = 3 173 | config.network = 'tensorflow_inception_graph.pb' 174 | 175 | config.octave_n = 2 176 | config.octave_scale = 1.8 177 | config.render_liquid = True 178 | config.rotate = False 179 | config.style_layer = ['conv2d2','mixed3b','mixed4b'] 180 | config.w_style_layer = [1,1,1] 181 | 182 | ##################### 183 | # frame range setting 184 | config.frames_per_opt = 120 185 | config.batch_size = 1 186 | config.window_sigma = 9 187 | # multi_frame = True 188 | # if multi_frame: 189 | # config.target_frame = 1 190 | # config.num_frames = 120 191 | # config.interp = 1 192 | 193 | # ###### 194 | # # interpolation test 195 | # interpolate = False 196 | # if interpolate: 197 | # config.interp = 5 198 | # n = (config.num_frames-1)//config.interp 199 | # config.num_frames = n*config.interp + 1 200 | # assert (config.num_frames - 1) % config.interp == 0 201 | # ##### 202 | # 203 | # else: 204 | # config.target_frame = 90 205 | # config.num_frames = 1 206 | # config.interp = 1 207 | 208 | ###### 209 | # position test 210 | config.target_field = 'p' 211 | semantic = False 212 | pressure = False 213 | 214 | # if semantic: 215 | # config.w_style = 0 216 | # config.w_content = 1 217 | # config.content_layer = 'mixed3b_3x3_bottleneck_pre_relu' 218 | # config.content_channel = 44 # net 219 | # else: 220 | # # style 221 | # config.w_style = 1 222 | # config.w_content = 0 223 | 224 | # style_list = { 225 | # 'spiral': 'pattern1.png', 226 | # 'fire_new': 'fire_new.jpg', 227 | # 'ben_giles': 'ben_giles.png', 228 | # 'wave': 'wave.jpeg', 229 | # } 230 | # style = 'spiral' 231 | # config.style_target = os.path.join(config.data_dir, 'image', style_list[style]) 232 | 233 | # pressure test 234 | if pressure: 235 | config.w_pressure = 1e12 # 1e10 ~ 1e12 236 | ##### 237 | 238 | if config.w_content == 1: 239 | config.tag = 'test_%s_%s_%d' % ( 240 | config.target_field, config.content_layer, config.content_channel) 241 | else: 242 | style = os.path.splitext(os.path.basename(config.style_target))[0] 243 | config.tag = 'test_%s_%s' % ( 244 | config.target_field, style) 245 | 246 | config.tag += '_%d_intp%d' % (config.num_frames, config.interp) 247 | 248 | quick_test = False 249 | if quick_test: 250 | config.scale = 1 251 | config.iter = 0 252 | config.octave_n = 1 253 | 254 | run(config) 255 | 256 | if __name__ == "__main__": 257 | config, unparsed = get_config() 258 | main(config) -------------------------------------------------------------------------------- /styler_2p.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import tensorflow as tf 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import os 8 | from tqdm import trange 9 | from util import * 10 | from transform import p2g 11 | import vgg 12 | from styler_base import StylerBase 13 | 14 | class Styler(StylerBase): 15 | def __init__(self, self_dict): 16 | StylerBase.__init__(self, self_dict) 17 | 18 | # particle position 19 | # shape: [N,2], scale: [0,1] 20 | p = [] 21 | p_shp = [None,2] 22 | self.p = [] 23 | 24 | # particle density shape 25 | r_shp = [None,1] 26 | self.r = [] 27 | 28 | # particle color shape 29 | c_shp = [None,3] 30 | self.c = [] 31 | 32 | # output and density field 33 | d = [] 34 | d_gray = [] 35 | 36 | self.opt_init = [] 37 | self.opt_ph = [] 38 | self.opt = [] 39 | 40 | self.res = tf.compat.v1.placeholder(tf.int32, [2], name='resolution') 41 | 42 | for i in range(self.batch_size): 43 | # particle position, [N,2] 44 | p_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=p_shp, name='p%d' % i) 45 | self.p.append(p_) 46 | p_ = tf.expand_dims(p_, axis=0) # [1,N,2] 47 | p.append(p_[0]) 48 | 49 | # particle density, [N,1] 50 | r_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=r_shp, name='r%d' % i) 51 | self.r.append(r_) 52 | r_ = tf.expand_dims(r_, axis=0) # [1,N,1] 53 | 54 | # position-based (SPH) density field estimation 55 | d_gray_ = p2g(p_, self.domain, self.res, self.radius, self.rest_density, self.nsize, support=self.support, clip=self.clip) # [B,N,2] -> [B,H,W,1] 56 | d_gray_ /= self.rest_density # normalize density 57 | d_gray.append(d_gray_) 58 | 59 | # particle color, [N,3] 60 | opt_ph = tf.compat.v1.placeholder(dtype=tf.float32, shape=c_shp, name='c_opt_ph%d' % i) 61 | self.opt_ph.append(opt_ph) 62 | opt_var = tf.Variable(opt_ph, validate_shape=False, name='c_opt%d' % i) 63 | self.opt.append(opt_var) 64 | opt_var_ = tf.reshape(opt_var, tf.shape(opt_ph)) 65 | opt_var_ = tf.expand_dims(opt_var_, axis=0) 66 | 67 | # clip particle color 68 | c_ = tf.clip_by_value(opt_var_, 0, 1) 69 | 70 | # mask color 71 | self.c.append(c_[0]*tf.clip_by_value(r_[0]/self.rest_density, 0, 1)) 72 | 73 | # position-based (SPH) color field estimation 74 | d_ = p2g(p_, self.domain, self.res, self.radius, self.rest_density, self.nsize, support=self.support, clip=self.clip, 75 | pc=c_, pd=r_) # [B,N,2] -> [B,H,W,3] 76 | 77 | d.append(d_) 78 | 79 | self.opt_init = tf.compat.v1.initializers.variables(self.opt) 80 | 81 | # particle position 82 | self.p_out = p # [N,2]*B 83 | 84 | # estimated color fields 85 | d = tf.concat(d, axis=0) # [B,H,W,3] 86 | 87 | # value clipping for rendering 88 | d = tf.clip_by_value(d, 0, 1) 89 | 90 | # estimated density fields for masking 91 | d_gray = tf.concat(d_gray, axis=0) # [B,H,W,1] 92 | 93 | # clamp density field [0,1] 94 | d_gray = tf.clip_by_value(d_gray, 0, 1) 95 | 96 | # mask for style features 97 | self.d_gray = d_gray 98 | 99 | # stylized result 100 | self.d_out = d*d_gray # [B,H,W,3] 101 | 102 | self._plugin_to_loss_net(d) 103 | 104 | def render_test(self, params): 105 | feed = {} 106 | feed[self.res] = self.resolution 107 | 108 | for i in range(self.batch_size): 109 | feed[self.p[i]] = params['p'][i] 110 | feed[self.r[i]] = params['r'][i] 111 | n = params['p'][i].shape[0] 112 | 113 | # feed[self.opt_ph[i]] = np.ones([n,3]) 114 | c_init_shp = [n,3] 115 | c_init = self.rng.uniform(-5,5, c_init_shp).astype(np.float32) 116 | c_init += np.array([vgg._R_MEAN, vgg._G_MEAN, vgg._B_MEAN]) 117 | feed[self.opt_ph[i]] = c_init/255 118 | 119 | self.sess.run(self.opt_init, feed) 120 | p_out, d_out, d_gray = self.sess.run([self.p_out, self.d_out, self.d_gray], feed) 121 | plt.subplot(121) 122 | plt.imshow(d_out[0]) 123 | plt.subplot(122) 124 | plt.imshow(d_gray[0,...,0]) 125 | plt.show() 126 | 127 | for i, p in enumerate(p_out): 128 | p[:,0] = p[:,0]*self.domain[0] 129 | p[:,1] = p[:,1]*self.domain[1] 130 | p_out[i] = np.stack([p[:,1],p[:,0]], axis=-1) 131 | v_ = None 132 | bbox = [ 133 | [0,0,-1], 134 | [self.domain[1],self.domain[0],1], 135 | ] 136 | draw_pt(p_out, v_, bbox=bbox) 137 | return 138 | 139 | # save to image 140 | for t in trange(0,self.num_frames,self.batch_size): 141 | if t == 0: 142 | n = params['p'][0].shape[0] 143 | from matplotlib import cm 144 | c = cm.plasma(np.linspace(0,1,n))[...,:-1] 145 | 146 | for i in range(self.batch_size): 147 | feed[self.p[i]] = params['p'][t+i] 148 | feed[self.r[i]] = params['r'][t+i] 149 | if 'p' in self.target_field: 150 | feed[self.opt_ph[i]] = np.zeros([n,2]) 151 | if 'c' in self.target_field: 152 | feed[self.opt_ph[i]] = c 153 | 154 | self.sess.run(self.opt_init, feed) 155 | d_out = self.sess.run(self.d_out, feed) 156 | if d_out.shape[-1] == 1: 157 | d_out = d_out[...,0] # [B,H,W] 158 | # plt.imshow(d_out[0]) 159 | # plt.show() 160 | for i in range(self.batch_size): 161 | im = Image.fromarray((d_out[i]*255).astype(np.uint8)) 162 | d_path = os.path.join(self.log_dir, '%03d.png' % (t+i)) 163 | im.save(d_path) 164 | 165 | def run(self, params): 166 | # loss 167 | self._loss(params) 168 | 169 | # optimizer 170 | self.opt_lr = tf.compat.v1.placeholder(tf.float32) 171 | 172 | # settings for octave process 173 | oct_size = [] 174 | hw = np.array(self.resolution) 175 | for _ in range(self.octave_n): 176 | oct_size.append(hw) 177 | hw = (hw//self.octave_scale).astype(np.int) 178 | oct_size.reverse() 179 | print('input size for each octave', oct_size) 180 | 181 | p = params['p'] 182 | r = params['r'] 183 | 184 | g_opt = [] 185 | n = p[0].shape[0] # n is fixed 186 | # # same noise 187 | # c_opt_shp = [n, 3] 188 | # different noise 189 | c_opt_shp = [self.num_frames, n, 3] 190 | c_opt = self.rng.uniform(-5,5, c_opt_shp).astype(np.float32) 191 | c_opt += np.array([vgg._R_MEAN, vgg._G_MEAN, vgg._B_MEAN]) 192 | c_opt /= 255 # [0,1] 193 | for i in range(self.num_frames): 194 | # # same noise 195 | # c_opt.append(c_opt) 196 | # different noise 197 | g_opt.append(c_opt[i]) 198 | 199 | # optimize 200 | loss_history = [] 201 | d_intm = [] 202 | opt_ = {} 203 | for octave in trange(self.octave_n, desc='octave'): 204 | loss_history_o = [] 205 | d_intm_o = [] 206 | 207 | feed = {} 208 | feed[self.res] = oct_size[octave] 209 | if self.content_img is not None: 210 | feed[self.content_feature] = self._content_feature( 211 | self.content_img, oct_size[octave]) 212 | 213 | if self.style_img is not None: 214 | style_features = self._style_feature( 215 | self.style_img, oct_size[octave]) 216 | 217 | for i in range(len(self.style_features)): 218 | feed[self.style_features[i]] = style_features[i] 219 | 220 | if self.w_hist > 0: 221 | hist_features = self._hist_feature( 222 | self.style_img, oct_size[octave]) 223 | 224 | for i in range(len(self.hist_features)): 225 | feed[self.hist_features[i]] = hist_features[i] 226 | 227 | if type(self.lr) == list: 228 | lr = self.lr[octave] 229 | else: 230 | lr = self.lr 231 | 232 | # optimizer list for each batch 233 | for step in trange(self.iter,desc='iter'): 234 | g_tmp = [None]*self.num_frames 235 | 236 | for t in range(0,self.num_frames,self.batch_size): 237 | for i in range(self.batch_size): 238 | feed[self.p[i]] = p[t+i] 239 | feed[self.r[i]] = r[t+i] 240 | feed[self.opt_ph[i]] = g_opt[t+i] 241 | 242 | # assign g_opt to self.opt through self.opt_ph 243 | self.sess.run(self.opt_init, feed) 244 | 245 | feed[self.opt_lr] = lr 246 | opt_id = t//self.frames_per_opt 247 | # opt_id = self.rng.randint(num_opt) 248 | if opt_id in opt_: 249 | train_op = opt_[opt_id] 250 | else: 251 | opt = tf.compat.v1.train.AdamOptimizer(learning_rate=self.opt_lr) 252 | train_op = opt.minimize(self.total_loss, var_list=self.opt) 253 | self.sess.run(tf.compat.v1.variables_initializer(opt.variables()), feed) 254 | opt_[opt_id] = train_op 255 | 256 | # optimize 257 | _, l_ = self.sess.run([train_op, self.total_loss], feed) 258 | loss_history_o.append(l_) 259 | 260 | g_opt_ = self.sess.run(self.opt, feed) 261 | for i in range(self.batch_size): 262 | g_tmp[t+i] = np.nan_to_num(g_opt_[i]) - g_opt[t+i] 263 | 264 | if step == self.iter-1 and octave < self.octave_n-1: # True or 265 | d_intm_ = self.sess.run(self.d_out, feed) 266 | d_intm_o.append((d_intm_*255).astype(np.uint8)) 267 | 268 | # ## debug 269 | # d_gray = self.sess.run(self.d_gray, feed) 270 | # plt.subplot(121) 271 | # plt.imshow(d_intm_[0,...]) 272 | # plt.subplot(122) 273 | # plt.imshow(d_gray[0,...,0]) 274 | # plt.show() 275 | 276 | ######### 277 | # gradient alignment 278 | if self.window_sigma > 0 and self.num_frames > 1: 279 | g_tmp = denoise(g_tmp, sigma=(self.window_sigma,0,0)) 280 | 281 | for t in range(self.num_frames): 282 | g_opt[t] += g_tmp[t] 283 | 284 | loss_history.append(loss_history_o) 285 | if octave < self.octave_n-1: 286 | d_intm.append(np.concatenate(d_intm_o, axis=0)) 287 | 288 | # gather outputs 289 | result = { 290 | 'l': loss_history, 'd_intm': d_intm, 291 | } 292 | 293 | # final inference 294 | c_sty = [None]*self.num_frames 295 | d_sty = [None]*self.num_frames 296 | for t in range(0,self.num_frames,self.batch_size): 297 | for i in range(self.batch_size): 298 | feed[self.p[i]] = p[t+i] 299 | feed[self.r[i]] = r[t+i] 300 | feed[self.opt_ph[i]] = g_opt[t+i] 301 | 302 | self.sess.run(self.opt_init, feed) 303 | p_, d_ = self.sess.run([self.p_out, self.d_out], feed) 304 | c_ = self.sess.run(self.c, feed) 305 | 306 | for i in range(self.batch_size): 307 | c_sty[t+i] = c_[i] 308 | 309 | d_ = (d_*255).astype(np.uint8) 310 | d_sty[t:t+self.batch_size] = d_ 311 | 312 | result['c'] = c_sty 313 | result['d'] = np.array(d_sty) 314 | 315 | return result -------------------------------------------------------------------------------- /test_smokegun_resim.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import numpy as np 5 | import tensorflow as tf 6 | import matplotlib.pyplot as plt 7 | import os 8 | from tqdm import trange 9 | import struct 10 | from config import get_config 11 | from util import * 12 | from transform import g2p, p2g, p2g_wavg 13 | import sys 14 | sys.path.append('E:/partio/build/py/Release') 15 | import partio 16 | 17 | class SimG2P(object): 18 | def __init__(self, self_dict): 19 | # get arguments 20 | for arg in vars(self_dict): 21 | setattr(self, arg, getattr(self_dict,arg)) 22 | 23 | self.sess = tf.compat.v1.InteractiveSession() 24 | 25 | # particle position at t 26 | x_shp = [None,3] 27 | self.x = tf.compat.v1.placeholder(dtype=tf.float32, shape=x_shp, name='x') 28 | x = tf.expand_dims(self.x, axis=0) # [1,None,3] 29 | 30 | # velocity field to sample 31 | u_shp = [None,None,None,3] 32 | self.u = tf.compat.v1.placeholder(dtype=tf.float32, shape=u_shp, name='u') 33 | u = tf.expand_dims(self.u, axis=0) # [1,None,None,None,3] 34 | 35 | # grid to particle velocity 36 | v = g2p(u, x, is_2d=False) 37 | 38 | #### 39 | # RK4 velocity sampling 40 | x1 = x + v*0.5 41 | v1 = g2p(u, x1, is_2d=False) 42 | 43 | x2 = x + v1*0.5 44 | v2 = g2p(u, x2, is_2d=False) 45 | 46 | x3 = x + v2 47 | v3 = g2p(u, x3, is_2d=False) 48 | v = (v + v1*2 + v2*2 + v3)/6 49 | #### 50 | 51 | # advect to t+1 52 | time_step = 0.5 53 | x_adv = x + v*time_step 54 | self.x_adv = x_adv[0] 55 | 56 | ############ 57 | # particle position displacement for optimization 58 | self.v = tf.compat.v1.placeholder(dtype=tf.float32, shape=x_shp, name='v') 59 | self.optv = tf.Variable(self.v, validate_shape=False, name='v_opt') 60 | v_opt = tf.reshape(self.optv, tf.shape(self.v)) 61 | self.pv = v_opt 62 | pv = tf.expand_dims(v_opt, axis=0) 63 | self.x_hat = x + pv 64 | 65 | # density splatting 66 | self.res = tf.compat.v1.placeholder(tf.int32, [3], name='resolution') 67 | d_rec = p2g(self.x_hat, self.domain, self.res, self.radius, self.rest_density, self.nsize, kernel='cubic', support=4, clip=False, is_2d=False) 68 | pressure = d_rec - self.rest_density 69 | pressure = tf.where(d_rec>0, pressure, tf.zeros_like(pressure)) 70 | 71 | # L2 Loss: pressure 72 | self.pres_loss = tf.reduce_mean(pressure**2) 73 | self.loss = self.pres_loss # + 0.1*tf.reduce_mean(tf.compat.v1.image.total_variation(pressure[0,...,0])) # weak TV loss 74 | 75 | self.opt_init = tf.compat.v1.initializers.variables([self.optv]) 76 | self.opt = tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr) 77 | self.train_op = self.opt.minimize(self.loss, var_list=[self.optv]) 78 | 79 | ############ 80 | # multi-scale density sampling 81 | 82 | # ground truth density field at t+1 83 | d_shp = [None,None,None] 84 | self.d = tf.compat.v1.placeholder(dtype=tf.float32, shape=d_shp, name='d') 85 | d = tf.expand_dims(tf.expand_dims(self.d, axis=0), axis=-1) 86 | # d = resize_tf(d, self.res, method=tf.image.ResizeMethod.BILINEAR, is_3d=True) 87 | 88 | # particle density sampling at t+1 89 | r = [] 90 | for o in range(self.octave_n): 91 | if o > 0: 92 | d_hi = d_hat 93 | d_ = d - d_hi[:,:,::-1] 94 | else: 95 | d_ = d 96 | r_ = g2p(d_, self.x_hat, is_2d=False) 97 | r.append(r_) 98 | 99 | factor = self.octave_scale**o 100 | d_hat = p2g_wavg(self.x_hat, r_, self.domain, self.res, self.radius, self.nsize, kernel='cubic', is_2d=False, clip=False, support=self.support/factor) 101 | if o > 0: 102 | d_hat += d_hi 103 | 104 | self.r_smp = tf.concat(r, axis=-1)[0] 105 | self.d_smp = tf.clip_by_value(d_hat[0,...,0], 0, 1) 106 | self.d_diff = (d[:,:,::-1] - d_hat)[0,:,::-1,:,0] 107 | 108 | # simple advection test 109 | r_shp = [None,1] 110 | self.r = tf.compat.v1.placeholder(dtype=tf.float32, shape=r_shp, name='r') 111 | r_ = tf.expand_dims(self.r, axis=0) # [1,None,3] 112 | d_rec = p2g_wavg(x_adv, r_, self.domain, self.res, self.radius, self.nsize, kernel='cubic', is_2d=False, clip=False, support=4) 113 | self.d_rec = d_rec[0,...,0] 114 | 115 | def sample(self, d, disc=1, threshold=0, p0=None, p_id=None): 116 | ''' 117 | sample particles where d's value is higher than threshold 118 | ''' 119 | # pid = np.where(d > threshold) 120 | # add pt only in src region 121 | pid = np.where(d[76:124,231:279,16:64] > threshold) 122 | pid = np.array(pid).transpose([1,0]).astype(np.float) 123 | pid += np.array([76,231,16]) 124 | 125 | cell_size = 1/disc 126 | offset = cell_size/2 127 | p = [] 128 | for i in range(disc): 129 | for j in range(disc): 130 | for k in range(disc): 131 | p_ = pid + offset + np.array([cell_size*i, cell_size*j, cell_size*k]) 132 | p.append(p_) 133 | p = np.concatenate(p, axis=0) 134 | 135 | # normalize to [0,1] 136 | pz, py, px = p[:,0], p[:,1], p[:,2] 137 | pz /= d.shape[0] 138 | py /= d.shape[1] 139 | px /= d.shape[2] 140 | p = np.stack([pz,py,px], axis=-1) 141 | 142 | # if there are new particles, add to prev 143 | if len(p) > 0: 144 | if p_id is None: 145 | p_id = np.arange(p.shape[0]) 146 | else: 147 | p_id0 = p_id[-1]+1 148 | p_id_new = np.arange(p_id0, p_id0+p.shape[0]) 149 | p_id = np.concatenate([p_id, p_id_new]) 150 | 151 | if p0 is not None: 152 | p = np.concatenate([p0, p], axis=0) 153 | 154 | return p, p_id 155 | 156 | def naive_adv(self, p, u, r): 157 | ''' 158 | reconstruct density field from p_t' with r 159 | ''' 160 | # advect particle to t+1 first 161 | feed = {self.res: self.resolution} 162 | feed[self.x] = p 163 | feed[self.u] = u 164 | feed[self.r] = r 165 | p_adv, d_rec = self.sess.run([self.x_adv, self.d_rec], feed) 166 | return p_adv, d_rec 167 | 168 | def optimize(self, p, p_id, d, u): 169 | ''' 170 | 1. advect p_t using u_t then optimize for redistribution 171 | 2. sample new particles where particles don't cover (src region) 172 | 3. sample particle density from d_(t+1) 173 | ''' 174 | # advect particle to t+1 first 175 | feed = {self.res: self.resolution} 176 | feed[self.x] = p # p_t 177 | feed[self.u] = u 178 | p = self.sess.run(self.x_adv, feed) 179 | 180 | # optimize for particle redistribution 181 | feed[self.x] = p # p_t' 182 | feed[self.v] = np.zeros_like(p) 183 | 184 | # init variables 185 | self.sess.run(self.opt_init, feed) 186 | self.sess.run(tf.compat.v1.variables_initializer(self.opt.variables()), feed) 187 | 188 | # optimize particle positions 189 | l = [] 190 | for _ in range(self.iter): 191 | # self.sess.run(self.train_op, feed) 192 | l_, _ = self.sess.run([self.loss, self.train_op], feed) 193 | l.append(l_) 194 | 195 | # seed particles 196 | feed[self.d] = d # d_t 197 | d_diff = self.sess.run(self.d_diff, feed) 198 | p_new = self.sess.run(self.x_hat, feed)[0] 199 | p, p_id = self.sample(d_diff, disc=self.disc, threshold=self.threshold, p0=p_new, p_id=p_id) 200 | 201 | # sample density at new position 202 | feed[self.x_hat] = p[None,:] # p_t' 203 | p_den = self.sess.run(self.r_smp, feed) 204 | 205 | result = { 206 | 'p': p, 207 | 'p_id': p_id, 208 | 'p_den': p_den, 209 | 'l': l, 210 | } 211 | 212 | # for debug 213 | result['d_diff'] = np.mean(d_diff, axis=0) 214 | d_smp = self.sess.run(self.d_smp, feed) 215 | result['d_smp'] = d_smp 216 | 217 | return result 218 | 219 | def run(config): 220 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # so the IDs match nvidia-smi 221 | os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu_id # "0, 1" for multiple 222 | 223 | prepare_dirs_and_logger(config) 224 | tf.compat.v1.set_random_seed(config.seed) 225 | config.rng = np.random.RandomState(config.seed) 226 | 227 | resampler = SimG2P(config) 228 | 229 | # load input density fields 230 | for t in trange(config.num_frames, desc='load density'): # last one for mask 231 | d_path = os.path.join(config.data_dir, config.dataset, config.d_path % (config.target_frame+t)) 232 | with np.load(d_path) as data: 233 | d = data['x'][:,::-1] # [D,H,W], [0-1] 234 | 235 | # mantaflow dataset 236 | v_path = os.path.join(config.data_dir, config.dataset, config.v_path % (config.target_frame+t)) 237 | with np.load(v_path) as data: 238 | v_ = data['x'] # [D,H,W,3] 239 | vx = np.dstack((v_,np.zeros((v_.shape[0],v_.shape[1],1,v_.shape[3])))) 240 | vx = (vx[:,:,1:,0] + vx[:,:,:-1,0]) * 0.5 241 | vy = np.hstack((v_,np.zeros((v_.shape[0],1,v_.shape[2],v_.shape[3])))) 242 | vy = (vy[:,1:,:,1] + vy[:,:-1,:,1]) * 0.5 243 | vz = np.vstack((v_,np.zeros((1,v_.shape[1],v_.shape[2],v_.shape[3])))) 244 | vz = (vz[1:,:,:,2] + vz[:-1,:,:,2]) * 0.5 245 | v_ = np.stack([vx,vy,vz], axis=-1) 246 | v_ = v_[:,::-1] 247 | 248 | vx = v_[...,0] / v_.shape[2] * config.scale 249 | vy = -v_[...,1] / v_.shape[1] * config.scale 250 | vz = v_[...,2] / v_.shape[0] * config.scale 251 | u = np.stack([vz,vy,vx], axis=-1) 252 | 253 | if config.resampling: 254 | if t == 0: 255 | n_prev = 0 256 | 257 | # sampling at the beginning wo opt. 258 | p, p_id = resampler.sample(d, disc=config.disc, threshold=0) 259 | 260 | result = resampler.optimize(p, p_id, d, u) 261 | 262 | p = result['p'] 263 | p_id = result['p_id'] 264 | p_den = result['p_den'] 265 | # d_diff = result['d_diff'] 266 | # plt.imshow(d_diff); plt.show() 267 | l = result['l'][-1] # last loss 268 | d_smp = result['d_smp'] 269 | else: 270 | if t == 0: 271 | n_prev = 0 272 | 273 | # sampling at the beginning wo opt. 274 | p, p_id = resampler.sample(d, disc=config.disc, threshold=0) 275 | p_src = p 276 | else: 277 | # simply source particles of t=0 278 | p = np.concatenate([p,p_src], axis=0) 279 | p_id = np.arange(p.shape[0]) 280 | 281 | p_den = np.ones([p.shape[0],1]) 282 | p, d_smp = resampler.naive_adv(p, u, p_den) 283 | l = 0 284 | 285 | print(t, 'num particles', p.shape[0], '(+%d)' % (p.shape[0]-n_prev), 'loss', l) 286 | n_prev = p.shape[0] 287 | 288 | # convert to the original domain coordinate 289 | px, py, pz = p[...,2], 1-p[...,1], p[...,0] 290 | p_ = np.stack([ 291 | px*config.domain[2], 292 | py*config.domain[1], 293 | pz*config.domain[0]], axis=-1) 294 | 295 | # create a particle set and attributes 296 | pt = partio.create() 297 | pid = pt.addAttribute('id',partio.INT,1) 298 | position = pt.addAttribute("position",partio.VECTOR,3) 299 | if p_den.shape[1] > 1: 300 | density = pt.addAttribute('density',partio.VECTOR,p_den.shape[1]) 301 | else: 302 | density = pt.addAttribute('density',partio.FLOAT,1) 303 | color = pt.addAttribute("Cd",partio.FLOAT,3) 304 | radius = pt.addAttribute("radius",partio.FLOAT,1) 305 | 306 | for i in range(p_.shape[0]): 307 | pt_ = pt.addParticle() 308 | pt.set(pid, pt_, (int(p_id[i]),)) 309 | pt.set(position, pt_, tuple(p_[i].astype(np.float))) 310 | if p_den.shape[1] > 1: 311 | pt.set(density, pt_, tuple(p_den[i].astype(np.float))) 312 | else: 313 | pt.set(density, pt_, (float(p_den[i]),)) 314 | pt.set(color, pt_, tuple(np.array([p_den[i,0]]*3,dtype=np.float))) 315 | pt.set(radius, pt_, (config.radius,)) 316 | 317 | # save particle 318 | p_path = os.path.join(config.log_dir, '%03d.bgeo' % (config.target_frame+t)) 319 | partio.write(p_path, pt) 320 | 321 | # save density image 322 | transmit = np.exp(-np.cumsum(d_smp[::-1], axis=0)*config.transmit) 323 | d_img = np.sum(d_smp*transmit, axis=0) 324 | d_img /= d_img.max() 325 | im = Image.fromarray((d_img[::-1]*255).astype(np.uint8)) 326 | im_path = os.path.join(config.log_dir, '%03d.png' % (config.target_frame+t)) 327 | im.save(im_path) 328 | 329 | stat_path = os.path.join(config.log_dir, 'stat.txt') 330 | with open(stat_path, 'w') as f: 331 | f.write('num particles %d\n' % p.shape[0]) 332 | f.write('loss %.2f' % l) 333 | 334 | # # visualize last frame 335 | # bbox = [ 336 | # [0,0,0], 337 | # [config.domain[2],config.domain[1],config.domain[0]], 338 | # ] 339 | # if config.octave_n == 1: 340 | # pc = np.concatenate([p_den]*3, axis=-1) 341 | # else: 342 | # pc = np.concatenate([p_den[:,0,None]]*3, axis=-1) 343 | # draw_pt([p_], pc=[pc], bbox=bbox, is_2d=False) 344 | 345 | def main(config): 346 | config.dataset = 'smokegun' 347 | config.d_path = 'd_low/%03d.npz' 348 | config.v_path = 'v_low/%03d.npz' 349 | 350 | config.num_frames = 120 351 | config.target_frame = 0 # 120 - config.num_frames 352 | 353 | # config.target_frame = 60 354 | # config.num_frames = 3 355 | 356 | config.scale = 1 357 | config.domain = [_*config.scale for _ in [200,300,200]] 358 | config.resolution = [int(_) for _ in config.domain] 359 | 360 | config.disc = 1 361 | cell_size = 1 # == 2*radius*disc 362 | config.radius = cell_size/config.disc/2 363 | config.nsize = 1 364 | config.support = 4 365 | config.rest_density = 1000 366 | config.threshold = 0.01 367 | config.lr = 0.0005 368 | config.iter = 20 369 | config.transmit = 0.01 370 | config.octave_n = 2 371 | if config.octave_n > 1: 372 | config.octave_scale = 2 373 | else: 374 | config.octave_scale = 1 375 | 376 | # resampling or naive advection 377 | config.resampling = True 378 | if config.resampling: 379 | config.tag = 'n%d_it%d_o%d' % (config.num_frames, config.iter, config.octave_n) 380 | else: 381 | config.tag = 'naive_n%d' % config.num_frames 382 | 383 | run(config) 384 | 385 | if __name__ == "__main__": 386 | config, unparsed = get_config() 387 | main(config) 388 | -------------------------------------------------------------------------------- /styler_base.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import tensorflow as tf 5 | import numpy as np 6 | import os 7 | from util import * 8 | from transform import advect 9 | import vgg 10 | 11 | class StylerBase(object): 12 | def __init__(self, self_dict): 13 | # get arguments 14 | for arg in vars(self_dict): 15 | setattr(self, arg, getattr(self_dict,arg)) 16 | 17 | # inception network setting 18 | self.model_path = os.path.join(self.data_dir, self.model_dir, self.network) 19 | if 'inception' in self.model_path: 20 | self.graph = tf.compat.v1.Graph() 21 | self.sess = tf.compat.v1.InteractiveSession(graph=self.graph) 22 | with tf.io.gfile.GFile(self.model_path, 'rb') as f: 23 | self.graph_def = tf.compat.v1.GraphDef() 24 | self.graph_def.ParseFromString(f.read()) 25 | 26 | # fix checkerboard artifacts: ksize should be divisible by the stride size 27 | # but it changes scale 28 | if self.pool1: 29 | for n in self.graph_def.node: 30 | if 'conv2d0_pre_relu/conv' in n.name: 31 | n.attr['strides'].list.i[1:3] = [1,1] 32 | 33 | def _plugin_to_loss_net(self, d): 34 | # resize rendering if needed 35 | if not np.isclose(self.resize_scale, 1): 36 | h = tf.cast(tf.multiply(float(self.resize_scale), tf.cast(tf.shape(d)[1], tf.float32)), tf.int32) 37 | w = tf.cast(tf.multiply(float(self.resize_scale), tf.cast(tf.shape(d)[2], tf.float32)), tf.int32) 38 | d = tf.compat.v1.image.resize(d, (h,w), method=tf.image.ResizeMethod.BILINEAR) # upsample w/ BICUBIC -> artifacts 39 | 40 | # change the range of d image [0-1] to [0-255] 41 | d = d*255 42 | if not 'c' in self.target_field: 43 | d = tf.concat([d]*3, axis=-1) # [B,H,W,3] 44 | d = tf.reshape(d, [tf.shape(d)[0],tf.shape(d)[1],tf.shape(d)[2],3]) 45 | self.d_img = d 46 | 47 | # plug-in to the pre-trained network 48 | if 'vgg' in self.model_path: 49 | self.sess = tf.compat.v1.InteractiveSession() 50 | self.layers = vgg.load_vgg(d, self.model_path, self.sess) 51 | print(self.layers.keys()) 52 | else: 53 | # imagenet_mean = 117.0 54 | # d_preprocessed = d - vggimagenet_mean 55 | tf.import_graph_def(self.graph_def, {'input': vgg.preprocess(d)}) 56 | self.layers = [op.name for op in self.graph.get_operations() if op.type=='Conv2D' and 'import/' in op.name] 57 | print(self.layers) 58 | 59 | def _transport(self, g, v, a, b, recursive=True): 60 | # g: [H,W,1 or 2], v: [N,H,W,2] 61 | if a < b: 62 | if recursive: 63 | for i in range(a,b): 64 | g = self.sess.run(self.adv, {self.g: g[None,:], self.u: v[i,None]})[0] 65 | else: 66 | # forward once 67 | g = self.sess.run(self.adv, {self.g: g[None,:], self.u: v[a,None]*(b-a)})[0] 68 | elif a > b: 69 | if recursive: 70 | for i in reversed(range(b,a)): 71 | g = self.sess.run(self.adv, {self.g: g[None,:], self.u: -v[i,None]})[0] 72 | else: 73 | g = self.sess.run(self.adv, {self.g: g[None,:], self.u: -v[a-1,None]*(a-b)})[0] 74 | return g 75 | 76 | def _transport_tf(self, v, a, b, recursive=True): 77 | if a < b: 78 | if recursive: 79 | for i in range(a,b): 80 | v = advect(v, tf.expand_dims(self.u[i], axis=0)) 81 | else: 82 | v = advect(v, tf.expand_dims(self.u[a]*(b-a), axis=0)) 83 | elif a > b: 84 | if recursive: 85 | for i in reversed(range(b,a)): 86 | v = advect(v, tf.expand_dims(-self.u[i], axis=0)) 87 | else: 88 | v = advect(v, tf.expand_dims(-self.u[a-1]*(a-b), axis=0)) 89 | return v 90 | 91 | def _layer(self, layer): 92 | if 'input' in layer: return self.d_img 93 | if 'vgg' in self.model_path: return self.layers[layer] 94 | else: return self.graph.get_tensor_by_name("import/%s:0" % layer) 95 | 96 | def _gram_matrix(self, x): 97 | g_ = [] 98 | for i in range(self.batch_size): 99 | F = tf.reshape(x[i], (-1, x.shape[-1])) 100 | g = tf.matmul(tf.transpose(F), F) 101 | g_.append(g) 102 | return tf.stack(g_, axis=0) 103 | 104 | def _hist_match(self, s, t, mask=None): 105 | m_ = [] 106 | sm_ = [] 107 | for i in range(self.batch_size): 108 | m_c = [] 109 | sm_c = [] 110 | for j in range(s.shape[-1]): 111 | s_ = s[i,...,j] 112 | if mask is not None: 113 | nz = tf.not_equal(mask[i,...,0], 0) 114 | s_ = tf.boolean_mask(s_, nz) 115 | sm_c.append(s_) 116 | result = histogram_match_tf(s_, t[i,...,j]) 117 | m_c.append(result['matched']) 118 | m_.append(tf.stack(m_c, axis=-1)) 119 | if mask is not None: 120 | sm_.append(tf.stack(sm_c, axis=-1)) 121 | if mask is not None: 122 | return m_, sm_ 123 | # return tf.stack(m_, axis=0), tf.stack(sm_, axis=0) 124 | else: 125 | return tf.stack(m_, axis=0) 126 | 127 | def _loss(self, params): 128 | self.content_loss = 0 129 | self.style_loss = 0 130 | self.style_loss_layer = [] 131 | self.hist_loss = 0 132 | self.hist_loss_layer = [] 133 | self.total_loss = 0 134 | 135 | if self.w_content: 136 | feature = self._layer(self.content_layer) # assert only one layer 137 | if self.content_img is not None: 138 | self.content_feature = tf.compat.v1.placeholder(tf.float32, name='content_feature_%s' % self.content_layer) 139 | # self.content_loss -= tf.reduce_mean(feature*self.content_feature) # dot 140 | self.content_loss += tf.reduce_mean(tf.math.squared_difference(feature, 141 | self.content_feature*self.w_content_amp)) 142 | else: 143 | if self.content_channel: 144 | self.content_loss -= tf.reduce_mean(feature[...,self.content_channel]) 145 | self.content_loss += tf.reduce_mean(tf.abs(feature[...,:self.content_channel])) 146 | self.content_loss += tf.reduce_mean(tf.abs(feature[...,self.content_channel+1:])) 147 | else: 148 | self.content_loss -= tf.reduce_mean(feature) 149 | 150 | self.total_loss += self.content_loss*self.w_content 151 | 152 | if self.w_style and self.style_img is not None: 153 | self.style_features = [] 154 | for style_layer, w_style_layer in zip(self.style_layer, self.w_style_layer): 155 | feature = self._layer(style_layer) 156 | f_shp = tf.shape(feature) 157 | gram_denom = tf.cast(2*f_shp[1]*f_shp[2]*f_shp[3], tf.float32) 158 | 159 | style_feature = tf.compat.v1.placeholder(tf.float32, shape=feature.shape, name='style_feature_%s' % style_layer) 160 | # style_denom = tf.cast(2*f_shp[1]*f_shp[2]*f_shp[3], tf.float32) 161 | f_shp_ = tf.shape(style_feature) 162 | style_denom = tf.cast(2*f_shp_[1]*f_shp_[2]*f_shp_[3], tf.float32) 163 | self.style_features.append(style_feature) 164 | 165 | if self.style_mask: 166 | style_mask = tf.compat.v1.image.resize(self.d_gray, (f_shp[1],f_shp[2]), method=tf.image.ResizeMethod.BICUBIC) 167 | feature *= style_mask 168 | area_mask = tf.reduce_sum(style_mask[...,0], axis=[1,2], keepdims=True) 169 | gram_denom = 2*area_mask*tf.cast(f_shp[3], tf.float32) 170 | 171 | if self.style_mask_on_ref: 172 | style_feature *= style_mask 173 | style_denom = 2*area_mask*tf.cast(f_shp[3], tf.float32) 174 | 175 | gram = self._gram_matrix(feature) 176 | gram /= gram_denom 177 | 178 | style_gram = self._gram_matrix(style_feature) 179 | style_gram /= style_denom 180 | 181 | style_loss = tf.reduce_sum(tf.math.squared_difference(gram, style_gram)) 182 | self.style_loss_layer.append(style_loss) 183 | self.style_loss += w_style_layer*style_loss 184 | 185 | self.total_loss += self.style_loss*self.w_style 186 | 187 | if self.w_hist and self.style_img is not None: 188 | self.hist_features = [] 189 | for hist_layer, w_hist_layer in zip(self.hist_layer, self.w_hist_layer): 190 | feature = self._layer(hist_layer) 191 | f_shp = tf.shape(feature) 192 | 193 | hist_feature = tf.compat.v1.placeholder(tf.float32, shape=feature.shape, name='hist_feature_%s' % hist_layer) 194 | self.hist_features.append(hist_feature) 195 | 196 | if self.style_mask: 197 | hist_mask = tf.compat.v1.image.resize(self.d_gray, (f_shp[1],f_shp[2]), method=tf.image.ResizeMethod.BICUBIC) 198 | matched_feature, feature_m = self._hist_match(feature, hist_feature, hist_mask) 199 | hist_loss = 0 200 | for m1, m2 in zip(matched_feature, feature_m): 201 | hist_loss += tf.reduce_sum(tf.math.squared_difference(m1, m2)) 202 | else: 203 | matched_feature = self._hist_match(feature, style_feature) 204 | hist_loss = tf.reduce_sum(tf.math.squared_difference(feature, matched_feature)) 205 | 206 | self.hist_loss_layer.append(hist_loss) 207 | self.hist_loss += w_style_layer*hist_loss 208 | 209 | self.total_loss += self.hist_loss*self.w_hist 210 | 211 | if self.w_tv: 212 | self.tv_loss = tf.reduce_mean(tf.compat.v1.image.total_variation(self.d_img)) 213 | self.total_loss += self.tv_loss*self.w_tv 214 | 215 | ####### 216 | # loss for density preservation 217 | if self.w_density > 0: 218 | self.d_loss = 0 219 | self.d_pres = 0 220 | for i in range(self.batch_size): 221 | self.d_loss += tf.reduce_sum(self.d[i])**2 222 | self.d_pres += tf.reduce_sum(-tf.log(tf.abs(self.d[i]) + 1e-6)) 223 | self.total_loss += (self.d_loss+self.d_pres*1e3)*self.w_density 224 | ####### 225 | 226 | ###### 227 | # loss for density correction 228 | if self.w_pressure > 0: 229 | self.pressure_loss = tf.reduce_mean(self.pressure**2) 230 | self.total_loss += self.pressure_loss*self.w_pressure 231 | ###### 232 | 233 | def _content_feature(self, content_target, content_shp): 234 | if not np.isclose(self.resize_scale, 1): 235 | content_shp = [int(s*self.resize_scale) for s in content_shp] 236 | content_target_ = resize(content_target, content_shp, order=3) # bicubic for downsampling 237 | feature = self._layer(self.content_layer) 238 | feature_ = self.sess.run(feature, {self.d_img: [content_target_]*self.batch_size}) 239 | 240 | if self.top_k > 0: 241 | assert('softmax2_pre_activation' in self.content_layer) 242 | feature_k_ = self.sess.run(tf.nn.top_k(np.abs(feature_), k=self.top_k)) 243 | for i in range(len(feature_)): 244 | exclude_idx = np.setdiff1d(np.arange(feature_.shape[1]), feature_k_.indices[i]) 245 | feature_[i,exclude_idx] = 0 246 | 247 | return feature_ 248 | 249 | def _style_feature(self, style_target, style_shp=None): 250 | # mask for style texture 251 | style_m = None 252 | if style_target.shape[-1] == 4: 253 | style_m = style_target[...,-1]/255 254 | style_target = style_target[...,:-1] 255 | style_target *= np.stack([style_m]*3, axis=-1) 256 | 257 | if style_shp is not None: 258 | if not np.isclose(self.resize_scale, 1): 259 | style_shp = [int(s*self.resize_scale) for s in style_shp] 260 | style_target_ = resize(style_target, style_shp, order=3) # bicubic for downsampling 261 | else: 262 | style_target_ = style_target 263 | 264 | style_features = [] 265 | for style_layer, w_style_layer in zip(self.style_layer, self.w_style_layer): 266 | style_feature = self._layer(style_layer) 267 | feed = {self.d_img: [style_target_]*self.batch_size} 268 | style_feature_ = self.sess.run(style_feature, feed) 269 | 270 | if style_m is not None: 271 | feature_mask_ = resize(style_m, style_feature_.shape[1:-1], order=3) # bicubic for downsampling 272 | feature_mask_ = np.stack([feature_mask_]*style_feature_.shape[-1], axis=-1) 273 | feature_mask_= np.stack([feature_mask_]*style_feature_.shape[0], axis=0) 274 | style_feature *= feature_mask_ 275 | 276 | style_features.append(style_feature_) 277 | 278 | return style_features 279 | 280 | def _hist_feature(self, style_target, style_shp=None): 281 | # mask for style texture 282 | style_m = None 283 | if style_target.shape[-1] == 4: 284 | style_m = style_target[...,-1]/255 285 | style_target = style_target[...,:-1] 286 | style_target *= np.stack([style_m]*3, axis=-1) 287 | 288 | if style_shp is not None: 289 | if not np.isclose(self.resize_scale, 1): 290 | style_shp = [int(s*self.resize_scale) for s in style_shp] 291 | style_target_ = resize(style_target, style_shp, order=3) # bicubic for downsampling 292 | else: 293 | style_target_ = style_target 294 | 295 | hist_features = [] 296 | for hist_layer, w_hist_layer in zip(self.hist_layer, self.w_hist_layer): 297 | hist_feature = self._layer(hist_layer) 298 | feed = {self.d_img: [style_target_]*self.batch_size} 299 | hist_feature_ = self.sess.run(hist_feature, feed) 300 | 301 | if style_m is not None: 302 | feature_mask_ = resize(style_m, hist_feature_.shape[1:-1], order=3) # bicubic for downsampling 303 | feature_mask_ = np.stack([feature_mask_]*hist_feature_.shape[-1], axis=-1) 304 | feature_mask_= np.stack([feature_mask_]*hist_feature_.shape[0], axis=0) 305 | hist_feature *= feature_mask_ 306 | 307 | hist_features.append(hist_feature_) 308 | 309 | return hist_features 310 | 311 | def load_img(self, hw=None): 312 | self.content_img = None 313 | self.style_img = None 314 | 315 | if self.w_content > 0 and self.content_target: 316 | content_target = np.float32(Image.open(self.content_target)) 317 | # remove alpha channel if exists 318 | if content_target.shape[-1] == 4: 319 | content_target = content_target[...,:-1] 320 | 321 | # crop 322 | if hw is not None: 323 | ratio = hw[1] / hw[0] 324 | content_target = crop_ratio(content_target, ratio) 325 | 326 | self.content_img = content_target 327 | 328 | if self.w_style > 0 and self.style_target: 329 | style_target = np.float32(Image.open(self.style_target)) 330 | # print(style_target.shape) 331 | if self.style_tiling > 1: 332 | style_target = np.tile(style_target, (self.style_tiling, self.style_tiling, 1)) 333 | # print(style_target.shape) 334 | 335 | # crop 336 | if hw is not None: 337 | ratio = hw[1] / hw[0] 338 | style_target = crop_ratio(style_target, ratio) 339 | 340 | # if style_target.shape[-1] == 4: 341 | # style_m = style_target[...,-1]/255 342 | # style_target = style_target[...,:-1] 343 | # style_target *= np.stack([style_m]*3, axis=-1) 344 | # # plt.imshow(style_target/255); plt.show() 345 | 346 | self.style_img = style_target -------------------------------------------------------------------------------- /styler_3p.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | import tensorflow as tf 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import os 8 | from tqdm import trange 9 | from util import * 10 | from transform import p2g, p2g_wavg, rotate, rot_mat 11 | import vgg 12 | from styler_base import StylerBase 13 | 14 | class Styler(StylerBase): 15 | def __init__(self, self_dict): 16 | StylerBase.__init__(self, self_dict) 17 | 18 | # particle position 19 | # shape: [N,3], scale: [0,1] 20 | p = [] 21 | p_shp = [None,3] 22 | self.p = [] # input 23 | self.v = [] # style 24 | 25 | # particle density, [N,nk] 26 | r_shp = [None,self.num_kernels] 27 | self.r = [] # input 28 | self.d = [] # style 29 | 30 | # output 31 | d = [] 32 | d_gray = [] 33 | 34 | pressure = [] 35 | 36 | self.opt_init = [] 37 | self.opt_ph = [] 38 | self.opt = [] 39 | 40 | self.res = tf.compat.v1.placeholder(tf.int32, [3], name='resolution') 41 | 42 | for i in range(self.batch_size): 43 | # particle position, [N,3] 44 | p_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=p_shp, name='p%d' % i) 45 | self.p.append(p_) 46 | p_ = tf.expand_dims(p_, axis=0) # [1,N,3] 47 | 48 | # particle velocity, [N,3] 49 | if 'p' in self.target_field: 50 | p_opt_ph = tf.compat.v1.placeholder(dtype=tf.float32, shape=p_shp, name='p_opt_ph%d' % i) 51 | self.opt_ph.append(p_opt_ph) 52 | p_opt = tf.Variable(p_opt_ph, validate_shape=False, name='p_opt%d' % i) 53 | self.opt.append(p_opt) 54 | p_opt_ = tf.reshape(p_opt, tf.shape(p_opt_ph)) 55 | p_opt_ = tf.expand_dims(p_opt_, axis=0) 56 | v_ = p_opt_ 57 | self.v.append(v_[0]) 58 | p_ += v_ 59 | 60 | p.append(p_[0]) 61 | 62 | # particle density, [N,nk] 63 | if 'd' in self.target_field: 64 | r_ = tf.compat.v1.placeholder(dtype=tf.float32, shape=r_shp, name='r%d' % i) 65 | self.r.append(r_) 66 | r_ = tf.expand_dims(r_, axis=0) # [1,N,nk] 67 | 68 | r_opt_ph = tf.compat.v1.placeholder(dtype=tf.float32, shape=r_shp, name='r_opt_ph') 69 | self.opt_ph.append(r_opt_ph) 70 | r_opt = tf.Variable(r_opt_ph, validate_shape=False, name='r_opt') 71 | self.opt.append(r_opt) 72 | r_opt_ = tf.reshape(r_opt, tf.shape(r_opt_ph)) 73 | r_opt_ = tf.expand_dims(r_opt_, axis=0) # [1,N,nk] 74 | r_opt_ = tf.clip_by_value(r_opt_, -1, 1) #### necessary! 75 | self.d.append(r_opt_[0]) 76 | r_ += r_opt_ 77 | 78 | # weighted avg. density estimation 79 | for k in range(self.num_kernels): 80 | factor = self.kernel_scale**k 81 | support = self.support/factor 82 | r_k = tf.expand_dims(r_[...,k], axis=-1) 83 | d_hat = p2g_wavg(p_, r_k, self.domain, self.res, self.radius, self.nsize, kernel='cubic', support=support, clip=self.clip, is_2d=False) 84 | if k == 0: 85 | d_ = d_hat 86 | else: 87 | d_ += d_hat 88 | else: 89 | # position-based (SPH) density field estimation 90 | d_ = p2g(p_, self.domain, self.res, self.radius, self.rest_density, self.nsize, support=self.support, clip=self.clip, is_2d=False) # [B,N,3] -> [B,D,H,W,1] 91 | d_ /= self.rest_density # normalize density 92 | 93 | d.append(d_) 94 | 95 | # pressure estimation 96 | if self.w_pressure > 0 and 'p' in self.target_field: 97 | pressure_ = tf.where(d_>0, d_-1, tf.zeros_like(d_)) 98 | pressure.append(pressure_) 99 | 100 | self.opt_init = tf.compat.v1.initializers.variables(self.opt) 101 | 102 | # stylized (advected) particles 103 | self.p_out = p # [N,3]*B 104 | 105 | # estimated density fields 106 | d = tf.concat(d, axis=0) # [B,D,H,W,1] 107 | 108 | if self.w_pressure > 0 and 'p' in self.target_field: 109 | pressure = tf.concat(pressure, axis=0) # [B,D,H,W,1] 110 | self.pressure = pressure 111 | 112 | if self.k > 0: 113 | # smoothing density for density optimization 114 | k = [] 115 | k1 = np.float32([1,self.k,1]) 116 | k2 = np.outer(k1, k1) 117 | for i in k1: 118 | k.append(k2*i) 119 | k = np.array(k) 120 | k = k[:,:,:,None,None]/k.sum() 121 | d = tf.nn.conv3d(d, k, [1,1,1,1,1], 'SAME') 122 | 123 | # value clipping for rendering 124 | # d = tf.clip_by_value(d, 0, 1) 125 | d = tf.maximum(d, 0) 126 | 127 | # stylized result 128 | self.d_out = d # [B,D,H,W,1] 129 | 130 | #### 131 | # rotate 3d smoke for rendering 132 | if self.rotate: 133 | d, self.rot_mat = rotate(d) # [B,D,H,W,1] or [B,D,H,W,4] 134 | self.d_out_rot = d 135 | 136 | # compute rotation matrices 137 | self.rot_mat_, self.views = rot_mat(self.phi0, self.phi1, self.phi_unit, 138 | self.theta0, self.theta1, self.theta_unit, 139 | sample_type=self.sample_type, rng=self.rng, 140 | nv=self.n_views) 141 | 142 | if self.n_views is None: 143 | self.n_views = len(self.views) 144 | print('# vps:', self.n_views) 145 | assert(self.n_views % self.v_batch == 0) 146 | 147 | # render 3d volume 148 | if self.render_liquid: 149 | # d = tf.reduce_max(d, axis=1) # [B,H,W,1] 150 | transmit = tf.exp(-tf.cumsum(d[:,::-1], axis=1)*self.transmit) 151 | self.d_trans = transmit 152 | d = 1 - transmit[:,-1] # [B,H,W,1], [0,1] 153 | # d = (1 - transmit[:,-1])*np.array([0.26, 0.5, 0.75]) + transmit[:,-1]*np.array([1, 1, 1]) # [B,H,W,1], [0,1] 154 | else: 155 | transmit = tf.exp(-tf.cumsum(d[:,::-1], axis=1)*self.transmit)[:,::-1] 156 | d *= transmit 157 | d = tf.reduce_sum(d, axis=1) # [B,H,W,1] or [B,H,W,3] 158 | d /= tf.reduce_max(d) # [B,H,W,1], [0,1] 159 | 160 | # mask for style features 161 | self.d_gray = d # [B,H,W,1] 162 | #### 163 | 164 | self._plugin_to_loss_net(d) 165 | 166 | def render_test(self, params): 167 | feed = {} 168 | feed[self.res] = self.resolution 169 | 170 | if self.rotate: 171 | feed[self.rot_mat] = self.rot_mat_[:self.v_batch] 172 | 173 | for i in range(self.batch_size): 174 | feed[self.p[i]] = params['p'][i] 175 | n = params['p'][i].shape[0] 176 | if 'p' in self.target_field: 177 | feed[self.opt_ph[i]] = np.zeros([n,3]) 178 | if 'd' in self.target_field: 179 | feed[self.r[i]] = params['r'][i] 180 | feed[self.opt_ph[i]] = np.zeros([n,self.num_kernels]) 181 | 182 | self.sess.run(self.opt_init, feed) 183 | p_out, d_img, d_gray = self.sess.run([self.p_out, self.d_img, self.d_gray], feed) 184 | plt.subplot(121) 185 | plt.imshow(d_img[0].astype(np.uint8)) 186 | plt.subplot(122) 187 | plt.imshow(d_gray[0,...,0]) 188 | plt.show() 189 | 190 | for i, p in enumerate(p_out): 191 | p[:,0] = p[:,0]*self.domain[0] 192 | p[:,1] = p[:,1]*self.domain[1] 193 | p[:,2] = p[:,2]*self.domain[2] 194 | p_out[i] = np.stack([p[:,2],p[:,1],p[:,0]], axis=-1) 195 | v_ = None 196 | bbox = [ 197 | [0,0,0], 198 | [self.domain[2],self.domain[1],self.domain[0]], 199 | ] 200 | draw_pt(p_out, pv=v_, bbox=bbox, is_2d=False) 201 | 202 | feed = {} 203 | feed[self.res] = self.resolution 204 | if self.rotate: 205 | feed[self.rot_mat] = [np.identity(3)]*self.batch_size 206 | 207 | # save to image 208 | for t in trange(0,self.num_frames,self.batch_size): 209 | if t == 0: 210 | n = params['p'][0].shape[0] 211 | 212 | for i in range(self.batch_size): 213 | feed[self.p[i]] = params['p'][t+i] 214 | if 'p' in self.target_field: 215 | feed[self.opt_ph[i]] = np.zeros([n,3]) 216 | if 'd' in self.target_field: 217 | feed[self.r[i]] = params['r'][t+i] 218 | feed[self.opt_ph[i]] = np.zeros([n,self.num_kernels]) 219 | 220 | self.sess.run(self.opt_init, feed) 221 | d_out = self.sess.run(self.d_img, feed) 222 | # plt.imshow(d_out[0]) 223 | # plt.show() 224 | for i in range(self.batch_size): 225 | im = Image.fromarray(d_out[i].astype(np.uint8)) 226 | d_path = os.path.join(self.log_dir, '%03d.png' % (t+i)) 227 | im.save(d_path) 228 | 229 | def run(self, params): 230 | # loss 231 | self._loss(params) 232 | 233 | # optimizer 234 | self.opt_lr = tf.compat.v1.placeholder(tf.float32) 235 | 236 | # adaptive learning rate per octave 237 | if abs(self.lr_scale - 1) > 1e-7: 238 | self.lr = [self.lr/self.lr_scale**i for i in range(self.octave_n)] 239 | 240 | # settings for octave process 241 | oct_size = [] 242 | dhw = np.array(self.resolution) 243 | for _ in range(self.octave_n): 244 | oct_size.append(dhw) 245 | dhw = (dhw//self.octave_scale).astype(np.int) 246 | oct_size.reverse() 247 | print('input size for each octave', oct_size) 248 | 249 | p = params['p'] 250 | 251 | g_opt = [] 252 | if 'p' in self.target_field: 253 | for i in range(self.num_frames): 254 | n = p[i].shape[0] 255 | p_opt_shp = [n, 3] 256 | p_opt = np.zeros(shape=p_opt_shp, dtype=np.float32) 257 | g_opt.append(p_opt) 258 | 259 | if 'd' in self.target_field: 260 | r = params['r'] 261 | for i in range(self.num_frames): 262 | n = p[i].shape[0] 263 | r_opt_shp = [n, self.num_kernels] 264 | r_opt_ = np.zeros(shape=r_opt_shp, dtype=np.float32) 265 | g_opt.append(r_opt_) 266 | 267 | # optimize 268 | loss_history = [] 269 | d_intm = [] 270 | opt_ = {} 271 | for octave in trange(self.octave_n, desc='octave'): 272 | loss_history_o = [] 273 | d_intm_o = [] 274 | 275 | feed = {} 276 | feed[self.res] = oct_size[octave] 277 | if self.content_img is not None: 278 | feed[self.content_feature] = self._content_feature( 279 | self.content_img, oct_size[octave][1:]) 280 | 281 | if self.style_img is not None: 282 | style_features = self._style_feature( 283 | self.style_img, oct_size[octave][1:]) 284 | 285 | for i in range(len(self.style_features)): 286 | feed[self.style_features[i]] = style_features[i] 287 | 288 | if self.w_hist > 0: 289 | hist_features = self._hist_feature( 290 | self.style_img, oct_size[octave][1:]) 291 | 292 | for i in range(len(self.hist_features)): 293 | feed[self.hist_features[i]] = hist_features[i] 294 | 295 | if type(self.lr) == list: 296 | lr = self.lr[octave] 297 | else: 298 | lr = self.lr 299 | 300 | # optimizer list for each batch 301 | for step in trange(self.iter,desc='iter'): 302 | g_tmp = [None]*self.num_frames 303 | 304 | for t in range(0,self.num_frames,self.batch_size*self.interp): 305 | for i in range(self.batch_size): 306 | feed[self.p[i]] = p[t+i*self.interp] 307 | feed[self.opt_ph[i]] = g_opt[t+i*self.interp] 308 | if 'd' in self.target_field: 309 | feed[self.r[i]] = r[t+i*self.interp] 310 | 311 | # assign g_opt to self.opt through self.opt_ph 312 | self.sess.run(self.opt_init, feed) 313 | 314 | feed[self.opt_lr] = lr 315 | opt_id = t//self.frames_per_opt 316 | # opt_id = self.rng.randint(num_opt) 317 | if opt_id in opt_: 318 | train_op = opt_[opt_id] 319 | else: 320 | opt = tf.compat.v1.train.AdamOptimizer(learning_rate=self.opt_lr) 321 | train_op = opt.minimize(self.total_loss, var_list=self.opt) 322 | self.sess.run(tf.compat.v1.variables_initializer(opt.variables()), feed) 323 | opt_[opt_id] = train_op 324 | 325 | # optimize 326 | if self.rotate: 327 | g_opt_ = None 328 | l_ = [] 329 | for i in range(0, self.n_views, self.v_batch): 330 | feed[self.rot_mat] = self.rot_mat_[i:i+self.v_batch] 331 | _, l_vp = self.sess.run([train_op, self.total_loss], feed) 332 | l_.append(l_vp) 333 | 334 | g_opt_i = self.sess.run(self.opt, feed) 335 | 336 | if i == 0: 337 | g_opt_ = np.nan_to_num(g_opt_i) 338 | else: 339 | for j in range(self.batch_size): 340 | g_opt_[j] += np.nan_to_num(g_opt_i[j]) 341 | 342 | loss_history_o.append(np.mean(l_)) 343 | 344 | if not 'uniform' in self.sample_type: 345 | self.rot_mat_, self.views = rot_mat( 346 | self.phi0, self.phi1, self.phi_unit, 347 | self.theta0, self.theta1, self.theta_unit, 348 | sample_type=self.sample_type, rng=self.rng, 349 | nv=self.n_views) 350 | 351 | for i in range(self.batch_size): 352 | g_opt_[i] /= (self.n_views/self.v_batch) 353 | else: 354 | _, l_ = self.sess.run([train_op, self.total_loss], feed) 355 | loss_history_o.append(l_) 356 | 357 | g_opt_ = self.sess.run(self.opt, feed) 358 | 359 | for i in range(self.batch_size): 360 | g_tmp[t+i*self.interp] = np.nan_to_num(g_opt_[i]) - g_opt[t+i*self.interp] 361 | if 'd' in self.target_field: 362 | # masking by original density 363 | g_tmp[t+i*self.interp] *= r[t+i*self.interp][...,0,None] 364 | 365 | if step == self.iter-1 and octave < self.octave_n-1: # True or 366 | if self.rotate: 367 | feed[self.rot_mat] = [np.identity(3)]*self.batch_size 368 | 369 | d_intm_ = self.sess.run(self.d_img, feed) 370 | d_intm_o.append(d_intm_.astype(np.uint8)) 371 | 372 | # ## debug 373 | # d_gray = self.sess.run(self.d_gray, feed) 374 | # plt.subplot(121) 375 | # plt.imshow(d_intm_[0,...]) 376 | # plt.subplot(122) 377 | # plt.imshow(d_gray[0,...,0]) 378 | # plt.show() 379 | 380 | ######### 381 | # gradient alignment 382 | if self.window_sigma > 0 and self.num_frames > 1: 383 | g_tmp[:self.num_frames:self.interp] = denoise(g_tmp[:self.num_frames:self.interp], sigma=(self.window_sigma,0,0)) 384 | 385 | for t in range(0,self.num_frames,self.interp): 386 | g_opt[t] += g_tmp[t] 387 | 388 | loss_history.append(loss_history_o) 389 | if octave < self.octave_n-1: 390 | d_intm.append(np.concatenate(d_intm_o, axis=0)) 391 | 392 | if self.interp > 1: 393 | w = np.linspace(0, 1, self.interp+1) 394 | for t in range(0,self.num_frames-1,self.interp): 395 | for i in range(1,self.interp): 396 | print(t+i, w[i]) 397 | g_opt[t+i] = g_opt[t]*(1-w[i]) + g_opt[t+self.interp]*w[i] 398 | 399 | # gather outputs 400 | result = { 401 | 'l': loss_history, 'd_intm': d_intm, 402 | 'v': None, 'c': None} 403 | 404 | # final inference 405 | p_sty = [None]*self.num_frames 406 | v_sty = [None]*self.num_frames 407 | r_sty = [None]*self.num_frames 408 | d_sty = [None]*self.num_frames 409 | for t in range(0,self.num_frames,self.batch_size): 410 | for i in range(self.batch_size): 411 | feed[self.p[i]] = p[t+i] 412 | feed[self.opt_ph[i]] = g_opt[t+i] 413 | if 'd' in self.target_field: 414 | feed[self.r[i]] = r[t+i] 415 | 416 | if self.rotate: 417 | feed[self.rot_mat] = [np.identity(3)]*self.batch_size 418 | 419 | self.sess.run(self.opt_init, feed) 420 | p_, d_, d_img = self.sess.run([self.p_out, self.d_out, self.d_img], feed) 421 | 422 | if 'p' in self.target_field: 423 | v_ = self.sess.run(self.v, feed) 424 | 425 | for i in range(self.batch_size): 426 | p_sty[t+i] = p_[i] 427 | if 'p' in self.target_field: 428 | v_sty[t+i] = v_[i] 429 | 430 | d_sty[t:t+self.batch_size] = d_ 431 | r_sty[t:t+self.batch_size] = d_img.astype(np.uint8) 432 | 433 | result['p'] = p_sty 434 | if 'p' in self.target_field: 435 | result['v'] = v_sty 436 | result['d'] = np.array(d_sty) 437 | result['r'] = np.array(r_sty) 438 | 439 | return result -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | from datetime import datetime 5 | import os 6 | import imageio 7 | from glob import glob 8 | import shutil 9 | import numpy as np 10 | from PIL import Image 11 | import logging 12 | import json 13 | from scipy.ndimage import gaussian_filter, zoom 14 | import skimage.transform 15 | from functools import partial 16 | import tensorflow as tf 17 | import matplotlib.colors 18 | import matplotlib.pyplot as plt 19 | try: 20 | import open3d as o3d 21 | except ImportError: 22 | pass # leonhard 23 | from matplotlib import cm 24 | from subprocess import call 25 | import sys 26 | 27 | k = np.float32([1,4,6,4,1]) 28 | k = np.outer(k, k) 29 | k5x5 = { 30 | 1: k[:,:,None,None]/k.sum(), 31 | 2: k[:,:,None,None]/k.sum()*np.eye(2, dtype=np.float32), 32 | 3: k[:,:,None,None]/k.sum()*np.eye(3, dtype=np.float32), 33 | 4: k[:,:,None,None]/k.sum()*np.eye(4, dtype=np.float32), 34 | 6: k[:,:,None,None]/k.sum()*np.eye(6, dtype=np.float32), 35 | } 36 | k_ = [] 37 | k2_ = [1,16**(1/3),36**(1/3),16**(1/3),1] 38 | k2 = np.float32([1,16**(1/3),36**(1/3),16**(1/3),1]) 39 | k2 = np.outer(k2, k2) 40 | for i in k2_: 41 | k_.append(k2*i) 42 | k_ = np.floor(np.array(k_)) 43 | k5x5x5 = { 44 | 1: k_[:,:,:,None,None]/k_.sum(), 45 | 3: k_[:,:,:,None,None]/k_.sum()*np.eye(3, dtype=np.float32), 46 | 5: k_[:,:,:,None,None]/k_.sum()*np.eye(5, dtype=np.float32), 47 | } 48 | 49 | 50 | def cosine_decay(global_step, decay_steps, learning_rate, factor): 51 | global_step = min(global_step, decay_steps) 52 | cos_decay = np.cos(np.pi * global_step / decay_steps) # [1, -1] 53 | cos_decay = (cos_decay + 1)*0.5*(factor-1) + 1 # [factor, 1] 54 | return learning_rate * cos_decay # 2lr -> lr 55 | 56 | 57 | def lap_split(img, is_3d, k): 58 | '''Split the image into lo and hi frequency components''' 59 | with tf.name_scope('split'): 60 | if is_3d: 61 | lo = tf.nn.conv3d(img, k, [1,2,2,2,1], 'SAME') 62 | lo2 = tf.nn.conv3d_transpose(lo, k*5, tf.shape(img), [1,2,2,2,1]) 63 | else: 64 | lo = tf.nn.conv2d(img, k, [1,2,2,1], 'SAME') 65 | lo2 = tf.nn.conv2d_transpose(lo, k*4, tf.shape(img), [1,2,2,1]) 66 | hi = img-lo2 67 | return lo, hi 68 | 69 | def lap_split_n(img, n, is_3d, k): 70 | '''Build Laplacian pyramid with n splits''' 71 | levels = [] 72 | for i in range(n): 73 | img, hi = lap_split(img, is_3d, k) 74 | levels.append(hi) 75 | levels.append(img) 76 | return levels[::-1] 77 | 78 | def lap_merge(levels, is_3d, k): 79 | '''Merge Laplacian pyramid''' 80 | img = levels[0] 81 | for hi in levels[1:]: 82 | with tf.name_scope('merge'): 83 | if is_3d: 84 | img = tf.nn.conv3d_transpose(img, k*5, tf.shape(hi), [1,2,2,2,1]) + hi 85 | else: 86 | img = tf.nn.conv2d_transpose(img, k*4, tf.shape(hi), [1,2,2,1]) + hi 87 | return img 88 | 89 | def normalize_std(img, eps=1e-10): 90 | '''Normalize image by making its standard deviation = 1.0''' 91 | with tf.name_scope('normalize'): 92 | std = tf.sqrt(tf.reduce_mean(tf.square(img))) 93 | return img/tf.maximum(std, eps) 94 | 95 | def lap_normalize(img, scale_n=3, is_3d=False, c=1): 96 | '''Perform the Laplacian pyramid normalization.''' 97 | if scale_n == 0: 98 | m = tf.reduce_mean(tf.abs(img)) 99 | return img/tf.maximum(m, 1e-7) 100 | else: 101 | if is_3d: 102 | k = k5x5x5[c] 103 | else: 104 | k = k5x5[c] 105 | 106 | img = tf.expand_dims(img, 0) 107 | tlevels = lap_split_n(img, scale_n, is_3d, k) 108 | tlevels = list(map(normalize_std, tlevels)) 109 | out = lap_merge(tlevels, is_3d, k) 110 | return out[0] 111 | 112 | def tffunc(*argtypes): 113 | '''Helper that transforms TF-graph generating function into a regular one. 114 | See "resize" function below. 115 | ''' 116 | placeholders = list(map(tf.compat.v1.placeholder, argtypes)) 117 | def wrap(f): 118 | out = f(*placeholders) 119 | def wrapper(*args, **kw): 120 | return out.eval(dict(zip(placeholders, args)), session=kw.get('session')) 121 | return wrapper 122 | return wrap 123 | 124 | def int_shape(tensor): 125 | shape = tensor.get_shape().as_list() 126 | return [num if num is not None else -1 for num in shape] 127 | 128 | def resize_tf(x, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, is_3d=False): 129 | if is_3d: 130 | # b, d, h, w, c = int_shape(x) 131 | shp = tf.shape(x) 132 | b, d, h, w, c = shp[0], shp[1], shp[2], shp[3], shp[4] 133 | hw = tf.reshape(tf.transpose(x, [0,2,3,1,4]), [b,h,w,d*c]) 134 | h, w = size[1], size[2] 135 | hw = tf.compat.v1.image.resize(hw, (h,w), method=method) 136 | hw = tf.reshape(hw, [b,h,w,d,c]) 137 | dh = tf.reshape(tf.transpose(hw, [0,3,1,2,4]), [b,d,h,w*c]) 138 | d = size[0] 139 | dh = tf.compat.v1.image.resize(dh, (d,h), method=method) 140 | x = tf.reshape(dh, [b,d,h,w,c]) 141 | else: 142 | x = tf.compat.v1.image.resize(x, size, method=method) 143 | return x 144 | 145 | def rescale_tf(x, scale, method=tf.image.ResizeMethod.BILINEAR, is_3d=False): 146 | if is_3d: 147 | # b, d, h, w, c = int_shape(x) 148 | shp = tf.shape(x) 149 | b, d, h, w, c = shp[0], shp[1], shp[2], shp[3], shp[4] 150 | hw = tf.reshape(tf.transpose(x, [0,2,3,1,4]), [b,h,w,d*c]) 151 | h = tf.cast(tf.cast(h, tf.float32)*scale, tf.int32) 152 | w = tf.cast(tf.cast(w, tf.float32)*scale, tf.int32) 153 | 154 | hw = tf.compat.v1.image.resize(hw, (h,w), method=method) 155 | hw = tf.reshape(hw, [b,h,w,d,c]) 156 | dh = tf.reshape(tf.transpose(hw, [0,3,1,2,4]), [b,d,h,w*c]) 157 | d = tf.cast(tf.cast(d, tf.float32)*scale, tf.int32) 158 | dh = tf.compat.v1.image.resize(dh, (d,h), method=method) 159 | x = tf.reshape(dh, [b,d,h,w,c]) 160 | else: 161 | # b, h, w, c = int_shape(x) 162 | shp = tf.shape(x) 163 | b, h, w, c = shp[0], shp[1], shp[2], shp[3] 164 | h = tf.cast(tf.cast(h, tf.float32)*scale, tf.int32) 165 | w = tf.cast(tf.cast(w, tf.float32)*scale, tf.int32) 166 | x = tf.compat.v1.image.resize(x, (h,w), method=method) 167 | return x 168 | 169 | def denoise(img, sigma): 170 | return gaussian_filter(img, sigma=sigma) 171 | # if sigma > 0: 172 | # return gaussian_filter(img, sigma=sigma) 173 | # else: 174 | # return img 175 | 176 | def crop_ratio(img, ratio): 177 | hw_t = img.shape[:2] 178 | ratio_t = hw_t[1] / float(hw_t[0]) 179 | if ratio_t > ratio: 180 | hw_ = [hw_t[0], int(hw_t[0]*ratio)] 181 | else: 182 | hw_ = [int(hw_t[1]/ratio), hw_t[1]] 183 | assert(hw_[0] <= hw_t[0] and hw_[1] <= hw_t[1]) 184 | o = [int((hw_t[0]-hw_[0])*0.5), int((hw_t[1]-hw_[1])*0.5)] 185 | return img[o[0]:o[0]+hw_[0], o[1]:o[1]+hw_[1]] 186 | 187 | def resize(img, size=None, f=None, order=1): 188 | vmin, vmax = img.min(), img.max() 189 | if vmin < -1 or vmax > 1: 190 | img = (img - vmin) / (vmax-vmin) # [0,1] 191 | if size is not None: 192 | if img.ndim == 4: 193 | if len(size) == 4: size = size[:-1] 194 | img_ = [] 195 | for i in range(img.shape[-1]): 196 | img_.append(skimage.transform.resize(img[...,i], size, order=order, mode='constant', anti_aliasing=True).astype(np.float32)) 197 | img = np.stack(img_, axis=-1) 198 | elif img.ndim < 4: 199 | img = skimage.transform.resize(img, size, order=order, mode='constant', anti_aliasing=True).astype(np.float32) 200 | else: 201 | assert False 202 | else: 203 | img = skimage.transform.rescale(img, f, order=order, multichannel=None, mode='constant', anti_aliasing=True).astype(np.float32) 204 | if vmin < -1 or vmax > 1: 205 | return img * (vmax-vmin) + vmin 206 | else: 207 | return img 208 | 209 | def save_density(d, d_path): 210 | im = d*255 211 | im = np.stack((im,im,im), axis=-1).astype(np.uint8) 212 | im = Image.fromarray(im) 213 | im.save(d_path) 214 | 215 | def yuv2rgb(y,u,v): 216 | # https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/ops/image_ops_impl.py 217 | r = y + 1.13988303*v 218 | g = y - 0.394642334*u - 0.58062185*v 219 | b = y + 2.03206185*u 220 | # r = y + 1.4746*v 221 | # g = y - 0.16455*u - 0.57135*v 222 | # b = y + 1.8814*u 223 | # ## JPEG 224 | # r = y + 1.402*v 225 | # g = y - 0.344136*u - 0.714136*v 226 | # b = y + 1.772*u 227 | return r,g,b 228 | 229 | def rgb2yuv(r,g,b): 230 | y = 0.299*r + 0.587*g + 0.114*b 231 | u = -0.14714119*r - 0.28886916*g + 0.43601035*b 232 | v = 0.61497538*r - 0.51496512*g - 0.10001026*b 233 | return y,u,v 234 | 235 | def hsv2rgb(h,s,v): 236 | c = s * v 237 | m = v - c 238 | dh = h * 6 239 | h_category = tf.cast(dh, tf.int32) 240 | fmodu = tf.mod(dh, 2) 241 | x = c * (1 - tf.abs(fmodu - 1)) 242 | component_shape = tf.shape(h) 243 | dtype = h.dtype 244 | rr = tf.zeros(component_shape, dtype=dtype) 245 | gg = tf.zeros(component_shape, dtype=dtype) 246 | bb = tf.zeros(component_shape, dtype=dtype) 247 | h0 = tf.equal(h_category, 0) 248 | rr = tf.where(h0, c, rr) 249 | gg = tf.where(h0, x, gg) 250 | h1 = tf.equal(h_category, 1) 251 | rr = tf.where(h1, x, rr) 252 | gg = tf.where(h1, c, gg) 253 | h2 = tf.equal(h_category, 2) 254 | gg = tf.where(h2, c, gg) 255 | bb = tf.where(h2, x, bb) 256 | h3 = tf.equal(h_category, 3) 257 | gg = tf.where(h3, x, gg) 258 | bb = tf.where(h3, c, bb) 259 | h4 = tf.equal(h_category, 4) 260 | rr = tf.where(h4, x, rr) 261 | bb = tf.where(h4, c, bb) 262 | h5 = tf.equal(h_category, 5) 263 | rr = tf.where(h5, c, rr) 264 | bb = tf.where(h5, x, bb) 265 | r = rr + m 266 | g = gg + m 267 | b = bb + m 268 | return r,g,b 269 | 270 | # Util function to match histograms 271 | def match_histograms(source, template): 272 | """ 273 | Adjust the pixel values of a grayscale image such that its histogram 274 | matches that of a target image (source to template) 275 | 276 | Arguments: 277 | ----------- 278 | source: np.ndarray 279 | Image to transform; the histogram is computed over the flattened 280 | array 281 | template: np.ndarray 282 | Template image; can have different dimensions to source 283 | Returns: 284 | ----------- 285 | matched: np.ndarray 286 | The transformed output image 287 | """ 288 | 289 | oldshape = source.shape 290 | source = source.ravel() 291 | template = template.ravel() 292 | 293 | # get the set of unique pixel values and their corresponding indices and 294 | # counts 295 | s_values, bin_idx, s_counts = np.unique(source, return_inverse=True, 296 | return_counts=True) 297 | t_values, t_counts = np.unique(template, return_counts=True) 298 | 299 | # take the cumsum of the counts and normalize by the number of pixels to 300 | # get the empirical cumulative distribution functions for the source and 301 | # template images (maps pixel value --> quantile) 302 | s_quantiles = np.cumsum(s_counts).astype(np.float64) 303 | s_quantiles /= s_quantiles[-1] 304 | t_quantiles = np.cumsum(t_counts).astype(np.float64) 305 | t_quantiles /= t_quantiles[-1] 306 | 307 | # interpolate linearly to find the pixel values in the template image 308 | # that correspond most closely to the quantiles in the source image 309 | interp_t_values = np.interp(s_quantiles, t_quantiles, t_values) 310 | 311 | # plt.figure() 312 | # plt.plot(range(len(s_quantiles)), s_quantiles, range(len(t_quantiles)), t_quantiles) 313 | # plt.show() 314 | 315 | return interp_t_values[bin_idx].reshape(oldshape) 316 | 317 | def histogram_match_tf(source, template, hist_bins=255): 318 | shape = tf.shape(source) 319 | 320 | source = tf.layers.flatten(source) 321 | template = tf.layers.flatten(template) 322 | 323 | # get the set of unique pixel values and their corresponding indices and counts 324 | # hist_bins = trainer.hist_bins 325 | 326 | # Defining the 'x_axis' of the histogram 327 | max_value = tf.reduce_max([tf.reduce_max(source), tf.reduce_max(template)]) 328 | min_value = tf.reduce_min([tf.reduce_min(source), tf.reduce_min(template)]) 329 | 330 | hist_delta = (max_value - min_value)/hist_bins 331 | 332 | # Getting the x-axis for each value 333 | hist_range = tf.range(min_value, max_value, hist_delta) 334 | # I don't want the bin values; instead, I want the average value of each bin, which is 335 | # lower_value + hist_delta/2 336 | hist_range = tf.add(hist_range, tf.divide(hist_delta, 2)) 337 | 338 | # Now, making fixed width histograms on this hist_axis 339 | s_hist = tf.histogram_fixed_width(source, 340 | [min_value, max_value], 341 | nbins=hist_bins, 342 | dtype=tf.int64 343 | ) 344 | 345 | t_hist = tf.histogram_fixed_width(template, 346 | [min_value, max_value], 347 | nbins=hist_bins, 348 | dtype=tf.int64 349 | ) 350 | 351 | # take the cumsum of the counts and normalize by the number of pixels to 352 | # get the empirical cumulative distribution functions for the source and 353 | # template images (maps pixel value --> quantile) 354 | s_quantiles = tf.cumsum(s_hist) 355 | s_quantiles /= s_quantiles[-1] 356 | 357 | t_quantiles = tf.cumsum(t_hist) 358 | t_quantiles /= t_quantiles[-1] 359 | 360 | from scipy.interpolate import interp1d 361 | def intp(x, xp): 362 | intp = interp1d(xp, np.arange(hist_bins), bounds_error=False, fill_value=(0,hist_bins-1)) 363 | return np.round(intp(x)).astype(np.int64) 364 | nearest_indices = tf.py_func(intp, [s_quantiles, t_quantiles], tf.int64) 365 | 366 | # nearest_indices = tf.map_fn(lambda x: tf.argmin(tf.abs(tf.subtract(t_quantiles, x))), 367 | # s_quantiles, dtype=tf.int64) 368 | 369 | 370 | # Finding the correct s-histogram bin for every element in source 371 | s_bin_index = tf.cast((source-min_value)/hist_delta, tf.int64) 372 | 373 | ## In the case where an activation function of 0-1 is used, then there might be some index exception errors. 374 | ## This is to deal with those 375 | s_bin_index = tf.clip_by_value(s_bin_index, 0, hist_bins-1) 376 | 377 | # Matching it to the correct t-histogram bin, and then making it the correct shape again 378 | matched_to_t = tf.gather(hist_range, tf.gather(nearest_indices, s_bin_index)) 379 | matched = tf.reshape(matched_to_t, shape) 380 | 381 | # to compare histograms 382 | m_hist = tf.histogram_fixed_width(matched_to_t, 383 | [min_value, max_value], 384 | nbins=hist_bins, 385 | dtype=tf.int64 386 | ) 387 | 388 | # self.s_hist = s_hist 389 | # self.t_hist = t_hist 390 | # self.m_hist = m_hist 391 | # self.s_quantiles = s_quantiles 392 | # self.t_quantiles = t_quantiles 393 | result = { 394 | 's_hist': s_hist, 395 | 't_hist': t_hist, 396 | 'm_hist': m_hist, 397 | 'matched': matched, 398 | } 399 | return result 400 | 401 | def str2bool(v): 402 | return v.lower() in ('true', '1') 403 | 404 | def prepare_dirs_and_logger(config): 405 | config.command = str(sys.argv) 406 | 407 | # print(__file__) 408 | os.chdir(os.path.dirname(__file__)) 409 | 410 | model_name = "{}_{}".format(get_time(), config.tag) 411 | config.log_dir = os.path.join(config.log_dir, config.dataset, model_name) 412 | 413 | if not os.path.exists(config.log_dir): 414 | os.makedirs(config.log_dir) 415 | 416 | save_config(config) 417 | 418 | def save_config(config): 419 | param_path = os.path.join(config.log_dir, "params.json") 420 | 421 | print("[*] MODEL dir: %s" % config.log_dir) 422 | print("[*] PARAM path: %s" % param_path) 423 | 424 | with open(param_path, 'w') as fp: 425 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 426 | 427 | def get_time(): 428 | return datetime.now().strftime("%m%d_%H%M%S") 429 | 430 | def save_video(imgdir, filename, ext='png', fps=24, delete_imgdir=False): 431 | filename = os.path.join(imgdir, filename+'.mp4') 432 | try: 433 | writer = imageio.get_writer(filename, fps=fps) 434 | except Exception: 435 | imageio.plugins.ffmpeg.download() 436 | writer = imageio.get_writer(filename, fps=fps) 437 | 438 | imgs = glob("{}/*.{}".format(imgdir, ext)) 439 | imgs = sorted(imgs, key=lambda x: int(os.path.basename(x).split('.')[0])) 440 | 441 | # print(imgs) 442 | for img in imgs: 443 | im = imageio.imread(img) 444 | writer.append_data(im) 445 | 446 | writer.close() 447 | 448 | if delete_imgdir: shutil.rmtree(imgdir) 449 | 450 | def v2rgb(v): 451 | # lazyfluid colormap 452 | theta = np.arctan2(-v[...,0], -v[...,1]) 453 | theta = (theta + np.pi) / (2*np.pi) 454 | r = np.sqrt(v[...,0]**2+v[...,1]**2) 455 | r_max = r.max() 456 | r /= r_max 457 | o = np.ones_like(r) 458 | hsv = np.stack((theta,r,o), axis=-1) 459 | rgb = matplotlib.colors.hsv_to_rgb(hsv) 460 | rgb = (rgb*255).astype(np.uint8) 461 | return rgb 462 | 463 | # v_path = 'E:/neural-flow-style\log\smoke_plume_f200/1104_165443_test/0_0_151_v.npz' 464 | # v_sty = np.load(v_path)['v'] 465 | # # v_path = 'D:\dev\deep-fluids\data\smoke3_vel5_buo3_f250/v/0_0_150.npz' 466 | # # v_sty = np.load(v_path)['x'][::-1,:,16] 467 | # # import matplotlib.pyplot as plt 468 | # # plt.figure() 469 | # # plt.subplot(131) 470 | # # plt.imshow(v_sty[...,0]) 471 | # # plt.subplot(132) 472 | # # plt.imshow(v_sty[...,1]) 473 | # # plt.subplot(133) 474 | # # plt.imshow(v_sty[...,2]) 475 | # # plt.show() 476 | # # v_sty = np.stack((v_sty[...,1], v_sty[...,0]), axis=-1) 477 | # # v_path = 'E:/neural-flow-style\log\smoke_plume_f200/1104_165443_test/test.npz' 478 | # im = Image.fromarray(v2rgb(v_sty)) 479 | # d_file_path = v_path[:-4]+'.png' 480 | # im.save(d_file_path) 481 | 482 | # save_video('E:/neural-flow-style\log\smoke_plume_f200/1102_064036_adv_s4_w0.05_volc', 'E:/neural-flow-style\log\smoke_plume_f200/1102_064036_adv_s4_w0.05_volc') 483 | 484 | def make_grid(tensor, nrow=8, padding=2, 485 | normalize=False, scale_each=False, gray=True): 486 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 487 | nmaps = tensor.shape[0] 488 | xmaps = min(nrow, nmaps) 489 | ymaps = int(np.ceil(float(nmaps) / xmaps)) 490 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 491 | if padding == 0: 492 | if gray: 493 | grid = np.zeros([height * ymaps, width * xmaps], dtype=np.uint8) 494 | else: 495 | grid = np.zeros([height * ymaps, width * xmaps, 3], dtype=np.uint8) 496 | else: 497 | if gray: 498 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2], dtype=np.uint8) 499 | else: 500 | grid = np.zeros([height * ymaps + 1 + padding // 2, width * xmaps + 1 + padding // 2, 3], dtype=np.uint8) 501 | k = 0 502 | for y in range(ymaps): 503 | for x in range(xmaps): 504 | if k >= nmaps: 505 | break 506 | if padding == 0: 507 | h, h_width = y * height, height 508 | w, w_width = x * width, width 509 | else: 510 | h, h_width = y * height + 1 + padding // 2, height - padding 511 | w, w_width = x * width + 1 + padding // 2, width - padding 512 | 513 | grid[h:h+h_width, w:w+w_width] = tensor[k] 514 | k = k + 1 515 | return grid 516 | 517 | def save_image(tensor, filename, nrow=8, padding=2, 518 | normalize=False, scale_each=False, single=False, gray=True): 519 | if not single: 520 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 521 | normalize=normalize, scale_each=scale_each, gray=gray) 522 | else: 523 | # h, w = tensor.shape[0], tensor.shape[1] 524 | # if gray: 525 | # ndarr = np.zeros([h,w], dtype=np.uint8) 526 | # else: 527 | # ndarr = np.zeros([h,w,3], dtype=np.uint8) 528 | ndarr = tensor 529 | 530 | im = Image.fromarray(ndarr) 531 | im.save(filename) 532 | 533 | def draw_voxel(d): 534 | vox = np.argwhere(d>0) 535 | pcd = o3d.geometry.PointCloud() 536 | pcd.points = o3d.utility.Vector3dVector(vox) 537 | c = d[d>0] 538 | c = np.stack([c]*3, axis=-1) 539 | pcd.colors = o3d.utility.Vector3dVector(c) 540 | # vol = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, 1) 541 | o3d.visualization.draw_geometries([pcd]) 542 | 543 | def npz2vdb(d, vdb_exe, d_path): 544 | np.savez_compressed(d_path, x=d[:,::-1]) 545 | sh = [vdb_exe, 'npz2vdb.py', '--src_path='+d_path] 546 | call(sh, shell=True) 547 | 548 | def draw_pt(pt, pv=None, pc=None, dt=1, is_2d=True, bbox=None): 549 | geom = [] 550 | 551 | # # bounding box 552 | # xmin, xmax = 0, 0 553 | # ymin, ymax = 0, 0 554 | # zmin, zmax = -1, 1 555 | # for i, p in enumerate(pt): 556 | # if i == 0: 557 | # xmin, xmax = p[:,0].min(), p[:,0].max() 558 | # ymin, ymax = p[:,1].min(), p[:,1].max() 559 | # if not is_2d: 560 | # zmin, zmax = p[:,2].min(), p[:,2].max() 561 | # else: 562 | # xmin = min(p[:,0].min(), xmin) 563 | # xmax = max(p[:,0].max(), xmax) 564 | # ymin = min(p[:,1].min(), ymin) 565 | # ymax = max(p[:,1].max(), ymax) 566 | # if not is_2d: 567 | # zmin = min(p[:,2].min(), zmin) 568 | # zmax = max(p[:,2].max(), zmax) 569 | 570 | # bbox = [ 571 | # [xmin, ymin, zmin], 572 | # [xmax, ymax, zmax] 573 | # ] 574 | if bbox is not None: 575 | bp = [] 576 | for i in range(2): 577 | for j in range(2): 578 | for k in range(2): 579 | bp.append([bbox[i][0],bbox[j][1],bbox[k][2]]) 580 | bl = [[0, 1], [0, 2], [1, 3], [2, 3], [4, 5], [4, 6], [5, 7], [6, 7], 581 | [0, 4], [1, 5], [2, 6], [3, 7]] 582 | bbox_line = o3d.geometry.LineSet() 583 | bbox_line.points = o3d.utility.Vector3dVector(bp) 584 | bbox_line.lines = o3d.utility.Vector2iVector(bl) 585 | geom.append(bbox_line) 586 | 587 | # gizmo 588 | # gizmo = o3d.geometry.TriangleMesh.create_coordinate_frame( 589 | gizmo = o3d.geometry.create_mesh_coordinate_frame( 590 | size=1, origin=[0, 0, 0]) 591 | geom.append(gizmo) 592 | 593 | # particles 594 | pcd = o3d.geometry.PointCloud() 595 | geom.append(pcd) 596 | # pcd_idx = len(geom) 597 | # for i in range(pt.shape[1]): 598 | # mesh_sphere = o3d.geometry.TriangleMesh.create_sphere(radius=0.5, resolution=10) 599 | # geom.append(mesh_sphere) 600 | # if i == 20: break 601 | 602 | # velocity 603 | if pv is not None: 604 | line_set = o3d.geometry.LineSet() 605 | geom.append(line_set) 606 | 607 | # draw_pt.vol = None 608 | 609 | draw_pt.t = 0 610 | def loadframe(vis): 611 | print('frame', draw_pt.t) 612 | p = pt[draw_pt.t] 613 | if is_2d: 614 | pz = np.zeros([p.shape[0],1]) 615 | p = np.concatenate((p,pz), axis=-1) 616 | 617 | pcd.points = o3d.utility.Vector3dVector(p) 618 | 619 | # if draw_pt.vol is not None: vis.remove_geometry(draw_pt.vol) 620 | # draw_pt.vol = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, 1) 621 | # vis.add_geometry(draw_pt.vol) 622 | 623 | # for i in range(p.shape[0]): 624 | # geom[pcd_idx+i].translate(translation=p[i]) 625 | # geom[pcd_idx+i].compute_vertex_normals() 626 | # if i == 20: break 627 | 628 | if pc is not None: 629 | c = pc[draw_pt.t] 630 | pcd.colors = o3d.utility.Vector3dVector(c) 631 | 632 | if pv is not None: 633 | v = pv[draw_pt.t] 634 | if is_2d: 635 | vz = np.zeros([v.shape[0],1]) 636 | v = np.concatenate((v,vz), axis=-1) 637 | 638 | p_ = p + v*dt 639 | p = np.concatenate((p,p_), axis=0) 640 | l0 = np.arange(v.shape[0]) 641 | l1 = np.arange(v.shape[0],2*v.shape[0]) 642 | l = np.stack((l0,l1), axis=-1) 643 | 644 | if pc is None: 645 | c = np.sqrt(np.sum(v**2, axis=-1)) 646 | c /= c.max() 647 | c = cm.Blues(1 - c)[...,:-1] 648 | pcd.colors = o3d.utility.Vector3dVector(c) 649 | 650 | # for i in range(c.shape[0]): 651 | # geom[pcd_idx+i].paint_uniform_color(c[i]) 652 | # if i == 20: break 653 | 654 | line_set.points = o3d.utility.Vector3dVector(p) 655 | line_set.lines = o3d.utility.Vector2iVector(l) 656 | line_set.colors = o3d.utility.Vector3dVector(c) 657 | 658 | elif pc is None: 659 | c = np.zeros_like(p) 660 | # c[...,-1] = 0.8 # 661 | pcd.colors = o3d.utility.Vector3dVector(c) 662 | 663 | vis.update_geometry() 664 | vis.poll_events() 665 | vis.update_renderer() 666 | 667 | # cam = vis.get_view_control() 668 | # if cam is not None: 669 | # param = cam.convert_to_pinhole_camera_parameters() 670 | # print('intrisic', param.intrinsic.width, param.intrinsic.height) 671 | # print(param.intrinsic.intrinsic_matrix) 672 | # print('extrisic\n', param.extrinsic) 673 | 674 | # p = pt[draw_pt.t] 675 | # for i in range(p.shape[0]): 676 | # geom[pcd_idx+i].translate(translation=-p[i]) 677 | # if i == 20: break 678 | 679 | vis = o3d.visualization.Visualizer() 680 | loadframe(vis) # for the first frame 681 | 682 | def nextframe(vis): 683 | # print('nextframe') 684 | draw_pt.t += 1 685 | if draw_pt.t == len(pt): 686 | draw_pt.t = 0 687 | loadframe(vis) 688 | return False 689 | 690 | def prevframe(vis): 691 | # print('prevframe') 692 | draw_pt.t -= 1 693 | if draw_pt.t == -1: 694 | draw_pt.t = len(pt)-1 695 | loadframe(vis) 696 | return False 697 | 698 | key_to_callback = {} 699 | key_to_callback[ord(",")] = prevframe 700 | key_to_callback[ord(".")] = nextframe 701 | o3d.visualization.draw_geometries_with_key_callbacks(geom, key_to_callback) -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | ############################################################# 2 | # MIT License, Copyright © 2020, ETH Zurich, Byungsoo Kim 3 | ############################################################# 4 | 5 | # https://github.com/Ryo-Ito/spatial_transformer_network 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | 10 | # https://github.com/scipython/scipython_maths/blob/master/poisson_disc_sampled_noise/poisson.py 11 | # For mathematical details of this algorithm, please see the blog 12 | # article at https://scipython.com/blog/poisson-disc-sampling-in-python/ 13 | # Christian Hill, March 2017. 14 | class PoissonDisc(object): 15 | """A class for generating two-dimensional Possion (blue) noise).""" 16 | 17 | def __init__(self, rng, width=50, height=50, r=1, k=30): 18 | self.rng = rng 19 | self.width, self.height = width, height 20 | self.r = r 21 | self.k = k 22 | 23 | # Cell side length 24 | self.a = r/np.sqrt(2) 25 | # Number of cells in the x- and y-directions of the grid 26 | self.nx, self.ny = int(width / self.a) + 1, int(height / self.a) + 1 27 | 28 | self.reset() 29 | 30 | def reset(self): 31 | """Reset the cells dictionary.""" 32 | 33 | # A list of coordinates in the grid of cells 34 | coords_list = [(ix, iy) for ix in range(self.nx) 35 | for iy in range(self.ny)] 36 | # Initilalize the dictionary of cells: each key is a cell's coordinates 37 | # the corresponding value is the index of that cell's point's 38 | # coordinates in the samples list (or None if the cell is empty). 39 | self.cells = {coords: None for coords in coords_list} 40 | 41 | def get_cell_coords(self, pt): 42 | """Get the coordinates of the cell that pt = (x,y) falls in.""" 43 | 44 | return int(pt[0] // self.a), int(pt[1] // self.a) 45 | 46 | def get_neighbours(self, coords): 47 | """Return the indexes of points in cells neighbouring cell at coords. 48 | For the cell at coords = (x,y), return the indexes of points in the 49 | cells with neighbouring coordinates illustrated below: ie those cells 50 | that could contain points closer than r. 51 | ooo 52 | ooooo 53 | ooXoo 54 | ooooo 55 | ooo 56 | """ 57 | 58 | dxdy = [(-1,-2),(0,-2),(1,-2),(-2,-1),(-1,-1),(0,-1),(1,-1),(2,-1), 59 | (-2,0),(-1,0),(1,0),(2,0),(-2,1),(-1,1),(0,1),(1,1),(2,1), 60 | (-1,2),(0,2),(1,2),(0,0)] 61 | neighbours = [] 62 | for dx, dy in dxdy: 63 | neighbour_coords = coords[0] + dx, coords[1] + dy 64 | if not (0 <= neighbour_coords[0] < self.nx and 65 | 0 <= neighbour_coords[1] < self.ny): 66 | # We're off the grid: no neighbours here. 67 | continue 68 | neighbour_cell = self.cells[neighbour_coords] 69 | if neighbour_cell is not None: 70 | # This cell is occupied: store the index of the contained point 71 | neighbours.append(neighbour_cell) 72 | return neighbours 73 | 74 | def point_valid(self, pt): 75 | """Is pt a valid point to emit as a sample? 76 | It must be no closer than r from any other point: check the cells in 77 | its immediate neighbourhood. 78 | """ 79 | 80 | cell_coords = self.get_cell_coords(pt) 81 | for idx in self.get_neighbours(cell_coords): 82 | nearby_pt = self.samples[idx] 83 | # Squared distance between candidate point, pt, and this nearby_pt. 84 | distance2 = (nearby_pt[0]-pt[0])**2 + (nearby_pt[1]-pt[1])**2 85 | if distance2 < self.r**2: 86 | # The points are too close, so pt is not a candidate. 87 | return False 88 | # All points tested: if we're here, pt is valid 89 | return True 90 | 91 | def get_point(self, refpt): 92 | """Try to find a candidate point near refpt to emit in the sample. 93 | We draw up to k points from the annulus of inner radius r, outer radius 94 | 2r around the reference point, refpt. If none of them are suitable 95 | (because they're too close to existing points in the sample), return 96 | False. Otherwise, return the pt. 97 | """ 98 | 99 | i = 0 100 | while i < self.k: 101 | rho, theta = (self.rng.uniform(self.r, 2*self.r), 102 | self.rng.uniform(0, 2*np.pi)) 103 | pt = refpt[0] + rho*np.cos(theta), refpt[1] + rho*np.sin(theta) 104 | if not (0 < pt[0] < self.width and 0 < pt[1] < self.height): 105 | # This point falls outside the domain, so try again. 106 | continue 107 | if self.point_valid(pt): 108 | return pt 109 | i += 1 110 | # We failed to find a suitable point in the vicinity of refpt. 111 | return False 112 | 113 | def sample(self): 114 | """Poisson disc random sampling in 2D. 115 | Draw random samples on the domain width x height such that no two 116 | samples are closer than r apart. The parameter k determines the 117 | maximum number of candidate points to be chosen around each reference 118 | point before removing it from the "active" list. 119 | """ 120 | 121 | # Pick a random point to start with. 122 | pt = (self.rng.uniform(0, self.width), 123 | self.rng.uniform(0, self.height)) 124 | self.samples = [pt] 125 | # Our first sample is indexed at 0 in the samples list... 126 | self.cells[self.get_cell_coords(pt)] = 0 127 | # and it is active, in the sense that we're going to look for more 128 | # points in its neighbourhood. 129 | active = [0] 130 | 131 | # As long as there are points in the active list, keep looking for 132 | # samples. 133 | while active: 134 | # choose a random "reference" point from the active list. 135 | idx = self.rng.choice(active) 136 | refpt = self.samples[idx] 137 | # Try to pick a new point relative to the reference point. 138 | pt = self.get_point(refpt) 139 | if pt: 140 | # Point pt is valid: add it to samples list and mark as active 141 | self.samples.append(pt) 142 | nsamples = len(self.samples) - 1 143 | active.append(nsamples) 144 | self.cells[self.get_cell_coords(pt)] = nsamples 145 | else: 146 | # We had to give up looking for valid points near refpt, so 147 | # remove it from the list of "active" points. 148 | active.remove(idx) 149 | 150 | return self.samples 151 | 152 | def mgrid(*args, **kwargs): 153 | """ 154 | create orthogonal grid 155 | similar to np.mgrid 156 | 157 | Parameters 158 | ---------- 159 | args : int 160 | number of points on each axis 161 | low : float 162 | minimum coordinate value 163 | high : float 164 | maximum coordinate value 165 | 166 | Returns 167 | ------- 168 | grid : tf.Tensor [len(args), args[0], ...] 169 | orthogonal grid 170 | """ 171 | low = kwargs.pop("low", -1) 172 | high = kwargs.pop("high", 1) 173 | low = tf.cast(low, tf.float32) 174 | high = tf.cast(high, tf.float32) 175 | coords = (tf.linspace(low, high, arg) for arg in args) 176 | grid = tf.stack(tf.meshgrid(*coords, indexing='ij')) 177 | return grid 178 | 179 | 180 | def batch_mgrid(n_batch, *args, **kwargs): 181 | """ 182 | create batch of orthogonal grids 183 | similar to np.mgrid 184 | 185 | Parameters 186 | ---------- 187 | n_batch : int 188 | number of grids to create 189 | args : int 190 | number of points on each axis 191 | low : float 192 | minimum coordinate value 193 | high : float 194 | maximum coordinate value 195 | 196 | Returns 197 | ------- 198 | grids : tf.Tensor [n_batch, len(args), args[0], ...] 199 | batch of orthogonal grids 200 | """ 201 | grid = mgrid(*args, **kwargs) 202 | grid = tf.expand_dims(grid, 0) 203 | grids = tf.tile(grid, [n_batch] + [1 for _ in range(len(args) + 1)]) 204 | return grids 205 | 206 | def batch_warp2d(imgs, mappings, sample_shape): 207 | """ 208 | warp image using mapping function 209 | I(x) -> I(phi(x)) 210 | phi: mapping function 211 | 212 | Parameters 213 | ---------- 214 | imgs : tf.Tensor 215 | images to be warped 216 | [n_batch, xlen, ylen, n_channel] 217 | mapping : tf.Tensor 218 | grids representing mapping function 219 | [n_batch, xlen, ylen, 2] 220 | 221 | Returns 222 | ------- 223 | output : tf.Tensor 224 | warped images 225 | [n_batch, xlen, ylen, n_channel] 226 | """ 227 | # n_batch = tf.shape(imgs)[0] 228 | n_batch = sample_shape[0] 229 | coords = tf.reshape(mappings, [n_batch, 2, -1]) 230 | x_coords = tf.slice(coords, [0, 0, 0], [-1, 1, -1]) 231 | y_coords = tf.slice(coords, [0, 1, 0], [-1, 1, -1]) 232 | x_coords_flat = tf.reshape(x_coords, [-1]) 233 | y_coords_flat = tf.reshape(y_coords, [-1]) 234 | 235 | output = _interpolate2d(imgs, x_coords_flat, y_coords_flat, sample_shape) 236 | return output 237 | 238 | def batch_warp3d(imgs, mappings, sample_shape): 239 | """ 240 | warp image using mapping function 241 | I(x) -> I(phi(x)) 242 | phi: mapping function 243 | 244 | Parameters 245 | ---------- 246 | imgs : tf.Tensor 247 | images to be warped 248 | [n_batch, xlen, ylen, zlen, n_channel] 249 | mapping : tf.Tensor 250 | grids representing mapping function 251 | [n_batch, 3, xlen, ylen, zlen] 252 | 253 | Returns 254 | ------- 255 | output : tf.Tensor 256 | warped images 257 | [n_batch, xlen, ylen, zlen, n_channel] 258 | """ 259 | n_batch = sample_shape[0] 260 | coords = tf.reshape(mappings, [n_batch, 3, -1]) 261 | x_coords = tf.slice(coords, [0, 0, 0], [-1, 1, -1]) 262 | y_coords = tf.slice(coords, [0, 1, 0], [-1, 1, -1]) 263 | z_coords = tf.slice(coords, [0, 2, 0], [-1, 1, -1]) 264 | x_coords_flat = tf.reshape(x_coords, [-1]) 265 | y_coords_flat = tf.reshape(y_coords, [-1]) 266 | z_coords_flat = tf.reshape(z_coords, [-1]) 267 | 268 | output = _interpolate3d(imgs, x_coords_flat, y_coords_flat, z_coords_flat, sample_shape) 269 | return output 270 | 271 | def _repeat(base_indices, n_repeats): 272 | base_indices = tf.matmul( 273 | tf.reshape(base_indices, [-1, 1]), 274 | tf.ones([1, n_repeats], dtype='int32')) 275 | # tf.reshape(tf.cast(base_indices, tf.float32), [-1, 1]), 276 | # tf.ones([1, n_repeats], dtype=tf.float32)) 277 | # base_indices = tf.to_int32(base_indices) 278 | return tf.reshape(base_indices, [-1]) 279 | 280 | def _interpolate2d(imgs, x, y, sample_shape): 281 | # n_batch = tf.shape(imgs)[0] 282 | xlen = tf.shape(imgs)[1] 283 | ylen = tf.shape(imgs)[2] 284 | n_channel = tf.shape(imgs)[3] 285 | 286 | n_batch = sample_shape[0] 287 | xlen_ = sample_shape[1] 288 | ylen_ = sample_shape[2] 289 | 290 | x = tf.cast(x, tf.float32) 291 | y = tf.cast(y, tf.float32) 292 | xlen_f = tf.cast(xlen, tf.float32) 293 | ylen_f = tf.cast(ylen, tf.float32) 294 | zero = tf.zeros([], dtype='int32') 295 | max_x = tf.cast(xlen - 1, 'int32') 296 | max_y = tf.cast(ylen - 1, 'int32') 297 | 298 | # scale indices from [-1, 1] to [0, xlen/ylen] 299 | x = (x + 1.) * (xlen_f - 1.) * 0.5 300 | y = (y + 1.) * (ylen_f - 1.) * 0.5 301 | 302 | # do sampling 303 | x0 = tf.cast(tf.floor(x), 'int32') 304 | x1 = x0 + 1 305 | y0 = tf.cast(tf.floor(y), 'int32') 306 | y1 = y0 + 1 307 | 308 | x0 = tf.clip_by_value(x0, zero, max_x) 309 | x1 = tf.clip_by_value(x1, zero, max_x) 310 | y0 = tf.clip_by_value(y0, zero, max_y) 311 | y1 = tf.clip_by_value(y1, zero, max_y) 312 | base = _repeat(tf.range(n_batch) * xlen_ * ylen_, ylen_ * xlen_) 313 | base_x0 = base + x0 * ylen 314 | base_x1 = base + x1 * ylen 315 | index00 = base_x0 + y0 316 | index01 = base_x0 + y1 317 | index10 = base_x1 + y0 318 | index11 = base_x1 + y1 319 | 320 | # use indices to lookup pixels in the flat image and restore 321 | # n_channel dim 322 | imgs_flat = tf.reshape(imgs, [-1, n_channel]) 323 | imgs_flat = tf.cast(imgs_flat, tf.float32) 324 | I00 = tf.gather(imgs_flat, index00) 325 | I01 = tf.gather(imgs_flat, index01) 326 | I10 = tf.gather(imgs_flat, index10) 327 | I11 = tf.gather(imgs_flat, index11) 328 | 329 | # and finally calculate interpolated values 330 | dx = x - tf.cast(x0, tf.float32) 331 | dy = y - tf.cast(y0, tf.float32) 332 | w00 = tf.expand_dims((1. - dx) * (1. - dy), 1) 333 | w01 = tf.expand_dims((1. - dx) * dy, 1) 334 | w10 = tf.expand_dims(dx * (1. - dy), 1) 335 | w11 = tf.expand_dims(dx * dy, 1) 336 | output = tf.add_n([w00*I00, w01*I01, w10*I10, w11*I11]) 337 | 338 | # reshape 339 | output = tf.reshape(output, [n_batch, xlen_, ylen_, n_channel]) 340 | 341 | return output 342 | 343 | def _interpolate3d(imgs, x, y, z, sample_shape): 344 | # n_batch = tf.shape(imgs)[0] 345 | xlen = tf.shape(imgs)[1] 346 | ylen = tf.shape(imgs)[2] 347 | zlen = tf.shape(imgs)[3] 348 | n_channel = tf.shape(imgs)[4] 349 | 350 | n_batch = sample_shape[0] 351 | xlen_ = sample_shape[1] 352 | ylen_ = sample_shape[2] 353 | zlen_ = sample_shape[3] 354 | 355 | x = tf.cast(x, tf.float32) 356 | y = tf.cast(y, tf.float32) 357 | z = tf.cast(z, tf.float32) 358 | xlen_f = tf.cast(xlen, tf.float32) 359 | ylen_f = tf.cast(ylen, tf.float32) 360 | zlen_f = tf.cast(zlen, tf.float32) 361 | zero = tf.zeros([], dtype='int32') 362 | max_x = tf.cast(xlen - 1, 'int32') 363 | max_y = tf.cast(ylen - 1, 'int32') 364 | max_z = tf.cast(zlen - 1, 'int32') 365 | 366 | # scale indices from [-1, 1] to [0, xlen/ylen] 367 | x = (x + 1.) * (xlen_f - 1.) * 0.5 368 | y = (y + 1.) * (ylen_f - 1.) * 0.5 369 | z = (z + 1.) * (zlen_f - 1.) * 0.5 370 | 371 | # do sampling 372 | x0 = tf.cast(tf.floor(x), 'int32') 373 | x1 = x0 + 1 374 | y0 = tf.cast(tf.floor(y), 'int32') 375 | y1 = y0 + 1 376 | z0 = tf.cast(tf.floor(z), 'int32') 377 | z1 = z0 + 1 378 | 379 | x0 = tf.clip_by_value(x0, zero, max_x) 380 | x1 = tf.clip_by_value(x1, zero, max_x) 381 | y0 = tf.clip_by_value(y0, zero, max_y) 382 | y1 = tf.clip_by_value(y1, zero, max_y) 383 | z0 = tf.clip_by_value(z0, zero, max_z) 384 | z1 = tf.clip_by_value(z1, zero, max_z) 385 | base = _repeat(tf.range(n_batch) * xlen_ * ylen_ * zlen_, 386 | xlen_ * ylen_ * zlen_) 387 | base_x0 = base + x0 * ylen * zlen 388 | base_x1 = base + x1 * ylen * zlen 389 | base00 = base_x0 + y0 * zlen 390 | base01 = base_x0 + y1 * zlen 391 | base10 = base_x1 + y0 * zlen 392 | base11 = base_x1 + y1 * zlen 393 | index000 = base00 + z0 394 | index001 = base00 + z1 395 | index010 = base01 + z0 396 | index011 = base01 + z1 397 | index100 = base10 + z0 398 | index101 = base10 + z1 399 | index110 = base11 + z0 400 | index111 = base11 + z1 401 | 402 | # use indices to lookup pixels in the flat image and restore 403 | # n_channel dim 404 | imgs_flat = tf.reshape(imgs, [-1, n_channel]) 405 | imgs_flat = tf.cast(imgs_flat, tf.float32) 406 | I000 = tf.gather(imgs_flat, index000) 407 | I001 = tf.gather(imgs_flat, index001) 408 | I010 = tf.gather(imgs_flat, index010) 409 | I011 = tf.gather(imgs_flat, index011) 410 | I100 = tf.gather(imgs_flat, index100) 411 | I101 = tf.gather(imgs_flat, index101) 412 | I110 = tf.gather(imgs_flat, index110) 413 | I111 = tf.gather(imgs_flat, index111) 414 | 415 | # and finally calculate interpolated values 416 | dx = x - tf.cast(x0, tf.float32) 417 | dy = y - tf.cast(y0, tf.float32) 418 | dz = z - tf.cast(z0, tf.float32) 419 | w000 = tf.expand_dims((1. - dx) * (1. - dy) * (1. - dz), 1) 420 | w001 = tf.expand_dims((1. - dx) * (1. - dy) * dz, 1) 421 | w010 = tf.expand_dims((1. - dx) * dy * (1. - dz), 1) 422 | w011 = tf.expand_dims((1. - dx) * dy * dz, 1) 423 | w100 = tf.expand_dims(dx * (1. - dy) * (1. - dz), 1) 424 | w101 = tf.expand_dims(dx * (1. - dy) * dz, 1) 425 | w110 = tf.expand_dims(dx * dy * (1. - dz), 1) 426 | w111 = tf.expand_dims(dx * dy * dz, 1) 427 | output = tf.add_n([w000 * I000, w001 * I001, w010 * I010, w011 * I011, 428 | w100 * I100, w101 * I101, w110 * I110, w111 * I111]) 429 | 430 | # reshape 431 | output = tf.reshape(output, [n_batch, xlen_, ylen_, zlen_, n_channel]) 432 | # output = tf.concat([output]*n, axis=0) 433 | return output 434 | 435 | def batch_affine_warp2d(imgs, theta): 436 | """ 437 | affine transforms 2d images 438 | 439 | Parameters 440 | ---------- 441 | imgs : tf.Tensor 442 | images to be warped 443 | [n_batch, xlen, ylen, n_channel] 444 | theta : tf.Tensor 445 | parameters of affine transformation 446 | [n_batch, 6] 447 | 448 | Returns 449 | ------- 450 | output : tf.Tensor 451 | warped images 452 | [n_batch, xlen, ylen, n_channel] 453 | """ 454 | n_batch = tf.shape(imgs)[0] 455 | xlen = tf.shape(imgs)[1] 456 | ylen = tf.shape(imgs)[2] 457 | theta = tf.reshape(theta, [-1, 2, 3]) 458 | matrix = tf.slice(theta, [0, 0, 0], [-1, -1, 2]) 459 | t = tf.slice(theta, [0, 0, 2], [-1, -1, -1]) 460 | 461 | grids = batch_mgrid(n_batch, xlen, ylen) 462 | coords = tf.reshape(grids, [n_batch, 2, -1]) 463 | 464 | T_g = tf.matmul(matrix, coords) + t 465 | T_g = tf.reshape(T_g, [n_batch, 2, xlen, ylen]) 466 | with tf.Session() as sess: 467 | print(sess.run(T_g)) 468 | 469 | output = batch_warp2d(imgs, T_g) 470 | return output 471 | 472 | 473 | def batch_affine_warp3d(imgs, theta): 474 | """ 475 | affine transforms 3d images 476 | 477 | Parameters 478 | ---------- 479 | imgs : tf.Tensor 480 | images to be warped 481 | [n_batch, xlen, ylen, zlen, n_channel] 482 | theta : tf.Tensor 483 | parameters of affine transformation 484 | [n_batch, 12] 485 | 486 | Returns 487 | ------- 488 | output : tf.Tensor 489 | warped images 490 | [n_batch, xlen, ylen, zlen, n_channel] 491 | """ 492 | n_batch = tf.shape(imgs)[0] 493 | xlen = tf.shape(imgs)[1] 494 | ylen = tf.shape(imgs)[2] 495 | zlen = tf.shape(imgs)[3] 496 | theta = tf.reshape(theta, [-1, 3, 4]) 497 | matrix = tf.slice(theta, [0, 0, 0], [-1, -1, 3]) 498 | t = tf.slice(theta, [0, 0, 3], [-1, -1, -1]) 499 | 500 | grids = batch_mgrid(n_batch, xlen, ylen, zlen) 501 | grids = tf.reshape(grids, [n_batch, 3, -1]) 502 | 503 | T_g = tf.matmul(matrix, grids) + t 504 | T_g = tf.reshape(T_g, [n_batch, 3, xlen, ylen, zlen]) 505 | output = batch_warp3d(imgs, T_g) 506 | return output 507 | 508 | def grad(p): 509 | dx = p[:,:,:,1:] - p[:,:,:,:-1] 510 | dy = p[:,:,1:,:] - p[:,:,:-1,:] 511 | dz = p[:,1:,:,:] - p[:,:-1,:,:] 512 | dx = tf.concat((dx, tf.expand_dims(dx[:,:,:,-1], axis=3)), axis=3) 513 | dy = tf.concat((dy, tf.expand_dims(dy[:,:,-1,:], axis=2)), axis=2) 514 | dz = tf.concat((dz, tf.expand_dims(dz[:,-1,:,:], axis=1)), axis=1) 515 | return tf.concat([dx,dy,dz], axis=-1) 516 | 517 | def curl(s, is_2d=True): 518 | if is_2d: 519 | # s: [B,H,W,1] 520 | u = s[:,1:,:,0] - s[:,:-1,:,0] # ds/dy 521 | v = s[:,:,:-1,0] - s[:,:,1:,0] # -ds/dx, 522 | u = tf.concat([u, tf.expand_dims(u[:,-1,:], axis=1)], axis=1) 523 | v = tf.concat([v, tf.expand_dims(v[:,:,-1], axis=2)], axis=2) 524 | return tf.stack([u,v], axis=-1) 525 | else: 526 | # s: [B,D,H,W,3] 527 | # dudx = s[:,:,:,1:,0] - s[:,:,:,:-1,0] 528 | dvdx = s[:,:,:,1:,1] - s[:,:,:,:-1,1] 529 | dwdx = s[:,:,:,1:,2] - s[:,:,:,:-1,2] 530 | 531 | dudy = s[:,:,1:,:,0] - s[:,:,:-1,:,0] 532 | # dvdy = s[:,:,1:,:,1] - s[:,:,:-1,:,1] 533 | dwdy = s[:,:,1:,:,2] - s[:,:,:-1,:,2] 534 | 535 | dudz = s[:,1:,:,:,0] - s[:,:-1,:,:,0] 536 | dvdz = s[:,1:,:,:,1] - s[:,:-1,:,:,1] 537 | # dwdz = s[:,1:,:,:,2] - s[:,:-1,:,:,2] 538 | 539 | # dudx = tf.concat((dudx, tf.expand_dims(dudx[:,:,:,-1], axis=3)), axis=3) 540 | dvdx = tf.concat((dvdx, tf.expand_dims(dvdx[:,:,:,-1], axis=3)), axis=3) 541 | dwdx = tf.concat((dwdx, tf.expand_dims(dwdx[:,:,:,-1], axis=3)), axis=3) 542 | 543 | dudy = tf.concat((dudy, tf.expand_dims(dudy[:,:,-1,:], axis=2)), axis=2) 544 | # dvdy = tf.concat((dvdy, tf.expand_dims(dvdy[:,:,-1,:], axis=2)), axis=2) 545 | dwdy = tf.concat((dwdy, tf.expand_dims(dwdy[:,:,-1,:], axis=2)), axis=2) 546 | 547 | dudz = tf.concat((dudz, tf.expand_dims(dudz[:,-1,:,:], axis=1)), axis=1) 548 | dvdz = tf.concat((dvdz, tf.expand_dims(dvdz[:,-1,:,:], axis=1)), axis=1) 549 | # dwdz = tf.concat((dwdz, tf.expand_dims(dwdz[:,-1,:,:], axis=1)), axis=1) 550 | 551 | u = dwdy - dvdz 552 | v = dudz - dwdx 553 | w = dvdx - dudy 554 | 555 | return tf.stack([u,v,w], axis=-1) 556 | 557 | def advect(d, vel, order=1, is_3d=False): 558 | n_batch = 1 # assert(tf.shape(d)[0] == 1) 559 | xlen = tf.shape(d)[1] 560 | ylen = tf.shape(d)[2] 561 | 562 | if is_3d: 563 | zlen = tf.shape(d)[3] 564 | grids = batch_mgrid(n_batch, xlen, ylen, zlen) # [b,3,u,v,w] 565 | vel = tf.transpose(vel, [0,4,1,2,3]) # [b,u,v,w,3] -> [b,3,u,v,w] 566 | grids -= vel # p' = p - v*dt, dt = 1 567 | 568 | if order == 1: # semi-lagrangian 569 | d_adv = batch_warp3d(d, grids, [n_batch, xlen, ylen, zlen]) 570 | else: # maccormack 571 | d_fwd = batch_warp3d(d, grids, [n_batch, xlen, ylen, zlen]) 572 | grids_ = batch_mgrid(n_batch, xlen, ylen, zlen) + vel 573 | d_bwd = batch_warp3d(d_fwd, grids_, [n_batch, xlen, ylen, zlen]) 574 | d_adv = d_fwd + (d-d_bwd)*0.5 575 | d_max = tf.nn.max_pool3d(d, ksize=(1,2,2,2,1), strides=(1,1,1,1,1), padding='SAME') 576 | d_min = -tf.nn.max_pool3d(-d, ksize=(1,2,2,2,1), strides=(1,1,1,1,1), padding='SAME') 577 | grids = tf.to_int32(grids) 578 | d_max = batch_warp3d(d_max, grids, [n_batch, xlen, ylen, zlen]) 579 | d_min = batch_warp3d(d_min, grids, [n_batch, xlen, ylen, zlen]) 580 | d_max = tf.greater(d_adv, d_max) 581 | d_min = tf.greater(d_min, d_adv) 582 | d_adv = tf.where(tf.logical_or(d_min,d_max), d_fwd, d_adv) 583 | else: 584 | grids = batch_mgrid(n_batch, xlen, ylen) # [b,2,u,v] 585 | vel = tf.transpose(vel, [0,3,1,2]) # [b,u,v,2] -> [b,2,u,v] 586 | grids -= vel # p' = p - v*dt, dt = 1 587 | 588 | if order == 1: 589 | d_adv = batch_warp2d(d, grids, [n_batch, xlen, ylen]) 590 | else: 591 | d_fwd = batch_warp2d(d, grids, [n_batch, xlen, ylen]) 592 | grids_ = batch_mgrid(n_batch, xlen, ylen) + vel 593 | d_bwd = batch_warp2d(d_fwd, grids_, [n_batch, xlen, ylen]) 594 | # flags = tf.clip_by_value(tf.math.ceil(d_fwd), 0, 1) 595 | d_adv = d_fwd + (d-d_bwd)*0.5 596 | d_max = tf.nn.max_pool(d, ksize=(1,2,2,1), strides=(1,1,1,1), padding='SAME') 597 | d_min = -tf.nn.max_pool(-d, ksize=(1,2,2,1), strides=(1,1,1,1), padding='SAME') 598 | grids = tf.to_int32(grids) 599 | # d_max = batch_warp2d(d_max, grids, [n_batch, xlen, ylen]) 600 | # d_min = batch_warp2d(d_min, grids, [n_batch, xlen, ylen]) 601 | d_max, d_min = d_max[grids], d_max[grids] 602 | # # hard clamp 603 | # d_adv = tf.clip_by_value(d_adv, d_min, d_max) 604 | # soft clamp 605 | d_max = tf.greater(d_adv, d_max) # find values larger than max (true if x > y) 606 | d_min = tf.greater(d_min, d_adv) # find values smaller than min (true if x > y) 607 | d_adv = tf.where(tf.logical_or(d_min,d_max), d_fwd, d_adv) # *flags 608 | 609 | return d_adv 610 | 611 | def rotate(d): 612 | b = tf.shape(d)[0] 613 | xlen = tf.shape(d)[1] 614 | ylen = tf.shape(d)[2] 615 | zlen = tf.shape(d)[3] 616 | 617 | rot_mat = tf.placeholder(shape=[None,3,3], dtype=tf.float32) 618 | n_rot = tf.shape(rot_mat)[0] 619 | n_batch = b*n_rot 620 | 621 | d = tf.tile(d, [n_rot,1,1,1,1]) 622 | r = tf.tile(rot_mat, [b,1,1]) 623 | grids = batch_mgrid(n_batch, xlen, ylen, zlen) # [b,3,u,v,w] 624 | grids = tf.reshape(grids, [n_batch, 3, -1]) 625 | grids = tf.matmul(r, grids) 626 | grids = tf.reshape(grids, [n_batch, 3, xlen, ylen, zlen]) 627 | d_rot = batch_warp3d(d, grids, [n_batch, xlen, ylen, zlen]) 628 | return d_rot, rot_mat 629 | 630 | def subsample(d, scale): 631 | n_batch = tf.shape(d)[0] 632 | xlen = tf.to_int32( 633 | tf.multiply(tf.cast(tf.shape(d)[1], tf.float32),scale)) 634 | ylen = tf.to_int32( 635 | tf.multiply(tf.cast(tf.shape(d)[2], tf.float32),scale)) 636 | grids = batch_mgrid(n_batch, xlen, ylen) # [b,2,u,v] 637 | d_sample = batch_warp2d(d, grids, [n_batch, xlen, ylen]) 638 | return d_sample 639 | 640 | def rot_z_3d(deg): 641 | rad = deg/180.0*np.pi 642 | c = np.cos(rad) 643 | s = np.sin(rad) 644 | rot_mat = np.array([ 645 | [c,-s,0], 646 | [s,c,0], 647 | [0,0,1]]) 648 | return rot_mat 649 | 650 | def rot_y_3d(deg): 651 | rad = deg/180.0*np.pi 652 | c = np.cos(rad) 653 | s = np.sin(rad) 654 | rot_mat = np.array([ 655 | [c,0,-s], 656 | [0,1,0], 657 | [s,0,c]]) 658 | return rot_mat 659 | 660 | def rot_x_3d(deg): 661 | rad = deg/180.0*np.pi 662 | c = np.cos(rad) 663 | s = np.sin(rad) 664 | rot_mat = np.array([ 665 | [1,0,0], 666 | [0,c,-s], 667 | [0,s,c]]) 668 | return rot_mat 669 | 670 | def scale(s): 671 | s_mat = np.array([ 672 | [s,0,0], 673 | [0,s,0], 674 | [0,0,s]]) 675 | return s_mat 676 | 677 | 678 | def rot_mat_turb(theta_unit, poisson_sample=False, rng=None): 679 | views = [{'theta':0}, {'theta':90}, {'theta':180}] 680 | if poisson_sample: 681 | views += rot_mat_poisson(0,0,0, 0, 180, theta_unit, rng) 682 | 683 | mat = [] 684 | for view in views: 685 | theta = view['theta'] 686 | mat.append(rot_y_3d(theta)) 687 | return mat, views 688 | 689 | def rot_mat(phi0, phi1, phi_unit, theta0, theta1, theta_unit, 690 | sample_type='uniform', rng=None, nv=None): 691 | 692 | if 'uniform' in sample_type: 693 | views = rot_mat_uniform(phi0, phi1, phi_unit, theta0, theta1, theta_unit) 694 | elif 'poisson' in sample_type: 695 | views = rot_mat_poisson(phi0, phi1, phi_unit, theta0, theta1, theta_unit, rng) 696 | views += rot_mat_uniform(phi0, phi1, 0, theta0, theta1, 0) # [midpoint] 697 | if nv is not None: 698 | if len(views) > nv: 699 | views = views[len(views)-nv:] 700 | elif len(views) < nv: 701 | views_ = rot_mat_poisson(phi0, phi1, phi_unit, theta0, theta1, theta_unit, rng) 702 | views += views_[:nv-len(views)] 703 | else: # both 704 | views = rot_mat_uniform(phi0, phi1, phi_unit, theta0, theta1, theta_unit) 705 | views += rot_mat_poisson(phi0, phi1, phi_unit*2, theta0, theta1, theta_unit*2, rng) 706 | if nv is not None: 707 | if len(views) > nv: 708 | views = views[len(views)-nv:] 709 | elif len(views) < nv: 710 | views_ = rot_mat_poisson(phi0, phi1, phi_unit*2, theta0, theta1, theta_unit*2, rng) 711 | views += views_[:nv-len(views)] 712 | 713 | mat = [] 714 | for view in views: 715 | phi, theta = view['phi'], view['theta'] 716 | rz = rot_z_3d(phi) 717 | ry = rot_y_3d(theta) 718 | rot_mat = np.matmul(ry,rz) 719 | # s = scale(3) 720 | # rot_mat = np.matmul(s, rot_mat) 721 | mat.append(rot_mat) 722 | return mat, views 723 | 724 | def rot_mat_poisson(phi0, phi1, phi_unit, theta0, theta1, theta_unit, rng): 725 | if phi_unit == 0: 726 | h = 1 727 | phi0 = -0.5 728 | else: 729 | h = phi1 - phi0 730 | 731 | if theta_unit == 0: 732 | w = 1 733 | theta0 = -0.5 734 | else: 735 | w = theta1 - theta0 736 | 737 | r = max(phi_unit, theta_unit)/2 738 | 739 | p = PoissonDisc(rng, height=h, width=w, r=r) 740 | s = p.sample() 741 | 742 | views = [] 743 | for s_ in s: 744 | phi_ = s_[1]+phi0 745 | theta_ = s_[0]+theta0 746 | views.append({'phi':phi_, 'theta':theta_}) 747 | 748 | return views 749 | 750 | def rot_mat_uniform(phi0, phi1, phi_unit, theta0, theta1, theta_unit): 751 | if phi_unit == 0: 752 | phi = [(phi1-phi0)/2] 753 | else: 754 | n_phi = np.abs(phi1-phi0) / float(phi_unit) + 1 755 | phi = np.linspace(phi0, phi1, n_phi, endpoint=True) 756 | 757 | if theta_unit == 0: 758 | theta = [(theta1-theta0)/2] 759 | else: 760 | n_theta = np.abs(theta1-theta0) / float(theta_unit) + 1 761 | theta = np.linspace(theta0, theta1, n_theta, endpoint=True) 762 | 763 | views = [] 764 | for phi_ in phi: 765 | for theta_ in theta: 766 | views.append({'phi':phi_, 'theta':theta_}) 767 | 768 | return views 769 | 770 | 771 | def g2p(g, p, is_2d=True, is_linear=False): 772 | 773 | if is_linear: 774 | return g2p_linear(g, p, is_2d) 775 | else: 776 | return g2p_cubic(g, p, is_2d) 777 | 778 | def g2p_cubic(g, p, is_2d=True): 779 | n_batch = 1 # tf.shape(g)[0] 780 | xlen = tf.shape(g)[1] 781 | ylen = tf.shape(g)[2] 782 | if is_2d: 783 | n_channel = tf.shape(g)[3] 784 | else: 785 | zlen = tf.shape(g)[3] 786 | n_channel = tf.shape(g)[4] 787 | pn = tf.shape(p)[1] 788 | 789 | x = tf.cast(p[0,...,0], tf.float32) # [0-1] 790 | y = tf.cast(p[0,...,1], tf.float32) 791 | if not is_2d: 792 | z = tf.cast(p[0,...,2], tf.float32) 793 | 794 | # scale to g 795 | xlen_f = tf.cast(xlen, tf.float32) 796 | ylen_f = tf.cast(ylen, tf.float32) 797 | x *= xlen_f 798 | y *= ylen_f 799 | if not is_2d: 800 | zlen_f = tf.cast(zlen, tf.float32) 801 | z *= zlen_f 802 | 803 | # do sampling 804 | zero = tf.zeros([], dtype='int32') 805 | max_x = tf.cast(xlen - 1, 'int32') 806 | max_y = tf.cast(ylen - 1, 'int32') 807 | if not is_2d: 808 | max_z = tf.cast(zlen - 1, 'int32') 809 | 810 | # shifted index to interpolate cell centers 811 | x1 = tf.cast(tf.floor(x - 0.5), 'int32') 812 | x0 = x1 - 1 813 | x2 = x1 + 1 814 | x3 = x1 + 2 815 | y1 = tf.cast(tf.floor(y - 0.5), 'int32') 816 | y0 = y1 - 1 817 | y2 = y1 + 1 818 | y3 = y1 + 2 819 | if not is_2d: 820 | z1 = tf.cast(tf.floor(z - 0.5), 'int32') 821 | z0 = z1 - 1 822 | z2 = z1 + 1 823 | z3 = z1 + 2 824 | 825 | x0 = tf.clip_by_value(x0, zero, max_x) 826 | x1 = tf.clip_by_value(x1, zero, max_x) 827 | x2 = tf.clip_by_value(x2, zero, max_x) 828 | x3 = tf.clip_by_value(x3, zero, max_x) 829 | y0 = tf.clip_by_value(y0, zero, max_y) 830 | y1 = tf.clip_by_value(y1, zero, max_y) 831 | y2 = tf.clip_by_value(y2, zero, max_y) 832 | y3 = tf.clip_by_value(y3, zero, max_y) 833 | if not is_2d: 834 | z0 = tf.clip_by_value(z0, zero, max_z) 835 | z1 = tf.clip_by_value(z1, zero, max_z) 836 | z2 = tf.clip_by_value(z2, zero, max_z) 837 | z3 = tf.clip_by_value(z3, zero, max_z) 838 | 839 | # compute flat indices 840 | if is_2d: 841 | base = _repeat(tf.range(n_batch)*xlen*ylen, pn) 842 | base_x0 = base + x0 * ylen 843 | base_x1 = base + x1 * ylen 844 | base_x2 = base + x2 * ylen 845 | base_x3 = base + x3 * ylen 846 | 847 | index00 = base_x0 + y0 848 | index01 = base_x0 + y1 849 | index02 = base_x0 + y2 850 | index03 = base_x0 + y3 851 | index10 = base_x1 + y0 852 | index11 = base_x1 + y1 853 | index12 = base_x1 + y2 854 | index13 = base_x1 + y3 855 | index20 = base_x2 + y0 856 | index21 = base_x2 + y1 857 | index22 = base_x2 + y2 858 | index23 = base_x2 + y3 859 | index30 = base_x3 + y0 860 | index31 = base_x3 + y1 861 | index32 = base_x3 + y2 862 | index33 = base_x3 + y3 863 | else: 864 | base = _repeat(tf.range(n_batch)*xlen*ylen*zlen, pn) 865 | base_x0 = base + x0 * ylen * zlen 866 | base_x1 = base + x1 * ylen * zlen 867 | base_x2 = base + x2 * ylen * zlen 868 | base_x3 = base + x3 * ylen * zlen 869 | 870 | base00 = base_x0 + y0 * zlen 871 | base01 = base_x0 + y1 * zlen 872 | base02 = base_x0 + y2 * zlen 873 | base03 = base_x0 + y3 * zlen 874 | base10 = base_x1 + y0 * zlen 875 | base11 = base_x1 + y1 * zlen 876 | base12 = base_x1 + y2 * zlen 877 | base13 = base_x1 + y3 * zlen 878 | base20 = base_x2 + y0 * zlen 879 | base21 = base_x2 + y1 * zlen 880 | base22 = base_x2 + y2 * zlen 881 | base23 = base_x2 + y3 * zlen 882 | base30 = base_x3 + y0 * zlen 883 | base31 = base_x3 + y1 * zlen 884 | base32 = base_x3 + y2 * zlen 885 | base33 = base_x3 + y3 * zlen 886 | 887 | index000 = base00 + z0 888 | index001 = base00 + z1 889 | index002 = base00 + z2 890 | index003 = base00 + z3 891 | 892 | index010 = base01 + z0 893 | index011 = base01 + z1 894 | index012 = base01 + z2 895 | index013 = base01 + z3 896 | 897 | index020 = base02 + z0 898 | index021 = base02 + z1 899 | index022 = base02 + z2 900 | index023 = base02 + z3 901 | 902 | index030 = base03 + z0 903 | index031 = base03 + z1 904 | index032 = base03 + z2 905 | index033 = base03 + z3 906 | 907 | index100 = base10 + z0 908 | index101 = base10 + z1 909 | index102 = base10 + z2 910 | index103 = base10 + z3 911 | 912 | index110 = base11 + z0 913 | index111 = base11 + z1 914 | index112 = base11 + z2 915 | index113 = base11 + z3 916 | 917 | index120 = base12 + z0 918 | index121 = base12 + z1 919 | index122 = base12 + z2 920 | index123 = base12 + z3 921 | 922 | index130 = base13 + z0 923 | index131 = base13 + z1 924 | index132 = base13 + z2 925 | index133 = base13 + z3 926 | 927 | index200 = base20 + z0 928 | index201 = base20 + z1 929 | index202 = base20 + z2 930 | index203 = base20 + z3 931 | 932 | index210 = base21 + z0 933 | index211 = base21 + z1 934 | index212 = base21 + z2 935 | index213 = base21 + z3 936 | 937 | index220 = base22 + z0 938 | index221 = base22 + z1 939 | index222 = base22 + z2 940 | index223 = base22 + z3 941 | 942 | index230 = base23 + z0 943 | index231 = base23 + z1 944 | index232 = base23 + z2 945 | index233 = base23 + z3 946 | 947 | index300 = base30 + z0 948 | index301 = base30 + z1 949 | index302 = base30 + z2 950 | index303 = base30 + z3 951 | 952 | index310 = base31 + z0 953 | index311 = base31 + z1 954 | index312 = base31 + z2 955 | index313 = base31 + z3 956 | 957 | index320 = base32 + z0 958 | index321 = base32 + z1 959 | index322 = base32 + z2 960 | index323 = base32 + z3 961 | 962 | index330 = base33 + z0 963 | index331 = base33 + z1 964 | index332 = base33 + z2 965 | index333 = base33 + z3 966 | 967 | # use indices to lookup pixels in the flat image and restore 968 | # n_channel dim 969 | g_flat = tf.reshape(g, [-1, n_channel]) 970 | g_flat = tf.cast(g_flat, tf.float32) 971 | 972 | def _hermite(A, B, C, D, t): 973 | # https://github.com/iwyoo/bicubic_interp-tensorflow/blob/master/bicubic_interp.py 974 | a = A * (-0.5) + B * 1.5 + C * (-1.5) + D * 0.5 975 | b = A + B * (-2.5) + C * 2.0 + D * (-0.5) 976 | c = A * (-0.5) + C * 0.5 977 | d = B 978 | return a*t*t*t + b*t*t + c*t + d 979 | 980 | if is_2d: 981 | I00 = tf.gather(g_flat, index00) 982 | I01 = tf.gather(g_flat, index01) 983 | I02 = tf.gather(g_flat, index02) 984 | I03 = tf.gather(g_flat, index03) 985 | I10 = tf.gather(g_flat, index10) 986 | I11 = tf.gather(g_flat, index11) 987 | I12 = tf.gather(g_flat, index12) 988 | I13 = tf.gather(g_flat, index13) 989 | I20 = tf.gather(g_flat, index20) 990 | I21 = tf.gather(g_flat, index21) 991 | I22 = tf.gather(g_flat, index22) 992 | I23 = tf.gather(g_flat, index23) 993 | I30 = tf.gather(g_flat, index30) 994 | I31 = tf.gather(g_flat, index31) 995 | I32 = tf.gather(g_flat, index32) 996 | I33 = tf.gather(g_flat, index33) 997 | 998 | # and finally calculate interpolated values 999 | dx = x - (tf.cast(x1, tf.float32) + 0.5) 1000 | dx = tf.expand_dims(dx, axis=-1) 1001 | I0 = _hermite(I00, I10, I20, I30, dx) 1002 | I1 = _hermite(I01, I11, I21, I31, dx) 1003 | I2 = _hermite(I02, I12, I22, I32, dx) 1004 | I3 = _hermite(I03, I13, I23, I33, dx) 1005 | 1006 | dy = y - (tf.cast(y1, tf.float32) + 0.5) 1007 | dy = tf.expand_dims(dy, axis=-1) 1008 | output = _hermite(I0, I1, I2, I3, dy) 1009 | else: 1010 | I000 = tf.gather(g_flat, index000) 1011 | I001 = tf.gather(g_flat, index001) 1012 | I002 = tf.gather(g_flat, index002) 1013 | I003 = tf.gather(g_flat, index003) 1014 | I010 = tf.gather(g_flat, index010) 1015 | I011 = tf.gather(g_flat, index011) 1016 | I012 = tf.gather(g_flat, index012) 1017 | I013 = tf.gather(g_flat, index013) 1018 | I020 = tf.gather(g_flat, index020) 1019 | I021 = tf.gather(g_flat, index021) 1020 | I022 = tf.gather(g_flat, index022) 1021 | I023 = tf.gather(g_flat, index023) 1022 | I030 = tf.gather(g_flat, index030) 1023 | I031 = tf.gather(g_flat, index031) 1024 | I032 = tf.gather(g_flat, index032) 1025 | I033 = tf.gather(g_flat, index033) 1026 | I100 = tf.gather(g_flat, index100) 1027 | I101 = tf.gather(g_flat, index101) 1028 | I102 = tf.gather(g_flat, index102) 1029 | I103 = tf.gather(g_flat, index103) 1030 | I110 = tf.gather(g_flat, index110) 1031 | I111 = tf.gather(g_flat, index111) 1032 | I112 = tf.gather(g_flat, index112) 1033 | I113 = tf.gather(g_flat, index113) 1034 | I120 = tf.gather(g_flat, index120) 1035 | I121 = tf.gather(g_flat, index121) 1036 | I122 = tf.gather(g_flat, index122) 1037 | I123 = tf.gather(g_flat, index123) 1038 | I130 = tf.gather(g_flat, index130) 1039 | I131 = tf.gather(g_flat, index131) 1040 | I132 = tf.gather(g_flat, index132) 1041 | I133 = tf.gather(g_flat, index133) 1042 | I200 = tf.gather(g_flat, index200) 1043 | I201 = tf.gather(g_flat, index201) 1044 | I202 = tf.gather(g_flat, index202) 1045 | I203 = tf.gather(g_flat, index203) 1046 | I210 = tf.gather(g_flat, index210) 1047 | I211 = tf.gather(g_flat, index211) 1048 | I212 = tf.gather(g_flat, index212) 1049 | I213 = tf.gather(g_flat, index213) 1050 | I220 = tf.gather(g_flat, index220) 1051 | I221 = tf.gather(g_flat, index221) 1052 | I222 = tf.gather(g_flat, index222) 1053 | I223 = tf.gather(g_flat, index223) 1054 | I230 = tf.gather(g_flat, index230) 1055 | I231 = tf.gather(g_flat, index231) 1056 | I232 = tf.gather(g_flat, index232) 1057 | I233 = tf.gather(g_flat, index233) 1058 | I300 = tf.gather(g_flat, index300) 1059 | I301 = tf.gather(g_flat, index301) 1060 | I302 = tf.gather(g_flat, index302) 1061 | I303 = tf.gather(g_flat, index303) 1062 | I310 = tf.gather(g_flat, index310) 1063 | I311 = tf.gather(g_flat, index311) 1064 | I312 = tf.gather(g_flat, index312) 1065 | I313 = tf.gather(g_flat, index313) 1066 | I320 = tf.gather(g_flat, index320) 1067 | I321 = tf.gather(g_flat, index321) 1068 | I322 = tf.gather(g_flat, index322) 1069 | I323 = tf.gather(g_flat, index323) 1070 | I330 = tf.gather(g_flat, index330) 1071 | I331 = tf.gather(g_flat, index331) 1072 | I332 = tf.gather(g_flat, index332) 1073 | I333 = tf.gather(g_flat, index333) 1074 | 1075 | # and finally calculate interpolated values 1076 | dx = x - (tf.cast(x1, tf.float32) + 0.5) 1077 | dx = tf.expand_dims(dx, axis=-1) 1078 | I00 = _hermite(I000, I100, I200, I300, dx) 1079 | I01 = _hermite(I001, I101, I201, I301, dx) 1080 | I02 = _hermite(I002, I102, I202, I302, dx) 1081 | I03 = _hermite(I003, I103, I203, I303, dx) 1082 | I10 = _hermite(I010, I110, I210, I310, dx) 1083 | I11 = _hermite(I011, I111, I211, I311, dx) 1084 | I12 = _hermite(I012, I112, I212, I312, dx) 1085 | I13 = _hermite(I013, I113, I213, I313, dx) 1086 | I20 = _hermite(I020, I120, I220, I320, dx) 1087 | I21 = _hermite(I021, I121, I221, I321, dx) 1088 | I22 = _hermite(I022, I122, I222, I322, dx) 1089 | I23 = _hermite(I023, I123, I223, I323, dx) 1090 | I30 = _hermite(I030, I130, I230, I330, dx) 1091 | I31 = _hermite(I031, I131, I231, I331, dx) 1092 | I32 = _hermite(I032, I132, I232, I332, dx) 1093 | I33 = _hermite(I033, I133, I233, I333, dx) 1094 | 1095 | dy = y - (tf.cast(y1, tf.float32) + 0.5) 1096 | dy = tf.expand_dims(dy, axis=-1) 1097 | I0 = _hermite(I00, I10, I20, I30, dy) 1098 | I1 = _hermite(I01, I11, I21, I31, dy) 1099 | I2 = _hermite(I02, I12, I22, I32, dy) 1100 | I3 = _hermite(I03, I13, I23, I33, dy) 1101 | 1102 | dz = z - (tf.cast(z1, tf.float32) + 0.5) 1103 | dz = tf.expand_dims(dz, axis=-1) 1104 | output = _hermite(I0, I1, I2, I3, dz) 1105 | 1106 | # reshape 1107 | output = tf.reshape(output, [n_batch, pn, n_channel]) 1108 | return output 1109 | 1110 | def g2p_linear(g, p, is_2d=True): 1111 | n_batch = 1 # tf.shape(g)[0] 1112 | xlen = tf.shape(g)[1] 1113 | ylen = tf.shape(g)[2] 1114 | if is_2d: 1115 | n_channel = tf.shape(g)[3] 1116 | else: 1117 | zlen = tf.shape(g)[3] 1118 | n_channel = tf.shape(g)[4] 1119 | pn = tf.shape(p)[1] 1120 | 1121 | x = tf.cast(p[0,...,0], tf.float32) # [0-1] 1122 | y = tf.cast(p[0,...,1], tf.float32) 1123 | if not is_2d: 1124 | z = tf.cast(p[0,...,2], tf.float32) 1125 | 1126 | # scale to g 1127 | xlen_f = tf.cast(xlen, tf.float32) 1128 | ylen_f = tf.cast(ylen, tf.float32) 1129 | x *= xlen_f 1130 | y *= ylen_f 1131 | if not is_2d: 1132 | zlen_f = tf.cast(zlen, tf.float32) 1133 | z *= zlen_f 1134 | 1135 | # do sampling 1136 | zero = tf.zeros([], dtype='int32') 1137 | max_x = tf.cast(xlen - 1, 'int32') 1138 | max_y = tf.cast(ylen - 1, 'int32') 1139 | if not is_2d: 1140 | max_z = tf.cast(zlen - 1, 'int32') 1141 | 1142 | # shifted index to interpolate cell centers 1143 | x0 = tf.cast(tf.floor(x - 0.5), 'int32') 1144 | x1 = x0 + 1 1145 | y0 = tf.cast(tf.floor(y - 0.5), 'int32') 1146 | y1 = y0 + 1 1147 | if not is_2d: 1148 | z0 = tf.cast(tf.floor(z - 0.5), 'int32') 1149 | z1 = z0 + 1 1150 | 1151 | x0 = tf.clip_by_value(x0, zero, max_x) 1152 | x1 = tf.clip_by_value(x1, zero, max_x) 1153 | y0 = tf.clip_by_value(y0, zero, max_y) 1154 | y1 = tf.clip_by_value(y1, zero, max_y) 1155 | if not is_2d: 1156 | z0 = tf.clip_by_value(z0, zero, max_z) 1157 | z1 = tf.clip_by_value(z1, zero, max_z) 1158 | 1159 | # compute flat indices 1160 | if is_2d: 1161 | base = _repeat(tf.range(n_batch)*xlen*ylen, pn) 1162 | base_x0 = base + x0 * ylen 1163 | base_x1 = base + x1 * ylen 1164 | index00 = base_x0 + y0 1165 | index01 = base_x0 + y1 1166 | index10 = base_x1 + y0 1167 | index11 = base_x1 + y1 1168 | else: 1169 | base = _repeat(tf.range(n_batch)*xlen*ylen*zlen, pn) 1170 | base_x0 = base + x0 * ylen * zlen 1171 | base_x1 = base + x1 * ylen * zlen 1172 | base00 = base_x0 + y0 * zlen 1173 | base01 = base_x0 + y1 * zlen 1174 | base10 = base_x1 + y0 * zlen 1175 | base11 = base_x1 + y1 * zlen 1176 | index000 = base00 + z0 1177 | index001 = base00 + z1 1178 | index010 = base01 + z0 1179 | index011 = base01 + z1 1180 | index100 = base10 + z0 1181 | index101 = base10 + z1 1182 | index110 = base11 + z0 1183 | index111 = base11 + z1 1184 | 1185 | # use indices to lookup pixels in the flat image and restore 1186 | # n_channel dim 1187 | g_flat = tf.reshape(g, [-1, n_channel]) 1188 | g_flat = tf.cast(g_flat, tf.float32) 1189 | 1190 | if is_2d: 1191 | I00 = tf.gather(g_flat, index00) 1192 | I01 = tf.gather(g_flat, index01) 1193 | I10 = tf.gather(g_flat, index10) 1194 | I11 = tf.gather(g_flat, index11) 1195 | 1196 | # and finally calculate interpolated values 1197 | dx = x - (tf.cast(x0, tf.float32) + 0.5) 1198 | dy = y - (tf.cast(y0, tf.float32) + 0.5) 1199 | w00 = tf.expand_dims((1. - dx) * (1. - dy), 1) 1200 | w01 = tf.expand_dims((1. - dx) * dy, 1) 1201 | w10 = tf.expand_dims(dx * (1. - dy), 1) 1202 | w11 = tf.expand_dims(dx * dy, 1) 1203 | output = tf.add_n([w00*I00, w01*I01, w10*I10, w11*I11]) 1204 | else: 1205 | I000 = tf.gather(g_flat, index000) 1206 | I001 = tf.gather(g_flat, index001) 1207 | I010 = tf.gather(g_flat, index010) 1208 | I011 = tf.gather(g_flat, index011) 1209 | I100 = tf.gather(g_flat, index100) 1210 | I101 = tf.gather(g_flat, index101) 1211 | I110 = tf.gather(g_flat, index110) 1212 | I111 = tf.gather(g_flat, index111) 1213 | 1214 | # and finally calculate interpolated values 1215 | dx = x - (tf.cast(x0, tf.float32) + 0.5) 1216 | dy = y - (tf.cast(y0, tf.float32) + 0.5) 1217 | dz = z - (tf.cast(z0, tf.float32) + 0.5) 1218 | w000 = tf.expand_dims((1. - dx) * (1. - dy) * (1. - dz), 1) 1219 | w001 = tf.expand_dims((1. - dx) * (1. - dy) * dz, 1) 1220 | w010 = tf.expand_dims((1. - dx) * dy * (1. - dz), 1) 1221 | w011 = tf.expand_dims((1. - dx) * dy * dz, 1) 1222 | w100 = tf.expand_dims(dx * (1. - dy) * (1. - dz), 1) 1223 | w101 = tf.expand_dims(dx * (1. - dy) * dz, 1) 1224 | w110 = tf.expand_dims(dx * dy * (1. - dz), 1) 1225 | w111 = tf.expand_dims(dx * dy * dz, 1) 1226 | output = tf.add_n([w000 * I000, w001 * I001, w010 * I010, w011 * I011, 1227 | w100 * I100, w101 * I101, w110 * I110, w111 * I111]) 1228 | 1229 | # reshape 1230 | output = tf.reshape(output, [n_batch, pn, n_channel]) 1231 | return output 1232 | 1233 | def W(k='cubic'): 1234 | def cubicspline(q, h, is_3d=False): 1235 | if is_3d: 1236 | sigma = 8/np.pi/(h**3) 1237 | else: 1238 | sigma = 40/7/np.pi/(h**2) 1239 | return tf.compat.v1.where(q > 1, 1240 | tf.zeros_like(q), 1241 | sigma * tf.where(q <= 0.5, 1242 | 6 * (q**3 - q**2) + 1, 1243 | 2 * (1-q)**3 1244 | ) 1245 | ) 1246 | 1247 | def linear(q, h=0, is_3d=False): 1248 | return tf.maximum(1 - q, 0) 1249 | 1250 | def smooth(q, h=0, is_3d=False): 1251 | return tf.maximum(1 - q**2, 0) 1252 | 1253 | def sharp(q, h=0, is_3d=False): 1254 | return tf.maximum((1/q)**2 - 1, 0) 1255 | 1256 | def poly6(q, h, is_3d=False): 1257 | if is_3d: 1258 | sigma = 315/64/np.pi/h**9 1259 | else: 1260 | sigma = 4/np.pi/h**8 1261 | return tf.maximum(sigma*(h**2 - q**2)**3, 0) 1262 | 1263 | if k == 'cubic': return cubicspline 1264 | elif k == 'linear': return linear 1265 | elif k == 'smooth': return smooth 1266 | elif k == 'sharp': return sharp 1267 | elif k == 'poly6': return poly6 1268 | 1269 | def GW(k='cubic'): 1270 | def cubicspline(r, h, is_3d=False): 1271 | if is_3d: 1272 | sigma = 48/np.pi/(h**3) 1273 | else: 1274 | sigma = 240/7/np.pi/(h**2) 1275 | rl = tf.sqrt(tf.reduce_sum(r**2, axis=-1)) 1276 | rl = tf.tile(tf.expand_dims(rl, axis=-1), [1,1,r.shape[-1]]) 1277 | q = rl / h 1278 | return tf.compat.v1.where(tf.logical_and(q <= 1, rl > 1e-6), 1279 | sigma / (rl * h) * r * tf.where(q <= 0.5, 1280 | q * (3*q - 2), 1281 | -(1-q)**2 1282 | ), 1283 | tf.zeros_like(r), 1284 | ) 1285 | 1286 | if k == 'cubic': return cubicspline 1287 | 1288 | # MIT License 1289 | 1290 | # Copyright (c) 2018 Eldar Insafutdinov 1291 | 1292 | # Permission is hereby granted, free of charge, to any person obtaining a copy 1293 | # of this software and associated documentation files (the "Software"), to deal 1294 | # in the Software without restriction, including without limitation the rights 1295 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 1296 | # copies of the Software, and to permit persons to whom the Software is 1297 | # furnished to do so, subject to the following conditions: 1298 | 1299 | # The above copyright notice and this permission notice shall be included in all 1300 | # copies or substantial portions of the Software. 1301 | 1302 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 1303 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 1304 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 1305 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 1306 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 1307 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 1308 | # SOFTWARE. 1309 | # https://github.com/eldar/differentiable-point-clouds/dpc/util/point_cloud.py-pointcloud2voxel3d_fast 1310 | def p2g(p, domain, res, radius, rest_density, nsize, pc=None, pd=None, is_2d=True, kernel='cubic', eps=1e-6, clip=True, support=4): 1311 | # p.shape: [B,N,2 or 3] 1312 | batch_size = p.shape[0] 1313 | num_points = tf.shape(p)[1] 1314 | 1315 | # scale [0,1] -> [domain size] 1316 | domain_ = tf.cast(domain, tf.float32) 1317 | p = p * domain_ 1318 | 1319 | # clip for outliers (after advection) 1320 | if clip: 1321 | p = tf.clip_by_value(p, 0, domain_ - eps) 1322 | else: 1323 | valid = tf.logical_and(p >= 0, p < domain_) 1324 | valid = tf.reduce_all(valid, axis=-1) 1325 | valid = tf.reshape(valid, [-1]) 1326 | 1327 | # compute grid id 1328 | cell_size = domain_ / tf.cast(res, tf.float32) 1329 | # assert cell_size[0] == cell_size[1] 1330 | cell_size = cell_size[0] 1331 | indices_floor = tf.floor(p / cell_size) 1332 | indices_int = tf.cast(indices_floor, tf.int32) 1333 | batch_indices = tf.range(0, batch_size, 1) 1334 | batch_indices = tf.expand_dims(batch_indices, -1) 1335 | batch_indices = tf.tile(batch_indices, [1, num_points]) 1336 | batch_indices = tf.expand_dims(batch_indices, -1) 1337 | 1338 | gc = (indices_floor+0.5)*cell_size # grid cell center 1339 | r = p - gc # fractional part from grid cell center 1340 | rr = [] 1341 | for n in range(-nsize,nsize+1): 1342 | rr.append(r-n*cell_size) # [+cx,0,-cx] 1343 | # rr = [r] 1344 | 1345 | # for sph 1346 | W_ = W(kernel) 1347 | support_radius = radius*support # 2*particle spacing 1348 | if is_2d: 1349 | volume = 0.8 * (2*radius)**2 1350 | else: 1351 | volume = 0.8 * (2*radius)**3 1352 | mass = volume * rest_density 1353 | 1354 | if is_2d: 1355 | # [B,N,3], last three has its integer indices including batch id 1356 | indices = tf.concat([batch_indices, indices_int], axis=2) 1357 | indices = tf.reshape(indices, [-1, 3]) 1358 | 1359 | if not clip: 1360 | indices = tf.boolean_mask(indices, valid) 1361 | 1362 | def interpolate_scatter2d(pos): 1363 | dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1] 1364 | q = tf.sqrt(dx**2 + dy**2) / support_radius 1365 | 1366 | updates_raw = W_(q, support_radius, is_3d=False) 1367 | updates = mass*tf.reshape(updates_raw, [-1]) 1368 | if not clip: 1369 | updates = tf.boolean_mask(updates, valid) 1370 | 1371 | indices_loc = indices 1372 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize]]) 1373 | # indices_shift = tf.constant([[0] + [pos[0], pos[1]]]) 1374 | num_updates = tf.shape(indices_loc)[0] 1375 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1376 | indices_loc = indices_loc + indices_shift 1377 | 1378 | if pc is None: 1379 | img = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res), axis=-1)) #[batch_size]+res 1380 | img = tf.expand_dims(img, axis=-1) 1381 | else: 1382 | updates_ = mass*tf.expand_dims(updates_raw, axis=-1) * pc 1383 | if pd is None: 1384 | updates_ /= rest_density 1385 | else: 1386 | updates_ /= pd 1387 | updates = tf.reshape(updates_, [-1, tf.shape(pc)[-1]]) 1388 | if not clip: 1389 | updates = tf.boolean_mask(updates, valid) 1390 | img = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[tf.shape(pc)[-1]]), axis=-1)) 1391 | 1392 | return img 1393 | 1394 | img = [] 1395 | for j in range(2*nsize+1): 1396 | for i in range(2*nsize+1): 1397 | vx = interpolate_scatter2d([j, i]) 1398 | img.append(vx) 1399 | 1400 | # vx, vx_rgb = interpolate_scatter2d([0, 0]) 1401 | # img.append(vx) 1402 | # img_rgb.append(vx_rgb) 1403 | 1404 | img = tf.add_n(img)[:,::-1] # flip in y 1405 | return img 1406 | else: 1407 | # [B,N,4], last three has its integer indices including batch id 1408 | indices = tf.concat([batch_indices, indices_int], axis=2) 1409 | indices = tf.reshape(indices, [-1, 4]) 1410 | 1411 | if not clip: 1412 | indices = tf.boolean_mask(indices, valid) 1413 | 1414 | def interpolate_scatter3d(pos): 1415 | dz,dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1], rr[pos[2]][...,2] 1416 | q = tf.sqrt(dx**2 + dy**2 + dz**2) / support_radius 1417 | 1418 | updates_raw = W_(q, support_radius, is_3d=True) 1419 | updates = mass*tf.reshape(updates_raw, [-1]) 1420 | if not clip: 1421 | updates = tf.boolean_mask(updates, valid) 1422 | 1423 | indices_loc = indices 1424 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize, pos[2]-nsize]]) 1425 | num_updates = tf.shape(indices_loc)[0] 1426 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1427 | indices_loc = indices_loc + indices_shift 1428 | 1429 | if pc is None: 1430 | vox = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res), axis=-1)) 1431 | vox = tf.expand_dims(vox, axis=-1) 1432 | else: 1433 | updates_ = mass*tf.expand_dims(updates_raw, axis=-1) * pc 1434 | if pd is None: 1435 | updates_ /= rest_density 1436 | else: 1437 | updates_ /= pd 1438 | updates = tf.reshape(updates_, [-1, tf.shape(pc)[-1]]) 1439 | if not clip: 1440 | updates = tf.boolean_mask(updates, valid) 1441 | vox = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[tf.shape(pc)[-1]]), axis=-1)) 1442 | 1443 | return vox 1444 | 1445 | vox = [] 1446 | for k in range(2*nsize+1): 1447 | for j in range(2*nsize+1): 1448 | for i in range(2*nsize+1): 1449 | vx = interpolate_scatter3d([k,j,i]) 1450 | vox.append(vx) 1451 | 1452 | vox = tf.add_n(vox)[:,:,::-1] # flip in y 1453 | return vox 1454 | 1455 | def p2g_grad(p, domain, res, radius, rest_density, nsize, pc=None, pd=None, is_2d=True, kernel='cubic', eps=1e-6, clip=True): 1456 | # p.shape: [B,N,2 or 3] 1457 | batch_size = p.shape[0] 1458 | num_points = tf.shape(p)[1] 1459 | 1460 | # scale [0,1] -> [domain size] 1461 | domain_ = tf.cast(domain, tf.float32) 1462 | p = p * domain_ 1463 | 1464 | # clip for outliers (after advection) 1465 | if clip: 1466 | p = tf.clip_by_value(p, 0, domain_ - eps) 1467 | else: 1468 | valid = tf.logical_and(p >= 0, p < domain_) 1469 | valid = tf.reduce_all(valid, axis=-1) 1470 | valid = tf.reshape(valid, [-1]) 1471 | 1472 | # compute grid id 1473 | cell_size = domain_ / tf.cast(res, tf.float32) 1474 | # assert cell_size[0] == cell_size[1] 1475 | cell_size = cell_size[0] 1476 | indices_floor = tf.floor(p / cell_size) 1477 | indices_int = tf.cast(indices_floor, tf.int32) 1478 | batch_indices = tf.range(0, batch_size, 1) 1479 | batch_indices = tf.expand_dims(batch_indices, -1) 1480 | batch_indices = tf.tile(batch_indices, [1, num_points]) 1481 | batch_indices = tf.expand_dims(batch_indices, -1) 1482 | 1483 | gc = (indices_floor+0.5)*cell_size # grid cell center 1484 | r = gc - p # fractional part from grid cell center 1485 | rr = [] 1486 | for n in range(-nsize,nsize+1): 1487 | rr.append(r+n*cell_size) # [+cx,0,-cx] 1488 | # rr = [r] 1489 | 1490 | # for sph 1491 | W_ = GW(kernel) 1492 | support_radius = radius*4 # 2*particle spacing 1493 | if is_2d: 1494 | volume = 0.8 * (2*radius)**2 1495 | else: 1496 | volume = 0.8 * (2*radius)**3 1497 | mass = volume * rest_density 1498 | 1499 | if is_2d: 1500 | # [B,N,3], last three has its integer indices including batch id 1501 | indices = tf.concat([batch_indices, indices_int], axis=2) 1502 | indices = tf.reshape(indices, [-1, 3]) 1503 | 1504 | if not clip: 1505 | indices = tf.boolean_mask(indices, valid) 1506 | 1507 | def interpolate_scatter2d(pos): 1508 | dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1] 1509 | updates_raw = W_(tf.stack([dy,dx], axis=-1), support_radius, is_3d=False) 1510 | 1511 | updates = mass*updates_raw 1512 | if pd is None: 1513 | updates /= rest_density 1514 | else: 1515 | updates /= pd 1516 | updates = tf.reshape(updates, [-1, 2]) 1517 | if not clip: 1518 | updates = tf.boolean_mask(updates, valid) 1519 | 1520 | indices_loc = indices 1521 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize]]) 1522 | num_updates = tf.shape(indices_loc)[0] 1523 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1524 | indices_loc = indices_loc + indices_shift 1525 | 1526 | n = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[2]), axis=-1)) 1527 | return n 1528 | 1529 | n = [] 1530 | for j in range(2*nsize+1): 1531 | for i in range(2*nsize+1): 1532 | n_ = interpolate_scatter2d([j, i]) 1533 | n.append(n_) 1534 | 1535 | n = -tf.add_n(n)[:,::-1] # flip in y 1536 | return n 1537 | else: 1538 | # [B,N,4], last three has its integer indices including batch id 1539 | indices = tf.concat([batch_indices, indices_int], axis=2) 1540 | indices = tf.reshape(indices, [-1, 4]) 1541 | 1542 | if not clip: 1543 | indices = tf.boolean_mask(indices, valid) 1544 | 1545 | def interpolate_scatter3d(pos): 1546 | dz,dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1], rr[pos[2]][...,2] 1547 | updates_raw = W_(tf.stack([dz,dy,dx], axis=-1), support_radius, is_3d=True) 1548 | 1549 | updates = mass*updates_raw 1550 | if pd is None: 1551 | updates /= rest_density 1552 | else: 1553 | updates /= pd 1554 | updates = tf.reshape(updates, [-1, 3]) 1555 | if not clip: 1556 | updates = tf.boolean_mask(updates, valid) 1557 | 1558 | indices_loc = indices 1559 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize, pos[2]-nsize]]) 1560 | num_updates = tf.shape(indices_loc)[0] 1561 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1562 | indices_loc = indices_loc + indices_shift 1563 | 1564 | n = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[3]), axis=-1)) 1565 | return n 1566 | 1567 | n = [] 1568 | for k in range(2*nsize+1): 1569 | for j in range(2*nsize+1): 1570 | for i in range(2*nsize+1): 1571 | n_ = interpolate_scatter3d([k,j,i]) 1572 | n.append(n_) 1573 | 1574 | n = -tf.add_n(n)[:,:,::-1] # flip in y 1575 | return n 1576 | 1577 | def p2g_wavg(p, x, domain, res, radius, nsize, is_2d=True, kernel='linear', eps=1e-6, clip=True, support=4): 1578 | # p.shape: [B,N,2 or 3] 1579 | batch_size = p.shape[0] 1580 | num_points = tf.shape(p)[1] 1581 | 1582 | # scale [0,1] -> [domain size] 1583 | domain_ = tf.cast(domain, tf.float32) 1584 | p = p * domain_ 1585 | 1586 | # clip for outliers (after advection) 1587 | if clip: 1588 | p = tf.clip_by_value(p, 0, domain_ - eps) 1589 | else: 1590 | valid = tf.logical_and(p >= 0, p < domain_) 1591 | valid = tf.reduce_all(valid, axis=-1) 1592 | valid = tf.reshape(valid, [-1]) 1593 | 1594 | # compute grid id 1595 | cell_size = domain_ / tf.cast(res, tf.float32) 1596 | # assert cell_size[0] == cell_size[1] 1597 | cell_size = cell_size[0] 1598 | indices_floor = tf.floor(p / cell_size) 1599 | indices_int = tf.cast(indices_floor, tf.int32) 1600 | batch_indices = tf.range(0, batch_size, 1) 1601 | batch_indices = tf.expand_dims(batch_indices, -1) 1602 | batch_indices = tf.tile(batch_indices, [1, num_points]) 1603 | batch_indices = tf.expand_dims(batch_indices, -1) 1604 | 1605 | gc = (indices_floor+0.5)*cell_size # grid cell center 1606 | r = p - gc # fractional part from grid cell center 1607 | rr = [] 1608 | for n in range(-nsize,nsize+1): 1609 | rr.append(r-n*cell_size) # [+cx,0,-cx] 1610 | # rr = [r] 1611 | 1612 | W_ = W(kernel) 1613 | support_radius = radius*support # 2*particle spacing 1614 | 1615 | if is_2d: 1616 | # [B,N,3], last three has its integer indices including batch id 1617 | indices = tf.concat([batch_indices, indices_int], axis=2) 1618 | indices = tf.reshape(indices, [-1, 3]) 1619 | if not clip: 1620 | indices = tf.boolean_mask(indices, valid) 1621 | 1622 | def interpolate_scatter2d(pos): 1623 | dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1] 1624 | q = tf.sqrt(dx**2 + dy**2) / support_radius 1625 | 1626 | updates_raw = W_(q, support_radius, is_3d=False) 1627 | updates = tf.reshape(updates_raw, [-1]) 1628 | if not clip: 1629 | updates = tf.boolean_mask(updates, valid) 1630 | 1631 | indices_loc = indices 1632 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize]]) 1633 | # indices_shift = tf.constant([[0] + [pos[0], pos[1]]]) 1634 | num_updates = tf.shape(indices_loc)[0] 1635 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1636 | indices_loc = indices_loc + indices_shift 1637 | 1638 | wmap = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res), axis=-1)) #[batch_size]+res 1639 | wmap = tf.expand_dims(wmap, axis=-1) 1640 | wmap = tf.tile(wmap, [1, 1, 1, tf.shape(x)[-1]]) 1641 | 1642 | updates_x = tf.expand_dims(updates_raw, axis=-1) * x 1643 | updates = tf.reshape(updates_x, [-1, tf.shape(x)[-1]]) 1644 | if not clip: 1645 | updates = tf.boolean_mask(updates, valid) 1646 | d_img = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[tf.shape(x)[-1]]), axis=-1)) 1647 | return wmap, d_img 1648 | 1649 | wmap, d_img = [], [] 1650 | for j in range(2*nsize+1): 1651 | for i in range(2*nsize+1): 1652 | w, d = interpolate_scatter2d([j, i]) 1653 | wmap.append(w) 1654 | d_img.append(d) 1655 | 1656 | wmap = tf.add_n(wmap)[:,::-1] # flip 1657 | d_img = tf.add_n(d_img)[:,::-1] # flip 1658 | d_img = tf.compat.v1.where(wmap > eps, d_img/wmap, d_img) 1659 | return d_img 1660 | else: 1661 | # [B,N,4], last three has its integer indices including batch id 1662 | indices = tf.concat([batch_indices, indices_int], axis=2) 1663 | indices = tf.reshape(indices, [-1, 4]) 1664 | if not clip: 1665 | indices = tf.boolean_mask(indices, valid) 1666 | 1667 | def interpolate_scatter3d(pos): 1668 | dz,dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1], rr[pos[2]][...,2] 1669 | q = tf.sqrt(dx**2 + dy**2 + dz**2) / support_radius 1670 | 1671 | updates_raw = W_(q, support_radius, is_3d=True) 1672 | updates = tf.reshape(updates_raw, [-1]) 1673 | if not clip: 1674 | updates = tf.boolean_mask(updates, valid) 1675 | 1676 | indices_loc = indices 1677 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize, pos[2]-nsize]]) 1678 | num_updates = tf.shape(indices_loc)[0] 1679 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1680 | indices_loc = indices_loc + indices_shift 1681 | 1682 | wmap = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res), axis=-1)) #[batch_size]+res 1683 | wmap = tf.expand_dims(wmap, axis=-1) 1684 | wmap = tf.tile(wmap, [1, 1, 1, 1, tf.shape(x)[-1]]) 1685 | 1686 | updates_x = tf.expand_dims(updates_raw, axis=-1) * x 1687 | updates = tf.reshape(updates_x, [-1, tf.shape(x)[-1]]) 1688 | if not clip: 1689 | updates = tf.boolean_mask(updates, valid) 1690 | d_vox = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[tf.shape(x)[-1]]), axis=-1)) 1691 | return wmap, d_vox 1692 | 1693 | wmap, d_vox = [], [] 1694 | for k in range(2*nsize+1): 1695 | for j in range(2*nsize+1): 1696 | for i in range(2*nsize+1): 1697 | w, vx = interpolate_scatter3d([k,j,i]) 1698 | wmap.append(w) 1699 | d_vox.append(vx) 1700 | 1701 | wmap = tf.add_n(wmap)[:,:,::-1] # flip 1702 | d_vox = tf.add_n(d_vox)[:,:,::-1] # flip 1703 | d_vox = tf.compat.v1.where(wmap > eps, d_vox/wmap, d_vox) 1704 | return d_vox 1705 | 1706 | def p2g_repulsive(p, domain, res, radius, nsize, is_2d=True, kernel='smooth', eps=1e-6, alpha=50): 1707 | # p.shape: [B,N,2 or 3] 1708 | batch_size = p.shape[0] 1709 | num_points = tf.shape(p)[1] 1710 | 1711 | # scale [0,1] -> [domain size] 1712 | domain_ = tf.cast(domain, tf.float32) 1713 | p = p * domain_ 1714 | 1715 | # clip for outliers (after advection) 1716 | p = tf.clip_by_value(p, 0, domain_ - eps) 1717 | 1718 | # compute grid id 1719 | cell_size = domain_ / tf.cast(res, tf.float32) 1720 | # assert cell_size[0] == cell_size[1] 1721 | cell_size = cell_size[0] 1722 | indices_floor = tf.floor(p / cell_size) 1723 | indices_int = tf.cast(indices_floor, tf.int32) 1724 | batch_indices = tf.range(0, batch_size, 1) 1725 | batch_indices = tf.expand_dims(batch_indices, -1) 1726 | batch_indices = tf.tile(batch_indices, [1, num_points]) 1727 | batch_indices = tf.expand_dims(batch_indices, -1) 1728 | 1729 | gc = (indices_floor+0.5)*cell_size # grid cell center 1730 | r = gc - p # particle to grid center vector 1731 | rr = [] 1732 | for n in range(-nsize,nsize+1): 1733 | rr.append(r + n*cell_size) 1734 | # rr = [r] 1735 | 1736 | # repulsive force with artificial weak spring [Ando and Tsuruno 2011] 1737 | W_ = W(kernel) 1738 | support_radius = radius*2 # particle spacing 1739 | 1740 | if is_2d: 1741 | # [B,N,3], last three has its integer indices including batch id 1742 | indices = tf.concat([batch_indices, indices_int], axis=2) 1743 | indices = tf.reshape(indices, [-1, 3]) 1744 | 1745 | def interpolate_scatter2d(pos): 1746 | dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1] 1747 | q = tf.sqrt(dx**2 + dy**2) / support_radius 1748 | 1749 | disp = tf.stack([dy,dx], axis=-1) 1750 | updates_raw = W_(q) 1751 | updates_force = tf.expand_dims(updates_raw/q, axis=-1) * disp 1752 | updates = tf.reshape(updates_force, [-1,2]) 1753 | 1754 | indices_loc = indices 1755 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize]]) 1756 | # indices_shift = tf.constant([[0] + [pos[0], pos[1]]]) 1757 | num_updates = tf.shape(indices_loc)[0] 1758 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1759 | indices_loc = indices_loc + indices_shift 1760 | 1761 | img = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res,[2]), axis=-1)) 1762 | 1763 | return img 1764 | 1765 | img = [] 1766 | for j in range(2*nsize+1): 1767 | for i in range(2*nsize+1): 1768 | vx = interpolate_scatter2d([j, i]) 1769 | img.append(vx) 1770 | 1771 | # vx, vx_rgb = interpolate_scatter2d([0, 0]) 1772 | # img.append(vx) 1773 | # img_rgb.append(vx_rgb) 1774 | 1775 | img = tf.add_n(img)[:,::-1] # flip in y 1776 | img *= -alpha*support_radius 1777 | return img 1778 | else: 1779 | # [B,N,4], last three has its integer indices including batch id 1780 | indices = tf.concat([batch_indices, indices_int], axis=2) 1781 | indices = tf.reshape(indices, [-1, 4]) 1782 | 1783 | def interpolate_scatter3d(pos): 1784 | dz,dy,dx = rr[pos[0]][...,0], rr[pos[1]][...,1], rr[pos[2]][...,2] 1785 | q = tf.sqrt(dx**2 + dy**2 + dz**2) / support_radius 1786 | 1787 | updates_raw = W_(q, support_radius, is_3d=True) 1788 | updates = tf.reshape(updates_raw, [-1]) 1789 | 1790 | indices_loc = indices 1791 | indices_shift = tf.constant([[0] + [pos[0]-nsize, pos[1]-nsize, pos[2]-nsize]]) 1792 | num_updates = tf.shape(indices_loc)[0] 1793 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1794 | indices_loc = indices_loc + indices_shift 1795 | 1796 | vox = tf.scatter_nd(indices_loc, updates, tf.concat(([batch_size],res), axis=-1)) 1797 | vox = tf.expand_dims(vox, axis=-1) 1798 | 1799 | return vox 1800 | 1801 | vox = [] 1802 | for k in range(2*nsize+1): 1803 | for j in range(2*nsize+1): 1804 | for i in range(2*nsize+1): 1805 | vx = interpolate_scatter3d([k,j,i]) 1806 | vox.append(vx) 1807 | 1808 | vox = tf.add_n(vox)[:,:,::-1] # flip in y 1809 | return vox 1810 | 1811 | def p2g_(p, res): 1812 | # p.shape: [B,N,2] 1813 | batch_size = p.shape[0] 1814 | num_points = tf.shape(p)[1] 1815 | 1816 | indices_floor = tf.floor(p) 1817 | indices_int = tf.cast(indices_floor, tf.int32) 1818 | batch_indices = tf.range(0, batch_size, 1) 1819 | batch_indices = tf.expand_dims(batch_indices, -1) 1820 | batch_indices = tf.tile(batch_indices, [1, num_points]) 1821 | batch_indices = tf.expand_dims(batch_indices, -1) 1822 | 1823 | indices = tf.concat([batch_indices, indices_int], axis=2) 1824 | indices = tf.reshape(indices, [-1, 4]) 1825 | 1826 | r = p - indices_floor # fractional part 1827 | # rr = [1.0 - r, r] 1828 | rr = [r, 1-r] 1829 | W_ = W('cubic') 1830 | 1831 | def interpolate_scatter3d(pos): 1832 | # updates_raw = rr[pos[0]][:, :, 0] * rr[pos[1]][:, :, 1] * rr[pos[2]][:, :, 2] 1833 | dx,dy,dz = rr[pos[0]][...,0], rr[pos[1]][...,1], rr[pos[2]][...,2] 1834 | updates_raw = W_(tf.sqrt(dx**2 + dy**2 + dz**2) / np.sqrt(3)) # normalized distance 1835 | updates = tf.reshape(updates_raw, [-1]) 1836 | 1837 | indices_loc = indices 1838 | indices_shift = tf.constant([[0] + pos]) 1839 | num_updates = tf.shape(indices_loc)[0] 1840 | indices_shift = tf.tile(indices_shift, [num_updates, 1]) 1841 | indices_loc = indices_loc + indices_shift 1842 | 1843 | voxels = tf.scatter_nd(indices_loc, updates, [batch_size]+res) 1844 | return voxels 1845 | 1846 | voxels = [] 1847 | for k in range(2): 1848 | for j in range(2): 1849 | for i in range(2): 1850 | vx = interpolate_scatter3d([k, j, i]) 1851 | voxels.append(vx) 1852 | 1853 | voxels = tf.expand_dims(tf.add_n(voxels), axis=-1) 1854 | voxels = voxels[:,:,::-1] # flip 1855 | voxels = tf.transpose(voxels, [0, 3, 2, 1, 4]) 1856 | voxels /= tf.reduce_max(voxels) # TODO: remove for multiple frames 1857 | return voxels 1858 | 1859 | if __name__ == '__main__': 1860 | """ 1861 | for test 1862 | 1863 | the result will be 1864 | 1865 | the original image 1866 | [[ 0. 1. 2. 3. 4.] 1867 | [ 5. 6. 7. 8. 9.] 1868 | [ 10. 11. 12. 13. 14.] 1869 | [ 15. 16. 17. 18. 19.] 1870 | [ 20. 21. 22. 23. 24.]] 1871 | 1872 | identity warped 1873 | [[ 0. 1. 2. 3. 4.] 1874 | [ 5. 6. 7. 8. 9.] 1875 | [ 10. 11. 12. 13. 14.] 1876 | [ 15. 16. 17. 18. 19.] 1877 | [ 20. 21. 22. 23. 24.]] 1878 | 1879 | zoom in warped 1880 | [[ 6. 6.5 7. 7.5 8. ] 1881 | [ 8.5 9. 9.5 10. 10.5] 1882 | [ 11. 11.5 12. 12.5 13. ] 1883 | [ 13.5 14. 14.5 15. 15.5] 1884 | [ 16. 16.5 17. 17.5 18. ]] 1885 | """ 1886 | img = tf.cast(np.arange(25).reshape(1, 5, 5, 1), tf.float32) 1887 | identity_matrix = tf.cast([1, 0, 0, 0, 1, 0], tf.float32) 1888 | zoom_in_matrix = identity_matrix * 0.5 1889 | identity_warped = batch_affine_warp2d(img, identity_matrix) 1890 | zoom_in_warped = batch_affine_warp2d(img, zoom_in_matrix) 1891 | with tf.Session() as sess: 1892 | print(sess.run(img[0, :, :, 0])) 1893 | 1894 | # # mgrid test 1895 | # print(sess.run(batch_mgrid(2, 5, 4))) 1896 | 1897 | print(sess.run(identity_warped[0, :, :, 0])) 1898 | print(sess.run(zoom_in_warped[0, :, :, 0])) --------------------------------------------------------------------------------