├── phi ├── __init__.py ├── tf │ ├── __init__.py │ ├── profiling.py │ ├── util.py │ └── flow.py ├── control │ ├── __init__.py │ ├── nets │ │ ├── __init__.py │ │ ├── force │ │ │ ├── __init__.py │ │ │ ├── forcenet2d_3x_16 │ │ │ │ ├── model.ckpt.index │ │ │ │ ├── model.ckpt.meta │ │ │ │ ├── model.ckpt.data-00000-of-00001 │ │ │ │ └── checkpoint │ │ │ └── forcenets.py │ │ └── project │ │ │ ├── __init__.py │ │ │ ├── projectnet2d_5x_8 │ │ │ ├── checkpoint │ │ │ ├── model.ckpt.index │ │ │ ├── model.ckpt.meta │ │ │ └── model.ckpt.data-00000-of-00001 │ │ │ └── projectnets.py │ ├── smoke_control.py │ ├── distances.py │ ├── sequences.py │ ├── control_scene.py │ └── voxelutil.py ├── local │ ├── __init__.py │ └── hostname.py ├── solver │ ├── __init__.py │ ├── cuda │ │ ├── __init__.py │ │ ├── benchmarks │ │ │ ├── __init__.py │ │ │ ├── benchmark_laplace.py │ │ │ ├── benchmark3d.py │ │ │ ├── benchmark2d.py │ │ │ └── floatingerror.py │ │ ├── build.sh │ │ ├── src │ │ │ ├── laplace_op.cc │ │ │ ├── laplace_op.cu.cc │ │ │ ├── pressure_solve_op.cc │ │ │ └── pressure_solve_op.cu.cc │ │ └── cuda.py │ ├── .DS_Store │ ├── conv.py │ ├── spcg.py │ ├── manta.py │ ├── explicit.py │ ├── net.py │ └── base.py ├── viz │ ├── __init__.py │ └── plot.py ├── .DS_Store ├── data │ ├── __init__.py │ ├── transform.py │ └── augment.py ├── math │ ├── __init__.py │ ├── scipy_backend.py │ └── base.py └── fluidformat.py ├── assets └── figure1.png ├── model ├── __init__.py └── text.py ├── utils_1d ├── utils.py ├── model_utils.py ├── result_io.py └── train_diffusion.py ├── scripts_1d ├── train_syn.sh ├── train_asyn.sh └── inf_asyn.sh ├── scripts_2d ├── train_asyn.sh ├── train_syn.sh ├── default_config.yaml └── inf_asyn.sh ├── __init__.py ├── LICENSE ├── inference └── evaluate_2d.py ├── dataset ├── data_1d.py └── data_2d.py ├── .gitignore ├── train ├── train_2d.py └── train_1d.py ├── README.md └── environment.yml /phi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/tf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/control/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/local/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/solver/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/control/nets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/solver/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/control/nets/force/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/solver/cuda/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /phi/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/.DS_Store -------------------------------------------------------------------------------- /assets/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/assets/figure1.png -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | # from video_diffusion_pytorch.video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer -------------------------------------------------------------------------------- /phi/solver/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/solver/.DS_Store -------------------------------------------------------------------------------- /phi/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from phi.data.data import * 4 | 5 | import phi.data.augment 6 | import phi.data.transform -------------------------------------------------------------------------------- /utils_1d/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def none_or_str(value): 3 | if value.lower() == 'none': 4 | return None 5 | return value -------------------------------------------------------------------------------- /phi/control/nets/project/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from phi.control.nets.project.projectnets import projectnet2d_5x_8 as projectnet -------------------------------------------------------------------------------- /phi/control/nets/project/projectnet2d_5x_8/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "model.ckpt" 3 | -------------------------------------------------------------------------------- /phi/control/nets/force/forcenet2d_3x_16/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/force/forcenet2d_3x_16/model.ckpt.index -------------------------------------------------------------------------------- /phi/control/nets/force/forcenet2d_3x_16/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/force/forcenet2d_3x_16/model.ckpt.meta -------------------------------------------------------------------------------- /phi/control/nets/project/projectnet2d_5x_8/model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/project/projectnet2d_5x_8/model.ckpt.index -------------------------------------------------------------------------------- /phi/control/nets/project/projectnet2d_5x_8/model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/project/projectnet2d_5x_8/model.ckpt.meta -------------------------------------------------------------------------------- /scripts_1d/train_syn.sh: -------------------------------------------------------------------------------- 1 | python ../train/train_1d.py \ 2 | --exp_id model_syn \ 3 | --is_condition_u0 True \ 4 | --is_condition_uT True \ 5 | --train_data_path /usr/train_data 6 | -------------------------------------------------------------------------------- /phi/control/nets/force/forcenet2d_3x_16/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/force/forcenet2d_3x_16/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /phi/control/nets/project/projectnet2d_5x_8/model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI4Science-WestlakeU/CL_DiffPhyCon/HEAD/phi/control/nets/project/projectnet2d_5x_8/model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /scripts_1d/train_asyn.sh: -------------------------------------------------------------------------------- 1 | python ../train/train_1d.py \ 2 | --exp_id model_asyn \ 3 | --is_condition_u0 True \ 4 | --is_condition_uT True \ 5 | --asynch_inference_mode \ 6 | --train_data_path /usr/train_data 7 | -------------------------------------------------------------------------------- /phi/control/nets/force/forcenet2d_3x_16/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt" 2 | all_model_checkpoint_paths: "../checkpoint_00010081/model.ckpt" 3 | all_model_checkpoint_paths: "../checkpoint_00015845/model.ckpt" 4 | all_model_checkpoint_paths: "model.ckpt" 5 | -------------------------------------------------------------------------------- /phi/local/hostname.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import socket 3 | 4 | hostname = socket.gethostname() 5 | print("Local hostname: %s" % hostname) 6 | 7 | if hostname == "cube4": 8 | hostname = "cube4.ge.in.tum.de" 9 | print("Recognized as cube4") -------------------------------------------------------------------------------- /scripts_2d/train_asyn.sh: -------------------------------------------------------------------------------- 1 | data_path="/data/cl_diffphycon/2d" 2 | accelerate launch --config_file default_config.yaml \ 3 | --main_process_port 29501 \ 4 | --gpu_ids 0,1 \ 5 | ../train/train_2d.py \ 6 | --results_path "${data_path}/checkpoints/asyn_models" \ 7 | --dataset_path ${data_path} -------------------------------------------------------------------------------- /scripts_2d/train_syn.sh: -------------------------------------------------------------------------------- 1 | data_path="/data/cl_diffphycon/2d" 2 | accelerate launch --config_file default_config.yaml \ 3 | --main_process_port 29500 \ 4 | --gpu_ids 0,1 \ 5 | ../train/train_2d.py \ 6 | --results_path "${data_path}/checkpoints/syn_models" \ 7 | --dataset_path ${data_path} \ 8 | --is_synch_model -------------------------------------------------------------------------------- /scripts_2d/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | # debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: fp16 8 | num_machines: 1 9 | num_processes: 2 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false -------------------------------------------------------------------------------- /phi/solver/conv.py: -------------------------------------------------------------------------------- 1 | from phi.math.nd import * 2 | import tensorflow as tf 3 | 4 | def conv_pressure(divergence): 5 | comps = np.meshgrid(*[range(-dim, dim+1) for dim in divergence.shape[1:-1]]) 6 | d = np.sqrt(np.sum([comp**2 for comp in comps], axis=0)) 7 | weights = - np.float32(1) / np.maximum(d, 0.5) # / (4*np.pi) 8 | weights = np.reshape(weights, list(d.shape)+[1, 1]) 9 | return tf.nn.conv2d(divergence, weights, [1, 1, 1, 1], "SAME") -------------------------------------------------------------------------------- /scripts_2d/inf_asyn.sh: -------------------------------------------------------------------------------- 1 | data_path="/data/cl_diffphycon/2d" 2 | python ../inference/inference_2d.py \ 3 | --dataset_path "${data_path}" \ 4 | --inference_result_path "${data_path}/inference_results/" \ 5 | --init_diffusion_model_path "${data_path}/checkpoints/unet_syn_bsize12_cond_d_v_s_diffsteps600_horizon15" \ 6 | --online_diffusion_model_path "${data_path}/checkpoints/unet_asyn_bsize12_cond_d_v_s_diffsteps600_horizon15/" \ 7 | --using_ddim True \ 8 | --asynch_inference_mode 9 | -------------------------------------------------------------------------------- /scripts_1d/inf_asyn.sh: -------------------------------------------------------------------------------- 1 | exp_id_ls=(0) 2 | 3 | for i in "${!exp_id_ls[@]}"; do 4 | exp_id=${exp_id_ls[i]} 5 | dim=${dim_ls[i]} 6 | CUDA_VISIBLE_DEVICES=0 python ../inference/inference_1d.py \ 7 | --exp_id_i model_syn \ 8 | --exp_id_f model_asyn \ 9 | --dataset '/usr/test_data' \ 10 | --test_target '/usr/test_data' \ 11 | --is_condition_u0 True \ 12 | --is_condition_uT True \ 13 | --save_file /usr/inference_savepath/test.yaml \ 14 | --infer_interval 1 \ 15 | --checkpoint 9 \ 16 | --eval_save /usr/inference_savepath \ 17 | --diffusion_model_path '/usr/checkpoints' \ 18 | --asynch_inference_mode 19 | done 20 | -------------------------------------------------------------------------------- /phi/data/transform.py: -------------------------------------------------------------------------------- 1 | from phi.data import * 2 | from phi.math.nd import * 3 | 4 | 5 | class Downsample(DerivedChannel): 6 | 7 | def __init__(self, field): 8 | DerivedChannel.__init__(self, [field]) 9 | self.field = self.input_fields[0] 10 | 11 | def size(self, datasource): 12 | return self.field.size(datasource) 13 | 14 | def shape(self, datasource): 15 | in_shape = self.field.shape(datasource) 16 | return [in_shape[0]] + [s // 2 for s in in_shape[1:-1]] + [in_shape[-1]] 17 | 18 | def get(self, datasource, indices): 19 | data = self.field.get(datasource, indices) 20 | for array in data: 21 | yield downsample2x(array) -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # from pde_gen_control.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer 2 | 3 | # from pde_gen_control.learned_gaussian_diffusion import LearnedGaussianDiffusion 4 | # from pde_gen_control.continuous_time_gaussian_diffusion import ContinuousTimeGaussianDiffusion 5 | # from pde_gen_control.weighted_objective_gaussian_diffusion import WeightedObjectiveGaussianDiffusion 6 | # from pde_gen_control.elucidated_diffusion import ElucidatedDiffusion 7 | # from pde_gen_control.v_param_continuous_time_gaussian_diffusion import VParamContinuousTimeGaussianDiffusion 8 | 9 | # from pde_gen_control.denoising_diffusion_pytorch_1d import GaussianDiffusion1D, Unet1D, Trainer1D, Dataset1D 10 | -------------------------------------------------------------------------------- /phi/control/smoke_control.py: -------------------------------------------------------------------------------- 1 | from phi.control.sequences import * 2 | 3 | 4 | 5 | class SmokeState(Frame): 6 | 7 | def __init__(self, index, density, velocity, type=TYPE_KEYFRAME): 8 | Frame.__init__(self, index, type=type) 9 | assert density is not None and velocity is not None 10 | self.density = density 11 | self.velocity = velocity 12 | 13 | 14 | class UpdatableSmokeState(Frame): 15 | 16 | def __init__(self, ground_truth): 17 | Frame.__init__(self, ground_truth.index) 18 | self.ground_truth = ground_truth 19 | self.states = [] 20 | 21 | def update(self, new_state, type): 22 | assert type >= self.type 23 | self.states.append(new_state) 24 | self.type = type -------------------------------------------------------------------------------- /phi/solver/cuda/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | rm -r ./build/ 4 | mkdir ./build/ 5 | 6 | TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') ) 7 | TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') ) 8 | 9 | /usr/local/cuda/bin/nvcc -std=c++11 -c -o ./build/laplace_op.cu.o ./src/laplace_op.cu.cc ${TF_CFLAGS[@]} -x cu -Xcompiler -fPIC 10 | 11 | # just used for laplace benchmark, can be removed when benchmark is not required 12 | g++ -std=c++11 -shared -o ./build/laplace_op.so ./src/laplace_op.cc ./build/laplace_op.cu.o ${TF_CFLAGS[@]} -fPIC ${TF_LFLAGS[@]} 13 | 14 | /usr/local/cuda/bin/nvcc -lcublas -std=c++11 -c -o ./build/pressure_solve_op.cu.o ./src/pressure_solve_op.cu.cc ${TF_CFLAGS[@]} -x cu -Xcompiler -fPIC 15 | g++ -std=c++11 -shared -o ./build/pressure_solve_op.so ./src/pressure_solve_op.cc ./build/pressure_solve_op.cu.o ./build/laplace_op.cu.o ${TF_CFLAGS[@]} -fPIC ${TF_LFLAGS[@]} -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 AI4Science-WestlakeU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /inference/evaluate_2d.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab as plt 2 | import numpy as np 3 | import tqdm 4 | import sys, os 5 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..')) 6 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 7 | from dataset.apps.evaluate_solver import * 8 | 9 | def evaluate(root, folder, n_sim=50): 10 | path = os.path.join(root, "inference_results", folder) 11 | smoke_out_dir = os.path.join(path, "smoke_outs") 12 | smoke_out_files = os.listdir(smoke_out_dir) 13 | n_sim = len(smoke_out_files) 14 | if n_sim == 0: 15 | return 16 | smoke_out_files.sort() 17 | all_smoke_out = [] 18 | for i in range(n_sim): 19 | smoke = np.load(os.path.join(smoke_out_dir, "{}.npy".format(i))) 20 | final_smoke_out = smoke[-1] 21 | all_smoke_out.append(final_smoke_out) 22 | avg_smoke_out = np.mean(all_smoke_out, axis=0) 23 | print(n_sim, folder, ", control objective J: ", 1-avg_smoke_out, "\n") 24 | 25 | if __name__ == '__main__': 26 | root = "/data/cl_diffphycon/2d" 27 | n_sim = 50 28 | all_folders = os.listdir(os.path.join(root, "inference_results")) 29 | all_folders.sort() 30 | for folder in all_folders: 31 | evaluate(root, folder, n_sim) -------------------------------------------------------------------------------- /phi/solver/cuda/benchmarks/benchmark_laplace.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.ticker import FormatStrFormatter 3 | from phi.solver.cuda.benchmarks.benchmark_utils import * 4 | 5 | testruns = 25 6 | warmup = 5 7 | tests = [8, 16, 32, 64, 128]#, 256, 512, 1024, 2048] 8 | dimension = 3 9 | 10 | cudaResults = benchmark_laplace_matrix_cuda(tests, dimension, warmup, testruns) 11 | gc.collect() 12 | phiResults = benchmark_laplace_matrix_phi(tests, dimension, warmup, testruns) 13 | 14 | cudaAVG = [np.mean(a) for a in cudaResults] 15 | cudaSTD = [np.std(a) for a in cudaResults] 16 | phiAVG = [np.mean(a) for a in phiResults] 17 | phiSTD = [np.std(a) for a in phiResults] 18 | 19 | print("tests = " + str(tests)) 20 | print("cudaAVG = " + str(cudaAVG)) 21 | print("cudaSTD = " + str(cudaSTD)) 22 | print("phiAVG = " + str(phiAVG)) 23 | print("phiSTD = " + str(phiSTD)) 24 | 25 | 26 | 27 | plt.errorbar(tests, cudaAVG, cudaSTD, fmt='-o') 28 | plt.errorbar(tests, phiAVG, phiSTD, fmt='-o') 29 | 30 | 31 | plt.legend(['Cuda', 'PhiFlow'], loc='upper left') 32 | plt.xscale('log', basex=2) 33 | plt.yscale('log') 34 | plt.xticks(tests) 35 | ax = plt.gca() 36 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f')) 37 | plt.xlabel("Grid Dimension") 38 | plt.ylabel("Computation Time in seconds") 39 | plt.show() 40 | -------------------------------------------------------------------------------- /phi/solver/spcg.py: -------------------------------------------------------------------------------- 1 | from phi.math.nd import * 2 | from phi.solver.base import PressureSolver, conjugate_gradient 3 | 4 | 5 | class SPCGPressureSolver(PressureSolver): 6 | 7 | def __init__(self): 8 | PressureSolver.__init__(self, "Single-Phase Conjugate Gradient") 9 | 10 | def solve(self, divergence, active_mask, fluid_mask, boundaries, accuracy, pressure_guess=None, 11 | max_iterations=500, return_loop_counter=False, gradient_accuracy=None): 12 | if fluid_mask is not None: 13 | fluid_mask = boundaries.pad_fluid(fluid_mask) 14 | # if active_mask is not None: 15 | # active_mask = boundaries.pad_active(active_mask) 16 | 17 | def presure_gradient(op, grad): 18 | return solve_pressure_forward(grad, fluid_mask, max_gradient_iterations, None, gradient_accuracy, boundaries)[0] 19 | 20 | pressure_with_gradient, iteration_count = math.with_custom_gradient(solve_pressure_forward, 21 | [divergence, fluid_mask, max_iterations, pressure_guess, accuracy, boundaries], 22 | presure_gradient, 23 | input_index=0, output_index=0, 24 | name_base="spcg_pressure_solve") 25 | 26 | max_gradient_iterations = max_iterations if gradient_accuracy is not None else iteration_count 27 | 28 | if return_loop_counter: 29 | return pressure_with_gradient, iteration_count 30 | else: 31 | return pressure_with_gradient 32 | 33 | 34 | def solve_pressure_forward(divergence, fluid_mask, max_iterations, guess, accuracy, boundaries): 35 | apply_A = lambda pressure: laplace(boundaries.pad_pressure(pressure), weights=fluid_mask, padding="valid") 36 | return conjugate_gradient(divergence, apply_A, guess, accuracy, max_iterations) 37 | -------------------------------------------------------------------------------- /phi/math/__init__.py: -------------------------------------------------------------------------------- 1 | from phi.math.base import DynamicBackend 2 | backend = DynamicBackend() 3 | 4 | from phi.math.scipy_backend import SciPyBackend 5 | backend.backends.append(SciPyBackend()) 6 | 7 | 8 | def load_tensorflow(): 9 | """ 10 | Internal function to register the TensorFlow backend. 11 | This function is called automatically once a TFSimulation is instantiated. 12 | :return: True if TensorFlow could be imported, else False 13 | """ 14 | try: 15 | import phi.math.tensorflow_backend as tfb 16 | for b in backend.backends: 17 | if isinstance(b, tfb.TFBackend): return True 18 | backend.backends.append(tfb.TFBackend()) 19 | return True 20 | except BaseException as e: 21 | import logging 22 | logging.fatal("Failed to load TensorFlow backend. Error: %s" % e) 23 | print("Failed to load TensorFlow backend. Error: %s" % e) 24 | return False 25 | 26 | 27 | abs = backend.abs 28 | add = backend.add 29 | boolean_mask = backend.boolean_mask 30 | ceil = backend.ceil 31 | floor = backend.floor 32 | concat = backend.concat 33 | conv = backend.conv 34 | dimrange = backend.dimrange 35 | dot = backend.dot 36 | exp = backend.exp 37 | expand_dims = backend.expand_dims 38 | flatten = backend.flatten 39 | gather = backend.gather 40 | isfinite = backend.isfinite 41 | matmul = backend.matmul 42 | max = backend.max 43 | maximum = backend.maximum 44 | mean = backend.mean 45 | minimum = backend.minimum 46 | name = backend.name 47 | ones_like = backend.ones_like 48 | pad = backend.pad 49 | py_func = backend.py_func 50 | resample = backend.resample 51 | reshape = backend.reshape 52 | shape = backend.shape 53 | sqrt = backend.sqrt 54 | stack = backend.stack 55 | std = backend.std 56 | sum = backend.sum 57 | tile = backend.tile 58 | to_float = backend.to_float 59 | unstack = backend.unstack 60 | while_loop = backend.while_loop 61 | with_custom_gradient = backend.with_custom_gradient 62 | zeros_like = backend.zeros_like -------------------------------------------------------------------------------- /phi/tf/profiling.py: -------------------------------------------------------------------------------- 1 | import json, threading, os, socket, urllib 2 | import tensorflow as tf 3 | from tensorflow.python.client import timeline 4 | from tensorboard import default, program 5 | 6 | 7 | class Timeliner: 8 | _timeline_dict = None 9 | options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 10 | run_metadata = tf.RunMetadata() 11 | 12 | 13 | def update_timeline(self, chrome_trace): 14 | # convert crome trace to python dict 15 | chrome_trace_dict = json.loads(chrome_trace) 16 | # for first run store full trace 17 | if self._timeline_dict is None: 18 | self._timeline_dict = chrome_trace_dict 19 | # for other - update only time consumption, not definitions 20 | else: 21 | for event in chrome_trace_dict['traceEvents']: 22 | # events time consumption started with 'ts' prefix 23 | if 'ts' in event: 24 | self._timeline_dict['traceEvents'].append(event) 25 | 26 | def save(self, f_name): 27 | os.path.isdir(os.path.dirname(f_name)) or os.makedirs(os.path.dirname(f_name)) 28 | with open(f_name, 'w') as f: 29 | json.dump(self._timeline_dict, f) 30 | 31 | def add_run(self, run_metadata=None): 32 | if run_metadata is None: 33 | run_metadata = self.run_metadata 34 | fetched_timeline = timeline.Timeline(run_metadata.step_stats) 35 | chrome_trace = fetched_timeline.generate_chrome_trace_format() 36 | self.update_timeline(chrome_trace) 37 | 38 | 39 | def launch_tensorboard(log_dir, same_process=False, port=6006): 40 | if same_process: 41 | from tensorboard import main as tb 42 | tf.flags.FLAGS.logdir = log_dir 43 | tf.flags.FLAGS.reload_interval = 1 44 | tf.flags.FLAGS.port = port 45 | threading.Thread(target=tb.main).start() 46 | else: 47 | def run_tb(): 48 | os.system('tensorboard --logdir=%s --port=%d' % (log_dir,port)) 49 | threading.Thread(target=run_tb).start() 50 | try: 51 | import phi.local.hostname 52 | host = phi.local.hostname.hostname 53 | except: 54 | host = socket.gethostname() 55 | url = "http://%s:%d/"%(host,port) 56 | return url 57 | -------------------------------------------------------------------------------- /model/text.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | 4 | def exists(val): 5 | return val is not None 6 | 7 | # singleton globals 8 | 9 | MODEL = None 10 | TOKENIZER = None 11 | BERT_MODEL_DIM = 768 12 | 13 | def get_tokenizer(): 14 | global TOKENIZER 15 | if not exists(TOKENIZER): 16 | TOKENIZER = torch.hub.load('huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') 17 | return TOKENIZER 18 | 19 | def get_bert(): 20 | global MODEL 21 | if not exists(MODEL): 22 | MODEL = torch.hub.load('huggingface/pytorch-transformers', 'model', 'bert-base-cased') 23 | if torch.cuda.is_available(): 24 | MODEL = MODEL.cuda() 25 | 26 | return MODEL 27 | 28 | # tokenize 29 | 30 | def tokenize(texts, add_special_tokens = True): 31 | if not isinstance(texts, (list, tuple)): 32 | texts = [texts] 33 | 34 | tokenizer = get_tokenizer() 35 | 36 | encoding = tokenizer.batch_encode_plus( 37 | texts, 38 | add_special_tokens = add_special_tokens, 39 | padding = True, 40 | return_tensors = 'pt' 41 | ) 42 | 43 | token_ids = encoding.input_ids 44 | return token_ids 45 | 46 | # embedding function 47 | 48 | @torch.no_grad() 49 | def bert_embed( 50 | token_ids, 51 | return_cls_repr = False, 52 | eps = 1e-8, 53 | pad_id = 0. 54 | ): 55 | model = get_bert() 56 | mask = token_ids != pad_id 57 | 58 | if torch.cuda.is_available(): 59 | token_ids = token_ids.cuda() 60 | mask = mask.cuda() 61 | 62 | outputs = model( 63 | input_ids = token_ids, 64 | attention_mask = mask, 65 | output_hidden_states = True 66 | ) 67 | 68 | hidden_state = outputs.hidden_states[-1] 69 | 70 | if return_cls_repr: 71 | return hidden_state[:, 0] # return [cls] as representation 72 | 73 | if not exists(mask): 74 | return hidden_state.mean(dim = 1) 75 | 76 | mask = mask[:, 1:] # mean all tokens excluding [cls], accounting for length 77 | mask = rearrange(mask, 'b n -> b n 1') 78 | 79 | numer = (hidden_state[:, 1:] * mask).sum(dim = 1) 80 | denom = mask.sum(dim = 1) 81 | masked_mean = numer / (denom + eps) 82 | return masked_mean 83 | -------------------------------------------------------------------------------- /phi/solver/cuda/benchmarks/benchmark3d.py: -------------------------------------------------------------------------------- 1 | from matplotlib.ticker import FormatStrFormatter 2 | 3 | from phi.backend.base import load_tensorflow 4 | from phi.solver.cuda.cuda import CudaPressureSolver 5 | from phi.solver.sparse import SparseCGPressureSolver 6 | import matplotlib.pyplot as plt 7 | from phi.solver.cuda.benchmarks.benchmark_utils import * 8 | 9 | cudaSolver = CudaPressureSolver() 10 | sparseCGSolver = SparseCGPressureSolver() 11 | 12 | # configuration of the benchmark 13 | warmup = 5 14 | testruns = 25 15 | 16 | dimension = 3 17 | accuracy = 1e-5 18 | batch_size = 1 19 | 20 | cpuTests = []#[8, 16, 32, 64, 128] 21 | tfTests = []#[8, 16, 32, 64, 128] 22 | cudaTests = [8, 16, 32, 64, 128]#, 256] 23 | 24 | 25 | # benchmark 26 | load_tensorflow() 27 | cudaTimes = benchmark_pressure_solve(cudaSolver, cudaTests, dimension, tf.float32, warmup, testruns, accuracy, batch_size) 28 | tfTimes = benchmark_pressure_solve(sparseCGSolver, tfTests, dimension, tf.float64, warmup, testruns, accuracy, batch_size) 29 | cpuTimes = benchmark_pressure_solve(sparseCGSolver, cpuTests, dimension, tf.float64, warmup, testruns, accuracy, batch_size, cpu=True) 30 | 31 | cudaAVG = [np.mean(a) for a in cudaTimes] 32 | cudaSTD = [np.std(a) for a in cudaTimes] 33 | tfAVG = [np.mean(a) for a in tfTimes] 34 | tfSTD = [np.std(a) for a in tfTimes] 35 | cpuAVG = [np.mean(a) for a in cpuTimes] 36 | cpuSTD = [np.std(a) for a in cpuTimes] 37 | 38 | # serialize and print all data necessary for the graph 39 | print("cudaTests = " + str(cudaTests)) 40 | print("cudaAVG = " + str(cudaAVG)) 41 | print("cudaSTD = " + str(cudaSTD)) 42 | print("tfTests = " + str(tfTests)) 43 | print("tfAVG = " + str(tfAVG)) 44 | print("tfSTD = " + str(tfSTD)) 45 | print("cpuTests = " + str(cpuTests)) 46 | print("cpuAVG = " + str(cpuAVG)) 47 | print("cpuSTD = " + str(cpuSTD)) 48 | 49 | plt.errorbar(tfTests, tfAVG, tfSTD, fmt='-o') 50 | plt.errorbar(cpuTests, cpuAVG, cpuSTD, fmt='-o') 51 | plt.errorbar(cudaTests, cudaAVG, cudaSTD, fmt='-o') 52 | 53 | plt.legend(['Tensorflow GPU', 'Tensorflow CPU', 'CUDA'], loc='upper left') 54 | plt.xscale('log', basex=2) 55 | plt.yscale('log') 56 | plt.xticks(cudaTests) 57 | ax = plt.gca() 58 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f')) 59 | plt.xlabel("Grid Dimension 3D") 60 | plt.ylabel("Computation Time in seconds") 61 | plt.show() 62 | -------------------------------------------------------------------------------- /phi/solver/cuda/benchmarks/benchmark2d.py: -------------------------------------------------------------------------------- 1 | from matplotlib.ticker import FormatStrFormatter 2 | 3 | from phi.backend.base import load_tensorflow 4 | from phi.solver.cuda.cuda import CudaPressureSolver 5 | from phi.solver.sparse import SparseCGPressureSolver 6 | import matplotlib.pyplot as plt 7 | from phi.solver.cuda.benchmarks.benchmark_utils import * 8 | 9 | cudaSolver = CudaPressureSolver() 10 | sparseCGSolver = SparseCGPressureSolver() 11 | 12 | # configuration of the benchmark 13 | warmup = 5 14 | testruns = 25 15 | 16 | dimension = 2 17 | accuracy = 1e-5 18 | batch_size = 1 19 | 20 | cpuTests = [] #[16, 32, 64]#, 128, 256, 512]#, 1024]#, 2048] 21 | tfTests = [] #[16, 32, 64, 128, 256, 512, 1024]#, 2048] 22 | cudaTests = [16, 32, 64, 128, 256, 512, 1024, 2048] 23 | 24 | # benchmark 25 | load_tensorflow() 26 | cudaTimes = benchmark_pressure_solve(cudaSolver, cudaTests, dimension, tf.float32, warmup, testruns, accuracy, batch_size) 27 | tfTimes = benchmark_pressure_solve(sparseCGSolver, tfTests, dimension, tf.float64, warmup, testruns, accuracy, batch_size) 28 | cpuTimes = benchmark_pressure_solve(sparseCGSolver, cpuTests, dimension, tf.float64, warmup, testruns, accuracy, batch_size, cpu=True) 29 | 30 | cudaAVG = [np.mean(a) for a in cudaTimes] 31 | cudaSTD = [np.std(a) for a in cudaTimes] 32 | tfAVG = [np.mean(a) for a in tfTimes] 33 | tfSTD = [np.std(a) for a in tfTimes] 34 | cpuAVG = [np.mean(a) for a in cpuTimes] 35 | cpuSTD = [np.std(a) for a in cpuTimes] 36 | 37 | # serialize and print all data necessary for the graph 38 | print("cudaTests = " + str(cudaTests)) 39 | print("cudaAVG = " + str(cudaAVG)) 40 | print("cudaSTD = " + str(cudaSTD)) 41 | print("tfTests = " + str(tfTests)) 42 | print("tfAVG = " + str(tfAVG)) 43 | print("tfSTD = " + str(tfSTD)) 44 | print("cpuTests = " + str(cpuTests)) 45 | print("cpuAVG = " + str(cpuAVG)) 46 | print("cpuSTD = " + str(cpuSTD)) 47 | 48 | plt.errorbar(tfTests, tfAVG, tfSTD, fmt='-o') 49 | plt.errorbar(cpuTests, cpuAVG, cpuSTD, fmt='-o') 50 | plt.errorbar(cudaTests, cudaAVG, cudaSTD, fmt='-o') 51 | 52 | plt.legend(['Tensorflow GPU', 'Tensorflow CPU', 'CUDA'], loc='upper left') 53 | plt.xscale('log', basex=2) 54 | plt.yscale('log') 55 | plt.xticks(cudaTests) 56 | ax = plt.gca() 57 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f')) 58 | plt.xlabel("Grid Dimension 2D") 59 | plt.ylabel("Computation Time in seconds") 60 | plt.show() 61 | -------------------------------------------------------------------------------- /phi/control/nets/force/forcenets.py: -------------------------------------------------------------------------------- 1 | from phi.tf.flow import * 2 | from phi.tf.util import residual_block 3 | import inspect, os 4 | 5 | 6 | def forcenet2d_3x_16(initial_density, initial_velocity, target_velocity, training=False, trainable=True, reuse=tf.AUTO_REUSE): 7 | with tf.variable_scope("ForceNet"): 8 | y = tf.concat([initial_velocity.staggered[:, :-1, :-1, :], initial_density, target_velocity.staggered[:, :-1, :-1, :]], axis=-1) 9 | downres_steps = 3 10 | downres_padding = sum([2 ** i for i in range(downres_steps)]) 11 | y = tf.pad(y, [[0, 0], [0, downres_padding], [0, downres_padding], [0, 0]]) 12 | resolutions = [ y ] 13 | filter_count = 16 14 | res_block_count = 2 15 | for i in range(downres_steps): # 1/2, 1/4 16 | y = tf.layers.conv2d(resolutions[0], filter_count, 2, strides=2, activation=tf.nn.relu, padding="valid", name="downconv_%d"%i, trainable=trainable, reuse=reuse) 17 | for j, nb_channels in enumerate([filter_count] * res_block_count): 18 | y = residual_block(y, nb_channels, name="downrb_%d_%d" % (i,j), training=training, trainable=trainable, reuse=reuse) 19 | resolutions.insert(0, y) 20 | 21 | for j, nb_channels in enumerate([filter_count] * res_block_count): 22 | y = residual_block(y, nb_channels, name="centerrb_%d" % j, training=training, trainable=trainable, reuse=reuse) 23 | 24 | for i in range(1, len(resolutions)): 25 | y = upsample2x(y) 26 | res_in = resolutions[i][:, 0:y.shape[1], 0:y.shape[2], :] 27 | y = tf.concat([y, res_in], axis=-1) 28 | if i < len(resolutions)-1: 29 | y = tf.pad(y, [[0, 0], [0, 1], [0, 1], [0, 0]], mode="SYMMETRIC") 30 | y = tf.layers.conv2d(y, filter_count, 2, 1, activation=tf.nn.relu, padding="valid", name="upconv_%d" % i, trainable=trainable, reuse=reuse) 31 | for j, nb_channels in enumerate([filter_count] * res_block_count): 32 | y = residual_block(y, nb_channels, 2, name="uprb_%d_%d" % (i, j), training=training, trainable=trainable, reuse=reuse) 33 | else: 34 | # Last iteration 35 | y = tf.pad(y, [[0,0], [1,1], [1,1], [0,0]], mode="SYMMETRIC") 36 | y = tf.layers.conv2d(y, 2, 2, 1, activation=None, padding="valid", name="upconv_%d"%i, trainable=trainable, reuse=reuse) 37 | force = StaggeredGrid(y) 38 | path = os.path.join(os.path.dirname(inspect.getabsfile(forcenet2d_3x_16)), "forcenet2d_3x_16") 39 | return force, path -------------------------------------------------------------------------------- /phi/solver/cuda/benchmarks/floatingerror.py: -------------------------------------------------------------------------------- 1 | from matplotlib.ticker import FormatStrFormatter 2 | 3 | from phi.solver.cuda.cuda import CudaPressureSolver 4 | from phi.solver.sparse import SparseCGPressureSolver, load_tensorflow 5 | import matplotlib.pyplot as plt 6 | from phi.solver.cuda.benchmarks.benchmark_utils import * 7 | 8 | cudaSolver = CudaPressureSolver() 9 | numpySolver = SparseCGPressureSolver() 10 | load_tensorflow() 11 | 12 | testruns = 20 13 | accuracy = 1e-5 14 | tests2d = [16, 32, 64, 128, 256, 512, 1024] 15 | dimension = 2 16 | 17 | error2dAbs, error2dRel = benchmark_error(cudaSolver, numpySolver, tests2d, dimension, testruns, accuracy) 18 | 19 | error2dAbsAVG = [np.mean(a) for a in error2dAbs] 20 | error2dAbsSTD = [np.std(a) for a in error2dAbs] 21 | error2dRelAVG = [np.mean(a) for a in error2dRel] 22 | error2dRelSTD = [np.std(a) for a in error2dRel] 23 | 24 | print('tests2d = ' + str(tests2d)) 25 | print('error2dAbsAVG = ' + str(error2dAbsAVG)) 26 | print('error2dAbsSTD = ' + str(error2dAbsSTD)) 27 | print('error2dRelAVG = ' + str(error2dRelAVG)) 28 | print('error2dRelSTD = ' + str(error2dRelSTD)) 29 | 30 | gc.collect() 31 | 32 | tests3d = [16, 32, 64, 128] 33 | dimension = 3 34 | 35 | error3dAbs, error3dRel = benchmark_error(cudaSolver, numpySolver, tests3d, dimension, testruns, accuracy) 36 | 37 | error3dAbsAVG = [np.mean(a) for a in error3dAbs] 38 | error3dAbsSTD = [np.std(a) for a in error3dAbs] 39 | error3dRelAVG = [np.mean(a) for a in error3dRel] 40 | error3dRelSTD = [np.std(a) for a in error3dRel] 41 | 42 | 43 | print('tests2d = ' + str(tests2d)) 44 | print('error2dAbsAVG = ' + str(error2dAbsAVG)) 45 | print('error2dAbsSTD = ' + str(error2dAbsSTD)) 46 | print('error2dRelAVG = ' + str(error2dRelAVG)) 47 | print('error2dRelSTD = ' + str(error2dRelSTD)) 48 | 49 | print('tests3d = ' + str(tests3d)) 50 | print('error3dAbsAVG = ' + str(error3dAbsAVG)) 51 | print('error3dAbsSTD = ' + str(error3dAbsSTD)) 52 | print('error3dRelAVG = ' + str(error3dRelAVG)) 53 | print('error3dRelSTD = ' + str(error3dRelSTD)) 54 | 55 | 56 | plt.errorbar(tests2d, error2dAbsAVG, error2dAbsSTD, fmt='-o') 57 | plt.errorbar(tests2d, error2dRelAVG, error2dRelSTD, fmt='-o') 58 | plt.errorbar(tests3d, error3dAbsAVG, error3dAbsSTD, fmt='-o') 59 | plt.errorbar(tests3d, error3dRelAVG, error3dRelSTD, fmt='-o') 60 | 61 | plt.legend(['Absolute error 2D', 'Relative error 2D', 'Absolute error 3D', 'Relative error 3D'], loc='bottom right') 62 | plt.xscale('log', basex=2) 63 | plt.yscale('log') 64 | plt.xticks(tests2d) 65 | ax = plt.gca() 66 | ax.xaxis.set_major_formatter(FormatStrFormatter('%.0f')) 67 | plt.xlabel("Grid Dimension") 68 | plt.ylabel("Error compared to CPU result") 69 | plt.show() 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /phi/control/nets/project/projectnets.py: -------------------------------------------------------------------------------- 1 | from phi.tf.flow import * 2 | from phi.tf.util import residual_block 3 | import inspect, os 4 | 5 | 6 | def projectnet2d_5x_8(velocity, training=False, trainable=True, reuse=None, scope="projectnet"): 7 | with tf.variable_scope(scope): 8 | y = velocity.staggered[:, :-1, :-1, :] 9 | downres_padding = sum([2**i for i in range(5)]) 10 | y = tf.pad(y, [[0,0], [0,downres_padding], [0,downres_padding], [0,0]]) 11 | resolutions = [y] 12 | for i, filters in enumerate([4, 8, 8, 8, 8]): 13 | y = tf.layers.conv2d(resolutions[0], filters, 2, strides=2, activation=tf.nn.relu, padding="valid", 14 | name="downconv_%d" % i, trainable=trainable, reuse=reuse) 15 | for j in range(2): 16 | y = residual_block(y, filters, name="downrb_%d_%d" % (i, j), training=training, trainable=trainable, 17 | reuse=reuse) 18 | resolutions.insert(0, y) 19 | 20 | for j, nb_channels in enumerate([8, 8]): 21 | y = residual_block(y, nb_channels, name="centerrb_%d" % j, training=training, trainable=trainable, reuse=reuse) 22 | 23 | for i, resolution_data in enumerate(resolutions[1:]): 24 | y = upsample2x(y) 25 | res_in = resolution_data[:, 0:y.shape[1], 0:y.shape[2], :] 26 | y = tf.concat([y, res_in], axis=-1) 27 | if i < len(resolutions) - 2: 28 | y = tf.pad(y, [[0, 0], [0, 1], [0, 1], [0, 0]], mode="SYMMETRIC") 29 | y = tf.layers.conv2d(y, 8, 2, 1, activation=tf.nn.relu, padding="valid", name="upconv_%d" % i, 30 | trainable=trainable, reuse=reuse) 31 | for j, nb_channels in enumerate([8, 8]): 32 | y = residual_block(y, nb_channels, 2, name="uprb_%d_%d" % (i, j), training=training, 33 | trainable=trainable, reuse=reuse) 34 | else: 35 | # Last iteration 36 | boundary_feature = tf.ones_like(y[...,0:1]) 37 | boundary_feature = tf.pad(boundary_feature, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") 38 | y = tf.pad(y, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="SYMMETRIC") 39 | y = math.concat([y, boundary_feature], axis=-1) 40 | y = tf.layers.conv2d(y, 1, 2, 1, activation=None, padding="valid", name="upconv_%d" % i, 41 | trainable=trainable, reuse=reuse) 42 | 43 | velocity = StaggeredGrid(y).curl() 44 | path = os.path.join(os.path.dirname(inspect.getabsfile(projectnet2d_5x_8)), "projectnet2d_5x_8") 45 | return velocity, path -------------------------------------------------------------------------------- /dataset/data_1d.py: -------------------------------------------------------------------------------- 1 | import scipy.io 2 | import numpy as np 3 | import h5py 4 | import pdb 5 | import pickle 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import Dataset 9 | from typing import Tuple 10 | import sys, os 11 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 12 | from IPython import embed 13 | 14 | def get_burgers_preprocess( 15 | rescaler=None, 16 | stack_u_and_f=False, 17 | pad_for_2d_conv=False, 18 | partially_observed_fill_zero_unobserved=None, 19 | ): 20 | if rescaler is None: 21 | raise NotImplementedError('Should specify rescaler. If no rescaler is not used, specify 1.') 22 | 23 | def preprocess(db): 24 | '''We are only returning f and u for now, in the shape of 25 | (u0, u1, ..., f0, f1, ...) 26 | ''' 27 | 28 | u = db['u'] 29 | f = db['f'] 30 | f = f[:,:15] 31 | 32 | fill_zero_unobserved = partially_observed_fill_zero_unobserved 33 | if fill_zero_unobserved is not None: 34 | if fill_zero_unobserved == 'front_rear_quarter': 35 | u = u.squeeze() 36 | nx = u.shape[-1] 37 | u[..., nx // 4: (nx * 3) // 4] = 0 38 | else: 39 | raise ValueError('Unknown partially observed mode') 40 | 41 | if stack_u_and_f: 42 | assert pad_for_2d_conv 43 | nt = f.size(-2) 44 | f = nn.functional.pad(f, (0, 0, 0, 16 - nt), 'constant', 0) 45 | u = nn.functional.pad(u, (0, 0, 0, 15 - nt), 'constant', 0) 46 | u_target = u 47 | data = torch.stack((u, f, u_target), dim=1) 48 | else: 49 | assert not pad_for_2d_conv 50 | data = torch.cat((u, f), dim=1) 51 | 52 | data = data / rescaler 53 | return data 54 | 55 | return preprocess 56 | 57 | 58 | 59 | class DiffusionDataset(Dataset): 60 | def __init__( 61 | self, 62 | fname, 63 | preprocess=get_burgers_preprocess('all'), 64 | load_all=True 65 | ): 66 | ''' 67 | Arguments: 68 | 69 | ''' 70 | self.load_all = load_all 71 | if load_all: 72 | self.db = torch.load(fname) 73 | self.x = preprocess(self.db) 74 | else: 75 | raise NotImplementedError 76 | 77 | def __len__(self): 78 | if self.load_all: 79 | return self.x.size(0) 80 | else: 81 | raise NotImplementedError 82 | 83 | def __getitem__(self, idx): 84 | if self.load_all: 85 | return self.x[idx] 86 | else: 87 | raise NotImplementedError 88 | 89 | def get(self, idx): 90 | return self.__getitem__(idx) 91 | 92 | def len(self): 93 | return self.__len__() -------------------------------------------------------------------------------- /phi/solver/cuda/src/laplace_op.cc: -------------------------------------------------------------------------------- 1 | //nsight -vm /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java 2 | 3 | 4 | #include "tensorflow/core/framework/op.h" 5 | #include "tensorflow/core/framework/op_kernel.h" 6 | #include "tensorflow/core/framework/shape_inference.h" 7 | 8 | using namespace tensorflow; // NOLINT(build/namespaces) 9 | 10 | REGISTER_OP("LaplaceMatrix") 11 | .Input("dimensions: int32") 12 | .Input("mask_dimensions: int32") 13 | .Input("active_mask: float32") 14 | .Input("fluid_mask: float32") 15 | .Attr("dim_product: int") 16 | .Input("input_laplace_matrix: int8") 17 | .Output("laplace_matrix: int8"); 18 | //.SetIsStateful(); 19 | 20 | void LaplaceMatrixKernelLauncher(const int *dimensions, const int dim_size, const int dimProduct, const float *active_mask, const float *fluid_mask, const int *maskDimensions, signed char *laplaceMatrix, int *cords); 21 | 22 | class LaplaceMatrixOp : public OpKernel 23 | { 24 | private: 25 | int dim_product; 26 | 27 | public: 28 | explicit LaplaceMatrixOp(OpKernelConstruction *context) : OpKernel(context) { 29 | context->GetAttr("dim_product", &dim_product); 30 | } 31 | 32 | void Compute(OpKernelContext *context) override 33 | { 34 | // This Op is only required for benchmarking the Laplace Matrix generation speed. 35 | // The pressure solve Op calls the LaplaceMatrixKernelLauncher before solving the pressure 36 | const Tensor &input_dimensions = context->input(0); 37 | const Tensor &input_mask_dimensions = context->input(1); 38 | const Tensor &input_active_mask = context->input(2); 39 | const Tensor &input_fluid_mask = context->input(3); 40 | Tensor input_laplace_matrix = context->input(4); 41 | 42 | auto dimensions = input_dimensions.flat(); 43 | auto mask_dimensions = input_mask_dimensions.flat(); 44 | auto active_mask = input_active_mask.flat(); 45 | auto fluid_mask = input_fluid_mask.flat(); 46 | auto laplace_matrix = input_laplace_matrix.flat(); 47 | 48 | int dim_size = dimensions.size(); 49 | 50 | Tensor cords; 51 | TensorShape cords_shape; 52 | cords_shape.AddDim(dim_product); 53 | cords_shape.AddDim(input_dimensions.dims()); 54 | 55 | context->set_output(0, input_laplace_matrix); 56 | 57 | OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, cords_shape, &cords)); 58 | 59 | auto cords_flat = cords.flat(); 60 | 61 | LaplaceMatrixKernelLauncher(dimensions.data(), dim_size, dim_product, active_mask.data(), fluid_mask.data(), mask_dimensions.data(), laplace_matrix.data(), cords_flat.data()); 62 | } 63 | }; 64 | 65 | REGISTER_KERNEL_BUILDER(Name("LaplaceMatrix").Device(DEVICE_GPU), LaplaceMatrixOp); 66 | -------------------------------------------------------------------------------- /phi/solver/manta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import mantatensor.mantatensor_bindings as mt 4 | import mantatensor.mantatensor_gradients # registers gradients 5 | from phi.solver.base import PressureSolver 6 | 7 | 8 | class MantaSolver(PressureSolver): 9 | 10 | def __init__(self): 11 | super(MantaSolver, self).__init__("Manta") 12 | 13 | def solve(self, divergence, accuracy=1e-05): 14 | return mt_solve_pressure(divergence, self.fluid_mask, accuracy) 15 | 16 | def set_fluid_mask(self, fluid_mask): 17 | self.fluid_mask = fluid_mask 18 | 19 | 20 | 21 | def mt_solve_pressure(divergence, fluid_mask, accuracy): 22 | dimensions = list(divergence.shape[1:-1]) 23 | 24 | neg_div = - divergence 25 | batches = neg_div.shape[0] 26 | try: 27 | batches = int(batches) 28 | except: 29 | raise ValueError("Manta solver requires fixed batch size") 30 | 31 | if len(dimensions) == 3: 32 | velocity_mac = tf.zeros([batches] + dimensions + [3]) 33 | scalar_shape_mt = neg_div.shape 34 | else: 35 | scalar_shape_mt = [batches, 1] + dimensions + [1] 36 | velocity_mac = tf.zeros([batches, 1] + dimensions + [3]) 37 | neg_div = tf.reshape(neg_div, scalar_shape_mt) 38 | 39 | flags_tensor = tf.constant(flags_array(fluid_mask, dimensions), name='flag_grid') 40 | flags_tensor = tf.tile(flags_tensor, [batches, 1, 1, 1, 1]) 41 | 42 | 43 | pressure_var = tf.Variable(np.zeros(scalar_shape_mt, dtype=np.float32), name='pressure') 44 | pressure_out = mt.solve_pressure_system(neg_div, velocity_mac, pressure_var, flags_tensor, 1, batches, 45 | cgAccuracy=accuracy) 46 | 47 | return to_tensorflow_scalar(pressure_out, dimensions) 48 | 49 | 50 | def flags_array(fluid_mask, dimensions): 51 | flags = (2 - fluid_mask).astype(np.int32) 52 | if len(dimensions) == 3: 53 | return flags.reshape((1,) + flags.shape) 54 | elif len(dimensions) == 2: 55 | return flags.reshape((1, 1) + flags.shape) 56 | else: 57 | raise ValueError("Only 2 and 3 dimensions supported") 58 | 59 | 60 | def to_tensorflow_scalar(field, dimensions): 61 | if len(field.shape) != len(dimensions) + 2: 62 | field = field[:, 0, :, :, :] 63 | return field 64 | 65 | 66 | def to_tensorflow_vector(field, dimensions): 67 | if len(field.shape) != len(dimensions) + 2: 68 | field = field[:, 0, :, :, :] 69 | if field.shape[-1] != len(dimensions): 70 | field = field[..., 0:len(dimensions)] 71 | return field 72 | 73 | 74 | # def _to_mantaflow_3vec(field): 75 | # if isinstance(field, StaggeredGrid): 76 | # field = field.staggered 77 | # if field.shape[-1] == 2: 78 | # backend.pad(field, [[0,0]]*(len(field.shape)-1) + [[0,1]]) 79 | # if len(field.shape) == 4: 80 | # field = backend.reshape(field, [1] + list(field.shape)) 81 | # 82 | # if len(field.shape) != 5 or field.shape[-1] != 3: 83 | # raise ValueError("Cannot convert field of shape {} to mantaflow".format(field.shape)) 84 | # return field -------------------------------------------------------------------------------- /phi/control/distances.py: -------------------------------------------------------------------------------- 1 | from phi.math.nd import * 2 | 3 | 4 | def _initial_dijkstra_tensor(target_mask): 5 | finite_mask = target_mask != 0 6 | array = np.ones_like(target_mask, np.float32) * np.inf 7 | array[finite_mask] = 0 8 | return array 9 | 10 | 11 | def _dijkstra_step(tensor): # (batch, spatial_dims, 1) 12 | rank = spatial_rank(tensor) 13 | dims = range(rank) 14 | all_dims = range(len(tensor.shape)) 15 | center_min = tensor 16 | for dimension in dims: 17 | padded = np.pad(tensor, [[1,1] if i-1 == dimension else [0,0] for i in all_dims], "constant", constant_values=np.inf) 18 | upper_slices = [(slice(2, None) if i-1 == dimension else slice(None)) for i in all_dims] 19 | lower_slices = [(slice(-2) if i-1 == dimension else slice(None)) for i in all_dims] 20 | center_min_dim = np.minimum(padded[lower_slices]+1, padded[upper_slices]+1) 21 | center_min = np.minimum(center_min, center_min_dim) 22 | return center_min 23 | 24 | 25 | def l1_distance_map(target_mask, fluid_mask=None, non_fluid_value=-1): 26 | """ 27 | Calculates the shortest distance from all grid points to the nearest entry in target_mask. 28 | Neighbouring cells are separated by distance 1. All resulting distances are integers. 29 | :param target_mask: Mask encoding points of distance 0. Shape (batch, spatial_dims..., 1) 30 | :param fluid_mask: Mask encoding the domain topology (same shape as target_mask) 31 | :param non_fluid_value: This value will be used in the returned tensor for non-fluid cells (fluid_mask==0). 32 | :return: A tensor of same shape as target_mask containing the shortest distances 33 | """ 34 | tensor = _initial_dijkstra_tensor(target_mask) 35 | obstacle_cell_count = 0 if fluid_mask is None else np.sum(fluid_mask == 0) 36 | while True: 37 | prev_tensor = tensor 38 | tensor = _dijkstra_step(tensor) 39 | if fluid_mask is not None: 40 | tensor[fluid_mask == 0] = np.inf 41 | 42 | if len(tensor[~np.isfinite(tensor)]) == obstacle_cell_count: 43 | tensor[~np.isfinite(tensor)] = non_fluid_value 44 | return tensor 45 | if np.array_equal(prev_tensor, tensor): 46 | raise ValueError("Unconnected regions detected, failed to create distance map") 47 | 48 | 49 | def shortest_halfway_point(distribution1, distribution2, fluid_mask=None): 50 | distances1 = l1_distance_map(distribution1, fluid_mask, non_fluid_value=np.inf) 51 | distances2 = l1_distance_map(distribution2, fluid_mask, non_fluid_value=np.inf) 52 | with np.errstate(invalid="ignore"): 53 | equal_points = np.abs(distances1-distances2) <= 1 54 | # Find lowest indices 55 | min_dist = np.min(distances1[equal_points]) 56 | shortest_equal_points = equal_points & (distances1 == min_dist) 57 | indices = np.argwhere(shortest_equal_points) 58 | mean_index = np.mean(indices, 0) 59 | return np.round(mean_index).astype(np.int) 60 | 61 | 62 | # fluid_mask = np.ones([1,8,8,1]) 63 | # fluid_mask[:, 4, 1:7, :] = 0 64 | # p1 = np.zeros([1,8,8,1]) 65 | # p1[:,0, 4, :] = 1 66 | # p2 = np.zeros([1,8,8,1]) 67 | # p2[:, 7, 4, :] = 1 68 | # print(shortest_halfway_point(p1, p2, fluid_mask)) -------------------------------------------------------------------------------- /phi/data/augment.py: -------------------------------------------------------------------------------- 1 | from phi.data import * 2 | 3 | 4 | class AugmentChannel(DerivedChannel): 5 | 6 | def __init__(self, aug_dimensions, field, affect_flags=(DATAFLAG_TRAIN,)): 7 | DerivedChannel.__init__(self, [field]) 8 | self.field = self.input_fields[0] 9 | self.affect_flags = affect_flags 10 | self.aug_dimensions = range(aug_dimensions) if isinstance(aug_dimensions, int) else aug_dimensions 11 | self.augmentation_factor = 2 ** len(self.aug_dimensions) 12 | 13 | def affects(self, datasource): 14 | if not self.affect_flags: 15 | return True 16 | for flag in self.affect_flags: 17 | if flag in datasource.flags: 18 | return True 19 | return False 20 | 21 | def size(self, datasource): 22 | return self.field.size(datasource) * (self.augmentation_factor if self.affects(datasource) else 1) 23 | 24 | def shape(self, datasource): 25 | return self.field.shape(datasource) 26 | 27 | def get(self, datasource, indices): 28 | if not self.affects(datasource): 29 | return self.field.get(datasource, indices) 30 | else: 31 | return self._augment_interleave(datasource, indices) 32 | 33 | def _augment_interleave(self, datasource, indices): 34 | arrays = self.field.get(datasource, [i // self.augmentation_factor for i in indices]) 35 | for array, index in zip(arrays, indices): 36 | perm = index % self.augmentation_factor 37 | if perm != 0: 38 | array = self.augment_single(array, perm) 39 | yield array 40 | 41 | def augment_single(self, array, perm): 42 | raise NotImplementedError() 43 | 44 | 45 | class AxisFlip(AugmentChannel): 46 | 47 | def __init__(self, flip_dimensions, field, flip_vectors=True, affect_flags=(DATAFLAG_TRAIN,)): 48 | AugmentChannel.__init__(self, flip_dimensions, field, affect_flags) 49 | self.flip_vectors = flip_vectors 50 | 51 | def augment_single(self, array, perm): 52 | slices = [slice(None, None, -1) if d >= 1 and perm & 2 ** (d - 1) != 0 else slice(None) for d in range(len(array.shape))] 53 | array = array[slices] 54 | if self.flip_vectors and array.shape[-1] == len(array.shape) - 2: 55 | flipped_components = [len(array.shape) - d - 3 for d in self.aug_dimensions if perm & 2 ** (d) != 0] 56 | array[..., flipped_components] *= -1 57 | return array 58 | 59 | 60 | # class SpatialShift(AugmentChannel): 61 | # 62 | # def __init__(self, shift_dimensions, shift, field, padding="symmetric", affect_flags=(DATAFLAG_TRAIN,)): 63 | # AugmentChannel.__init__(self, shift_dimensions, field, affect_flags) 64 | # self.shift = shift 65 | # self.padding = padding 66 | # 67 | # def shape(self, datasource): 68 | # input_shape = self.field.shape(datasource) 69 | # if self.padding is not None: 70 | # return input_shape 71 | # else: 72 | # input_shape = np.ndarray(input_shape) 73 | # input_shape[1:-1] -= 2 74 | # return input_shape 75 | # 76 | # def augment_single(self, array, perm): 77 | # slices = [slice(None, None, -1) if d >= 1 and perm & 2 ** (d - 1) != 0 else slice(None) for d in range(len(array.shape))] 78 | # 79 | # def velocity_adjustment(self, field): -------------------------------------------------------------------------------- /phi/solver/cuda/cuda.py: -------------------------------------------------------------------------------- 1 | from phi.solver.base import ExplicitBoundaryPressureSolver 2 | import tensorflow as tf 3 | from phi import math 4 | import numpy as np 5 | 6 | 7 | class CudaPressureSolver(ExplicitBoundaryPressureSolver): 8 | 9 | def __init__(self): 10 | super(CudaPressureSolver, self).__init__("CUDA Conjugate Gradient") 11 | import os 12 | current_dir = os.path.dirname(os.path.realpath(__file__)) 13 | self.pressure_op = tf.load_op_library(current_dir + "/build/pressure_solve_op.so") 14 | 15 | def solve_with_boundaries(self, divergence, active_mask, fluid_mask, accuracy=1e-5, pressure_guess=None, # pressure_guess is not used in this implementation => Kernel automatically takes the last pressure value for initial_guess 16 | max_iterations=2000, gradient_accuracy=None, return_loop_counter=False): 17 | 18 | def pressure_gradient(op, grad): 19 | return self.cuda_solve_forward(grad, active_mask, fluid_mask, accuracy, max_iterations)[0] 20 | 21 | pressure_out, iterations = math.with_custom_gradient(self.cuda_solve_forward, 22 | [divergence, active_mask, fluid_mask, accuracy, max_iterations], 23 | pressure_gradient, input_index=0, output_index=0, name_base="cuda_pressure_solve") 24 | 25 | if return_loop_counter: 26 | return pressure_out, iterations 27 | else: 28 | return pressure_out 29 | 30 | def cuda_solve_forward(self, divergence, active_mask, fluid_mask, accuracy, max_iterations): 31 | dimensions = divergence.get_shape()[1:-1] 32 | dimensions = dimensions[::-1] # the custom op needs it in the x,y,z order 33 | dim_array = np.array(dimensions) 34 | dim_product = np.prod(dimensions) 35 | 36 | mask_dimensions = dim_array + 2 37 | 38 | laplace_matrix = tf.zeros(dim_product * (len(dimensions) * 2 + 1), dtype=tf.int8) 39 | 40 | # Helper variables for CG, make sure new memory is allocated for each variable. 41 | one_vector = tf.ones(dim_product, dtype=tf.float32) 42 | p = tf.zeros_like(divergence, dtype=tf.float32) + 1 43 | z = tf.zeros_like(divergence, dtype=tf.float32) + 2 44 | r = tf.zeros_like(divergence, dtype=tf.float32) + 3 45 | pressure = tf.zeros_like(divergence, dtype=tf.float32) + 4 46 | 47 | # Call the custom kernel 48 | pressure_out, iterations = self.pressure_op.pressure_solve(dimensions, 49 | 50 | mask_dimensions, 51 | active_mask, 52 | fluid_mask, 53 | laplace_matrix, 54 | 55 | divergence, 56 | p, r, z, pressure, one_vector, 57 | 58 | dim_product, 59 | accuracy, 60 | max_iterations) 61 | return pressure_out, iterations -------------------------------------------------------------------------------- /phi/solver/explicit.py: -------------------------------------------------------------------------------- 1 | from phi.experimental import * 2 | 3 | 4 | 5 | def explicit_dipole_pressure(div, num=1): 6 | # [filter_height, filter_width, in_channels, out_channels] 7 | filter = np.zeros([3, 3, 3, 3], np.float32) 8 | # pressure (q) 9 | filter[(0, 1, 1, 2), (1, 0, 2, 1), 0, 0] = 1 # edges q 10 | filter[(0, 1, 1, 2), (1, 0, 2, 1), (2, 1, 1, 2), 0] = (+0.0986, +0.0986, -0.0986, -0.0986) # edges px, py 11 | filter[(0, 0, 2, 2), (0, 2, 0, 2), 0, 0] = 0.7071 # corners q 12 | filter[(0, 0, 2, 2), (0, 2, 0, 2), 1, 0] = (0.03288, -0.03288, -0.03288, 0.03288) # corners px 13 | filter[(0, 0, 2, 2), (0, 2, 0, 2), 2, 0] = (0.03288, 0.03288, -0.03288, -0.03288) # corners py 14 | filter[1, 1, 0, 0] = 1.4142 # self-pressure 15 | # pressure gradient 16 | filter[(0, 1, 1, 2), (1, 0, 2, 1), 0, (2, 1, 1, 2)] = (-0.5, -0.5, +0.5, +0.5) # edges q 17 | filter[(0, 1, 1, 2), (1, 0, 2, 1), (2, 1, 1, 2), (2, 1, 1, 2)] = (-0.2347, -0.2347, 0.2347, 0.2347) # edges px, py longitudinal 18 | filter[(0, 1, 1, 2), (1, 0, 2, 1), (1, 2, 2, 1), (1, 2, 2, 1)] = (0.2347/4, 0.2347/4, 0.2347/4, 0.2347/4) # edges px, py transversal 19 | filter[(0, 0, 2, 2), (0, 2, 0, 2), 0, 1] = (-0.3536, +0.3536, -0.3536, +0.3536) # corners q -> px 20 | filter[(0, 0, 2, 2), (0, 2, 0, 2), 0, 2] = (-0.3536, -0.3536, +0.3536, +0.3536) # corners q -> py 21 | filter[1, 1, (1, 2), (1, 2)] = 1./num # self-pressure 22 | # corners px,py -> px,py is comparably small 23 | return tf.nn.conv2d(div, filter, strides=[1, 1, 1, 1], padding="SAME") 24 | 25 | 26 | def explicit_pressure_multigrid(divergence, level_control=False): 27 | rank = spatial_rank(divergence) 28 | dV = 2**rank 29 | size = int(max(divergence.shape[1:])) 30 | 31 | multires_div = [to_dipole_format(divergence)] # order from low-res to high-res 32 | for i in range(math.frexp(float(size))[1] - 2): # downsample until 2x2 33 | p = downsample_dipole_2d_2x(multires_div[0]) 34 | # p = tf.layers.average_pooling2d(multires_div[0], pool_size=[2, 2], strides=2) * dV 35 | multires_div.insert(0, p) 36 | 37 | p_div = None # Divergence of pressure 38 | pressure = None 39 | 40 | pressure_accum = [] 41 | pressure_by_lvl = [] 42 | p_div_by_level = [] 43 | level_scalings = [] 44 | i = 0 45 | 46 | for div_lvl in multires_div: # start with low-res 47 | div = div_lvl 48 | 49 | if p_div is not None: # Upsample previous level and subtract div p 50 | div -= to_dipole_format(p_div) 51 | 52 | pressure_lvl = explicit_dipole_pressure(div, num=len(multires_div)) 53 | pressure_lvl = upsample_flatten_dipole_2d_2x(pressure_lvl) 54 | delta_p_div = laplace(pressure_lvl) 55 | 56 | pressure_by_lvl.append(pressure_lvl) 57 | p_div_by_level.append(delta_p_div) 58 | 59 | if isinstance(level_control, collections.Iterable): 60 | level_scaling = level_control[i] 61 | elif level_control is True: 62 | level_scaling = tf.placeholder(tf.float32, shape=[1, 1, 1, 1], name="lvl_scale_%d"%i) 63 | elif level_control is False: 64 | level_scaling = 1 65 | else: 66 | raise ValueError("illegal level_control: {}".format(level_control)) 67 | level_scalings.append(level_scaling) 68 | 69 | if p_div is None: 70 | pressure = pressure_lvl * level_scaling 71 | p_div = delta_p_div * level_scaling 72 | else: 73 | pressure = upsample2x(pressure) + pressure_lvl * level_scaling 74 | p_div = upsample2x(p_div) / dV + delta_p_div * level_scaling 75 | 76 | pressure_accum.append(pressure) 77 | i += 1 78 | 79 | pressure = tf.layers.average_pooling2d(pressure, [2, 2], [2, 2]) 80 | p_div = tf.layers.average_pooling2d(p_div, [2, 2], [2, 2]) 81 | return pressure, (p_div, pressure_accum, level_scalings) 82 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /phi/control/sequences.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | class PartitioningExecutor(object): 5 | 6 | def create_frame(self, index, step_count): 7 | return Frame(index, type=TYPE_KEYFRAME if index == 0 or index == step_count else TYPE_UNKNOWN) 8 | 9 | def execute_step(self, initial_frame, target_frame): 10 | # type: (Frame, Frame) -> None 11 | print("Execute -> %d" % (initial_frame.index + 1)) 12 | assert initial_frame.type >= TYPE_REAL 13 | target_frame.type = max(TYPE_REAL, target_frame.type) 14 | 15 | def partition(self, n, initial_frame, target_frame, center_frame): 16 | # type: (int, Frame, Frame, Frame) -> None 17 | print("Partition length %d sequence (from %d to %d) at frame %d" % (n, initial_frame.index, target_frame.index, center_frame.index)) 18 | assert initial_frame.type != TYPE_UNKNOWN and target_frame.type != TYPE_UNKNOWN 19 | center_frame.type = TYPE_PLANNED 20 | 21 | 22 | TYPE_UNKNOWN = 0 23 | TYPE_PLANNED = 1 24 | TYPE_REAL = 2 25 | TYPE_KEYFRAME = 3 26 | 27 | class Frame(object): 28 | 29 | def __init__(self, index, type=TYPE_UNKNOWN): 30 | self.index = index 31 | self.type = type 32 | 33 | def next(self): 34 | return self.index + 1 35 | 36 | def __repr__(self): 37 | return "Frame#%d" % self.index 38 | 39 | 40 | class PartitionedSequence(object): 41 | 42 | def __init__(self, step_count, operator): 43 | # type: (int, PartitioningExecutor) -> None 44 | self.step_count = step_count 45 | self.operator = operator 46 | self._frames = [operator.create_frame(i, step_count) for i in range(step_count+1)] 47 | 48 | def execute(self): 49 | self.partition_execute(self.step_count, 0) 50 | 51 | def partition_execute(self, n, start_frame_index, **kwargs): 52 | if n == 1: 53 | self.leaf_execute(self._frames[start_frame_index], self._frames[start_frame_index+1], **kwargs) 54 | else: 55 | self.branch_execute(n, start_frame_index, **kwargs) 56 | 57 | def leaf_execute(self, start_frame, end_frame, **kwargs): 58 | self.operator.execute_step(start_frame, end_frame) 59 | 60 | def branch_execute(self, n, start_frame_index, **kwargs): 61 | raise NotImplementedError() 62 | 63 | def partition(self, n, start_frame_index): 64 | self.operator.partition(n, self._frames[start_frame_index], self._frames[start_frame_index + n], 65 | self._frames[start_frame_index + n // 2]) 66 | 67 | def __getitem__(self, item): 68 | return self._frames[item] 69 | 70 | def __len__(self): 71 | return len(self._frames) 72 | 73 | def __iter__(self): 74 | return self._frames.__iter__() 75 | 76 | 77 | 78 | 79 | class TreeSequence(PartitionedSequence): 80 | 81 | def __init__(self, step_count, operator): 82 | PartitionedSequence.__init__(self, step_count, operator) 83 | 84 | def branch_execute(self, n, start_frame_index, **kwargs): 85 | self.partition(n, start_frame_index) 86 | self.partition_execute(n//2, start_frame_index) 87 | self.partition_execute(n//2, start_frame_index+n//2) 88 | 89 | 90 | class AdaptivePlanSequence(PartitionedSequence): 91 | 92 | def __init__(self, step_count, operator): 93 | PartitionedSequence.__init__(self, step_count, operator) 94 | 95 | def branch_execute(self, n, start_frame_index, update_target=False, **kwargs): 96 | self.partition(n, start_frame_index) 97 | self.partition_execute(n // 2, start_frame_index, update_target=True) 98 | if update_target: 99 | self.partition(n, start_frame_index + n) 100 | self.partition(n, start_frame_index + n // 2) 101 | self.partition_execute(n // 2, start_frame_index + n // 2, update_target=update_target) 102 | 103 | 104 | # AdaptivePlanSequence(8, PartitioningExecutor()).execute() -------------------------------------------------------------------------------- /phi/solver/net.py: -------------------------------------------------------------------------------- 1 | import collections, math, os.path, inspect 2 | from phi.solver.base import PressureSolver 3 | from phi.experimental import * 4 | 5 | 6 | class NetworkSolver(PressureSolver): 7 | 8 | def __init__(self): 9 | super(NetworkSolver, self).__init__("Net") 10 | 11 | def solve(self, divergence, active_mask, fluid_mask, boundaries, accuracy, pressure_guess=None, **kwargs): 12 | base_path = os.path.dirname(inspect.getfile(inspect.currentframe())) 13 | ckpt = os.path.join(base_path, "data/pnet_/modelsave.ckpt") 14 | return solve_pressure_tompson2(divergence, level_control=False, constants_file=ckpt, **kwargs)[0] 15 | 16 | 17 | 18 | 19 | def tompson2_pressure(div, constants_file=None): 20 | conv = conv_function("Tompson2", constants_file=constants_file) 21 | n = div 22 | n = conv(n, filters=8, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, name="conv1") 23 | n = conv(n, filters=16, kernel_size=[3, 3], padding="same", activation=tf.nn.relu, name="conv2") 24 | n = conv(n, filters=1, kernel_size=[1, 1], padding="same", activation=None, name="conv_out") 25 | return n 26 | 27 | 28 | def tompson2_load(sess, graph, path="./data/tompson2/modelsave.ckpt"): 29 | restore_net("Tompson2", sess, graph, path) 30 | 31 | 32 | def solve_pressure_tompson2(divergence, level_control=False, constants_file=None, cubic=True): 33 | rank = spatial_rank(divergence) 34 | dV = 2**rank 35 | size = int(max(divergence.shape[1:])) 36 | 37 | # if cubic: 38 | # resize = tf.image.resize_bicubic 39 | # else: 40 | # resize = tf.image.resize_bilinear 41 | 42 | multires_div = [to_dipole_format(divergence)] # order from low-res to high-res 43 | for i in range(math.frexp(float(size))[1] - 2): # downsample until 2x2 44 | p = downsample_dipole_2d_2x(multires_div[0], scaling="sum") 45 | # p = tf.layers.average_pooling2d(multires_div[0], pool_size=[2, 2], strides=2) * dV 46 | multires_div.insert(0, p) 47 | 48 | p_div = None # Divergence of pressure 49 | pressure = None 50 | 51 | pressure_accum = [] 52 | pressure_by_lvl = [] 53 | p_div_accum = [] 54 | p_div_by_level = [] 55 | remaining_div = [] 56 | level_scalings = [] 57 | i = 0 58 | 59 | for div_lvl in multires_div: # start with low-res 60 | div = div_lvl 61 | 62 | if p_div is not None: # Upsample previous level and subtract div p 63 | double_shape = np.array(pressure.shape[1:-1]) * 2 64 | p_div = upsample2x(p_div) / dV 65 | pressure = upsample2x(pressure)[:,2:-2,2:-2,:] 66 | # p_div = resize(p_div, div.shape[1:-1]) / dV 67 | # pressure = resize(pressure, double_shape)[:,2:-2,2:-2,:] 68 | div -= to_dipole_format(p_div) 69 | 70 | normalized_div, std = normalize_dipole(div) 71 | padded_div = tf.pad(normalized_div, [[0,0]]+[[1,1]]*rank+[[0,0]]) 72 | if pressure is not None: pressure = tf.pad(pressure, [[0,0]]+[[1,1]]*rank+[[0,0]], mode="SYMMETRIC") 73 | pressure_lvl = std * tompson2_pressure(padded_div, constants_file=constants_file) 74 | delta_p_div = laplace(pressure_lvl)[:, 1:-1, 1:-1, :] 75 | 76 | pressure_by_lvl.append(pressure_lvl) 77 | p_div_by_level.append(delta_p_div) 78 | 79 | if isinstance(level_control, collections.Iterable): 80 | level_scaling = level_control[i] 81 | elif level_control is True: 82 | level_scaling = tf.placeholder(tf.float32, shape=[1, 1, 1, 1], name="lvl_scale_%d"%i) 83 | elif level_control is False: 84 | level_scaling = 1 85 | else: 86 | raise ValueError("illegal level_control: {}".format(level_control)) 87 | level_scalings.append(level_scaling) 88 | 89 | if p_div is None: 90 | pressure = pressure_lvl * level_scaling 91 | p_div = delta_p_div * level_scaling 92 | else: 93 | pressure += pressure_lvl * level_scaling 94 | p_div += delta_p_div * level_scaling 95 | 96 | pressure_accum.append(pressure) 97 | i += 1 98 | 99 | return pressure[:,1:-1,1:-1,:], (p_div, pressure_accum, level_scalings) 100 | -------------------------------------------------------------------------------- /phi/control/control_scene.py: -------------------------------------------------------------------------------- 1 | import json, os, inspect, shutil 2 | from phi.fluidformat import * 3 | 4 | 5 | class ControlScene: 6 | 7 | def __init__(self, path, mode="r", index=None): 8 | self.path = path 9 | self.index = index 10 | if mode.lower() == "r": 11 | with open(os.path.join(path, "description.json"), "r") as file: 12 | self.infodict = json.load(file) 13 | elif mode.lower() == "w": 14 | self.infodict = {} 15 | else: 16 | raise ValueError("Illegal mode: %s " %mode) 17 | 18 | def get_final_loss(self, include_reg_loss=True): 19 | final_loss = self.infodict["final_loss"] 20 | if not include_reg_loss and "regloss" in self.infodict: 21 | final_loss -= self.infodict["regloss"] 22 | return final_loss 23 | 24 | def improvement(self): 25 | final_loss = self.get_final_loss(include_reg_loss=False) 26 | initial_loss = self.infodict["initial_loss"] 27 | return initial_loss / final_loss 28 | 29 | @property 30 | def scenetype(self): 31 | return self.infodict["scenetype"] 32 | 33 | def control_frames(self): 34 | return range(self.infodict["n_frames"]) 35 | 36 | def target_density(self): 37 | return read_sim_frames(self.path, ["target density"])[0] 38 | 39 | def get_state(self, index): 40 | return read_sim_frame(self.path, ["density", "velocity", "force"], index, set_missing_to_none=False) 41 | 42 | def time_to_keyframe(self, index): 43 | return self.infodict["n_frames"] - index 44 | 45 | def put(self, dict, save=True): 46 | self.infodict.update(dict) 47 | if save: 48 | with open(os.path.join(self.path, "description.json"), "w") as out: 49 | json.dump(self.infodict, out, indent=2) 50 | 51 | def file(self, name): 52 | return os.path.join(self.path, name) 53 | 54 | def __getitem__(self, key): 55 | return self.infodict[key] 56 | 57 | def __getattr__(self, item): 58 | return self.infodict[item] 59 | 60 | def __str__(self): 61 | return self.path 62 | 63 | def copy_calling_script(self): 64 | script_path = inspect.stack()[1][1] 65 | script_name = os.path.basename(script_path) 66 | src_path = os.path.join(self.path, "src") 67 | os.path.isdir(src_path) or os.mkdir(src_path) 68 | target = os.path.join(self.path, "src", script_name) 69 | shutil.copy(script_path, target) 70 | try: 71 | shutil.copystat(script_path, target) 72 | except: 73 | pass # print("Could not copy file metadata to %s"%target) 74 | 75 | 76 | def list_scenes(directory, category, min=None, max=None): 77 | scenes = [] 78 | if min is None: 79 | i = 1 80 | else: 81 | i = int(min) 82 | while True: 83 | path = os.path.join(directory, category, "sim_%06d/"%i) 84 | if not os.path.isdir(path): break 85 | scenes.append(ControlScene(path, "r", i)) 86 | if max is not None and i == max: break 87 | i += 1 88 | return scenes 89 | 90 | 91 | def new_scene(directory, category): 92 | scenedir = os.path.join(directory, category) 93 | if not os.path.isdir(scenedir): 94 | os.makedirs(scenedir) 95 | next_index = 1 96 | else: 97 | indices = [int(name[4:]) for name in os.listdir(scenedir) if name.startswith("sim_")] 98 | if not indices: 99 | next_index = 1 100 | else: 101 | next_index = max(indices) + 1 102 | path = os.path.join(scenedir, "sim_%06d"%next_index) 103 | os.mkdir(path) 104 | return ControlScene(path, "w", next_index) 105 | 106 | 107 | 108 | 109 | def load_scene_data(scenes): 110 | densities = [] 111 | velocities = [] 112 | forces = [] 113 | targets = [] 114 | remaining_times = [] 115 | 116 | for scene in scenes: 117 | target = scene.target_density() 118 | for i in scene.control_frames(): 119 | density, velocity, force = scene.get_state(i) 120 | remaining_time = scene.time_to_keyframe(i) 121 | densities.append(density) 122 | velocities.append(velocity) 123 | forces.append(force) 124 | remaining_times.append(remaining_time) 125 | targets.append(target) 126 | 127 | return densities, velocities, forces, targets, remaining_times -------------------------------------------------------------------------------- /train/train_2d.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..')) 3 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 4 | from diffusion.diffusion_2d import GaussianDiffusion, Trainer 5 | from dataset.data_2d import Smoke 6 | import pdb 7 | import torch 8 | from accelerate import Accelerator 9 | import datetime 10 | 11 | import argparse 12 | 13 | from IPython import embed 14 | 15 | def load_model_accelerator(model, model_path): 16 | fp16 = False 17 | accelerator = Accelerator( 18 | split_batches = True, 19 | mixed_precision = 'fp16' if fp16 else 'no' 20 | ) 21 | device = accelerator.device 22 | data = torch.load(model_path, map_location=device) 23 | model = accelerator.unwrap_model(model) 24 | model.load_state_dict(data) 25 | 26 | return model 27 | 28 | 29 | parser = argparse.ArgumentParser(description='Train EBM model') 30 | 31 | parser.add_argument('--dataset', default='Smoke', type=str, 32 | help='dataset to evaluate') 33 | parser.add_argument('--dataset_path', default="/data", type=str, 34 | help='path to dataset') 35 | parser.add_argument('--is_condition_control', default=False, type=eval, 36 | help='If condition on control') 37 | parser.add_argument('--is_condition_pad', default=True, type=eval, 38 | help='If condition on padded state') 39 | parser.add_argument('--batch_size', default=12, type=int, 40 | help='size of batch of input to use') 41 | parser.add_argument('--horizon', default=15, type=int, 42 | help='number of horizon to diffuse') 43 | parser.add_argument('--train_num_steps', default=250000, type=int, 44 | help='total training steps') 45 | parser.add_argument('--results_path', default="./results/train/{}/".format(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')), type=str, 46 | help='folder to save training checkpoints') 47 | parser.add_argument('--diffusion_steps', default=600, type=int, 48 | help='number of denoising steps in diffusion model') 49 | parser.add_argument('--is_synch_model', action='store_true', help="whether use synchronous denoising steps among different time steps") 50 | 51 | 52 | if __name__ == "__main__": 53 | FLAGS = parser.parse_args() 54 | print(FLAGS) 55 | 56 | # get shape, RESCALER 57 | if FLAGS.dataset == "Smoke": 58 | dataset = Smoke( 59 | dataset_path=FLAGS.dataset_path, 60 | horizon=FLAGS.horizon, 61 | is_train=True, 62 | ) 63 | _, shape, ori_shape, _ = dataset[0] 64 | else: 65 | assert False 66 | RESCALER = dataset.RESCALER.unsqueeze(0) 67 | 68 | from model.video_diffusion_pytorch_conv3d import Unet3D_with_Conv3D 69 | model = Unet3D_with_Conv3D( 70 | dim = 64, 71 | dim_mults = (1, 2, 4), 72 | channels = 6, 73 | ) 74 | 75 | print("Number of parameters: {}". 76 | format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 77 | print("Saved at: ", FLAGS.results_path) 78 | 79 | diffusion = GaussianDiffusion( 80 | model, 81 | RESCALER, 82 | FLAGS.is_condition_control, 83 | FLAGS.is_condition_pad, 84 | ori_shape, 85 | image_size = 64, 86 | horizon = FLAGS.horizon, 87 | diffusion_steps = FLAGS.diffusion_steps, # number of diffusion steps 88 | sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) 89 | loss_type = 'l2', # L1 or L2 90 | objective = "pred_noise", 91 | is_synch_model = FLAGS.is_synch_model 92 | ) 93 | 94 | trainer = Trainer( 95 | diffusion, 96 | FLAGS.dataset, 97 | FLAGS.dataset_path, 98 | train_batch_size = FLAGS.batch_size, 99 | train_lr = 1e-3, 100 | train_num_steps = FLAGS.train_num_steps, # total training steps 101 | gradient_accumulate_every = 1, # gradient accumulation steps 102 | ema_decay = 0.995, # exponential moving average decay 103 | save_and_sample_every = 10000, # 10000 104 | horizon = FLAGS.horizon, 105 | results_path = FLAGS.results_path, 106 | amp = False, # turn on mixed precision 107 | calculate_fid = False, # whether to calculate fid during training 108 | ) 109 | 110 | trainer.train() -------------------------------------------------------------------------------- /dataset/data_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | import pdb 5 | import sys, os 6 | 7 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..')) 8 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 9 | 10 | from IPython import embed 11 | 12 | class Smoke(Dataset): 13 | def __init__( 14 | self, 15 | dataset_path, 16 | time_steps=64, 17 | horizon=10, 18 | time_interval=1, 19 | all_size=128, 20 | size=64, 21 | is_train=True, 22 | ): 23 | super().__init__() 24 | self.root = dataset_path 25 | self.time_steps = time_steps # total time steps of each trajectory after down sampling 26 | self.horizon = horizon # horizon of diffusion model 27 | self.time_interval = time_interval 28 | self.time_steps_effective = (self.time_steps - self.horizon + 1) // self.time_interval 29 | self.all_size = all_size 30 | self.size = size 31 | self.space_interval = int(all_size/size) 32 | self.is_train = is_train 33 | self.dirname = "train" if self.is_train else "test" 34 | if self.is_train: 35 | self.n_simu = 40000 36 | else: 37 | self.n_simu = 50 38 | # self.RESCALER = torch.tensor([3, 20, 20, 17, 19, 1]).reshape(1, 6, 1, 1) 39 | self.RESCALER = torch.tensor([1, 45, 50, 45, 50, 1]).reshape(1, 6, 1, 1) # rescale the data to [-1, 1] with relaxation, on 64 time steps dataset 40 | 41 | def __len__(self): 42 | # return self.n_simu 43 | if self.is_train: 44 | return self.n_simu * self.time_steps_effective 45 | else: 46 | return self.n_simu 47 | 48 | def __getitem__(self, idx): 49 | if self.is_train: 50 | sim_id, time_id = divmod(idx, self.time_steps_effective) 51 | else: 52 | sim_id, time_id = idx, 0 # for test, pass each trajectory as a whole and only once 53 | 54 | if self.is_train: 55 | d = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Density.npy'.format(sim_id))), \ 56 | dtype=torch.float).permute(2,3,0,1) 57 | v = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Velocity.npy'.format(sim_id))), \ 58 | dtype=torch.float).permute(2,3,0,1) # 2, 65, 64, 64 59 | c = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Control.npy'.format(sim_id))), \ 60 | dtype=torch.float).permute(2,3,0,1) 61 | s = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Smoke.npy'.format(sim_id))), \ 62 | dtype=torch.float) # 65, 8 63 | s = s[:, 1]/s.sum(-1) # shape: [65]; 1 is index of of the target bucket 64 | s = s.reshape(1, s.shape[0], 1, 1).expand(1, s.shape[0], self.size, self.size) # 1, 65, 64, 64 65 | state = torch.cat((d, v, c, s), dim=0)[:, time_id: time_id + self.horizon] # 6, horizon, 64, 64 66 | 67 | data = ( 68 | state.permute(1, 0, 2, 3) / self.RESCALER, # horizon, 6, 64, 64 69 | list(state.shape[-3:]), 70 | list(state.shape[-3:]), 71 | sim_id, 72 | ) 73 | else: 74 | d = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Density.npy'.format(sim_id))), \ 75 | dtype=torch.float).permute(2,3,0,1) 76 | v = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Velocity.npy'.format(sim_id))), \ 77 | dtype=torch.float).permute(2,3,0,1) 78 | c = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Control.npy'.format(sim_id))), \ 79 | dtype=torch.float).permute(2,3,0,1) 80 | s = torch.tensor(np.load(os.path.join(self.root, self.dirname, 'sim_{:06d}/Smoke.npy'.format(sim_id))), \ 81 | dtype=torch.float) 82 | 83 | s = s[:, 1]/s.sum(-1) 84 | s = s.reshape(1, s.shape[0], 1, 1).expand(1, s.shape[0], self.size, self.size) 85 | state = torch.cat((d, v, c, s), dim=0) # 6, 65, 64, 64 86 | data = ( 87 | state.permute(1, 0, 2, 3), # 65, 6, 64, 64, not rescaled 88 | list(state.shape[-3:]), 89 | list(state.shape[-3:]), 90 | sim_id, 91 | ) 92 | 93 | return data 94 | 95 | if __name__ == "__main__": 96 | dataset = Smoke( 97 | dataset_path="/data/", 98 | is_train=True, 99 | ) 100 | print("len(dataset): ", len(dataset)) -------------------------------------------------------------------------------- /phi/control/voxelutil.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageFont 2 | import numpy as np 3 | import random 4 | from phi.fluidformat import * 5 | 6 | 7 | def text_to_pixels(text, size=10, binary=False, as_numpy_array=True): 8 | image = Image.new("1" if binary else "L", (len(text)*size*3//4, size), 0) 9 | draw = ImageDraw.Draw(image) 10 | try: 11 | font = ImageFont.truetype("arial.ttf", size) 12 | except: 13 | font = ImageFont.truetype('Pillow/Tests/fonts/DejaVuSans.ttf', size=size) 14 | draw.text((0,0), text, fill=255, font=font) 15 | del draw 16 | 17 | if as_numpy_array: 18 | return np.array(image).astype(np.float32) / 255.0 19 | else: 20 | return image 21 | 22 | 23 | # image = text_to_pixels("The", as_numpy_array=False) 24 | # image.save("testimg.png", "PNG") 25 | 26 | 27 | def alphabet_soup(shape, count, margin=1, total_content=100, fontsize=10): 28 | if len(shape) != 4: raise ValueError("shape must be 4D") 29 | array = np.zeros(shape) 30 | letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' 31 | 32 | for batch in range(shape[0]): 33 | for i in range(count): 34 | letter = letters[random.randint(0, len(letters)-1)] 35 | tile = text_to_pixels(letter, fontsize)#[::-1, :] 36 | y = random.randint(margin, shape[1] - margin - tile.shape[0] - 2) 37 | x = random.randint(margin, shape[2] - margin - tile.shape[1] - 2) 38 | array[batch, y:(y+tile.shape[0]), x:(x+tile.shape[1]), 0] += tile 39 | 40 | return array.astype(np.float32) * total_content / np.sum(array) 41 | 42 | 43 | def random_word(shape, min_count, max_count, margin=1, total_content=100, fontsize=10, y=40): 44 | if len(shape) != 4: raise ValueError("shape must be 4D") 45 | array = np.zeros(shape) 46 | letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789' 47 | 48 | for b in range(shape[0]): 49 | count = random.randint(min_count, max_count) 50 | for i in range(count): 51 | letter = letters[random.randint(0, len(letters)-1)] 52 | tile = text_to_pixels(letter, fontsize)#[::-1, :] 53 | x = random.randint(margin, shape[2] - margin - tile.shape[1] - 2) 54 | array[b, y:(y+tile.shape[0]), x:(x+tile.shape[1]), 0] += tile 55 | 56 | return array.astype(np.float32) * total_content / np.sum(array) 57 | 58 | 59 | 60 | def single_shape(shape, scene, margin=1, fluid_mask=None): 61 | if len(shape) != 4: raise ValueError("shape must be 4D") 62 | array = np.zeros(shape) 63 | for batch in range(shape[0]): 64 | img = scene.read_array("Shape", random.choice(scene.indices))[0,...] 65 | while True: 66 | y = random.randint(margin, shape[1] - margin - img.shape[0] - 2) 67 | x = random.randint(margin, shape[2] - margin - img.shape[1] - 2) 68 | array[batch, y:(y + img.shape[0]), x:(x + img.shape[1]), :] = img 69 | if _all_density_valid(array[batch:batch+1,...], fluid_mask): 70 | break 71 | else: 72 | array[batch,...] = 0 73 | 74 | return array.astype(np.float32) 75 | 76 | 77 | def _all_density_valid(density, fluid_mask): 78 | if fluid_mask is None: 79 | return True 80 | return np.sum(density * fluid_mask) == np.sum(density) 81 | 82 | 83 | def push_density_inside(density_tile, tile_location, fluid_mask): # (y, x) 84 | """ 85 | Tries to adjust the tile_location so that the density_tile does not overlap with any obstacles. 86 | :param density_tile: 2D binary array, representing the density mask to be shifted 87 | :param tile_location: the initial location of the tile, (1D array with 2 values) 88 | :param fluid_mask: 2D binary array (must be larger than the tile) 89 | :return: the shifted location (1D array with 2 values) 90 | """ 91 | x, y = np.meshgrid(*[np.linspace(-1, 1, d) for d in density_tile.shape]) 92 | location = np.array(tile_location, dtype=np.int) 93 | 94 | def cropped_mask(location): 95 | slices = [slice(location[i], location[i]+density_tile.shape[i]) for i in range(2)] 96 | return fluid_mask[slices] 97 | 98 | while True: 99 | cropped_fluid_mask = cropped_mask(location) 100 | overlap = density_tile * (1-cropped_fluid_mask) 101 | if np.sum(overlap) == 0: 102 | return location 103 | update = -np.sign([np.sum(overlap * y), np.sum(overlap * x)]).astype(np.int) 104 | if np.all(update == 0): 105 | raise ValueError("Failed to push tile with initial location %s out of obstacle" % (tile_location,)) 106 | location += update 107 | 108 | 109 | # print(alphabet_soup([1, 16, 16, 1], 1000)[0,:,:,0]) 110 | 111 | # result = single_shape((2, 64, 64, 1), scene_at("data/shapelib/sim_000000")) 112 | # print(result.shape, np.sum(result)) 113 | 114 | # Test push_density_inside 115 | # fluid_mask = np.ones([64, 64]) 116 | # fluid_mask[10:20, 10:20] = 0 117 | # density_tile = np.ones([5,5]) 118 | # tile_location = (18,9) 119 | # print(push_density_inside(density_tile, tile_location, fluid_mask)) -------------------------------------------------------------------------------- /phi/tf/util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from phi.math.nd import * 4 | 5 | 6 | def group_normalization(x, group_count, eps=1e-5): 7 | batch_size, H, W, C = tf.shape(x) 8 | gamma = tf.Variable(np.ones([1,1,1,C]), dtype=tf.float32, name="GN_gamma") 9 | beta = tf.Variable(np.zeros([1,1,1,C]), dtype=tf.float32, name="GN_beta") 10 | x = tf.reshape(x, [batch_size, group_count, H, W, C // group_count]) 11 | mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True) 12 | x = (x - mean) / tf.sqrt(var + eps) 13 | x = tf.reshape(x, [batch_size, H, W, C]) 14 | return x * gamma + beta 15 | 16 | 17 | def residual_block(y, nb_channels, kernel_size=(3, 3), _strides=(1, 1), activation=tf.nn.leaky_relu, 18 | _project_shortcut=False, padding="SYMMETRIC", name=None, training=False, trainable=True, reuse=tf.AUTO_REUSE): 19 | shortcut = y 20 | 21 | if isinstance(kernel_size, int): 22 | kernel_size = (kernel_size, kernel_size) 23 | 24 | pad1 = [(kernel_size[0] - 1) // 2, kernel_size[0] // 2] 25 | pad2 = [(kernel_size[1] - 1) // 2, kernel_size[1] // 2] 26 | 27 | # down-sampling is performed with a stride of 2 28 | y = tf.pad(y, [[0,0], pad1, pad2, [0,0]], mode=padding) 29 | y = tf.layers.conv2d(y, nb_channels, kernel_size=kernel_size, strides=_strides, padding='valid', 30 | name=None if name is None else name+"/conv1", trainable=trainable, reuse=reuse) 31 | # y = tf.layers.batch_normalization(y, name=None if name is None else name+"/norm1", training=training, trainable=trainable, reuse=reuse) 32 | y = activation(y) 33 | 34 | y = tf.pad(y, [[0,0], pad1, pad2, [0,0]], mode=padding) 35 | y = tf.layers.conv2d(y, nb_channels, kernel_size=kernel_size, strides=(1, 1), padding='valid', 36 | name=None if name is None else name + "/conv2", trainable=trainable, reuse=reuse) 37 | # y = tf.layers.batch_normalization(y, name=None if name is None else name+"/norm2", training=training, trainable=trainable, reuse=reuse) 38 | 39 | # identity shortcuts used directly when the input and output are of the same dimensions 40 | if _project_shortcut or _strides != (1, 1): 41 | # when the dimensions increase projection shortcut is used to match dimensions (done by 1×1 convolutions) 42 | # when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2 43 | shortcut = tf.pad(shortcut, [[0,0], pad1, pad2, [0,0]], mode=padding) 44 | shortcut = tf.layers.conv2d(shortcut, nb_channels, kernel_size=(1, 1), strides=_strides, padding='valid', 45 | name=None if name is None else name + "/convid", trainable=trainable, reuse=reuse) 46 | # shortcut = tf.layers.batch_normalization(shortcut, name=None if name is None else name+"/normid", training=training, trainable=trainable, reuse=reuse) 47 | 48 | y += shortcut 49 | y = activation(y) 50 | 51 | return y 52 | 53 | 54 | def residual_block_1d(y, nb_channels, kernel_size=(3,), _strides=(1,), activation=tf.nn.leaky_relu, 55 | _project_shortcut=False, padding="SYMMETRIC", name=None, training=False, trainable=True, reuse=tf.AUTO_REUSE): 56 | shortcut = y 57 | 58 | if isinstance(kernel_size, int): 59 | kernel_size = (kernel_size,) 60 | 61 | pad1 = [(kernel_size[0] - 1) // 2, kernel_size[0] // 2] 62 | 63 | # down-sampling is performed with a stride of 2 64 | y = tf.pad(y, [[0,0], pad1, [0,0]], mode=padding) 65 | y = tf.layers.conv1d(y, nb_channels, kernel_size=kernel_size, strides=_strides, padding='valid', 66 | name=None if name is None else name+"/conv1", trainable=trainable, reuse=reuse) 67 | # y = tf.layers.batch_normalization(y, name=None if name is None else name+"/norm1", training=training, trainable=trainable, reuse=reuse) 68 | y = activation(y) 69 | 70 | y = tf.pad(y, [[0,0], pad1, [0,0]], mode=padding) 71 | y = tf.layers.conv1d(y, nb_channels, kernel_size=kernel_size, strides=(1,), padding='valid', 72 | name=None if name is None else name + "/conv2", trainable=trainable, reuse=reuse) 73 | # y = tf.layers.batch_normalization(y, name=None if name is None else name+"/norm2", training=training, trainable=trainable, reuse=reuse) 74 | 75 | # identity shortcuts used directly when the input and output are of the same dimensions 76 | if _project_shortcut or _strides != (1,): 77 | # when the dimensions increase projection shortcut is used to match dimensions (done by 1×1 convolutions) 78 | # when the shortcuts go across feature maps of two sizes, they are performed with a stride of 2 79 | shortcut = tf.pad(shortcut, [[0,0], pad1, [0,0]], mode=padding) 80 | shortcut = tf.layers.conv1d(shortcut, nb_channels, kernel_size=(1, 1), strides=_strides, padding='valid', 81 | name=None if name is None else name + "/convid", trainable=trainable, reuse=reuse) 82 | # shortcut = tf.layers.batch_normalization(shortcut, name=None if name is None else name+"/normid", training=training, trainable=trainable, reuse=reuse) 83 | 84 | y += shortcut 85 | y = activation(y) 86 | 87 | return y 88 | 89 | 90 | def istensor(object): 91 | if isinstance(object, StaggeredGrid): 92 | object = object.staggered 93 | return isinstance(object, tf.Tensor) 94 | -------------------------------------------------------------------------------- /phi/solver/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from phi.math.nd import * 3 | 4 | 5 | def create_mask(divergence_tensor): 6 | return np.ones([1]+list(divergence_tensor.shape)[1:], np.float32) 7 | 8 | 9 | class PressureSolver(object): 10 | 11 | def __init__(self, name): 12 | self.name = name 13 | 14 | def solve(self, divergence, active_mask, fluid_mask, boundaries, accuracy, pressure_guess=None, **kwargs): 15 | """ 16 | Solves the pressure equation Δp = ∇·v for all active fluid cells where active cells are given by the active_mask. 17 | The resulting pressure is expected to fulfill (Δp-∇·v) ≤ accuracy for every active cell. 18 | :param divergence: the scalar divergence of the velocity field, ∇·v 19 | :param active_mask: (Optional) Scalar field encoding active cells as ones and inactive (open/obstacle) as zero. 20 | :param fluid_mask: (Optional) Scalar field encoding fluid cells as ones and obstacles as zero. 21 | Has the same dimensions as the divergence field. If no obstacles are present, None may be passed. 22 | :param boundaries: DomainBoundary object defining open and closed boundaries 23 | :param accuracy: The accuracy of the result. Every grid cell should fulfill (Δp-∇·v) ≤ accuracy 24 | :param pressure_guess: (Optional) Pressure field which can be used as an initial state for the solver 25 | :param kwargs: solver-specific arguments 26 | """ 27 | raise NotImplementedError() 28 | 29 | 30 | class ExplicitBoundaryPressureSolver(PressureSolver): 31 | 32 | def __init__(self, name): 33 | PressureSolver.__init__(self, name) 34 | 35 | def solve(self, divergence, active_mask, fluid_mask, boundaries, accuracy, pressure_guess=None, **kwargs): 36 | active_mask = create_mask(divergence) if active_mask is None else active_mask 37 | active_mask = boundaries.pad_active(active_mask) 38 | fluid_mask = create_mask(divergence) if fluid_mask is None else fluid_mask 39 | fluid_mask = boundaries.pad_fluid(fluid_mask) 40 | return self.solve_with_boundaries(divergence, active_mask, fluid_mask, accuracy, pressure_guess, **kwargs) 41 | 42 | def solve_with_boundaries(self, divergence, active_mask, fluid_mask, accuracy, pressure_guess=None, **kwargs): 43 | """ 44 | See :func:`PressureSolver.solve`. Unlike the regular solve method, active_mask and fluid_mask are valid tensors which include 45 | one extra voxel at each boundary to account for boundary conditions. 46 | :param divergence: n^d dimensional scalar field 47 | :param active_mask: (n+2)^d dimensional scalar field 48 | :param fluid_mask: (n+2)^d dimensional scalar field 49 | :param accuracy: 50 | :param pressure_guess: 51 | :param kwargs: 52 | """ 53 | raise NotImplementedError() 54 | 55 | 56 | def conjugate_gradient(k, apply_A, initial_x=None, accuracy=1e-5, max_iterations=1024, back_prop=False): 57 | """ 58 | Solve the linear system of equations Ax=k using the conjugate gradient (CG) algorithm. 59 | The implementation is based on https://nvlpubs.nist.gov/nistpubs/jres/049/jresv49n6p409_A1b.pdf 60 | :param k: Right-hand-side vector 61 | :param apply_A: function that takes x and calculates Ax 62 | :param initial_x: initial guess for the value of x 63 | :param accuracy: the algorithm terminates once |Ax-k| ≤ accuracy for every element. If None, the algorithm runs until max_iterations is reached. 64 | :param max_iterations: maximum number of CG iterations to perform 65 | :return: Pair containing the result for x and the number of iterations performed 66 | """ 67 | if initial_x is None: 68 | x = math.zeros_like(k) 69 | momentum = k 70 | else: 71 | x = initial_x 72 | momentum = k - apply_A(x) 73 | residual = momentum 74 | 75 | laplace_momentum = apply_A(momentum) 76 | loop_index = 0 77 | 78 | vars = [x, momentum, laplace_momentum, residual, loop_index] 79 | 80 | if accuracy is not None: 81 | def loop_condition(_1, _2, _3, residual, i): 82 | return math.max(math.abs(residual)) >= accuracy 83 | else: 84 | def loop_condition(_1, _2, _3, residual, i): 85 | return True 86 | 87 | def loop_body(pressure, momentum, A_times_momentum, residual, loop_index): 88 | tmp = math.sum(momentum * A_times_momentum) 89 | a = math.sum(momentum * residual) / tmp 90 | pressure += a * momentum 91 | residual -= a * A_times_momentum 92 | b = - math.sum(residual * A_times_momentum) / tmp 93 | momentum = residual + b * momentum 94 | A_times_momentum = apply_A(momentum) 95 | return [pressure, momentum, A_times_momentum, residual, loop_index + 1] 96 | 97 | x, momentum, laplace_momentum, residual, loop_index = math.while_loop(loop_condition, loop_body, vars, 98 | parallel_iterations=2, back_prop=back_prop, 99 | swap_memory=False, 100 | name="pressure_solve_loop", 101 | maximum_iterations=max_iterations) 102 | 103 | return x, loop_index 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CL-DiffPhyCon: Closed-loop Diffusion Control of Complex Physical Systems (ICLR 2025) 2 | 3 | [Paper](https://openreview.net/forum?id=PiHGrTTnvb) | [arXiv](https://arxiv.org/pdf/2408.03124) 4 | 5 | 6 | 7 | Official repo for the paper [CL-DiffPhyCon: Closed-loop Diffusion Control of Complex Physical Systems](https://openreview.net/pdf?id=PiHGrTTnvb).
8 | [Long Wei*](https://longweizju.github.io/), [Haodong Feng*](https://scholar.google.com/citations?user=0GOKl_gAAAAJ&hl=en), [Yuchen Yang](), [Ruiqi Feng](https://weenming.github.io/), [Peiyan Hu](https://peiyannn.github.io/), [Xiang Zheng](), [Tao Zhang](https://zhangtao167.github.io), [Dixia Fan](https://en.westlake.edu.cn/faculty/dixia-fan.html), [Tailin Wu†](https://tailin.org/)
9 | ICLR 2025. 10 | 11 | We propose a diffusion method with an asynchronous denoising schedule for physical systems control tasks. It achieves closed-loop control with a significant speedup of sampling efficiency. Specifically, it has the following features: 12 | 13 | - Efficient Sampling: CL-DiffPhyCon significantly reduces the computational cost during the sampling process through an asynchronous denoising framework. Compared with existing diffusion-based control methods, CL-DiffPhyCon can generate high-quality control signals in a much shorter time. 14 | 15 | - Closed-loop Control: CL-DiffPhyCon enables closed-loop control, adjusting strategies according to real-time environmental feedback. It outperforms open-loop diffusion-based planning methods in control effectiveness. 16 | 17 | - Accelerated Sampling: CL-DiffPhyCon can integrate with acceleration techniques such as [DDIM](https://arxiv.org/abs/2010.02502). It further enhances control efficiency while keeping the control effect stable. 18 | 19 | Framework of CL-DiffPhyCon: 20 | 21 | 22 | 23 | This is a follow-up work of our previous DiffPhyCon (NeurIPS 2024): [Paper](https://openreview.net/forum?id=MbZuh8L0Xg) | [Code](https://github.com/AI4Science-WestlakeU/diffphycon). 24 | 25 | # Installation 26 | 27 | Run the following commands to install dependencies. In particular, the Python version must be 3.8 when running the 2D smoke control task, as the Phiflow software requires. 28 | 29 | ```code 30 | conda env create -f environment.yml 31 | conda activate base 32 | ``` 33 | 34 | # Dataset and checkpoints 35 | ## Dataset 36 | The training and testing datasets and checkpoints of our CL-DiffPhyCon on both tasks (1D Burgers control and 2D smoke control) can be downloaded in [link](https://drive.google.com/drive/folders/1moLdtqmvmAU8FoWt6ELWOTXT0tPuY-qJ). To run the following training and inference scripts locally, replace the path names in the following scripts with your local paths. 37 | 38 | 39 | # Training: 40 | ## 1D Burgers' Equation Control: 41 | 42 | In the scripts_1d/ folder, run the following two scripts to train the synchronous and asynchronous diffusion models, respectively: 43 | ```code 44 | bash train_syn.sh 45 | bash train_asyn.sh 46 | ``` 47 | 48 | ## 2D Smoke Control: 49 | 50 | In the scripts_2d/ folder, modify the configs in the file default_config.yaml and the argument "main_process_port" and "gpu_ids" according to your local GPU environments to run [accelerate](https://pypi.org/project/accelerate/) properly. Then, run the following two scripts to train the two diffusion models, respectively: 51 | ```code 52 | bash train_syn.sh 53 | bash train_asyn.sh 54 | ``` 55 | 56 | # Inference: 57 | ## 1D Burgers' Equation Control: 58 | In the scripts_1d/ folder, run the following script for closed-loop diffusion control: 59 | ``` 60 | bash inf_asyn.sh 61 | ``` 62 | 63 | ## 2D Smoke Control: 64 | ### CL-DiffPhyCon 65 | In the scripts_2d/ folder, run the following script for closed-loop diffusion control: 66 | ``` 67 | bash inf_asyn.sh 68 | ``` 69 | 70 | Then in the inference/ folder, run evaluate_2d.py to evaluate the inference results (also modify the data path variable "root" first) 71 | ``` 72 | python evaluate_2d.py 73 | ``` 74 | 75 | ## Related Projects 76 | * [DiffPhyCon](https://github.com/AI4Science-WestlakeU/diffphycon) (NeurIPS 2024): We introduce DiffPhyCon which uses diffusion generative models to jointly model control and simulation of complex physical systems as a single task. 77 | 78 | * [WDNO](https://github.com/AI4Science-WestlakeU/wdno) (ICLR 2025): We propose Wavelet Diffusion Neural Operator (WDNO), a novel method for generative PDE simulation and control, to address diffusion models' challenges of modeling system states with abrupt changes and generalizing to higher resolutions, via performing diffusion in the wavelet space. 79 | 80 | * [SafeDiffCon](https://github.com/AI4Science-WestlakeU/safediffcon) (ICML 2025): We propose safe diffusion models for PDE Control, which introduces the uncertainty quantile as model uncertainty quantification to achieve optimal control under safety constraints through both post-training and inference phases. 81 | 82 | * [CinDM](https://github.com/AI4Science-WestlakeU/cindm) (ICLR 2024 spotlight): We introduce a method that uses compositional generative models to design boundaries and initial states significantly more complex than the ones seen in training for physical simulations. 83 | 84 | ## Citation 85 | If you find our work and/or our code useful, please cite us via: 86 | 87 | ```bibtex 88 | @inproceedings{ 89 | wei2025cldiffphycon, 90 | title={{CL}-DiffPhyCon: Closed-loop Diffusion Control of Complex Physical Systems}, 91 | author={Long Wei and Haodong Feng and Yuchen Yang and Ruiqi Feng and Peiyan Hu and Xiang Zheng and Tao Zhang and Dixia Fan and Tailin Wu}, 92 | booktitle={The Thirteenth International Conference on Learning Representations}, 93 | year={2025}, 94 | url={https://openreview.net/forum?id=PiHGrTTnvb} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /utils_1d/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def normalize_to_neg_one_to_one(img): 5 | return img * 2 - 1 6 | 7 | def unnormalize_to_zero_to_one(t): 8 | return (t + 1) * 0.5 9 | 10 | def exists(x): 11 | return x is not None 12 | 13 | def default(val, d): 14 | if exists(val): 15 | return val 16 | return d() if callable(d) else d 17 | 18 | def identity(t, *args, **kwargs): 19 | return t 20 | 21 | def cycle(dl): 22 | while True: 23 | for data in dl: 24 | yield data 25 | 26 | def extract(a, t, x_shape): 27 | if len(t.shape) == 1: 28 | b, *_ = t.shape 29 | out = a.gather(-1, t) 30 | out = out.reshape(b, *((1,) * (len(x_shape) - 1))) 31 | 32 | elif len(t.shape) == 2: 33 | a = a.unsqueeze(0).expand(t.shape[0], -1) 34 | b, T, *_ = t.shape 35 | out = a.gather(-1, t) 36 | 37 | out = out.reshape(b, T, *((1,) * (len(x_shape) - 2))).permute(0,2,1,3) 38 | 39 | return out 40 | 41 | 42 | # guidance helper functions 43 | 44 | def get_nablaJ(loss_fn: callable): 45 | '''Use explicit loss for guided inference in diffusion. 46 | J is the loss here, not Jacobian. 47 | 48 | Arguments: 49 | loss_fn: callable, calculates the loss. 50 | Arguments: 51 | x: state + control 52 | Returns: loss (requires_grad) 53 | ''' 54 | def nablaJ(x: torch.TensorType): 55 | x.requires_grad_(True) 56 | J = loss_fn(x) # vec of size of batch 57 | grad = torch.autograd.grad(J, x, grad_outputs=torch.ones_like(J), retain_graph = True, create_graph=True, allow_unused=True)[0] 58 | return grad.detach() 59 | return nablaJ 60 | 61 | def get_proj_ep_orthogonal_func(norm='F'): 62 | # well,for 1D case there is no ambiguity but it seems less straightforward 63 | # for highr-dimensional embedding (even for Burgers' the data is essentailly 2D) 64 | # The inner product of two matrices are their F norm though. 65 | 66 | if norm == 'F': 67 | def proj_ep_orthogonal(ep, nabla_J): 68 | return ep + nabla_J - (nabla_J * ep).sum() * ep / ep.square().sum((-2, -1)).sqrt().unsqueeze(-1).unsqueeze(-1) 69 | elif norm == '1D_x': 70 | def proj_ep_orthogonal(ep, nabla_J): 71 | return ep + nabla_J - (nabla_J * ep).sum(-1).unsqueeze(-1) * ep / ep.square().sum(-1).sqrt().unsqueeze(-1) 72 | elif norm == '1D_t': 73 | def proj_ep_orthogonal(ep, nabla_J): 74 | return ep + nabla_J - (nabla_J * ep).sum(-2) * ep / ep.square().sum(-2).sqrt() 75 | else: 76 | raise NotImplementedError 77 | 78 | return proj_ep_orthogonal 79 | 80 | 81 | def cosine_beta_J_schedule(t, s = 0.008): 82 | """ 83 | cosine schedule (returns beta = 1 - cos^2 (x / N), which is increasing.) 84 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 85 | """ 86 | timesteps = 1000 87 | steps = timesteps + 1 88 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 89 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 90 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 91 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 92 | return torch.clip(betas, 0, 0.999)[t] 93 | 94 | def plain_cosine_schedule(t, s = 0.0): 95 | """ 96 | cosine schedule, which is decreasing... 97 | """ 98 | timesteps = 1000 99 | steps = timesteps + 1 100 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 101 | eta = torch.cos((x + s) / (timesteps + s)) 102 | return eta.flip()[t] # t=0 should be zero (small step size) 103 | 104 | def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-5): 105 | """ 106 | sigmoid schedule 107 | proposed in https://arxiv.org/abs/2212.11972 - Figure 8 108 | better for images > 64x64, when used during training 109 | """ 110 | timesteps = 1000 111 | steps = timesteps + 1 112 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps 113 | v_start = torch.tensor(start / tau).sigmoid() 114 | v_end = torch.tensor(end / tau).sigmoid() 115 | alphas_cumprod = (-((x * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) 116 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 117 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 118 | return torch.clip(betas, 0, 0.999)[t] 119 | 120 | def sigmoid_schedule_flip(t): 121 | return sigmoid_schedule(999 - t) 122 | 123 | def linear_schedule(t): 124 | timesteps = 1 125 | scale = 1000 / timesteps 126 | beta_start = scale * 0.0001 127 | beta_end = scale * 0.02 128 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)[t] 129 | 130 | 131 | def linear_beta_schedule(timesteps): 132 | scale = 1000 / timesteps 133 | beta_start = scale * 0.0001 134 | beta_end = scale * 0.02 135 | return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) 136 | 137 | def cosine_beta_schedule(timesteps, s = 0.008): 138 | """ 139 | cosine schedule 140 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 141 | """ 142 | steps = timesteps + 1 143 | x = torch.linspace(0, timesteps, steps, dtype = torch.float64) 144 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 145 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 146 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 147 | return torch.clip(betas, 0, 0.999) 148 | 149 | 150 | # helpers functions 151 | 152 | 153 | def has_int_squareroot(num): 154 | return (math.sqrt(num) ** 2) == num 155 | 156 | def num_to_groups(num, divisor): 157 | groups = num // divisor 158 | remainder = num % divisor 159 | arr = [divisor] * groups 160 | if remainder > 0: 161 | arr.append(remainder) 162 | return arr 163 | 164 | def convert_image_to_fn(img_type, image): 165 | if image.mode != img_type: 166 | return image.convert(img_type) 167 | return image 168 | -------------------------------------------------------------------------------- /utils_1d/result_io.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import numpy as np 3 | import torch 4 | from pathlib import Path 5 | import os 6 | import copy 7 | 8 | class YamlReaderError(Exception): 9 | pass 10 | 11 | 12 | def data_merge(a, b) -> None: 13 | """merges b into a and return merged result 14 | NOTE: side effect: a is modified 15 | NOTE: tuples and arbitrary objects are not handled as it is totally ambiguous what should happen 16 | """ 17 | key = None 18 | # ## debug output 19 | # sys.stderr.write("DEBUG: %s to %s\n" %(b,a)) 20 | try: 21 | if ( 22 | a is None 23 | or isinstance(a, str) 24 | or isinstance(a, int) 25 | or isinstance(a, float) 26 | ): 27 | # border case for first run or if a is a primitive 28 | a = b 29 | elif isinstance(a, list): 30 | # lists can be only appended 31 | if isinstance(b, list): 32 | # merge lists 33 | a.extend(b) 34 | else: 35 | # append to list 36 | a.append(b) 37 | elif isinstance(a, dict): 38 | # dicts must be merged 39 | if isinstance(b, dict): 40 | for key in b: 41 | if key in a: 42 | a[key] = data_merge(a[key], b[key]) 43 | else: 44 | a[key] = b[key] 45 | else: 46 | raise YamlReaderError( 47 | 'Cannot merge non-dict "%s" into dict "%s"' % (b, a) 48 | ) 49 | else: 50 | raise YamlReaderError('NOT IMPLEMENTED "%s" into "%s"' % (b, a)) 51 | except TypeError as e: 52 | raise YamlReaderError( 53 | 'TypeError "%s" in key "%s" when merging "%s" into "%s"' % (e, key, b, a) 54 | ) 55 | return a 56 | 57 | def merge_save_dict(fname, new_res): 58 | # save to file 59 | # creates file if not exist 60 | if not os.path.exists(fname): 61 | with open(fname, 'w'): pass 62 | 63 | with open(fname, "r") as f: 64 | res = yaml.safe_load(f) 65 | 66 | res = data_merge(res, new_res) # join recursively result dicts 67 | 68 | with open(fname, "w") as f: 69 | yaml.dump(res, f) 70 | 71 | def save_acc( 72 | acc, 73 | fname, 74 | make_dict_path: callable=lambda acc_dict, dict_args:{dict_args['model_name']: {"result": acc_dict}}, 75 | **dict_path_args 76 | ): 77 | # NOTE: convert to python float 78 | new_res = make_dict_path({"mean": float(np.mean(acc)), "std": float(np.std(acc))}, dict_path_args) 79 | merge_save_dict(fname, new_res) 80 | 81 | 82 | def save_edge_flip(edge_flip, fpath): 83 | 84 | edge_flip = edge_flip.detach().cpu().numpy() 85 | # check if path exists, if not, create that directory. 86 | Path(fpath).mkdir(parents=True, exist_ok=True) 87 | np.save(os.path.join(fpath, 'flip.npy'), edge_flip) 88 | 89 | 90 | def load_edge_flip(fpath): 91 | if not os.path.exists(fpath): 92 | raise ValueError('Edge filp file not found: path does not exist') 93 | try: 94 | edge_flip = np.load(os.path.join(fpath, 'flip.npy')) 95 | except Exception as e: 96 | raise ValueError('Edge filp file not found: unknown error. ' + e) 97 | return torch.from_numpy(edge_flip).cuda() 98 | 99 | 100 | 101 | 102 | def rep_save_model(model_name, budget, models, rep_per_split=5, save_name='clean_model', dataset_name='cora'): 103 | for rep, model in enumerate(models): 104 | fpath = f'./my_exp/{dataset_name}_flips/{model_name}/models/{budget:.3f}/split_{rep // rep_per_split}/rand_model_{rep % rep_per_split}/' 105 | Path(fpath).mkdir(parents=True, exist_ok=True) 106 | fpath += save_name 107 | torch.save(model.state_dict(), fpath) 108 | 109 | def rep_load_model(model_name, budget, model, rep_per_split=5, save_name='clean_model', neglect_ada_model=False, dataset_name='cora', debug=False): 110 | models = [] 111 | rep = 0 112 | while True: 113 | try: 114 | if neglect_ada_model and rep % rep_per_split == 0: 115 | raise Exception # not load the model init that is attacked 116 | fpath = f'{os.path.dirname(__file__)}/{dataset_name}_flips/{model_name}/models/{budget:.3f}/split_{rep // rep_per_split}/rand_model_{rep % rep_per_split}/{save_name}' 117 | cur_model: torch.nn.Module = copy.deepcopy(model) 118 | cur_model.load_state_dict(torch.load(fpath)) 119 | models.append(cur_model.eval()) 120 | except Exception as e: 121 | if debug: 122 | print(e) 123 | break 124 | rep += 1 125 | return models 126 | 127 | 128 | 129 | 130 | 131 | def save_rep_edge_flips(model_name, budget: float, attack_name, flip_ls, dataset_name='cora'): 132 | for rep, flip in enumerate(flip_ls): 133 | if type(flip) is list: 134 | for i, f in enumerate(flip): 135 | fpath = f'./my_exp/{dataset_name}_flips/{model_name}/{attack_name}/{budget:.3f}/split_{rep}/node_{i}/' 136 | save_edge_flip(f, fpath) 137 | else: 138 | # save flip to file 139 | fpath = f'./my_exp/{dataset_name}_flips/{model_name}/{attack_name}/{budget:.3f}/split_{rep}/' 140 | save_edge_flip(flip, fpath) 141 | 142 | 143 | def load_rep_edge_flips(model_name, budget: float, attack_name, dataset_name='cora'): 144 | flips = [] 145 | rep = 0 146 | while True: 147 | try: 148 | if 'global' in attack_name: 149 | fpath = f'{os.path.dirname(__file__)}/{dataset_name}_flips/{model_name}/{attack_name}/{budget:.3f}/split_{rep}/' 150 | flip = load_edge_flip(fpath) 151 | flips.append(flip) 152 | elif 'local' in attack_name: 153 | cur_flips = [] 154 | node_idx = 0 155 | while True: 156 | fpath = f'{os.path.dirname(__file__)}/{dataset_name}_flips/{model_name}/{attack_name}/{budget:.3f}/split_{rep}/' 157 | 158 | if os.path.exists(fpath) and not os.path.exists(fpath + f'node_{node_idx}/'): 159 | print(f'{node_idx} nodes loaded') 160 | break 161 | 162 | fpath += f'node_{node_idx}/' 163 | flip = load_edge_flip(fpath) 164 | cur_flips.append(flip) 165 | node_idx += 1 166 | 167 | flips.append(cur_flips) 168 | 169 | except Exception as e: 170 | print('Loading terminates.', f'{rep} splits loaded.') 171 | if rep == 0: 172 | print('Last error message:', e) 173 | break 174 | rep += 1 175 | return flips 176 | 177 | -------------------------------------------------------------------------------- /phi/solver/cuda/src/laplace_op.cu.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | 6 | static void CheckCudaErrorAux(const char* file, unsigned line, const char* statement, cudaError_t err) { 7 | if (err == cudaSuccess) return; 8 | std::cerr << statement << " returned " << cudaGetErrorString(err) << "(" 9 | << err << ") at " << file << ":" << line << std::endl; 10 | 11 | exit(1); 12 | } 13 | #define CUDA_CHECK_RETURN(value) CheckCudaErrorAux(__FILE__, __LINE__, #value, value) 14 | 15 | // Converts coordinates of the simulation grid to indices of the extended mask grid with a shift to get the indices of the neighbors 16 | __device__ int gridIDXWithOffsetShifted(const int *dimensions, const int dim_size, const int *cords, int cords_offset, int dim_index_offset, int offset) 17 | { 18 | int factor = 1; 19 | int result = 0; 20 | for (int i = 0; i < dim_size; i++) 21 | { 22 | if (i == dim_index_offset) 23 | result += factor * (cords[i + cords_offset * dim_size] + offset); 24 | else 25 | result += factor * (cords[i + cords_offset * dim_size] + 1); 26 | 27 | factor *= dimensions[i]; 28 | } 29 | return result; 30 | } 31 | 32 | __device__ void CordsByRow(int row, const int *dimensions, const int dim_size, const int dim_product, int *cords) 33 | { 34 | int modulo = 0; 35 | int divisor = dim_product; 36 | 37 | for (int i = dim_size - 1; i >= 0; i--) 38 | { 39 | divisor = divisor / dimensions[i]; 40 | 41 | cords[i + row * dim_size] = (modulo == 0 ? row : (row % modulo)) / divisor; // 0 mod 0 not possible due to c++ restrictions 42 | modulo = divisor; 43 | } 44 | } 45 | 46 | 47 | __global__ void calcLaplaceMatrix(const int *dimensions, const int dim_size, const int dim_product, const float *active_mask, const float *fluid_mask, const int *mask_dimensions, signed char *laplace_matrix, int *cords) 48 | { 49 | for (int row = blockIdx.x * blockDim.x + threadIdx.x; row < dim_product; row += blockDim.x * gridDim.x) 50 | { // TODO: reduce this by half since the matrix is symmetrical and only the half needs to be created? => Already pretty fast 51 | // Derive the coordinates of the dim_size-Dimensional mask by the laplace row id 52 | CordsByRow(row, dimensions, dim_size, dim_product, cords); 53 | 54 | // Every thread accesses the laplaceDataBuffer at different areas. index_pointer points to the current position of the current thread 55 | int index_pointer = row * (dim_size * 2 + 1); 56 | 57 | // forward declaration of variables, that are reused 58 | int mask_idx = 0; 59 | int mask_idx_before = 0; 60 | int mask_idx_after = 0; 61 | 62 | // dim_size-Dimensional "Cubes" have exactly dim_size * 2 "neighbors" 63 | int diagonal = -dim_size * 2; 64 | 65 | // get the index on the extended mask grid of the current cell 66 | int rowMaskIdx = gridIDXWithOffsetShifted(mask_dimensions, dim_size, cords, row, 0, 1); 67 | 68 | // Check neighbors if they are solids. For every solid neighbor increment diagonal by one 69 | for (int j = dim_size - 1; j >= 0; j--) 70 | { 71 | // get the index on the extended mask grid of the neighbor cells 72 | mask_idx_before = gridIDXWithOffsetShifted(mask_dimensions, dim_size, cords, row, j, 0); 73 | mask_idx_after = gridIDXWithOffsetShifted(mask_dimensions, dim_size, cords, row, j, 2); 74 | if(active_mask[mask_idx_before] == 0.0f && fluid_mask[mask_idx_before] == 0.0f) diagonal++; 75 | if(active_mask[mask_idx_after] == 0.0f && fluid_mask[mask_idx_after] == 0.0f) diagonal++; 76 | } 77 | 78 | // Check the "left"/"before" neighbors if they are fluid and add them to the laplaceData 79 | for (int j = dim_size - 1; j >= 0; j--) 80 | { 81 | mask_idx = gridIDXWithOffsetShifted(mask_dimensions, dim_size, cords, row, j, 0); 82 | 83 | if (active_mask[mask_idx] == 1 && fluid_mask[mask_idx] == 1 && !(active_mask[rowMaskIdx] == 0 && fluid_mask[rowMaskIdx] == 0)) 84 | { // fluid - fluid 85 | laplace_matrix[index_pointer] = 1; 86 | } 87 | else if (active_mask[mask_idx] == 0 && fluid_mask[mask_idx] == 1) 88 | { // Empty / open cell 89 | // pass, because we initialized the data with zeros 90 | } 91 | index_pointer++; 92 | } 93 | 94 | // Add the diagonal value 95 | laplace_matrix[index_pointer] = diagonal; 96 | index_pointer++; 97 | 98 | // Finally add the "right"/"after" neighbors 99 | for (int j = 0; j < dim_size; j++) 100 | { 101 | mask_idx = gridIDXWithOffsetShifted(mask_dimensions, dim_size, cords, row, j, 2); 102 | 103 | if (active_mask[mask_idx] == 1 && fluid_mask[mask_idx] == 1 && !(active_mask[rowMaskIdx] == 0 && fluid_mask[rowMaskIdx] == 0)) 104 | { // fluid - fluid 105 | laplace_matrix[index_pointer] = 1; 106 | } 107 | else if (active_mask[mask_idx] == 0 && fluid_mask[mask_idx] == 1) 108 | { // Empty / open cell 109 | // pass, because we initialized the data with zeros 110 | } 111 | index_pointer++; 112 | } 113 | } 114 | } 115 | 116 | __global__ void setUpData( const int dim_size, const int dim_product, signed char *laplace_matrix) { 117 | for (int row = blockIdx.x * blockDim.x + threadIdx.x; 118 | row < dim_product * (dim_size * 2 + 1); 119 | row += blockDim.x * gridDim.x) 120 | { 121 | laplace_matrix[row] = 0; 122 | } 123 | } 124 | void LaplaceMatrixKernelLauncher(const int *dimensions, const int dim_size, const int dim_product, const float *active_mask, const float *fluid_mask, const int *mask_dimensions, signed char *laplace_matrix, int *cords) { 125 | // get block and gridSize to theoretically get best occupancy 126 | int blockSize; 127 | int minGridSize; 128 | int gridSize; 129 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, 130 | setUpData, 0, 0); 131 | gridSize = (dim_product * (dim_size * 2 + 1) + blockSize - 1) / blockSize; 132 | 133 | // Init Laplace Matrix with zeros 134 | setUpData<<>>(dim_size, dim_product, laplace_matrix); 135 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 136 | 137 | 138 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, 139 | calcLaplaceMatrix, 0, 0); 140 | gridSize = (dim_product + blockSize - 1) / blockSize; 141 | 142 | // Calculate the Laplace Matrix 143 | calcLaplaceMatrix<<>>(dimensions, dim_size, dim_product, active_mask, fluid_mask, mask_dimensions, laplace_matrix, cords); 144 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 145 | } 146 | 147 | -------------------------------------------------------------------------------- /phi/math/scipy_backend.py: -------------------------------------------------------------------------------- 1 | from phi.math.base import Backend 2 | import numpy as np 3 | import numbers 4 | import collections 5 | import scipy.sparse, scipy.signal 6 | 7 | class SciPyBackend(Backend): 8 | 9 | def __init__(self): 10 | Backend.__init__(self, "SciPy") 11 | 12 | def is_applicable(self, values): 13 | if values is None: return True 14 | if isinstance(values, np.ndarray): return True 15 | if isinstance(values, numbers.Number): return True 16 | if isinstance(values, bool): return True 17 | if scipy.sparse.issparse(values): return True 18 | # if isinstance(values, collections.Iterable): python < 3.9 19 | if isinstance(values, collections.abc.Iterable): # python >= 3.9 20 | try: 21 | for value in values: 22 | if not self.is_applicable(value): return False 23 | return True 24 | except: 25 | return False 26 | return False 27 | 28 | def rank(self, value): 29 | return len(value.shape) 30 | 31 | def stack(self, values, axis=0): 32 | return np.stack(values, axis) 33 | 34 | def concat(self, values, axis): 35 | return np.concatenate(values, axis) 36 | 37 | def pad(self, value, pad_width, mode="constant", constant_values=0): 38 | if mode.lower() == "constant": 39 | return np.pad(value, pad_width, "constant", constant_values=constant_values) 40 | else: 41 | return np.pad(value, pad_width, mode.lower()) 42 | 43 | def add(self, values): 44 | return np.sum(values, axis=0) 45 | 46 | def reshape(self, value, shape): 47 | return value.reshape(shape) 48 | 49 | def sum(self, value, axis=None): 50 | return np.sum(value, axis=tuple(axis) if axis is not None else None) 51 | 52 | def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad=None): 53 | result = func(*inputs) 54 | assert result.dtype == Tout, "returned value has wrong type: {}, expected {}".format(result.dtype, Tout) 55 | assert result.shape == shape_out, "returned value has wrong shape: {}, expected {}".format(result.shape, 56 | shape_out) 57 | return result 58 | 59 | def resample(self, inputs, sample_coords, interpolation="LINEAR", boundary="ZERO"): 60 | if boundary.lower() == "zero": 61 | pass # default 62 | elif boundary.lower() == "replicate": 63 | sample_coords = clamp(sample_coords, inputs.shape[1:-1][::-1]) 64 | else: 65 | raise ValueError("Unsupported boundary: %s"%boundary) 66 | 67 | import scipy.interpolate 68 | points = [np.arange(dim) for dim in inputs.shape[1:-1]] 69 | result = [] 70 | for batch in range(sample_coords.shape[0]): 71 | components = [] 72 | for dim in range(inputs.shape[-1]): 73 | resampled = scipy.interpolate.interpn(points, inputs[batch, ..., dim], sample_coords[batch, ...], 74 | method=interpolation.lower(), bounds_error=False, fill_value=0) 75 | components.append(resampled) 76 | result.append(np.stack(components, -1)) 77 | 78 | result = np.stack(result).astype(inputs.dtype) 79 | return result 80 | 81 | def zeros_like(self, tensor): 82 | return np.zeros_like(tensor) 83 | 84 | def ones_like(self, tensor): 85 | return np.ones_like(tensor) 86 | 87 | def mean(self, value, axis=None): 88 | return np.mean(value, axis) 89 | 90 | def dot(self, a, b, axes): 91 | return np.tensordot(a, b, axes) 92 | 93 | def matmul(self, A, b): 94 | return np.stack([A.dot(b[i]) for i in range(b.shape[0])]) 95 | 96 | def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, 97 | swap_memory=False, name=None, maximum_iterations=None): 98 | i = 0 99 | while cond(*loop_vars): 100 | if maximum_iterations is not None and i == maximum_iterations: break 101 | loop_vars = body(*loop_vars) 102 | i += 1 103 | return loop_vars 104 | 105 | def abs(self, x): 106 | return np.abs(x) 107 | 108 | def ceil(self, x): 109 | return np.ceil(x) 110 | 111 | def floor(self, x): 112 | return np.floor(x) 113 | 114 | def max(self, x, axis=None): 115 | return np.max(x, axis) 116 | 117 | def with_custom_gradient(self, function, inputs, gradient, input_index=0, output_index=None, name_base="custom_gradient_func"): 118 | return function(*inputs) 119 | 120 | def maximum(self, a, b): 121 | return np.maximum(a, b) 122 | 123 | def minimum(self, a, b): 124 | return np.minimum(a, b) 125 | 126 | def sqrt(self, x): 127 | return np.sqrt(x) 128 | 129 | def exp(self, x): 130 | return np.exp(x) 131 | 132 | def conv(self, tensor, kernel, padding="SAME"): 133 | assert tensor.shape[-1] == kernel.shape[-2] 134 | # kernel = kernel[[slice(None)] + [slice(None, None, -1)] + [slice(None)]*(len(kernel.shape)-3) + [slice(None)]] 135 | if padding.lower() == "same": 136 | result = np.zeros(tensor.shape[:-1]+(kernel.shape[-1],), np.float32) 137 | elif padding.lower() == "valid": 138 | valid = [tensor.shape[i+1]-(kernel.shape[i]+1)//2 for i in range(tensor_spatial_rank(tensor))] 139 | result = np.zeros([tensor.shape[0]]+valid+[kernel.shape[-1]], np.float32) 140 | else: 141 | raise ValueError("Illegal padding: %s"%padding) 142 | for batch in range(tensor.shape[0]): 143 | for o in range(kernel.shape[-1]): 144 | for i in range(tensor.shape[-1]): 145 | result[batch, ..., o] += scipy.signal.correlate(tensor[batch, ..., i], kernel[..., i, o], padding.lower()) 146 | return result 147 | 148 | def expand_dims(self, a, axis): 149 | return np.expand_dims(a, axis) 150 | 151 | def shape(self, tensor): 152 | return tensor.shape 153 | 154 | def to_float(self, x): 155 | if not isinstance(x, np.ndarray): 156 | return float(x) 157 | return x.astype(np.float32) 158 | 159 | def gather(self, values, indices): 160 | return values[indices] 161 | 162 | def unstack(self, tensor, axis=0): 163 | result = [] 164 | for i in range(tensor.shape[axis]): 165 | result.append(tensor[[i if d==axis else slice(None) for d in range(len(tensor.shape))]]) 166 | return result 167 | 168 | def std(self, x, axis=None): 169 | return np.std(x, axis) 170 | 171 | def boolean_mask(self, x, mask): 172 | return x[mask] 173 | 174 | def isfinite(self, x): 175 | return np.isfinite(x) 176 | 177 | def tile(self, x, multiples): 178 | return np.tile(x, multiples) 179 | 180 | 181 | 182 | def clamp(coordinates, shape): 183 | assert coordinates.shape[-1] == len(shape) 184 | for i in range(len(shape)): 185 | coordinates[...,i] = np.maximum(0, np.minimum(shape[i], coordinates[...,i])) 186 | return coordinates 187 | 188 | 189 | def tensor_spatial_rank(field): 190 | dims = len(field.shape) - 2 191 | assert dims > 0, "field has no spatial dimensions" 192 | return dims -------------------------------------------------------------------------------- /phi/solver/cuda/src/pressure_solve_op.cc: -------------------------------------------------------------------------------- 1 | #include "tensorflow/core/framework/op.h" 2 | #include "tensorflow/core/framework/op_kernel.h" 3 | #include "tensorflow/core/framework/shape_inference.h" 4 | #include 5 | 6 | using namespace tensorflow; // NOLINT(build/namespaces) 7 | 8 | REGISTER_OP("PressureSolve") 9 | .Input("dimensions: int32") 10 | 11 | .Input("mask_dimensions: int32") 12 | .Input("active_mask: float32") 13 | .Input("fluid_mask: float32") 14 | .Input("laplace_matrix: int8") 15 | 16 | .Input("divergence: float32") 17 | .Input("p: float32") 18 | .Input("r: float32") 19 | .Input("z: float32") 20 | .Input("pressure: float32") 21 | .Input("one_vector: float32") 22 | 23 | .Attr("dim_product: int") 24 | .Attr("accuracy: float") 25 | .Attr("max_iterations: int") 26 | .Output("pressure_out: float32") 27 | .Output("iterations: int32") 28 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { 29 | c->set_output(0, c->input(5)); // divergence 30 | return Status::OK(); 31 | }); 32 | 33 | 34 | void LaunchPressureKernel(const int *dimensions, const int dimProduct, const int dimSize, 35 | const signed char* laplaceMatrix, 36 | float* p, float* z, float* r, float *divergence, float* x, 37 | const float* oneVector, 38 | bool* thresholdReached, 39 | const float accuracy, 40 | const int max_iterations, 41 | const int batch_size, 42 | int* iterations_gpu); 43 | 44 | void LaplaceMatrixKernelLauncher(const int *dimensions, const int dimSize, const int dimProduct, const float *active_mask, const float *fluid_mask, const int *maskDimensions, signed char *laplaceMatrix, int *cords); 45 | 46 | class PressureSolveOp : public OpKernel { 47 | private: 48 | int dim_product; 49 | float accuracy; 50 | int max_iterations; 51 | 52 | public: 53 | explicit PressureSolveOp(OpKernelConstruction* context) : OpKernel(context) { 54 | context->GetAttr("dim_product", &dim_product); 55 | context->GetAttr("accuracy", &accuracy); 56 | context->GetAttr("max_iterations", &max_iterations); 57 | } 58 | 59 | void Compute(OpKernelContext* context) override { 60 | auto begin = std::chrono::high_resolution_clock::now(); 61 | 62 | // General 63 | const Tensor& input_dimensions = context->input(0); 64 | 65 | // Laplace related 66 | const Tensor &input_mask_dimensions = context->input(1); 67 | const Tensor &input_active_mask = context->input(2); 68 | const Tensor &input_fluid_mask = context->input(3); 69 | Tensor input_laplace_matrix = context->input(4); 70 | 71 | // Pressure Solve 72 | Tensor input_divergence = context->input(5); 73 | Tensor input_p = context->input(6); 74 | Tensor input_r = context->input(7); 75 | Tensor input_z = context->input(8); 76 | Tensor input_pressure = context->input(9); 77 | const Tensor& input_one_vector = context->input(10); 78 | 79 | // Flattening 80 | auto dimensions = input_dimensions.flat(); 81 | 82 | auto mask_dimensions = input_mask_dimensions.flat(); 83 | auto active_mask = input_active_mask.flat(); 84 | auto fluid_mask = input_fluid_mask.flat(); 85 | auto laplace_matrix = input_laplace_matrix.flat(); 86 | 87 | auto divergence = input_divergence.flat(); 88 | auto p = input_p.flat(); 89 | auto r = input_r.flat(); 90 | auto z = input_z.flat(); 91 | auto pressure = input_pressure.flat(); 92 | auto one_vector = input_one_vector.flat(); 93 | 94 | int batch_size = input_divergence.shape().dim_size(0); 95 | int dim_size = dimensions.size(); 96 | 97 | auto end = std::chrono::high_resolution_clock::now(); 98 | 99 | // printf("General Preparation took: %f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 100 | 101 | begin = std::chrono::high_resolution_clock::now(); 102 | // Laplace: 103 | // Laplace Helper 104 | Tensor cords; // cords allocation does not really impact the performance. However it could be outsourced to be reused. 105 | TensorShape cords_shape; 106 | cords_shape.AddDim(dim_product); 107 | cords_shape.AddDim(dim_size); 108 | OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_INT32, cords_shape, &cords)); 109 | auto cords_flat = cords.flat(); 110 | 111 | end = std::chrono::high_resolution_clock::now(); 112 | 113 | // printf("Laplace Preparation took: %f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 114 | 115 | 116 | begin = std::chrono::high_resolution_clock::now(); 117 | LaplaceMatrixKernelLauncher(dimensions.data(), dim_size, dim_product, active_mask.data(), fluid_mask.data(), mask_dimensions.data(), laplace_matrix.data(), cords_flat.data()); 118 | end = std::chrono::high_resolution_clock::now(); 119 | 120 | // printf("Laplace Matrix Generation took: %f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 121 | 122 | 123 | begin = std::chrono::high_resolution_clock::now(); 124 | 125 | TensorShape threshold_shape; 126 | threshold_shape.AddDim(batch_size); 127 | Tensor threshold_reached_tensor; 128 | OP_REQUIRES_OK(context, context->allocate_temp(DataType::DT_BOOL, threshold_shape, &threshold_reached_tensor)); 129 | auto threshold_reached = threshold_reached_tensor.flat(); 130 | 131 | context->set_output(0, input_pressure); 132 | 133 | TensorShape iterations_shape; 134 | iterations_shape.AddDim(1); 135 | Tensor* iterations_tensor; 136 | 137 | OP_REQUIRES_OK(context, context->allocate_output(1, iterations_shape, &iterations_tensor)); 138 | auto iterations_flat = iterations_tensor->flat(); 139 | 140 | end = std::chrono::high_resolution_clock::now(); 141 | 142 | 143 | // printf("Pressure Solve Preparation took: %f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 144 | 145 | 146 | begin = std::chrono::high_resolution_clock::now(); 147 | LaunchPressureKernel(dimensions.data(), dim_product, dim_size, 148 | laplace_matrix.data(), 149 | p.data(), z.data(), r.data(), divergence.data(), pressure.data(), 150 | one_vector.data(), 151 | threshold_reached.data(), 152 | accuracy, 153 | max_iterations, 154 | batch_size, 155 | iterations_flat.data()); 156 | end = std::chrono::high_resolution_clock::now(); 157 | 158 | 159 | // printf("Pressure Solve took: %f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 160 | // printf("%f\n", std::chrono::duration_cast(end-begin).count() * 1e-6); 161 | 162 | } 163 | }; 164 | 165 | REGISTER_KERNEL_BUILDER(Name("PressureSolve").Device(DEVICE_GPU), PressureSolveOp); 166 | -------------------------------------------------------------------------------- /utils_1d/train_diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pathlib import Path 3 | from random import random 4 | from functools import partial 5 | from multiprocessing import cpu_count 6 | import pdb 7 | import torch 8 | import torch.nn as nn 9 | from torch import nn, einsum, Tensor 10 | import torch.nn.functional as F 11 | from torch.cuda.amp import autocast 12 | from torch.optim import Adam 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | from einops import rearrange, reduce 16 | from einops.layers.torch import Rearrange 17 | from accelerate import Accelerator 18 | from ema_pytorch import EMA 19 | from tqdm.auto import tqdm 20 | from IPython import embed 21 | import datetime 22 | from collections import namedtuple 23 | from diffusion.diffusion_1d import GaussianDiffusion 24 | from utils_1d.model_utils import has_int_squareroot, cycle, exists 25 | 26 | # trainer class 27 | 28 | class Trainer(object): 29 | def __init__( 30 | self, 31 | diffusion_model: GaussianDiffusion, 32 | dataset: Dataset, 33 | *, 34 | train_batch_size = 16, 35 | gradient_accumulate_every = 1, 36 | train_lr = 1e-4, 37 | train_num_steps = 100000, 38 | ema_update_every = 16, 39 | ema_decay = 0.995, 40 | adam_betas = (0.9, 0.99), 41 | save_and_sample_every = 1000, 42 | num_samples = 25, 43 | results_folder = './results', 44 | amp = False, 45 | mixed_precision_type = 'fp16', 46 | split_batches = True, 47 | max_grad_norm = 1., 48 | ): 49 | super().__init__() 50 | 51 | # accelerator 52 | 53 | self.accelerator = Accelerator( 54 | split_batches = split_batches, 55 | mixed_precision = mixed_precision_type if amp else 'no' 56 | ) 57 | 58 | # model 59 | 60 | self.model = diffusion_model 61 | self.channels = diffusion_model.channels 62 | 63 | # sampling and training hyperparameters 64 | 65 | assert has_int_squareroot(num_samples), 'number of samples must have an integer square root' 66 | self.num_samples = num_samples 67 | self.save_and_sample_every = save_and_sample_every 68 | 69 | self.batch_size = train_batch_size 70 | self.gradient_accumulate_every = gradient_accumulate_every 71 | self.max_grad_norm = max_grad_norm 72 | 73 | self.train_num_steps = train_num_steps 74 | 75 | # dataset and dataloader 76 | 77 | dl = DataLoader(dataset, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count()) 78 | 79 | dl = self.accelerator.prepare(dl) 80 | self.dl = cycle(dl) 81 | 82 | # optimizer & scheduler 83 | 84 | self.opt = Adam(diffusion_model.parameters(), lr = train_lr, betas = adam_betas) 85 | self.scheduler = CosineAnnealingLR(self.opt, T_max=10000, eta_min=0) 86 | 87 | # for logging results in a folder periodically 88 | 89 | if self.accelerator.is_main_process: 90 | self.ema = EMA(diffusion_model, beta = ema_decay, update_every = ema_update_every) 91 | self.ema.to(self.device) 92 | 93 | self.results_folder = Path(results_folder) 94 | make_dir(results_folder) 95 | self.results_folder.mkdir(exist_ok = True) 96 | 97 | # step counter state 98 | 99 | self.step = 0 100 | 101 | # prepare model, dataloader, optimizer with accelerator 102 | 103 | self.model, self.opt = self.accelerator.prepare(self.model, self.opt) 104 | 105 | @property 106 | def device(self): 107 | return self.accelerator.device 108 | 109 | def save(self, milestone): 110 | if not self.accelerator.is_local_main_process: 111 | return 112 | 113 | data = { 114 | 'step': self.step, 115 | 'model': self.accelerator.get_state_dict(self.model), 116 | 'opt': self.opt.state_dict(), 117 | 'ema': self.ema.state_dict(), 118 | 'scaler': self.accelerator.scaler.state_dict() if exists(self.accelerator.scaler) else None, 119 | 'loss': self.total_loss, 120 | # 'version': __version__ 121 | } 122 | 123 | torch.save(data, str(self.results_folder / f'cos10000-model-{milestone}.pt')) 124 | 125 | def load(self, milestone): 126 | accelerator = self.accelerator 127 | device = accelerator.device 128 | 129 | print('loading model from', str(self.results_folder / f'cos10000-model-{milestone}.pt')) 130 | 131 | if type(milestone) is int: 132 | data = torch.load(str(self.results_folder / f'cos10000-model-{milestone}.pt'), map_location=device) 133 | else: 134 | data = torch.load(str(self.results_folder / milestone), map_location=device) 135 | 136 | model = self.accelerator.unwrap_model(self.model) 137 | model.load_state_dict(data['model']) 138 | 139 | self.step = data['step'] 140 | self.opt.load_state_dict(data['opt']) 141 | if self.accelerator.is_main_process: 142 | self.ema.load_state_dict(data["ema"]) 143 | 144 | if 'version' in data: 145 | print(f"loading from version {data['version']}") 146 | 147 | if exists(self.accelerator.scaler) and exists(data['scaler']): 148 | self.accelerator.scaler.load_state_dict(data['scaler']) 149 | 150 | def train(self): 151 | print(self.device) 152 | 153 | accelerator = self.accelerator 154 | device = accelerator.device 155 | 156 | with tqdm(initial = self.step, total = self.train_num_steps, disable = not accelerator.is_main_process) as pbar: 157 | 158 | while self.step < self.train_num_steps: 159 | 160 | total_loss = 0. 161 | 162 | for _ in range(self.gradient_accumulate_every): 163 | data = next(self.dl).to(device) 164 | 165 | with self.accelerator.autocast(): 166 | loss = self.model(data) 167 | loss = loss / self.gradient_accumulate_every 168 | total_loss += loss.item() 169 | 170 | self.accelerator.backward(loss) 171 | 172 | self.total_loss = total_loss 173 | pbar.set_description(f'loss: {total_loss:.4f}') 174 | 175 | accelerator.wait_for_everyone() 176 | accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 177 | 178 | self.opt.step() 179 | self.opt.zero_grad() 180 | self.scheduler.step() 181 | 182 | accelerator.wait_for_everyone() 183 | 184 | self.step += 1 185 | if accelerator.is_main_process: 186 | self.ema.update() 187 | 188 | if self.step != 0 and self.step % self.save_and_sample_every == 0: 189 | self.ema.ema_model.eval() 190 | 191 | with torch.no_grad(): 192 | milestone = self.step // self.save_and_sample_every 193 | self.save(milestone) 194 | 195 | pbar.update(1) 196 | 197 | accelerator.print('training complete') 198 | 199 | class Trainer1D(Trainer): 200 | '''Rename Trainer1D to Trainer, but also allow the original references 201 | ''' 202 | def __init__(self, *args, **kwargs): 203 | super().__init__(*args, **kwargs) 204 | 205 | 206 | def make_dir(filename): 207 | """Make directory using filename if the directory does not exist""" 208 | import os 209 | import errno 210 | if not os.path.exists(os.path.dirname(filename)): 211 | print("directory {0} does not exist, created.".format(os.path.dirname(filename))) 212 | try: 213 | os.makedirs(os.path.dirname(filename)) 214 | except OSError as exc: # Guard against race condition 215 | if exc.errno != errno.EEXIST: 216 | print(exc) 217 | raise 218 | 219 | -------------------------------------------------------------------------------- /phi/math/base.py: -------------------------------------------------------------------------------- 1 | 2 | class Backend: 3 | 4 | def __init__(self, name): 5 | self.name = name 6 | 7 | def __str__(self): 8 | return self.name 9 | 10 | def __repr__(self): 11 | return self.name 12 | 13 | def matches_name(self, name): 14 | return self.name.lower() == name.lower() 15 | 16 | def is_applicable(self, values): 17 | return False 18 | 19 | def stack(self, values, axis=0): 20 | raise NotImplementedError() 21 | 22 | def concat(self, values, axis): 23 | raise NotImplementedError() 24 | 25 | def pad(self, value, pad_width, mode="constant", constant_values=0): 26 | raise NotImplementedError() 27 | 28 | def add(self, values): 29 | raise NotImplementedError() 30 | 31 | def reshape(self, value, shape): 32 | raise NotImplementedError() 33 | 34 | def sum(self, value, axis=None): 35 | raise NotImplementedError() 36 | 37 | def mean(self, value, axis=None): 38 | raise NotImplementedError() 39 | 40 | def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad=None): 41 | raise NotImplementedError() 42 | 43 | def resample(self, inputs, sample_coords, interpolation="LINEAR", boundary="ZERO"): 44 | raise NotImplementedError() 45 | 46 | def zeros_like(self, tensor): 47 | raise NotImplementedError() 48 | 49 | def ones_like(self, tensor): 50 | raise NotImplementedError() 51 | 52 | def dot(self, a, b, axes): 53 | raise NotImplementedError() 54 | 55 | def matmul(self, A, b): 56 | raise NotImplementedError() 57 | 58 | def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, 59 | swap_memory=False, name=None, maximum_iterations=None): 60 | raise NotImplementedError() 61 | 62 | def abs(self, x): 63 | raise NotImplementedError() 64 | 65 | def ceil(self, x): 66 | raise NotImplementedError() 67 | 68 | def floor(self, x): 69 | raise NotImplementedError() 70 | 71 | def max(self, x, axis=None): 72 | raise NotImplementedError() 73 | 74 | def maximum(self, a, b): 75 | raise NotImplementedError() 76 | 77 | def minimum(self, a, b): 78 | raise NotImplementedError() 79 | 80 | def with_custom_gradient(self, function, inputs, gradient, input_index=0, output_index=None, name_base="custom_gradient_func"): 81 | raise NotImplementedError() 82 | 83 | def sqrt(self, x): 84 | raise NotImplementedError() 85 | 86 | def exp(self, x): 87 | raise NotImplementedError() 88 | 89 | def conv(self, tensor, kernel, padding="SAME"): 90 | raise NotImplementedError() 91 | 92 | def expand_dims(self, a, axis): 93 | raise NotImplementedError() 94 | 95 | def shape(self, tensor): 96 | raise NotImplementedError() 97 | 98 | def to_float(self, x): 99 | raise NotImplementedError() 100 | 101 | def dimrange(self, tensor): 102 | return range(1, len(tensor.shape)-1) 103 | 104 | def gather(self, values, indices): 105 | raise NotImplementedError() 106 | 107 | def flatten(self, x): 108 | return self.reshape(x, (-1,) ) 109 | 110 | def unstack(self, tensor, axis=0): 111 | raise NotImplementedError() 112 | 113 | def std(self, x, axis=None): 114 | raise NotImplementedError() 115 | 116 | def boolean_mask(self, x, mask): 117 | raise NotImplementedError() 118 | 119 | def isfinite(self, x): 120 | raise NotImplementedError() 121 | 122 | def tile(self, x, multiples): 123 | raise NotImplementedError() 124 | 125 | 126 | class DynamicBackend(Backend): 127 | 128 | def __init__(self): 129 | Backend.__init__(self, "Dynamic") 130 | self.backends = [] 131 | 132 | def choose_backend(self, values): 133 | if not isinstance(values, tuple) and not isinstance(values, list): 134 | values = [values] 135 | for backend in self.backends: 136 | if backend.is_applicable(values): 137 | return backend 138 | raise NoBackendFound("No backend found for values %s; registered backends are %s" % (values, self.backends)) 139 | 140 | def is_applicable(self, values): 141 | if not isinstance(values, tuple) and not isinstance(values, list): 142 | values = [values] 143 | for backend in self.backends: 144 | if backend.is_applicable(values): 145 | return True 146 | return False 147 | 148 | def stack(self, values, axis=0): 149 | return self.choose_backend(values).stack(values, axis) 150 | 151 | def concat(self, values, axis): 152 | return self.choose_backend(values).concat(values, axis) 153 | 154 | def pad(self, value, pad_width, mode="constant", constant_values=0): 155 | return self.choose_backend(value).pad(value, pad_width, mode, constant_values) 156 | 157 | def add(self, values): 158 | return self.choose_backend(values).add(values) 159 | 160 | def reshape(self, value, shape): 161 | return self.choose_backend(value).reshape(value, shape) 162 | 163 | def sum(self, value, axis=None): 164 | return self.choose_backend(value).sum(value, axis) 165 | 166 | def mean(self, value, axis=None): 167 | return self.choose_backend(value).mean(value, axis) 168 | 169 | def py_func(self, func, inputs, Tout, shape_out, stateful=True, name=None, grad=None): 170 | return self.choose_backend(inputs).py_func(func, inputs, Tout, shape_out, stateful, name, grad) 171 | 172 | def resample(self, inputs, sample_coords, interpolation="LINEAR", boundary="ZERO"): 173 | return self.choose_backend((inputs, sample_coords)).resample(inputs, sample_coords, interpolation, boundary) 174 | 175 | def zeros_like(self, tensor): 176 | return self.choose_backend(tensor).zeros_like(tensor) 177 | 178 | def ones_like(self, tensor): 179 | return self.choose_backend(tensor).ones_like(tensor) 180 | 181 | def dot(self, a, b, axes): 182 | return self.choose_backend((a, b)).dot(a, b, axes) 183 | 184 | def matmul(self, A, b): 185 | return self.choose_backend((A, b)).matmul(A, b) 186 | 187 | def while_loop(self, cond, body, loop_vars, shape_invariants=None, parallel_iterations=10, back_prop=True, 188 | swap_memory=False, name=None, maximum_iterations=None): 189 | return self.choose_backend(loop_vars).while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, 190 | back_prop, swap_memory, name, maximum_iterations) 191 | 192 | def abs(self, x): 193 | return self.choose_backend(x).abs(x) 194 | 195 | def ceil(self, x): 196 | return self.choose_backend(x).ceil(x) 197 | 198 | def floor(self, x): 199 | return self.choose_backend(x).floor(x) 200 | 201 | def max(self, x, axis=None): 202 | return self.choose_backend(x).max(x, axis) 203 | 204 | def maximum(self, a, b): 205 | return self.choose_backend([a,b]).maximum(a, b) 206 | 207 | def minimum(self, a, b): 208 | return self.choose_backend([a,b]).minimum(a, b) 209 | 210 | def with_custom_gradient(self, function, inputs, gradient, input_index=0, output_index=None, name_base="custom_gradient_func"): 211 | return self.choose_backend(inputs[0]).with_custom_gradient(function, inputs, gradient, input_index, output_index, name_base) 212 | 213 | def sqrt(self, x): 214 | return self.choose_backend(x).sqrt(x) 215 | 216 | def exp(self, x): 217 | return self.choose_backend(x).exp(x) 218 | 219 | def conv(self, tensor, kernel, padding="SAME"): 220 | return self.choose_backend([tensor, kernel]).conv(tensor, kernel, padding) 221 | 222 | def expand_dims(self, a, axis): 223 | return self.choose_backend(a).expand_dims(a, axis) 224 | 225 | def shape(self, tensor): 226 | return self.choose_backend(tensor).shape(tensor) 227 | 228 | def to_float(self, x): 229 | return self.choose_backend(x).to_float(x) 230 | 231 | def gather(self, values, indices): 232 | return self.choose_backend([values, indices]).gather(values, indices) 233 | 234 | def unstack(self, tensor, axis=0): 235 | return self.choose_backend(tensor).unstack(tensor, axis) 236 | 237 | def std(self, x, axis=None): 238 | return self.choose_backend(x).std(x, axis) 239 | 240 | def boolean_mask(self, x, mask): 241 | return self.choose_backend((x, mask)).boolean_mask(x, mask) 242 | 243 | def isfinite(self, x): 244 | return self.choose_backend(x).isfinite(x) 245 | 246 | def tile(self, x, multiples): 247 | return self.choose_backend([x, multiples]).tile(x, multiples) 248 | 249 | 250 | class NoBackendFound(Exception): 251 | def __init__(self, msg): 252 | Exception.__init__(self, msg) 253 | -------------------------------------------------------------------------------- /phi/viz/plot.py: -------------------------------------------------------------------------------- 1 | import numpy, os 2 | import plotly.figure_factory as ff 3 | 4 | import phi.math.nd 5 | 6 | # Views 7 | FRONT = 'front' 8 | RIGHT = 'right' 9 | TOP = 'top' 10 | 11 | # Vector display 12 | LENGTH = 'length' 13 | VECTOR2 = 'vec2' 14 | 15 | class PlotlyFigureBuilder(object): 16 | 17 | 18 | def __init__(self, 19 | batches=slice(None), 20 | depths=slice(None), 21 | staggered=False, 22 | antisymmetry=False, 23 | view=FRONT, 24 | component=LENGTH, 25 | draw_arrows_backward=True, 26 | max_vector_resolution=18, 27 | max_resolution=512): 28 | self.batches = batches 29 | self.depths = depths 30 | self.staggered = staggered 31 | self.antisymmetry = antisymmetry 32 | self.view = view 33 | self.component = component 34 | self.draw_arrows_backward = draw_arrows_backward 35 | self.max_vector_resolution = max_vector_resolution 36 | self.max_resolution = max_resolution 37 | 38 | def select_batch(self, batch): 39 | if batch is None: 40 | self.batches = slice(None) 41 | elif isinstance(batch, int): 42 | self.batches = [batch] 43 | else: 44 | self.batches = batch 45 | 46 | def select_depth(self, depth): 47 | if depth is None: 48 | self.depths = slice(None) 49 | elif isinstance(depth, int): 50 | self.depths = [depth] 51 | else: 52 | self.depths = depth 53 | 54 | def save_figures(self, directory, fieldname, time, data, same_scale_data=None): 55 | import matplotlib.pyplot as plt 56 | batches = self.batches if self.batches is not None else range(data.shape[0]) 57 | for batch in batches: 58 | if len(data.shape) == 5: 59 | for depth in self.get_selected_slices(data.shape): 60 | path = os.path.join(directory, '%s_batch%04d_depth%04d_%04d.png' % (fieldname, batch, depth, time)) 61 | fig = self.create_figure(data, batch=batch, depth=depth, same_scale_data=same_scale_data, library='matplotlib') 62 | plt.savefig(path) 63 | plt.close() 64 | yield path 65 | else: 66 | path = os.path.join(directory, '%s_batch%04d_%04d.png' % (fieldname, batch, time)) 67 | fig = self.create_figure(data, batch=batch, same_scale_data=same_scale_data, library='matplotlib') 68 | plt.savefig(path) 69 | plt.close() 70 | yield path 71 | 72 | def get_selected_slices(self, shape): 73 | try: 74 | selected_depths = numpy.arange(self.slice_count(shape))[self.depths] 75 | except: 76 | selected_depths = [self.slice_count(shape) - 1] 77 | return selected_depths 78 | 79 | def create_figure(self, data, same_scale_data=None, batch=None, depth=None, library='matplotlib'): 80 | # Determine batch 81 | if data.shape[0] == 1: 82 | batch = 0 83 | if batch is None: 84 | try: 85 | selected_batches = numpy.arange(data.shape[0])[self.batches] 86 | if len(selected_batches) != 1: 87 | raise ValueError('no batch specified and default batches contains more than one element') 88 | except: 89 | return None 90 | batch = selected_batches[0] 91 | # Determine slice 92 | if depth is None and len(data.shape) == 5: 93 | selected_depths = self.get_selected_slices(data.shape) 94 | if len(selected_depths) != 1: 95 | raise ValueError('no depth specified and default depths contains more than one element') 96 | depth = selected_depths[0] 97 | 98 | # Antisymmetry 99 | if isinstance(data, phi.math.nd.StaggeredGrid): 100 | data = data.at_centers() 101 | staggered = True 102 | else: 103 | staggered = self.staggered 104 | if self.antisymmetry: 105 | if staggered: 106 | data = data[..., 1:,:] 107 | if data.shape[-1] != 1: 108 | datax = data[..., ::-1, 0:1] + data[..., 0:1] 109 | datayz = data[..., ::-1, 1:] - data[..., 1:] 110 | data = numpy.concatenate((datax, datayz), axis=-1) 111 | else: 112 | data = data - data[..., ::-1, :] 113 | 114 | # Select batch 115 | if batch < data.shape[0]: 116 | data = data[batch, ...] 117 | else: 118 | return { 'data': [{'type': 'heatmap', 'z': [[0]]}] } 119 | 120 | # 1D graph 121 | if len(data.shape) == 2: 122 | return self.graphs(data, library) 123 | 124 | # 3D projection / Select depth 125 | if len(data.shape) == 4: 126 | if self.view == FRONT: 127 | data = data[min(depth, data.shape[0]), :, :, :] 128 | elif self.view == RIGHT: 129 | data = data[:, :, min(depth, data.shape[2]), :] 130 | print(data.shape) 131 | data = numpy.transpose(data, axes=(1,0,2)) 132 | elif self.view == TOP: 133 | data = data[:, min(depth, data.shape[1]), :, :] 134 | else: 135 | data = data[0, ...] 136 | 137 | # Create figure 138 | component = 0 if data.shape[-1] == 1 else self.component 139 | 140 | if component == VECTOR2: 141 | # Downsample 142 | while numpy.prod(data.shape[:-1]) > self.max_vector_resolution ** 2: 143 | data = data[::2, ::2, :] * 0.5 144 | return self.draw_vector_field(data, library) 145 | 146 | elif component == LENGTH: 147 | if data.shape[-1] == 3: 148 | data = numpy.sqrt( data[...,0:1]**2 + data[...,1:2]**2 + data[...,2:3]**2) 149 | else: 150 | data = numpy.sqrt( data[...,0:1]**2 + data[...,1:2]**2) 151 | else: 152 | if component >= data.shape[-1]: 153 | data = numpy.zeros_like(data[...,0:1]) 154 | else: 155 | data = data[...,component:component+1] 156 | 157 | # Downsample 158 | while numpy.prod(data.shape[:-1]) > self.max_resolution ** 2: 159 | data = data[::2, ::2, :] 160 | if same_scale_data is not None: 161 | return self.heatmap(data[..., 0], library, minmax=global_minmax(same_scale_data)) 162 | else: 163 | return self.heatmap(data[..., 0], library) 164 | 165 | def slice_count(self, shape): 166 | if len(shape) <= 4: 167 | return 1 168 | if self.view == FRONT: 169 | return shape[1] 170 | elif self.view == TOP: 171 | return shape[2] 172 | elif self.view == RIGHT: 173 | return shape[3] 174 | else: 175 | raise ValueError('Illegal view: %s'%self.view) 176 | 177 | def heatmap(self, z, library, minmax=None): 178 | if library == 'dash': 179 | args = {'z' : z, 'type': 'heatmap'} 180 | if minmax is not None: 181 | args['zmin'] = minmax[0] 182 | args['zmax'] = minmax[1] 183 | return { 'data': [args] } 184 | elif library == 'matplotlib': 185 | import matplotlib.pyplot as plt 186 | fig = plt.figure() 187 | plt.imshow(z, cmap='bwr', origin='lower') 188 | return fig 189 | else: 190 | raise NotImplementedError() 191 | 192 | def graphs(self, data, library): 193 | x = numpy.arange(data.shape[0]) 194 | if library == 'dash': 195 | graphs = [{ 'mode': 'markers+lines', 'type': 'scatter', 'x': x, 'y': data[:,i]} for i in range(data.shape[-1])] 196 | return {'data': graphs } 197 | else: 198 | import matplotlib.pyplot as plt 199 | fig = plt.figure() 200 | for i in range(data.shape[-1]): 201 | plt.plot(x, data[:,i]) 202 | return fig 203 | 204 | 205 | def draw_vector_field(self, vector_field, library): 206 | x, y = numpy.meshgrid(numpy.arange(0, vector_field.shape[1], 1), numpy.arange(0, vector_field.shape[0], 1)) 207 | if library == 'dash': 208 | if self.draw_arrows_backward: 209 | return ff.create_quiver(x - vector_field[..., 0], y - vector_field[..., 1], vector_field[..., 0], vector_field[..., 1], scale=1.0) 210 | else: 211 | return ff.create_quiver(x, y, vector_field[..., 0], vector_field[..., 1], scale=1.0) 212 | else: 213 | raise NotImplementedError() 214 | 215 | def empty_figure(self, library): 216 | if library == 'dash': 217 | return { 218 | 'data': [{'z': None, 'type': 'heatmap'}] 219 | } 220 | elif library == 'matplotlib': 221 | import matplotlib.pyplot as plt 222 | fig = plt.figure() 223 | return fig 224 | 225 | 226 | def global_minmax(arrays): 227 | global_min = numpy.minimum(*[numpy.min(data) for data in arrays]) 228 | global_max = numpy.maximum(*[numpy.max(data) for data in arrays]) 229 | return global_min, global_max -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - local 4 | - conda-forge 5 | dependencies: 6 | - _libgcc_mutex=0.1=conda_forge 7 | - _openmp_mutex=4.5=1_llvm 8 | - backcall=0.2.0=pyh9f0ad1d_0 9 | - backports=1.0=py_2 10 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 11 | - beautifulsoup4=4.10.0=pyha770c72_0 12 | - brotlipy=0.7.0=py38h497a2fe_1001 13 | - bzip2=1.0.8=h7f98852_4 14 | - c-ares=1.17.2=h7f98852_0 15 | - ca-certificates=2021.5.30=ha878542_0 16 | - catalogue=2.0.6=py38h578d9bd_0 17 | - certifi=2021.5.30=py38h578d9bd_0 18 | - cffi=1.14.6=py38h3931269_1 19 | - chardet=4.0.0=py38h578d9bd_1 20 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 21 | - click=8.0.1=py38h578d9bd_0 22 | - cmake=3.21.3=h8897547_0 23 | - colorama=0.4.4=pyh9f0ad1d_0 24 | - conda=4.10.3=py38h578d9bd_2 25 | - conda-build=3.21.4=py38h578d9bd_0 26 | - conda-package-handling=1.7.3=py38h497a2fe_0 27 | - cryptography=3.4.7=py38ha5dfef3_0 28 | - cymem=2.0.5=py38h709712a_2 29 | - cython-blis=0.7.4=py38h5c078b8_0 30 | - dataclasses=0.8=pyhc8e2a94_3 31 | - decorator=5.1.0=pyhd8ed1ab_0 32 | - expat=2.4.1=h9c3ff4c_0 33 | - filelock=3.3.0=pyhd8ed1ab_0 34 | - glob2=0.7=py_0 35 | - icu=68.1=h58526e2_0 36 | - idna=3.1=pyhd3deb0d_0 37 | - ipython=7.28.0=py38he5a9106_0 38 | - jedi=0.18.0=py38h578d9bd_2 39 | - krb5=1.19.2=hcc1bbae_2 40 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 41 | - libarchive=3.5.2=hccf745f_1 42 | - libblas=3.9.0=5_h92ddd45_netlib 43 | - libcblas=3.9.0=5_h92ddd45_netlib 44 | - libcurl=7.79.1=h2574ce0_1 45 | - libedit=3.1.20191231=he28a2e2_2 46 | - libev=4.33=h516909a_1 47 | - libffi=3.4.2=h9c3ff4c_4 48 | - libgcc-ng=11.2.0=h1d223b6_9 49 | - libgfortran-ng=11.2.0=h69a702a_9 50 | - libgfortran5=11.2.0=h5c6108e_9 51 | - libgomp=11.2.0=h1d223b6_9 52 | - libiconv=1.16=h516909a_0 53 | - liblapack=3.9.0=5_h92ddd45_netlib 54 | - liblief=0.11.5=h9c3ff4c_0 55 | - libllvm10=10.0.1=he513fc3_3 56 | - libnghttp2=1.43.0=h812cca2_1 57 | - libopenblas=0.3.18=pthreads_h8fe5266_0 58 | - libssh2=1.10.0=ha56f1ee_2 59 | - libstdcxx-ng=11.2.0=he4da1e4_9 60 | - libuv=1.42.0=h7f98852_0 61 | - libxml2=2.9.12=h72842e0_0 62 | - llvm-openmp=12.0.1=h4bd325d_1 63 | - llvmlite=0.35.0=py38h4630a5e_1 64 | - lz4-c=1.9.3=h9c3ff4c_1 65 | - lzo=2.10=h516909a_1000 66 | - magma-cuda110=2.5.2=5 67 | - markupsafe=2.0.1=py38h497a2fe_0 68 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 69 | - mkl=2019.5=281 70 | - mkl-include=2019.5=281 71 | - mock=4.0.3=py38h578d9bd_1 72 | - murmurhash=1.0.5=py38h709712a_0 73 | - ncurses=6.2=h58526e2_4 74 | - ninja=1.10.2=h4bd325d_1 75 | - numba=0.52.0=py38h51da96c_0 76 | - numpy=1.21.2=py38he2449b9_0 77 | - openblas=0.3.18=pthreads_h4748800_0 78 | - openssl=1.1.1l=h7f98852_0 79 | - parso=0.8.2=pyhd8ed1ab_0 80 | - patchelf=0.13=h58526e2_0 81 | - pathy=0.6.0=pyhd8ed1ab_0 82 | - pexpect=4.8.0=pyh9f0ad1d_2 83 | - pickleshare=0.7.5=py_1003 84 | - pip=21.2.4=pyhd8ed1ab_0 85 | - pkginfo=1.7.1=pyhd8ed1ab_0 86 | - preshed=3.0.5=py38h709712a_1 87 | - prompt-toolkit=3.0.20=pyha770c72_0 88 | - psutil=5.8.0=py38h497a2fe_1 89 | - ptyprocess=0.7.0=pyhd3deb0d_0 90 | - py-lief=0.11.5=py38h709712a_0 91 | - pycosat=0.6.3=py38h497a2fe_1006 92 | - pycparser=2.20=pyh9f0ad1d_2 93 | - pydantic=1.8.2=py38h497a2fe_0 94 | - pygments=2.10.0=pyhd8ed1ab_0 95 | - pyopenssl=20.0.1=pyhd8ed1ab_0 96 | - pyparsing=2.4.7=pyh9f0ad1d_0 97 | - pysocks=1.7.1=py38h578d9bd_3 98 | - python=3.8.12=hb7a2778_1_cpython 99 | - python-libarchive-c=3.1=py38h578d9bd_0 100 | - python_abi=3.8=2_cp38 101 | - pytz=2021.3=pyhd8ed1ab_0 102 | - pyyaml=5.4.1=py38h497a2fe_1 103 | - readline=8.1=h46c0cb4_0 104 | - rhash=1.4.1=h7f98852_0 105 | - ripgrep=13.0.0=habb4d0f_1 106 | - ruamel_yaml=0.15.80=py38h497a2fe_1004 107 | - scipy=1.6.3=py38h7b17777_0 108 | - setuptools=58.2.0=py38h578d9bd_0 109 | - shellingham=1.4.0=pyh44b312d_0 110 | - six=1.16.0=pyh6c4a22f_0 111 | - smart_open=5.2.1=pyhd8ed1ab_0 112 | - soupsieve=2.0.1=py_1 113 | - spacy=3.1.3=py38h2b96118_0 114 | - spacy-legacy=3.0.8=pyhd8ed1ab_0 115 | - sqlite=3.36.0=h9cd32fc_2 116 | - srsly=2.4.1=py38h709712a_0 117 | - tk=8.6.11=h27826a3_1 118 | - tqdm=4.62.3=pyhd8ed1ab_0 119 | - typer=0.4.0=pyhd8ed1ab_0 120 | - tzdata=2021a=he74cb21_1 121 | - wheel=0.37.0=pyhd8ed1ab_1 122 | - xz=5.2.5=h516909a_1 123 | - yaml=0.2.5=h516909a_0 124 | - zlib=1.2.11=h516909a_1010 125 | - zstd=1.5.0=ha95c52a_0 126 | - pip: 127 | - absl-py==0.14.1 128 | - accelerate==0.27.2 129 | - alabaster==0.7.12 130 | - anyio==3.6.2 131 | - apex==0.1 132 | - appdirs==1.4.4 133 | - argon2-cffi==21.1.0 134 | - asgiref==3.4.1 135 | - attrs==21.2.0 136 | - audioread==2.1.9 137 | - babel==2.11.0 138 | - beartype==0.17.2 139 | - bleach==4.1.0 140 | - cachetools==4.2.4 141 | - cftime==1.6.4 142 | - cloudpickle==1.6.0 143 | - codecov==2.1.12 144 | - coverage==6.0.1 145 | - cycler==0.10.0 146 | - cython==0.29.24 147 | - debugpy==1.5.0 148 | - defusedxml==0.7.1 149 | - django==3.2.6 150 | - docker-pycreds==0.4.0 151 | - docutils==0.17.1 152 | - einops==0.7.0 153 | - einops-exts==0.0.4 154 | - ema-pytorch==0.4.2 155 | - entrypoints==0.3 156 | - expecttest==0.1.3 157 | - farama-notifications==0.0.4 158 | - fastjsonschema==2.16.2 159 | - flake8==3.7.9 160 | - flask==2.0.2 161 | - fsspec==2024.2.0 162 | - future==0.18.2 163 | - gitdb==4.0.11 164 | - gitpython==3.1.43 165 | - grpcio==1.41.0 166 | - gunicorn==20.1.0 167 | - h11==0.12.0 168 | - h5py==3.11.0 169 | - horovod==0.28.1 170 | - httptools==0.2.0 171 | - huggingface-hub==0.29.1 172 | - imageio==2.35.1 173 | - imagesize==1.2.0 174 | - importlib-metadata==6.0.0 175 | - importlib-resources==5.10.2 176 | - iniconfig==1.1.1 177 | - iopath==0.1.9 178 | - ipykernel==6.4.1 179 | - ipython-genutils==0.2.0 180 | - itsdangerous==2.0.1 181 | - jinja2==3.1.2 182 | - joblib==1.1.0 183 | - json5==0.9.6 184 | - jsonschema==4.17.3 185 | - jupytext==1.14.4 186 | - kiwisolver==1.3.2 187 | - librosa==0.8.1 188 | - lmdb==1.2.1 189 | - markdown==3.3.4 190 | - markdown-it-py==1.1.0 191 | - matplotlib==3.4.3 192 | - mccabe==0.6.1 193 | - mdit-py-plugins==0.2.8 194 | - mistune==2.0.4 195 | - mpmath==1.3.0 196 | - nbclassic==0.5.1 197 | - nbclient==0.5.4 198 | - nbconvert==7.2.9 199 | - nbformat==5.7.3 200 | - nest-asyncio==1.5.6 201 | - networkx==2.0 202 | - nltk==3.6.4 203 | - notebook==6.4.1 204 | - notebook-shim==0.2.2 205 | - nvidia-cublas-cu12==12.1.3.1 206 | - nvidia-cuda-cupti-cu12==12.1.105 207 | - nvidia-cuda-nvrtc-cu12==12.1.105 208 | - nvidia-cuda-runtime-cu12==12.1.105 209 | - nvidia-cudnn-cu12==8.9.2.26 210 | - nvidia-cufft-cu12==11.0.2.54 211 | - nvidia-curand-cu12==10.3.2.106 212 | - nvidia-cusolver-cu12==11.4.5.107 213 | - nvidia-cusparse-cu12==12.1.0.106 214 | - nvidia-dali-cuda110==1.6.0 215 | - nvidia-dlprof-pytorch-nvtx==1.6.0 216 | - nvidia-dlprofviewer==1.6.0 217 | - nvidia-nccl-cu12==2.19.3 218 | - nvidia-nvjitlink-cu12==12.3.101 219 | - nvidia-nvtx-cu12==12.1.105 220 | - nvidia-pyindex==1.0.9 221 | - oauthlib==3.1.1 222 | - onnx==1.8.204 223 | - packaging==23.0 224 | - pandocfilters==1.5.0 225 | - pillow==10.4.0 226 | - pkgutil-resolve-name==1.3.10 227 | - platformdirs==2.6.2 228 | - pluggy==1.0.0 229 | - polygraphy==0.33.0 230 | - pooch==1.5.1 231 | - portalocker==2.3.2 232 | - prettytable==2.2.1 233 | - prometheus-client==0.11.0 234 | - protobuf==3.20.3 235 | - py==1.10.0 236 | - pyasn1==0.4.8 237 | - pyasn1-modules==0.2.8 238 | - pybind11==2.8.0 239 | - pycocotools==2.0+nv0.5.1 240 | - pycodestyle==2.5.0 241 | - pydot==1.4.2 242 | - pyflakes==2.1.1 243 | - pyrsistent==0.18.0 244 | - pytest==6.2.5 245 | - pytest-cov==3.0.0 246 | - pytest-pythonpath==0.7.3 247 | - python-dateutil==2.8.2 248 | - python-dotenv==0.19.1 249 | - python-hostlist==1.21 250 | - python-nvd3==0.15.0 251 | - python-slugify==5.0.2 252 | - pytorch-quantization==2.1.0 253 | - pyzmq==25.0.0 254 | - regex==2021.10.8 255 | - requests==2.28.2 256 | - requests-oauthlib==1.3.0 257 | - resampy==0.2.2 258 | - revtok==0.0.3 259 | - rotary-embedding-torch==0.5.3 260 | - rsa==4.7.2 261 | - sacremoses==0.0.46 262 | - safetensors==0.4.2 263 | - scikit-learn==1.0 264 | - send2trash==1.8.0 265 | - sentry-sdk==2.1.1 266 | - setproctitle==1.3.3 267 | - shapely==2.0.3 268 | - sniffio==1.3.0 269 | - snowballstemmer==2.1.0 270 | - soundfile==0.10.3.post1 271 | - sphinx==4.2.0 272 | - sphinx-glpi-theme==0.3 273 | - sphinx-rtd-theme==1.0.0 274 | - sphinxcontrib-applehelp==1.0.2 275 | - sphinxcontrib-devhelp==1.0.2 276 | - sphinxcontrib-htmlhelp==2.0.0 277 | - sphinxcontrib-jsmath==1.0.1 278 | - sphinxcontrib-qthelp==1.0.3 279 | - sphinxcontrib-serializinghtml==1.1.5 280 | - sqlparse==0.4.2 281 | - sympy==1.12 282 | - tensorrt==8.0.3.4 283 | - terminado==0.12.1 284 | - testpath==0.5.0 285 | - threadpoolctl==3.0.0 286 | - tinycss2==1.2.1 287 | - toml==0.10.2 288 | - tomli==1.2.1 289 | - torch==2.2.1 290 | - torchtext==0.11.0a0 291 | - torchvision==0.11.0a0 292 | - tornado==6.2 293 | - traitlets==5.8.1 294 | - triton==2.2.0 295 | - typing-extensions==4.12.2 296 | - uff==0.6.9 297 | - urllib3==1.26.18 298 | - uvicorn==0.15.0 299 | - uvloop==0.16.0 300 | - wcwidth==0.2.13 301 | - zipp==3.12.0 302 | prefix: /opt/conda -------------------------------------------------------------------------------- /phi/tf/flow.py: -------------------------------------------------------------------------------- 1 | from phi.flow import * 2 | from phi.flow import _default_phi_stack 3 | from phi.tf.profiling import Timeliner 4 | import tensorflow as tf 5 | import os 6 | from tensorflow.python.client import device_lib 7 | 8 | 9 | class TFFluidSimulation(FluidSimulation): 10 | 11 | def __init__(self, shape, boundary='open', batch_size=None, session=None, solver=None, **kwargs): 12 | math.load_tensorflow() 13 | # Init 14 | self.session = session if session else tf.Session() 15 | self.graph = tf.get_default_graph() 16 | gpus = [device.name for device in device_lib.list_local_devices() if device.device_type == 'GPU'] 17 | if solver is None and len(gpus) > 0: 18 | try: 19 | from phi.solver.cuda.cuda import CudaPressureSolver 20 | solver = CudaPressureSolver() 21 | except: 22 | pass 23 | FluidSimulation.__init__(self, shape, boundary=boundary, batch_size=batch_size, solver=solver, **kwargs) 24 | self.timeliner = None 25 | self.timeline_file = None 26 | self.summary_writers = {} 27 | self.summary_directory = '' 28 | 29 | def run(self, tasks, feed_dict=None, options=None, run_metadata=None, summary_key=None, time=None, merged_summary=None): 30 | if isinstance(tasks, np.ndarray): 31 | return tasks 32 | if tasks is None: 33 | return None 34 | 35 | # Fix feed dict 36 | if feed_dict is not None: 37 | new_feed_dict = {} 38 | for (key, value) in feed_dict.items(): 39 | if isinstance(key, StaggeredGrid): 40 | key = key.staggered 41 | if isinstance(value, StaggeredGrid): 42 | value = value.staggered 43 | new_feed_dict[key] = value 44 | feed_dict = new_feed_dict 45 | # Fix tasks 46 | wrap_results = [] 47 | if isinstance(tasks, list) or isinstance(tasks, tuple): 48 | tasks = list(tasks) 49 | for i in range(len(tasks)): 50 | wrap_results.append(isinstance(tasks[i], StaggeredGrid)) 51 | if wrap_results[-1]: 52 | tasks[i] = tasks[i].staggered 53 | if isinstance(tasks, StaggeredGrid): 54 | wrap_results = True 55 | tasks = tasks.staggered 56 | 57 | # Handle tracing 58 | if self.timeliner: 59 | options = self.timeliner.options 60 | run_metadata = self.timeliner.run_metadata 61 | # Summary 62 | if summary_key is not None and merged_summary is not None: 63 | tasks = [merged_summary] + list(tasks) 64 | 65 | result = self.session.run(tasks, feed_dict=feed_dict, options=options, run_metadata=run_metadata) 66 | 67 | if summary_key: 68 | summary_buffer = result[0] 69 | result = result[1:] 70 | if summary_key in self.summary_writers: 71 | summary_writer = self.summary_writers[summary_key] 72 | else: 73 | summary_writer = tf.summary.FileWriter(os.path.join(self.summary_directory, str(summary_key)), self.graph) 74 | self.summary_writers[summary_key] = summary_writer 75 | summary_writer.add_summary(summary_buffer, time) 76 | summary_writer.flush() 77 | 78 | if self.timeliner: 79 | self.timeliner.add_run() 80 | 81 | if wrap_results is True: 82 | result = StaggeredGrid(result) 83 | elif wrap_results: 84 | result = list(result) 85 | for i in range(len(wrap_results)): 86 | if wrap_results[i]: 87 | result[i] = StaggeredGrid(result[i]) 88 | 89 | return result 90 | 91 | def initialize_variables(self): 92 | import tensorflow as tf 93 | self.session.run(tf.global_variables_initializer()) 94 | self.saver = tf.train.Saver(max_to_keep=100, allow_empty=True) 95 | 96 | def placeholder(self, element_type=1, name=None, batch_size=None, dtype=np.float32): 97 | import tensorflow as tf 98 | if element_type == "velocity": 99 | element_type = "staggered" if self._mac else "vector" 100 | array = tf.placeholder(dtype, self.shape(element_type, batch_size), name=name) 101 | if element_type == "staggered": 102 | return StaggeredGrid(array) 103 | else: 104 | return array 105 | 106 | def clear_domain(self): 107 | # Active / Fluid Mask 108 | if self._force_use_masks or self._active_mask is not None or self._fluid_mask is not None: 109 | self._active_mask = self._create_or_reset_mask(self._active_mask) 110 | self._fluid_mask = self._create_or_reset_mask(self._fluid_mask) 111 | # Velocity Mask 112 | if self._force_use_masks or self._active_mask is not None or self._fluid_mask is not None: 113 | self._update_velocity_mask() 114 | 115 | def set_obstacle(self, mask_or_size, origin=None): 116 | if self._active_mask is None: 117 | self._active_mask = self._create_or_reset_mask(None) 118 | if self._fluid_mask is None: 119 | self._fluid_mask = self._create_or_reset_mask(None) 120 | 121 | dims = range(self.rank) 122 | 123 | if isinstance(mask_or_size, np.ndarray): 124 | value = mask_or_size 125 | slices = None 126 | raise NotImplementedError() # TODO 127 | else: 128 | # mask_or_size = tuple/list of extents 129 | if isinstance(mask_or_size, int): 130 | mask_or_size = [mask_or_size for i in dims] 131 | if origin is None: 132 | origin = [0 for i in range(len(mask_or_size))] 133 | else: 134 | origin = list(origin) 135 | fluid_mask_data, active_mask_data = self.session.run([self._fluid_mask, self._active_mask]) 136 | fluid_mask_data[[0]+[slice(origin[i], origin[i]+mask_or_size[i]) for i in dims]+[0]] = 0 137 | active_mask_data[[0]+[slice(origin[i], origin[i]+mask_or_size[i]) for i in dims]+[0]] = 0 138 | self.session.run([self._fluid_mask.assign(fluid_mask_data), self._active_mask.assign(active_mask_data)]) 139 | self._update_velocity_mask() 140 | 141 | def _create_or_reset_mask(self, old_mask): 142 | if old_mask is None: 143 | return tf.Variable(self._create_mask(), dtype=tf.float32, trainable=False) 144 | else: 145 | self.session.run(old_mask.assign(self._create_mask())) 146 | return old_mask 147 | 148 | def _update_velocity_mask(self): 149 | new_velocity_mask = self._boundary.create_velocity_mask(self._fluid_mask, self._dimensions, self._mac) 150 | if self._velocity_mask is None: 151 | self._velocity_mask = StaggeredGrid(tf.Variable(new_velocity_mask.staggered, dtype=tf.float32, trainable=False)) 152 | else: 153 | self.session.run(self._velocity_mask.staggered.assign(new_velocity_mask.staggered)) 154 | 155 | def _update_masks(self, cell_type_mask): 156 | if self.cell_type_mask is None: 157 | self.cell_type_mask = tf.Variable(cell_type_mask, dtype=tf.float32, trainable=False) 158 | else: 159 | self.session.run(self.cell_type_mask.assign(cell_type_mask)) 160 | 161 | new_velocity_mask = self.boundary.create_velocity_mask(cell_type_mask, self._mac) 162 | if new_velocity_mask is None: 163 | self.velocity_mask = None 164 | else: 165 | if self.velocity_mask is None: 166 | self.velocity_mask = StaggeredGrid(tf.Variable(new_velocity_mask.staggered, dtype=tf.float32, trainable=False)) 167 | else: 168 | self.session.run(self.velocity_mask.staggered.assign(new_velocity_mask.staggered)) 169 | 170 | def save(self, dir): 171 | os.path.isdir(dir) or os.makedirs(dir) 172 | self.saver.save(self.session, os.path.join(dir, "model.ckpt")) 173 | 174 | 175 | def restore(self, dir, scope=None): 176 | path = os.path.join(dir, "model.ckpt") 177 | vars = self.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=scope) 178 | saver = tf.train.Saver(var_list=vars) 179 | saver.restore(self.session, path) 180 | 181 | def restore_new_scope(self, dir, saved_scope, tf_scope): 182 | var_remap = dict() 183 | vars = [v for v in self.graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf_scope) if "Adam" not in v.name] 184 | for var in vars: 185 | var_remap[saved_scope + var.name[len(tf_scope):-2]] = var 186 | path = os.path.join(dir, "model.ckpt") 187 | saver = tf.train.Saver(var_list=var_remap) 188 | try: 189 | saver.restore(self.session, path) 190 | except tf.errors.NotFoundError as e: 191 | from tensorflow.contrib.framework.python.framework import checkpoint_utils 192 | print(checkpoint_utils.list_variables(dir)) 193 | raise e 194 | 195 | 196 | @property 197 | def tracing(self): 198 | return self.timeliner is not None 199 | 200 | def start_tracing(self, file): 201 | dir = os.path.dirname(file) 202 | os.path.isdir(dir) or os.makedirs(dir) 203 | self.timeline_file = file 204 | self.timeliner = Timeliner() 205 | 206 | def stop_tracing(self): 207 | self.timeliner.save(self.timeline_file) 208 | self.timeliner = None 209 | 210 | 211 | def placeholder(element_type=1, name=None, batch_size=None, dtype=np.float32): 212 | return _default_phi_stack.get_default().placeholder(element_type, name=name, batch_size=batch_size, dtype=dtype) 213 | 214 | 215 | def run(tasks, feed_dict=None, options=None, run_metadata=None, summary_key=None, time=None, merged_summary=None): 216 | return _default_phi_stack.get_default().run(tasks, feed_dict, options, run_metadata, summary_key=summary_key, 217 | time=time, merged_summary=merged_summary) -------------------------------------------------------------------------------- /phi/fluidformat.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import numpy as np 3 | import os, os.path, json, inspect, shutil, six, re 4 | from os.path import join, isfile, isdir 5 | 6 | 7 | def read_zipped_array(filename): 8 | file = np.load(filename) 9 | array = file[file.files[0]] 10 | if array.shape[0] != 1: 11 | array = array.reshape((1,)+array.shape) 12 | return array 13 | 14 | 15 | def write_zipped_array(filename, array): 16 | if array.shape[0] == 1: 17 | array = array[0,...] 18 | np.savez_compressed(filename, array) 19 | 20 | 21 | def _check_same_dimensions(arrays): 22 | for array in arrays: 23 | if array.shape[1:-1] != arrays[0].shape[1:-1]: 24 | raise ValueError("All arrays should have the same spatial dimensions, but got %s and %s" % (array.shape, arrays[0].shape)) 25 | 26 | 27 | def read_sim_frame(simpath, fieldnames, index, set_missing_to_none=True): 28 | if isinstance(fieldnames, six.string_types): fieldnames = [fieldnames] 29 | for fieldname in fieldnames: 30 | filename = join(simpath, "%s_%06i.npz"%(fieldname,index)) 31 | if os.path.isfile(filename): 32 | yield read_zipped_array(filename) 33 | else: 34 | if set_missing_to_none: 35 | yield None 36 | else: 37 | raise IOError("Missing frame at index %d: %s"%(index,filename)) 38 | 39 | 40 | def write_sim_frame(simpath, arrays, fieldnames, index, check_same_dimensions=False): 41 | if check_same_dimensions: _check_same_dimensions(arrays) 42 | os.path.isdir(simpath) or os.mkdir(simpath) 43 | if not isinstance(fieldnames, (tuple, list)) and not isinstance(arrays, (tuple, list)): 44 | fieldnames = [fieldnames] 45 | arrays = [arrays] 46 | filenames = [join(simpath, "%s_%06i.npz"%(name,index)) for name in fieldnames] 47 | for i in range(len(arrays)): 48 | write_zipped_array(filenames[i], arrays[i]) 49 | return filenames 50 | 51 | 52 | def read_sim_frames(simpath, fieldnames=None, indices=None): 53 | if fieldnames is None: fieldnames = get_fieldnames(simpath) 54 | if not fieldnames: return [] 55 | if indices is None: indices = get_indices(simpath, fieldnames[0]) 56 | if isinstance(indices, int): indices = [indices] 57 | single_fieldname = isinstance(fieldnames, six.string_types) 58 | if single_fieldname: fieldnames = [fieldnames] 59 | 60 | field_lists = [[] for f in fieldnames] 61 | for i in indices: 62 | fields = read_sim_frame(simpath, fieldnames, i, set_missing_to_none=False) 63 | for j in range(len(fieldnames)): 64 | field_lists[j].append(fields[j]) 65 | result = [np.concatenate(list, 0) for list in field_lists] 66 | return result if not single_fieldname else result[0] 67 | 68 | 69 | def get_fieldnames(simpath): 70 | fieldnames_set = {f[:-11] for f in os.listdir(simpath) if f.endswith(".npz")} 71 | return sorted(fieldnames_set) 72 | 73 | 74 | def first_index(simpath, fieldname=None): 75 | return min(get_indices(simpath, fieldname)) 76 | 77 | 78 | def get_indices(simpath, fieldname=None, mode="intersect"): 79 | if fieldname is not None: 80 | all_indices = {int(f[-10:-4]) for f in os.listdir(simpath) if f.startswith(fieldname) and f.endswith(".npz")} 81 | return sorted(all_indices) 82 | else: 83 | indices_lists = [get_indices(simpath, fieldname) for fieldname in get_fieldnames(simpath)] 84 | if mode.lower() == "intersect": 85 | intersection = set(indices_lists[0]).intersection(*indices_lists[1:]) 86 | return sorted(intersection) 87 | elif mode.lower() == "union": 88 | if not indices_lists: 89 | return [] 90 | union = set(indices_lists[0]).union(*indices_lists[1:]) 91 | return sorted(union) 92 | 93 | 94 | class Scene(object): 95 | 96 | def __init__(self, dir, category, index): 97 | self.dir = dir 98 | self.category = category 99 | self.index = index 100 | self._properties = None 101 | 102 | @property 103 | def path(self): 104 | return join(self.dir, self.category, "sim_%06d"%self.index) 105 | 106 | def subpath(self, name, create=False): 107 | path = join(self.path, name) 108 | if create and not os.path.isdir(path): 109 | os.mkdir(path) 110 | return path 111 | 112 | def _init_properties(self): 113 | if self._properties is not None: return 114 | dfile = join(self.path, "description.json") 115 | if isfile(dfile): 116 | self._properties = json.load(dfile) 117 | else: 118 | self._properties = {} 119 | 120 | def exists_config(self): 121 | return isfile(join(self.path, "description.json")) 122 | 123 | 124 | @property 125 | def properties(self): 126 | self._init_properties() 127 | return self._properties 128 | 129 | @properties.setter 130 | def properties(self, dict): 131 | self._properties = dict 132 | with open(join(self.path, "description.json"), "w") as out: 133 | json.dump(self._properties, out, indent=2) 134 | 135 | def put_property(self, key, value): 136 | self._init_properties() 137 | self._properties[key] = value 138 | with open(join(self.path, "description.json"), "w") as out: 139 | json.dump(self._properties, out, indent=2) 140 | 141 | 142 | def read_sim_frames(self, fieldnames=None, indices=None): 143 | return read_sim_frames(self.path, fieldnames=fieldnames, indices=indices) 144 | 145 | def read_array(self, fieldname, index): 146 | return next(read_sim_frame(self.path, [fieldname], index, set_missing_to_none=False)) 147 | 148 | def write_sim_frame(self, arrays, fieldnames, index, check_same_dimensions=False): 149 | write_sim_frame(self.path, arrays, fieldnames, index, check_same_dimensions=check_same_dimensions) 150 | 151 | @property 152 | def fieldnames(self): 153 | return get_fieldnames(self.path) 154 | 155 | @property 156 | def indices(self): 157 | return get_indices(self.path) 158 | 159 | def get_indices(self, mode="intersect"): 160 | return get_indices(self.path, None, mode) 161 | 162 | def __str__(self): 163 | return self.path 164 | 165 | def __repr__(self): 166 | return self.path 167 | 168 | def copy_calling_script(self): 169 | script_path = inspect.stack()[1][1] 170 | script_name = os.path.basename(script_path) 171 | src_path = os.path.join(self.path, "src") 172 | os.path.isdir(src_path) or os.mkdir(src_path) 173 | target = os.path.join(self.path, "src", script_name) 174 | shutil.copy(script_path, target) 175 | try: 176 | shutil.copystat(script_path, target) 177 | except: 178 | pass # print("Could not copy file metadata to %s"%target) 179 | 180 | def copy_src(self, path): 181 | file_name = os.path.basename(path) 182 | src_dir = os.path.dirname(path) 183 | target_dir = join(self.path, "src") 184 | # Create directory and copy 185 | isdir(target_dir) or os.mkdir(target_dir) 186 | shutil.copy(path, join(target_dir, file_name)) 187 | try: 188 | shutil.copystat(path, join(target_dir, file_name)) 189 | except: 190 | pass # print("Could not copy file metadata to %s"%target) 191 | 192 | def mkdir(self, subdir=None): 193 | path = self.path 194 | isdir(path) or os.mkdir(path) 195 | if subdir is not None: 196 | subpath = join(path, subdir) 197 | isdir(subpath) or os.mkdir(subpath) 198 | 199 | def remove(self): 200 | if isdir(self.path): 201 | shutil.rmtree(self.path) 202 | 203 | 204 | def scenes(directory, category=None, indexfilter=None, max_count=None): 205 | directory = os.path.expanduser(directory) 206 | if not category: 207 | root_path = directory 208 | category = os.path.basename(directory) 209 | directory = os.path.dirname(directory) 210 | else: 211 | root_path = join(directory, category) 212 | indices = [int(sim[4:]) for sim in os.listdir(root_path) if sim.startswith("sim_")] 213 | if indexfilter: 214 | indices = indexfilter(indices) 215 | if max_count and len(indices) >= max_count: 216 | indices = indices[0:max_count] 217 | for scene_index in indices: 218 | yield Scene(directory, category, scene_index) 219 | 220 | 221 | def new_scene(directory, category=None, mkdir=True): 222 | directory = os.path.expanduser(directory) 223 | if category is None: 224 | category = os.path.basename(directory) 225 | directory = os.path.dirname(directory) 226 | else: 227 | category = slugify(category) 228 | 229 | scenedir = join(directory, category) 230 | if not isdir(scenedir): 231 | os.makedirs(scenedir) 232 | next_index = 0 233 | else: 234 | indices = [int(name[4:]) for name in os.listdir(scenedir) if name.startswith("sim_")] 235 | if not indices: 236 | next_index = 0 237 | else: 238 | next_index = max(indices) + 1 239 | scene = Scene(directory, category, next_index) 240 | if mkdir: scene.mkdir() 241 | return scene 242 | 243 | 244 | def scene_at(sim_dir): 245 | sim_dir = os.path.expanduser(sim_dir) 246 | dirname = os.path.basename(sim_dir) 247 | if not dirname.startswith("sim_"): 248 | raise ValueError("%s is not a valid scene directory."%sim_dir) 249 | category_directory = os.path.dirname(sim_dir) 250 | category = os.path.basename(category_directory) 251 | directory = os.path.dirname(category_directory) 252 | index = int(dirname[4:]) 253 | return Scene(directory, category, index) 254 | 255 | 256 | def slugify(value): 257 | """ 258 | Normalizes string, converts to lowercase, removes non-alpha characters, 259 | and converts spaces to hyphens. 260 | """ 261 | import unicodedata 262 | value = six.u(value) 263 | # value = u"{}".format(value.decode('utf-8')) 264 | value = unicodedata.normalize('NFKD', value)#.encode('ascii', 'ignore') 265 | value = re.sub('Φ', "Phi", value) 266 | value = re.sub('[^\w\s-]', '', value).strip().lower() 267 | value = re.sub('[-\s]+', '-', value) 268 | return value 269 | -------------------------------------------------------------------------------- /phi/solver/cuda/src/pressure_solve_op.cu.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | static void CheckCudaErrorAux(const char* file, unsigned line, const char* statement, cudaError_t err) { 12 | if (err == cudaSuccess) return; 13 | std::cerr << statement << " returned " << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line << std::endl; 14 | exit(1); 15 | } 16 | 17 | #define CUDA_CHECK_RETURN(value) CheckCudaErrorAux(__FILE__, __LINE__, #value, value) 18 | 19 | __global__ void calcZ_v4(const int *dimensions, const int dimProduct, const int maxDataPerRow, const signed char *laplaceMatrix, const float *p, float *z) { 20 | extern __shared__ int diagonalOffsets[]; 21 | 22 | // Build diagonalOffsets on the first thread of each block and write it to shared memory 23 | if(threadIdx.x == 0) { 24 | const int diagonal = maxDataPerRow / 2; 25 | diagonalOffsets[diagonal] = 0; 26 | int factor = 1; 27 | 28 | for(int i = 0, offset = 1; i < diagonal; i++, offset++) { 29 | diagonalOffsets[diagonal - offset] = -factor; 30 | diagonalOffsets[diagonal + offset] = factor; 31 | factor *= dimensions[i]; 32 | } 33 | } 34 | __syncthreads(); 35 | 36 | const int row = blockIdx.x * blockDim.x + threadIdx.x; 37 | if (row < dimProduct) { 38 | const int diagonal = row * maxDataPerRow; 39 | float tmp = 0; 40 | for(int i = diagonal; i < diagonal + maxDataPerRow; i++) { 41 | // when accessing out of bound memory in p, laplaceMatrix[i] is always zero. So no illegal mem-access will be made. 42 | // Anyway, if this causes problems add this: 43 | // if(row + offsets[i - diagonal] >= 0 && row + offsets[i - diagonal] < dimProduct) 44 | tmp += (signed char)laplaceMatrix[i] * p[row + diagonalOffsets[i - diagonal]]; // No modulo here (as the general way in the thesis suggests) 45 | } 46 | z[row] = tmp; 47 | } 48 | } 49 | 50 | __global__ void checkResiduum(const int dimProduct, const float* r, const float threshold, bool *threshold_reached) { 51 | for (int row = blockIdx.x * blockDim.x + threadIdx.x; row < dimProduct; row += blockDim.x * gridDim.x) { 52 | if (r[row] >= threshold) { 53 | *threshold_reached = false; 54 | break; 55 | } 56 | } 57 | } 58 | 59 | __global__ void initVariablesWithGuess(const int dimProduct, const float *divergence, float* A_times_x_0, float *p, float *r, bool *threshold_reached) { 60 | const int row = blockIdx.x * blockDim.x + threadIdx.x; 61 | if (row < dimProduct) { 62 | float tmp = divergence[row] - A_times_x_0[row]; 63 | p[row] = tmp; 64 | r[row] = tmp; 65 | 66 | } 67 | if(row == 0) *threshold_reached = false; 68 | } 69 | 70 | void LaunchPressureKernel(const int* dimensions, const int dimProduct, const int dim_size, 71 | const signed char *laplaceMatrix, 72 | float* p, float* z, float* r, float* divergence, float* x, 73 | const float *oneVector, 74 | bool* threshold_reached, 75 | const float accuracy, 76 | const int max_iterations, 77 | const int batch_size, 78 | int* iterations_gpu) { 79 | // printf("Address of laplaceMatrix is %p\n", (void *)laplaceMatrix); 80 | // printf("Address of oneVector is %p\n", (void *)oneVector); 81 | // printf("Address of x is %p\n", (void *)x); 82 | // printf("Address of p is %p\n", (void *)p); 83 | // printf("Address of z is %p\n", (void *)z); 84 | // printf("Address of r is %p\n", (void *)r); 85 | // printf("Address of divergence is %p\n", (void *)divergence); 86 | 87 | cublasHandle_t blasHandle; 88 | cublasCreate_v2(&blasHandle); 89 | cublasSetPointerMode_v2(blasHandle, CUBLAS_POINTER_MODE_HOST); 90 | 91 | // CG helper variables variables init 92 | float *alpha = new float[batch_size], *beta = new float[batch_size]; 93 | const float oneScalar = 1.0f; 94 | bool *threshold_reached_cpu = new bool[batch_size]; 95 | float *p_r = new float[batch_size], *p_z = new float[batch_size], *r_z = new float[batch_size]; 96 | 97 | // get block and gridSize to theoretically get best occupancy 98 | int blockSize; 99 | int minGridSize; 100 | int gridSize; 101 | 102 | // Initialize the helper variables 103 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, calcZ_v4, 0, 0); 104 | gridSize = (dimProduct + blockSize - 1) / blockSize; 105 | 106 | // First calc A * x_0, save result to z: 107 | for(int i = 0; i < batch_size; i++) { 108 | calcZ_v4<<>>(dimensions, 109 | dimProduct, 110 | dim_size * 2 + 1, 111 | laplaceMatrix, 112 | x + i * dimProduct, 113 | z + i * dimProduct); 114 | } 115 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 116 | 117 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, initVariablesWithGuess, 0, 0); 118 | gridSize = (dimProduct + blockSize - 1) / blockSize; 119 | 120 | // Second apply result to the helper variables 121 | for(int i = 0; i < batch_size; i++) { 122 | int offset = i * dimProduct; 123 | initVariablesWithGuess<<>>(dimProduct, 124 | divergence + offset, 125 | z + offset, 126 | p + offset, 127 | r + offset, 128 | threshold_reached + i); 129 | } 130 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 131 | 132 | 133 | // Init residuum checker variables 134 | CUDA_CHECK_RETURN(cudaMemcpy(threshold_reached_cpu, threshold_reached, sizeof(bool) * batch_size, cudaMemcpyDeviceToHost)); 135 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 136 | 137 | cudaOccupancyMaxPotentialBlockSize( &minGridSize, &blockSize, 138 | calcZ_v4, 0, 0); 139 | gridSize = (dimProduct + blockSize - 1) / blockSize; 140 | 141 | // Do CG-Solve 142 | int checker = 1; 143 | int iterations = 0; 144 | for (; iterations < max_iterations; iterations++) { 145 | for(int i = 0; i < batch_size; i++) { 146 | if(threshold_reached_cpu[i]) continue; 147 | calcZ_v4<<>>(dimensions, dimProduct, dim_size * 2 + 1, laplaceMatrix, p + i * dimProduct, z + i * dimProduct); 148 | } 149 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 150 | 151 | 152 | for(int i = 0; i < batch_size; i++) { 153 | if(threshold_reached_cpu[i]) continue; 154 | cublasSdot_v2(blasHandle, dimProduct, p + i * dimProduct, 1, r + i * dimProduct, 1, p_r + i); 155 | cublasSdot_v2(blasHandle, dimProduct, p + i * dimProduct, 1, z + i * dimProduct, 1, p_z + i); 156 | } 157 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 158 | 159 | for(int i = 0; i < batch_size; i++) { 160 | if(threshold_reached_cpu[i]) continue; 161 | alpha[i] = p_r[i] / p_z[i]; 162 | cublasSaxpy_v2(blasHandle, dimProduct, alpha + i, p + i * dimProduct, 1, x + i * dimProduct, 1); 163 | 164 | alpha[i] = -alpha[i]; 165 | cublasSaxpy_v2(blasHandle, dimProduct, alpha + i, z + i * dimProduct, 1, r + i * dimProduct, 1); 166 | 167 | } 168 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 169 | 170 | // Check the residuum every 5 steps to keep memcopys between H&D low 171 | // Tests have shown, that 5 is a good avg trade-of between memcopys and extra computation and increases the performance 172 | if (checker % 5 == 0) { 173 | for(int i = 0; i < batch_size; i++) { 174 | if(threshold_reached_cpu[i]) continue; 175 | // Use fewer occupancy here, because in most cases residual will be to high and therefore 176 | checkResiduum<<<8, blockSize>>>(dimProduct, r + i * dimProduct, accuracy, threshold_reached + i); 177 | } 178 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 179 | 180 | CUDA_CHECK_RETURN(cudaMemcpy(threshold_reached_cpu, threshold_reached, sizeof(bool) * batch_size, cudaMemcpyDeviceToHost)); 181 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 182 | 183 | bool done = true; 184 | for(int i = 0; i < batch_size; i++) { 185 | if (!threshold_reached_cpu[i]) { 186 | done = false; 187 | break; 188 | } 189 | } 190 | if(done){ 191 | iterations++; 192 | break; 193 | } 194 | CUDA_CHECK_RETURN(cudaMemset(threshold_reached, 1, sizeof(bool) * batch_size)); 195 | } 196 | checker++; 197 | 198 | for(int i = 0; i < batch_size; i++) { 199 | if(threshold_reached_cpu[i]) continue; 200 | cublasSdot_v2(blasHandle, dimProduct, r + i * dimProduct, 1, z + i * dimProduct, 1, r_z + i); 201 | } 202 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 203 | 204 | for(int i = 0; i < batch_size; i++) { 205 | if(threshold_reached_cpu[i]) continue; 206 | beta[i] = -r_z[i] / p_z[i]; 207 | cublasSscal_v2(blasHandle, dimProduct, beta + i, p + i * dimProduct, 1); 208 | cublasSaxpy_v2(blasHandle, dimProduct, &oneScalar, r + i * dimProduct, 1, p + i * dimProduct, 1); 209 | } 210 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 211 | } 212 | 213 | delete[] alpha, beta, threshold_reached_cpu, p_r, p_z, r_z; 214 | // printf("I: %i\n", iterations); 215 | 216 | CUDA_CHECK_RETURN(cudaMemcpy(iterations_gpu, &iterations, sizeof(int), cudaMemcpyHostToDevice)); 217 | CUDA_CHECK_RETURN(cudaDeviceSynchronize()); 218 | 219 | } 220 | -------------------------------------------------------------------------------- /train/train_1d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys, os 3 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..')) 4 | sys.path.append(os.path.join(os.path.dirname("__file__"), '..', '..')) 5 | import torch 6 | import numpy as np 7 | import pdb 8 | from dataset.data_1d import DiffusionDataset, get_burgers_preprocess 9 | from diffusion.diffusion_1d import GaussianDiffusion1D, GaussianDiffusion 10 | from model.model_1d.unet import Unet1D, Unet2D 11 | from utils_1d.train_diffusion import Trainer, Trainer1D 12 | from utils_1d.utils import none_or_str 13 | import matplotlib.pyplot as plt 14 | from utils_1d.result_io import merge_save_dict 15 | from datetime import datetime 16 | import yaml 17 | 18 | RESCALER = 10. 19 | 20 | parser = argparse.ArgumentParser(description='Train model') 21 | parser.add_argument('--exp_id', default='gen-control', type=str, 22 | help='experiment folder id') 23 | parser.add_argument('--date_time', default=datetime.today().strftime('%Y-%m-%d'), type=str, 24 | help='date for the experiment folder') 25 | parser.add_argument('--dataset', default='free_u_f_1e5', type=str, 26 | help='dataset name') 27 | parser.add_argument('--train_data_path', default='/data', type=str, 28 | help='train data path') 29 | parser.add_argument('--train_num_steps', default=100000, type=int, 30 | help='train_num_steps') 31 | parser.add_argument('--checkpoint_interval', default=10000, type=int, 32 | help='save checkpoint every checkpoint_interval steps') 33 | 34 | parser.add_argument('--is_condition_u0', default=False, type=eval, 35 | help='If learning p(u_[1, T] | u0)') 36 | parser.add_argument('--is_condition_uT', default=False, type=eval, 37 | help='If learning p(u_[0, T-1] | uT)') 38 | parser.add_argument('--is_condition_u0_zero_pred_noise', default=True, type=eval, 39 | help='If enforcing the pred_noise to be zero for the conditioned data\ 40 | when learning p(u_[1, T-1] | u0). if false, reproduce some faulty behaviors') 41 | parser.add_argument('--is_condition_uT_zero_pred_noise', default=True, type=eval, 42 | help='If enforcing the pred_noise to be zero for the conditioned data\ 43 | when learning p(u_[1, T-1] | u0). if false, reproduce some faulty behaviors') 44 | parser.add_argument('--condition_on_residual', default=None, type=str, 45 | help='option: None, residual_gradient') 46 | parser.add_argument('--residual_on_u0', default=False, type=eval, 47 | help='when using conditioning on residual, whether feeding u0 or ut into Unet') 48 | # exp setting 49 | parser.add_argument('--partially_observed', default=None, type=none_or_str, 50 | help='If None, fully observed, otherwise, partially observed during training\ 51 | Note that the force is always fuly observed. Possible choices:\ 52 | front_rear_quarter. \ 53 | In the training part, partially_observed sets the training trajectories to zero at the unobserved locations') 54 | parser.add_argument('--train_on_partially_observed', default=None, type=none_or_str, 55 | help='Whether to train the model to generate zero states at the unobserved locations. if None, enforce zero.') 56 | 57 | # sampling setting: does not affect 58 | parser.add_argument('--set_unobserved_to_zero_during_sampling', default=False, type=eval, 59 | help='Set central 1/2 to zero in each p sample loop.') 60 | parser.add_argument('--recurrence', default=False, type=eval, help='whether to use recurrence in Universal Guidance for Diffusion Models') 61 | parser.add_argument('--recurrence_k', default=1, type=int, help='how many iterations of recurrence. k in Algo 1 in Universal Guidance for Diffusion Models') 62 | 63 | # unet hyperparam 64 | parser.add_argument('--dim', default=64, type=int, 65 | help='first layer feature dim num in Unet') 66 | parser.add_argument('--resnet_block_groups', default=1, type=int, 67 | help='group num in GroupNorm default 8') 68 | parser.add_argument('--dim_muls', nargs='+', default=[1, 2, 4, 8], type=int, 69 | help='dimension of channels, multiplied to the base dim\ 70 | seq_length % (2 ** len(dim_muls)) must be 0') 71 | 72 | # 2 ddpm: learn p(w, u) and p(w) -> use p(u | w) during inference 73 | parser.add_argument('--is_model_w', default=False, type=eval, help='If training the p(w) model, else train the p(u, w) model') 74 | parser.add_argument('--eval_two_models', default=False, type=eval, help='Set to False in this training file') 75 | parser.add_argument('--expand_condition', default=False, type=eval, help='Expand conditioning information of u0 or uT in separate channels') 76 | parser.add_argument('--prior_beta', default=1, type=eval, help='strength of the prior (1 is p(u,w); 0 is p(u|w))') 77 | parser.add_argument('--asynch_inference_mode', action='store_true', 78 | help="if true, all time steps are denoised asynchronously (our method), and only infer_interval=1 allowed; otherwise synchronously (baseline) for inference, and 1<=infer_interval