├── pde_superresolution ├── scripts │ ├── __init__.py │ ├── create_baseline_data_test.py │ ├── create_training_data_test.py │ ├── run_survival.py │ ├── run_training.py │ ├── simulate_train_evaluate_test.py │ ├── run_mae.py │ ├── create_exact_data.py │ ├── create_baseline_data.py │ ├── create_training_data.py │ └── run_evaluation.py ├── __init__.py ├── model_test.py ├── utils.py ├── analysis_test.py ├── equations_test.py ├── duckarray_test.py ├── xarray_beam_test.py ├── training_test.py ├── analysis.py ├── layers_test.py ├── weno_test.py ├── weno.py ├── xarray_beam.py ├── layers.py ├── polynomials_test.py ├── duckarray.py ├── integrate_test.py ├── polynomials.py ├── integrate.py ├── equations.py └── training.py ├── CONTRIBUTING.md ├── setup.py ├── README.md └── LICENSE /pde_superresolution/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /pde_superresolution/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Code for PDE-superresolution.""" 16 | from pde_superresolution import analysis 17 | from pde_superresolution import duckarray 18 | from pde_superresolution import equations 19 | from pde_superresolution import integrate 20 | from pde_superresolution import layers 21 | from pde_superresolution import model 22 | from pde_superresolution import polynomials 23 | from pde_superresolution import training 24 | from pde_superresolution import utils 25 | from pde_superresolution import weno 26 | from pde_superresolution import xarray_beam 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Install pde-superresolution.""" 16 | import setuptools 17 | 18 | 19 | INSTALL_REQUIRES = [ 20 | 'absl-py', 21 | 'apache-beam', 22 | 'h5py', 23 | 'numpy', 24 | 'pandas', 25 | 'scipy', 26 | 'tensorflow<2', 27 | 'xarray', 28 | ] 29 | 30 | setuptools.setup( 31 | name='pde-superresolution', 32 | version='0.0.0', 33 | license='Apache 2.0', 34 | author='Google LLC', 35 | author_email='noreply@google.com', 36 | install_requires=INSTALL_REQUIRES, 37 | url='https://github.com/google/pde-superresolution', 38 | packages=setuptools.find_packages(), 39 | python_requires='>=3') 40 | -------------------------------------------------------------------------------- /pde_superresolution/model_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for model functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from pde_superresolution import model # pylint: disable=g-bad-import-order 27 | 28 | 29 | class ModelTest(parameterized.TestCase): 30 | 31 | def test_stack_all_rolls(self): 32 | with tf.Graph().as_default(): 33 | with tf.Session(): 34 | inputs = tf.range(5) 35 | actual = model._stack_all_rolls(inputs, 3) 36 | expected = [[0, 1, 2, 3, 4], [1, 2, 3, 4, 0], [2, 3, 4, 0, 1]] 37 | np.testing.assert_allclose(expected, actual.eval()) 38 | 39 | 40 | if __name__ == '__main__': 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /pde_superresolution/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Miscellaneous utility functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import contextlib 22 | import os.path 23 | import shutil 24 | import tempfile 25 | 26 | import h5py 27 | import tensorflow as tf 28 | from typing import Iterator 29 | 30 | 31 | @contextlib.contextmanager 32 | def write_h5py(path: str) -> Iterator[h5py.File]: 33 | """Context manager to open an h5py.File for writing.""" 34 | tmp_dir = tempfile.mkdtemp() 35 | local_path = os.path.join(tmp_dir, 'data.h5') 36 | with h5py.File(local_path) as f: 37 | yield f 38 | tf.gfile.Copy(local_path, path) 39 | shutil.rmtree(tmp_dir) 40 | 41 | 42 | @contextlib.contextmanager 43 | def read_h5py(path: str) -> Iterator[h5py.File]: 44 | """Context manager to open an h5py.File for reading.""" 45 | tmp_dir = tempfile.mkdtemp() 46 | local_path = os.path.join(tmp_dir, 'data.h5') 47 | tf.gfile.Copy(path, local_path) 48 | with h5py.File(local_path) as f: 49 | yield f 50 | shutil.rmtree(tmp_dir) 51 | -------------------------------------------------------------------------------- /pde_superresolution/analysis_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for analysis functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | import xarray 25 | 26 | from pde_superresolution import analysis # pylint: disable=g-bad-import-order 27 | 28 | 29 | class AnalysisTest(parameterized.TestCase): 30 | 31 | @parameterized.parameters( 32 | dict(data=np.arange(100) < 65, expected=6.5), 33 | dict(data=np.ones(100), expected=9.9), 34 | dict(data=np.zeros(100), expected=0), 35 | dict(data=np.concatenate([np.ones(10), np.zeros(1), 36 | np.ones(9), np.zeros(80)]), 37 | expected=1), 38 | ) 39 | def test_calculate_survival(self, data, expected): 40 | array = xarray.DataArray(data, [('time', np.arange(100) / 10)]) 41 | result = analysis.calculate_survival(array).item() 42 | self.assertEqual(expected, result) 43 | 44 | 45 | if __name__ == '__main__': 46 | absltest.main() 47 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/create_baseline_data_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sanity test for create_training_data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os.path 21 | 22 | from absl import flags 23 | from absl.testing import flagsaver 24 | import xarray 25 | from absl.testing import absltest # pylint: disable=g-bad-import-order 26 | 27 | from pde_superresolution.scripts import create_baseline_data 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class CreateBaselineDataTest(absltest.TestCase): 34 | 35 | def test(self): 36 | output_path = os.path.join(FLAGS.test_tmpdir, 'temp.nc') 37 | 38 | # run the beam job 39 | with flagsaver.flagsaver( 40 | output_path=output_path, 41 | equation_name='burgers', 42 | equation_kwargs='{"num_points": 400}', 43 | num_samples=2, 44 | accuracy_orders=[1, 3, 5], 45 | time_max=1.0, 46 | time_delta=0.1, 47 | warmup=0): 48 | create_baseline_data.main([]) 49 | 50 | # verify the results 51 | with xarray.open_dataset(output_path) as ds: 52 | self.assertEqual(ds['y'].dims, ('sample', 'accuracy_order', 'time', 'x')) 53 | self.assertEqual(ds['y'].shape, (2, 3, 11, 400)) 54 | 55 | 56 | if __name__ == '__main__': 57 | absltest.main() 58 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/create_training_data_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sanity test for create_training_data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os.path 21 | 22 | from absl import flags 23 | from absl.testing import flagsaver 24 | from absl.testing import absltest # pylint: disable=g-bad-import-order 25 | 26 | from pde_superresolution import utils 27 | from pde_superresolution.scripts import create_training_data 28 | 29 | 30 | FLAGS = flags.FLAGS 31 | 32 | 33 | class CreateTrainingDataTest(absltest.TestCase): 34 | 35 | def test(self): 36 | output_path = os.path.join(FLAGS.test_tmpdir, 'temp.h5') 37 | 38 | # run the beam job 39 | with flagsaver.flagsaver( 40 | output_path=output_path, 41 | equation_name='burgers', 42 | equation_kwargs='{"num_points": 400}', 43 | num_tasks=2, 44 | time_max=1.0, 45 | time_delta=0.1, 46 | warmup=0): 47 | create_training_data.main([]) 48 | 49 | # verify the results 50 | with utils.read_h5py(output_path) as f: 51 | data = f['v'][...] 52 | metadata = dict(f.attrs) 53 | self.assertEqual(data.shape, (20, 400)) 54 | self.assertEqual(metadata, {'num_points': 400}) 55 | 56 | 57 | if __name__ == '__main__': 58 | absltest.main() 59 | -------------------------------------------------------------------------------- /pde_superresolution/equations_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for equations.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | from pde_superresolution import equations # pylint: disable=g-bad-import-order 26 | 27 | 28 | class GridTest(absltest.TestCase): 29 | 30 | def test_grid(self): 31 | grid = equations.Grid(3, resample_factor=2, period=60) 32 | 33 | self.assertEqual(grid.solution_num_points, 3) 34 | np.testing.assert_equal(grid.solution_x, [0, 20, 40]) 35 | np.testing.assert_equal(grid.solution_dx, 20) 36 | 37 | self.assertEqual(grid.reference_num_points, 6) 38 | np.testing.assert_equal(grid.reference_x, [0, 10, 20, 30, 40, 50]) 39 | np.testing.assert_equal(grid.reference_dx, 10) 40 | 41 | 42 | class EquationsTest(absltest.TestCase): 43 | 44 | def test_staggered_first_derivative_consistency(self): 45 | # numpy and tensorflow should give the same result 46 | y = np.random.RandomState(0).randn(10) 47 | np_result = equations.staggered_first_derivative(y, dx=1.0) 48 | with tf.Graph().as_default(): 49 | with tf.Session(): 50 | tf_result = equations.staggered_first_derivative( 51 | tf.constant(y), dx=1.0).eval() 52 | np.testing.assert_allclose(np_result, tf_result) 53 | 54 | 55 | if __name__ == '__main__': 56 | absltest.main() 57 | -------------------------------------------------------------------------------- /pde_superresolution/duckarray_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for duck array functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | import scipy.fftpack 25 | import tensorflow as tf 26 | 27 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 28 | 29 | 30 | class DuckArrayTest(parameterized.TestCase): 31 | 32 | def test_resample_mean(self): 33 | inputs = np.arange(6.0) 34 | expected = np.array([0.5, 2.5, 4.5]) 35 | actual = duckarray.resample_mean(inputs, factor=2) 36 | np.testing.assert_allclose(expected, actual) 37 | 38 | with tf.Graph().as_default(): 39 | with tf.Session() as sess: 40 | actual = sess.run( 41 | duckarray.resample_mean(tf.constant(inputs), factor=2)) 42 | np.testing.assert_allclose(expected, actual) 43 | 44 | def test_subsample(self): 45 | inputs = np.arange(6) 46 | expected = np.array([0, 2, 4]) 47 | actual = duckarray.subsample(inputs, factor=2) 48 | np.testing.assert_allclose(expected, actual) 49 | 50 | with tf.Graph().as_default(): 51 | with tf.Session() as sess: 52 | actual = sess.run( 53 | duckarray.subsample(tf.constant(inputs), factor=2)) 54 | np.testing.assert_allclose(expected, actual) 55 | 56 | @parameterized.parameters( 57 | dict(y=np.sin(2*np.pi*np.arange(8)/8), period=1), 58 | dict(y=np.sin(2*np.pi*np.arange(8)/8), period=8), 59 | dict(y=np.linspace(-1, 1, num=12) ** 2, period=2), 60 | ) 61 | def test_spectral_derivative(self, y, period): 62 | for order in range(3): 63 | with self.subTest(order=order): 64 | expected = scipy.fftpack.diff(y, order=order, period=period) 65 | actual = duckarray.spectral_derivative(y, order, period) 66 | np.testing.assert_allclose(expected, actual, atol=1e-12) 67 | 68 | 69 | if __name__ == '__main__': 70 | absltest.main() 71 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/run_survival.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # pylint: disable=line-too-long 16 | """Run a beam pipeline to add netCDF files with survival results.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from absl import flags 23 | import apache_beam as beam 24 | from pde_superresolution import analysis # pylint: disable=g-bad-import-order 25 | import tensorflow as tf 26 | import xarray 27 | 28 | 29 | flags.DEFINE_string( 30 | 'file_pattern', None, 31 | 'Glob to use for matching simulation files.') 32 | flags.DEFINE_string( 33 | 'exact_results_file', None, 34 | 'Optional file providing alternative "exact" simulation results.') 35 | flags.DEFINE_float( 36 | 'quantile', 0.8, 37 | 'Quantile to use for "good enough".') 38 | 39 | 40 | FLAGS = flags.FLAGS 41 | 42 | 43 | def create_survival_netcdf(simulation_path, quantile=0.8, exact_path=None): 44 | """Create a new netCDF file with survival analysis results.""" 45 | 46 | if '/results.nc' not in simulation_path: 47 | # no simulation results 48 | return 49 | 50 | # read data 51 | with tf.gfile.GFile(simulation_path, 'rb') as f: 52 | ds = xarray.open_dataset(f.read()).load() 53 | 54 | if exact_path is not None: 55 | with tf.gfile.GFile(exact_path, 'rb') as f: 56 | ds_exact = xarray.open_dataset(f.read()).load() 57 | ds['y_exact'] = (ds_exact['y'] 58 | .rename({'x': 'x_high'}) 59 | .reindex_like(ds, method='nearest')) 60 | 61 | # do analysis 62 | survival = analysis.mostly_good_survival(ds, quantile) 63 | 64 | # save results 65 | survival_path = simulation_path.replace('/results.nc', '/survival.nc') 66 | with tf.gfile.GFile(survival_path, 'wb') as f: 67 | f.write(survival.to_netcdf()) 68 | 69 | 70 | def main(_, runner=None): 71 | if runner is None: 72 | # must create before flags are used 73 | runner = beam.runners.DirectRunner() 74 | 75 | pipeline = ( 76 | beam.Create(tf.gfile.Glob(FLAGS.file_pattern)) 77 | | beam.Reshuffle() 78 | | beam.Map(create_survival_netcdf, quantile=FLAGS.quantile, 79 | exact_path=FLAGS.exact_results_file) 80 | ) 81 | runner.run(pipeline) 82 | 83 | 84 | if __name__ == '__main__': 85 | app.run(main) 86 | -------------------------------------------------------------------------------- /pde_superresolution/xarray_beam_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | import os.path 20 | 21 | from absl import flags 22 | from absl.testing import absltest # pylint: disable=g-bad-import-order 23 | import numpy as np 24 | import xarray 25 | 26 | from pde_superresolution import xarray_beam # pylint: disable=g-bad-import-order 27 | 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | 32 | class NetCDFTest(absltest.TestCase): 33 | 34 | def test_read_write(self): 35 | data = np.random.RandomState(0).rand(3, 4) 36 | ds = xarray.Dataset({'foo': (('x', 'y'), data)}) 37 | path = os.path.join(FLAGS.test_tmpdir, 'tmp.nc') 38 | xarray_beam.write_netcdf(ds, path) 39 | roundtripped = xarray_beam.read_netcdf(path) 40 | xarray.testing.assert_equal(ds, roundtripped) 41 | 42 | 43 | class StackUnstackTest(absltest.TestCase): 44 | 45 | def test_stack_1d(self): 46 | input_ds = xarray.Dataset({'foo': ('x', [1, 2])}, {'x': [0, 1]}) 47 | stacked = xarray_beam.stack(input_ds, dim='z', levels=['x']) 48 | expected = xarray.Dataset({'foo': ('z', [1, 2])}, 49 | {'x': ('z', [0, 1])}) 50 | xarray.testing.assert_equal(stacked, expected) 51 | 52 | def test_stack_2d(self): 53 | input_ds = xarray.Dataset({'foo': (('x', 'y'), [[1, 2], [3, 4]])}, 54 | {'x': [0, 1], 'y': ['a', 'b']}) 55 | stacked = xarray_beam.stack(input_ds, dim='z', levels=['x', 'y']) 56 | expected = xarray.Dataset({'foo': ('z', [1, 2, 3, 4])}, 57 | {'x': ('z', [0, 0, 1, 1]), 58 | 'y': ('z', ['a', 'b', 'a', 'b'])}) 59 | xarray.testing.assert_equal(stacked, expected) 60 | 61 | def test_stack_unstack_1d(self): 62 | input_ds = xarray.Dataset({'foo': ('x', [1, 2])}, {'x': [0, 1]}) 63 | stacked = xarray_beam.stack(input_ds, dim='z', levels=['x']) 64 | roundtripped = xarray_beam.unstack(stacked, dim='z', levels=['x']) 65 | xarray.testing.assert_equal(roundtripped, input_ds) 66 | 67 | def test_stack_unstack_2d(self): 68 | input_ds = xarray.Dataset({'foo': (('x', 'y'), [[1, 2], [3, 4]])}, 69 | {'x': [0, 1], 'y': ['a', 'b']}) 70 | stacked = xarray_beam.stack(input_ds, dim='z', levels=['x', 'y']) 71 | roundtripped = xarray_beam.unstack(stacked, dim='z', levels=['x', 'y']) 72 | xarray.testing.assert_equal(roundtripped, input_ds) 73 | 74 | 75 | if __name__ == '__main__': 76 | absltest.main() 77 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/run_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Binary for running training.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import json 21 | import os.path 22 | 23 | from absl import app 24 | from absl import flags 25 | from absl import logging 26 | import tensorflow as tf 27 | 28 | from pde_superresolution import equations # pylint: disable=g-bad-import-order 29 | from pde_superresolution import training # pylint: disable=g-bad-import-order 30 | from pde_superresolution import utils # pylint: disable=g-bad-import-order 31 | 32 | 33 | # NOTE(shoyer): allow_override=True lets us import multiple binaries for the 34 | # purpose of running integration tests. This is safe since we're strict about 35 | # only using FLAGS inside main(). 36 | 37 | flags.DEFINE_string( 38 | 'checkpoint_dir', '', 39 | 'Directory to use for saving model', 40 | allow_override=True) 41 | flags.DEFINE_string( 42 | 'input_path', None, 43 | 'Path to HDF5 file with input data.') 44 | flags.DEFINE_enum( 45 | 'equation', None, list(equations.EQUATION_TYPES), 46 | 'Equation to integrate.') 47 | flags.DEFINE_string( 48 | 'hparams', '', 49 | 'Additional hyper-parameter values to use, in the form of a ' 50 | 'comma-separated list of name=value pairs, e.g., ' 51 | '"num_layers=3,filter_size=64".') 52 | flags.DEFINE_string( 53 | 'master', '', 54 | 'Master to use with TensorFlow.') 55 | 56 | 57 | FLAGS = flags.FLAGS 58 | 59 | 60 | def main(unused_argv): 61 | logging.info('Loading training data') 62 | with utils.read_h5py(FLAGS.input_path) as f: 63 | snapshots = f['v'][...] 64 | equation_kwargs = {k: v.item() for k, v in f.attrs.items()} 65 | 66 | logging.info('Inputs have shape %r', snapshots.shape) 67 | 68 | if FLAGS.checkpoint_dir: 69 | tf.gfile.MakeDirs(FLAGS.checkpoint_dir) 70 | 71 | hparams = training.create_hparams( 72 | FLAGS.equation, equation_kwargs=json.dumps(equation_kwargs)) 73 | hparams.parse(FLAGS.hparams) 74 | 75 | logging.info('Starting training loop') 76 | metrics_df = training.training_loop(snapshots, FLAGS.checkpoint_dir, 77 | hparams, master=FLAGS.master) 78 | 79 | if FLAGS.checkpoint_dir: 80 | logging.info('Saving CSV with metrics') 81 | csv_path = os.path.join(FLAGS.checkpoint_dir, 'metrics.csv') 82 | with tf.gfile.GFile(csv_path, 'w') as f: 83 | metrics_df.to_csv(f, index=False) 84 | 85 | logging.info('Finished') 86 | 87 | 88 | if __name__ == '__main__': 89 | flags.mark_flag_as_required('checkpoint_dir') 90 | flags.mark_flag_as_required('input_path') 91 | flags.mark_flag_as_required('equation') 92 | app.run(main) 93 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/simulate_train_evaluate_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """An integration test that does data generation, training and evaluation.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os.path 21 | 22 | from absl import flags 23 | from absl.testing import flagsaver 24 | import apache_beam as beam 25 | from pde_superresolution.scripts import create_exact_data 26 | from pde_superresolution.scripts import create_training_data 27 | from pde_superresolution.scripts import run_evaluation 28 | from pde_superresolution.scripts import run_training 29 | import xarray 30 | from absl.testing import absltest 31 | 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | 36 | class IntegrationTest(absltest.TestCase): 37 | 38 | def test(self): 39 | training_path = os.path.join(FLAGS.test_tmpdir, 'training.h5') 40 | exact_path = os.path.join(FLAGS.test_tmpdir, 'exact.nc') 41 | checkpoint_dir = os.path.join(FLAGS.test_tmpdir, 'checkpoint') 42 | samples_output_name = 'results.nc' 43 | samples_output_path = os.path.join(checkpoint_dir, samples_output_name) 44 | 45 | with flagsaver.flagsaver( 46 | output_path=training_path, 47 | equation_name='burgers', 48 | equation_kwargs='{"num_points": 256}', 49 | num_tasks=2, 50 | time_max=1.0, 51 | time_delta=0.1, 52 | warmup=0): 53 | create_training_data.main([], runner=beam.runners.DirectRunner()) 54 | 55 | with flagsaver.flagsaver( 56 | output_path=exact_path, 57 | equation_name='burgers', 58 | equation_kwargs='{"num_points": 256}', 59 | num_samples=2, 60 | time_max=1.0, 61 | time_delta=0.1, 62 | warmup=1.0): 63 | create_exact_data.main([], runner=beam.runners.DirectRunner()) 64 | 65 | with flagsaver.flagsaver( 66 | checkpoint_dir=checkpoint_dir, 67 | input_path=training_path, 68 | hparams='resample_factor=4,learning_rates=[1e-3],learning_stops=[20],' 69 | 'eval_interval=10', 70 | equation='burgers'): 71 | run_training.main([]) 72 | 73 | with flagsaver.flagsaver( 74 | checkpoint_dir=checkpoint_dir, 75 | exact_solution_path=exact_path, 76 | samples_output_name=samples_output_name, 77 | num_samples=2, 78 | time_max=1.0, 79 | time_delta=0.1): 80 | run_evaluation.main([], runner=beam.runners.DirectRunner()) 81 | 82 | # verify the results 83 | with xarray.open_dataset(samples_output_path) as ds: 84 | self.assertEqual(dict(ds.dims), 85 | {'sample': 2, 'time': 11, 'x': 64}) 86 | self.assertEqual(set(ds), {'y'}) 87 | 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /pde_superresolution/training_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sanity tests for training a model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | import tempfile 23 | 24 | from absl import flags 25 | from absl.testing import parameterized 26 | from absl.testing import absltest # pylint: disable=g-bad-import-order 27 | import numpy as np 28 | import pandas as pd 29 | import tensorflow as tf 30 | 31 | from pde_superresolution import training # pylint: disable=g-bad-import-order 32 | 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | NUM_X_POINTS = 256 37 | 38 | 39 | class TrainingTest(parameterized.TestCase): 40 | 41 | def setUp(self): 42 | self.tmpdir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) 43 | 44 | extra_testcases = [] 45 | for equation in ['burgers', 'kdv', 'ks']: 46 | for conservative in [True, False]: 47 | for num_time_steps in [0, 1]: 48 | extra_testcases.append({ 49 | 'equation': equation, 50 | 'conservative': conservative, 51 | 'num_time_steps': num_time_steps, 52 | }) 53 | 54 | @parameterized.parameters( 55 | dict(equation='burgers', polynomial_accuracy_order=0), 56 | dict(equation='ks', coefficient_grid_min_size=9), 57 | dict(equation='ks', polynomial_accuracy_order=0), 58 | dict(equation='burgers', conservative=True, numerical_flux=True), 59 | dict(equation='ks', conservative=True, numerical_flux=True), 60 | dict(equation='kdv', conservative=True, numerical_flux=True), 61 | dict(equation='burgers', noise_probability=0.5, noise_amplitude=0.1), 62 | dict(equation='burgers', noise_probability=0.5, noise_amplitude=0.1, 63 | noise_type='filtered'), 64 | dict(equation='burgers', kernel_size=5, nonlinearity='relu6'), 65 | dict(equation='burgers', resample_factor=64), 66 | dict(equation='burgers', polynomial_accuracy_order=1, num_layers=0), 67 | dict(equation='burgers', model_target='space_derivatives'), 68 | dict(equation='burgers', model_target='flux'), 69 | dict(equation='burgers', model_target='time_derivative'), 70 | dict(equation='burgers', error_max=10.0), 71 | *extra_testcases) 72 | def test_training_loop(self, **hparam_values): 73 | with tf.Graph().as_default(): 74 | snapshots = np.random.RandomState(0).randn(100, NUM_X_POINTS) 75 | hparams = training.create_hparams( 76 | learning_rates=[1e-3], 77 | learning_stops=[20], 78 | eval_interval=10, 79 | equation_kwargs=json.dumps({'num_points': NUM_X_POINTS}), 80 | **hparam_values) 81 | results = training.training_loop(snapshots, self.tmpdir, hparams) 82 | self.assertIsInstance(results, pd.DataFrame) 83 | self.assertEqual(results.shape[0], 2) 84 | 85 | 86 | if __name__ == '__main__': 87 | absltest.main() 88 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/run_mae.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # pylint: disable=line-too-long 16 | """Run a beam pipeline to add netCDF files with mean absolute error.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | 23 | from absl import app 24 | from absl import flags 25 | import apache_beam as beam 26 | from pde_superresolution import analysis # pylint: disable=g-bad-import-order 27 | import pandas 28 | import tensorflow as tf 29 | import xarray 30 | 31 | 32 | flags.DEFINE_string( 33 | 'file_pattern', None, 34 | 'Glob to use for matching simulation files.') 35 | flags.DEFINE_string( 36 | 'exact_results_file', None, 37 | 'Optional file providing alternative "exact" simulation results.') 38 | # These defaults values are chosen for Burgers [13, 15, 20, 25], KdV [51] and 39 | # KS [103] 40 | flags.DEFINE_string( 41 | 'stop_times', json.dumps([13, 15, 20, 25, 51, 103]), 42 | 'Cut-off times to use when calculating MAE.') 43 | 44 | 45 | FLAGS = flags.FLAGS 46 | 47 | 48 | def create_mae_netcdf(simulation_path, stop_times=None, exact_path=None): 49 | """Create a new netCDF file with mean absolute error.""" 50 | 51 | if '/results.nc' not in simulation_path: 52 | # no simulation results 53 | return 54 | 55 | # read data 56 | with tf.gfile.GFile(simulation_path, 'rb') as f: 57 | ds = xarray.open_dataset(f.read()).load() 58 | 59 | if exact_path is not None: 60 | with tf.gfile.GFile(exact_path, 'rb') as f: 61 | ds_exact = xarray.open_dataset(f.read()).load() 62 | ds['y_exact'] = (ds_exact['y'] 63 | .rename({'x': 'x_high'}) 64 | .reindex_like(ds, method='nearest')) 65 | 66 | # do the analysis 67 | ds = analysis.unify_x_coords(ds) 68 | 69 | results = [] 70 | for time_max in stop_times: 71 | ds = ds.sel(time=slice(None, time_max)) 72 | mae = abs(ds.drop('y_exact') - ds.y_exact).mean(['x', 'time'], skipna=False) 73 | results.append(mae) 74 | dim = pandas.Index(stop_times, name='time_max') 75 | mae_all = xarray.concat(results, dim=dim) 76 | 77 | # save results 78 | mae_path = simulation_path.replace('/results.nc', '/mae.nc') 79 | with tf.gfile.GFile(mae_path, 'wb') as f: 80 | f.write(mae_all.to_netcdf()) 81 | 82 | 83 | def main(_, runner=None): 84 | if runner is None: 85 | # must create before flags are used 86 | runner = beam.runners.DirectRunner() 87 | 88 | pipeline = ( 89 | beam.Create(tf.gfile.Glob(FLAGS.file_pattern)) 90 | | beam.Reshuffle() 91 | | beam.Map(create_mae_netcdf, 92 | stop_times=json.loads(FLAGS.stop_times), 93 | exact_path=FLAGS.exact_results_file) 94 | ) 95 | runner.run(pipeline) 96 | 97 | 98 | if __name__ == '__main__': 99 | app.run(main) 100 | -------------------------------------------------------------------------------- /pde_superresolution/analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Analysis functions for saved model results.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | from typing import Union 23 | 24 | import xarray 25 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 26 | 27 | 28 | XarrayObject = Union[xarray.Dataset, xarray.DataArray] # pylint: disable=invalid-name 29 | 30 | 31 | def resample_mean(ds, dim, factor): 32 | """Resample an xarray object along a single dimension.""" 33 | return xarray.apply_ufunc( 34 | duckarray.resample_mean, ds, 35 | input_core_dims=[[dim]], output_core_dims=[['dim_new']], 36 | kwargs=dict(factor=factor)).rename({'dim_new': dim}) 37 | 38 | 39 | def unify_x_coords(ds: xarray.Dataset) -> xarray.Dataset: 40 | """Resample data variables in an xarray.Dataset to only use low resolution.""" 41 | factor = ds.sizes['x_high'] // ds.sizes['x_low'] 42 | 43 | high_vars = [k for k, v in ds.variables.items() if 'x_high' in v.dims] 44 | ds_low = ds.drop(high_vars).rename({'x_low': 'x'}) 45 | 46 | low_vars = [k for k, v in ds.variables.items() if 'x_low' in v.dims] 47 | ds_high = ds.drop(low_vars).rename({'x_high': 'x'}) 48 | ds_high_resampled = resample_mean(ds_high, 'x', factor) 49 | 50 | unified = ds_low.merge(ds_high_resampled) 51 | return xarray.Dataset( 52 | collections.OrderedDict((k, unified[k]) for k in sorted(unified))) 53 | 54 | 55 | def is_good( 56 | model: XarrayObject, 57 | exact: XarrayObject, 58 | max_error: float = 0.5, 59 | ) -> XarrayObject: 60 | """Is each point of solution accurate within some error threshold?""" 61 | return abs(model - exact) <= max_error 62 | 63 | 64 | def mostly_good( 65 | model: XarrayObject, 66 | exact: XarrayObject, 67 | max_error: float = 0.5, 68 | frac_good: float = 0.8, 69 | ) -> XarrayObject: 70 | """Is the solution at a single-time within acceptable error bounds?""" 71 | return is_good(model, exact, max_error=max_error).mean('x') >= frac_good 72 | 73 | 74 | def calculate_survival(ds: XarrayObject) -> XarrayObject: 75 | """Calculate the "lifetime" of an xarray object with a boolean dtype.""" 76 | return xarray.where(ds.all('time'), 77 | ds['time'].max(), 78 | ds['time'].isel(time=ds.argmin('time'))).drop('time') 79 | 80 | 81 | def mostly_good_survival( 82 | ds: xarray.Dataset, quantile: float = 0.8) -> xarray.Dataset: 83 | """Calculate mostly good survival for a Dataset with a "y_exact" variable.""" 84 | max_error = abs(ds['y_exact']).quantile(q=1-quantile).item() 85 | unified = unify_x_coords(ds) 86 | good_enough = mostly_good( 87 | unified.drop('y_exact').to_array(dim='variable'), unified['y_exact'], 88 | max_error=max_error, frac_good=quantile) 89 | survival = calculate_survival(good_enough).to_dataset(dim='variable') 90 | return survival 91 | -------------------------------------------------------------------------------- /pde_superresolution/layers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sanity tests for layers.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | from pde_superresolution import layers # pylint: disable=g-bad-import-order 27 | 28 | 29 | def pad_periodic_1d(inputs, padding, center=False): 30 | padded_inputs = inputs[tf.newaxis, :, tf.newaxis] 31 | padded_outputs = layers.pad_periodic(padded_inputs, padding, center) 32 | return tf.squeeze(padded_outputs, axis=(0, 2)) 33 | 34 | 35 | class LayersTest(parameterized.TestCase): 36 | 37 | def test_static_or_dynamic_size(self): 38 | with tf.Graph().as_default(): 39 | with tf.Session(): 40 | self.assertEqual(layers.static_or_dynamic_size(tf.range(5), axis=0), 5) 41 | 42 | feed_size = tf.placeholder(tf.int32, ()) 43 | size = layers.static_or_dynamic_size(tf.range(feed_size), axis=0) 44 | self.assertEqual(size.eval(feed_dict={feed_size: 5}), 5) 45 | 46 | with self.assertRaisesRegexp(ValueError, 'out of bounds'): 47 | layers.static_or_dynamic_size(tf.range(5), axis=1) 48 | 49 | @parameterized.parameters( 50 | dict(padding=0, center=True, expected=[0, 1, 2]), 51 | dict(padding=1, center=True, expected=[2, 0, 1, 2]), 52 | dict(padding=2, center=True, expected=[2, 0, 1, 2, 0]), 53 | dict(padding=3, center=True, expected=[1, 2, 0, 1, 2, 0]), 54 | dict(padding=4, center=True, expected=[1, 2, 0, 1, 2, 0, 1]), 55 | dict(padding=6, center=True, expected=[0, 1, 2, 0, 1, 2, 0, 1, 2]), 56 | dict(padding=7, center=True, expected=[2, 0, 1, 2, 0, 1, 2, 0, 1, 2]), 57 | dict(padding=0, center=False, expected=[0, 1, 2]), 58 | dict(padding=1, center=False, expected=[0, 1, 2, 0]), 59 | dict(padding=2, center=False, expected=[0, 1, 2, 0, 1]), 60 | dict(padding=3, center=False, expected=[0, 1, 2, 0, 1, 2]), 61 | dict(padding=5, center=False, expected=[0, 1, 2, 0, 1, 2, 0, 1]), 62 | ) 63 | def test_pad_periodic(self, padding, expected, center): 64 | with tf.Graph().as_default(): 65 | with tf.Session(): 66 | inputs = pad_periodic_1d(tf.range(3), padding=padding, center=center) 67 | np.testing.assert_equal(inputs.eval(), expected) 68 | 69 | def test_nn_conv1d_periodic(self): 70 | with tf.Graph().as_default(): 71 | with tf.Session(): 72 | inputs = tf.range(5.0)[tf.newaxis, :, tf.newaxis] 73 | 74 | filters = tf.constant([0.0, 1.0, 0.0])[:, tf.newaxis, tf.newaxis] 75 | actual = layers.nn_conv1d_periodic(inputs, filters, center=True) 76 | np.testing.assert_allclose(inputs.eval(), actual.eval()) 77 | 78 | filters = tf.constant([0.0, 1.0])[:, tf.newaxis, tf.newaxis] 79 | actual = layers.nn_conv1d_periodic(inputs, filters, center=True) 80 | np.testing.assert_allclose(inputs.eval(), actual.eval()) 81 | 82 | filters = tf.constant([0.5, 0.5])[:, tf.newaxis, tf.newaxis] 83 | expected = tf.constant( 84 | [2.0, 0.5, 1.5, 2.5, 3.5])[tf.newaxis, :, tf.newaxis] 85 | actual = layers.nn_conv1d_periodic(inputs, filters, center=True) 86 | np.testing.assert_allclose(expected.eval(), actual.eval()) 87 | 88 | 89 | if __name__ == '__main__': 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /pde_superresolution/weno_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for WENO reconstruction.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | 25 | from pde_superresolution import weno 26 | 27 | 28 | class WENOTest(parameterized.TestCase): 29 | 30 | def test_calculate_omega_smooth(self): 31 | u = np.zeros(5) 32 | actual = weno.calculate_omega(u) 33 | expected = np.stack(5 * [[0.1, 0.6, 0.3]], axis=1) 34 | np.testing.assert_allclose(actual, expected) 35 | 36 | def test_left_coefficients_smooth(self): 37 | u = np.zeros(5) 38 | actual = weno.left_coefficients(u) 39 | expected = np.stack(5 * [[2/60, -13/60, 47/60, 27/60, -3/60]], axis=0) 40 | np.testing.assert_allclose(actual, expected) 41 | 42 | def test_right_coefficients_smooth(self): 43 | u = np.zeros(5) 44 | actual = weno.right_coefficients(u) 45 | expected = np.stack(5 * [[-3/60, 27/60, 47/60, -13/60, 2/60]], axis=0) 46 | np.testing.assert_allclose(actual, expected) 47 | 48 | def test_reconstruct_left_discontinuity(self): 49 | u = np.array([0, 1, 2, 3, 4, -4, -3, -2, -1]) 50 | actual = weno.reconstruct_left(u) 51 | expected = [0.5, 1.5, 2.5, 3.5, 4.5, -3.5, -2.5, -1.5, -0.5] 52 | np.testing.assert_allclose(actual, expected, atol=0.005) 53 | 54 | def test_reconstruct_right_discontinuity(self): 55 | u = np.array([0, 1, 2, 3, 4, -4, -3, -2, -1]) 56 | actual = weno.reconstruct_right(u) 57 | expected = [0.5, 1.5, 2.5, 3.5, -4.5, -3.5, -2.5, -1.5, -0.5] 58 | np.testing.assert_allclose(actual, expected, atol=0.005) 59 | 60 | @parameterized.parameters( 61 | dict(u=[0, 0, 0, 0, 1, 0, 0, 0, 0, 0]), 62 | dict(u=[1, 1, 1, 1, 1, 0, 0, 0, 0, 0]), 63 | dict(u=[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]), 64 | dict(u=[0, 0, 1, 2, 3, 0, 0, 0, 0, 0]), 65 | dict(u=[0, 0, 0, 1, 2, 0, 0, 0, 0, 0]), 66 | dict(u=2 * np.random.RandomState(0).rand(10)), 67 | ) 68 | def test_reconstruction_symmetry(self, u): 69 | u = np.array(u, dtype=float) 70 | 71 | def flip(x): 72 | return x[::-1] 73 | 74 | def flip_staggered(x): 75 | return flip(np.roll(x, +1)) 76 | 77 | left_direct = weno.reconstruct_left(u) 78 | left_flipped = flip_staggered(weno.reconstruct_right(flip(u))) 79 | np.testing.assert_allclose(left_direct, left_flipped, atol=1e-6) 80 | 81 | right_direct = weno.reconstruct_right(u) 82 | right_flipped = flip_staggered(weno.reconstruct_left(flip(u))) 83 | np.testing.assert_allclose(right_direct, right_flipped, atol=1e-6) 84 | 85 | def test_batched(self): 86 | u_batched = np.array([[0, 0, 0, 1, 2, 3, 4], 87 | [0, 0, 1, 2, 3, 4, 5]]) 88 | expected_left = np.stack([weno.reconstruct_left(u_batched[0]), 89 | weno.reconstruct_left(u_batched[1])]) 90 | expected_right = np.stack([weno.reconstruct_right(u_batched[0]), 91 | weno.reconstruct_right(u_batched[1])]) 92 | 93 | actual_left = weno.reconstruct_left(u_batched) 94 | actual_right = weno.reconstruct_right(u_batched) 95 | 96 | np.testing.assert_allclose(actual_left, expected_left) 97 | np.testing.assert_allclose(actual_right, expected_right) 98 | 99 | 100 | if __name__ == '__main__': 101 | absltest.main() 102 | -------------------------------------------------------------------------------- /pde_superresolution/weno.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """An implementation of 5th order upwind-biased WENO, "WENO5". 16 | 17 | Based on the implementation described in: 18 | [1] Tang, Lei. 2005. "Upwind and Central WENO Schemes." Applied Mathematics and 19 | Computation 166 (2): 434-48. 20 | [2] Shu, Chi-Wang. 1998. "Essentially Non-Oscillatory and Weighted Essentially 21 | Non-Oscillatory Schemes for Hyperbolic Conservation Laws." In Advanced 22 | Numerical Approximation of Nonlinear Hyperbolic Equations: Lectures given 23 | at the 2nd Session of the Centro Internazionale Matematico Estivo 24 | (C.I.M.E.) Held in Cetraro, Italy, June 23-28, 1997, edited by Bernardo 25 | Cockburn, Chi-Wang Shu, Claes Johnson, Eitan Tadmor, and Alfio Quarteroni, 26 | 325-432. Berlin, Heidelberg: Springer Berlin Heidelberg. 27 | https://www3.nd.edu/~zxu2/acms60790S13/Shu-WENO-notes.pdf 28 | """ 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import numpy as np 34 | 35 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 36 | 37 | 38 | # These optimal weights result in a 5th order one-point upwinded coefficients 39 | # for smooth functions. 40 | OPTIMAL_SMOOTH_WEIGHTS = (0.1, 0.6, 0.3) 41 | 42 | 43 | def calculate_smoothness_indicators(u): 44 | """Calculate smoothness indicators for picking weights.""" 45 | # see Equation (7) in ref [1] 46 | u_minus2 = duckarray.roll(u, +2, axis=-1) 47 | u_minus1 = duckarray.roll(u, +1, axis=-1) 48 | u_plus1 = duckarray.roll(u, -1, axis=-1) 49 | u_plus2 = duckarray.roll(u, -2, axis=-1) 50 | return duckarray.stack([ 51 | 1/4 * (u_minus2 - 4 * u_minus1 + 3 * u) ** 2 + 52 | 13/12 * (u_minus2 - 2 * u_minus1 + u) ** 2, 53 | 1/4 * (u_minus1 - u_plus1) ** 2 + 54 | 13/12 * (u_minus1 - 2 * u + u_plus1) ** 2, 55 | 1/4 * (3 * u - 4 * u_plus1 + u_plus2) ** 2 + 56 | 13/12 * (u - 2 * u_plus1 + u_plus2) ** 2, 57 | ], axis=-2) 58 | 59 | 60 | def calculate_omega( 61 | u, 62 | optimal_linear_weights=OPTIMAL_SMOOTH_WEIGHTS, 63 | epsilon=1e-6, 64 | p=2, 65 | ): 66 | """Calculate linear weights for the three polynomial reconstructions.""" 67 | # see Equation (6) in ref [1] 68 | indicator_kj = calculate_smoothness_indicators(u) 69 | # p=2 is used by ref. [2] 70 | alpha_kj = (np.array(optimal_linear_weights)[:, np.newaxis] 71 | / (epsilon + indicator_kj) ** p) 72 | omega_kj = alpha_kj / duckarray.sum(alpha_kj, axis=-2, keepdims=True) 73 | return omega_kj 74 | 75 | 76 | def left_coefficients(u): 77 | """Linear coefficients for WENO reconstruction from the left.""" 78 | # see Equation (5) from ref [1] 79 | omega_kj = calculate_omega(u) 80 | omega0 = omega_kj[..., 0, :] 81 | omega1 = omega_kj[..., 1, :] 82 | omega2 = omega_kj[..., 2, :] 83 | return duckarray.stack([ 84 | omega0 / 3, 85 | - (7 * omega0 + omega1) / 6, 86 | (11 * omega0 + 5 * omega1 + 2 * omega2) / 6, 87 | (2 * omega1 + 5 * omega2) / 6, 88 | - omega2 / 6, 89 | ], axis=-1) 90 | 91 | 92 | def reconstruct_left(u): 93 | """Reconstruct u at +1/2 cells with a left-biased stencil.""" 94 | coefficients = left_coefficients(u) 95 | u_all = duckarray.stack( 96 | [duckarray.roll(u, i, axis=-1) for i in [2, 1, 0, -1, -2]], axis=-1) 97 | return duckarray.sum(coefficients * u_all, axis=-1) 98 | 99 | 100 | def right_coefficients(u): 101 | """Linear coefficients for WENO reconstruction from the right.""" 102 | # see Equation (9) from ref [1], but note that it has an error: optimal 103 | # smoothing weights should be reversed, per step 2 of Procedure 2.2 in ref [2] 104 | omega_kj = calculate_omega(u, OPTIMAL_SMOOTH_WEIGHTS[::-1]) 105 | omega_kj_rolled = duckarray.roll(omega_kj, -1, axis=-1) 106 | omega2 = omega_kj_rolled[..., 0, :] 107 | omega1 = omega_kj_rolled[..., 1, :] 108 | omega0 = omega_kj_rolled[..., 2, :] 109 | return duckarray.stack([ 110 | -omega2 / 6, 111 | (5 * omega2 + 2 * omega1) / 6, 112 | (2 * omega2 + 5 * omega1 + 11 * omega0) / 6, 113 | -(omega1 + 7 * omega0) / 6, 114 | omega0 / 3, 115 | ], axis=-1) 116 | 117 | 118 | def reconstruct_right(u): 119 | """Reconstruct u at +1/2 cells with a right-biased stencil.""" 120 | coefficients = right_coefficients(u) 121 | u_all = duckarray.stack( 122 | [duckarray.roll(u, i, axis=-1) for i in [1, 0, -1, -2, -3]], axis=-1) 123 | return duckarray.sum(coefficients * u_all, axis=-1) 124 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/create_exact_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a beam pipeline to run the WENO5 model.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import json 21 | 22 | from absl import app 23 | from absl import flags 24 | import apache_beam as beam 25 | import numpy as np 26 | from pde_superresolution import equations 27 | from pde_superresolution import integrate 28 | from pde_superresolution import xarray_beam 29 | 30 | 31 | # NOTE(shoyer): allow_override=True lets us import multiple binaries for the 32 | # purpose of running integration tests. This is safe since we're strict about 33 | # only using FLAGS inside main(). 34 | 35 | # files 36 | flags.DEFINE_string( 37 | 'output_path', '', 38 | 'Full path to which to save the resulting netCDF file.', 39 | allow_override=True) 40 | 41 | # equation parameters 42 | flags.DEFINE_enum( 43 | 'equation_name', 'burgers', list(equations.EQUATION_TYPES), 44 | 'Equation to integrate.', 45 | allow_override=True) 46 | flags.DEFINE_string( 47 | 'equation_kwargs', '{"num_points": 400}', 48 | 'Parameters to pass to the equation constructor.', 49 | allow_override=True) 50 | flags.DEFINE_integer( 51 | 'num_samples', 10, 52 | 'Number of times to integrate each equation.', 53 | allow_override=True) 54 | 55 | # integrate parameters 56 | flags.DEFINE_float( 57 | 'time_max', 10, 58 | 'Total time for which to run each integration.', 59 | allow_override=True) 60 | flags.DEFINE_float( 61 | 'time_delta', 1, 62 | 'Difference between saved time steps in the integration.', 63 | allow_override=True) 64 | flags.DEFINE_float( 65 | 'warmup', 0, 66 | 'Amount of time to integrate before using the neural network.', 67 | allow_override=True) 68 | flags.DEFINE_enum( 69 | 'discretization_method', 'exact', ['exact', 'weno', 'spectral'], 70 | 'How the exact solution is discretized. By default, uses the "exact" ' 71 | 'method that has been saved for this equation.', 72 | allow_override=True) 73 | flags.DEFINE_string( 74 | 'integrate_method', 'RK23', 75 | 'Method to use for integration with scipy.integrate.solve_ivp.', 76 | allow_override=True) 77 | flags.DEFINE_float( 78 | 'exact_filter_interval', 0, 79 | 'Interval between periodic filtering. Only used for spectral methods.', 80 | allow_override=True) 81 | 82 | 83 | FLAGS = flags.FLAGS 84 | 85 | 86 | def main(_, runner=None): 87 | if runner is None: 88 | # must create before flags are used 89 | runner = beam.runners.DirectRunner() 90 | 91 | equation_kwargs = json.loads(FLAGS.equation_kwargs) 92 | 93 | use_weno = (FLAGS.discretization_method == 'weno' 94 | or (FLAGS.discretization_method == 'exact' 95 | and FLAGS.equation_name == 'burgers')) 96 | 97 | if (not use_weno and FLAGS.exact_filter_interval): 98 | exact_filter_interval = float(FLAGS.exact_filter_interval) 99 | else: 100 | exact_filter_interval = None 101 | 102 | def create_equation(seed, name=FLAGS.equation_name, 103 | kwargs=equation_kwargs): 104 | equation_type = (equations.FLUX_EQUATION_TYPES 105 | if use_weno else 106 | equations.EQUATION_TYPES)[name] 107 | return equation_type(random_seed=seed, **kwargs) 108 | 109 | def do_integrate( 110 | equation, 111 | times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), 112 | warmup=FLAGS.warmup, 113 | integrate_method=FLAGS.integrate_method): 114 | integrate_func = (integrate.integrate_weno 115 | if use_weno 116 | else integrate.integrate_spectral) 117 | return integrate_func(equation, times, warmup, integrate_method, 118 | exact_filter_interval=exact_filter_interval) 119 | 120 | def create_equation_and_integrate(seed): 121 | equation = create_equation(seed) 122 | result = do_integrate(equation) 123 | result.coords['sample'] = seed 124 | return result 125 | 126 | pipeline = ( 127 | beam.Create(list(range(FLAGS.num_samples))) 128 | | beam.Map(create_equation_and_integrate) 129 | | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) 130 | | beam.Map(lambda ds: ds.sortby('sample')) 131 | | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path)) 132 | 133 | runner.run(pipeline) 134 | 135 | 136 | if __name__ == '__main__': 137 | app.run(main) 138 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/create_baseline_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a beam pipeline to generate training data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import json 21 | 22 | from absl import app 23 | from absl import flags 24 | import apache_beam as beam 25 | import numpy as np 26 | from pde_superresolution import equations 27 | from pde_superresolution import integrate 28 | from pde_superresolution import xarray_beam 29 | 30 | 31 | # NOTE(shoyer): allow_override=True lets us import multiple binaries for the 32 | # purpose of running integration tests. This is safe since we're strict about 33 | # only using FLAGS inside main(). 34 | 35 | # files 36 | flags.DEFINE_string( 37 | 'output_path', '', 38 | 'Full path to which to save the resulting netCDF file.', 39 | allow_override=True) 40 | 41 | # equation parameters 42 | flags.DEFINE_enum( 43 | 'equation_name', 'burgers', list(equations.CONSERVATIVE_EQUATION_TYPES), 44 | 'Equation to integrate.', allow_override=True) 45 | flags.DEFINE_string( 46 | 'equation_kwargs', '{"num_points": 400}', 47 | 'Parameters to pass to the equation constructor.', allow_override=True) 48 | flags.DEFINE_integer( 49 | 'num_samples', 10, 50 | 'Number of times to integrate each equation.', allow_override=True) 51 | 52 | # integrate parameters 53 | flags.DEFINE_float( 54 | 'time_max', 10, 55 | 'Total time for which to run each integration.', 56 | allow_override=True) 57 | flags.DEFINE_multi_integer( 58 | 'accuracy_orders', [1, 3], 59 | 'Accuracy order for which to calculate results', 60 | allow_override=True) 61 | flags.DEFINE_float( 62 | 'time_delta', 1, 63 | 'Difference between saved time steps in the integration.', 64 | allow_override=True) 65 | flags.DEFINE_float( 66 | 'warmup', 0, 67 | 'Amount of time to integrate before using the neural network.', 68 | allow_override=True) 69 | flags.DEFINE_string( 70 | 'integrate_method', 'RK23', 71 | 'Method to use for integration with scipy.integrate.solve_ivp.', 72 | allow_override=True) 73 | flags.DEFINE_float( 74 | 'exact_filter_interval', 0, 75 | 'Interval between periodic filtering. Only used for spectral methods.', 76 | allow_override=True) 77 | 78 | 79 | FLAGS = flags.FLAGS 80 | 81 | 82 | def main(_, runner=None): 83 | if runner is None: 84 | # must create before flags are used 85 | runner = beam.runners.DirectRunner() 86 | 87 | equation_kwargs = json.loads(FLAGS.equation_kwargs) 88 | accuracy_orders = FLAGS.accuracy_orders 89 | 90 | if (equations.EQUATION_TYPES[FLAGS.equation_name].EXACT_METHOD 91 | is equations.ExactMethod.SPECTRAL and FLAGS.exact_filter_interval): 92 | exact_filter_interval = float(FLAGS.exact_filter_interval) 93 | else: 94 | exact_filter_interval = None 95 | 96 | def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs): 97 | equation_type = equations.CONSERVATIVE_EQUATION_TYPES[name] 98 | return equation_type(random_seed=seed, **kwargs) 99 | 100 | def integrate_baseline( 101 | equation, accuracy_order, 102 | times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), 103 | warmup=FLAGS.warmup, 104 | integrate_method=FLAGS.integrate_method, 105 | exact_filter_interval=exact_filter_interval): 106 | return integrate.integrate_baseline( 107 | equation, times, warmup, accuracy_order, integrate_method, 108 | exact_filter_interval).astype(np.float32) 109 | 110 | def create_equation_and_integrate(seed_and_accuracy_order): 111 | seed, accuracy_order = seed_and_accuracy_order 112 | equation = create_equation(seed) 113 | assert equation.CONSERVATIVE 114 | result = integrate_baseline(equation, accuracy_order) 115 | result.coords['sample'] = seed 116 | result.coords['accuracy_order'] = accuracy_order 117 | return (seed, result) 118 | 119 | pipeline = ( 120 | beam.Create(list(range(FLAGS.num_samples))) 121 | | beam.FlatMap( 122 | lambda seed: [(seed, accuracy) for accuracy in accuracy_orders]) 123 | | beam.Map(create_equation_and_integrate) 124 | | beam.CombinePerKey(xarray_beam.ConcatCombineFn('accuracy_order')) 125 | | beam.Map(lambda seed_and_ds: seed_and_ds[1].sortby('accuracy_order')) 126 | | beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) 127 | | beam.Map(lambda ds: ds.sortby('sample')) 128 | | beam.Map(xarray_beam.write_netcdf, path=FLAGS.output_path)) 129 | 130 | runner.run(pipeline) 131 | 132 | 133 | if __name__ == '__main__': 134 | app.run(main) 135 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/create_training_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a beam pipeline to generate training data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import functools 21 | import json 22 | import os.path 23 | 24 | from absl import app 25 | from absl import flags 26 | import apache_beam as beam 27 | import numpy as np 28 | from pde_superresolution import equations 29 | from pde_superresolution import integrate 30 | from pde_superresolution import utils 31 | import tensorflow as tf 32 | import xarray 33 | 34 | 35 | # NOTE(shoyer): allow_override=True lets us import multiple binaries for the 36 | # purpose of running integration tests. This is safe since we're strict about 37 | # only using FLAGS inside main(). 38 | 39 | # files 40 | flags.DEFINE_string( 41 | 'output_path', '', 42 | 'Full path to which to save the resulting HDF5 file.', 43 | allow_override=True) 44 | 45 | # equation parameters 46 | flags.DEFINE_enum( 47 | 'equation_name', 'burgers', list(equations.CONSERVATIVE_EQUATION_TYPES), 48 | 'Equation to integrate.', allow_override=True) 49 | flags.DEFINE_string( 50 | 'equation_kwargs', '{"num_points": 400}', 51 | 'Parameters to pass to the equation constructor.', allow_override=True) 52 | flags.DEFINE_integer( 53 | 'num_tasks', 10, 54 | 'Number of times to integrate each equation.', 55 | allow_override=True) 56 | flags.DEFINE_integer( 57 | 'seed_offset', 1000000, 58 | 'Integer seed offset for random number generator. This should be larger ' 59 | 'than the largest possible number of evaluation seeds, but smaller ' 60 | 'than 2^32 (the size of NumPy\'s random number seed).', 61 | allow_override=True) 62 | 63 | # integrate parameters 64 | flags.DEFINE_float( 65 | 'time_max', 10, 66 | 'Total time for which to run each integration.', 67 | allow_override=True) 68 | flags.DEFINE_float( 69 | 'time_delta', 1, 70 | 'Difference between saved time steps in the integration.', 71 | allow_override=True) 72 | flags.DEFINE_float( 73 | 'warmup', 0, 74 | 'Amount of time to integrate before saving snapshots.', 75 | allow_override=True) 76 | flags.DEFINE_string( 77 | 'integrate_method', 'RK23', 78 | 'Method to use for integration with scipy.integrate.solve_ivp.', 79 | allow_override=True) 80 | flags.DEFINE_float( 81 | 'exact_filter_interval', 0, 82 | 'Interval between periodic filtering. Only used for spectral methods.', 83 | allow_override=True) 84 | 85 | 86 | FLAGS = flags.FLAGS 87 | 88 | 89 | def main(_, runner=None): 90 | if runner is None: 91 | # must create before flags are used 92 | runner = beam.runners.DirectRunner() 93 | 94 | equation_kwargs = json.loads(FLAGS.equation_kwargs) 95 | 96 | def create_equation(seed, name=FLAGS.equation_name, kwargs=equation_kwargs): 97 | equation_type = equations.EQUATION_TYPES[name] 98 | return equation_type(random_seed=seed, **kwargs) 99 | 100 | if (equations.EQUATION_TYPES[FLAGS.equation_name].EXACT_METHOD 101 | is equations.ExactMethod.SPECTRAL and FLAGS.exact_filter_interval): 102 | filter_interval = FLAGS.exact_filter_interval 103 | else: 104 | filter_interval = None 105 | 106 | integrate_exact = functools.partial( 107 | integrate.integrate_exact, 108 | times=np.arange(0, FLAGS.time_max, FLAGS.time_delta), 109 | warmup=FLAGS.warmup, 110 | integrate_method=FLAGS.integrate_method, 111 | filter_interval=filter_interval) 112 | 113 | expected_samples_per_task = int(round(FLAGS.time_max / FLAGS.time_delta)) 114 | expected_total_samples = expected_samples_per_task * FLAGS.num_tasks 115 | 116 | def save(list_of_datasets, path=FLAGS.output_path, attrs=equation_kwargs): 117 | assert len(list_of_datasets) == len(seeds), len(list_of_datasets) 118 | combined = xarray.concat(list_of_datasets, dim='time') 119 | num_samples = combined.sizes['time'] 120 | assert num_samples == expected_total_samples, num_samples 121 | tf.gfile.MakeDirs(os.path.dirname(path)) 122 | with utils.write_h5py(path) as f: 123 | f.create_dataset('v', data=combined['y'].values) 124 | f.attrs.update(attrs) 125 | 126 | # introduce an offset so there's no overlap with the evaluation dataset 127 | seeds = [i + FLAGS.seed_offset for i in range(FLAGS.num_tasks)] 128 | 129 | pipeline = ( 130 | beam.Create(seeds) 131 | | beam.Map(create_equation) 132 | | beam.Map(integrate_exact) 133 | | beam.combiners.ToList() 134 | | beam.Map(save) 135 | ) 136 | runner.run(pipeline) 137 | 138 | 139 | if __name__ == '__main__': 140 | app.run(main) 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning data-driven discretizations for partial differential equations 2 | 3 | Code associated with the paper: 4 | 5 | [Learning data-driven discretizations for partial differential equations](https://www.pnas.org/content/116/31/15344). 6 | Yohai Bar-Sinai, Stephan Hoyer, Jason Hickey, Michael P. Brenner. 7 | Proceedings of the National Academy of Sciences Jul 2019, 116 (31) 15344-15349; DOI: 10.1073/pnas.1814058116. 8 | 9 | 10 | ## Deprecation 11 | 12 | This code for Data Driven Discretization was developed for and used in [https://arxiv.org/abs/1808.04930]. The code is fully functional, but is no longer maintained. It was deprecated by a new implementation that can natively handle higher dimensions and is better designed to be generalized. The new code is available [here](https://github.com/google-research/data-driven-pdes). If you want to implement our method on your favorite equation, please contact the authors. 13 | 14 | ## Running the code 15 | 16 | ### Local installation 17 | 18 | If desired, you can install the code locally. You can also run using Google's hosted Colab notebook service (see below for examples). 19 | 20 | Clone this repository and install in-place: 21 | 22 | git clone https://github.com/google/data-driven-discretization-1d.git 23 | pip install -e data-driven-discretization-1d 24 | 25 | Note that Python 3 is required. Dependencies for the core library (including 26 | TensorFlow) are specified in setup.py and should be installed automatically as 27 | required. Also note that TensorFlow 1.x is required: this code has not been 28 | updated to use TensorFlow 2.0. 29 | 30 | From the source directory, execute each test file: 31 | 32 | cd data-driven-discretization-1d 33 | python ./pde_superresolution/integrate_test.py 34 | python ./pde_superresolution/training_test.py 35 | 36 | ### Training your own models 37 | 38 | We used the scripts in the `pde_superresolution/scripts` directly to run 39 | training. In particular, see `run_training.py`. 40 | 41 | Training data was created with `create_training_data.py`, but can also be 42 | downloaded from Google Cloud Storage: 43 | 44 | - https://storage.googleapis.com/data-driven-discretization-public/training-data/burgers.h5 45 | - https://storage.googleapis.com/data-driven-discretization-public/training-data/kdv.h5 46 | - https://storage.googleapis.com/data-driven-discretization-public/training-data/ks.h5 47 | 48 | We have two notebooks showing how to train and run parts of our model. As written, these notebooks are intended to run in Google Colab, which can do by clicking the links below: 49 | - [Super resolution of Burgers' equation](https://colab.research.google.com/github/google/data-driven-discretization-1d/blob/master/notebooks/burgers-super-resolution.ipynb) 50 | - [Time integration of Burgers' equation](https://colab.research.google.com/github/google/data-driven-discretization-1d/blob/master/notebooks/time-integration.ipynb) 51 | 52 | These notebooks install the code from scratch; skip those cells if running things locally. You will also need [gsutil](https://cloud.google.com/storage/docs/gsutil) installed to download data from Google Cloud Storage. 53 | 54 | ## Citation 55 | 56 | ``` 57 | @article {Bar-Sinai15344, 58 | author = {Bar-Sinai, Yohai and Hoyer, Stephan and Hickey, Jason and Brenner, Michael P.}, 59 | title = {Learning data-driven discretizations for partial differential equations}, 60 | volume = {116}, 61 | number = {31}, 62 | pages = {15344--15349}, 63 | year = {2019}, 64 | doi = {10.1073/pnas.1814058116}, 65 | publisher = {National Academy of Sciences}, 66 | abstract = {In many physical systems, the governing equations are known with high confidence, but direct numerical solution is prohibitively expensive. Often this situation is alleviated by writing effective equations to approximate dynamics below the grid scale. This process is often impossible to perform analytically and is often ad hoc. Here we propose data-driven discretization, a method that uses machine learning to systematically derive discretizations for continuous physical systems. On a series of model problems, data-driven discretization gives accurate solutions with a dramatic drop in required resolution.The numerical solution of partial differential equations (PDEs) is challenging because of the need to resolve spatiotemporal features over wide length- and timescales. Often, it is computationally intractable to resolve the finest features in the solution. The only recourse is to use approximate coarse-grained representations, which aim to accurately represent long-wavelength dynamics while properly accounting for unresolved small-scale physics. Deriving such coarse-grained equations is notoriously difficult and often ad hoc. Here we introduce data-driven discretization, a method for learning optimized approximations to PDEs based on actual solutions to the known underlying equations. Our approach uses neural networks to estimate spatial derivatives, which are optimized end to end to best satisfy the equations on a low-resolution grid. The resulting numerical methods are remarkably accurate, allowing us to integrate in time a collection of nonlinear equations in 1 spatial dimension at resolutions 4{\texttimes} to 8{\texttimes} coarser than is possible with standard finite-difference methods.}, 67 | issn = {0027-8424}, 68 | URL = {https://www.pnas.org/content/116/31/15344}, 69 | eprint = {https://www.pnas.org/content/116/31/15344.full.pdf}, 70 | journal = {Proceedings of the National Academy of Sciences} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /pde_superresolution/xarray_beam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """"Utilities for using xarray with beam.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import apache_beam as beam 21 | import tensorflow as tf 22 | from typing import Iterator, List 23 | import xarray 24 | 25 | 26 | def read_netcdf(path: str) -> xarray.Dataset: 27 | """Read a netCDF file from a path into memory.""" 28 | with tf.gfile.GFile(path, mode='rb') as f: 29 | return xarray.open_dataset(f.read()).load() 30 | 31 | 32 | def write_netcdf(ds: xarray.Dataset, path: str) -> None: 33 | """Write an xarray.Datset to the given path.""" 34 | with tf.gfile.GFile(path, 'w') as f: 35 | f.write(ds.to_netcdf()) 36 | 37 | 38 | def _swap_dims_no_coordinate( 39 | ds: xarray.Dataset, old_dim: str, new_dim: str) -> xarray.Dataset: 40 | """Like xarray.Dataset.swap_dims(), but works even for non-coordinates. 41 | 42 | See https://github.com/pydata/xarray/issues/1855 for the upstream bug. 43 | 44 | Args: 45 | ds: old dataset. 46 | old_dim: name of existing dimension name. 47 | new_dim: name of new dimension name. 48 | 49 | Returns: 50 | Dataset with swapped dimensions. 51 | """ 52 | fix_dims = lambda dims: tuple(new_dim if d == old_dim else d for d in dims) 53 | return xarray.Dataset( 54 | {k: (fix_dims(v.dims), v.data, v.attrs) for k, v in ds.data_vars.items()}, 55 | {k: (fix_dims(v.dims), v.data, v.attrs) for k, v in ds.coords.items()}, 56 | ds.attrs) 57 | 58 | 59 | def stack(ds: xarray.Dataset, 60 | dim: str, 61 | levels: List[str]) -> xarray.Dataset: 62 | """Stack multiple dimensions along a new dimension. 63 | 64 | Unlike xarray's built-in stack: 65 | 1. This works for a single level. 66 | 2. Levels are turned into new coordinates, not levels in a MultiIndex. 67 | 68 | Args: 69 | ds: input dataset. 70 | dim: name of the new stacked dimension. Should not be found on the input 71 | dataset. 72 | levels: list of names of dimensions on the input dataset. Variables along 73 | these dimensions will be stacked together along the new dimension `dim`. 74 | 75 | Returns: 76 | Dataset with stacked data. 77 | 78 | """ 79 | if len(levels) == 1: 80 | # xarray's stack doesn't work properly with one level 81 | level = levels[0] 82 | return _swap_dims_no_coordinate(ds, level, dim) 83 | 84 | return ds.stack(**{dim: levels}).reset_index(dim) 85 | 86 | 87 | def unstack(ds: xarray.Dataset, 88 | dim: str, 89 | levels: List[str]) -> xarray.Dataset: 90 | """Unstack a dimension into multiple dimensions. 91 | 92 | Unlike xarray's built-in stack: 93 | 1. This works for a single level. 94 | 2. It does not expect levels to exist in a MultiIndex, but rather as 1D 95 | coordinates. 96 | 97 | Args: 98 | ds: input dataset. 99 | dim: name of an existing dimension on the input. 100 | levels: list of names of 1D variables along the dimension `dim` in the 101 | input dataset. Each of these will be a dimension on the output. 102 | 103 | Returns: 104 | Dataset with unstacked data, with each level turned into a new dimension. 105 | """ 106 | if len(levels) == 1: 107 | # xarray's unstack doesn't work properly with one level 108 | level = levels[0] 109 | return _swap_dims_no_coordinate(ds, dim, level) 110 | 111 | return ds.set_index(**{dim: levels}).unstack(dim) 112 | 113 | 114 | class SplitDoFn(beam.DoFn): 115 | """DoFn that splits an xarray Dataset across a dimension.""" 116 | 117 | def __init__(self, dim: str, keep_dims: bool = False): 118 | self.dim = dim 119 | self.keep_dims = keep_dims 120 | 121 | def process(self, element: xarray.Dataset) -> Iterator[xarray.Dataset]: 122 | for i in range(element.sizes[self.dim]): 123 | index = slice(i, i + 1) if self.keep_dims else i 124 | yield element[{self.dim: index}].copy() 125 | 126 | 127 | class ConcatCombineFn(beam.CombineFn): 128 | """CombineFn that concatenates across the given dimension.""" 129 | 130 | def __init__(self, dim: str): 131 | self._dim = dim 132 | 133 | def create_accumulator(self): 134 | return [] 135 | 136 | def add_input(self, 137 | accumulator: List[xarray.Dataset], 138 | element: xarray.Dataset) -> List[xarray.Dataset]: 139 | accumulator.append(element) 140 | return accumulator 141 | 142 | def merge_accumulators( 143 | self, accumulators: List[List[xarray.Dataset]]) -> List[xarray.Dataset]: 144 | return [xarray.concat(sum(accumulators, []), dim=self._dim)] 145 | 146 | def extract_output( 147 | self, accumulator: List[xarray.Dataset]) -> xarray.Dataset: 148 | if accumulator: 149 | ds = xarray.concat(accumulator, dim=self._dim) 150 | else: 151 | # NOTE(shoyer): I'm not quite sure why, but Beam needs to be able to run 152 | # this step on a empty accumulator. 153 | ds = xarray.Dataset() 154 | return ds 155 | -------------------------------------------------------------------------------- /pde_superresolution/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Layers for 1D convolutional networks with periodic boundary conditions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from typing import Any, Union 23 | 24 | 25 | def static_or_dynamic_size( 26 | tensor: tf.Tensor, axis: int) -> Union[int, tf.Tensor]: 27 | """Return the size of a tensor dimension, as an integer if possible.""" 28 | try: 29 | static_size = tensor.shape[axis].value 30 | except IndexError: 31 | raise ValueError('axis {} out of bounds for tensor with shape {}' 32 | .format(axis, tensor.shape)) 33 | if static_size is not None: 34 | return static_size 35 | else: 36 | return tf.shape(tensor)[axis] 37 | 38 | 39 | def pad_periodic(inputs: tf.Tensor, 40 | padding: int, 41 | center: bool = False, 42 | name: str = None): 43 | """Pad a 3D tensor with periodic boundary conditions along the second axis. 44 | 45 | Args: 46 | inputs: tensor with shape [batch_size, length, num_features]. 47 | padding: integer amount of padding to add along the length axis. 48 | center: bool indicating whether to center convolutions or not. Useful if you 49 | need to align convolutional layers with different kernels. 50 | name: optional name for this operation. 51 | 52 | Returns: 53 | Padded tensor. 54 | 55 | Raises: 56 | ValueError: if the convolution kernel would span more than once across the 57 | periodic dimension. 58 | """ 59 | if len(inputs.shape) != 3: 60 | raise ValueError('inputs must be 3D for periodic padding') 61 | 62 | with tf.name_scope(name, 'pad_periodic', [inputs]) as scope: 63 | inputs = tf.convert_to_tensor(inputs, name='inputs') 64 | 65 | if padding == 0: 66 | # allow assuming padding > 0 67 | return tf.identity(inputs, name=scope) 68 | 69 | num_x_points = static_or_dynamic_size(inputs, axis=1) 70 | if center: 71 | repeats = (padding // 2) // num_x_points 72 | else: 73 | repeats = padding // num_x_points 74 | tiled_inputs = tf.tile(inputs, (1, 1 + repeats, 1)) 75 | 76 | if center: 77 | inputs_list = [tiled_inputs[:, -padding//2:, :], 78 | inputs, 79 | tiled_inputs[:, :padding//2, :]] 80 | else: 81 | inputs_list = [inputs, tiled_inputs[:, :padding, :]] 82 | 83 | return tf.concat(inputs_list, axis=1, name=scope) 84 | 85 | 86 | def _check_periodic_layer_shape( 87 | inputs: tf.Tensor, outputs: tf.Tensor, strides: int) -> None: 88 | """Verify that a periodic 1d layer changes length as expected.""" 89 | num_x_points = inputs.shape[1].value 90 | if num_x_points is not None: 91 | expected_in_length = num_x_points * strides 92 | assert expected_in_length == num_x_points, (outputs, inputs) 93 | 94 | 95 | def nn_conv1d_periodic(inputs: tf.Tensor, filters: tf.Tensor, stride: int = 1, 96 | center: bool = False, **kwargs: Any) -> tf.Tensor: 97 | """tf.nn.conv1d with periodic boundary conditions.""" 98 | padded_inputs = pad_periodic( 99 | inputs, filters.shape[0].value - 1, center=center) 100 | return tf.nn.conv1d(padded_inputs, filters, stride, padding='VALID', **kwargs) 101 | 102 | 103 | def conv1d_periodic_layer(inputs: tf.Tensor, 104 | filters: int, 105 | kernel_size: int, 106 | strides: int = 1, 107 | dilation_rate: int = 1, 108 | center: bool = False, 109 | **kwargs: Any) -> tf.Tensor: 110 | """1D convolutional layer with periodic boundary conditions. 111 | 112 | Args: 113 | inputs: tensor with shape [batch_size, length, num_features]. 114 | filters: integer filter size, the number of output channels. 115 | kernel_size: integer size of the kernel to apply. 116 | strides: integer specifying the stride length of the convolution. 117 | dilation_rate: integer specifying the dilation rate of the convolution. 118 | center: bool indicating whether to center convolutions or not. Useful if you 119 | need to align convolutional layers with different kernels. If kernel_size 120 | is even, then the result is shifted one half unit size to the left, e.g., 121 | for kernel_size=2, position 1 in the result by convolving over positions 122 | 0 and 1 on inputs. 123 | **kwargs: passed on to tf.layers.conv1d. 124 | 125 | Returns: 126 | Tensor with shape [batch_size, ceil(length / strides), filters]. 127 | """ 128 | with tf.name_scope('conv1d_periodic_layer'): 129 | padding = (kernel_size - 1) * dilation_rate 130 | padded_inputs = pad_periodic(inputs, padding, center) 131 | outputs = tf.layers.conv1d(padded_inputs, filters, kernel_size, 132 | padding='valid', 133 | strides=strides, 134 | dilation_rate=dilation_rate, 135 | **kwargs) 136 | _check_periodic_layer_shape(inputs, outputs, strides) 137 | return outputs 138 | 139 | 140 | def max_pooling1d_periodic(inputs: tf.Tensor, 141 | pool_size: int, 142 | strides: int = 1, 143 | center: bool = False) -> tf.Tensor: 144 | """1D max pooling layer with periodic boundary conditions. 145 | 146 | Args: 147 | inputs: tensor with shape [batch_size, length, num_features]. 148 | pool_size: integer size of the pooling window. 149 | strides: integer specifying the stride length. 150 | center: bool indicating whether to center convolutions or not. Useful if you 151 | need to align convolutional layers with different kernels. If kernel_size 152 | is even, then the result is shifted one half unit size to the left, e.g., 153 | for kernel_size=2, position 1 in the result by convolving over positions 154 | 0 and 1 on inputs. 155 | 156 | Returns: 157 | Tensor with shape [batch_size, ceil(length / strides), filters]. 158 | """ 159 | with tf.name_scope('max_pooling1d_periodic'): 160 | padded_inputs = pad_periodic(inputs, pool_size - 1, center) 161 | outputs = tf.layers.max_pooling1d(padded_inputs, pool_size, strides, 162 | padding='valid') 163 | _check_periodic_layer_shape(inputs, outputs, strides) 164 | return outputs 165 | -------------------------------------------------------------------------------- /pde_superresolution/polynomials_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for polynomial finite differences.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import absltest # pylint: disable=g-bad-import-order 22 | from absl.testing import parameterized 23 | import numpy as np 24 | 25 | from pde_superresolution import polynomials 26 | 27 | 28 | FINITE_DIFF = polynomials.Method.FINITE_DIFFERENCES 29 | FINITE_VOL = polynomials.Method.FINITE_VOLUMES 30 | 31 | 32 | class PolynomialsTest(parameterized.TestCase): 33 | 34 | # For test-cases, see 35 | # https://en.wikipedia.org/wiki/Finite_difference_coefficient 36 | @parameterized.parameters( 37 | dict(grid=[-1, 0, 1], derivative_order=1, expected=[-1/2, 0, 1/2]), 38 | dict(grid=[-1, 0, 1], derivative_order=2, expected=[1, -2, 1]), 39 | dict(grid=[-2, -1, 0, 1, 2], derivative_order=2, 40 | expected=[-1/12, 4/3, -5/2, 4/3, -1/12]), 41 | dict(grid=[0, 1], derivative_order=1, expected=[-1, 1]), 42 | dict(grid=[0, 2], derivative_order=1, expected=[-0.5, 0.5]), 43 | dict(grid=[0, 0.5], derivative_order=1, expected=[-2, 2]), 44 | dict(grid=[0, 1, 2, 3, 4], derivative_order=4, 45 | expected=[1, -4, 6, -4, 1]), 46 | ) 47 | def test_finite_difference_coefficients( 48 | self, grid, derivative_order, expected): 49 | result = polynomials.coefficients( 50 | np.array(grid), FINITE_DIFF, derivative_order) 51 | np.testing.assert_allclose(result, expected) 52 | 53 | # based in part on standard WENO coefficients 54 | @parameterized.parameters( 55 | dict(grid=[-0.5, 0.5], derivative_order=0, expected=[1/2, 1/2]), 56 | dict(grid=[-1, 1], derivative_order=0, expected=[1/2, 1/2]), 57 | dict(grid=[-1.5, -0.5], derivative_order=0, expected=[-1/2, 3/2]), 58 | dict(grid=[-0.5, 0.5, 1.5], derivative_order=0, 59 | expected=[1/3, 5/6, -1/6]), 60 | dict(grid=[-0.25, 0.25, 0.75], derivative_order=0, 61 | expected=[1/3, 5/6, -1/6]), 62 | dict(grid=[2.5, 1.5, 0.5, -0.5, -1.5], derivative_order=0, 63 | expected=[2/60, -13/60, 47/60, 27/60, -3/60]), 64 | dict(grid=[-0.5, 0.5], derivative_order=1, expected=[-1, 1]), 65 | dict(grid=[-1, 1], derivative_order=1, expected=[-1/2, 1/2]), 66 | dict(grid=[0.5, 1.5, 2.5], derivative_order=1, expected=[-2, 3, -1]), 67 | dict(grid=[-1.5, -0.5, 0.5, 1.5], derivative_order=1, 68 | expected=[1/12, -5/4, 5/4, -1/12]), 69 | dict(grid=[-.75, -0.25, 0.25, 0.75], derivative_order=1, 70 | expected=[1/6, -5/2, 5/2, -1/6]), 71 | ) 72 | def test_finite_volume_coefficients( 73 | self, grid, derivative_order, expected): 74 | result = polynomials.coefficients( 75 | np.array(grid), FINITE_VOL, derivative_order) 76 | np.testing.assert_allclose(result, expected) 77 | 78 | def test_finite_difference_constraints(self): 79 | # first and second order accuracy should be identical constraints 80 | grid = np.array([-1, 0, 1]) 81 | a1, b1 = polynomials.constraints( 82 | grid, FINITE_DIFF, derivative_order=1, accuracy_order=1) 83 | a1, b1 = polynomials.constraints( 84 | grid, FINITE_DIFF, derivative_order=1, accuracy_order=2) 85 | np.testing.assert_allclose(a1, a1) 86 | np.testing.assert_allclose(b1, b1) 87 | 88 | @parameterized.parameters( 89 | dict(grid=[-2, -1, 0, 1, 2], method=FINITE_DIFF, derivative_order=1), 90 | dict(grid=[-2, -1, 0, 1, 2], method=FINITE_DIFF, derivative_order=2), 91 | dict(grid=[-1.5, -0.5, 0.5, 1.5], method=FINITE_DIFF, derivative_order=1), 92 | dict(grid=[-1.5, -0.5, 0.5, 1.5], method=FINITE_VOL, derivative_order=1), 93 | ) 94 | def test_polynomial_accuracy_layer_consistency( 95 | self, grid, method, derivative_order, accuracy_order=2): 96 | args = (np.array(grid), method, derivative_order, accuracy_order) 97 | A, b = polynomials.constraints(*args) # pylint: disable=invalid-name 98 | layer = polynomials.PolynomialAccuracyLayer(*args) 99 | 100 | inputs = np.random.RandomState(0).randn(10, layer.input_size) 101 | outputs = layer.bias + np.einsum('bi,ij->bj', inputs, layer.nullspace) 102 | 103 | residual = np.einsum('ij,bj->bi', A, outputs) - b 104 | np.testing.assert_allclose(residual, 0, atol=1e-7) 105 | 106 | def test_polynomial_accuracy_layer_bias_zero_padding(self): 107 | layer = polynomials.PolynomialAccuracyLayer( 108 | np.array([-1.5, -0.5, 0.5, 1.5]), FINITE_DIFF, derivative_order=0, 109 | bias_zero_padding=(0, 1)) 110 | expected_bias = np.concatenate( 111 | [polynomials.coefficients( 112 | np.array([-1.5, -0.5, 0.5]), FINITE_DIFF, derivative_order=0), 113 | [0.0]]) 114 | np.testing.assert_allclose(layer.bias, expected_bias) 115 | 116 | @parameterized.parameters( 117 | dict(derivative_order=0, 118 | grid_offset=polynomials.GridOffset.CENTERED, 119 | expected_grid=[0]), 120 | dict(derivative_order=1, 121 | grid_offset=polynomials.GridOffset.CENTERED, 122 | expected_grid=[-1, 0, 1]), 123 | dict(derivative_order=2, 124 | grid_offset=polynomials.GridOffset.CENTERED, 125 | expected_grid=[-1, 0, 1]), 126 | dict(derivative_order=3, 127 | grid_offset=polynomials.GridOffset.CENTERED, 128 | expected_grid=[-2, -1, 0, 1, 2]), 129 | dict(derivative_order=4, 130 | grid_offset=polynomials.GridOffset.CENTERED, 131 | expected_grid=[-2, -1, 0, 1, 2]), 132 | dict(derivative_order=0, 133 | grid_offset=polynomials.GridOffset.STAGGERED, 134 | expected_grid=[-0.5, 0.5]), 135 | dict(derivative_order=1, 136 | grid_offset=polynomials.GridOffset.STAGGERED, 137 | expected_grid=[-0.5, 0.5]), 138 | dict(derivative_order=2, 139 | grid_offset=polynomials.GridOffset.STAGGERED, 140 | expected_grid=[-1.5, -0.5, 0.5, 1.5]), 141 | dict(derivative_order=3, 142 | grid_offset=polynomials.GridOffset.STAGGERED, 143 | expected_grid=[-1.5, -0.5, 0.5, 1.5]), 144 | dict(derivative_order=0, 145 | accuracy_order=6, 146 | grid_offset=polynomials.GridOffset.CENTERED, 147 | expected_grid=[-3, -2, -1, 0, 1, 2, 3]), 148 | dict(derivative_order=0, 149 | accuracy_order=6, 150 | grid_offset=polynomials.GridOffset.STAGGERED, 151 | expected_grid=[-2.5, -1.5, -0.5, 0.5, 1.5, 2.5]), 152 | ) 153 | def test_regular_grid( 154 | self, grid_offset, derivative_order, expected_grid, accuracy_order=1): 155 | actual_grid = polynomials.regular_grid( 156 | grid_offset, derivative_order, accuracy_order) 157 | np.testing.assert_allclose(actual_grid, expected_grid) 158 | 159 | if __name__ == '__main__': 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /pde_superresolution/duckarray.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Duck array functions that work on NumPy arrays and TensorFlow tensors. 16 | 17 | TODO(shoyer): remove this in favor of a comprehensive solution. 18 | """ 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | from typing import List, Sequence, Optional, Tuple, TypeVar, Union 26 | 27 | 28 | # TODO(shoyer): replace with TypeVar('T', np.ndarray, tf.Tensor) when pytype 29 | # supports it (b/74212131) 30 | T = TypeVar('T') 31 | 32 | 33 | def concatenate(arrays: List[T], axis: int) -> T: 34 | """Concatenate arrays or tensors.""" 35 | if isinstance(arrays[0], tf.Tensor): 36 | return tf.concat(arrays, axis=axis) 37 | else: 38 | return np.concatenate(arrays, axis=axis) 39 | 40 | 41 | def stack(arrays: List[T], axis: int) -> T: 42 | """Stack arrays or tensors.""" 43 | if isinstance(arrays[0], tf.Tensor): 44 | return tf.stack(arrays, axis=axis) 45 | else: 46 | return np.stack(arrays, axis=axis) 47 | 48 | 49 | def sin(x: T) -> T: 50 | if isinstance(x, tf.Tensor): 51 | return tf.sin(x) 52 | else: 53 | return np.sin(x) 54 | 55 | 56 | def sum(x: T, axis: int = None, **kwargs) -> T: # pylint: disable=redefined-builtin 57 | if isinstance(x, tf.Tensor): 58 | return tf.reduce_sum(x, axis=axis, **kwargs) 59 | else: 60 | return np.sum(x, axis=axis, **kwargs) 61 | 62 | 63 | def mean(x: T, axis: int = None, **kwargs) -> T: 64 | if isinstance(x, tf.Tensor): 65 | return tf.reduce_mean(x, axis=axis, **kwargs) 66 | else: 67 | return np.mean(x, axis=axis, **kwargs) 68 | 69 | 70 | def get_shape(x: Union[tf.Tensor, np.ndarray]) -> Tuple[Optional[int]]: 71 | if isinstance(x, tf.Tensor): 72 | return tuple(x.shape.as_list()) # pytype: disable=attribute-error 73 | else: 74 | return x.shape 75 | 76 | 77 | def reshape(x: T, shape: Sequence[int]) -> T: 78 | if isinstance(x, tf.Tensor): 79 | return tf.reshape(x, shape) 80 | else: 81 | return np.reshape(x, shape) 82 | 83 | 84 | def maximum(x: T, y: T) -> T: 85 | return tf.maximum(x, y) if isinstance(x, tf.Tensor) else np.maximum(x, y) 86 | 87 | 88 | def minimum(x: T, y: T) -> T: 89 | return tf.minimum(x, y) if isinstance(x, tf.Tensor) else np.minimum(x, y) 90 | 91 | 92 | def where(cond: T, x: T, y: T) -> T: 93 | where_ = tf.where if isinstance(cond, tf.Tensor) else np.where 94 | return where_(cond, x, y) 95 | 96 | 97 | def rfft(x: T) -> T: 98 | return tf.spectral.rfft(x) if isinstance(x, tf.Tensor) else np.fft.rfft(x) 99 | 100 | 101 | def irfft(x: T) -> T: 102 | return tf.spectral.irfft(x) if isinstance(x, tf.Tensor) else np.fft.irfft(x) 103 | 104 | 105 | def spectral_derivative(x: T, order: int = 1, period: float = 2*np.pi) -> T: 106 | """Differentiate along the last axis of x using a Fourier transform.""" 107 | length = get_shape(x)[-1] 108 | if length % 2: 109 | raise ValueError('spectral derivative only works for even length data') 110 | c = 2*np.pi*1j / period 111 | k = np.fft.rfftfreq(length, d=1/length) 112 | return irfft((c * k) ** order * rfft(x)) 113 | 114 | 115 | def smoothing_filter(x: T, 116 | alpha: float = -np.log(1e-15), 117 | order: int = 2) -> T: 118 | """Apply a low-pass smoothing filter to remove noise.""" 119 | # Based on: 120 | # Gottlieb and Hesthaven (2001), "Spectral methods for hyperbolic problems" 121 | # https://doi.org/10.1016/S0377-0427(00)00510-0 122 | length = get_shape(x)[-1] 123 | if length % 2: 124 | raise ValueError('smoothing filter only works for even length data') 125 | count = length // 2 126 | eta = np.arange(count+1) / count 127 | sigma = np.exp(-alpha * eta**(2*order)) 128 | return irfft(sigma * rfft(x)) 129 | 130 | 131 | def _normalize_axis(axis: int, ndim: int) -> int: 132 | if not -ndim <= axis < ndim: 133 | raise ValueError('invalid axis {} for ndim {}'.format(axis, ndim)) 134 | if axis < 0: 135 | axis += ndim 136 | return axis 137 | 138 | 139 | def resample_mean(inputs: T, factor: int, axis: int = -1) -> T: 140 | """Resample data to a lower-resolution with the mean. 141 | 142 | Args: 143 | inputs: array with dimensions [batch, x, ...]. 144 | factor: integer factor by which to reduce the size of the x-dimension. 145 | axis: integer axis to resample over. 146 | 147 | Returns: 148 | Array with dimensions [batch, x//factor, ...]. 149 | 150 | Raises: 151 | ValueError: if x is not evenly divided by factor. 152 | """ 153 | shape = get_shape(inputs) 154 | axis = _normalize_axis(axis, len(shape)) 155 | if shape[axis] % factor: 156 | raise ValueError('resample factor {} must divide size {}' 157 | .format(factor, shape[axis])) 158 | 159 | new_shape = shape[:axis] + (shape[axis] // factor, factor) + shape[axis+1:] 160 | new_shape = [-1 if size is None else size for size in new_shape] 161 | 162 | reshaped = reshape(inputs, new_shape) 163 | return mean(reshaped, axis=axis+1) 164 | 165 | 166 | def subsample(inputs: T, factor: int, axis: int = -1) -> T: 167 | """Resample data to a lower-resolution by subsampling data-points. 168 | 169 | Args: 170 | inputs: array with dimensions [batch, x, ...]. 171 | factor: integer factor by which to reduce the size of the x-dimension. 172 | axis: integer axis to resample over. 173 | 174 | Returns: 175 | Array with dimensions [batch, x//factor, ...]. 176 | 177 | Raises: 178 | ValueError: if x is not evenly divided by factor. 179 | """ 180 | shape = get_shape(inputs) 181 | axis = _normalize_axis(axis, len(shape)) 182 | if shape[axis] % factor: 183 | raise ValueError('resample factor {} must divide size {}' 184 | .format(factor, shape[axis])) 185 | 186 | indexer = [slice(None)] * len(shape) 187 | indexer[axis] = slice(None, None, factor) 188 | 189 | return inputs[tuple(indexer)] 190 | 191 | 192 | def _roll_once( 193 | tensor: T, 194 | shift: int, 195 | axis: int, 196 | ) -> T: 197 | """Roll along a single dimension like tf.roll().""" 198 | if not shift: 199 | return tensor 200 | axis = _normalize_axis(axis, len(tensor.shape)) 201 | slice_left = (slice(None),) * axis + (slice(-shift, None),) 202 | slice_right = (slice(None),) * axis + (slice(None, -shift),) 203 | return concatenate([tensor[slice_left], tensor[slice_right]], axis=axis) 204 | 205 | 206 | def roll( 207 | tensor: T, 208 | shift: Union[int, Sequence[int]], 209 | axis: Union[int, Sequence[int]], 210 | ) -> T: 211 | """Like tf.roll(), but runs on GPU as a well as CPU.""" 212 | if isinstance(axis, int): 213 | axis = [axis] 214 | if isinstance(shift, int): 215 | shift = [shift] 216 | result = tensor 217 | for axis_element, shift_element in zip(axis, shift): 218 | result = _roll_once(result, shift_element, axis_element) 219 | return result 220 | 221 | 222 | RESAMPLE_FUNCS = { 223 | 'mean': resample_mean, 224 | 'subsample': subsample, 225 | } 226 | -------------------------------------------------------------------------------- /pde_superresolution/integrate_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Sanity tests for training a model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | import tempfile 23 | 24 | from absl import flags 25 | from absl.testing import absltest # pylint: disable=g-bad-import-order 26 | from absl.testing import parameterized 27 | import numpy as np 28 | import tensorflow as tf 29 | import xarray 30 | 31 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 32 | from pde_superresolution import equations # pylint: disable=g-bad-import-order 33 | from pde_superresolution import integrate # pylint: disable=g-bad-import-order 34 | from pde_superresolution import training # pylint: disable=g-bad-import-order 35 | from pde_superresolution import weno # pylint: disable=g-bad-import-order 36 | 37 | 38 | FLAGS = flags.FLAGS 39 | 40 | NUM_X_POINTS = 256 41 | RANDOM_SEED = 0 42 | 43 | 44 | class IntegrateTest(parameterized.TestCase): 45 | 46 | def setUp(self): 47 | self.checkpoint_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir) 48 | self.model_kwargs = dict(num_layers=1, filter_size=32) 49 | 50 | def train(self, hparams): 51 | # train a model on random noise 52 | with tf.Graph().as_default(): 53 | snapshots = 0.01 * np.random.RandomState(0).randn(100, NUM_X_POINTS) 54 | training.training_loop(snapshots, self.checkpoint_dir, hparams) 55 | 56 | @parameterized.parameters( 57 | dict(equation='burgers'), 58 | dict(equation='kdv'), 59 | dict(equation='ks'), 60 | dict(equation='burgers', conservative=True), 61 | dict(equation='kdv', conservative=True), 62 | dict(equation='ks', conservative=True), 63 | dict(equation='burgers', conservative=True, numerical_flux=True), 64 | dict(equation='kdv', conservative=True, numerical_flux=True), 65 | dict(equation='ks', conservative=True, numerical_flux=True), 66 | dict(equation='burgers', warmup=1), 67 | dict(equation='burgers', warmup=1, conservative=True), 68 | dict(equation='kdv', warmup=1, conservative=True), 69 | dict(equation='kdv', warmup=1, conservative=True, 70 | exact_filter_interval=1), 71 | ) 72 | def test_integrate_exact_baseline_and_model( 73 | self, warmup=0, conservative=False, resample_factor=4, 74 | exact_filter_interval=None, **hparam_values): 75 | hparams = training.create_hparams( 76 | learning_rates=[1e-3], 77 | learning_stops=[20], 78 | eval_interval=10, 79 | equation_kwargs=json.dumps({'num_points': NUM_X_POINTS}), 80 | conservative=conservative, 81 | resample_factor=resample_factor, 82 | **hparam_values) 83 | self.train(hparams) 84 | 85 | results = integrate.integrate_exact_baseline_and_model( 86 | self.checkpoint_dir, 87 | random_seed=RANDOM_SEED, 88 | times=np.linspace(0, 1, num=11), 89 | warmup=warmup, 90 | exact_filter_interval=exact_filter_interval) 91 | 92 | self.assertIsInstance(results, xarray.Dataset) 93 | self.assertEqual(dict(results.dims), 94 | {'time': 11, 95 | 'x_high': NUM_X_POINTS, 96 | 'x_low': NUM_X_POINTS // resample_factor}) 97 | self.assertEqual(results['y_exact'].dims, ('time', 'x_high')) 98 | self.assertEqual(results['y_baseline'].dims, ('time', 'x_low')) 99 | self.assertEqual(results['y_model'].dims, ('time', 'x_low')) 100 | 101 | with self.subTest('average should be zero'): 102 | y_exact_mean = results.y_exact.mean('x_high') 103 | xarray.testing.assert_allclose( 104 | y_exact_mean, xarray.zeros_like(y_exact_mean), atol=1e-3) 105 | 106 | with self.subTest('matching initial conditions'): 107 | if conservative: 108 | resample = duckarray.resample_mean 109 | else: 110 | resample = duckarray.subsample 111 | y_exact = resample(results.y_exact.isel(time=0).values, 112 | resample_factor) 113 | np.testing.assert_allclose( 114 | y_exact, results.y_baseline.isel(time=0).values) 115 | np.testing.assert_allclose( 116 | y_exact, results.y_model.isel(time=0).values) 117 | 118 | with self.subTest('matches integrate_baseline'): 119 | equation_type = equations.equation_type_from_hparams(hparams) 120 | assert equation_type.CONSERVATIVE == conservative 121 | equation = equation_type(NUM_X_POINTS//resample_factor, 122 | resample_factor=resample_factor, 123 | random_seed=RANDOM_SEED) 124 | results2 = integrate.integrate_baseline( 125 | equation, times=np.linspace(0, 1, num=11), warmup=warmup) 126 | np.testing.assert_allclose( 127 | results['y_baseline'].data, results2['y'].data, atol=1e-5) 128 | 129 | @parameterized.parameters( 130 | dict(equation=equations.BurgersEquation(200)), 131 | dict(equation=equations.KdVEquation(200)), 132 | dict(equation=equations.KSEquation(200), warmup=50.0), 133 | ) 134 | def test_integrate_exact(self, equation, **kwargs): 135 | results = integrate.integrate_exact( 136 | equation, times=np.linspace(0, 1, num=11), **kwargs) 137 | self.assertIsInstance(results, xarray.Dataset) 138 | self.assertEqual(dict(results.dims), {'time': 11, 'x': 200}) 139 | self.assertEqual(results['y'].dims, ('time', 'x')) 140 | 141 | with self.subTest('average should be zero'): 142 | y_mean = results.y.mean('x') 143 | xarray.testing.assert_allclose( 144 | y_mean, xarray.zeros_like(y_mean), atol=1e-3) 145 | 146 | def test_burgers_exact_weno(self): 147 | equation = equations.BurgersEquation(200) 148 | results_exact = integrate.integrate_exact( 149 | equation, times=np.linspace(0, 1, num=11)) 150 | 151 | equation = equations.GodunovBurgersEquation(200) 152 | results_weno = integrate.integrate_weno( 153 | equation, times=np.linspace(0, 1, num=11)) 154 | np.testing.assert_allclose( 155 | results_exact['y'].data, results_weno['y'].data, atol=1e-10) 156 | 157 | @parameterized.parameters( 158 | dict(equation=equations.KdVEquation(200)), 159 | dict(equation=equations.KSEquation(200)), 160 | ) 161 | def test_spectral_exact(self, equation): 162 | results_exact = integrate.integrate_exact( 163 | equation, times=np.linspace(0, 1, num=11)) 164 | results_spectra = integrate.integrate_spectral( 165 | equation, times=np.linspace(0, 1, num=11)) 166 | np.testing.assert_allclose( 167 | results_exact['y'].data, results_spectra['y'].data, atol=1e-10) 168 | 169 | @parameterized.parameters( 170 | dict(equation=equations.BurgersEquation(200)), 171 | dict(equation=equations.ConservativeBurgersEquation(200)), 172 | dict(equation=equations.KdVEquation(200)), 173 | dict(equation=equations.KSEquation(200), warmup=50.0), 174 | ) 175 | def test_integrate_baseline(self, equation, **kwargs): 176 | results = integrate.integrate_baseline( 177 | equation, times=np.linspace(0, 1, num=11), **kwargs) 178 | self.assertIsInstance(results, xarray.Dataset) 179 | self.assertEqual(dict(results.dims), {'time': 11, 'x': 200}) 180 | self.assertEqual(results['y'].dims, ('time', 'x')) 181 | 182 | # average value should remain near 0 183 | y_mean = results.y.mean('x') 184 | xarray.testing.assert_allclose( 185 | y_mean, xarray.zeros_like(y_mean), atol=1e-3) 186 | 187 | @parameterized.parameters( 188 | dict(equation=equations.GodunovBurgersEquation(200)), 189 | dict(equation=equations.GodunovKdVEquation(200), tol=5e-3), 190 | dict(equation=equations.GodunovKSEquation(200)), 191 | ) 192 | def test_integrate_baseline_and_weno_consistency(self, equation, tol=1e-3): 193 | times = np.linspace(0, 1, num=11) 194 | results_baseline = integrate.integrate_baseline(equation, times=times) 195 | results_weno = integrate.integrate_weno(equation, times=times) 196 | xarray.testing.assert_allclose( 197 | results_baseline.drop('num_evals'), results_weno.drop('num_evals'), 198 | rtol=tol, atol=tol) 199 | 200 | 201 | if __name__ == '__main__': 202 | absltest.main() 203 | -------------------------------------------------------------------------------- /pde_superresolution/scripts/run_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Run a beam pipeline to evaluate our PDE models.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import json 21 | import pandas 22 | import os.path 23 | 24 | from absl import app 25 | from absl import flags 26 | import apache_beam as beam 27 | import numpy as np 28 | from pde_superresolution import analysis 29 | from pde_superresolution import duckarray 30 | from pde_superresolution import equations 31 | from pde_superresolution import integrate 32 | from pde_superresolution import training 33 | from pde_superresolution import xarray_beam 34 | import tensorflow as tf 35 | import xarray 36 | 37 | 38 | # NOTE(shoyer): allow_override=True lets us import multiple binaries for the 39 | # purpose of running integration tests. This is safe since we're strict about 40 | # only using FLAGS inside main(). 41 | 42 | # files 43 | flags.DEFINE_string( 44 | 'checkpoint_dir', '', 45 | 'Directory from which to load a trained model and save results.', 46 | allow_override=True) 47 | flags.DEFINE_string( 48 | 'exact_solution_path', '', 49 | 'Path from which to load the exact solution for an initial condition.', 50 | allow_override=True) 51 | flags.DEFINE_enum( 52 | 'equation_name', 'burgers', list(equations.CONSERVATIVE_EQUATION_TYPES), 53 | 'Equation to integrate.', allow_override=True) 54 | flags.DEFINE_string( 55 | 'equation_kwargs', '', 56 | 'If provided, use these parameters instead of those on the saved equation.', 57 | allow_override=True) 58 | flags.DEFINE_string( 59 | 'samples_output_name', 'results.nc', 60 | 'Name of the netCDF file in checkpoint_dir to which to save samples.') 61 | flags.DEFINE_string( 62 | 'mae_output_name', 'mae.nc', 63 | 'Name of the netCDF file in checkpoint_dir to which to save MAE results.') 64 | flags.DEFINE_string( 65 | 'survival_output_name', 'survival.nc', 66 | 'Name of the netCDF file in checkpoint_dir to which to save survival ' 67 | 'results.') 68 | flags.DEFINE_string( 69 | 'stop_times', json.dumps([13, 15, 20, 25, 51, 103]), 70 | 'Cut-off times to use when calculating MAE.') 71 | flags.DEFINE_string( 72 | 'quantiles', json.dumps([0.8, 0.9, 0.95]), 73 | 'Quantiles to use for "good enough".') 74 | 75 | # integrate parameters 76 | flags.DEFINE_integer( 77 | 'num_samples', 10, 78 | 'Number of times to integrate each equation.', 79 | allow_override=True) 80 | flags.DEFINE_float( 81 | 'time_max', 10, 82 | 'Total time for which to run each integration.', 83 | allow_override=True) 84 | flags.DEFINE_float( 85 | 'time_delta', 0.05, 86 | 'Difference between saved time steps in the integration.', 87 | allow_override=True) 88 | flags.DEFINE_float( 89 | 'warmup', 0, 90 | 'Amount of time to integrate before using the neural network.', 91 | allow_override=True) 92 | flags.DEFINE_string( 93 | 'integrate_method', 'RK23', 94 | 'Method to use for integration with scipy.integrate.solve_ivp.', 95 | allow_override=True) 96 | flags.DEFINE_float( 97 | 'exact_filter_interval', 0, 98 | 'Interval between periodic filtering. Only used for spectral methods.', 99 | allow_override=True) 100 | 101 | 102 | FLAGS = flags.FLAGS 103 | 104 | _METRICS_NAMESPACE = 'finitediff/run_integrate' 105 | 106 | 107 | def get_counter_metric(name): 108 | return beam.metrics.Metrics.counter(_METRICS_NAMESPACE, name) 109 | 110 | 111 | def count_start_finish(func, name=None): 112 | """Run a function with Beam metric counters for each start/finish.""" 113 | if name is None: 114 | name = func.__name__ 115 | 116 | def wrapper(*args, **kwargs): 117 | get_counter_metric('%s_started' % name).inc() 118 | get_counter_metric('%s_in_progress' % name).inc() 119 | results = func(*args, **kwargs) 120 | get_counter_metric('%s_in_progress' % name).dec() 121 | get_counter_metric('%s_finished' % name).inc() 122 | return results 123 | return wrapper 124 | 125 | 126 | def main(_, runner=None): 127 | if runner is None: 128 | # must create before flags are used 129 | runner = beam.runners.DirectRunner() 130 | 131 | hparams = training.load_hparams(FLAGS.checkpoint_dir) 132 | 133 | if FLAGS.equation_kwargs: 134 | hparams.set_hparam('equation_kwargs', FLAGS.equation_kwargs) 135 | 136 | def load_initial_conditions(path=FLAGS.exact_solution_path, 137 | num_samples=FLAGS.num_samples): 138 | ds = xarray_beam.read_netcdf(path) 139 | initial_conditions = duckarray.resample_mean( 140 | ds['y'].isel(time=0).data, hparams.resample_factor) 141 | 142 | if np.isnan(initial_conditions).any(): 143 | raise ValueError('initial conditions cannot have NaNs') 144 | if ds.sizes['sample'] != num_samples: 145 | raise ValueError('invalid number of samples in exact dataset') 146 | 147 | for seed in range(num_samples): 148 | y0 = initial_conditions[seed, :] 149 | assert y0.ndim == 1 150 | yield (seed, y0) 151 | 152 | def run_integrate( 153 | seed_and_initial_condition, 154 | checkpoint_dir=FLAGS.checkpoint_dir, 155 | times=np.arange(0, FLAGS.time_max + FLAGS.time_delta, FLAGS.time_delta), 156 | warmup=FLAGS.warmup, 157 | integrate_method=FLAGS.integrate_method, 158 | ): 159 | random_seed, y0 = seed_and_initial_condition 160 | _, equation_coarse = equations.from_hparams( 161 | hparams, random_seed=random_seed) 162 | checkpoint_path = training.checkpoint_dir_to_path(checkpoint_dir) 163 | differentiator = integrate.SavedModelDifferentiator( 164 | checkpoint_path, equation_coarse, hparams) 165 | solution_model, num_evals_model = integrate.odeint( 166 | y0, differentiator, warmup+times, method=integrate_method) 167 | 168 | results = xarray.Dataset( 169 | data_vars={'y': (('time', 'x'), solution_model)}, 170 | coords={'time': warmup+times, 171 | 'x': equation_coarse.grid.solution_x, 172 | 'num_evals': num_evals_model, 173 | 'sample': random_seed}) 174 | return results 175 | 176 | samples_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.samples_output_name) 177 | mae_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.mae_output_name) 178 | survival_path = os.path.join(FLAGS.checkpoint_dir, FLAGS.survival_output_name) 179 | 180 | def finalize( 181 | ds_model, 182 | exact_path=FLAGS.exact_solution_path, 183 | stop_times=json.loads(FLAGS.stop_times), 184 | quantiles=json.loads(FLAGS.quantiles), 185 | ): 186 | ds_model = ds_model.sortby('sample') 187 | xarray_beam.write_netcdf(ds_model, samples_path) 188 | 189 | # build combined dataset 190 | ds_exact = xarray_beam.read_netcdf(exact_path) 191 | ds = ds_model.rename({'y': 'y_model', 'x': 'x_low'}) 192 | ds['y_exact'] = ds_exact['y'].rename({'x': 'x_high'}) 193 | unified = analysis.unify_x_coords(ds) 194 | 195 | # calculate MAE 196 | results = [] 197 | for time_max in stop_times: 198 | ds_sel = unified.sel(time=slice(None, time_max)) 199 | mae = abs(ds_sel.drop('y_exact') - ds_sel.y_exact).mean( 200 | ['x', 'time'], skipna=False) 201 | results.append(mae) 202 | dim = pandas.Index(stop_times, name='time_max') 203 | mae_all = xarray.concat(results, dim=dim) 204 | xarray_beam.write_netcdf(mae_all, mae_path) 205 | 206 | # calculate survival 207 | survival_all = xarray.concat( 208 | [analysis.mostly_good_survival(ds, q) for q in quantiles], 209 | dim=pandas.Index(quantiles, name='quantile')) 210 | xarray_beam.write_netcdf(survival_all, survival_path) 211 | 212 | pipeline = ( 213 | 'create' >> beam.Create(range(1)) 214 | | 'load' >> beam.FlatMap(lambda _: load_initial_conditions()) 215 | | 'reshuffle' >> beam.Reshuffle() 216 | | 'integrate' >> beam.Map( 217 | count_start_finish(run_integrate, name='run_integrate')) 218 | | 'combine' >> beam.CombineGlobally(xarray_beam.ConcatCombineFn('sample')) 219 | | 'finalize' >> beam.Map(finalize) 220 | ) 221 | runner.run(pipeline) 222 | 223 | 224 | if __name__ == '__main__': 225 | flags.mark_flag_as_required('checkpoint_dir') 226 | app.run(main) 227 | 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /pde_superresolution/polynomials.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Polynomial based models for finite differences and finite volumes.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import enum 22 | 23 | import numpy as np 24 | import scipy.special 25 | import tensorflow as tf 26 | from typing import Tuple 27 | 28 | from pde_superresolution import layers # pylint: disable=g-bad-import-order 29 | 30 | 31 | class GridOffset(enum.Enum): 32 | """Relationship between successive grids.""" 33 | CENTERED = 1 34 | STAGGERED = 2 35 | 36 | 37 | class Method(enum.Enum): 38 | """Discretization method.""" 39 | FINITE_DIFFERENCES = 1 40 | FINITE_VOLUMES = 2 41 | 42 | 43 | def regular_grid( 44 | grid_offset: GridOffset, 45 | derivative_order: int, 46 | accuracy_order: int = 1, 47 | dx: float = 1) -> np.ndarray: 48 | """Return the smallest grid on which finite differences can be calculated. 49 | 50 | Args: 51 | grid_offset: offset between input and output grids. 52 | derivative_order: integer derivative order to calculate. 53 | accuracy_order: integer order of polynomial accuracy to enforce. By default, 54 | only 1st order accuracy is guaranteed. 55 | dx: difference between grid points. 56 | 57 | Returns: 58 | 1D numpy array giving positions at which to calculate finite differences. 59 | """ 60 | min_grid_size = derivative_order + accuracy_order 61 | 62 | if grid_offset is GridOffset.CENTERED: 63 | max_offset = min_grid_size // 2 # 1 -> 0, 2 -> 1, 3 -> 1, 4 -> 2, ... 64 | grid = np.arange(-max_offset, max_offset + 1) * dx 65 | elif grid_offset is GridOffset.STAGGERED: 66 | max_offset = (min_grid_size + 1) // 2 # 1 -> 1, 2 -> 1, 3 -> 2, 4 -> 2, ... 67 | grid = (0.5 + np.arange(-max_offset, max_offset)) * dx 68 | else: 69 | raise ValueError('unexpected grid_offset: {}'.format(grid_offset)) # pylint: disable=g-doc-exception 70 | 71 | return grid 72 | 73 | 74 | def constraints( 75 | grid: np.ndarray, 76 | method: Method, 77 | derivative_order: int, 78 | accuracy_order: int = None) -> Tuple[np.ndarray, np.ndarray]: 79 | """Setup the linear equation A @ c = b for finite difference coefficients. 80 | 81 | Args: 82 | grid: grid on which to calculate the finite difference stencil, relative 83 | to the point at which to approximate the derivative. The grid must be 84 | regular. 85 | method: discretization method. 86 | derivative_order: integer derivative order to approximate. 87 | accuracy_order: minimum accuracy order for the solution. 88 | 89 | Returns: 90 | Tuple of arrays `(A, b)` where `A` is 2D and `b` is 1D providing linear 91 | constraints. Any vector of finite difference coefficients `c` such that 92 | `A @ c = b` satisfies the requested accuracy order. The matrix `A` is 93 | guaranteed not to have more rows than columns. 94 | 95 | Raises: 96 | ValueError: if the linear constraints are not satisfiable. 97 | 98 | References: 99 | https://en.wikipedia.org/wiki/Finite_difference_coefficient 100 | Fornberg, Bengt (1988), "Generation of Finite Difference Formulas on 101 | Arbitrarily Spaced Grids", Mathematics of Computation, 51 (184): 699-706, 102 | doi:10.1090/S0025-5718-1988-0935077-0, ISSN 0025-5718. 103 | """ 104 | if accuracy_order is None: 105 | # Use the highest order accuracy we can ensure in general. (In some cases, 106 | # e.g., centered finite differences, this solution actually has higher order 107 | # accuracy.) 108 | accuracy_order = grid.size - derivative_order 109 | 110 | if accuracy_order < 1: 111 | raise ValueError('cannot compute constriants with non-positive ' 112 | 'accuracy_order: {}'.format(accuracy_order)) 113 | 114 | deltas = np.unique(np.diff(grid)) 115 | if (abs(deltas - deltas[0]) > 1e-8).any(): 116 | raise ValueError('not a regular grid: {}'.format(deltas)) 117 | delta = deltas[0] 118 | 119 | final_constraint = None 120 | zero_constraints = set() 121 | for m in range(accuracy_order + derivative_order): 122 | if method is Method.FINITE_DIFFERENCES: 123 | constraint = grid ** m 124 | elif method is Method.FINITE_VOLUMES: 125 | constraint = (1 / delta 126 | * ((grid + delta/2) ** (m + 1) 127 | - (grid - delta/2) ** (m + 1)) 128 | / (m + 1)) 129 | else: 130 | raise ValueError('unexpected method: {}'.format(method)) 131 | if m == derivative_order: 132 | final_constraint = constraint 133 | else: 134 | zero_constraints.add(tuple(constraint)) 135 | 136 | assert final_constraint is not None 137 | 138 | num_constraints = len(zero_constraints) + 1 139 | if num_constraints > grid.size: 140 | raise ValueError('no valid {} stencil exists for derivative_order={} and ' 141 | 'accuracy_order={} with grid={}' 142 | .format(method, derivative_order, accuracy_order, grid)) 143 | 144 | A = np.array(sorted(zero_constraints) + [final_constraint]) # pylint: disable=invalid-name 145 | 146 | b = np.zeros(A.shape[0]) 147 | b[-1] = scipy.special.factorial(derivative_order) 148 | 149 | return A, b 150 | 151 | 152 | def coefficients( 153 | grid: np.ndarray, 154 | method: Method, 155 | derivative_order: int) -> np.ndarray: 156 | """Calculate standard finite difference coefficients for the given grid. 157 | 158 | Args: 159 | grid: grid on which to calculate finite difference coefficients. 160 | method: discretization method. 161 | derivative_order: integer derivative order to approximate. 162 | 163 | Returns: 164 | NumPy array giving finite difference coefficients on the grid. 165 | """ 166 | A, b = constraints(grid, method, derivative_order) # pylint: disable=invalid-name 167 | return np.linalg.solve(A, b) 168 | 169 | 170 | def zero_padded_coefficients( 171 | grid: np.ndarray, 172 | method: Method, 173 | derivative_order: int, 174 | padding: Tuple[int, int]) -> np.ndarray: 175 | """Calculate finite difference coefficients, but padded by zeros. 176 | 177 | These coefficients always hold on the given grid, but the result is guaranteed 178 | to have values on the left and right sides with indicated number of zeros. 179 | 180 | Args: 181 | grid: grid on which to calculate finite difference coefficients, which will 182 | be trimmed based on padding. 183 | method: discretization method. 184 | derivative_order: integer derivative order to approximate. 185 | padding: number of zeros to pad on the left and right sides of the result. 186 | 187 | Returns: 188 | NumPy array giving finite difference coefficients on the grid. 189 | """ 190 | # note: need the "or" to avoid slicing with 0 as a right bound, because 0 191 | # is always interpretted as an offset from the start. 192 | pad_left, pad_right = padding 193 | trimmed_grid = grid[pad_left : (-pad_right or None)] 194 | trimmed_coefficients = coefficients(trimmed_grid, method, derivative_order) 195 | return np.pad(trimmed_coefficients, padding, mode='constant') 196 | 197 | 198 | class PolynomialAccuracyLayer(object): 199 | """Layer to enforce polynomial accuracy for finite difference coefficients. 200 | 201 | Attributes: 202 | input_size: length of input vectors that are transformed into valid finite 203 | difference coefficients. 204 | bias: numpy array of shape (grid_size,) to which zero vectors are mapped. 205 | nullspace: numpy array of shape (input_size, output_size) representing the 206 | nullspace of the constraint matrix. 207 | """ 208 | 209 | def __init__(self, 210 | grid: np.ndarray, 211 | method: Method, 212 | derivative_order: int, 213 | accuracy_order: int = 2, 214 | bias: np.ndarray = None, 215 | bias_zero_padding: Tuple[int, int] = (0, 0), 216 | out_scale: float = 1.0): 217 | """Constructor. 218 | 219 | Args: 220 | grid: grid on which to calculate finite difference coefficients. 221 | method: discretization method. 222 | derivative_order: integer derivative order to approximate. 223 | accuracy_order: integer order of polynomial accuracy to enforce. 224 | bias: np.ndarray of shape (grid_size,) to which zero-vectors will be 225 | mapped. Must satisfy polynomial accuracy to the requested order. By 226 | default, we calculate the standard finite difference coefficients for 227 | the given grid. 228 | bias_zero_padding: if a value for bias is not provided, ensure that the 229 | computed bias has the indicated number of zeros padded on the left and 230 | right sides. This is useful for initializing bias with upwinded 231 | coefficients. 232 | out_scale: desired multiplicative scaling on the outputs, relative to the 233 | bias. 234 | """ 235 | A, b = constraints(grid, method, derivative_order, accuracy_order) # pylint: disable=invalid-name 236 | 237 | if bias is None: 238 | bias = zero_padded_coefficients( 239 | grid, method, derivative_order, bias_zero_padding) 240 | 241 | norm = np.linalg.norm(np.dot(A, bias) - b) 242 | if norm > 1e-8: 243 | raise ValueError('invalid bias, not in nullspace') # pylint: disable=g-doc-exception 244 | 245 | # https://en.wikipedia.org/wiki/Kernel_(linear_algebra)#Nonhomogeneous_systems_of_linear_equations 246 | _, _, v = np.linalg.svd(A) 247 | input_size = A.shape[1] - A.shape[0] 248 | if not input_size: 249 | raise ValueError( # pylint: disable=g-doc-exception 250 | 'there is only one valid solution accurate to this order') 251 | 252 | # nullspace from the SVD is always normalized such that its singular values 253 | # are 1 or 0, which means it's actually independent of the grid spacing. 254 | nullspace = v[-input_size:] 255 | 256 | # ensure the nullspace is scaled comparably to the bias 257 | # TODO(shoyer): fix this for arbitrary spaced grids 258 | dx = grid[1] - grid[0] 259 | scaled_nullspace = nullspace * (out_scale / dx ** derivative_order) 260 | 261 | self.input_size = input_size 262 | self.grid_size = grid.size 263 | self.nullspace = scaled_nullspace 264 | self.bias = bias 265 | 266 | def apply(self, inputs: tf.Tensor) -> tf.Tensor: 267 | """Apply this layer to inputs. 268 | 269 | Args: 270 | inputs: float32 Tensor with dimensions [batch, x, input_size]. 271 | 272 | Returns: 273 | Float32 Tensor with dimensions [batch, x, grid_size]. 274 | """ 275 | bias = self.bias.astype(np.float32) 276 | nullspace = tf.convert_to_tensor(self.nullspace.astype(np.float32)) 277 | return bias + tf.einsum('bxi,ij->bxj', inputs, nullspace) 278 | 279 | 280 | def reconstruct( 281 | inputs: tf.Tensor, 282 | grid: np.ndarray, 283 | method: Method, 284 | derivative_order: int) -> tf.Tensor: 285 | """Calculate finite difference/volumes using the standard tables. 286 | 287 | Args: 288 | inputs: tf.Tensor with dimensions [batch, x]. 289 | grid: grid on which to calculate finite difference coefficients. 290 | method: discretization method. 291 | derivative_order: integer derivative order to calculate. 292 | 293 | Returns: 294 | tf.Tensor with dimensions [batch, x] with finite difference approximations 295 | to spatial derivatives at each point. 296 | """ 297 | filters = tf.convert_to_tensor( 298 | coefficients(grid, method, derivative_order), 299 | dtype=tf.float32) 300 | convolved = layers.nn_conv1d_periodic( 301 | inputs[..., tf.newaxis], filters[..., tf.newaxis, tf.newaxis], 302 | stride=1, center=True) 303 | return tf.squeeze(convolved, axis=2) 304 | -------------------------------------------------------------------------------- /pde_superresolution/integrate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for integrating PDEs with pretrained and baseline models.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import functools 21 | import os 22 | 23 | from absl import logging 24 | import numpy as np 25 | import scipy.fftpack 26 | import scipy.integrate 27 | import tensorflow as tf 28 | from typing import Any, Optional, Tuple 29 | import xarray 30 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 31 | from pde_superresolution import equations # pylint: disable=g-bad-import-order 32 | from pde_superresolution import model # pylint: disable=g-bad-import-order 33 | from pde_superresolution import training # pylint: disable=g-bad-import-order 34 | from pde_superresolution import weno # pylint: disable=g-bad-import-order 35 | 36 | 37 | _DEFAULT_TIMES = np.linspace(0, 10, num=201) 38 | 39 | 40 | class Differentiator(object): 41 | """Base class for calculating time derivatives.""" 42 | 43 | def __call__(self, t: float, y: np.ndarray) -> np.ndarray: 44 | """Calculate all desired spatial derivatives.""" 45 | raise NotImplementedError 46 | 47 | 48 | class SavedModelDifferentiator(Differentiator): 49 | """Calculate derivatives from a saved TensorFlow model.""" 50 | 51 | def __init__(self, 52 | checkpoint_dir: str, 53 | equation: equations.Equation, 54 | hparams: tf.contrib.training.HParams): 55 | 56 | with tf.Graph().as_default(): 57 | self.t = tf.placeholder(tf.float32, shape=()) 58 | 59 | num_points = equation.grid.solution_num_points 60 | self.inputs = tf.placeholder(tf.float32, shape=(num_points,)) 61 | 62 | time_derivative = tf.squeeze(model.predict_time_derivative( 63 | self.inputs[tf.newaxis, :], hparams), axis=0) 64 | self.value = equation.finalize_time_derivative(self.t, time_derivative) 65 | 66 | saver = tf.train.Saver() 67 | self.sess = tf.Session() 68 | saver.restore(self.sess, checkpoint_dir) 69 | 70 | def __call__(self, t: float, y: np.ndarray) -> np.ndarray: 71 | return self.sess.run(self.value, feed_dict={self.t: t, self.inputs: y}) 72 | 73 | 74 | class PolynomialDifferentiator(Differentiator): 75 | """Calculate derivatives using standard finite difference coefficients.""" 76 | 77 | def __init__(self, 78 | equation: equations.Equation, 79 | accuracy_order: Optional[int] = 1): 80 | 81 | with tf.Graph().as_default(): 82 | self.t = tf.placeholder(tf.float32, shape=()) 83 | 84 | num_points = equation.grid.solution_num_points 85 | self.inputs = tf.placeholder(tf.float32, shape=(num_points,)) 86 | 87 | batched_inputs = self.inputs[tf.newaxis, :] 88 | space_derivatives = model.baseline_space_derivatives( 89 | batched_inputs, equation, accuracy_order=accuracy_order) 90 | time_derivative = tf.squeeze(model.apply_space_derivatives( 91 | space_derivatives, batched_inputs, equation), axis=0) 92 | self.value = equation.finalize_time_derivative(self.t, time_derivative) 93 | 94 | self._space_derivatives = { 95 | k: tf.squeeze(space_derivatives[..., i], axis=0) 96 | for i, k in enumerate(equation.DERIVATIVE_NAMES) 97 | } 98 | 99 | self.sess = tf.Session() 100 | 101 | def __call__(self, t: float, y: np.ndarray) -> np.ndarray: 102 | return self.sess.run(self.value, feed_dict={self.t: t, self.inputs: y}) 103 | 104 | def calculate_space_derivatives(self, y): 105 | return self.sess.run(self._space_derivatives, feed_dict={self.inputs: y}) 106 | 107 | 108 | class SpectralDifferentiator(Differentiator): 109 | """Calculate derivatives using a spectral method.""" 110 | 111 | def __init__(self, equation: equations.Equation): 112 | self.equation = equation 113 | 114 | def __call__(self, t: float, y: np.ndarray) -> np.ndarray: 115 | period = self.equation.grid.period 116 | names_and_orders = zip(self.equation.DERIVATIVE_NAMES, 117 | self.equation.DERIVATIVE_ORDERS) 118 | space_derivatives = {name: scipy.fftpack.diff(y, order, period) 119 | for name, order in names_and_orders} 120 | time_derivative = self.equation.equation_of_motion(y, space_derivatives) 121 | return self.equation.finalize_time_derivative(t, time_derivative) 122 | 123 | 124 | class WENODifferentiator(Differentiator): 125 | """Calculate derivatives using a 5th order WENO method.""" 126 | 127 | def __init__(self, 128 | equation: equations.Equation, 129 | non_weno_accuracy_order: int = 3): 130 | self.equation = equation 131 | self.poly_diff = PolynomialDifferentiator(equation, non_weno_accuracy_order) 132 | 133 | def __call__(self, t: float, y: np.ndarray) -> np.ndarray: 134 | space_derivatives = self.poly_diff.calculate_space_derivatives(y) 135 | # replace u^- and u^+ with WENO reconstructions 136 | assert 'u_minus' in space_derivatives and 'u_plus' in space_derivatives 137 | space_derivatives['u_minus'] = np.roll(weno.reconstruct_left(y), 1) 138 | space_derivatives['u_plus'] = np.roll(weno.reconstruct_right(y), 1) 139 | time_derivative = self.equation.equation_of_motion(y, space_derivatives) 140 | return self.equation.finalize_time_derivative(t, time_derivative) 141 | 142 | 143 | def odeint(y0: np.ndarray, 144 | differentiator: Differentiator, 145 | times: np.ndarray, 146 | method: str = 'RK23') -> Tuple[np.ndarray, int]: 147 | """Integrate an ODE.""" 148 | logging.info('solve_ivp from %s to %s', times[0], times[-1]) 149 | 150 | # Most of our equations are somewhat stiff, so lower order Runga-Kutta is a 151 | # sane default. For whatever reason, the stiff solvers are much slower when 152 | # using TensorFlow to compute derivatives (even the baseline model) than 153 | # when using NumPy. 154 | sol = scipy.integrate.solve_ivp(differentiator, (times[0], times[-1]), y0, 155 | t_eval=times, max_step=0.01, method=method) 156 | y = sol.y.T # (time, x) 157 | 158 | logging.info('nfev: %r, njev: %r, nlu: %r', sol.nfev, sol.njev, sol.nlu) 159 | logging.info('status: %r, message: %s', sol.status, sol.message) 160 | 161 | # if integration diverges, pad result with NaN 162 | logging.info('output has length %s', y.shape[0]) 163 | num_missing = len(times) - y.shape[0] 164 | if num_missing: 165 | logging.info('padding with %s values', num_missing) 166 | pad_width = ((0, num_missing), (0, 0)) 167 | y = np.pad(y, pad_width, mode='constant', constant_values=np.nan) 168 | 169 | return y, sol.nfev 170 | 171 | 172 | def odeint_with_periodic_filtering( 173 | y0: np.ndarray, 174 | differentiator: Differentiator, 175 | times: np.ndarray, 176 | filter_interval: float, 177 | filter_order: int, 178 | method: str = 'RK23'): 179 | """Integrate with periodic filtering.""" 180 | 181 | # Spectral methods for hyperbolic problems can suffer from aliasing artifacts, 182 | # which can be alleviated by applying a low-pass (smoothing) filter. See 183 | # Sections 4.2 and 5 of: 184 | # Hesthaven, J. S. 2016. "Spectral Methods for Hyperbolic Problems." In 185 | # Handbook of Numerical Analysis, edited by Remi Abgrall and Chi-Wang Shu, 186 | # 17:441-66. Elsevier. 187 | # https://infoscience.epfl.ch/record/221484/files/SpecHandBook.pdf 188 | 189 | eps = 1e-8 190 | split_times = np.arange(times[0], times[-1] + eps, filter_interval) 191 | if not np.isin(split_times, times).all(): 192 | raise ValueError('all times in filter_interval must be sampled') 193 | split_indexes = np.searchsorted(times, split_times, side='right') 194 | 195 | y_list = [y0[np.newaxis, ...]] 196 | 197 | num_evals = 0 198 | for start_index, end_index in zip(split_indexes[:-1], split_indexes[1:]): 199 | cur_times = times[start_index-1:end_index] 200 | y, cur_num_evals = odeint(y0, differentiator, cur_times, method=method) 201 | y_list.append(y[1:]) # exclude y0 202 | y0 = duckarray.smoothing_filter(y[-1], order=filter_order) 203 | num_evals += cur_num_evals 204 | 205 | y = np.concatenate(y_list, axis=0) 206 | assert y.shape == (times.size, y0.size) 207 | 208 | # apply the filter again for post-processing 209 | # note: applying the filter at each time step during integration adds noise 210 | y = duckarray.smoothing_filter(y, order=filter_order) 211 | 212 | return y, num_evals 213 | 214 | 215 | def exact_differentiator( 216 | equation: equations.Equation) -> Differentiator: 217 | """Return an "exact" differentiator for the given equation. 218 | 219 | Args: 220 | equation: equation for which to produce an "exact" differentiator. 221 | 222 | Returns: 223 | Differentiator to use for "exact" integration. 224 | """ 225 | if type(equation.to_exact()) is not type(equation): 226 | raise TypeError('an exact equation must be provided') 227 | if equation.EXACT_METHOD is equations.ExactMethod.POLYNOMIAL: 228 | differentiator = PolynomialDifferentiator(equation, accuracy_order=None) 229 | elif equation.EXACT_METHOD is equations.ExactMethod.SPECTRAL: 230 | differentiator = SpectralDifferentiator(equation) 231 | elif equation.EXACT_METHOD is equations.ExactMethod.WENO: 232 | differentiator = WENODifferentiator(equation) 233 | else: 234 | raise TypeError('unexpected equation: {}'.format(equation)) 235 | return differentiator 236 | 237 | 238 | def integrate( 239 | equation: equations.Equation, 240 | differentiator: Differentiator, 241 | times: np.ndarray = _DEFAULT_TIMES, 242 | warmup: float = 0, 243 | integrate_method: str = 'RK23', 244 | filter_interval: float = None, 245 | filter_all_times: bool = False) -> xarray.Dataset: 246 | """Integrate an equation with possible warmup or periodic filtering.""" 247 | 248 | if filter_interval is not None: 249 | warmup_odeint = functools.partial( 250 | odeint_with_periodic_filtering, 251 | filter_interval=filter_interval, 252 | filter_order=max(equation.to_exact().DERIVATIVE_ORDERS)) 253 | else: 254 | warmup_odeint = odeint 255 | 256 | if warmup: 257 | equation_exact = equation.to_exact() 258 | diff_exact = exact_differentiator(equation_exact) 259 | if filter_interval is not None: 260 | warmup_times = np.arange(0, warmup + 1e-8, filter_interval) 261 | else: 262 | warmup_times = np.array([0, warmup]) 263 | y0_0 = equation_exact.initial_value() 264 | solution_warmup, _ = warmup_odeint( 265 | y0_0, diff_exact, times=warmup_times, method=integrate_method) 266 | # use the sample after warmup to initialize later simulations 267 | y0 = equation.grid.resample(solution_warmup[-1, :]) 268 | else: 269 | y0 = equation.initial_value() 270 | 271 | odeint_func = warmup_odeint if filter_all_times else odeint 272 | solution, num_evals = odeint_func( 273 | y0, differentiator, times=warmup+times, method=integrate_method) 274 | 275 | results = xarray.Dataset( 276 | data_vars={'y': (('time', 'x'), solution)}, 277 | coords={'time': warmup+times, 'x': equation.grid.solution_x, 278 | 'num_evals': num_evals}) 279 | return results 280 | 281 | 282 | def integrate_exact( 283 | equation: equations.Equation, 284 | times: np.ndarray = _DEFAULT_TIMES, 285 | warmup: float = 0, 286 | integrate_method: str = 'RK23', 287 | filter_interval: float = None) -> xarray.Dataset: 288 | """Integrate only the exact model.""" 289 | equation = equation.to_exact() 290 | differentiator = exact_differentiator(equation) 291 | return integrate(equation, differentiator, times, warmup, 292 | integrate_method=integrate_method, 293 | filter_interval=filter_interval) 294 | 295 | 296 | def integrate_baseline( 297 | equation: equations.Equation, 298 | times: np.ndarray = _DEFAULT_TIMES, 299 | warmup: float = 0, 300 | accuracy_order: int = 1, 301 | integrate_method: str = 'RK23', 302 | exact_filter_interval: float = None) -> xarray.Dataset: 303 | """Integrate a baseline finite difference model.""" 304 | differentiator = PolynomialDifferentiator( 305 | equation, accuracy_order=accuracy_order) 306 | return integrate(equation, differentiator, times, warmup, 307 | integrate_method=integrate_method, 308 | filter_interval=exact_filter_interval) 309 | 310 | 311 | def integrate_weno( 312 | equation: equations.Equation, 313 | times: np.ndarray = _DEFAULT_TIMES, 314 | warmup: float = 0, 315 | integrate_method: str = 'RK23', 316 | exact_filter_interval: float = None, 317 | **kwargs: Any) -> xarray.Dataset: 318 | """Integrate a baseline finite difference model.""" 319 | if type(equation) not in equations.FLUX_EQUATION_TYPES.values(): 320 | raise ValueError('invalid equation: {}'.format(equation)) 321 | differentiator = WENODifferentiator(equation, **kwargs) 322 | return integrate(equation, differentiator, times, warmup, 323 | integrate_method=integrate_method, 324 | filter_interval=exact_filter_interval) 325 | 326 | 327 | def integrate_spectral( 328 | equation: equations.Equation, 329 | times: np.ndarray = _DEFAULT_TIMES, 330 | warmup: float = 0, 331 | integrate_method: str = 'RK23', 332 | exact_filter_interval: float = None) -> xarray.Dataset: 333 | """Integrate a baseline finite difference model.""" 334 | if type(equation) not in equations.EQUATION_TYPES.values(): 335 | raise ValueError('invalid equation: {}'.format(equation)) 336 | differentiator = SpectralDifferentiator(equation) 337 | return integrate(equation, differentiator, times, warmup, 338 | integrate_method=integrate_method, 339 | filter_interval=exact_filter_interval) 340 | 341 | 342 | def integrate_exact_baseline_and_model( 343 | checkpoint_dir: str, 344 | hparams: tf.contrib.training.HParams = None, 345 | random_seed: int = 0, 346 | times: np.ndarray = _DEFAULT_TIMES, 347 | warmup: float = 0, 348 | integrate_method: str = 'RK23', 349 | exact_filter_interval: float = None) -> xarray.Dataset: 350 | """Integrate the given PDE with standard and modeled finite differences.""" 351 | 352 | if hparams is None: 353 | hparams = training.load_hparams(checkpoint_dir) 354 | 355 | logging.info('integrating %s with seed=%s', hparams.equation, random_seed) 356 | equation_fine, equation_coarse = equations.from_hparams( 357 | hparams, random_seed=random_seed) 358 | 359 | logging.info('solving the "exact" model at high resolution') 360 | ds_solution_exact = integrate_exact( 361 | equation_fine, times, warmup, integrate_method=integrate_method, 362 | filter_interval=exact_filter_interval) 363 | solution_exact = ds_solution_exact['y'].data 364 | num_evals_exact = ds_solution_exact['num_evals'].item() 365 | 366 | # resample to the coarse grid 367 | y0 = equation_coarse.grid.resample(solution_exact[0, :]) 368 | 369 | if np.isnan(y0).any(): 370 | raise ValueError('solution contains NaNs') 371 | 372 | logging.info('solving baseline finite differences at low resolution') 373 | differentiator = PolynomialDifferentiator(equation_coarse) 374 | solution_baseline, num_evals_baseline = odeint( 375 | y0, differentiator, warmup+times, method=integrate_method) 376 | 377 | logging.info('solving neural network model at low resolution') 378 | checkpoint_path = training.checkpoint_dir_to_path(checkpoint_dir) 379 | differentiator = SavedModelDifferentiator( 380 | checkpoint_path, equation_coarse, hparams) 381 | solution_model, num_evals_model = odeint( 382 | y0, differentiator, warmup+times, method=integrate_method) 383 | 384 | results = xarray.Dataset({ 385 | 'y_exact': (('time', 'x_high'), solution_exact), 386 | 'y_baseline': (('time', 'x_low'), solution_baseline), 387 | 'y_model': (('time', 'x_low'), solution_model), 388 | }, coords={ 389 | 'time': warmup+times, 390 | 'x_low': equation_coarse.grid.solution_x, 391 | 'x_high': equation_fine.grid.solution_x, 392 | 'num_evals_exact': num_evals_exact, 393 | 'num_evals_baseline': num_evals_baseline, 394 | 'num_evals_model': num_evals_model, 395 | }) 396 | return results 397 | 398 | 399 | def integrate_model_from_warm_start( 400 | checkpoint_dir: str, 401 | y0: np.ndarray, 402 | hparams: tf.contrib.training.HParams = None, 403 | random_seed: int = 0, 404 | times: np.ndarray = _DEFAULT_TIMES, 405 | warmup: float = 0, 406 | integrate_method: str = 'RK23') -> xarray.Dataset: 407 | """Integrate the given PDE with standard and modeled finite differences.""" 408 | 409 | if hparams is None: 410 | hparams = training.load_hparams(checkpoint_dir) 411 | 412 | logging.info('integrating %s with seed=%s', hparams.equation, random_seed) 413 | _, equation_coarse = equations.from_hparams(hparams, random_seed=random_seed) 414 | 415 | logging.info('solving neural network model at low resolution') 416 | checkpoint_path = training.checkpoint_dir_to_path(checkpoint_dir) 417 | differentiator = SavedModelDifferentiator( 418 | checkpoint_path, equation_coarse, hparams) 419 | solution_model, num_evals_model = odeint( 420 | y0, differentiator, warmup+times, method=integrate_method) 421 | 422 | results = xarray.Dataset( 423 | data_vars={'y': (('time', 'x'), solution_model)}, 424 | coords={'time': warmup+times, 425 | 'x': equation_coarse.grid.solution_x, 426 | 'num_evals': num_evals_model}) 427 | return results 428 | -------------------------------------------------------------------------------- /pde_superresolution/equations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Equations for inference and training data.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import enum 21 | import json 22 | 23 | import numpy as np 24 | import tensorflow as tf 25 | from typing import Mapping, Tuple, Type, TypeVar 26 | 27 | from pde_superresolution import duckarray # pylint: disable=g-bad-import-order 28 | from pde_superresolution import polynomials # pylint: disable=g-bad-import-order 29 | 30 | 31 | # TODO(shoyer): replace with TypeVar('T', np.ndarray, tf.Tensor) when pytype 32 | # supports it (b/74212131) 33 | T = TypeVar('T') 34 | 35 | 36 | @enum.unique 37 | class ExactMethod(enum.Enum): 38 | """Method to use for the "exact" solution at high resolution.""" 39 | POLYNOMIAL = 1 40 | SPECTRAL = 2 41 | WENO = 3 42 | 43 | 44 | class Grid(object): 45 | """Object for keeping track of grids and resampling.""" 46 | 47 | def __init__(self, 48 | solution_num_points: int, 49 | resample_factor: int = 1, 50 | resample_method: str = 'mean', 51 | period: float = 1.0): 52 | 53 | self.resample_factor = resample_factor 54 | self.resample_method = resample_method 55 | self.period = period 56 | 57 | self.solution_num_points = solution_num_points 58 | self.solution_dx = period / solution_num_points 59 | self.solution_x = self.solution_dx * np.arange(solution_num_points) 60 | 61 | self.reference_num_points = solution_num_points * resample_factor 62 | self.reference_dx = period / self.reference_num_points 63 | self.reference_x = self.reference_dx * np.arange(self.reference_num_points) 64 | 65 | def resample(self, x: T, axis: int = -1) -> T: 66 | """Resample from the reference resolution to the solution resolution.""" 67 | func = duckarray.RESAMPLE_FUNCS[self.resample_method] 68 | return func(x, self.resample_factor, axis=axis) 69 | 70 | 71 | class Equation(object): 72 | """Base class for equations to integrate.""" 73 | 74 | # TODO(shoyer): switch to use ClassVar when pytype supports it (b/72678203) 75 | CONSERVATIVE = ... # type: bool 76 | GRID_OFFSET = ... # type: polynomials.GridOffset 77 | EXACT_METHOD = ... # type: ExactMethod 78 | DERIVATIVE_NAMES =... # type: Tuple[str, ...] 79 | DERIVATIVE_ORDERS = ... # type: Tuple[int, ...] 80 | 81 | def __init__(self, 82 | num_points: int, 83 | resample_factor: int = 1, 84 | period: float = 1.0, 85 | random_seed: int = 0): 86 | """Constructor. 87 | 88 | Args: 89 | num_points: number of positions in x at which the equation is solved. 90 | resample_factor: integer factor by which num_points is resampled from the 91 | original grid. 92 | period: period for x. Equation subclasses may set different default 93 | values appropriate for the equation being solved. 94 | random_seed: integer random seed for any stochastic aspects of the 95 | equation. 96 | """ 97 | # Note: Ideally we would pass in grid as a construtor argument, but we need 98 | # different default grids for different equations, so we initialize it here 99 | # instead. 100 | resample_method = 'mean' if self.CONSERVATIVE else 'subsample' 101 | self.grid = Grid(num_points, resample_factor, resample_method, period) 102 | self.random_seed = random_seed 103 | 104 | def initial_value(self) -> np.ndarray: 105 | """Initial condition for time integration.""" 106 | raise NotImplementedError 107 | 108 | @property 109 | def time_step(self) -> float: 110 | """Time step size to use with explicit integration (the midpoint rule).""" 111 | raise NotImplementedError 112 | 113 | @property 114 | def standard_deviation(self) -> float: 115 | """Empricial standard deviation for integrated solutions.""" 116 | raise NotImplementedError 117 | 118 | def equation_of_motion( 119 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 120 | """Time derivatives of the state `y` for integration. 121 | 122 | ML models may have access to equation_of_motion() for training. 123 | 124 | Args: 125 | y: float np.ndarray or tf.Tensor (with any number of dimensions) giving 126 | current function values. 127 | spatial_derivatives: dict of np.ndarray or Tensor with same dtype/shape 128 | as `y` mapping from spatial derivatives by name to derivative values. 129 | 130 | Returns: 131 | ndarray or Tensor with same dtype/shape as `y` giving the partial 132 | derivative of `y` with respect to time according to this equation. 133 | """ 134 | raise NotImplementedError 135 | 136 | def finalize_time_derivative(self, t: float, y_t: np.ndarray) -> np.ndarray: 137 | """Finalize time derivatives for integrations. 138 | 139 | ML models do *not* have access to finalize_time_derivative() during 140 | training. It is only used when integrating the PDE. 141 | 142 | The default implementation returns y_t unmodified. 143 | 144 | Args: 145 | t: float giving current time. 146 | y_t: float np.ndarray with any number of dimensions giving 147 | current function values. 148 | 149 | Returns: 150 | Array with same dtype/shape as `y_t`. 151 | """ 152 | del t # unused 153 | return y_t 154 | 155 | def to_fine(self) -> 'Equation': 156 | """Return a copy of this equation on a fine resolution grid. 157 | 158 | This equation will have exactly the same type and parameters, with the 159 | exception of resample_factor. 160 | """ 161 | raise NotImplementedError 162 | 163 | @classmethod 164 | def exact_type(cls) -> 'Type[Equation]': 165 | raise NotImplementedError 166 | 167 | @classmethod 168 | def conservative_type(cls) -> 'Type[Equation]': 169 | raise NotImplementedError 170 | 171 | @classmethod 172 | def base_type(cls) -> 'Type[Equation]': 173 | raise NotImplementedError 174 | 175 | def params(self) -> dict: 176 | raise NotImplementedError 177 | 178 | def to_exact(self) -> 'Equation': 179 | """Return the "exact" version of this equation, on the same grid. 180 | 181 | This equation will have exactly the same parameters, except it may have a 182 | different type. 183 | 184 | This is used for "exact" numerical integration in integrate.py. It should 185 | be WENO for Burgers' and a non-conservative equation for KdV and KS (we use 186 | it with spectral methods). 187 | """ 188 | return self.exact_type()(**self.params()) 189 | 190 | def to_conservative(self) -> 'Equation': 191 | """Return the "conservative" version of this equation, on the same grid. 192 | """ 193 | return self.conservative_type()(**self.params()) 194 | 195 | 196 | class RandomForcing(object): 197 | """Deterministic random forcing, periodic in both space and time.""" 198 | 199 | def __init__(self, 200 | grid: Grid, 201 | nparams: int = 20, 202 | seed: int = 0, 203 | amplitude: float = 1, 204 | k_min: int = 1, 205 | k_max: int = 3): 206 | self.grid = grid 207 | rs = np.random.RandomState(seed) 208 | self.a = 0.5 * amplitude * rs.uniform(-1, 1, size=(nparams, 1)) 209 | self.omega = rs.uniform(-0.4, 0.4, size=(nparams, 1)) 210 | k_values = np.arange(k_min, k_max + 1) 211 | self.k = rs.choice(np.concatenate([-k_values, k_values]), size=(nparams, 1)) 212 | self.phi = rs.uniform(0, 2 * np.pi, size=(nparams, 1)) 213 | 214 | def __call__(self, t: float) -> np.ndarray: 215 | spatial_phase = (2 * np.pi * self.k * self.grid.reference_x 216 | / self.grid.period) 217 | signals = duckarray.sin(self.omega * t + spatial_phase + self.phi) 218 | reference_forcing = duckarray.sum(self.a * signals, axis=0) 219 | return self.grid.resample(reference_forcing) 220 | 221 | def export(self, path): 222 | """Export to a text file.""" 223 | p = np.zeros_like(self.a) 224 | p[0] = self.grid.period 225 | p[1] = self.grid.reference_num_points 226 | array = np.array([self.a, self.omega, self.k, self.phi, p]).squeeze() 227 | np.savetxt(path, array) 228 | 229 | 230 | class BurgersEquation(Equation): 231 | """Burger's equation with random forcing.""" 232 | 233 | CONSERVATIVE = False 234 | GRID_OFFSET = polynomials.GridOffset.CENTERED 235 | EXACT_METHOD = ExactMethod.WENO 236 | DERIVATIVE_NAMES = ('u_x', 'u_xx') 237 | DERIVATIVE_ORDERS = (1, 2) 238 | 239 | def __init__(self, 240 | num_points: int, 241 | resample_factor: int = 1, 242 | period: float = 2 * np.pi, 243 | random_seed: int = 0, 244 | eta: float = 0.04, 245 | k_min: int = 1, 246 | k_max: int = 3, 247 | ): 248 | super(BurgersEquation, self).__init__( 249 | num_points, resample_factor, period, random_seed) 250 | self.forcing = RandomForcing(self.grid, seed=random_seed, k_min=k_min, 251 | k_max=k_max) 252 | self.eta = eta 253 | self.k_min = k_min 254 | self.k_max = k_max 255 | 256 | def initial_value(self) -> np.ndarray: 257 | return np.zeros_like(self.grid.solution_x) 258 | 259 | @property 260 | def time_step(self) -> float: 261 | # TODO(shoyer): pick this dynamically 262 | return 1e-3 263 | 264 | @property 265 | def standard_deviation(self) -> float: 266 | # TODO(shoyer): pick this dynamically 267 | return 0.7917 268 | 269 | def equation_of_motion( 270 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 271 | y_x = spatial_derivatives['u_x'] 272 | y_xx = spatial_derivatives['u_xx'] 273 | y_t = self.eta * y_xx - y * y_x 274 | return y_t 275 | 276 | def finalize_time_derivative(self, t: float, y_t: tf.Tensor) -> tf.Tensor: 277 | return y_t + self.forcing(t) 278 | 279 | def params(self): 280 | return dict( 281 | num_points=self.grid.reference_num_points, 282 | period=self.grid.period, 283 | random_seed=self.random_seed, 284 | eta=self.eta, 285 | k_min=self.k_min, 286 | k_max=self.k_max, 287 | ) 288 | 289 | def to_fine(self): 290 | return type(self)(**self.params()) 291 | 292 | @classmethod 293 | def exact_type(cls): 294 | return GodunovBurgersEquation 295 | 296 | @classmethod 297 | def conservative_type(cls): 298 | return ConservativeBurgersEquation 299 | 300 | @classmethod 301 | def base_type(cls): 302 | return BurgersEquation 303 | 304 | 305 | def staggered_first_derivative(y: T, dx: float) -> T: 306 | """Calculate a first-order derivative with second order finite differences. 307 | 308 | This function works on both NumPy arrays and tf.Tensor objects. 309 | 310 | Args: 311 | y: array to differentiate, with shape [..., x]. 312 | dx: spacing between grid points. 313 | 314 | Returns: 315 | Differentiated array, same type and shape as `y`. 316 | """ 317 | # Use concat instead of roll because roll doesn't have GPU or TPU 318 | # implementations in TensorFlow 319 | y_forward = duckarray.concatenate([y[..., 1:], y[..., :1]], axis=-1) 320 | return (1 / dx) * (y_forward - y) 321 | 322 | 323 | class ConservativeBurgersEquation(BurgersEquation): 324 | """Burgers constrained to obey the continuity equation.""" 325 | 326 | CONSERVATIVE = True 327 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 328 | DERIVATIVE_NAMES = ('u', 'u_x') 329 | DERIVATIVE_ORDERS = (0, 1) 330 | 331 | def equation_of_motion( 332 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 333 | del y # unused 334 | y = spatial_derivatives['u'] 335 | y_x = spatial_derivatives['u_x'] 336 | flux = 0.5 * y ** 2 - self.eta * y_x 337 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 338 | return y_t 339 | 340 | 341 | def godunov_convective_flux(u_minus, u_plus): 342 | """Calculate Godunov's flux for 0.5*u**2.""" 343 | u_minus_squared = u_minus ** 2 344 | u_plus_squared = u_plus ** 2 345 | return 0.5 * duckarray.where( 346 | u_minus <= u_plus, 347 | duckarray.minimum(u_minus_squared, u_plus_squared), 348 | duckarray.maximum(u_minus_squared, u_plus_squared), 349 | ) 350 | 351 | 352 | class GodunovBurgersEquation(BurgersEquation): 353 | """Conserative Burgers' equation using Godunov numerical flux.""" 354 | 355 | CONSERVATIVE = True 356 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 357 | DERIVATIVE_NAMES = ('u_minus', 'u_plus', 'u_x') 358 | DERIVATIVE_ORDERS = (0, 0, 1) 359 | 360 | def equation_of_motion( 361 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 362 | del y # unused 363 | y_minus = spatial_derivatives['u_minus'] 364 | y_plus = spatial_derivatives['u_plus'] 365 | y_x = spatial_derivatives['u_x'] 366 | 367 | convective_flux = godunov_convective_flux(y_minus, y_plus) 368 | flux = convective_flux - self.eta * y_x 369 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 370 | return y_t 371 | 372 | 373 | class KdVEquation(Equation): 374 | """Korteweg-de Vries (KdV) equation with random initial conditions.""" 375 | 376 | CONSERVATIVE = False 377 | GRID_OFFSET = polynomials.GridOffset.CENTERED 378 | EXACT_METHOD = ExactMethod.SPECTRAL 379 | DERIVATIVE_NAMES = ('u_x', 'u_xxx') 380 | DERIVATIVE_ORDERS = (1, 3) 381 | 382 | def __init__(self, 383 | num_points: int, 384 | resample_factor: int = 1, 385 | period: float = 32, 386 | random_seed: int = 0, 387 | k_min: int = 1, 388 | k_max: int = 3, 389 | ): 390 | super(KdVEquation, self).__init__( 391 | num_points, resample_factor, period, random_seed) 392 | self.forcing = RandomForcing(self.grid, nparams=10, seed=random_seed, 393 | k_min=k_min, k_max=k_max) 394 | self.k_min = k_min 395 | self.k_max = k_max 396 | 397 | def initial_value(self) -> np.ndarray: 398 | return self.forcing(0) 399 | 400 | @property 401 | def time_step(self) -> float: 402 | # TODO(shoyer): pick this dynamically 403 | return 2.5e-5 404 | 405 | @property 406 | def standard_deviation(self) -> float: 407 | # TODO(shoyer): pick this dynamically 408 | return 0.594 409 | 410 | def equation_of_motion( 411 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 412 | y_x = spatial_derivatives['u_x'] 413 | y_xxx = spatial_derivatives['u_xxx'] 414 | y_t = -6 * y * y_x - y_xxx 415 | return y_t 416 | 417 | def params(self): 418 | return dict( 419 | num_points=self.grid.reference_num_points, 420 | period=self.grid.period, 421 | random_seed=self.random_seed, 422 | k_min=self.k_min, 423 | k_max=self.k_max, 424 | ) 425 | 426 | def to_fine(self): 427 | return type(self)(**self.params()) 428 | 429 | @classmethod 430 | def exact_type(cls): 431 | return KdVEquation 432 | 433 | @classmethod 434 | def conservative_type(cls): 435 | return ConservativeKdVEquation 436 | 437 | @classmethod 438 | def base_type(cls): 439 | return KdVEquation 440 | 441 | 442 | class ConservativeKdVEquation(KdVEquation): 443 | """KdV constrained to obey the continuity equation.""" 444 | 445 | CONSERVATIVE = True 446 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 447 | DERIVATIVE_NAMES = ('u', 'u_xx') 448 | DERIVATIVE_ORDERS = (0, 2) 449 | 450 | def equation_of_motion( 451 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 452 | del y # unused 453 | y = spatial_derivatives['u'] 454 | y_xx = spatial_derivatives['u_xx'] 455 | flux = 3 * y ** 2 + y_xx 456 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 457 | return y_t 458 | 459 | 460 | class GodunovKdVEquation(KdVEquation): 461 | """Conservative KdV using Godunov numerical flux.""" 462 | 463 | CONSERVATIVE = True 464 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 465 | DERIVATIVE_NAMES = ('u_minus', 'u_plus', 'u_xx') 466 | DERIVATIVE_ORDERS = (0, 0, 2) 467 | 468 | def equation_of_motion( 469 | self, y: T, spatial_derivatives: Mapping[Tuple[str, int], T]) -> T: 470 | del y # unused 471 | y_minus = spatial_derivatives['u_minus'] 472 | y_plus = spatial_derivatives['u_plus'] 473 | y_xx = spatial_derivatives['u_xx'] 474 | 475 | convective_flux = godunov_convective_flux(y_minus, y_plus) 476 | flux = 6 * convective_flux + y_xx 477 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 478 | return y_t 479 | 480 | 481 | class KSEquation(Equation): 482 | """Kuramoto-Sivashinsky (KS) equation with random initial conditions.""" 483 | 484 | CONSERVATIVE = False 485 | GRID_OFFSET = polynomials.GridOffset.CENTERED 486 | EXACT_METHOD = ExactMethod.SPECTRAL 487 | DERIVATIVE_NAMES = ('u_x', 'u_xx', 'u_xxxx') 488 | DERIVATIVE_ORDERS = (1, 2, 4) 489 | 490 | def __init__(self, 491 | num_points: int, 492 | resample_factor: int = 1, 493 | period: float = 64, 494 | random_seed: int = 0, 495 | k_min: int = 1, 496 | k_max: int = 3, 497 | ): 498 | super(KSEquation, self).__init__( 499 | num_points, resample_factor, period, random_seed) 500 | self.forcing = RandomForcing(self.grid, nparams=10, seed=random_seed, 501 | k_min=k_min, k_max=k_max) 502 | self.k_min = k_min 503 | self.k_max = k_max 504 | 505 | @property 506 | def time_step(self) -> float: 507 | # TODO(shoyer): pick this dynamically 508 | return 2.5e-5 509 | 510 | @property 511 | def standard_deviation(self) -> float: 512 | # TODO(shoyer): pick this dynamically 513 | return 0.299 514 | 515 | def initial_value(self) -> np.ndarray: 516 | return self.forcing(0) 517 | 518 | def equation_of_motion( 519 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 520 | y_x = spatial_derivatives['u_x'] 521 | y_xx = spatial_derivatives['u_xx'] 522 | y_xxxx = spatial_derivatives['u_xxxx'] 523 | y_t = -y*y_x - y_xxxx - y_xx 524 | return y_t 525 | 526 | def params(self): 527 | return dict( 528 | num_points=self.grid.reference_num_points, 529 | period=self.grid.period, 530 | random_seed=self.random_seed, 531 | k_min=self.k_min, 532 | k_max=self.k_max, 533 | ) 534 | 535 | def to_fine(self): 536 | return type(self)(**self.params()) 537 | 538 | @classmethod 539 | def exact_type(cls): 540 | return KSEquation 541 | 542 | @classmethod 543 | def conservative_type(cls): 544 | return ConservativeKSEquation 545 | 546 | @classmethod 547 | def base_type(cls): 548 | return KSEquation 549 | 550 | 551 | class ConservativeKSEquation(KSEquation): 552 | """Conservative KS using Godunov numerical flux.""" 553 | 554 | CONSERVATIVE = True 555 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 556 | DERIVATIVE_NAMES = ('u', 'u_x', 'u_xxx') 557 | DERIVATIVE_ORDERS = (0, 1, 3) 558 | 559 | def equation_of_motion( 560 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 561 | del y # unused 562 | y = spatial_derivatives['u'] 563 | y_x = spatial_derivatives['u_x'] 564 | y_xxx = spatial_derivatives['u_xxx'] 565 | flux = 0.5*y**2 + y_xxx + y_x 566 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 567 | return y_t 568 | 569 | 570 | class GodunovKSEquation(KSEquation): 571 | CONSERVATIVE = True 572 | GRID_OFFSET = polynomials.GridOffset.STAGGERED 573 | DERIVATIVE_NAMES = ('u_minus', 'u_plus', 'u_x', 'u_xxx') 574 | DERIVATIVE_ORDERS = (0, 0, 1, 3) 575 | 576 | def equation_of_motion( 577 | self, y: T, spatial_derivatives: Mapping[str, T]) -> T: 578 | del y # unused 579 | y_minus = spatial_derivatives['u_minus'] 580 | y_plus = spatial_derivatives['u_plus'] 581 | y_x = spatial_derivatives['u_x'] 582 | y_xxx = spatial_derivatives['u_xxx'] 583 | 584 | convective_flux = godunov_convective_flux(y_minus, y_plus) 585 | flux = y_xxx + y_x + convective_flux 586 | y_t = -staggered_first_derivative(flux, self.grid.solution_dx) 587 | return y_t 588 | 589 | 590 | EQUATION_TYPES = { 591 | 'burgers': BurgersEquation, 592 | 'kdv': KdVEquation, 593 | 'ks': KSEquation, 594 | } 595 | 596 | CONSERVATIVE_EQUATION_TYPES = { 597 | 'burgers': ConservativeBurgersEquation, 598 | 'kdv': ConservativeKdVEquation, 599 | 'ks': ConservativeKSEquation, 600 | } 601 | 602 | FLUX_EQUATION_TYPES = { 603 | 'burgers': GodunovBurgersEquation, 604 | 'kdv': GodunovKdVEquation, 605 | 'ks': GodunovKSEquation, 606 | } 607 | 608 | 609 | def equation_type_from_hparams( 610 | hparams: tf.contrib.training.HParams) -> Type[Equation]: 611 | """Create an equation type from HParams. 612 | 613 | Args: 614 | hparams: hyperparameters for training. 615 | 616 | Returns: 617 | Corresponding equation type. 618 | """ 619 | if hparams.conservative: 620 | if hparams.numerical_flux: 621 | types = FLUX_EQUATION_TYPES 622 | else: 623 | types = CONSERVATIVE_EQUATION_TYPES 624 | else: 625 | types = EQUATION_TYPES 626 | return types[hparams.equation] 627 | 628 | 629 | def from_hparams( 630 | hparams: tf.contrib.training.HParams, 631 | random_seed: int = 0) -> Tuple[Equation, Equation]: 632 | """Create Equation objects for model training from HParams. 633 | 634 | Args: 635 | hparams: hyperparameters for training. 636 | random_seed: integer random seed. 637 | 638 | Returns: 639 | A tuple of two Equation objects, providing the equations being solved on 640 | the fine (exact) and coarse (modeled) grids. 641 | 642 | Raises: 643 | ValueError: if hparams.resample_factor does not exactly divide 644 | exact_grid_size. 645 | """ 646 | kwargs = json.loads(hparams.equation_kwargs) 647 | exact_num_points = kwargs.pop('num_points') 648 | 649 | num_points, remainder = divmod(exact_num_points, hparams.resample_factor) 650 | if remainder: 651 | raise ValueError('resample_factor={} does not divide exact_num_points={}' 652 | .format(hparams.resample_factor, exact_num_points)) 653 | 654 | equation_type = equation_type_from_hparams(hparams) 655 | coarse_equation = equation_type( 656 | num_points, 657 | resample_factor=hparams.resample_factor, 658 | random_seed=random_seed, 659 | **kwargs) 660 | fine_equation = coarse_equation.to_fine() 661 | 662 | return fine_equation, coarse_equation 663 | -------------------------------------------------------------------------------- /pde_superresolution/training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for training a finite difference coefficient model. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import copy 23 | import os.path 24 | 25 | from absl import logging 26 | import numpy as np 27 | import pandas as pd 28 | import tensorflow as tf 29 | from google.protobuf import text_format # pylint: disable=g-bad-import-order 30 | from tensorflow.contrib.training.python.training import hparam_pb2 # pylint: disable=g-bad-import-order 31 | from tensorflow.core.protobuf import config_pb2 32 | from tensorflow.core.protobuf import rewriter_config_pb2 33 | from typing import Any, Dict, List, Tuple, Type, Union 34 | 35 | # pylint: disable=g-bad-import-order 36 | from pde_superresolution import equations 37 | from pde_superresolution import model 38 | 39 | 40 | def create_hparams(equation: str, **kwargs: Any) -> tf.contrib.training.HParams: 41 | """Create default hyper-parameters for training a model. 42 | 43 | Dataset parameters: 44 | equation: name of the equation being solved. 45 | conservative: boolean indicating whether to use the continuity preserving 46 | variant of this equation or not. 47 | numerical_flux: whether to use the Gudonov numerical flux formulation of the 48 | equation. 49 | resample_factor: integer factor by which to upscale from low to high 50 | resolution. Must evenly divide the high resolution grid. 51 | equation_kwargs: JSON encoded string with equation specific keyword 52 | arguments, excluding resample_factor and random_seed. 53 | 54 | Neural network parameters: 55 | model_target: string indicating what the neural network is asked to directly 56 | output, any of 'coefficients', 'space_derivatives', 'flux' or 57 | 'time_derivative'. 58 | num_layers: integer number of conv1d layers to use for coefficient 59 | prediction. 60 | filter_size: inetger filter size for conv1d layers. 61 | polynomial_accuracy_order: integer order of polynomial accuracy to enforce 62 | by construction. 63 | polynomial_accuracy_scale: float scaling on output from the polynomial 64 | accuracy layer. 65 | ensure_unbiased_coefficients: boolean indicating whether to ensure finite 66 | difference constraints are unbiased. Only used if 67 | polynomial_accuracy_order == 0. 68 | coefficient_grid_min_size: integer minimum size of the grid used for finite 69 | difference coefficients. The coefficient grid will be either this size or 70 | one larger, if GRID_OFFSET is False, 71 | 72 | Training parameters: 73 | base_batch_size: base batch size. Scaled by resample_factor to compute the 74 | batch size sized used in training. This ensures that models trained at 75 | different resolutions uses the same number of data points per batch. 76 | learning_rates: List[float] giving constant learning rates to use with Adam. 77 | learning_stops: List[int] giving global steps at which to move on to the 78 | next learning rate or stop training. 79 | frac_training: float fraction of the input dataset to use for training vs. 80 | validation. 81 | eval_interval: integer training step frequency at which to run evaluation. 82 | 83 | Noise parameters 84 | noise_probability: float probability of adding noise to input data for any 85 | particular example during training. 86 | noise_amplitude: float amplitude of Gaussian noise to add to input data 87 | during training. 88 | noise_type: string 'white' or 'filtered' indicating the type of noise to 89 | apply. 90 | 91 | Loss parameters: 92 | ground_truth_order: polynomial accuracy order to use for creating ground- 93 | truth labels used in the loss. -1 is a special sentinel value indicating 94 | "maximum accuracy", i.e., WENO for Burgers' equation and a spectral method 95 | for KdV and KS. 96 | num_time_steps: integer number of integration time steps to include in the 97 | loss. 98 | error_floor_quantile: float quantile to use for the error floor. 99 | error_scale: List[float] with length 2*num_channels indicating the 100 | scaling in the loss to use on squared error and relative squared error 101 | for each derivative target. 102 | error_floor: List[float] with length num_channels giving the scale for 103 | weighting of relative errors. 104 | error_max: float indicating the largest relative error (compared to 105 | predicting no change over time) to use in computing each error term in the 106 | loss. Larger errors values are clipped to this maximum value. 107 | relative_error_weight: float relative weighting for absolute error term in 108 | the loss. 109 | relative_error_weight: float relative weighting for relative error term in 110 | the loss. 111 | space_derivatives_weight: float relative weighting for space derivatives in 112 | the loss. 113 | time_derivative_weight: float relative weighting for time derivatives in the 114 | loss. 115 | integrated_solution_weight: float relative weighting for the integrated 116 | solution in the loss. 117 | 118 | Args: 119 | equation: lowercase string name of the equation to solve. 120 | **kwargs: default hyper-parameter values to override. 121 | 122 | Returns: 123 | HParams object with all hyperparameter values. 124 | """ 125 | hparams = tf.contrib.training.HParams( 126 | # dataset parameters 127 | equation=equation, 128 | conservative=True, 129 | numerical_flux=False, 130 | equation_kwargs='{}', 131 | resample_factor=4, 132 | # neural network parameters 133 | model_target='coefficients', 134 | num_layers=3, 135 | filter_size=32, 136 | kernel_size=5, 137 | nonlinearity='relu', 138 | polynomial_accuracy_order=1, 139 | polynomial_accuracy_scale=1.0, 140 | ensure_unbiased_coefficients=False, 141 | coefficient_grid_min_size=6, 142 | # training parameters 143 | base_batch_size=128, 144 | learning_rates=[1e-3, 1e-4], 145 | learning_stops=[20000, 40000], 146 | frac_training=0.8, 147 | eval_interval=250, 148 | noise_probability=0.0, 149 | noise_amplitude=0.0, 150 | noise_type='white', 151 | # loss parameters 152 | ground_truth_order=-1, 153 | num_time_steps=0, 154 | error_floor_quantile=0.1, 155 | error_scale=[np.nan], # set by set_data_dependent_hparams 156 | error_floor=[np.nan], # set by set_data_dependent_hparams 157 | error_max=0.0, 158 | absolute_error_weight=1.0, 159 | relative_error_weight=0.0, 160 | space_derivatives_weight=0.0, 161 | time_derivative_weight=1.0, 162 | integrated_solution_weight=0.0, 163 | ) 164 | hparams.override_from_dict(kwargs) 165 | return hparams 166 | 167 | 168 | def set_data_dependent_hparams( 169 | hparams: tf.contrib.training.HParams, 170 | snapshots: np.ndarray): 171 | """Add data-dependent hyperparameters to hparams. 172 | 173 | Added hyper-parameters: 174 | error_scale: List[float] with length 2*num_channels indicating the 175 | scaling in the loss to use on squared error and relative squared error 176 | for each derivative target. 177 | error_floor: List[float] with length num_channels giving the scale for 178 | weighting of relative errors. 179 | 180 | Args: 181 | hparams: hyper-parameters for training. Will be modified by adding 182 | 'error_floor' and 'error_scale' entries (lists of float). 183 | snapshots: np.ndarray with shape [examples, x] with high-resolution 184 | training data. 185 | """ 186 | error_floor, error_scale = determine_loss_scales(snapshots, hparams) 187 | hparams.set_hparam('error_scale', error_scale.ravel().tolist()) 188 | hparams.set_hparam('error_floor', error_floor.tolist()) 189 | 190 | 191 | def create_training_step( 192 | loss: tf.Tensor, 193 | hparams: tf.contrib.training.HParams) -> tf.Tensor: 194 | """Create a training step operation for training our neural network. 195 | 196 | Args: 197 | loss: loss to optimize. 198 | hparams: hyperparameters for training. 199 | 200 | Returns: 201 | Tensor that runs a single step of training each time it is evaluated. 202 | """ 203 | global_step = tf.train.get_or_create_global_step() 204 | 205 | if len(hparams.learning_rates) > 1: 206 | learning_rate = tf.train.piecewise_constant( 207 | global_step, boundaries=hparams.learning_stops[:-1], 208 | values=hparams.learning_rates) 209 | else: 210 | (learning_rate,) = hparams.learning_rates 211 | 212 | optimizer = tf.train.AdamOptimizer(learning_rate, beta2=0.99) 213 | 214 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 215 | with tf.control_dependencies(update_ops): 216 | train_step = optimizer.minimize(loss, global_step=global_step) 217 | 218 | return train_step 219 | 220 | 221 | def setup_training( 222 | snapshots: np.ndarray, 223 | hparams: tf.contrib.training.HParams) -> Tuple[tf.Tensor, tf.Tensor]: 224 | """Create Tensors for training. 225 | 226 | Args: 227 | snapshots: np.ndarray with shape [examples, x] with high-resolution 228 | training data. 229 | hparams: hyperparameters for training. 230 | 231 | Returns: 232 | Tensors for the current loss, and for taking a training step. 233 | """ 234 | dataset = model.make_dataset(snapshots, hparams, 235 | dataset_type=model.Dataset.TRAINING) 236 | tensors = dataset.make_one_shot_iterator().get_next() 237 | 238 | predictions = model.predict_result(tensors['inputs'], hparams) 239 | 240 | loss_per_head = model.loss_per_head(predictions, 241 | labels=tensors['labels'], 242 | baseline=tensors['baseline'], 243 | hparams=hparams) 244 | loss = model.weighted_loss(loss_per_head, hparams) 245 | train_step = create_training_step(loss, hparams) 246 | 247 | return loss, train_step 248 | 249 | 250 | MetricsDict = Dict[str, Tuple[tf.Tensor, tf.Tensor]] # pylint: disable=invalid-name 251 | 252 | 253 | class Inferer(object): 254 | """Object for repeated running inference over a fixed dataset.""" 255 | 256 | def __init__(self, 257 | snapshots: np.ndarray, 258 | hparams: tf.contrib.training.HParams, 259 | training: bool = False): 260 | """Initialize an object for running inference. 261 | 262 | Args: 263 | snapshots: np.ndarray with shape [examples, x] with high-resolution 264 | training data. 265 | hparams: hyperparameters for training. 266 | training: whether to evaluate on training or validation datasets. 267 | """ 268 | if training: 269 | dataset_type = model.Dataset.TRAINING 270 | else: 271 | dataset_type = model.Dataset.VALIDATION 272 | dataset = model.make_dataset(snapshots, hparams, dataset_type=dataset_type, 273 | repeat=False, evaluation=True) 274 | iterator = dataset.make_initializable_iterator() 275 | data = iterator.get_next() 276 | 277 | _, coarse_equation = equations.from_hparams(hparams) 278 | 279 | predictions = model.predict_result(data['inputs'], hparams) 280 | loss_per_head = model.loss_per_head( 281 | predictions, 282 | labels=data['labels'], 283 | baseline=data['baseline'], 284 | hparams=hparams) 285 | loss = model.weighted_loss(loss_per_head, hparams) 286 | 287 | results = dict(data, predictions=predictions) 288 | metrics = {k: tf.contrib.metrics.streaming_concat(v) 289 | for k, v in results.items()} 290 | metrics['loss'] = tf.metrics.mean(loss) 291 | 292 | space_loss, time_loss, integrated_loss = model.result_unstack( 293 | loss_per_head, coarse_equation) 294 | metrics['loss/space_derivatives'] = tf.metrics.mean(space_loss) 295 | metrics['loss/time_derivative'] = tf.metrics.mean(time_loss) 296 | if integrated_loss is not None: 297 | metrics['loss/integrated_solution'] = tf.metrics.mean(integrated_loss) 298 | 299 | initializer = tf.group(iterator.initializer, 300 | tf.local_variables_initializer()) 301 | 302 | self._initializer = initializer 303 | self._metrics = metrics 304 | 305 | def run(self, sess: tf.Session) -> Dict[str, np.ndarray]: 306 | """Run inference over a complete dataset. 307 | 308 | Args: 309 | sess: active session. 310 | 311 | Returns: 312 | Dict with evaluated metrics as NumPy arrays. 313 | """ 314 | return evaluate_metrics(sess, self._initializer, self._metrics) 315 | 316 | 317 | def evaluate_metrics(sess: tf.Session, 318 | initializer: tf.Tensor, 319 | metrics: MetricsDict) -> Dict[str, np.ndarray]: 320 | """Evaluate metrics over a complete dataset. 321 | 322 | Args: 323 | sess: active session. 324 | initializer: tensor to run to (re)initialize local variables. 325 | metrics: metrics to evaluate. 326 | 327 | Returns: 328 | Dict with evaluated metrics as NumPy arrays. 329 | """ 330 | values, updates = tf.contrib.metrics.aggregate_metric_map(metrics) 331 | sess.run(initializer) 332 | while True: 333 | try: 334 | sess.run(updates) 335 | except tf.errors.OutOfRangeError: 336 | break 337 | return sess.run(values) 338 | 339 | 340 | def load_dataset(dataset: tf.data.Dataset) -> Dict[str, np.ndarray]: 341 | """Given a TensorFlow dataset, load it into memory as numpy arrays. 342 | 343 | Args: 344 | dataset: input dataset with some finite size. 345 | 346 | Returns: 347 | Dict of numpy arrays with concatenated data from the full input dataset. 348 | """ 349 | tensors = dataset.make_one_shot_iterator().get_next() 350 | metrics = { 351 | k: tf.contrib.metrics.streaming_concat(v) for k, v in tensors.items() 352 | } 353 | initializer = tf.local_variables_initializer() 354 | with tf.Session(config=_session_config()) as sess: 355 | return evaluate_metrics(sess, initializer, metrics) 356 | 357 | 358 | def determine_loss_scales( 359 | snapshots: np.ndarray, 360 | hparams: tf.contrib.training.HParams) -> Tuple[np.ndarray, np.ndarray]: 361 | """Determine scale factors for the loss. 362 | 363 | When passed into model.compute_loss, predictions of all zero should result 364 | in a loss of 1.0 when averaged over the full dataset. 365 | 366 | Args: 367 | snapshots: np.ndarray with shape [examples, x] with high-resolution 368 | training data. 369 | hparams: hyperparameters to use for training. 370 | 371 | Returns: 372 | Tuple of two numpy arrays: 373 | error_scale: array with dimensions [2, derivative] indicating the 374 | scaling in the loss to use on squared error and relative squared error 375 | for each derivative target. 376 | error_floor: numpy array with scale for weighting of relative errors. 377 | """ 378 | with tf.Graph().as_default(): 379 | dataset = model.make_dataset(snapshots, hparams, repeat=False) 380 | data = load_dataset(dataset) 381 | 382 | baseline = data['baseline'] 383 | labels = data['labels'] 384 | inputs = data['inputs'] 385 | 386 | # Handle cases where we use WENO for only ground truth labels or predictions 387 | if baseline.shape[-1] < labels.shape[-1]: 388 | labels = labels[..., 1:] 389 | elif baseline.shape[-1] > labels.shape[-1]: 390 | labels = np.concatenate([labels[..., :1], labels], axis=-1) 391 | 392 | baseline_error = (labels - baseline) ** 2 393 | percentile = 100 * hparams.error_floor_quantile 394 | error_floor = np.maximum( 395 | np.percentile(baseline_error, percentile, axis=(0, 1)), 1e-12) 396 | 397 | # predict zero for all derivatives, and a constant value for the integrated 398 | # solution over time. 399 | equation_type = equations.equation_type_from_hparams(hparams) 400 | num_zero_predictions = len(equation_type.DERIVATIVE_ORDERS) + 1 401 | labels_shape = labels.shape 402 | predictions = np.concatenate([ 403 | np.zeros(labels_shape[:-1] + (num_zero_predictions,)), 404 | np.repeat(inputs[..., np.newaxis], 405 | labels_shape[-1] - num_zero_predictions, 406 | axis=-1) 407 | ], axis=-1) 408 | 409 | components = np.stack(model.abs_and_rel_error(predictions=predictions, 410 | labels=labels, 411 | baseline=baseline, 412 | error_floor=error_floor)) 413 | baseline_error = np.mean(components, axis=(1, 2)) 414 | logging.info('baseline_error: %s', baseline_error) 415 | 416 | error_scale = np.where(baseline_error > 0, 1.0 / baseline_error, 0) 417 | return error_floor, error_scale 418 | 419 | 420 | def geometric_mean( 421 | x: np.ndarray, 422 | axis: Union[int, Tuple[int, ...]] = None 423 | ) -> Union[np.ndarray, np.generic]: 424 | """Calculate the geometric mean of an array.""" 425 | return np.exp(np.mean(np.log(x), axis)) 426 | 427 | 428 | def safe_abs(x: np.ndarray, epsilon: float = 1e-8) -> np.ndarray: 429 | """Absolute value guarantees to be larger than epsilon.""" 430 | return np.maximum(abs(x), epsilon) 431 | 432 | 433 | def calculate_metrics( 434 | data: Dict[str, np.ndarray], 435 | equation_type: Type[equations.Equation]) -> Dict[str, float]: 436 | """From a dict of inference results, calculate evaluation metrics. 437 | 438 | Args: 439 | data: evaluation metrics from steup_inference() passed through 440 | run_inference(). 441 | equation_type: type of equation being solved. 442 | 443 | Returns: 444 | Dict from evaluation metrics to scalar values. 445 | """ 446 | labels = data['labels'] 447 | baseline = data['baseline'] 448 | predictions = data['predictions'] 449 | 450 | # Handle cases where we use WENO for only ground truth labels or predictions 451 | if baseline.shape[-1] < labels.shape[-1]: 452 | labels = labels[..., 1:] 453 | elif baseline.shape[-1] > labels.shape[-1]: 454 | labels = np.concatenate([labels[..., :1], labels], axis=-1) 455 | 456 | mae = (np.mean(abs(labels - predictions), axis=(0, 1)) / 457 | np.mean(abs(labels - baseline), axis=(0, 1))) 458 | rms_error = np.sqrt( 459 | np.mean((labels - predictions) ** 2, axis=(0, 1)) / 460 | np.mean((labels - baseline) ** 2, axis=(0, 1))) 461 | mean_abs_relative_error = geometric_mean( 462 | safe_abs(labels - predictions) 463 | / safe_abs(labels - baseline), 464 | axis=(0, 1)) 465 | below_baseline = np.mean( 466 | (labels - predictions) ** 2 467 | < (labels - baseline) ** 2, axis=(0, 1)) 468 | 469 | metrics = {'count': len(labels)} 470 | metrics.update({k: float(v) for k, v in data.items() if 'loss' in k}) 471 | 472 | target_names = list(equation_type.DERIVATIVE_NAMES) + ['u_t'] 473 | assert labels.shape[-1] >= len(target_names) 474 | for i, target in enumerate(target_names): 475 | metrics.update({ 476 | 'mae/' + target: mae[i], 477 | 'rms_error/' + target: rms_error[i], 478 | 'mean_abs_relative_error/' + target: mean_abs_relative_error[i], 479 | 'frac_below_baseline/' + target: below_baseline[i], 480 | }) 481 | time_index = len(target_names) 482 | if time_index < labels.shape[-1]: 483 | target = 'u(t)' 484 | metrics.update({ 485 | 'mae/' + target: mae[time_index:].mean(), 486 | 'rms_error/' + target: rms_error[time_index:].mean(), 487 | 'mean_abs_relative_error/' + target: 488 | mean_abs_relative_error[time_index:].mean(), 489 | 'frac_below_baseline/' + target: below_baseline[time_index:].mean(), 490 | }) 491 | return metrics 492 | 493 | 494 | def metrics_one_linear(metrics: Dict[str, float]) -> str: 495 | """Summarize training metrics into a one line string.""" 496 | 497 | def matching_metrics_string(like, style='{}={:1.4f}', delimiter='/'): 498 | values = [(k.split('/')[-1], v) 499 | for k, v in sorted(metrics.items()) 500 | if like in k] 501 | return delimiter.join(style.format(*kv) for kv in values) 502 | 503 | return ('loss: {:1.7f}, abs_error: {}, rel_error: {}, below_baseline: {}' 504 | .format(metrics['loss'], 505 | matching_metrics_string('mae'), 506 | matching_metrics_string('mean_abs_relative_error'), 507 | matching_metrics_string('frac_below_baseline'))) 508 | 509 | 510 | class SaveAtEnd(tf.train.SessionRunHook): 511 | """A simple hook to save results at the end of training.""" 512 | 513 | def __init__(self, path): 514 | self.path = path 515 | 516 | def begin(self): 517 | self.saver = tf.train.Saver() 518 | 519 | def end(self, sess): 520 | self.saver.save(sess, self.path) 521 | 522 | 523 | def checkpoint_dir_to_path(checkpoint_dir: str) -> str: 524 | return os.path.join(checkpoint_dir, 'model.ckpt') 525 | 526 | 527 | def save_summaries(metrics: Dict[str, float], 528 | writer: tf.summary.FileWriter, 529 | global_step: int) -> None: 530 | """Log metrics with a tf.summary.FileWriter.""" 531 | values = [tf.Summary.Value(tag=k, simple_value=v) for k, v in metrics.items()] 532 | summary = tf.Summary(value=values) 533 | writer.add_summary(summary, global_step) 534 | writer.flush() 535 | 536 | 537 | def metrics_to_dataframe( 538 | logged_metrics: List[Tuple[int, Dict[str, float], Dict[str, float]]] 539 | ) -> pd.DataFrame: 540 | """Convert metrics into a single DataFrame, e.g., for saving as a CSV file.""" 541 | all_metrics = [] 542 | for step, test_metrics, train_metrics in logged_metrics: 543 | metrics = {'test_' + k: v for k, v in test_metrics.items()} 544 | metrics.update({'train_' + k: v for k, v in train_metrics.items()}) 545 | metrics['step'] = step 546 | all_metrics.append(metrics) 547 | return pd.DataFrame(all_metrics) 548 | 549 | 550 | def _session_config(): 551 | """Setup configuration for the TensorFlow session.""" 552 | # Disable graph rewrites (b/92797692) 553 | off = rewriter_config_pb2.RewriterConfig.OFF 554 | rewriter_config = rewriter_config_pb2.RewriterConfig( 555 | disable_model_pruning=True, 556 | constant_folding=off, 557 | arithmetic_optimization=off, 558 | remapping=off, 559 | shape_optimization=off, 560 | dependency_optimization=off, 561 | function_optimization=off, 562 | layout_optimizer=off, 563 | loop_optimization=off, 564 | memory_optimization=rewriter_config_pb2.RewriterConfig.NO_MEM_OPT) 565 | graph_options = config_pb2.GraphOptions( 566 | rewrite_options=rewriter_config) 567 | return config_pb2.ConfigProto(graph_options=graph_options) 568 | 569 | 570 | def training_loop(snapshots: np.ndarray, 571 | checkpoint_dir: str, 572 | hparams: tf.contrib.training.HParams, 573 | master: str = '') -> pd.DataFrame: 574 | """Run training. 575 | 576 | Args: 577 | snapshots: np.ndarray with shape [examples, x] with high-resolution 578 | training data. 579 | checkpoint_dir: directory to which to save model checkpoints. 580 | hparams: hyperparameters for training, as created by create_hparams(). 581 | master: string master to use for MonitoredTrainingSession. 582 | 583 | Returns: 584 | pd.DataFrame with metrics for the full training run. 585 | """ 586 | hparams = copy.deepcopy(hparams) 587 | set_data_dependent_hparams(hparams, snapshots) 588 | logging.info('Training with hyperparameters:\n%r', hparams) 589 | 590 | hparams_path = os.path.join(checkpoint_dir, 'hparams.pbtxt') 591 | with tf.gfile.GFile(hparams_path, 'w') as f: 592 | f.write(str(hparams.to_proto())) 593 | 594 | logging.info('Setting up training') 595 | _, train_step = setup_training(snapshots, hparams) 596 | train_inferer = Inferer(snapshots, hparams, training=True) 597 | test_inferer = Inferer(snapshots, hparams, training=False) 598 | 599 | global_step = tf.train.get_or_create_global_step() 600 | 601 | logging.info('Variables: %s', '\n'.join(map(str, tf.trainable_variables()))) 602 | 603 | logged_metrics = [] 604 | equation_type = equations.equation_type_from_hparams(hparams) 605 | 606 | with tf.train.MonitoredTrainingSession( 607 | master=master, 608 | checkpoint_dir=checkpoint_dir, 609 | save_checkpoint_secs=300, 610 | config=_session_config(), 611 | hooks=[SaveAtEnd(checkpoint_dir_to_path(checkpoint_dir))]) as sess: 612 | 613 | test_writer = tf.summary.FileWriter( 614 | os.path.join(checkpoint_dir, 'test'), sess.graph, flush_secs=60) 615 | train_writer = tf.summary.FileWriter( 616 | os.path.join(checkpoint_dir, 'train'), sess.graph, flush_secs=60) 617 | 618 | initial_step = sess.run(global_step) 619 | 620 | with test_writer, train_writer: 621 | for step in range(initial_step, hparams.learning_stops[-1]): 622 | sess.run(train_step) 623 | 624 | if (step + 1) % hparams.eval_interval == 0: 625 | train_inference_data = train_inferer.run(sess) 626 | test_inference_data = test_inferer.run(sess) 627 | 628 | train_metrics = calculate_metrics(train_inference_data, equation_type) 629 | test_metrics = calculate_metrics(test_inference_data, equation_type) 630 | logged_metrics.append((step, test_metrics, train_metrics)) 631 | 632 | logging.info(metrics_one_linear(test_metrics)) 633 | save_summaries(test_metrics, test_writer, global_step=step) 634 | save_summaries(train_metrics, train_writer, global_step=step) 635 | 636 | return metrics_to_dataframe(logged_metrics) 637 | 638 | 639 | def load_hparams(checkpoint_dir: str) -> tf.contrib.training.HParams: 640 | """Load saved hyperparameters from a checkpoint.""" 641 | hparams_path = os.path.join(checkpoint_dir, 'hparams.pbtxt') 642 | hparam_def = hparam_pb2.HParamDef() 643 | with tf.gfile.GFile(hparams_path, 'r') as f: 644 | text_format.Merge(f.read(), hparam_def) 645 | hparams = tf.contrib.training.HParams(hparam_def) 646 | # Set any new hparams not found in the file with default values. 647 | return create_hparams(**hparams.values()) 648 | --------------------------------------------------------------------------------