├── BUILD ├── README.md ├── reference_tests ├── README.md ├── call_to_tf_function_test.py ├── call_to_cast_function_test.py ├── call_to_numpy_function_test.py ├── assertion_test.py ├── two_level_call_tree_test.py ├── call_to_lambda_function_test.py ├── loop_with_else_test.py ├── generator_test.py ├── basic_ifexp_test.py ├── call_to_named_tuple_test.py ├── dynamic_call_tree_test.py ├── while_loop_function_call_mix_test.py ├── call_to_print_function_test.py ├── call_to_builtin_function_test.py ├── logical_expression_test.py ├── composite_names_in_control_flow_test.py ├── loop_control_flow_test.py ├── nested_control_flow_test.py ├── distributed_dataset_test.py ├── datasets_test.py ├── loop_scoping_test.py ├── early_return_test.py ├── loop_with_variable_type_illegal_cases_test.py ├── loop_with_function_call_test.py ├── cond_basic_test.py ├── loop_basic_test.py ├── loop_with_variable_type_test.py └── reference_test_base.py ├── examples └── sysml2019 │ ├── README.md │ ├── benchmark_base.py │ ├── maml_benchmark.py │ ├── seq2seq_benchmark.py │ ├── mnist_benchmark.py │ ├── lbfgs_benchmark.py │ ├── benchmark_dashboard.ipynb │ ├── rnn_benchmark.py │ └── beam_search_benchmark.py ├── CONTRIBUTING.md └── LICENSE /BUILD: -------------------------------------------------------------------------------- 1 | licenses(["notice"]) 2 | 3 | exports_files(["LICENSE"]) 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoGraph 2 | 3 | This repository contains tests and example code for [TensorFlow AutoGraph](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/autograph). For more information, see: 4 | 5 | * [tf.function and AutoGraph guide](https://www.tensorflow.org/beta/guide/autograph) 6 | * [AutoGraph reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/index.md) 7 | -------------------------------------------------------------------------------- /reference_tests/README.md: -------------------------------------------------------------------------------- 1 | # Autograph Internals 2 | 3 | This directory contains tests for Python idioms that AutoGraph supports. 4 | Since they are easy to read, they also double of small code samples. 5 | 6 | The BUILD file contains the full list of tests. 7 | 8 | ## Locating the samples inside tests 9 | 10 | Each test is structured as: 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | The sample functions are what demonstrate how code is authored for AutoGraph. 19 | 20 | The test in generale ensure that the sample code produces the same results when 21 | run in a TF graph as it would when executed as regular Python. 22 | -------------------------------------------------------------------------------- /examples/sysml2019/README.md: -------------------------------------------------------------------------------- 1 | To run the benchmarks on Linux, use the commands below. 2 | 3 | Tip: use a virtual environment to avoid clobbering your existing TF installation. 4 | 5 | Create a directory for the benchmark result files: 6 | 7 | export BENCHMARKS_DIR=/tmp/autograph/sysml2019_benchmarks/ 8 | export TEST_REPORT_FILE_PREFIX=${BENCHMARKS_DIR} 9 | export BENCHMARK_NUM_EXECUTIONS=100 10 | # Optionally, clean up the target directory: rm -Rf ${BENCHMARKS_DIR} 11 | mkdir -p ${BENCHMARKS_DIR} 12 | 13 | Run benchmarks: 14 | 15 | pip install tensorflow 16 | python beam_search_benchmark.py --benchmarks=. 17 | python lbfgs_benchmark.py --benchmarks=. 18 | python maml_benchmark.py --benchmarks=. 19 | python seq2seq_benchmark.py --benchmarks=. 20 | 21 | To parse the result files, start a Jupyter client and use the notebook: 22 | 23 | pip install jupyter 24 | pip install pandas 25 | jupyter notebook benchmark_dashboard.ipynb 26 | -------------------------------------------------------------------------------- /reference_tests/call_to_tf_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to a TF API function. 16 | 17 | The call will remain unchanged. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def core_tf_call(x): 29 | return x * tf.constant(2) 30 | 31 | 32 | class ReferenceTest(reference_test_base.TestCase): 33 | 34 | def test_basic(self): 35 | self.assertTfMatchesCompiled(core_tf_call, 1) 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /reference_tests/call_to_cast_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to cast functions and other computations.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def nested_cast(x): 26 | return float(int(x)) 27 | 28 | 29 | class ReferenceTest(reference_test_base.TestCase): 30 | 31 | def test_basic(self): 32 | self.assertNativeMatchesCompiled(nested_cast, 5) 33 | self.assertNativeMatchesCompiled(nested_cast, 3.0) 34 | 35 | 36 | if __name__ == '__main__': 37 | tf.test.main() 38 | -------------------------------------------------------------------------------- /reference_tests/call_to_numpy_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to a whitelisted Numpy function. 16 | 17 | The call should be wrapped in py_func. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import numpy as np 26 | import tensorflow.compat.v1 as tf 27 | 28 | 29 | def f(): 30 | return 2 * np.random.binomial(1, 0.5, size=(10,)) - 1 31 | 32 | 33 | class ReferenceTest(reference_test_base.TestCase): 34 | 35 | def test_basic(self): 36 | self.assertNativeMatchesCompiled(f) 37 | 38 | 39 | if __name__ == '__main__': 40 | tf.test.main() 41 | -------------------------------------------------------------------------------- /reference_tests/assertion_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic assertions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def simple_assertion(x): 26 | assert x > 0 27 | return x 28 | 29 | 30 | class ReferenceTest(reference_test_base.TestCase): 31 | 32 | def test_basic(self): 33 | self.assertNativeMatchesCompiled(simple_assertion, 1) 34 | with self.assertRaises(tf.errors.InvalidArgumentError): 35 | self.try_execute_compiled(simple_assertion, 0) 36 | 37 | 38 | if __name__ == '__main__': 39 | tf.test.main() 40 | -------------------------------------------------------------------------------- /reference_tests/two_level_call_tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Call to a second user function. 16 | 17 | The second function will be converted as well. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def f1(x): 29 | return x + 1 30 | 31 | 32 | def f2(x): 33 | return 2 * f1(x) 34 | 35 | 36 | class ReferenceTest(reference_test_base.TestCase): 37 | 38 | def test_basic(self): 39 | self.assertNativeMatchesCompiled(f1, 1) 40 | self.assertNativeMatchesCompiled(f2, 1) 41 | 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /reference_tests/call_to_lambda_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to lambda functions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def inline_lambda(x): 26 | l = lambda x: x * x if x > 0 else -x 27 | return l(x) 28 | 29 | 30 | def external_lambda(x, l): 31 | return l(x) 32 | 33 | 34 | class ReferenceTest(reference_test_base.TestCase): 35 | 36 | def test_inline(self): 37 | self.assertNativeMatchesCompiled(inline_lambda, 1) 38 | 39 | def test_external(self): 40 | self.assertNativeMatchesCompiled(external_lambda, 1, lambda x: x == 0) 41 | 42 | if __name__ == '__main__': 43 | tf.test.main() 44 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We'd love to have your patches and contributions! 4 | 5 | ## Contributor License Agreement 6 | 7 | Contributions to this project must be accompanied by a Contributor License 8 | Agreement. You (or your employer) retain the copyright to your contribution; 9 | this simply gives us permission to use and redistribute your contributions as 10 | part of the project. Head over to to see 11 | your current agreements on file or to sign a new one. 12 | 13 | You generally only need to submit a CLA once, so if you've already submitted one 14 | (even if it was for a different project), you probably don't need to do it 15 | again. 16 | 17 | ## Code reviews 18 | 19 | All submissions, including submissions by project members, require review. We 20 | use GitHub pull requests for this purpose. Consult [GitHub 21 | Help](https://help.github.com/articles/about-pull-requests/) for more 22 | information on using pull requests. 23 | 24 | After a pull request is approved, we merge it. Note our merging process differs 25 | from GitHub in that we pull and submit the change into an internal version 26 | control system. This system automatically pushes a git commit to the GitHub 27 | repository (with credit to the original author) and closes the pull request. 28 | 29 | ## Unit tests 30 | 31 | Please include unit tests when contributing new features, as they help to a) 32 | prove that your code works correctly, and b) guard against future breaking 33 | changes to lower the maintenance cost. It's also helpful to check that any 34 | changes you propose do not break existing unit tests. You can run tests using 35 | the command, 36 | -------------------------------------------------------------------------------- /reference_tests/loop_with_else_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Loops with the exotic else construct.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | tf.enable_v2_behavior() 26 | 27 | 28 | def for_else(l1, l2): 29 | s = 0 30 | for c in l1: 31 | if c in l2: 32 | break 33 | s = s * 10 + c 34 | else: 35 | s = -1000 36 | return s 37 | 38 | 39 | def while_else(x, y): 40 | s = 0 41 | while x > 0: 42 | x -= 1 43 | if x > y: 44 | break 45 | s += x 46 | else: 47 | s = -100 48 | return s 49 | 50 | 51 | class LoopControlFlowTest(reference_test_base.TestCase): 52 | 53 | def test_for_else(self): 54 | with self.assertRaisesRegex(NotImplementedError, 'for/else'): 55 | tf.function(for_else)([], []) 56 | 57 | def test_while_else(self): 58 | with self.assertRaisesRegex(NotImplementedError, 'while/else'): 59 | tf.function(while_else)(0, 0) 60 | 61 | 62 | if __name__ == '__main__': 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /reference_tests/generator_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Generators.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | tf.enable_v2_behavior() 26 | 27 | 28 | def basic_generator(): 29 | yield 1 30 | 31 | 32 | def generator_in_for(n): 33 | for i in range(n): 34 | yield i 35 | 36 | 37 | def generator_in_while(n): 38 | i = 0 39 | while i < n: 40 | i += 1 41 | yield i 42 | 43 | 44 | class LoopControlFlowTest(reference_test_base.TestCase): 45 | 46 | def test_basic_generator(self): 47 | with self.assertRaisesRegex(NotImplementedError, 'generators'): 48 | tf.function(basic_generator)() 49 | 50 | def test_generator_in_for(self): 51 | with self.assertRaisesRegex(NotImplementedError, 'generators'): 52 | tf.function(generator_in_for)([]) 53 | 54 | def test_generator_in_while(self): 55 | with self.assertRaisesRegex(NotImplementedError, 'generators'): 56 | tf.function(generator_in_while)(0) 57 | 58 | 59 | if __name__ == '__main__': 60 | tf.test.main() 61 | -------------------------------------------------------------------------------- /reference_tests/basic_ifexp_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic if conditional. 16 | 17 | The loop is converted to tf.cond. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def consecutive_conds(x): 29 | if x > 0: 30 | x = -x if x < 5 else x 31 | else: 32 | x = -2 * x if x < 5 else 1 33 | if x > 0: 34 | x = -x if x < 5 else x 35 | else: 36 | x = (3 * x) if x < 5 else x 37 | return x 38 | 39 | 40 | def cond_with_multiple_values(x): 41 | if x > 0: 42 | x = -x if x < 5 else x 43 | y = 2 * x if x < 5 else x 44 | z = -y if y < 5 else y 45 | else: 46 | x = 2 * x if x < 5 else x 47 | y = -x if x < 5 else x 48 | z = -y if y < 5 else y 49 | return x, y, z 50 | 51 | 52 | class ReferenceTest(reference_test_base.TestCase): 53 | 54 | def test_basic(self): 55 | for x in [-1, 1, 5]: 56 | self.assertNativeMatchesCompiled(consecutive_conds, x) 57 | self.assertNativeMatchesCompiled(cond_with_multiple_values, x) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /reference_tests/call_to_named_tuple_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to construct a namedtuple.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | 23 | import reference_test_base 24 | import tensorflow.compat.v1 as tf 25 | 26 | 27 | def inline_namedtuple(x): 28 | nt = collections.namedtuple('TestNamedTuple', ('a', 'b')) 29 | n = nt(a=1, b=x) 30 | return n 31 | 32 | 33 | def external_namedtuple(x, nt): 34 | return nt(a=1, b=x) 35 | 36 | 37 | class NamedTupleSubclass(collections.namedtuple('TestNamedTuple', ('a',))): 38 | 39 | def foo(self): 40 | return self.a + 1 41 | 42 | 43 | def namedtuple_subclass(x): 44 | nt = NamedTupleSubclass(x) 45 | return nt.foo() 46 | 47 | 48 | class ReferenceTest(reference_test_base.TestCase): 49 | 50 | def test_inline(self): 51 | self.assertNativeMatchesCompiled(inline_namedtuple, 1) 52 | 53 | def test_external(self): 54 | nt = collections.namedtuple('TestNamedTuple', ('a', 'b')) 55 | self.assertNativeMatchesCompiled(external_namedtuple, 1, nt) 56 | 57 | def test_subclass(self): 58 | self.assertNativeMatchesCompiled(namedtuple_subclass, 1) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /reference_tests/dynamic_call_tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Calls to dynamic functions. 16 | 17 | Dynamic functions include: 18 | * function variables 19 | * function parameters 20 | * factories 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import reference_test_base 28 | import tensorflow.compat.v1 as tf 29 | 30 | 31 | def function_1(x): 32 | return x * x * x 33 | 34 | 35 | def function_2(x): 36 | return -1 * x + 11 37 | 38 | 39 | def factory(): 40 | return function_1 41 | 42 | 43 | def factory_dynamic_fn(x): 44 | f = factory() 45 | return f(x) 46 | 47 | 48 | def param_dynamic_fn(f, x): 49 | return f(x) 50 | 51 | 52 | def variable_dynamic_fn(x): 53 | f = function_1 54 | return f(x) 55 | 56 | 57 | def variable_dynamic_whitelisted_fn(x): 58 | f = tf.identity 59 | return f(x) 60 | 61 | 62 | def dynamic_fn_with_kwargs(f, x): 63 | return f(x=x) 64 | 65 | 66 | class ReferenceTest(reference_test_base.TestCase): 67 | 68 | def test_basic(self): 69 | self.assertNativeMatchesCompiled(factory_dynamic_fn, 1) 70 | self.assertNativeMatchesCompiled(param_dynamic_fn, function_1, 1) 71 | self.assertNativeMatchesCompiled(variable_dynamic_fn, 1) 72 | self.assertTfMatchesCompiled(variable_dynamic_whitelisted_fn, 1) 73 | self.assertTfMatchesCompiled(dynamic_fn_with_kwargs, function_1, 1) 74 | 75 | 76 | if __name__ == '__main__': 77 | tf.test.main() 78 | -------------------------------------------------------------------------------- /reference_tests/while_loop_function_call_mix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """While loops mixed with function calls.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def basic_fn(x): 26 | return x * 2 27 | 28 | 29 | def function_call_inside_cond(n): 30 | i = 0 31 | s = 0 32 | while i < basic_fn(n): 33 | s += i 34 | i += 1 35 | return s 36 | 37 | 38 | def function_call_inside_body(n): 39 | i = 0 40 | s = 0 41 | while i < n: 42 | s += basic_fn(i) 43 | i += 1 44 | return s 45 | 46 | 47 | def print_inside_body(n): 48 | i = 0 49 | s = 0 50 | while i < n: 51 | s += i 52 | print(s) 53 | i += 1 54 | return s 55 | 56 | 57 | class ReferenceTest(reference_test_base.TestCase): 58 | """Base class for the reference tests.""" 59 | 60 | def setUp(self): 61 | super(ReferenceTest, self).setUp() 62 | self.convert = reference_test_base.tf_function_custom( 63 | tf.autograph.experimental.Feature.all_but( 64 | tf.autograph.experimental.Feature.AUTO_CONTROL_DEPS)) 65 | 66 | def test_basic(self): 67 | self.assertNativeMatchesCompiled(function_call_inside_cond, 3) 68 | self.assertNativeMatchesCompiled(function_call_inside_body, 3) 69 | self.assertNativeMatchesCompiled(print_inside_body, 3) 70 | 71 | 72 | if __name__ == '__main__': 73 | tf.test.main() 74 | -------------------------------------------------------------------------------- /examples/sysml2019/benchmark_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Common benchmarking code. 16 | 17 | See https://www.tensorflow.org/community/benchmarks for usage. 18 | To run all benchmarks, use "--benchmarks=.". 19 | Control the output directory using the "TEST_REPORT_FILE_PREFIX" environment 20 | variable. 21 | 22 | For the benchmarks in this directory, we used: 23 | 24 | TEST_REPORT_FILE_PREFIX=/tmp/autograph/sysml2019_benchmarks/ 25 | 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import os 33 | import time 34 | 35 | import tensorflow as tf 36 | 37 | 38 | class ReportingBenchmark(tf.test.Benchmark): 39 | """Base class for a benchmark that reports general performance metrics.""" 40 | 41 | def time_execution(self, 42 | name, 43 | target, 44 | iters=None, 45 | warm_up_iters=3, 46 | iter_volume=None, 47 | iter_unit=None, 48 | extras=None): 49 | if iters is None: 50 | iters = int(os.environ.get('BENCHMARK_NUM_EXECUTIONS', 50)) 51 | 52 | for _ in range(warm_up_iters): 53 | target() 54 | 55 | all_times = [] 56 | for _ in range(iters): 57 | iter_time = time.time() 58 | target() 59 | all_times.append(time.time() - iter_time) 60 | 61 | extras = dict(extras) if extras else {} 62 | 63 | extras['all_times'] = all_times 64 | 65 | extras['name'] = name 66 | if isinstance(name, tuple): 67 | name = '_'.join(str(piece) for piece in name) 68 | 69 | # TODO(mdanatg): This is unnecessary - use normal extras. 70 | if iter_volume is not None: 71 | assert iter_unit is not None 72 | extras['iter_volume'] = iter_volume 73 | extras['iter_unit'] = iter_unit 74 | 75 | self.report_benchmark( 76 | iters=iters, wall_time=sum(all_times), name=name, extras=extras) 77 | 78 | 79 | if __name__ == '__main__': 80 | tf.test.main() 81 | -------------------------------------------------------------------------------- /reference_tests/call_to_print_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to a print function preceding other computations. 16 | 17 | The call may be wrapped inside a py_func, but tf.Print should be used if 18 | possible. The subsequent computations will be gated by the print function 19 | execution. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import reference_test_base 27 | import numpy as np 28 | import tensorflow.compat.v1 as tf 29 | 30 | 31 | # TODO(mdan): Allow functions that do not have a return value. 32 | # (either that or raise an error if a return value is missing) 33 | 34 | 35 | def lone_print(x): 36 | print(x) 37 | return x + 1 38 | 39 | 40 | def print_multiple_values(x): 41 | print('x is', x) 42 | return x + 1 43 | 44 | 45 | def multiple_prints(x, y): 46 | print('x is', x) 47 | print('y is', y) 48 | return x + 1 49 | 50 | 51 | def print_with_nontf_values(x): 52 | print('x is', x, {'foo': 'bar'}) 53 | return x + 1 54 | 55 | 56 | def print_in_cond(x): 57 | if x == 0: 58 | print(x) 59 | return x 60 | 61 | 62 | def tf_print(x): 63 | tf.print(x) 64 | return x + 1 65 | 66 | 67 | class ReferenceTest(reference_test_base.TestCase): 68 | 69 | def test_lone_print(self): 70 | self.assertNativeMatchesCompiled(lone_print, 1) 71 | self.assertNativeMatchesCompiled(lone_print, np.array([1, 2, 3])) 72 | 73 | def test_print_multiple_values(self): 74 | self.assertNativeMatchesCompiled(print_multiple_values, 1) 75 | self.assertNativeMatchesCompiled(print_multiple_values, np.array([1, 2, 3])) 76 | 77 | def test_multiple_prints(self): 78 | self.assertNativeMatchesCompiled(multiple_prints, 1, 2) 79 | self.assertNativeMatchesCompiled(multiple_prints, np.array([1, 2, 3]), 4) 80 | 81 | def test_print_with_nontf_values(self): 82 | self.assertNativeMatchesCompiled(print_with_nontf_values, 1) 83 | self.assertNativeMatchesCompiled(print_with_nontf_values, 84 | np.array([1, 2, 3])) 85 | 86 | def test_print_in_cond(self): 87 | self.assertNativeMatchesCompiled(print_in_cond, 0) 88 | self.assertNativeMatchesCompiled(print_in_cond, 1) 89 | 90 | def test_tf_print(self): 91 | self.assertTfMatchesCompiled(tf_print, 0) 92 | 93 | 94 | if __name__ == '__main__': 95 | tf.test.main() 96 | -------------------------------------------------------------------------------- /reference_tests/call_to_builtin_function_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simple call to a builtin function.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import parameterized 22 | import reference_test_base 23 | import mock 24 | import six 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | # TODO(mdan): Add tests for all builtins. 29 | 30 | 31 | def xrange_call(x): 32 | return list(six.moves.xrange(x)) 33 | 34 | 35 | def dict_call(x): 36 | return dict(foo=x) 37 | 38 | 39 | def dict_call_aliased(x): 40 | def fake_dict(x): 41 | return x 42 | 43 | dict = fake_dict # pylint:disable=redefined-builtin 44 | return dict(x) 45 | 46 | 47 | def dict_call_dynamic(x): 48 | def gen_dict(): 49 | return dict 50 | 51 | d = gen_dict() 52 | return d(foo=x) 53 | 54 | 55 | def len_call(x): 56 | return len(x) 57 | 58 | 59 | def nested_call(x): 60 | return list(range(len(x))) 61 | 62 | 63 | def len_call_aliased(x): 64 | 65 | def fake_len(x): 66 | return x 67 | 68 | len = fake_len # pylint:disable=redefined-builtin 69 | return len(x) 70 | 71 | 72 | def len_call_dynamic(x): 73 | 74 | def gen_len(): 75 | return len 76 | 77 | l = gen_len() 78 | return l(x) 79 | 80 | 81 | def len_call_on_mock(): 82 | x = mock.MagicMock() 83 | return len(x) 84 | 85 | 86 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 87 | 88 | @parameterized.named_parameters([ 89 | ('to_graph', reference_test_base.to_graph), 90 | ('to_graph_nonrecursive', reference_test_base.to_graph_nonrecursive), 91 | ]) 92 | def test_basic(self, conversion_func): 93 | self.convert = conversion_func 94 | self.assertNativeMatchesCompiled(dict_call, 1) 95 | self.assertNativeMatchesCompiled(len_call, [1, 2]) 96 | self.assertNativeMatchesCompiled(dict_call_aliased, 1) 97 | self.assertNativeMatchesCompiled(len_call_aliased, [1, 2]) 98 | self.assertNativeMatchesCompiled(dict_call_dynamic, 1) 99 | self.assertNativeMatchesCompiled(len_call_dynamic, [1, 2]) 100 | self.assertNativeMatchesCompiled(nested_call, [1, 2, 3]) 101 | self.assertNativeMatchesCompiled(nested_call, [1, 2, 3]) 102 | if six.PY2: 103 | self.assertNativeMatchesCompiled(xrange_call, 3) 104 | 105 | 106 | if __name__ == '__main__': 107 | tf.test.main() 108 | -------------------------------------------------------------------------------- /reference_tests/logical_expression_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic logical expressions that are not autoboxed to TF.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v1 as tf 23 | 24 | 25 | def composite_ors_with_callable(x, y, z): 26 | z1 = lambda: z 27 | return x or y or z1() 28 | 29 | 30 | def composite_ors(x, y, z): 31 | return x or y or z 32 | 33 | 34 | def composite_ands(x, y, z): 35 | return x and y and z 36 | 37 | 38 | def composite_mixed(x, y, z): 39 | return x or y or z and y and z 40 | 41 | 42 | def equality(x, y): 43 | return x == y 44 | 45 | 46 | def inequality(x, y): 47 | return x != y 48 | 49 | 50 | def multiple_equality(x, y, z): 51 | return x == y == z 52 | 53 | 54 | def comparison(x, y, z): 55 | return x < y and y < z 56 | 57 | 58 | class ReferenceTest(reference_test_base.TestCase): 59 | 60 | def test_basic(self): 61 | self.assertNativeMatchesCompiled(composite_ors, False, True, False) 62 | self.assertNativeMatchesCompiled(composite_ors, False, False, False) 63 | self.assertNativeMatchesCompiled(composite_ands, True, True, True) 64 | self.assertNativeMatchesCompiled(composite_ands, True, False, True) 65 | self.assertNativeMatchesCompiled(composite_mixed, False, True, True) 66 | self.assertNativeMatchesCompiled(composite_ors_with_callable, False, True, 67 | False) 68 | self.assertNativeMatchesCompiled(composite_ors_with_callable, False, False, 69 | True) 70 | self.assertNativeMatchesCompiled(composite_ors_with_callable, False, False, 71 | False) 72 | 73 | self.assertNativeMatchesCompiled(equality, 1, 1) 74 | self.assertNativeMatchesCompiled(equality, 1, 2) 75 | self.assertNativeMatchesCompiled(inequality, 1, 1) 76 | self.assertNativeMatchesCompiled(inequality, 1, 2) 77 | self.assertNativeMatchesCompiled(multiple_equality, 1, 1, 2) 78 | self.assertNativeMatchesCompiled(multiple_equality, 1, 1, 1) 79 | 80 | self.assertNativeMatchesCompiled(comparison, 1, 2, 3) 81 | self.assertNativeMatchesCompiled(comparison, 2, 1, 3) 82 | self.assertNativeMatchesCompiled(comparison, 3, 2, 1) 83 | self.assertNativeMatchesCompiled(comparison, 3, 1, 2) 84 | self.assertNativeMatchesCompiled(comparison, 1, 3, 2) 85 | self.assertNativeMatchesCompiled(comparison, 2, 3, 1) 86 | 87 | 88 | if __name__ == '__main__': 89 | tf.test.main() 90 | -------------------------------------------------------------------------------- /reference_tests/composite_names_in_control_flow_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Composite names (attributes) in control flow. 16 | 17 | Generally, composite symbols should be treated like regular ones. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def cond_basic(x): 29 | if x.a > 0: 30 | x.b = 1 31 | else: 32 | x.b = -1 33 | return x 34 | 35 | 36 | def while_basic(x, y): 37 | while x > 0: 38 | x -= 1 39 | y.a += 1 40 | return y 41 | 42 | 43 | def while_state_only(y): 44 | while y.b <= 10: 45 | y.a += 1 46 | y.b *= 2 47 | return y 48 | 49 | 50 | def for_basic(n, x, y): 51 | for i in range(n): 52 | x -= 1 53 | y.a += i 54 | return y 55 | 56 | 57 | def for_state_only(n, y): 58 | for _ in range(n): 59 | y.a += 1 60 | return y 61 | 62 | 63 | # TODO(mdan): More tests needed. Many pitfalls around mutating objects this way. 64 | 65 | 66 | class ReferenceTest(reference_test_base.TestCase): 67 | 68 | def test_cond_basic(self): 69 | self.assertNativeMatchesCompiled( 70 | cond_basic, 71 | reference_test_base.MutableContainer(a=1, b=0), 72 | ) 73 | self.assertNativeMatchesCompiled( 74 | cond_basic, 75 | reference_test_base.MutableContainer(a=0, b=0), 76 | ) 77 | 78 | def test_while_basic(self): 79 | self.assertNativeMatchesCompiled( 80 | while_basic, 81 | 3, 82 | reference_test_base.MutableContainer(a=3, b=0), 83 | ) 84 | self.assertNativeMatchesCompiled( 85 | while_basic, 86 | 0, 87 | reference_test_base.MutableContainer(a=7, b=0), 88 | ) 89 | 90 | def test_while_state_only(self): 91 | self.assertNativeMatchesCompiled( 92 | while_state_only, 93 | reference_test_base.MutableContainer(a=3, b=1), 94 | ) 95 | self.assertNativeMatchesCompiled( 96 | while_state_only, 97 | reference_test_base.MutableContainer(a=7, b=10), 98 | ) 99 | 100 | def test_for_basic(self): 101 | self.assertNativeMatchesCompiled( 102 | for_basic, 103 | 5, 104 | 3, 105 | reference_test_base.MutableContainer(a=3, b=0), 106 | ) 107 | self.assertNativeMatchesCompiled( 108 | for_basic, 109 | 5, 110 | 0, 111 | reference_test_base.MutableContainer(a=7, b=0), 112 | ) 113 | 114 | def test_for_state_only(self): 115 | self.assertNativeMatchesCompiled( 116 | for_state_only, 117 | 5, 118 | reference_test_base.MutableContainer(a=3, b=0), 119 | ) 120 | self.assertNativeMatchesCompiled( 121 | for_state_only, 122 | 0, 123 | reference_test_base.MutableContainer(a=7, b=0), 124 | ) 125 | 126 | 127 | if __name__ == '__main__': 128 | tf.test.main() 129 | -------------------------------------------------------------------------------- /reference_tests/loop_control_flow_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Nested loops and loop control statements (e.g. break and continue). 16 | 17 | Meant to verify that: 18 | * break/continue in the inner loop does not affect outer loop 19 | * break/continue inside nested conditionals still works 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import reference_test_base 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | def continue_in_single_for(l): 31 | s = 0 32 | for c in l: 33 | if c % 2 > 0: 34 | continue 35 | s += c 36 | return s 37 | 38 | 39 | def continue_in_single_while(x): 40 | s = 0 41 | while x > 0: 42 | x -= 1 43 | if x % 2 > 0: 44 | continue 45 | s += x 46 | return s 47 | 48 | 49 | def continue_in_inner_for(m): 50 | s = 0 51 | for l in m: 52 | for c in l: 53 | if c % 2 > 0: 54 | continue 55 | s += c 56 | return s 57 | 58 | 59 | def continue_in_inner_while(x, y): 60 | s = 0 61 | while x > 0: 62 | x -= 1 63 | while y > 0: 64 | y -= 1 65 | if (x + y) % 2 > 0: 66 | continue 67 | s += x + y 68 | return s 69 | 70 | 71 | def break_in_single_for(l): 72 | s = 0 73 | for c in l: 74 | if c % 2 > 0: 75 | break 76 | s += c 77 | return s 78 | 79 | 80 | def break_in_single_while(x): 81 | s = 0 82 | while x > 0: 83 | x -= 1 84 | if x % 2 > 0: 85 | break 86 | s += x 87 | return s 88 | 89 | 90 | def break_in_inner_for(m): 91 | s = 0 92 | for l in m: 93 | for c in l: 94 | if c % 2 > 0: 95 | break 96 | s += c 97 | return s 98 | 99 | 100 | def break_in_inner_while(x, y): 101 | s = 0 102 | while x > 0: 103 | x -= 1 104 | while y > 0: 105 | y -= 1 106 | if ((x + y) % 2) == 0: 107 | break 108 | s += x + y 109 | return s 110 | 111 | 112 | def break_continue_in_inner_for(m): 113 | s = 0 114 | for l in m: 115 | for c in l: 116 | if c % 2 > 0: 117 | break 118 | else: 119 | continue 120 | s += c 121 | return s 122 | 123 | 124 | def break_continue_in_inner_while(x, y): 125 | s = 0 126 | while x > 0: 127 | x -= 1 128 | while y > 0: 129 | y -= 1 130 | if (x + y) % 2 > 0: 131 | break 132 | else: 133 | continue 134 | s += x + y 135 | return s 136 | 137 | 138 | def break_followed_by_cond_in_single_for(x, y): 139 | for i in range(y): 140 | if i == 2: 141 | break 142 | if x > 0: 143 | x -= 1 144 | return x 145 | 146 | 147 | def break_followed_by_cond_in_single_while(x): 148 | while x > 0: 149 | if x == 2: 150 | break 151 | if x > 0: 152 | x -= 1 153 | return x 154 | 155 | 156 | def multiple_breaks_in_single_while(n): 157 | s = 1 158 | i = 0 159 | while i < n: 160 | i += 1 161 | if i > 10 * n: 162 | break 163 | if i == n: 164 | break 165 | s = s * 10 + i 166 | return i, s 167 | 168 | 169 | class LoopControlFlowTest(reference_test_base.TestCase): 170 | 171 | def test_continue_in_single_for(self): 172 | self.assertNativeMatchesCompiled(continue_in_single_for, 173 | [1, 2, 3, 4, 5, 6]) 174 | 175 | def test_continue_in_single_while(self): 176 | self.assertNativeMatchesCompiled(continue_in_single_while, 7) 177 | 178 | def test_continue_in_inner_for(self): 179 | self.assertNativeMatchesCompiled(continue_in_inner_for, 180 | [[1, 2, 3], [4, 5, 6]]) 181 | 182 | def test_continue_in_inner_while(self): 183 | self.assertNativeMatchesCompiled(continue_in_inner_while, 10, 11) 184 | 185 | def test_break_in_single_for(self): 186 | self.assertNativeMatchesCompiled(break_in_single_for, [1, 2, 3, 4, 5, 6]) 187 | 188 | def test_break_in_single_while(self): 189 | self.assertNativeMatchesCompiled(break_in_single_while, 7) 190 | 191 | def test_break_in_inner_for(self): 192 | self.assertNativeMatchesCompiled(break_in_inner_for, 193 | [[1, 2, 3], [4, 5, 6]]) 194 | 195 | def test_break_in_inner_while(self): 196 | self.assertNativeMatchesCompiled(break_in_inner_while, 10, 11) 197 | 198 | def test_break_continue_in_inner_for(self): 199 | self.assertNativeMatchesCompiled(break_continue_in_inner_for, 200 | [[1, 2, 3], [4, 5, 6]]) 201 | 202 | def test_break_continue_in_inner_while(self): 203 | self.assertNativeMatchesCompiled(break_continue_in_inner_while, 10, 11) 204 | 205 | def test_break_followed_by_cond_in_single_for(self): 206 | self.assertNativeMatchesCompiled(break_followed_by_cond_in_single_for, 3, 3) 207 | 208 | def test_break_followed_by_cond_in_single_while(self): 209 | self.assertNativeMatchesCompiled(break_followed_by_cond_in_single_while, 3) 210 | 211 | def test_multiple_breaks_in_single_while(self): 212 | self.assertNativeMatchesCompiled(multiple_breaks_in_single_while, 0) 213 | self.assertNativeMatchesCompiled(multiple_breaks_in_single_while, 2) 214 | self.assertNativeMatchesCompiled(multiple_breaks_in_single_while, 5) 215 | 216 | 217 | if __name__ == '__main__': 218 | tf.test.main() 219 | -------------------------------------------------------------------------------- /reference_tests/nested_control_flow_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Nested loops and conditional statements (e.g. while, for, if). 16 | 17 | Meant to verify that arbitrarily nested statements are processed correctly. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import reference_test_base 25 | import tensorflow.compat.v1 as tf 26 | 27 | 28 | def independent_ifs(x, y): 29 | z = 0 30 | if x > 0: 31 | if y > 0: 32 | z = x + y 33 | return z 34 | 35 | 36 | def dependent_inner_if(x): 37 | y = 0 38 | if x > 0: 39 | y = -2 * x 40 | if y > 0: 41 | x = -3 * x 42 | else: 43 | y = 4 * x 44 | return x, y 45 | 46 | 47 | def dependent_imbalanced_inner_if(x): 48 | y = 0 49 | if x > 0: 50 | if x < 3: 51 | y = -2 * x 52 | x = -3 * x 53 | return x, y 54 | 55 | 56 | def independent_inner_for(a, b): 57 | p = 0 58 | for _ in a: 59 | tmp = b 60 | for j in tmp: 61 | p += j 62 | return p 63 | 64 | 65 | def independent_inner_while(a, b): 66 | p = 0 67 | while a > 0: 68 | tmp = b 69 | while tmp > 0: 70 | p += 1 71 | tmp -= 1 72 | a -= 1 73 | return p 74 | 75 | 76 | def dependent_inner_for(a, b): 77 | r = 1 78 | s = 0 79 | for _ in a: 80 | r += s 81 | tmp = b 82 | for j in tmp: 83 | s += j 84 | return r 85 | 86 | 87 | def dependent_inner_while(a, b): 88 | r = 1 89 | while a > 0: 90 | r += 1 91 | tmp = b 92 | while tmp > 0: 93 | a -= 1 94 | tmp -= 1 95 | return r 96 | 97 | 98 | def if_in_for(a): 99 | k = 0 100 | for i in a: 101 | if i % 2 > 0: 102 | j = i // 2 103 | k += j 104 | return k 105 | 106 | 107 | def while_with_continue_in_context_manager(x): 108 | z = 0 109 | while x > 0: 110 | with tf.name_scope(''): 111 | x = x - 1 112 | if x < 5: 113 | continue 114 | z = z + 1 115 | return z 116 | 117 | 118 | def while_continue_in_try(x): 119 | z = 0 120 | while x > 0: 121 | x = x - 1 122 | try: 123 | if x < 5: 124 | continue 125 | z = z + 1 126 | finally: 127 | z = z + 10 128 | return z 129 | 130 | 131 | def while_break_in_context_manager(x): 132 | z = 0 133 | while x > 0: 134 | with tf.name_scope(''): 135 | x = x - 1 136 | if x < 5: 137 | break 138 | z = z + 1 139 | return z 140 | 141 | 142 | def while_break_in_try(x): 143 | z = 0 144 | while x > 0: 145 | x = x - 1 146 | try: 147 | if x < 5: 148 | break 149 | z = z + 1 150 | finally: 151 | z = z + 10 152 | return z 153 | 154 | 155 | class NestedControlFlowTest(reference_test_base.TestCase): 156 | 157 | def test_independent_ifs(self): 158 | self.assertNativeMatchesCompiled(independent_ifs, 1, 1) 159 | self.assertNativeMatchesCompiled(independent_ifs, 1, -1) 160 | self.assertNativeMatchesCompiled(independent_ifs, -1, 1) 161 | self.assertNativeMatchesCompiled(independent_ifs, -1, 1) 162 | 163 | def test_dependent_inner_if(self): 164 | self.assertNativeMatchesCompiled(dependent_inner_if, 1) 165 | self.assertNativeMatchesCompiled(dependent_inner_if, -1) 166 | 167 | def test_dependent_imbalanced_inner_if(self): 168 | self.assertNativeMatchesCompiled(dependent_imbalanced_inner_if, 1) 169 | self.assertNativeMatchesCompiled(dependent_imbalanced_inner_if, -1) 170 | 171 | def test_independent_inner_for(self): 172 | self.assertNativeMatchesCompiled( 173 | independent_inner_for, list(range(3)), list(range(5))) 174 | 175 | def test_independent_inner_while(self): 176 | self.assertNativeMatchesCompiled(independent_inner_while, 3, 5) 177 | 178 | def test_dependent_inner_for(self): 179 | self.assertNativeMatchesCompiled( 180 | dependent_inner_for, list(range(31)), list(range(7))) 181 | 182 | def test_dependent_inner_while(self): 183 | self.assertNativeMatchesCompiled(dependent_inner_while, 31, 7) 184 | 185 | def test_if_in_for(self): 186 | self.assertNativeMatchesCompiled(if_in_for, list(range(7))) 187 | 188 | def test_while_continue_in_context_manager(self): 189 | self.assertNativeMatchesCompiled(while_with_continue_in_context_manager, 10) 190 | self.assertNativeMatchesCompiled(while_with_continue_in_context_manager, 4) 191 | self.assertNativeMatchesCompiled(while_with_continue_in_context_manager, 0) 192 | 193 | def test_while_continue_in_try(self): 194 | self.assertNativeMatchesCompiled(while_continue_in_try, 10) 195 | self.assertNativeMatchesCompiled(while_continue_in_try, 4) 196 | self.assertNativeMatchesCompiled(while_continue_in_try, 0) 197 | 198 | def test_while_break_in_context_manager(self): 199 | self.assertNativeMatchesCompiled(while_break_in_context_manager, 10) 200 | self.assertNativeMatchesCompiled(while_break_in_context_manager, 4) 201 | self.assertNativeMatchesCompiled(while_break_in_context_manager, 0) 202 | 203 | def test_while_break_in_try(self): 204 | self.assertNativeMatchesCompiled(while_break_in_try, 10) 205 | self.assertNativeMatchesCompiled(while_break_in_try, 4) 206 | self.assertNativeMatchesCompiled(while_break_in_try, 0) 207 | 208 | if __name__ == '__main__': 209 | tf.test.main() 210 | -------------------------------------------------------------------------------- /reference_tests/distributed_dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests involving the tf.distributed datasets.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | tf.enable_v2_behavior() 26 | 27 | 28 | def dataset_no_vars_loop(ds, dds): 29 | for pr in dds: 30 | tf.print(ds.reduce('SUM', pr, axis=None)) 31 | 32 | 33 | def iterator_no_vars_loop(ds, dds): 34 | for pr in iter(dds): 35 | tf.print(ds.reduce('SUM', pr, axis=None)) 36 | 37 | 38 | def dataset_single_var_loop(ds, dds): 39 | s = 0 40 | for pr in dds: 41 | # TODO(mdan): It would be nice to be able to write s = s * 10 + pr. 42 | s = s * 10 + ds.reduce('SUM', pr, axis=None) 43 | # TODO(mdan): This looks like a bug. 44 | s.set_shape(()) 45 | return s 46 | 47 | 48 | def iterator_single_var_loop(ds, dds): 49 | s = 0 50 | for pr in iter(dds): 51 | s = s * 10 + ds.reduce('SUM', pr, axis=None) 52 | return s 53 | 54 | 55 | def dataset_two_vars_loop(ds, dds): 56 | s = 0 57 | p = 1 58 | for pr in dds: 59 | e = ds.reduce('SUM', pr, axis=None) 60 | e.set_shape(()) 61 | s += e 62 | p *= e 63 | return s, p 64 | 65 | 66 | def iterator_two_vars_loop(ds, dds): 67 | s = 0 68 | p = 1 69 | for pr in iter(dds): 70 | e = ds.reduce('SUM', pr, axis=None) 71 | e.set_shape(()) 72 | s += e 73 | p *= e 74 | return s, p 75 | 76 | 77 | def dataset_enumeration(ds, dds): 78 | s = 0 79 | p = 1 80 | for i, pr in enumerate(dds): 81 | e = ds.reduce('SUM', pr, axis=None) 82 | e.set_shape(()) 83 | s = s * 10 + e 84 | p *= i 85 | return s, p 86 | 87 | 88 | def iterator_next(ds, dds): 89 | itr = iter(dds) 90 | return ds.reduce('SUM', next(itr), axis=None) 91 | 92 | 93 | def iterator_next_multiple_calls(ds, dds): 94 | itr = iter(dds) 95 | a = ds.reduce('SUM', next(itr), axis=None) 96 | b = ds.reduce('SUM', next(itr), axis=None) 97 | return a * 10 + b 98 | 99 | 100 | def iterator_next_in_limited_loop(ds, dds, n): 101 | itr = iter(dds) 102 | s = 0 103 | for _ in range(n): 104 | s = s * 10 + ds.reduce('SUM', next(itr), axis=None) 105 | return s 106 | 107 | 108 | def iterator_next_stopping(ds, dds, cond): 109 | # This case will raise, but not the expected StopIteration error. 110 | itr = iter(dds) 111 | while cond: 112 | ds.reduce('SUM', next(itr), axis=None) 113 | 114 | 115 | def iterator_next_with_catching_stop_iteration(ds, dds, cond): 116 | # This is the one instance when the use of TF iterators does not work as 117 | # intended. In graph mode, the `except` below will never catch, and the 118 | # tf.function will raise the error instead. 119 | # TODO(b/132311724): The error should be friendlier here. 120 | # Note: b/132298783 covers actually supporting this pattern. 121 | itr = iter(dds) 122 | try: 123 | while cond: 124 | ds.reduce('SUM', next(itr), axis=None) 125 | except StopIteration: 126 | pass 127 | 128 | 129 | class ReferenceTest(reference_test_base.TestCase): 130 | 131 | def setUp(self): 132 | super(ReferenceTest, self).setUp() 133 | cpus = tf.config.experimental.list_physical_devices('CPU') 134 | tf.config.experimental.set_virtual_device_configuration( 135 | cpus[0], [tf.config.experimental.VirtualDeviceConfiguration()] * 2) 136 | 137 | strategy = tf.distribute.MirroredStrategy() 138 | dataset = tf.data.Dataset.from_tensor_slices( 139 | tf.reshape(tf.range(40), (10, 4))) 140 | 141 | self.ds = strategy 142 | self.dds = strategy.experimental_distribute_dataset(dataset) 143 | 144 | def test_dataset_no_vars_loop(self): 145 | self.assertFunctionMatchesEager(dataset_no_vars_loop, self.ds, self.dds) 146 | 147 | def test_iterator_no_vars_loop(self): 148 | with self.assertRaises(RuntimeError): 149 | tf.function(iterator_no_vars_loop)(self.ds, self.dds) 150 | 151 | def test_dataset_single_var_loop(self): 152 | self.assertFunctionMatchesEager(dataset_single_var_loop, self.ds, self.dds) 153 | 154 | def test_iterator_single_var_loop(self): 155 | with self.assertRaises(RuntimeError): 156 | tf.function(iterator_single_var_loop)(self.ds, self.dds) 157 | 158 | def test_dataset_two_vars_loop(self): 159 | self.assertFunctionMatchesEager(dataset_two_vars_loop, self.ds, self.dds) 160 | 161 | def test_iterator_two_vars_loop(self): 162 | with self.assertRaises(RuntimeError): 163 | tf.function(iterator_two_vars_loop)(self.ds, self.dds) 164 | 165 | def test_iterator_next(self): 166 | self.assertFunctionMatchesEager(iterator_next, self.ds, self.dds) 167 | 168 | def test_iterator_next_multiple_calls(self): 169 | self.assertFunctionMatchesEager(iterator_next_multiple_calls, self.ds, 170 | self.dds) 171 | 172 | def test_iterator_next_in_limited_loop(self): 173 | self.assertFunctionMatchesEager(iterator_next_in_limited_loop, self.ds, 174 | self.dds, 0) 175 | self.assertFunctionMatchesEager(iterator_next_in_limited_loop, self.ds, 176 | self.dds, 1) 177 | self.assertFunctionMatchesEager(iterator_next_in_limited_loop, self.ds, 178 | self.dds, 3) 179 | 180 | def test_iterator_next_stopping(self): 181 | with self.assertRaises(tf.errors.OutOfRangeError): 182 | tf.function(iterator_next_stopping)(self.ds, self.dds, tf.constant(True)) 183 | 184 | def test_iterator_next_with_catching_stop_iteration(self): 185 | with self.assertRaises(tf.errors.OutOfRangeError): 186 | tf.function(iterator_next_with_catching_stop_iteration)(self.ds, self.dds, 187 | tf.constant(True)) 188 | 189 | 190 | if __name__ == '__main__': 191 | tf.test.main() 192 | -------------------------------------------------------------------------------- /examples/sysml2019/maml_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark for a basic MAML implementation. 16 | 17 | Based on implementation found at https://github.com/cbfinn/maml. 18 | 19 | This code requires TF 2.0 or newer. 20 | Use `pip install tf-nightly-2.0-preview` to install. 21 | Tip: Run the `pip install` command in a separate virtual environment 22 | (Virtualenv, Anaconda) to avoid clobbering an existing TF installation. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import benchmark_base 30 | 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | 35 | AMP_RANGE = (0.1, 5.0) 36 | PHASE_RANGE = (0.0, np.pi) 37 | INPUT_RANGE = (-5.0, 5.0) 38 | INPUT_SIZE = 1 39 | OUTPUT_SIZE = 1 40 | LOCAL_TRAIN_EXAMPLES = 10 41 | LOCAL_VALID_EXAMPLES = 10 42 | META_BATCH_SIZE = 25 43 | 44 | HIDDEN_SIZES = (40, 40) 45 | 46 | LOCAL_LEARNING_RATE = 1e-3 47 | LOCAL_LEARNING_STEPS = 1 48 | LOCAL_LEARNING_STEPS_TEST = 5 49 | TRAIN_SAMPES_SLICE = slice(0, LOCAL_TRAIN_EXAMPLES) 50 | VALID_SAMPES_SLICE = slice(LOCAL_TRAIN_EXAMPLES, 51 | LOCAL_TRAIN_EXAMPLES + LOCAL_VALID_EXAMPLES) 52 | 53 | META_LR = 0.001 54 | 55 | 56 | def sin_dataset(): 57 | """Procedural dataset that generates samples for the sin function.""" 58 | assert INPUT_SIZE == OUTPUT_SIZE 59 | 60 | total_samples = LOCAL_TRAIN_EXAMPLES + LOCAL_VALID_EXAMPLES 61 | 62 | def gen_input_outputs(_): 63 | amp = tf.random.uniform((), AMP_RANGE[0], AMP_RANGE[1]) 64 | phase = tf.random.uniform((), PHASE_RANGE[0], PHASE_RANGE[1]) 65 | init_inputs = tf.random.uniform((total_samples, INPUT_SIZE), INPUT_RANGE[0], 66 | INPUT_RANGE[1]) 67 | outputs = amp * tf.math.sin(init_inputs - phase) 68 | return init_inputs, outputs, amp, phase 69 | 70 | ds = tf.data.Dataset.range(2) 71 | ds = ds.map(gen_input_outputs) 72 | 73 | return ds 74 | 75 | 76 | def w(i): 77 | return 'w{}'.format(i) 78 | 79 | 80 | def b(i): 81 | return 'b{}'.format(i) 82 | 83 | 84 | def model_weights(): 85 | weights = {} 86 | aug_sizes = (INPUT_SIZE,) + HIDDEN_SIZES + (OUTPUT_SIZE,) 87 | for i in range(1, len(aug_sizes)): 88 | weights[w(i)] = tf.Variable( 89 | tf.random.truncated_normal((aug_sizes[i - 1], aug_sizes[i]), 90 | stddev=0.01)) 91 | weights[b(i)] = tf.Variable(tf.zeros((aug_sizes[i],))) 92 | return weights 93 | 94 | 95 | def model(inputs, weights): 96 | x = inputs 97 | for i in range(1, len(HIDDEN_SIZES) + 1): 98 | x = tf.nn.relu(tf.matmul(x, weights[w(i)]) + weights[b(i)]) 99 | i = len(HIDDEN_SIZES) + 1 100 | return tf.matmul(x, weights[w(i)]) + weights[b(i)] 101 | 102 | 103 | def mse(preds, labels): 104 | return tf.reduce_mean(tf.square(preds - labels)) 105 | 106 | 107 | def local_learn(inputs, outputs, weights, num_steps): 108 | """Runs a classical training loop.""" 109 | learned_weights = tf.nest.map_structure(lambda w: w.read_value(), weights) 110 | 111 | for _ in tf.range(num_steps): 112 | # Inference 113 | with tf.GradientTape() as tape: 114 | tf.nest.map_structure(tape.watch, learned_weights) 115 | y_pred = model(inputs, learned_weights) 116 | step_loss = mse(y_pred, outputs) 117 | 118 | # SGD step 119 | grads = tape.gradient(step_loss, learned_weights) 120 | learned_weights = tf.nest.map_structure( 121 | lambda w, g: w - LOCAL_LEARNING_RATE * g, learned_weights, grads) 122 | 123 | return learned_weights 124 | 125 | 126 | def metalearn(weights, opt, meta_steps): 127 | """Runs a MAML learning loop.""" 128 | ds = sin_dataset().repeat().batch(META_BATCH_SIZE).take(meta_steps) 129 | 130 | for inputs, outputs, _, _ in ds: 131 | train_inputs = inputs[:, TRAIN_SAMPES_SLICE, :] 132 | valid_inputs = inputs[:, VALID_SAMPES_SLICE, :] 133 | train_outputs = outputs[:, TRAIN_SAMPES_SLICE, :] 134 | valid_outputs = outputs[:, VALID_SAMPES_SLICE, :] 135 | 136 | with tf.GradientTape() as tape: 137 | tf.nest.map_structure(tape.watch, weights) 138 | 139 | # Per-task learning 140 | task_losses = tf.TensorArray(tf.float32, size=META_BATCH_SIZE) 141 | for i in tf.range(META_BATCH_SIZE): 142 | # Train on the training data points 143 | learned_weights = local_learn(train_inputs[i], train_outputs[i], 144 | weights, LOCAL_LEARNING_STEPS) 145 | # Calucalate loss on the validation data points 146 | learned_valid_outputs = model(valid_inputs[i], learned_weights) 147 | # Use the validation error for meta training 148 | task_loss = mse(learned_valid_outputs, valid_outputs[i]) 149 | task_losses = task_losses.write(i, task_loss) 150 | 151 | # Average per-task validation errors. 152 | meta_loss = tf.reduce_mean(task_losses.stack()) 153 | 154 | # Take a single meta-training step. 155 | meta_grads = tape.gradient(meta_loss, weights) 156 | grads_and_vars = zip(tf.nest.flatten(meta_grads), tf.nest.flatten(weights)) 157 | opt.apply_gradients(grads_and_vars) 158 | 159 | 160 | class MAMLBenchmark(benchmark_base.ReportingBenchmark): 161 | """Basic benchmark for the MAML example model.""" 162 | 163 | def _run_benchmark(self, name, metalearn_function, meta_steps): 164 | init_weights = model_weights() 165 | opt = tf.keras.optimizers.Adam() 166 | 167 | def target(): 168 | metalearn_function(init_weights, opt, meta_steps) 169 | 170 | self.time_execution((name, meta_steps), 171 | target, 172 | extras={ 173 | 'meta_steps': meta_steps, 174 | }) 175 | 176 | def benchmark_maml(self): 177 | # TODO(mdanatg): Remove this override. 178 | all_current_features = tf.autograph.experimental.Feature.all_but( 179 | tf.autograph.experimental.Feature.AUTO_CONTROL_DEPS) 180 | 181 | for meta_steps in (1, 10): 182 | self._run_benchmark('Eager', metalearn, meta_steps) 183 | metalearn_autograph = tf.function( 184 | metalearn, experimental_autograph_options=all_current_features) 185 | self._run_benchmark('AutoGraph', metalearn_autograph, meta_steps) 186 | 187 | 188 | if __name__ == '__main__': 189 | tf.test.main() 190 | -------------------------------------------------------------------------------- /reference_tests/datasets_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests involving the tf.data.Datasets API.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import reference_test_base 22 | import tensorflow.compat.v2 as tf 23 | 24 | 25 | tf.enable_v2_behavior() 26 | 27 | 28 | def dataset_no_vars_loop(ds): 29 | for e in ds: 30 | tf.print(e) 31 | 32 | 33 | def iterator_no_vars_loop(ds): 34 | for e in iter(ds): 35 | tf.print(e) 36 | 37 | 38 | def dataset_single_var_loop(ds): 39 | s = tf.constant(0, dtype=tf.int64) 40 | for e in ds: 41 | s = s * 10 + e 42 | return s 43 | 44 | 45 | def iterator_single_var_loop(ds): 46 | s = tf.constant(0, dtype=tf.int64) 47 | for e in iter(ds): 48 | s = s * 10 + e 49 | return s 50 | 51 | 52 | def dataset_two_vars_loop(ds): 53 | s = tf.constant(0, dtype=tf.int64) 54 | p = tf.constant(1, dtype=tf.int64) 55 | for e in ds: 56 | s += e 57 | p *= e 58 | return s, p 59 | 60 | 61 | def iterator_two_vars_loop(ds): 62 | s = tf.constant(0, dtype=tf.int64) 63 | p = tf.constant(1, dtype=tf.int64) 64 | for e in iter(ds): 65 | s += e 66 | p *= e 67 | return s, p 68 | 69 | 70 | def dataset_loop_with_break(ds): 71 | s = tf.constant(0, dtype=tf.int64) 72 | for e in ds: 73 | s = s * 10 + e 74 | if s > 100: 75 | break 76 | return s 77 | 78 | 79 | def iterator_loop_with_break(ds): 80 | s = tf.constant(0, dtype=tf.int64) 81 | for e in iter(ds): 82 | s = s + e 83 | if s > 100: 84 | break 85 | return s 86 | 87 | 88 | def iterator_resuming_loop(ds): 89 | s = tf.constant(0, dtype=tf.int64) 90 | itr = iter(ds) 91 | for e in itr: 92 | s = s * 10 + e 93 | break 94 | for e in itr: 95 | s = s * 10 + e 96 | break 97 | for e in itr: 98 | s = s * 10 + e 99 | return s 100 | 101 | 102 | def dataset_loop_with_return(ds): 103 | y = tf.constant(0, dtype=tf.int64) 104 | for e in ds: 105 | y = e 106 | return y 107 | return y 108 | 109 | 110 | def iterator_loop_with_return(ds): 111 | y = tf.constant(0, dtype=tf.int64) 112 | for e in iter(ds): 113 | y = e 114 | return y 115 | return y 116 | 117 | 118 | def iterator_next(ds): 119 | itr = iter(ds) 120 | return next(itr) 121 | 122 | 123 | def iterator_next_multiple_calls(ds): 124 | itr = iter(ds) 125 | return 10 * next(itr) + next(itr) 126 | 127 | 128 | def iterator_next_in_loop(ds, n): 129 | itr = iter(ds) 130 | s = tf.constant(0, dtype=tf.int64) 131 | for _ in range(n): 132 | s = s * 10 + next(itr) 133 | return s 134 | 135 | 136 | def iterator_next_stopping(ds, cond): 137 | # This case will raise, but not the expected StopIteration error. 138 | itr = iter(ds) 139 | while cond: 140 | next(itr) 141 | 142 | 143 | def iterator_next_with_catching_stop_iteration(ds, cond): 144 | # This is the only instance when the use of TF iterators does not work as 145 | # intended. In graph mode, the `except` below will never catch, and the 146 | # tf.function will raise the error instead. 147 | # TODO(b/132311724): The error should be friendlier here. 148 | # Note: b/132298783 covers actually supporting this pattern. 149 | itr = iter(ds) 150 | try: 151 | while cond: 152 | next(itr) 153 | except StopIteration: 154 | pass 155 | 156 | 157 | class ReferenceTest(reference_test_base.TestCase): 158 | 159 | def setUp(self): 160 | super(ReferenceTest, self).setUp() 161 | self.ds = tf.data.Dataset.range(7) 162 | 163 | def test_dataset_no_vars_loop(self): 164 | self.assertFunctionMatchesEager(dataset_no_vars_loop, self.ds) 165 | 166 | def test_iterator_no_vars_loop(self): 167 | self.assertFunctionMatchesEager(iterator_no_vars_loop, self.ds) 168 | 169 | def test_dataset_single_var_loop(self): 170 | self.assertFunctionMatchesEager(dataset_single_var_loop, self.ds) 171 | 172 | def test_iterator_single_var_loop(self): 173 | self.assertFunctionMatchesEager(iterator_single_var_loop, self.ds) 174 | 175 | def test_dataset_two_vars_loop(self): 176 | self.assertFunctionMatchesEager(dataset_two_vars_loop, self.ds) 177 | 178 | def test_iterator_two_vars_loop(self): 179 | self.assertFunctionMatchesEager(iterator_two_vars_loop, self.ds) 180 | 181 | def test_dataset_loop_with_break(self): 182 | self.assertFunctionMatchesEager(dataset_loop_with_break, self.ds) 183 | 184 | def test_iterator_loop_with_break(self): 185 | self.assertFunctionMatchesEager(iterator_loop_with_break, self.ds) 186 | 187 | def test_dataset_loop_with_return_raises(self): 188 | # This is for the same reason why returns in loops aren't allowed. 189 | # TODO(mdan): This might be resolved by unrolling the loop once. 190 | with self.assertRaisesRegex( 191 | ValueError, 192 | 'return statements are not supported within a TensorFlow loop.'): 193 | tf.function(dataset_loop_with_return)(self.ds) 194 | 195 | def test_iterator_loop_with_return_raises(self): 196 | # This is for the same reason why returns in loops aren't allowed. 197 | # TODO(mdan): This might be resolved by unrolling the loop once. 198 | with self.assertRaisesRegex( 199 | ValueError, 200 | 'return statements are not supported within a TensorFlow loop.'): 201 | tf.function(iterator_loop_with_return)(self.ds) 202 | 203 | def test_iterator_next(self): 204 | self.assertFunctionMatchesEager(iterator_next, self.ds) 205 | 206 | def test_iterator_next_multiple_calls(self): 207 | self.assertFunctionMatchesEager(iterator_next_multiple_calls, self.ds) 208 | 209 | def test_iterator_next_in_loop(self): 210 | self.assertFunctionMatchesEager(iterator_next_in_loop, self.ds, 7) 211 | 212 | def test_iterator_next_stopping(self): 213 | # Graph ops raise OutOfRangeError, but eager ops raise StopIteration 214 | with self.assertRaises(tf.errors.OutOfRangeError): 215 | tf.function(iterator_next_stopping)(self.ds, tf.constant(True)) 216 | 217 | def test_iterator_next_with_catching_stop_iteration(self): 218 | # Graph ops raise OutOfRangeError, but eager ops raise StopIteration 219 | with self.assertRaises(tf.errors.OutOfRangeError): 220 | tf.function(iterator_next_with_catching_stop_iteration)( 221 | self.ds, tf.constant(True)) 222 | 223 | 224 | if __name__ == '__main__': 225 | tf.test.main() 226 | -------------------------------------------------------------------------------- /examples/sysml2019/seq2seq_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark for a seq2seq encoder/decoder implementation. 16 | 17 | This code requires TF 2.0 or newer. 18 | Use `pip install tf-nightly-2.0-preview` to install. 19 | Tip: Run the `pip install` command in a separate virtual environment 20 | (Virtualenv, Anaconda) to avoid clobbering an existing TF installation. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import benchmark_base 28 | 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | 33 | BATCH_SIZE = 32 34 | MAX_SEQ_LEN = 100 35 | HIDDEN_SIZE = 256 36 | VOCAB_SIZE = 1000 37 | EMBEDDING_SIZE = 50 38 | 39 | 40 | def seq2seq(rnn_cell, embedding, input_seq, input_seq_lengths, eos_id, 41 | target_seq=None, max_output_len=100): 42 | """An implementation of seq2seq in AutoGraph-friendly format. 43 | 44 | Args: 45 | rnn_cell: An RNNCell. Is used for both encoding and decoding. 46 | embedding: An embedding lookup matrix. Used for both encoding and decoding. 47 | input_seq: Tensor of shape [batch_size, seq_len] of vocab_ids. 48 | input_seq_lengths: Tensor of shape [batch_size] with sequence lengths. 49 | eos_id: The vocab_id corresponding to . 50 | target_seq: (Optional) Target sequence for teacher forcing. 51 | max_output_len: Maximum output length before cutting off decoder. (Otherwise 52 | it can run forever without emitting EOS). 53 | Returns: 54 | Tensor of shape [seq_length, batch_size, vocab_size]. 55 | """ 56 | batch_size = input_seq.shape[0] 57 | input_seq = tf.nn.embedding_lookup(embedding, input_seq) 58 | state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 59 | max_input_seq_len = tf.reduce_max(input_seq_lengths) 60 | # Encoder half of model 61 | for i in tf.range(max_input_seq_len): 62 | _, new_state = rnn_cell(input_seq[:, i], state) 63 | state = tf.where(i < input_seq_lengths, new_state, state) 64 | 65 | # Decoder half of model 66 | if target_seq is not None: 67 | max_output_len = target_seq.shape[1] 68 | 69 | outputs = tf.TensorArray(tf.float32, size=max_output_len) 70 | is_done = tf.zeros([batch_size], dtype=tf.bool) 71 | # Initial input is the end of seq token. 72 | eos_vector = tf.nn.embedding_lookup(embedding, tf.constant([eos_id])) 73 | dec_input = tf.tile(eos_vector, [batch_size, 1]) 74 | # Run up to max_output_len steps; can exit earlier if all sequences done. 75 | for i in tf.range(max_output_len): 76 | new_output, new_state = rnn_cell(dec_input, state) 77 | output = tf.where(is_done, tf.zeros_like(new_output), new_output) 78 | outputs.write(i, output) 79 | if target_seq is not None: 80 | # if target is known, use teacher forcing 81 | target_word = target_seq[:, i] 82 | else: 83 | # Otherwise, pick the most likely continuation (greedy search) 84 | target_word = tf.argmax(output, axis=1) 85 | dec_input = tf.nn.embedding_lookup(embedding, target_word) 86 | is_done = tf.logical_or(is_done, tf.equal(target_word, eos_id)) 87 | if tf.reduce_all(is_done): 88 | break 89 | return outputs.stack() 90 | 91 | 92 | class Seq2SeqBenchmark(benchmark_base.ReportingBenchmark): 93 | """Runs benchmarks for eager/autograph/graph variants of seq2seq.""" 94 | 95 | def _generate_fake_rnn_inputs(self, 96 | batch_size=BATCH_SIZE, 97 | max_seq_len=MAX_SEQ_LEN): 98 | np.random.seed(17) 99 | 100 | input_data = np.random.randint( 101 | 0, VOCAB_SIZE, size=[batch_size, max_seq_len]).astype(np.int32) 102 | # Generate some varying sequence lengths but keep max(sequence_lengths) 103 | # a constant, for more reproducible benchmarks. 104 | sequence_lengths = np.concatenate(([max_seq_len], 105 | np.random.randint( 106 | max_seq_len // 2, 107 | max_seq_len, 108 | size=[batch_size - 1]))).astype( 109 | np.int32) 110 | 111 | for i, seq_len in enumerate(sequence_lengths): 112 | input_data[i, seq_len:] = 0 113 | 114 | input_data = tf.constant(input_data) 115 | sequence_lengths = tf.constant(sequence_lengths) 116 | 117 | return input_data, sequence_lengths 118 | 119 | def _create_embedding(self): 120 | return tf.random.uniform([VOCAB_SIZE, EMBEDDING_SIZE]) 121 | 122 | def _create_rnn_cell(self, batch_size=BATCH_SIZE): 123 | rnn_cell = tf_v1.nn.rnn_cell.BasicRNNCell(HIDDEN_SIZE, dtype=tf.float32) 124 | rnn_cell.build(tf.TensorShape([batch_size, EMBEDDING_SIZE])) 125 | return rnn_cell 126 | 127 | def _benchmark_seq2seq(self, mode, seq2seq_variant, batch_size, max_seq_len, 128 | use_teacher_forcing): 129 | input_seq, input_seq_lengths = self._generate_fake_rnn_inputs( 130 | batch_size=batch_size, max_seq_len=max_seq_len) 131 | target_seq, _ = self._generate_fake_rnn_inputs( 132 | batch_size=batch_size, max_seq_len=max_seq_len) 133 | rnn_cell = self._create_rnn_cell(batch_size=batch_size) 134 | embedding = self._create_embedding() 135 | 136 | if not use_teacher_forcing: 137 | target_seq = None 138 | 139 | def target(): 140 | return seq2seq_variant(rnn_cell, embedding, input_seq, 141 | input_seq_lengths, 0, target_seq=target_seq) 142 | 143 | self.time_execution( 144 | (mode, max_seq_len, batch_size, use_teacher_forcing), 145 | target, 146 | iter_volume=batch_size, 147 | iter_unit='examples', 148 | extras={ 149 | 'max_seq_len': max_seq_len, 150 | 'batch_size': batch_size, 151 | 'use_teacher_forcing': use_teacher_forcing 152 | }) 153 | 154 | def benchmark_seq2seq(self): 155 | for batch_size in (32, 64, 128): 156 | for max_seq_len in (64, 128): 157 | for use_teacher_forcing in (True, False): 158 | self._benchmark_seq2seq('AutoGraph', tf.function(seq2seq), batch_size, 159 | max_seq_len, use_teacher_forcing) 160 | self._benchmark_seq2seq('Eager', seq2seq, batch_size, max_seq_len, 161 | use_teacher_forcing) 162 | 163 | 164 | if __name__ == '__main__': 165 | tf.test.main() 166 | -------------------------------------------------------------------------------- /examples/sysml2019/mnist_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark comparing MNIST with eager mode, graph mode, and autograph. 16 | 17 | This code assumes TF 1.X. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import benchmark_base 25 | 26 | import tensorflow.compat.v1 as tf 27 | from google3.third_party.tensorflow.contrib import training as contrib_training 28 | from google3.third_party.tensorflow.contrib.eager.python import tfe as contrib_eager 29 | 30 | 31 | tf.enable_eager_execution() 32 | 33 | 34 | def get_data_and_params(): 35 | """Set up input dataset and variables.""" 36 | (train_x, train_y), _ = tf.keras.datasets.mnist.load_data() 37 | tf.set_random_seed(0) 38 | hparams = contrib_training.HParams( 39 | batch_size=200, 40 | learning_rate=0.1, 41 | train_steps=101, 42 | ) 43 | dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y)) 44 | dataset = dataset.repeat() 45 | dataset = dataset.shuffle(hparams.batch_size * 10) 46 | dataset = dataset.batch(hparams.batch_size) 47 | 48 | def reshape_ex(x, y): 49 | return (tf.to_float(tf.reshape(x, (-1, 28 * 28))) / 256.0, 50 | tf.one_hot(tf.squeeze(y), 10)) 51 | 52 | dataset = dataset.map(reshape_ex) 53 | w = tf.get_variable('w0', (28 * 28, 10)) 54 | b = tf.get_variable('b0', (10,), initializer=tf.zeros_initializer()) 55 | opt = tf.train.GradientDescentOptimizer(hparams.learning_rate) 56 | return dataset, opt, hparams, w, b 57 | 58 | 59 | def model_fn(x, w, b): 60 | return tf.matmul(x, w) + b 61 | 62 | 63 | def loss_fn(x, y, w, b): 64 | y_ = model_fn(x, w, b) 65 | return tf.losses.softmax_cross_entropy(y, y_) 66 | 67 | 68 | class MNISTBenchmark(benchmark_base.ReportingBenchmark): 69 | """Benchmark a simple model training loop on MNIST digits dataset.""" 70 | 71 | def benchmark_eager(self): 72 | ds, opt, hp, w, b = get_data_and_params() 73 | iterator = iter(ds) 74 | 75 | def target(): 76 | """Eager implementation of training loop.""" 77 | for i, (x, y) in enumerate(iterator): 78 | if i >= hp.train_steps: 79 | break 80 | with contrib_eager.GradientTape() as tape: 81 | tape.watch(w) 82 | tape.watch(b) 83 | loss_val = loss_fn(x, y, w, b) 84 | dw, db = tape.gradient(loss_val, (w, b)) 85 | opt.apply_gradients(((dw, w), (db, b))) 86 | if i % 100 == 0: 87 | print('Step', i, ':', loss_val) 88 | assert 0.1 < loss_val < 1, loss_val 89 | 90 | self.time_execution( 91 | 'Eager', 92 | target, 93 | iter_volume=hp.train_steps, 94 | iter_unit='training steps') 95 | 96 | def benchmark_legacy_tf(self, loss_fun=loss_fn): 97 | with tf.Graph().as_default(): 98 | ds, opt, hp, w, b = get_data_and_params() 99 | x, y = ds.make_one_shot_iterator().get_next() 100 | loss_t = loss_fn(x, y, w, b) 101 | train_op = opt.minimize(loss_t, var_list=(w, b)) 102 | with tf.Session() as sess: 103 | sess.run(tf.global_variables_initializer()) 104 | 105 | def target(): 106 | for i in range(hp.train_steps): 107 | loss_val, _ = sess.run([loss_t, train_op]) 108 | if i % 100 == 0: 109 | print('Step', i, ':', loss_val) 110 | assert 0.1 < loss_val < 1, loss_val 111 | 112 | self.time_execution( 113 | 'Classical', 114 | target, 115 | iter_volume=hp.train_steps, 116 | iter_unit='training steps') 117 | 118 | def benchmark_autograph(self): 119 | 120 | def loop(ds, opt, hp, w, b): 121 | """AG implementation of training loop.""" 122 | loss = 0.0 123 | iterator = ds.make_one_shot_iterator() 124 | # TODO(brianklee): Rewrite with only one usage of iterator.get_next(). 125 | # Currently needs two calls because of the control_dependencies clause. 126 | # See b/109924949, b/117497661 127 | x, y = iterator.get_next() 128 | for i in tf.range(hp.train_steps): 129 | loss = loss_fn(x, y, w, b) 130 | if i % 100 == 0: 131 | print('Step', i, ':', loss) 132 | with tf.control_dependencies([opt.minimize(loss, var_list=(w, b))]): 133 | # This ensures that each iteration of the loop has a dependency 134 | # on the previous iteration completing. Otherwise you get async SGD. 135 | x, y = iterator.get_next() 136 | 137 | return loss 138 | 139 | loop = tf.autograph.to_graph( 140 | loop, experimental_optional_features=tf.autograph.Feature.ALL) 141 | 142 | with tf.Graph().as_default(): 143 | ds, opt, hp, w, b = get_data_and_params() 144 | loss = loop(ds, opt, hp, w, b) 145 | with tf.Session() as sess: 146 | sess.run(tf.global_variables_initializer()) 147 | 148 | def target(): 149 | loss_val = sess.run(loss) 150 | assert 0.1 < loss_val < 1, loss_val 151 | 152 | self.time_execution( 153 | 'AutoGraph', 154 | target, 155 | iter_volume=hp.train_steps, 156 | iter_unit='training steps') 157 | 158 | def benchmark_handwritten(self): 159 | with tf.Graph().as_default(): 160 | ds, opt, hp, w, b = get_data_and_params() 161 | iterator = ds.make_one_shot_iterator() 162 | 163 | def loop_body(i, unused_previous_loss_t): 164 | """Manual implementation of training loop.""" 165 | # Call get_next() inside body or else training happens repeatedly on 166 | # the first minibatch only. 167 | x, y = iterator.get_next() 168 | loss_t = loss_fn(x, y, w, b) 169 | train_op = opt.minimize(loss_t, var_list=(w, b)) 170 | i = tf.cond(tf.equal(i % 100, 0), 171 | lambda: tf.Print(i, [i, loss_t], message='Step, loss: '), 172 | lambda: i) 173 | 174 | with tf.control_dependencies([train_op]): 175 | return i + 1, loss_t 176 | 177 | _, final_loss_t = tf.while_loop( 178 | lambda i, _: i < hp.train_steps, 179 | loop_body, 180 | [tf.constant(0), tf.constant(0.0)]) 181 | 182 | with tf.Session() as sess: 183 | sess.run(tf.global_variables_initializer()) 184 | 185 | def target(): 186 | loss_val = sess.run(final_loss_t) 187 | assert 0.1 < loss_val < 1, loss_val 188 | 189 | self.time_execution( 190 | 'Handwritten', 191 | target, 192 | iter_volume=hp.train_steps, 193 | iter_unit='training steps') 194 | 195 | 196 | if __name__ == '__main__': 197 | tf.test.main() 198 | -------------------------------------------------------------------------------- /reference_tests/loop_scoping_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests that verify scoping around loops.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | from absl.testing import parameterized 24 | import reference_test_base 25 | import tensorflow as tf 26 | 27 | 28 | def for_with_local_var(l): 29 | s = 0 30 | for i in l: 31 | x = i + 2 32 | s = s * 10 + x 33 | return s 34 | 35 | 36 | def while_with_local_var(x): 37 | s = 0 38 | while x > 0: 39 | y = x + 2 40 | s = s * 10 + y 41 | x -= 1 42 | return s 43 | 44 | 45 | def for_initializes_local_var(l): 46 | s = 0 47 | for i in l: 48 | if i == l[0]: 49 | x = 0 50 | else: 51 | x += 1 52 | s = s * 10 + x 53 | return s 54 | 55 | 56 | def while_initializes_local_var(x): 57 | s = 0 58 | while x > 0: 59 | if x > 0: 60 | y = 0 61 | else: 62 | y += 1 63 | s = s * 10 + y 64 | x -= 1 65 | return s 66 | 67 | 68 | def for_defines_var(l): 69 | for i in l: 70 | x = i + 2 71 | return x 72 | 73 | 74 | def while_defines_var(x): 75 | while x > 0: 76 | y = x + 2 77 | x -= 1 78 | return y 79 | 80 | 81 | def for_defines_iterate(n, fn): 82 | s = 0 83 | for i in fn(n): 84 | s = s * 10 + i 85 | return i, s # pylint:disable=undefined-loop-variable 86 | 87 | 88 | def for_reuses_iterate(n, fn): 89 | i = 7 90 | s = 0 91 | for i in fn(n): 92 | s = s * 10 + i 93 | return i, s 94 | 95 | 96 | def for_alters_iterate(n, fn): 97 | i = 7 98 | s = 0 99 | for i in fn(n): 100 | i = 3 * i + 1 101 | s = s * 10 + i 102 | return i, s 103 | 104 | 105 | def _int_tensor(x): 106 | return tf.constant(x, dtype=tf.int32) 107 | 108 | 109 | class LoopScopingTest(reference_test_base.TestCase, parameterized.TestCase): 110 | 111 | @parameterized.parameters(*itertools.product( 112 | ([], [1], [1, 2]), 113 | (list, _int_tensor), 114 | )) 115 | def test_for_with_local_var(self, l, type_): 116 | l = type_(l) 117 | self.assertFunctionMatchesEager(for_with_local_var, l) 118 | 119 | @parameterized.parameters(*itertools.product( 120 | (0, 1, 2), 121 | (range, tf.range), 122 | )) 123 | def test_for_with_local_var_range(self, l, type_): 124 | l = type_(l) 125 | self.assertFunctionMatchesEager(for_with_local_var, l) 126 | 127 | @parameterized.parameters(*itertools.product( 128 | (0, 1, 2), 129 | (int, _int_tensor), 130 | )) 131 | def test_while_with_local_var(self, x, type_): 132 | x = type_(x) 133 | self.assertFunctionMatchesEager(while_with_local_var, x) 134 | 135 | @parameterized.parameters( 136 | ([],), 137 | ([1],), 138 | ([1, 2],), 139 | ) 140 | def test_for_initializes_local_var_legal_cases(self, l): 141 | self.assertFunctionMatchesEager(for_initializes_local_var, l) 142 | 143 | @parameterized.parameters( 144 | ([],), 145 | ([1],), 146 | ([1, 2],), 147 | ) 148 | def test_for_initializes_local_var_illegal_cases(self, l): 149 | l = tf.constant(l) 150 | with self.assertRaisesRegex(ValueError, '"x" must be defined'): 151 | tf.function(for_initializes_local_var)(l) 152 | 153 | @parameterized.parameters( 154 | 0, 155 | 1, 156 | 2, 157 | ) 158 | def test_while_initializes_local_var_legal_cases(self, x): 159 | self.assertFunctionMatchesEager(while_initializes_local_var, x) 160 | 161 | @parameterized.parameters( 162 | 0, 163 | 1, 164 | 2, 165 | ) 166 | def test_while_initializes_local_var_illegal_cases(self, x): 167 | x = tf.constant(x) 168 | with self.assertRaisesRegex(ValueError, '"y" must be defined'): 169 | tf.function(while_initializes_local_var)(x) 170 | 171 | @parameterized.parameters( 172 | # TODO(b/155171694): Enable once the error message here is corrected. 173 | # ([],), 174 | ([1],), 175 | ([1, 2],), 176 | ) 177 | def test_for_defines_var_legal_cases(self, l): 178 | self.assertFunctionMatchesEager(for_defines_var, l) 179 | 180 | @parameterized.parameters( 181 | ([],), 182 | ([1],), 183 | ([1, 2],), 184 | ) 185 | def test_for_defines_var_illegal_cases(self, l): 186 | l = tf.constant(l) 187 | with self.assertRaisesRegex(ValueError, '"x" must be defined'): 188 | tf.function(for_defines_var)(l) 189 | 190 | @parameterized.parameters( 191 | # TODO(b/155171694): Enable once the error message here is corrected. 192 | # (0,), 193 | (1,), 194 | (2,), 195 | ) 196 | def test_while_defines_var_legal_cases(self, x): 197 | self.assertFunctionMatchesEager(while_defines_var, x) 198 | 199 | @parameterized.parameters( 200 | (0,), 201 | (1,), 202 | (2,), 203 | ) 204 | def test_while_defines_var_illegal_cases(self, x): 205 | x = tf.constant(x) 206 | with self.assertRaisesRegex(ValueError, '"y" must be defined'): 207 | tf.function(while_defines_var)(x) 208 | 209 | @parameterized.parameters(*itertools.product( 210 | (1, 2), 211 | (range, tf.range), 212 | )) 213 | def test_for_defines_iterate_legal_cases(self, n, fn): 214 | self.assertFunctionMatchesEager(for_defines_iterate, n, fn) 215 | 216 | def test_for_defines_iterate_range(self): 217 | self.skipTest('b/155171694') 218 | 219 | def test_for_defines_iterate_tf_range(self): 220 | # Deviating from the normal Python semantics here to avoid inserting 221 | # an extra assert op. If needed, we can insert it and raise an error 222 | # to mimic the eager behavior, but this is an exceptionally uncummon 223 | # use case. 224 | self.assertAllEqual(tf.function(for_defines_iterate)(0, tf.range), (0, 0)) 225 | 226 | @parameterized.parameters(*itertools.product( 227 | ([], [1], [1, 2]), 228 | (list, _int_tensor), 229 | )) 230 | def test_for_reuses_iterate(self, l, fn): 231 | self.assertFunctionMatchesEager(for_reuses_iterate, l, fn) 232 | 233 | @parameterized.parameters(*itertools.product( 234 | (0, 1, 2), 235 | (range, tf.range), 236 | )) 237 | def test_for_reuses_iterate_range(self, n, fn): 238 | self.assertFunctionMatchesEager(for_reuses_iterate, n, fn) 239 | 240 | @parameterized.parameters(*itertools.product( 241 | ([], [1], [1, 2]), 242 | (list, _int_tensor), 243 | )) 244 | def test_for_alters_iterate(self, l, fn): 245 | self.assertFunctionMatchesEager(for_alters_iterate, l, fn) 246 | 247 | @parameterized.parameters(*itertools.product( 248 | (0, 1, 2), 249 | (range, tf.range), 250 | )) 251 | def test_for_alters_iterate_range(self, n, fn): 252 | self.assertFunctionMatchesEager(for_alters_iterate, n, fn) 253 | 254 | 255 | if __name__ == '__main__': 256 | tf.test.main() 257 | -------------------------------------------------------------------------------- /reference_tests/early_return_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Multiple returns, some in conditionals.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | from absl.testing import parameterized 24 | import reference_test_base 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | tf.enable_v2_behavior() 29 | 30 | 31 | def return_with_default(x): 32 | if x > 0: 33 | tf.print('x', x) 34 | return x 35 | return x * x 36 | 37 | 38 | def return_possibly_undefined(x): 39 | if x > 0: 40 | if x < 5: 41 | return x 42 | else: 43 | return x * x * x 44 | 45 | 46 | def nested_ifs(x): 47 | if x > 0: 48 | if x < 5: 49 | return x 50 | else: 51 | return x * x 52 | else: 53 | return x * x * x 54 | 55 | 56 | def possible_return_before_loop(c1, c2, n): 57 | if c1: 58 | if c2: 59 | return 1 60 | for _ in range(n): 61 | pass 62 | return 2 63 | 64 | 65 | def nested_ifs_and_context_managers(x): 66 | with tf.name_scope(''): 67 | if x > 0: 68 | if x < 5: 69 | with tf.name_scope(''): 70 | return x 71 | else: 72 | return x * x 73 | else: 74 | return x * x * x 75 | 76 | 77 | def unreachable_return(x): 78 | with tf.name_scope(''): 79 | if x > 0: 80 | if x < 5: 81 | with tf.name_scope(''): 82 | return x 83 | else: 84 | return x * x 85 | else: 86 | return x * x * x 87 | return x * x * x * x 88 | 89 | 90 | def return_with_default_in_contexmanager(x): 91 | with tf.name_scope(''): 92 | if x > 0: 93 | return 1 94 | return 0 95 | 96 | 97 | def return_in_try_with_finally(x): 98 | try: 99 | if x > 0: 100 | return 1 101 | else: 102 | return 0 103 | finally: 104 | x = x + 1 105 | 106 | 107 | def return_with_default_in_try_with_finally(x): 108 | try: 109 | if x > 0: 110 | return 1 111 | return 0 112 | finally: 113 | x = x + 1 114 | 115 | 116 | def return_in_finally(x): 117 | try: 118 | return 2 119 | finally: 120 | if x > 0: 121 | return 1 # pylint: disable=lost-exception 122 | else: 123 | return 0 # pylint: disable=lost-exception 124 | 125 | 126 | def return_with_default_in_finally(x): 127 | try: 128 | return 2 129 | finally: 130 | if x > 0: 131 | return 1 # pylint: disable=lost-exception 132 | return 0 # pylint: disable=lost-exception 133 | 134 | 135 | def return_in_finally_default_in_try(x): 136 | try: 137 | if x > 0: 138 | return 0 139 | finally: 140 | return 1 # pylint: disable=lost-exception 141 | 142 | 143 | def _raising_helper(): 144 | raise ValueError() 145 | 146 | 147 | def raise_during_return_caught(): 148 | try: 149 | return _raising_helper() 150 | except ValueError: 151 | pass 152 | return 1 153 | 154 | 155 | def raise_during_return_caught_in_tail_branch(c): 156 | if c: 157 | return 2 158 | try: 159 | return _raising_helper() 160 | except ValueError: 161 | pass 162 | return 1 163 | 164 | 165 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 166 | """Base class for the reference tests.""" 167 | 168 | @parameterized.parameters(*itertools.product( 169 | (0, 1), 170 | (int, tf.constant), 171 | )) 172 | def test_return_with_default(self, n, type_): 173 | self.assertFunctionMatchesEager(return_with_default, type_(n)) 174 | 175 | @parameterized.parameters((0,), (3,), (5,)) 176 | def test_return_possibly_undefined_legal(self, n): 177 | self.assertFunctionMatchesEager(return_possibly_undefined, n) 178 | 179 | @parameterized.parameters((0,), (3,), (5,)) 180 | def test_return_possibly_undefined_illegal(self, n): 181 | with self.assertRaisesRegex( 182 | ValueError, '.*must also.*return.*else branch.*'): 183 | tf.function(return_possibly_undefined)(tf.constant(n)) 184 | 185 | @parameterized.parameters(*itertools.product( 186 | (-1, 3, 6), 187 | (int, tf.constant), 188 | )) 189 | def test_nested_ifs(self, n, type_): 190 | self.assertFunctionMatchesEager(nested_ifs, type_(n)) 191 | 192 | @parameterized.parameters(*itertools.product( 193 | (True, False), 194 | (True, False), 195 | (0, 1, 2), 196 | )) 197 | def test_possible_return_before_loop(self, c1, c2, n): 198 | self.assertFunctionMatchesEager(possible_return_before_loop, c1, c2, n) 199 | 200 | @parameterized.parameters(*itertools.product( 201 | (0, 3, 5), 202 | (int, tf.constant), 203 | )) 204 | def test_nested_ifs_and_context_managers(self, x, type_): 205 | self.assertFunctionMatchesEager(nested_ifs_and_context_managers, type_(x)) 206 | 207 | @parameterized.parameters(*itertools.product( 208 | (0, 3, 5), 209 | (int, tf.constant), 210 | )) 211 | def test_unreachable_return(self, x, type_): 212 | self.assertFunctionMatchesEager(unreachable_return, type_(x)) 213 | 214 | @parameterized.parameters(*itertools.product( 215 | (0, 1), 216 | (int, tf.constant), 217 | )) 218 | def test_return_with_default_in_contexmanager(self, x, type_): 219 | self.assertFunctionMatchesEager( 220 | return_with_default_in_contexmanager, type_(x)) 221 | 222 | @parameterized.parameters(*itertools.product( 223 | (0, 1), 224 | (int, tf.constant), 225 | )) 226 | def test_return_in_try_finally(self, x, type_): 227 | self.assertFunctionMatchesEager(return_in_try_with_finally, type_(x)) 228 | 229 | @parameterized.parameters(*itertools.product( 230 | (0, 1), 231 | (int, tf.constant), 232 | )) 233 | def test_return_with_default_try_finally(self, x, type_): 234 | self.assertFunctionMatchesEager( 235 | return_with_default_in_try_with_finally, type_(x)) 236 | 237 | @parameterized.parameters(*itertools.product( 238 | (0, 1), 239 | (int, tf.constant), 240 | )) 241 | def test_return_in_finally(self, x, type_): 242 | self.assertFunctionMatchesEager(return_in_finally, type_(x)) 243 | 244 | @parameterized.parameters(*itertools.product( 245 | (0, 1), 246 | (int, tf.constant), 247 | )) 248 | def test_return_with_default_in_finally(self, x, type_): 249 | self.assertFunctionMatchesEager(return_with_default_in_finally, type_(x)) 250 | 251 | @parameterized.parameters(*itertools.product( 252 | (0, 1), 253 | (int, tf.constant), 254 | )) 255 | def test_return_in_finally_default_in_try(self, x, type_): 256 | self.assertFunctionMatchesEager(return_in_finally_default_in_try, type_(x)) 257 | 258 | def test_raise_during_return_caught(self): 259 | self.assertFunctionMatchesEager(raise_during_return_caught) 260 | 261 | @parameterized.parameters(*itertools.product( 262 | (True, False), 263 | (int, tf.constant), 264 | )) 265 | def test_raise_during_return_caught_in_tail_branch(self, c, type_): 266 | self.assertFunctionMatchesEager( 267 | raise_during_return_caught_in_tail_branch, type_(c)) 268 | 269 | 270 | if __name__ == '__main__': 271 | tf.test.main() 272 | -------------------------------------------------------------------------------- /examples/sysml2019/lbfgs_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark for a basic L-BFGS implementation without beam search. 16 | 17 | Adapted from 18 | https://github.com/yaroslavvb/stuff/blob/master/eager_lbfgs/eager_lbfgs.py. 19 | 20 | This code requires TF 2.0 or newer. 21 | Use `pip install tf-nightly-2.0-preview` to install. 22 | Tip: Run the `pip install` command in a separate virtual environment 23 | (Virtualenv, Anaconda) to avoid clobbering an existing TF installation. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import benchmark_base 31 | 32 | import tensorflow as tf 33 | 34 | 35 | INPUT_SIZE = 28 * 28 36 | BATCH_SIZE = 100 37 | 38 | MAX_ITER = 20 39 | MAX_EVAL = 25 40 | TOL_F = 1e-5 41 | TOL_X = 1e-9 42 | N_CORRECTIONS = 100 43 | LEARNING_RATE = 1.0 44 | 45 | 46 | def mnist_dataset(): 47 | """Loads the MNIST dataset.""" 48 | 49 | def prepare_mnist_features_and_labels(x, y): 50 | x = tf.cast(x, tf.float32) / 255.0 51 | x = tf.reshape(x, (INPUT_SIZE,)) 52 | y = tf.cast(y, tf.int64) 53 | return x, y 54 | 55 | (x, y), _ = tf.keras.datasets.mnist.load_data() 56 | ds = tf.data.Dataset.from_tensor_slices((x, y)) 57 | ds = ds.map(prepare_mnist_features_and_labels) 58 | ds = ds.take(BATCH_SIZE).batch(BATCH_SIZE) 59 | return ds 60 | 61 | 62 | def cb_allocate(el, size): 63 | """Basic primitive to allocate a circular buffer.""" 64 | # TODO(mdanatg): Rewrite as a namedtuple. 65 | el = tf.convert_to_tensor(el) 66 | buff = tf.TensorArray( 67 | el.dtype, 68 | size=size + 1, 69 | element_shape=tf.TensorShape(None), 70 | clear_after_read=False) 71 | for i in tf.range(size): 72 | buff = buff.write(i, el) 73 | begin = 0 74 | end = 0 75 | return buff, begin, end 76 | 77 | 78 | def cb_append(buff, begin, end, size, el): 79 | """Circular buffer append primitive.""" 80 | buff = buff.write(end, el) 81 | end = (end + 1) % size 82 | if tf.equal(end, begin): 83 | begin = (begin + 1) % size 84 | return buff, begin, end 85 | 86 | 87 | def cb_len(begin, end, size): 88 | """Circular buffer length primitive.""" 89 | return (end - begin) % size 90 | 91 | 92 | def cb_range(begin, end, size): 93 | """Circular buffer range primitive.""" 94 | if end < begin: 95 | virtual_end = end + size 96 | else: 97 | virtual_end = end 98 | return tf.range(begin, virtual_end) % size 99 | 100 | 101 | def cb_rev_range(begin, end, size): 102 | """Circular buffer reversed range primitive.""" 103 | if end < begin: 104 | virtual_end = end + size 105 | else: 106 | virtual_end = end 107 | return tf.range(virtual_end - 1, begin - 1, -1) % size 108 | 109 | 110 | def dot(a, b): 111 | return tf.reduce_sum(a * b) 112 | 113 | 114 | def loss_fn(w_flat, data): 115 | w = tf.reshape(w_flat, [INPUT_SIZE, -1]) 116 | x = tf.matmul(data, w) 117 | x = tf.sigmoid(x) 118 | x = tf.matmul(x, w, transpose_b=True) 119 | x = tf.sigmoid(x) 120 | return tf.reduce_mean(tf.square(x - data)) 121 | 122 | 123 | def loss_and_grad(w_flat, data): 124 | with tf.GradientTape() as g: 125 | g.watch(w_flat) 126 | f = loss_fn(w_flat, data) 127 | g = g.gradient(f, w_flat) 128 | return f, g 129 | 130 | 131 | def lbfgs_eager(x, data): 132 | """Implementation of L-BFGS in AutoGraph / TF Eager.""" 133 | 134 | f, g = loss_and_grad(x, data) 135 | f_hist = tf.TensorArray(f.dtype, size=0, dynamic_size=True) 136 | f_hist = f_hist.write(0, f) 137 | f_evals = 1 138 | 139 | # Check optimality of initial point. 140 | if tf.reduce_sum(tf.abs(g)) <= TOL_F: 141 | tf.print('Optimality condition below TOL_F.') 142 | return x, f_hist.stack() 143 | 144 | # Pre-allocate some buffers. 145 | dirs_buff, dirs_begin, dirs_end = cb_allocate(tf.zeros_like(g), N_CORRECTIONS) 146 | steps_buff, steps_begin, steps_end = cb_allocate( 147 | tf.zeros_like(g), N_CORRECTIONS) 148 | ro, _, _ = cb_allocate(0.0, N_CORRECTIONS) 149 | al, _, _ = cb_allocate(0.0, N_CORRECTIONS) 150 | 151 | n_iter = tf.constant(0) 152 | d = -g 153 | prev_g = g 154 | h_diag = 1.0 155 | t = tf.minimum(1.0, 1.0 / tf.reduce_sum(tf.abs(g))) 156 | 157 | while n_iter <= MAX_ITER: 158 | n_iter += 1 159 | 160 | if n_iter > 1: 161 | y = g - prev_g 162 | s = d * t 163 | ys = dot(y, s) 164 | 165 | if ys > 1e-10: 166 | dirs_buff, dirs_begin, dirs_end = cb_append(dirs_buff, dirs_begin, 167 | dirs_end, N_CORRECTIONS, s) 168 | steps_buff, steps_begin, steps_end = cb_append( 169 | steps_buff, steps_begin, steps_end, N_CORRECTIONS, y) 170 | h_diag = ys / dot(y, y) 171 | 172 | # Approximate inverse Hessian-gradient product. 173 | q = -g 174 | for i in cb_rev_range(dirs_begin, dirs_end, N_CORRECTIONS): 175 | ro = ro.write(i, 1 / dot(steps_buff.read(i), dirs_buff.read(i))) 176 | al = al.write(i, dot(dirs_buff.read(i), q) * ro.read(i)) 177 | q = q - al.read(i) * steps_buff.read(i) 178 | 179 | r = q * h_diag 180 | for i in cb_range(dirs_begin, dirs_end, N_CORRECTIONS): 181 | be = dot(steps_buff.read(i), r) * ro.read(i) 182 | r += (al.read(i) - be) * dirs_buff.read(i) 183 | 184 | d = r 185 | 186 | prev_g = g 187 | prev_f = f 188 | 189 | # Step direction (directional derivative). 190 | gtd = dot(g, d) 191 | 192 | if gtd > -TOL_X: 193 | tf.print('Can not make progress along direction.') 194 | break 195 | 196 | # Step size 197 | if n_iter > 1: 198 | t = LEARNING_RATE 199 | 200 | # No line search, simply move with fixed step. 201 | x += t * d 202 | 203 | if n_iter < MAX_ITER: 204 | # Skip re-evaluation after last iteration. 205 | f, g = loss_and_grad(x, data) 206 | f_evals += 1 # This becomes less trivial when using line search. 207 | 208 | f_hist = f_hist.write(f_hist.size(), f) 209 | 210 | # Check conditions, again on all-but-final-eval only. 211 | if tf.equal(n_iter, MAX_ITER): 212 | break 213 | 214 | if f_evals >= MAX_EVAL: 215 | tf.print('Max number of function evals.') 216 | break 217 | 218 | if tf.reduce_sum(tf.abs(d * t)) <= TOL_X: 219 | tf.print('Step size below TOL_X.') 220 | break 221 | 222 | f_delta = tf.abs(f - prev_f) 223 | if f_delta < TOL_X: 224 | tf.print('Function value changing less than TOL_X.', f_delta) 225 | break 226 | 227 | return x, f_hist.stack() 228 | 229 | 230 | lbfgs_autograph = tf.function( 231 | lbfgs_eager, 232 | experimental_autograph_options=tf.autograph.experimental.Feature.ALL) 233 | 234 | 235 | class LBFGSBenchmark(benchmark_base.ReportingBenchmark): 236 | """Basic benchmark for the L-BFGS algorithm.""" 237 | 238 | def _run_benchmark(self, name, algorithm_function, hidden_size, data): 239 | w_flat = tf.Variable(tf.zeros((INPUT_SIZE * hidden_size,))) 240 | 241 | def target(): 242 | new_w_flat, _ = algorithm_function(w_flat.read_value(), data) 243 | _ = new_w_flat.numpy() 244 | 245 | self.time_execution((name, hidden_size), 246 | target, 247 | extras={ 248 | 'hidden_size': hidden_size, 249 | }) 250 | 251 | def benchmark_lbfgs(self): 252 | data, _ = next(iter(mnist_dataset())) 253 | # TODO(mdanatg): Use more interesting parametrizations. 254 | # TODO(mdanatg): Double check correctness. 255 | for hidden_size in (100,): 256 | self._run_benchmark('Eager', lbfgs_eager, hidden_size, data) 257 | self._run_benchmark('AutoGraph', lbfgs_autograph, hidden_size, data) 258 | 259 | 260 | if __name__ == '__main__': 261 | tf.test.main() 262 | -------------------------------------------------------------------------------- /reference_tests/loop_with_variable_type_illegal_cases_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Loops with type changing variables.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl.testing import parameterized 22 | import reference_test_base 23 | import tensorflow.compat.v2 as tf 24 | 25 | 26 | tf.enable_v2_behavior() 27 | 28 | 29 | def while_with_variable_dtype(): 30 | n = tf.constant(0, dtype=tf.int32) 31 | while tf.constant(True): 32 | n = tf.constant(0, dtype=tf.float32) 33 | return n 34 | 35 | 36 | def while_with_variable_dtype_and_early_stopping(): 37 | n = tf.constant(0, dtype=tf.int32) 38 | while tf.constant(True): 39 | n = tf.constant(0, dtype=tf.float32) 40 | break 41 | return n 42 | 43 | 44 | def for_with_variable_dtype(l): 45 | n = tf.constant(0, dtype=tf.int32) 46 | for _ in l: 47 | n = tf.constant(0, dtype=tf.float32) 48 | return n 49 | 50 | 51 | def for_with_variable_dtype_and_early_stopping(l): 52 | n = tf.constant(0, dtype=tf.int32) 53 | for _ in l: 54 | n = tf.constant(0, dtype=tf.float32) 55 | break 56 | return n 57 | 58 | 59 | def while_with_variable_shape(): 60 | t = tf.constant([1]) 61 | while tf.constant(True): 62 | t = tf.constant([1, 1]) 63 | return t 64 | 65 | 66 | def for_with_variable_shape(l): 67 | t = tf.constant([1]) 68 | for _ in l: 69 | t = tf.constant([1, 1]) 70 | return t 71 | 72 | 73 | def while_with_shape_erasure(): 74 | t = tf.constant([1]) 75 | while tf.constant(True): 76 | t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 77 | return t 78 | 79 | 80 | def for_with_shape_erasure(l): 81 | t = tf.constant([1]) 82 | for _ in l: 83 | t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 84 | return t 85 | 86 | 87 | def while_with_shape_invariant_violation(): 88 | t = tf.constant([1]) 89 | while tf.constant(True): 90 | tf.autograph.experimental.set_loop_options( 91 | shape_invariants=((t, tf.TensorShape([1])),)) 92 | t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 93 | return t 94 | 95 | 96 | def for_with_shape_invariant_violation(l): 97 | t = tf.constant([1]) 98 | for _ in l: 99 | tf.autograph.experimental.set_loop_options( 100 | shape_invariants=((t, tf.TensorShape([1])),)) 101 | t = tf.range(tf.random.uniform((), 2, 3, dtype=tf.int32)) 102 | return t 103 | 104 | 105 | def while_with_variable_structure(): 106 | s = {'a': tf.constant(0)} 107 | while tf.constant(True): 108 | s = tf.constant(7.0) 109 | return s 110 | 111 | 112 | def for_with_variable_structure(l): 113 | s = [tf.constant(0)] 114 | for _ in l: 115 | s = s + [tf.constant(0)] 116 | return s 117 | 118 | 119 | def _tf_range(l): 120 | return tf.range(len(l)) 121 | 122 | 123 | def _dataset(l): 124 | return tf.data.Dataset.from_tensor_slices(l) 125 | 126 | 127 | def _dataset_iterator(l): 128 | return iter(tf.data.Dataset.from_tensor_slices(l)) 129 | 130 | 131 | def _distributed_dataset(l): 132 | ds = tf.data.Dataset.from_tensor_slices([l] * 2) 133 | return tf.distribute.MirroredStrategy().experimental_distribute_dataset(ds) 134 | 135 | 136 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 137 | 138 | def test_while_with_variable_dtype(self): 139 | with self.assertRaisesRegex( 140 | TypeError, 141 | '"n" has dtype int32 before the loop, but dtype float32 after'): 142 | tf.function(while_with_variable_dtype)() 143 | 144 | def test_while_with_variable_dtype_and_early_stopping(self): 145 | with self.assertRaisesRegex( 146 | TypeError, 147 | '"n" has dtype int32 before the loop, but dtype float32 after'): 148 | tf.function(while_with_variable_dtype_and_early_stopping)() 149 | 150 | @parameterized.parameters( 151 | (tf.constant,), 152 | (_tf_range,), 153 | (_dataset,), 154 | (_dataset_iterator,), 155 | (_distributed_dataset,), 156 | ) 157 | def test_for_with_variable_dtype(self, type_): 158 | l = type_([1, 2, 3]) 159 | with self.assertRaisesRegex( 160 | TypeError, 161 | '"n" has dtype int32 before the loop, but dtype float32 after'): 162 | tf.function(for_with_variable_dtype)(l) 163 | 164 | # Note: distributed datasets don't allow early stopping. 165 | @parameterized.parameters( 166 | (tf.constant,), 167 | (_tf_range,), 168 | (_dataset,), 169 | (_dataset_iterator,), 170 | ) 171 | def test_for_with_variable_dtype_and_early_stopping(self, type_): 172 | l = type_([1, 2, 3]) 173 | with self.assertRaisesRegex( 174 | TypeError, 175 | '"n" has dtype int32 before the loop, but dtype float32 after'): 176 | tf.function(for_with_variable_dtype_and_early_stopping)(l) 177 | 178 | def test_while_with_variable_shape(self): 179 | with self.assertRaisesRegex( 180 | ValueError, 181 | r'"t" has shape \(1,\) before the loop, but shape \(2,\) after'): 182 | tf.function(while_with_variable_shape)() 183 | 184 | # Note: datasets do allow variable shape. 185 | @parameterized.parameters( 186 | (tf.constant,), 187 | (_tf_range,), 188 | (_dataset_iterator,), 189 | (_distributed_dataset,), 190 | ) 191 | def test_for_with_variable_shape(self, type_): 192 | l = type_([1, 2, 3]) 193 | with self.assertRaisesRegex( 194 | ValueError, 195 | r'"t" has shape \(1,\) before the loop, but shape \(2,\) after'): 196 | tf.function(for_with_variable_shape)(l) 197 | 198 | def test_while_with_shape_erasure(self): 199 | with self.assertRaisesRegex( 200 | ValueError, 201 | r'"t" has shape \(1,\) before the loop, but shape \(None,\) after'): 202 | tf.function(while_with_shape_erasure)() 203 | 204 | # Note: datasets do allow variable shape. 205 | @parameterized.parameters( 206 | (tf.constant,), 207 | (_tf_range,), 208 | (_dataset_iterator,), 209 | (_distributed_dataset,), 210 | ) 211 | def test_for_with_shape_erasure(self, type_): 212 | l = type_([1, 2, 3]) 213 | with self.assertRaisesRegex( 214 | ValueError, 215 | r'"t" has shape \(1,\) before the loop, but shape \(None,\) after'): 216 | tf.function(for_with_shape_erasure)(l) 217 | 218 | def test_while_with_shape_invariant_violation(self): 219 | with self.assertRaisesRegex( 220 | ValueError, 221 | r'"t" has shape \(None,\) after one iteration, which does not conform'): 222 | tf.function(while_with_shape_invariant_violation)() 223 | 224 | # Note: dataset loops ignore shape invariants. 225 | @parameterized.parameters( 226 | (tf.constant,), 227 | (_tf_range,), 228 | (_dataset_iterator,), 229 | (_distributed_dataset,), 230 | ) 231 | def test_for_with_shape_invariant_violation(self, type_): 232 | l = type_([1, 2, 3]) 233 | with self.assertRaisesRegex( 234 | ValueError, 235 | r'"t" has shape \(None,\) after one iteration, which does not conform'): 236 | tf.function(for_with_shape_invariant_violation)(l) 237 | 238 | def test_while_with_variable_structure(self): 239 | with self.assertRaisesRegex( 240 | TypeError, 241 | '"s" does not have the same nested structure'): 242 | tf.function(while_with_variable_structure)() 243 | 244 | @parameterized.parameters( 245 | (tf.constant,), 246 | (_tf_range,), 247 | (_dataset,), 248 | (_dataset_iterator,), 249 | (_distributed_dataset,), 250 | ) 251 | def test_for_with_variable_structure(self, type_): 252 | l = type_([1, 2, 3]) 253 | with self.assertRaisesRegex( 254 | TypeError, 255 | '"s" does not have the same nested structure'): 256 | tf.function(for_with_variable_structure)(l) 257 | 258 | 259 | if __name__ == '__main__': 260 | tf.test.main() 261 | -------------------------------------------------------------------------------- /reference_tests/loop_with_function_call_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Function calls inside the while loop body.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | from absl.testing import parameterized 24 | import reference_test_base 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | tf.enable_v2_behavior() 29 | 30 | 31 | def while_with_call_in_cond(n, fn): 32 | i = 0 33 | s = 0 34 | while i < fn(n): 35 | s = s * 10 + i 36 | i += 1 37 | return s 38 | 39 | 40 | def for_with_call_in_target(l, fn): 41 | s = 0 42 | for i in fn(l): 43 | s = s * 10 + i 44 | return s 45 | 46 | 47 | def while_with_local_call_in_cond(n): 48 | 49 | def local_fn(x): 50 | return x * 3 51 | 52 | i = 0 53 | s = 0 54 | while i < local_fn(n): 55 | s = s * 10 + i 56 | i += 1 57 | return s 58 | 59 | 60 | def for_with_local_call_in_target(l): 61 | 62 | def local_fn(l): 63 | return l * 1 64 | 65 | s = 0 66 | for i in local_fn(l): 67 | s = s * 10 + i 68 | return s 69 | 70 | 71 | def while_with_call(n, fn): 72 | i = 0 73 | s = 0 74 | while i < n: 75 | s = s * 10 + fn(i) 76 | i += 1 77 | return s 78 | 79 | 80 | def for_with_call(l, fn): 81 | s = 0 82 | for i in l: 83 | s = s * 10 + fn(i) 84 | return s 85 | 86 | 87 | def while_with_local_call(n): 88 | 89 | def local_fn(x): 90 | return x * 3 91 | 92 | i = 0 93 | s = 0 94 | while i < n: 95 | s = s * 10 + local_fn(i) 96 | i += 1 97 | return s 98 | 99 | 100 | def for_with_local_call(l): 101 | 102 | def local_fn(x): 103 | return x * 3 104 | 105 | s = 0 106 | for i in l: 107 | s = s * 10 + local_fn(i) 108 | return s 109 | 110 | 111 | def while_with_closure_call(n): 112 | i = 0 113 | 114 | def i_via_closure(): 115 | return i + 2 116 | 117 | i = 0 118 | s = 0 119 | while i < n: 120 | s = s * 10 + i_via_closure() 121 | i += 1 122 | return s 123 | 124 | 125 | def for_with_closure_call(l): 126 | i = 0 127 | 128 | def i_via_closure(): 129 | return i + 2 130 | 131 | s = 0 132 | for i in l: 133 | s = s * 10 + i_via_closure() 134 | # TODO(b/134822197): Remove i from return values. 135 | return s, i 136 | 137 | 138 | def while_with_lambda_closure_call(n): 139 | i = 0 140 | s = 0 141 | i_via_closure = lambda: i + 2 142 | while i < n: 143 | s = s * 10 + i_via_closure() 144 | i += 1 145 | return s 146 | 147 | 148 | def for_with_lambda_closure_call(l): 149 | i = 0 150 | s = 0 151 | i_via_closure = lambda: i + 2 152 | for i in l: 153 | s = s * 10 + i_via_closure() 154 | # TODO(b/134822197): Remove i from return values. 155 | return s, i 156 | 157 | 158 | def while_with_method_closure_call(n): 159 | i = 0 160 | 161 | class Callable(object): 162 | 163 | def __call__(self): 164 | return i 165 | 166 | i_via_closure = Callable() 167 | i = 0 168 | s = 0 169 | while i < n: 170 | s = s * 10 + i_via_closure() 171 | i += 1 172 | return s 173 | 174 | 175 | def for_with_method_closure_call(l): 176 | i = 0 177 | 178 | class Callable(object): 179 | 180 | def __call__(self): 181 | return i 182 | 183 | i_via_closure = Callable() 184 | i = 0 185 | s = 0 186 | for i in l: 187 | s = s * 10 + i_via_closure() 188 | # TODO(b/134822197): Remove i from return values. 189 | return s, i 190 | 191 | 192 | def global_fn(x): 193 | return x * 2 194 | 195 | 196 | class TestClass(object): 197 | 198 | def method(self, x): 199 | return x * 4 200 | 201 | 202 | def _int_tensor(x): 203 | return tf.constant(x, dtype=tf.int32) 204 | 205 | 206 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 207 | 208 | @parameterized.parameters(*itertools.product( 209 | (0, 1, 2), 210 | (int, tf.constant), 211 | (global_fn, lambda x: x * 1, TestClass().method, abs), 212 | )) 213 | def test_while_with_call_in_cond(self, n, type_, fn): 214 | n = type_(n) 215 | self.assertFunctionMatchesEager(while_with_call_in_cond, n, fn) 216 | 217 | @parameterized.parameters(*itertools.product( 218 | ([], [1], [1, 2]), 219 | (list, _int_tensor), 220 | (global_fn, lambda x: x * 1, TestClass().method, abs), 221 | )) 222 | def test_for_with_call_in_target(self, l, type_, fn): 223 | l = type_(l) 224 | self.assertFunctionMatchesEager(for_with_call_in_target, l, fn) 225 | 226 | @parameterized.parameters(*itertools.product( 227 | (0, 1, 2), 228 | (int, _int_tensor), 229 | (range, tf.range), 230 | )) 231 | def test_for_with_range_call_in_target(self, l, type_, fn): 232 | l = type_(l) 233 | self.assertFunctionMatchesEager(for_with_call_in_target, l, fn) 234 | 235 | @parameterized.parameters(*itertools.product( 236 | (0, 1, 2), 237 | (int, tf.constant), 238 | (global_fn, lambda x: x * 1, TestClass().method, abs), 239 | )) 240 | def test_while_with_call(self, n, type_, fn): 241 | n = type_(n) 242 | self.assertFunctionMatchesEager(while_with_call, n, fn) 243 | 244 | @parameterized.parameters(*itertools.product( 245 | ([], [1], [1, 2]), 246 | (list, _int_tensor), 247 | (global_fn, lambda x: x * 1, TestClass().method, abs), 248 | )) 249 | def test_for_with_call(self, l, type_, fn): 250 | l = type_(l) 251 | self.assertFunctionMatchesEager(for_with_call, l, fn) 252 | 253 | @parameterized.parameters(*itertools.product( 254 | (0, 1, 2), 255 | (int, tf.constant), 256 | )) 257 | def test_while_with_local_call(self, n, type_): 258 | n = type_(n) 259 | self.assertFunctionMatchesEager(while_with_local_call, n) 260 | 261 | @parameterized.parameters(*itertools.product( 262 | ([], [1], [1, 2]), 263 | (list, _int_tensor), 264 | )) 265 | def test_for_with_local_call(self, l, type_): 266 | l = type_(l) 267 | self.assertFunctionMatchesEager(for_with_local_call, l) 268 | 269 | @parameterized.parameters(*itertools.product( 270 | (0, 1, 2), 271 | (int, tf.constant), 272 | )) 273 | def test_while_with_closure_call(self, n, type_): 274 | n = type_(n) 275 | self.assertFunctionMatchesEager(while_with_closure_call, n) 276 | 277 | @parameterized.parameters(*itertools.product( 278 | ([], [1], [1, 2]), 279 | (list, _int_tensor), 280 | )) 281 | def test_for_with_closure_call(self, l, type_): 282 | l = type_(l) 283 | self.assertFunctionMatchesEager(for_with_closure_call, l) 284 | 285 | @parameterized.parameters(*itertools.product( 286 | (0, 1, 2), 287 | (int, tf.constant), 288 | )) 289 | def test_while_with_lambda_closure_call(self, n, type_): 290 | n = type_(n) 291 | self.assertFunctionMatchesEager(while_with_lambda_closure_call, n) 292 | 293 | @parameterized.parameters(*itertools.product( 294 | ([], [1], [1, 2]), 295 | (list, _int_tensor), 296 | )) 297 | def test_for_with_lambda_closure_call(self, l, type_): 298 | l = type_(l) 299 | self.assertFunctionMatchesEager(for_with_lambda_closure_call, l) 300 | 301 | @parameterized.parameters(*itertools.product( 302 | (0, 1, 2), 303 | (int, tf.constant), 304 | )) 305 | def test_while_with_method_closure_call(self, n, type_): 306 | self.skipTest('fix static analysis for nested classes') 307 | n = type_(n) 308 | self.assertFunctionMatchesEager(while_with_method_closure_call, n) 309 | 310 | @parameterized.parameters(*itertools.product( 311 | ([], [1], [1, 2]), 312 | (list, _int_tensor), 313 | )) 314 | def test_for_with_method_closure_call(self, l, type_): 315 | self.skipTest('fix static analysis for nested classes') 316 | l = type_(l) 317 | self.assertFunctionMatchesEager(for_with_method_closure_call, l) 318 | 319 | 320 | if __name__ == '__main__': 321 | tf.test.main() 322 | -------------------------------------------------------------------------------- /reference_tests/cond_basic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic conditionals.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | import reference_test_base 24 | from absl.testing import parameterized 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | tf.enable_v2_behavior() 29 | 30 | 31 | def if_no_vars(c, v): 32 | v.assign(0) 33 | if c: 34 | v.assign_add(1) 35 | return v.read_value() 36 | 37 | 38 | def if_else_no_vars(c, v): 39 | v.assign(0) 40 | if c: 41 | v.assign_add(1) 42 | else: 43 | v.assign_add(2) 44 | return v.read_value() 45 | 46 | 47 | def if_one_var(n): 48 | i = 0 49 | if i < n: 50 | i += 1 51 | return i 52 | 53 | 54 | def if_else_one_var(n): 55 | i = 0 56 | if i < n: 57 | i += 1 58 | else: 59 | i += 2 60 | return i 61 | 62 | 63 | def if_two_vars(n): 64 | i = 0 65 | j = 1 66 | if i < n: 67 | i += 1 68 | j *= 10 69 | return i, j 70 | 71 | 72 | def if_else_two_vars(n): 73 | i = 0 74 | j = 1 75 | if i < n: 76 | i += 1 77 | j *= 10 78 | else: 79 | i += 2 80 | j *= 20 81 | return i, j 82 | 83 | 84 | def if_creates_var(c): 85 | if c: 86 | i = 1 87 | return i 88 | 89 | 90 | def if_else_creates_var(c): 91 | if c: 92 | i = 1 93 | else: 94 | i = 2 95 | return i 96 | 97 | 98 | def else_creates_var(c): 99 | if c: 100 | pass 101 | else: 102 | i = 2 103 | return i 104 | 105 | 106 | def if_returns_none(c): 107 | i = 0 108 | j = 1 109 | if c: 110 | i = None 111 | j = 2 112 | return i, j 113 | 114 | 115 | def if_else_returns_none(c): 116 | if c: 117 | i = None 118 | j = 1 119 | else: 120 | i = None 121 | j = 2 122 | return i, j 123 | 124 | 125 | def else_returns_none(c): 126 | i = 1 127 | j = 1 128 | if c: 129 | pass 130 | else: 131 | i = None 132 | j = 2 133 | return i, j 134 | 135 | 136 | def if_local_var(c): 137 | i = 0 138 | if c: 139 | j = 1 140 | i = j + 1 141 | return i 142 | 143 | 144 | def if_else_local_var(c): 145 | i = 0 146 | if c: 147 | j = 1 148 | else: 149 | j = 2 150 | i = j + 1 151 | return i 152 | 153 | 154 | def successive_ifs(n1, n2): 155 | s = 0 156 | i = 0 157 | if i < n1: 158 | s = s * 10 + i 159 | i += 1 160 | i = 0 161 | if i < n2: 162 | s = s * 10 + i 163 | i += 1 164 | return s 165 | 166 | 167 | def successive_if_elses(n1, n2): 168 | s = 0 169 | i = 0 170 | if i < n1: 171 | s = s * 10 + i 172 | i += 1 173 | else: 174 | s = s * 11 + i 175 | i += 2 176 | i = 0 177 | if i < n2: 178 | s = s * 10 + i 179 | i += 1 180 | else: 181 | s = s * 11 + i 182 | i += 2 183 | return s 184 | 185 | 186 | def nested_ifs(n1, n2): 187 | i = 0 188 | l = 0 189 | if i < n1: 190 | j = 0 191 | s = 0 192 | if j < n2: 193 | s = s * 10 + i * j 194 | j += 1 195 | l = l * 1000 + s 196 | i += 1 197 | return l 198 | 199 | 200 | def nested_if_elses(n1, n2): 201 | i = 0 202 | l = 0 203 | if i < n1: 204 | j = 0 205 | s = 0 206 | if j < n2: 207 | s = s * 10 + i * j 208 | j += 1 209 | else: 210 | s = s * 11 + i * j 211 | j += 2 212 | l = l * 1000 + s 213 | i += 1 214 | else: 215 | j = 0 216 | s = 0 217 | if j < n2: 218 | s = s * 12 + i * j 219 | j += 3 220 | else: 221 | s = s * 13 + i * j 222 | j += 4 223 | l = l * 2000 + s 224 | i += 1 225 | return l 226 | 227 | 228 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 229 | 230 | @parameterized.parameters(*itertools.product( 231 | ( 232 | if_no_vars, 233 | if_else_no_vars, 234 | ), 235 | ( 236 | True, 237 | False, 238 | ), 239 | ( 240 | bool, 241 | tf.constant, 242 | ), 243 | )) 244 | def test_no_vars(self, target, c, type_): 245 | c = type_(c) 246 | self.assertFunctionMatchesEager(target, c, tf.Variable(0)) 247 | 248 | @parameterized.parameters(*itertools.product( 249 | ( 250 | if_one_var, 251 | if_else_one_var, 252 | if_two_vars, 253 | if_else_two_vars, 254 | ), 255 | ( 256 | 0, 257 | 1, 258 | ), 259 | ( 260 | int, 261 | tf.constant, 262 | ), 263 | )) 264 | def test_several_vars(self, target, n, type_): 265 | n = type_(n) 266 | self.assertFunctionMatchesEager(target, n) 267 | 268 | def test_creates_var_imbalanced_legal(self): 269 | self.assertFunctionMatchesEager(if_creates_var, True) 270 | self.assertFunctionMatchesEager(else_creates_var, False) 271 | 272 | @parameterized.parameters(*itertools.product( 273 | ( 274 | True, 275 | False, 276 | ), 277 | ( 278 | int, 279 | tf.constant, 280 | ), 281 | )) 282 | def test_if_else_creates_var(self, c, type_): 283 | c = type_(c) 284 | self.assertFunctionMatchesEager(if_else_creates_var, c) 285 | 286 | # The odd special_values.Undefined messages are due to b/131412459. 287 | @parameterized.parameters( 288 | (if_creates_var, False, bool, TypeError, "special_values.Undefined"), 289 | (if_creates_var, True, tf.constant, ValueError, 290 | r"must also be initialized in the else branch: \('i',\)"), 291 | (if_creates_var, False, tf.constant, ValueError, 292 | r"must also be initialized in the else branch: \('i',\)"), 293 | (else_creates_var, True, bool, TypeError, "special_values.Undefined"), 294 | (else_creates_var, True, tf.constant, ValueError, 295 | r"must also be initialized in the if branch: \('i',\)"), 296 | (else_creates_var, False, tf.constant, ValueError, 297 | r"must also be initialized in the if branch: \('i',\)"), 298 | ) 299 | def test_creates_var_imbalanced_illegal(self, target, c, type_, exc_type, 300 | exc_regex): 301 | c = type_(c) 302 | with self.assertRaisesRegex(exc_type, exc_regex): 303 | tf.function(target)(c) 304 | 305 | def test_returns_none_legal(self): 306 | self.assertFunctionMatchesEager(if_returns_none, True) 307 | self.assertFunctionMatchesEager(if_else_returns_none, False) 308 | self.assertFunctionMatchesEager(else_returns_none, False) 309 | 310 | @parameterized.parameters( 311 | (if_returns_none, True), 312 | (if_returns_none, False), 313 | (else_returns_none, True), 314 | (else_returns_none, False), 315 | (if_else_returns_none, True), 316 | (if_else_returns_none, False), 317 | ) 318 | def test_returns_none_illegal(self, target, c): 319 | c = tf.constant(c) 320 | with self.assertRaisesRegex(ValueError, '"i" is None'): 321 | tf.function(target)(c) 322 | 323 | @parameterized.parameters(*itertools.product( 324 | ( 325 | if_local_var, 326 | if_else_local_var, 327 | ), 328 | ( 329 | True, 330 | False, 331 | ), 332 | ( 333 | bool, 334 | tf.constant, 335 | ), 336 | )) 337 | def test_local_vars(self, target, c, type_): 338 | c = type_(c) 339 | self.assertFunctionMatchesEager(target, c) 340 | 341 | @parameterized.parameters(*itertools.product( 342 | ( 343 | successive_ifs, 344 | successive_if_elses, 345 | nested_ifs, 346 | nested_if_elses, 347 | ), 348 | ( 349 | 0, 350 | 1, 351 | ), 352 | ( 353 | bool, 354 | tf.constant, 355 | ), 356 | ( 357 | 0, 358 | 1, 359 | ), 360 | ( 361 | bool, 362 | tf.constant, 363 | ), 364 | )) 365 | def test_composition(self, target, n1, n1_type, n2, n2_type): 366 | n1 = n1_type(n1) 367 | n2 = n2_type(n2) 368 | self.assertFunctionMatchesEager(target, n1, n2) 369 | 370 | 371 | if __name__ == "__main__": 372 | tf.test.main() 373 | -------------------------------------------------------------------------------- /examples/sysml2019/benchmark_dashboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "Jxv6goXm7oGF" 8 | }, 9 | "source": [ 10 | "##### Copyright 2017 The TensorFlow Authors.\n", 11 | "\n", 12 | "Licensed under the Apache License, Version 2.0 (the \"License\");" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 0, 18 | "metadata": { 19 | "colab": {}, 20 | "colab_type": "code", 21 | "id": "llMNufAK7nfK" 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "#@title Licensed under the Apache License, Version 2.0 (the \"License\"); { display-mode: \"form\" }\n", 26 | "# you may not use this file except in compliance with the License.\n", 27 | "# You may obtain a copy of the License at\n", 28 | "#\n", 29 | "# https://www.apache.org/licenses/LICENSE-2.0\n", 30 | "#\n", 31 | "# Unless required by applicable law or agreed to in writing, software\n", 32 | "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", 33 | "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 34 | "# See the License for the specific language governing permissions and\n", 35 | "# limitations under the License." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "colab_type": "text", 42 | "id": "ql6OnsW4n9Hi" 43 | }, 44 | "source": [ 45 | "This notebook allows inspecting the benchmark results for the examples found in this directory.\n", 46 | "\n", 47 | "To generate data, set the following env variable:\n", 48 | "\n", 49 | " TEST_REPORT_FILE_PREFIX=/tmp/autograph/sysml2019_benchmarks/\n", 50 | "\n", 51 | "Then run the benchmarks with this argument:\n", 52 | "\n", 53 | " --benchmarks=." 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 0, 59 | "metadata": { 60 | "colab": {}, 61 | "colab_type": "code", 62 | "id": "wnYwsa-HT8La" 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "import numpy as np\n", 67 | "import pandas as pd\n", 68 | "import tensorflow as tf\n", 69 | "from tensorflow.core.util import test_log_pb2" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 0, 75 | "metadata": { 76 | "colab": {}, 77 | "colab_type": "code", 78 | "id": "xUdSDbyoVKo2" 79 | }, 80 | "outputs": [], 81 | "source": [ 82 | "def load_benchmarks(path, columns, extra_cols=None):\n", 83 | " results = []\n", 84 | "\n", 85 | " for f in tf.io.gfile.glob(path):\n", 86 | " with tf.io.gfile.GFile(f, 'rb') as infile:\n", 87 | " serialized_entry = infile.read()\n", 88 | " benchmark_item = test_log_pb2.BenchmarkEntries.FromString(\n", 89 | " serialized_entry)\n", 90 | " entry, = benchmark_item.entry\n", 91 | " extras = entry.extras\n", 92 | " \n", 93 | " names = extras['name'].string_value\n", 94 | " if names.startswith('('):\n", 95 | " names = tuple(str(n) for n in names[1:-2].split(','))\n", 96 | " else:\n", 97 | " names = (names,)\n", 98 | "\n", 99 | " all_times = None\n", 100 | " all_times = extras['all_times'].string_value[1:-1].split(', ')\n", 101 | " all_times = list(map(float, all_times))\n", 102 | " \n", 103 | " extra_col_values = ()\n", 104 | " if extra_cols:\n", 105 | " for c in extra_cols:\n", 106 | " extra_col_values += (extras[c].double_value,)\n", 107 | " \n", 108 | " for time in all_times:\n", 109 | " results.append(names + (time,) + extra_col_values)\n", 110 | "\n", 111 | " if extra_cols:\n", 112 | " columns += extra_cols\n", 113 | " \n", 114 | " return pd.DataFrame(results, columns=columns)" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 0, 120 | "metadata": { 121 | "colab": {}, 122 | "colab_type": "code", 123 | "id": "RmPPny0p9x5V" 124 | }, 125 | "outputs": [], 126 | "source": [ 127 | "data = load_benchmarks(\n", 128 | " '/tmp/autograph/sysml2019_benchmarks/BeamSearchBenchmark.*',\n", 129 | " ('benchmark', 'max_seq_len', 'vocab_size', 'time'))\n", 130 | "\n", 131 | "data.groupby(['benchmark', 'max_seq_len', 'vocab_size']).agg([np.mean, np.std])" 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": 0, 137 | "metadata": { 138 | "colab": {}, 139 | "colab_type": "code", 140 | "id": "qBgytDLM8iim" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "data = load_benchmarks('/tmp/autograph/sysml2019_benchmarks/MAMLBenchmark.*',\n", 145 | " ('benchmark', 'meta_steps', 'time'))\n", 146 | "\n", 147 | "data.groupby(['benchmark', 'meta_steps']).agg([np.mean, np.std])" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 0, 153 | "metadata": { 154 | "colab": {}, 155 | "colab_type": "code", 156 | "id": "uCXjNIquVLz0" 157 | }, 158 | "outputs": [], 159 | "source": [ 160 | "data = load_benchmarks('/tmp/autograph/sysml2019_benchmarks/LBFGSBenchmark.*',\n", 161 | " ('benchmark', 'batch_size', 'time'))\n", 162 | "\n", 163 | "data.groupby(['benchmark', 'batch_size']).agg([np.mean, np.std])" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 0, 169 | "metadata": { 170 | "colab": {}, 171 | "colab_type": "code", 172 | "id": "dCufVwhyP8vx" 173 | }, 174 | "outputs": [], 175 | "source": [ 176 | "data = load_benchmarks('/tmp/autograph/sysml2019_benchmarks/MAMLBenchmark.*',\n", 177 | " ('benchmark', 'meta_steps', 'time'))\n", 178 | "\n", 179 | "data.groupby(['benchmark', 'meta_steps']).agg([np.mean, np.std])" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": 0, 185 | "metadata": { 186 | "colab": {}, 187 | "colab_type": "code", 188 | "id": "IQ8vT5R4-o4i" 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "data = load_benchmarks(\n", 193 | " '/tmp/autograph/sysml2019_benchmarks/MNISTBenchmark.*',\n", 194 | " ('benchmark', 'time'),\n", 195 | " extra_cols=('iter_volume',))\n", 196 | "\n", 197 | "data['examples_per_sec'] = data['iter_volume'].values / data['time'].values\n", 198 | "data = data[['benchmark', 'examples_per_sec']]\n", 199 | "\n", 200 | "data.groupby(['benchmark']).agg([np.mean, np.std])" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 0, 206 | "metadata": { 207 | "colab": {}, 208 | "colab_type": "code", 209 | "id": "tSTC5EirQKCV" 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "data = load_benchmarks(\n", 214 | " '/tmp/autograph/sysml2019_benchmarks/RNNBenchmark.*',\n", 215 | " ('benchmark', 'batch_size', 'max_seq_len', 'time'),\n", 216 | " extra_cols=('iter_volume',))\n", 217 | "\n", 218 | "data['examples_per_sec'] = data['iter_volume'].values / data['time'].values \n", 219 | "data = data[['benchmark', 'batch_size', 'max_seq_len', 'examples_per_sec']]\n", 220 | "\n", 221 | "data.groupby(['benchmark', 'batch_size', 'max_seq_len']).agg([np.mean, np.std])" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 0, 227 | "metadata": { 228 | "colab": {}, 229 | "colab_type": "code", 230 | "id": "zsE0a1njQYg_" 231 | }, 232 | "outputs": [], 233 | "source": [ 234 | "data = load_benchmarks(\n", 235 | " '/tmp/autograph/sysml2019_benchmarks/Seq2SeqBenchmark.*',\n", 236 | " ('benchmark', 'max_seq_len', 'vocab_size', 'teacher_forcing', 'time'))\n", 237 | "\n", 238 | "data.groupby(['benchmark', 'max_seq_len', 'vocab_size', 'teacher_forcing']).agg([np.mean, np.std])" 239 | ] 240 | } 241 | ], 242 | "metadata": { 243 | "colab": { 244 | "collapsed_sections": [], 245 | "last_runtime": { 246 | "build_target": "", 247 | "kind": "local" 248 | }, 249 | "name": "Benchmark result analysis", 250 | "provenance": [], 251 | "version": "0.3.2" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 0 256 | } 257 | -------------------------------------------------------------------------------- /reference_tests/loop_basic_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic loops iterating over various types.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import functools 22 | import itertools 23 | 24 | from absl.testing import parameterized 25 | import reference_test_base 26 | import tensorflow.compat.v2 as tf 27 | 28 | 29 | tf.enable_v2_behavior() 30 | 31 | 32 | def while_no_vars(n, v): 33 | v.assign(0) 34 | while v.read_value() < n: 35 | v.assign_add(1) 36 | return v.read_value() 37 | 38 | 39 | def for_no_vars(l, v): 40 | v.assign(0) 41 | for _ in l: 42 | v.assign_add(1) 43 | return v.read_value() 44 | 45 | 46 | def while_one_var(n): 47 | i = 0 48 | while i < n: 49 | i = i * 10 + 1 50 | return i 51 | 52 | 53 | def for_one_var(l): 54 | i = 0 55 | for i in l: 56 | pass 57 | return i 58 | 59 | 60 | def while_two_vars(n): 61 | s = 0 62 | i = 0 63 | while i < n: 64 | s = s * 10 + i 65 | i += 1 66 | return s 67 | 68 | 69 | def for_two_vars(l): 70 | s = 0 71 | for i in l: 72 | s = s * 10 + i 73 | return s 74 | 75 | 76 | def successive_while_loops(n1, n2): 77 | s = 0 78 | i = 0 79 | while i < n1: 80 | s = s * 10 + i 81 | i += 1 82 | i = 0 83 | while i < n2: 84 | s = s * 10 + i 85 | i += 1 86 | return s 87 | 88 | 89 | def successive_for_loops(l1, l2): 90 | s = 0 91 | for i in l1: 92 | s = s * 10 + i 93 | for i in l2: 94 | s = s * 10 + i 95 | return s 96 | 97 | 98 | def nested_while_loops(n1, n2): 99 | i = 0 100 | l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=()) 101 | while i < n1: 102 | j = 0 103 | s = 0 104 | while j < n2: 105 | s = s * 10 + i * j 106 | j += 1 107 | l = l.write(i, s) 108 | i += 1 109 | return l.stack() 110 | 111 | 112 | def nested_for_loops(m): 113 | l = tf.TensorArray(tf.int32, size=0, dynamic_size=True, element_shape=()) 114 | for i in m: 115 | s = 0 116 | for j in i: 117 | s = s * 10 + j 118 | l = l.write(l.size(), s) 119 | return l.stack() 120 | 121 | 122 | def _int_tensor(x): 123 | return tf.constant(x, dtype=tf.int32) 124 | 125 | 126 | def _int_dataset(l): 127 | return tf.data.Dataset.from_tensor_slices(tf.constant(l, dtype=tf.int32)) 128 | 129 | 130 | def double_product(l1, l2): 131 | for i in l1: 132 | for j in l2: 133 | for k in l1: 134 | for l in l2: 135 | yield i, j, k, l 136 | 137 | 138 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 139 | 140 | @parameterized.parameters(*itertools.product( 141 | ( 142 | 0, 143 | 1, 144 | 2, 145 | ), 146 | ( 147 | int, 148 | _int_tensor, 149 | ), 150 | )) 151 | def test_while_no_vars(self, n, type_): 152 | n = type_(n) 153 | self.assertFunctionMatchesEager(while_no_vars, n, tf.Variable(0)) 154 | 155 | @parameterized.parameters(*itertools.product( 156 | ( 157 | [], 158 | [1], 159 | [1, 2], 160 | ), 161 | ( 162 | list, 163 | _int_tensor, 164 | # TODO(mdan): Enable this once #35335 is fixed. 165 | # _int_dataset, 166 | ))) 167 | def test_for_no_vars(self, l, type_): 168 | l = type_(l) 169 | self.assertFunctionMatchesEager(for_no_vars, l, tf.Variable(0)) 170 | 171 | @parameterized.parameters( 172 | ([],), 173 | ([1],), 174 | ([1, 2],), 175 | ) 176 | def test_for_no_vars_ds_iterator(self, l): 177 | inputs_ = lambda: (iter(_int_dataset(l)), tf.Variable(0)) 178 | self.assertFunctionMatchesEagerStatefulInput(for_no_vars, inputs_) 179 | 180 | @parameterized.parameters(*itertools.product( 181 | ( 182 | 0, 183 | 1, 184 | 2, 185 | ), 186 | ( 187 | int, 188 | _int_tensor, 189 | ), 190 | )) 191 | def test_while_one_var(self, n, type_): 192 | n = type_(n) 193 | self.assertFunctionMatchesEager(while_one_var, n) 194 | 195 | @parameterized.parameters(*itertools.product( 196 | ( 197 | [], 198 | [1], 199 | [1, 2], 200 | ), 201 | ( 202 | list, 203 | _int_tensor, 204 | _int_dataset, 205 | ), 206 | )) 207 | def test_for_one_var(self, l, type_): 208 | l = type_(l) 209 | self.assertFunctionMatchesEager(for_one_var, l) 210 | 211 | @parameterized.parameters( 212 | ([],), 213 | ([1],), 214 | ([1, 2],), 215 | ) 216 | def test_for_one_var_ds_iterator(self, l): 217 | inputs_ = lambda: (iter(_int_dataset(l)), tf.Variable(0)) 218 | self.assertFunctionMatchesEagerStatefulInput(for_one_var, inputs_) 219 | 220 | @parameterized.parameters(*itertools.product( 221 | ( 222 | 0, 223 | 1, 224 | 2, 225 | ), 226 | ( 227 | int, 228 | _int_tensor, 229 | ), 230 | )) 231 | def test_while_two_vars(self, n, type_): 232 | n = type_(n) 233 | self.assertFunctionMatchesEager(while_two_vars, n) 234 | 235 | @parameterized.parameters(*itertools.product( 236 | ( 237 | [], 238 | [1], 239 | [1, 2], 240 | ), 241 | ( 242 | list, 243 | _int_tensor, 244 | _int_dataset, 245 | ), 246 | )) 247 | def test_for_two_vars(self, l, type_): 248 | l = type_(l) 249 | self.assertFunctionMatchesEager(for_two_vars, l) 250 | 251 | @parameterized.parameters( 252 | ([],), 253 | ([1],), 254 | ([1, 2],), 255 | ) 256 | def test_for_two_vars_ds_iterator(self, l): 257 | inputs_ = lambda: (iter(_int_dataset(l)), tf.Variable(0)) 258 | self.assertFunctionMatchesEagerStatefulInput(for_two_vars, inputs_) 259 | 260 | @parameterized.parameters(*double_product( 261 | ( 262 | 0, 263 | 1, 264 | 2, 265 | ), 266 | ( 267 | int, 268 | _int_tensor, 269 | ), 270 | )) 271 | def test_successive_while_loops(self, n1, type1, n2, type2): 272 | n1 = type1(n1) 273 | n2 = type1(n2) 274 | self.assertFunctionMatchesEager(successive_while_loops, n1, n2) 275 | 276 | @parameterized.parameters(*double_product( 277 | ( 278 | [], 279 | [1], 280 | [1, 2], 281 | ), 282 | ( 283 | list, 284 | _int_tensor, 285 | _int_dataset, 286 | ), 287 | )) 288 | def test_successive_for_loops(self, l1, type1, l2, type2): 289 | l1 = type1(l1) 290 | l2 = type1(l2) 291 | self.assertFunctionMatchesEager(successive_for_loops, l1, l2) 292 | 293 | @parameterized.parameters(*double_product( 294 | ( 295 | [], 296 | [1], 297 | [1, 2], 298 | ), 299 | ( 300 | list, 301 | _int_dataset, 302 | ), 303 | )) 304 | def test_successive_for_loops_iterators(self, l1, type1, l2, type2): 305 | inputs_ = lambda: (iter(type1(l1)), iter(type2(l2))) 306 | self.assertFunctionMatchesEagerStatefulInput(successive_for_loops, inputs_) 307 | 308 | @parameterized.parameters(*double_product( 309 | ( 310 | 0, 311 | 1, 312 | 2, 313 | ), 314 | ( 315 | int, 316 | _int_tensor, 317 | ), 318 | )) 319 | def test_nested_while_loops(self, n1, type1, n2, type2): 320 | n1 = type1(n1) 321 | n2 = type1(n2) 322 | self.assertFunctionMatchesEager(nested_while_loops, n1, n2) 323 | 324 | @parameterized.parameters(*itertools.product( 325 | ( 326 | [[]], 327 | [[], []], 328 | [[1]], 329 | [[1], [2]], 330 | [[1, 2]], 331 | [[1, 2], [3, 4]], 332 | ), 333 | ( 334 | _int_tensor, 335 | _int_dataset, 336 | ), 337 | )) 338 | def test_nested_for_loops_dense(self, m, type_): 339 | m = type_(m) 340 | self.assertFunctionMatchesEager(nested_for_loops, m) 341 | 342 | @parameterized.parameters(*itertools.product( 343 | ( 344 | [[]], 345 | [[], [1]], 346 | [[], [1], [1, 2]], 347 | ), 348 | ( 349 | list, 350 | functools.partial(tf.ragged.constant, dtype=tf.int32), 351 | ), 352 | )) 353 | def test_nested_for_loops_ragged(self, m, type_): 354 | m = type_(m) 355 | self.assertFunctionMatchesEager(nested_for_loops, m) 356 | 357 | def test_nested_for_loops_mixed_list(self): 358 | m = [[], _int_tensor([]), [1], _int_tensor([1]), [1, 2]] 359 | self.assertFunctionMatchesEager(nested_for_loops, m) 360 | 361 | 362 | if __name__ == '__main__': 363 | tf.test.main() 364 | -------------------------------------------------------------------------------- /examples/sysml2019/rnn_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark comparing eager, autograph, and official dynamic_rnn. 16 | 17 | This code is tested on TF 1.13. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import benchmark_base 25 | 26 | import numpy as np 27 | import tensorflow.compat.v1 as tf 28 | 29 | 30 | tf.enable_eager_execution() 31 | 32 | 33 | BATCH_SIZE = 32 34 | MAX_SEQ_LEN = 100 35 | FEATURE_SIZE = 50 36 | HIDDEN_SIZE = 256 37 | 38 | 39 | class RNNBenchmark(benchmark_base.ReportingBenchmark): 40 | """Runs benchmarks for eager/autograph/graph variants of dynamic_rnn.""" 41 | 42 | def _generate_fake_rnn_inputs(self, 43 | batch_size=BATCH_SIZE, 44 | max_seq_len=MAX_SEQ_LEN): 45 | np.random.seed(17) 46 | 47 | input_data = np.random.random([batch_size, max_seq_len, 48 | FEATURE_SIZE]).astype(np.float32) 49 | # Generate some varying sequence lengths but keep max(sequence_lengths) 50 | # a constant, for more reproducible benchmarks. 51 | sequence_lengths = np.concatenate(([max_seq_len], 52 | np.random.randint( 53 | max_seq_len // 2, 54 | max_seq_len, 55 | size=[batch_size - 1]))).astype( 56 | np.int32) 57 | 58 | for i, seq_len in enumerate(sequence_lengths): 59 | input_data[i, seq_len:, :] = 0 60 | 61 | input_data = tf.constant(input_data) 62 | sequence_lengths = tf.constant(sequence_lengths) 63 | 64 | return input_data, sequence_lengths 65 | 66 | def _create_rnn_cell(self, batch_size=BATCH_SIZE): 67 | rnn_cell = tf.nn.rnn_cell.BasicRNNCell(HIDDEN_SIZE, dtype=tf.float32) 68 | rnn_cell.build(tf.TensorShape([batch_size, FEATURE_SIZE])) 69 | return rnn_cell, rnn_cell.zero_state(batch_size, dtype=tf.float32) 70 | 71 | def _benchmark_eager_dynamic_rnn(self, batch_size, max_seq_len): 72 | input_data, sequence_lengths = self._generate_fake_rnn_inputs( 73 | batch_size=batch_size, max_seq_len=max_seq_len) 74 | rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size) 75 | 76 | def eager_dynamic_rnn(rnn_cell, 77 | input_data, 78 | initial_state, 79 | sequence_length=None): 80 | """An eager version of dynamic_rnn.""" 81 | # [batch, time, features] -> [time, batch, features] 82 | input_data = tf.transpose(input_data, [1, 0, 2]) 83 | outputs = [] 84 | state = initial_state 85 | if sequence_length is None: 86 | max_seq_len = input_data.shape[0] 87 | else: 88 | max_seq_len = tf.reduce_max(sequence_length) 89 | for i in range(max_seq_len): 90 | new_output, new_state = rnn_cell(input_data[i], state) 91 | output = tf.where(i < sequence_length, new_output, 92 | tf.zeros(new_output.shape)) 93 | state = tf.where(i < sequence_length, new_state, state) 94 | outputs.append(output) 95 | return tf.transpose(tf.stack(outputs), [1, 0, 2]), state 96 | 97 | def target(): 98 | eager_dynamic_rnn(rnn_cell, input_data, initial_state, sequence_lengths) 99 | 100 | self.time_execution( 101 | ('Eager', batch_size, max_seq_len), 102 | target, 103 | iter_volume=batch_size, 104 | iter_unit='examples', 105 | extras={ 106 | 'max_seq_len': max_seq_len, 107 | 'batch_size': batch_size, 108 | }) 109 | 110 | def _benchmark_handwritten_dynamic_rnn(self, batch_size, max_seq_len): 111 | 112 | def my_dynamic_rnn(rnn_cell, 113 | input_data, 114 | initial_state, 115 | sequence_length=None): 116 | """A handwritten reimplementation of dynamic_rnn.""" 117 | input_data = tf.transpose(input_data, [1, 0, 2]) 118 | outputs = tf.TensorArray(tf.float32, input_data.shape[0]) 119 | if sequence_length is None: 120 | max_seq_len = input_data.shape[0] 121 | else: 122 | max_seq_len = tf.reduce_max(sequence_length) 123 | 124 | def while_body(i, state, outputs): 125 | new_output, new_state = rnn_cell(input_data[i], state) 126 | output = tf.where(i < sequence_length, new_output, 127 | tf.zeros(new_output.shape)) 128 | state = tf.where(i < sequence_length, new_state, state) 129 | outputs = outputs.write(i, output) 130 | return i + 1, state, outputs 131 | 132 | def while_cond(i, unused_state, unused_outputs): 133 | return i < max_seq_len 134 | 135 | _, state, outputs = tf.while_loop( 136 | while_cond, 137 | while_body, 138 | loop_vars=(tf.constant(0), initial_state, outputs)) 139 | return tf.transpose(outputs.stack(), [1, 0, 2]), state 140 | 141 | with tf.Graph().as_default(): 142 | input_data, sequence_lengths = self._generate_fake_rnn_inputs( 143 | batch_size=batch_size, max_seq_len=max_seq_len) 144 | rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size) 145 | graph_output_t = my_dynamic_rnn(rnn_cell, input_data, initial_state, 146 | sequence_lengths) 147 | 148 | with tf.Session() as sess: 149 | sess.run(tf.global_variables_initializer()) 150 | 151 | def target(): 152 | sess.run(graph_output_t) 153 | 154 | self.time_execution( 155 | ('Handwritten', batch_size, max_seq_len), 156 | target, 157 | iter_volume=batch_size, 158 | iter_unit='examples', 159 | extras={ 160 | 'max_seq_len': max_seq_len, 161 | 'batch_size': batch_size, 162 | }) 163 | 164 | def benchmark_dynamic_rnn(self): 165 | for batch_size in (32, 64, 128): 166 | for max_seq_len in (64, 128): 167 | self._benchmark_eager_dynamic_rnn(batch_size, max_seq_len) 168 | self._benchmark_handwritten_dynamic_rnn(batch_size, max_seq_len) 169 | self._benchmark_ag_dynamic_rnn(batch_size, max_seq_len) 170 | self._benchmark_official_dynamic_rnn(batch_size, max_seq_len) 171 | 172 | def _benchmark_ag_dynamic_rnn(self, batch_size, max_seq_len): 173 | 174 | def ag_dynamic_rnn(rnn_cell, 175 | input_data, 176 | initial_state, 177 | sequence_length=None): 178 | """An autograph-able reimplementation of subset of dynamic_rnn.""" 179 | # [batch, time, features] -> [time, batch, features] 180 | input_data = tf.transpose(input_data, [1, 0, 2]) 181 | if sequence_length is None: 182 | max_seq_len = input_data.shape[0] 183 | else: 184 | max_seq_len = tf.reduce_max(sequence_length) 185 | 186 | outputs = tf.TensorArray(tf.float32, size=max_seq_len) 187 | state = initial_state 188 | for i in tf.range(max_seq_len): 189 | new_output, new_state = rnn_cell(input_data[i], state) 190 | output = tf.where(i < sequence_length, new_output, 191 | tf.zeros(new_output.shape)) 192 | state = tf.where(i < sequence_length, new_state, state) 193 | outputs = outputs.write(i, output) 194 | return tf.transpose(outputs.stack(), [1, 0, 2]), state 195 | 196 | ag_dynamic_rnn = tf.autograph.to_graph(ag_dynamic_rnn) 197 | 198 | with tf.Graph().as_default(): 199 | input_data, sequence_lengths = self._generate_fake_rnn_inputs( 200 | batch_size=batch_size, max_seq_len=max_seq_len) 201 | rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size) 202 | rnn_output = ag_dynamic_rnn(rnn_cell, input_data, initial_state, 203 | sequence_lengths) 204 | 205 | with tf.Session() as sess: 206 | sess.run(tf.global_variables_initializer()) 207 | 208 | def target(): 209 | sess.run(rnn_output) 210 | 211 | self.time_execution( 212 | ('AutoGraph', batch_size, max_seq_len), 213 | target, 214 | iter_volume=batch_size, 215 | iter_unit='examples', 216 | extras={ 217 | 'max_seq_len': max_seq_len, 218 | 'batch_size': batch_size, 219 | }) 220 | 221 | def _benchmark_official_dynamic_rnn(self, batch_size, max_seq_len): 222 | with tf.Graph().as_default(): 223 | input_data, sequence_lengths = self._generate_fake_rnn_inputs( 224 | batch_size=batch_size, max_seq_len=max_seq_len) 225 | rnn_cell, initial_state = self._create_rnn_cell(batch_size=batch_size) 226 | 227 | rnn_output = tf.nn.dynamic_rnn( 228 | rnn_cell, 229 | input_data, 230 | initial_state=initial_state, 231 | sequence_length=sequence_lengths) 232 | 233 | with tf.Session() as sess: 234 | sess.run(tf.global_variables_initializer()) 235 | 236 | def target(): 237 | sess.run(rnn_output) 238 | 239 | self.time_execution( 240 | ('tf.nn.dynamic_rnn', batch_size, max_seq_len), 241 | target, 242 | iter_volume=batch_size, 243 | iter_unit='examples', 244 | extras={ 245 | 'max_seq_len': max_seq_len, 246 | 'batch_size': batch_size, 247 | }) 248 | 249 | 250 | if __name__ == '__main__': 251 | tf.test.main() 252 | -------------------------------------------------------------------------------- /reference_tests/loop_with_variable_type_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Loops with type changing variables.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import itertools 22 | 23 | from absl.testing import parameterized 24 | import reference_test_base 25 | import tensorflow.compat.v2 as tf 26 | 27 | 28 | tf.enable_v2_behavior() 29 | 30 | 31 | def while_with_variable_shape_growing_vector(n): 32 | v = tf.constant([0, 0]) 33 | i = 0 34 | while i < n: 35 | tf.autograph.experimental.set_loop_options( 36 | shape_invariants=[(v, tf.TensorShape([None]))]) 37 | v = tf.concat((v, [i]), 0) 38 | i += 1 39 | return v 40 | 41 | 42 | def for_with_variable_shape_growing_vector(l): 43 | v = tf.constant([0, 0]) 44 | for i in l: 45 | tf.autograph.experimental.set_loop_options( 46 | shape_invariants=[(v, tf.TensorShape([None]))]) 47 | v = tf.concat((v, [i]), 0) 48 | return v 49 | 50 | 51 | def while_with_variable_shape_growing_matrix_rows(n): 52 | m = tf.constant([[0]]) 53 | i = 0 54 | while i < n: 55 | tf.autograph.experimental.set_loop_options( 56 | shape_invariants=[(m, tf.TensorShape([None, 1]))]) 57 | m = tf.concat((m, [[i]]), 0) 58 | i += 1 59 | return m 60 | 61 | 62 | def for_with_variable_shape_growing_matrix_rows(l): 63 | m = tf.constant([[0]]) 64 | for i in l: 65 | tf.autograph.experimental.set_loop_options( 66 | shape_invariants=[(m, tf.TensorShape([None, 1]))]) 67 | m = tf.concat((m, [[i]]), 0) 68 | return m 69 | 70 | 71 | def while_with_variable_shape_growing_matrix_cols(n): 72 | m = tf.constant([[0, 0]]) 73 | i = 0 74 | while i < n: 75 | tf.autograph.experimental.set_loop_options( 76 | shape_invariants=[(m, tf.TensorShape([1, None]))]) 77 | m = tf.concat((m, [[i]]), 1) 78 | i += 1 79 | return m 80 | 81 | 82 | def for_with_variable_shape_growing_matrix_cols(l): 83 | m = tf.constant([[0, 0]]) 84 | for i in l: 85 | tf.autograph.experimental.set_loop_options( 86 | shape_invariants=[(m, tf.TensorShape([1, None]))]) 87 | m = tf.concat((m, [[i]]), 1) 88 | return m 89 | 90 | 91 | def while_with_variable_shape_growing_matrix(n): 92 | m = tf.constant([[0, 0], [0, 0]]) 93 | i = 0 94 | while i < n: 95 | tf.autograph.experimental.set_loop_options( 96 | shape_invariants=[(m, tf.TensorShape(None))]) 97 | m = tf.pad(m, [[1, 1], [1, 1]], constant_values=i) 98 | i += 1 99 | return m 100 | 101 | 102 | def for_with_variable_shape_growing_matrix(l): 103 | m = tf.constant([[0, 0], [0, 0]]) 104 | for i in l: 105 | tf.autograph.experimental.set_loop_options( 106 | shape_invariants=[(m, tf.TensorShape(None))]) 107 | m = tf.pad(m, [[1, 1], [1, 1]], constant_values=i) 108 | return m 109 | 110 | 111 | def while_with_variable_shape_inside_if(n): 112 | v = tf.constant([0, 0]) 113 | i = 0 114 | if n > 1: 115 | while i < n: 116 | tf.autograph.experimental.set_loop_options( 117 | shape_invariants=[(v, tf.TensorShape([None]))]) 118 | v = tf.concat((v, [i]), 0) 119 | i += 1 120 | else: 121 | v = tf.constant([1, 2, 3]) 122 | return v 123 | 124 | 125 | def for_with_variable_shape_inside_if(n): 126 | v = tf.constant([0, 0]) 127 | if n > 1: 128 | for i in range(n): 129 | tf.autograph.experimental.set_loop_options( 130 | shape_invariants=[(v, tf.TensorShape([None]))]) 131 | v = tf.concat((v, [i]), 0) 132 | i += 1 133 | else: 134 | v = tf.constant([1, 2, 3]) 135 | return v 136 | 137 | 138 | def while_with_variable_shape_and_break(n): 139 | v = tf.constant([0, 0]) 140 | i = 0 141 | if n > 1: 142 | while i < n: 143 | tf.autograph.experimental.set_loop_options( 144 | shape_invariants=[(v, tf.TensorShape([None]))]) 145 | v = tf.concat((v, [i]), 0) 146 | i += 1 147 | if i > 3: 148 | break 149 | else: 150 | v = tf.constant([1, 2, 3]) 151 | return v 152 | 153 | 154 | def for_with_variable_shape_and_break(n): 155 | v = tf.constant([0, 0]) 156 | if n > 1: 157 | for i in range(n): 158 | tf.autograph.experimental.set_loop_options( 159 | shape_invariants=[(v, tf.TensorShape([None]))]) 160 | v = tf.concat((v, [i]), 0) 161 | i += 1 162 | if i > 3: 163 | break 164 | else: 165 | v = tf.constant([1, 2, 3]) 166 | return v 167 | 168 | 169 | def while_with_composite_tensor_shape_invariant(n): 170 | v = tf.SparseTensor( 171 | indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[3, 3]) 172 | i = 0 173 | while i < n: 174 | tf.autograph.experimental.set_loop_options( 175 | shape_invariants=[(v, tf.TensorShape(None))]) 176 | v = tf.sparse.expand_dims(v) 177 | i += 1 178 | return v 179 | 180 | 181 | def for_with_composite_tensor_shape_invariant(l): 182 | v = tf.SparseTensor( 183 | indices=[[0, 0], [1, 1]], values=[1, 2], dense_shape=[3, 3]) 184 | for _ in l: 185 | tf.autograph.experimental.set_loop_options( 186 | shape_invariants=[(v, tf.TensorShape(None))]) 187 | v = tf.sparse.expand_dims(v) 188 | return v 189 | 190 | 191 | def _int_dataset_range(n): 192 | return tf.data.Dataset.range(n).map(lambda x: tf.cast(x, tf.int32)) 193 | 194 | 195 | class ReferenceTest(reference_test_base.TestCase, parameterized.TestCase): 196 | 197 | @parameterized.parameters(*itertools.product( 198 | (0, 1, 2), 199 | (int, tf.constant), 200 | )) 201 | def test_while_with_variable_shape_growing_vector(self, n, type_): 202 | n = type_(n) 203 | self.assertFunctionMatchesEager(while_with_variable_shape_growing_vector, n) 204 | 205 | @parameterized.parameters(*itertools.product( 206 | (0, 1, 2), 207 | (range, tf.range, tf.data.Dataset.range), 208 | )) 209 | def test_for_with_variable_shape_growing_vector(self, n, list_type): 210 | l = list_type(n) 211 | self.assertFunctionMatchesEager(for_with_variable_shape_growing_vector, l) 212 | 213 | @parameterized.parameters(*itertools.product( 214 | (0, 1, 2), 215 | (int, tf.constant), 216 | )) 217 | def test_while_with_variable_shape_growing_matrix_rows(self, n, type_): 218 | n = type_(n) 219 | self.assertFunctionMatchesEager( 220 | while_with_variable_shape_growing_matrix_rows, n) 221 | 222 | @parameterized.parameters(*itertools.product( 223 | (0, 1, 2), 224 | (range, tf.range, _int_dataset_range), 225 | )) 226 | def test_for_with_variable_shape_growing_matrix_rows(self, l, type_): 227 | l = type_(l) 228 | self.assertFunctionMatchesEager( 229 | for_with_variable_shape_growing_matrix_rows, l) 230 | 231 | @parameterized.parameters(*itertools.product( 232 | (0, 1, 2), 233 | (int, tf.constant), 234 | )) 235 | def test_while_with_variable_shape_growing_matrix_cols(self, n, type_): 236 | n = type_(n) 237 | self.assertFunctionMatchesEager( 238 | while_with_variable_shape_growing_matrix_cols, n) 239 | 240 | @parameterized.parameters(*itertools.product( 241 | (0, 1, 2), 242 | (range, tf.range, tf.data.Dataset.range), 243 | )) 244 | def test_for_with_variable_shape_growing_matrix_cols(self, l, type_): 245 | l = type_(l) 246 | self.assertFunctionMatchesEager( 247 | for_with_variable_shape_growing_matrix_cols, l) 248 | 249 | @parameterized.parameters(*itertools.product( 250 | (0, 1, 2), 251 | (int, tf.constant), 252 | )) 253 | def test_while_with_variable_shape_growing_matrix(self, n, type_): 254 | n = type_(n) 255 | self.assertFunctionMatchesEager(while_with_variable_shape_growing_matrix, n) 256 | 257 | @parameterized.parameters(*itertools.product( 258 | (0, 1, 2), 259 | (range, tf.range, _int_dataset_range), 260 | )) 261 | def test_for_with_variable_shape_growing_matrix(self, n, type_): 262 | l = type_(n) 263 | self.assertFunctionMatchesEager(for_with_variable_shape_growing_matrix, l) 264 | 265 | @parameterized.parameters(*itertools.product( 266 | (0, 1, 2), 267 | (int, tf.constant), 268 | )) 269 | def test_while_with_variable_shape_inside_if(self, n, type_): 270 | n = type_(n) 271 | self.assertFunctionMatchesEager(while_with_variable_shape_inside_if, n) 272 | 273 | @parameterized.parameters(*itertools.product( 274 | (0, 1, 2), 275 | (int, tf.constant), 276 | )) 277 | def test_for_with_variable_shape_inside_if(self, n, type_): 278 | n = type_(n) 279 | self.assertFunctionMatchesEager(for_with_variable_shape_inside_if, n) 280 | 281 | @parameterized.parameters(*itertools.product( 282 | (0, 1, 2), 283 | (int, tf.constant), 284 | )) 285 | def test_while_with_variable_shape_and_break(self, n, type_): 286 | n = type_(n) 287 | self.assertFunctionMatchesEager(while_with_variable_shape_and_break, n) 288 | 289 | @parameterized.parameters(*itertools.product( 290 | (0, 1, 2, 5), 291 | (int, tf.constant), 292 | )) 293 | def test_for_with_variable_shape_and_break(self, n, type_): 294 | n = type_(n) 295 | self.assertFunctionMatchesEager(for_with_variable_shape_and_break, n) 296 | 297 | @parameterized.parameters(*itertools.product( 298 | (0, 1, 2, 5), 299 | (int, tf.constant), 300 | )) 301 | def test_while_with_composite_tensor_shape_invariant(self, n, type_): 302 | n = type_(n) 303 | self.assertFunctionMatchesEager( 304 | while_with_composite_tensor_shape_invariant, n) 305 | 306 | @parameterized.parameters(*itertools.product( 307 | (0, 1, 2), 308 | (range, tf.range, _int_dataset_range), 309 | )) 310 | def test_for_with_composite_tensor_shape_invariant(self, n, type_): 311 | l = type_(n) 312 | self.assertFunctionMatchesEager( 313 | for_with_composite_tensor_shape_invariant, l) 314 | 315 | 316 | if __name__ == '__main__': 317 | tf.test.main() 318 | -------------------------------------------------------------------------------- /examples/sysml2019/beam_search_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Benchmark for a beam search implementation. 16 | 17 | This code requires TF 2.0 or newer. 18 | Use `pip install tf-nightly-2.0-preview` to install. 19 | Tip: Run the `pip install` command in a separate virtual environment 20 | (Virtualenv, Anaconda) to avoid clobbering an existing TF installation. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import benchmark_base 28 | 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | 33 | # Implicitly assumed by code: OOV_TOKEN = 0 34 | INIT_TOKEN = 1 35 | EOS_TOKEN = 2 36 | 37 | BEAM_SIZE = 5 38 | NEG_INF = -1e10 39 | 40 | MAX_SEQ_LEN = 50 41 | 42 | 43 | class FakeDecoder(tf.keras.Model): 44 | """Stub decoder implementation that just returns a random array of logits.""" 45 | 46 | def __init__(self, vocab_size=1000, hidden_size=256): 47 | super(FakeDecoder, self). __init__() 48 | self.vocab_size = vocab_size 49 | self.lstm = tf.keras.layers.LSTM(hidden_size, return_state=True) 50 | 51 | def get_initial_state(self, batch_size): 52 | return self.lstm.cell.get_initial_state( 53 | batch_size=batch_size, dtype=tf.float32) 54 | 55 | def call(self, x, hidden_state): 56 | np_logits = -np.random.random([x.shape[0], self.vocab_size]) 57 | np_logits[:, EOS_TOKEN] = -10000 # don't allow beam search to exit early 58 | return tf.constant(np_logits, dtype=tf.float32), hidden_state 59 | 60 | 61 | def get_best_alive(cumulative_logprob, eos_token, beam_size, vocab_size): 62 | """Get top_k sequences/log probs, masking out completed sequences. 63 | 64 | Args: 65 | cumulative_logprob: Tensor of shape [beam_size, vocab_size] with cumulative 66 | log-probability for each possible continuation of each beam. 67 | eos_token: ID of the end-of-sequence token. 68 | beam_size: beam size. 69 | vocab_size: vocab size. 70 | 71 | Returns: 72 | tuple of (chosen_sequences, alive_logprobs, alive_indices) 73 | chosen_sequences: Tensor of shape [beam_size] with the indices of the 74 | sequences that are being continued. 75 | alive_logprobs: Tensor of shape [beam_size] with log probs of the top K 76 | alive sequence continuations 77 | alive_indices: Tensor of shape [beam_size] with the token that comes next 78 | for each sequence continuation. 79 | """ 80 | # Mask finished sequences with -INF so that top_k ignores them. 81 | cumulative_logprob = tf.where( 82 | tf.equal(tf.one_hot([eos_token] * beam_size, vocab_size), 1), 83 | tf.tile([[NEG_INF]], [beam_size, vocab_size]), 84 | cumulative_logprob) 85 | 86 | # [beam_size * vocab_size] 87 | flat_logprobs = tf.reshape(cumulative_logprob, [-1]) 88 | alive_logprobs, alive_kn_indices = tf.nn.top_k(flat_logprobs, k=beam_size) 89 | # [beam_size], [beam_size] 90 | alive_indices, chosen_sequences = alive_kn_indices % vocab_size, alive_kn_indices // vocab_size 91 | return chosen_sequences, alive_logprobs, alive_indices 92 | 93 | 94 | def get_best_sequences(beam_size, seq1, logprobs1, seq2, logprobs2): 95 | """Helper function to get top_k from two sets of sequences/log probs. 96 | 97 | Args: 98 | beam_size: Beam size. 99 | seq1: Tensor of shape [beam_size, max_seq_len] 100 | logprobs1: Tensors of shape [beam_size] 101 | seq2: Tensor of shape [beam_size, max_seq_len] 102 | logprobs2: Tensors of shape [beam_size] 103 | 104 | Returns: 105 | top beam_size sequences and log probs from the two sets as tensors of 106 | shape [beam_size, max_seq_len] and [beam_size] 107 | """ 108 | both_seq = tf.concat([seq1, seq2], axis=0) 109 | both_logprobs = tf.concat([logprobs1, logprobs2], axis=0) 110 | chosen_logprobs, chosen_idx = tf.nn.top_k(both_logprobs, k=beam_size) 111 | return tf.gather(both_seq, chosen_idx), chosen_logprobs 112 | 113 | 114 | def beam_search(decoder, init_hidden_state, init_token, eos_token, beam_size, 115 | vocab_size, max_seq_len=MAX_SEQ_LEN): 116 | """Beam search. 117 | 118 | Keeps beam_size living sequences at each iteration, and beam_size completed 119 | sequences at each iteration. Completes when all living sequences have dropped 120 | far enough in probability that no living sequences have any chance of beating 121 | one of the known completed sequences, or if the search limit has been reached. 122 | 123 | If, at the end, an incomplete sequence with max_seq_len has higher probability 124 | than any complete sequence, then it will be ranked higher than the completed 125 | sequence. 126 | 127 | Args: 128 | decoder: Decoder module. 129 | init_hidden_state: A hidden state representing decoding context. Should have 130 | a batch dimension with size 1. 131 | init_token: Token to seed decoding with. 132 | eos_token: Token to compare against to see if sequence is ended. 133 | beam_size: beam size. 134 | vocab_size: vocab size. 135 | max_seq_len: Maximum seq len before stopping and returning what we have. 136 | 137 | Returns: 138 | Tuple of sequences, log probs. 139 | sequences: Tensor of shape [beam_size, max_seq_len] 140 | log_probs: Tensor of shape [beam_size] 141 | """ 142 | init_logits, hidden_state = decoder(tf.constant([init_token]), 143 | init_hidden_state) 144 | start_logprobs = tf.nn.log_softmax(tf.squeeze(init_logits)) 145 | 146 | # Seed the starting sequences by executing decoder once and taking top k. 147 | # [beam_size], [beam_size] 148 | alive_logprobs, alive_indices = tf.nn.top_k(start_logprobs, k=beam_size) 149 | # [beam_size, max_seq_len] 150 | alive_sequences = tf.concat([ 151 | tf.expand_dims(alive_indices, 1), 152 | tf.zeros([beam_size, max_seq_len - 1], dtype=tf.int32)], axis=1) 153 | # [[beam_size, hidden_size], ...] 154 | alive_hidden = tf.nest.map_structure( 155 | lambda s: tf.tile(s, [beam_size, 1]), 156 | hidden_state) 157 | 158 | # Seed finished sequences as the empty sequence, i.e. [, 0, 0...] and 159 | # zeros everywhere else. 160 | # Mark all other sequences with logprob = -INF 161 | finished_sequences = eos_token * tf.one_hot( 162 | [0], beam_size * max_seq_len, dtype=tf.int32) 163 | finished_sequences = tf.reshape(finished_sequences, [beam_size, max_seq_len]) 164 | finished_logprobs = tf.where( 165 | tf.equal(tf.one_hot(0, beam_size), 1), 166 | tf.tile([start_logprobs[eos_token]], [beam_size]), 167 | tf.tile([NEG_INF], [beam_size])) 168 | 169 | for i in tf.range(1, max_seq_len): 170 | # [beam_size, vocab_size], [[beam_size, hidden_size], ..] 171 | next_char_logits, hidden_state = decoder(alive_indices, alive_hidden) 172 | # Adding log probabilities is equivalent to multiplying probabilities. 173 | # [beam_size, vocab_size] 174 | cumulative_logprob = (tf.expand_dims(alive_logprobs, 1) + 175 | tf.nn.log_softmax(next_char_logits)) 176 | 177 | # Pad all the finished/alive sequences so that they maintain the same shape 178 | # with each iteration. (A limitation of AutoGraph-generated tf.while_loops.) 179 | sequence_padding = tf.zeros([beam_size, max_seq_len - i - 1], 180 | dtype=tf.int32) 181 | 182 | # Gather sequences/log probs for finished sequences 183 | newly_finished_sequences = tf.concat([ 184 | alive_sequences[:, :i], 185 | tf.tile([[eos_token]], [beam_size, 1]), 186 | sequence_padding], axis=1) 187 | newly_finished_logprobs = cumulative_logprob[:, eos_token] 188 | finished_sequences, finished_logprobs = get_best_sequences( 189 | beam_size, finished_sequences, finished_logprobs, 190 | newly_finished_sequences, newly_finished_logprobs) 191 | 192 | # Gather sequences/log probs for alive sequences 193 | chosen_sequences, alive_logprobs, alive_indices = get_best_alive( 194 | cumulative_logprob, eos_token, beam_size, vocab_size) 195 | new_sequence_history = tf.gather(alive_sequences, chosen_sequences) 196 | # [beam_size, max_seq_len] 197 | alive_sequences = tf.concat([ 198 | new_sequence_history[:, :i], 199 | tf.expand_dims(alive_indices, 1), 200 | sequence_padding], axis=1) 201 | alive_sequences.set_shape([beam_size, max_seq_len]) 202 | # [[beam_size, hidden_size], ...] 203 | alive_hidden = tf.nest.map_structure( 204 | lambda s: tf.gather(s, chosen_sequences), # pylint: disable=cell-var-from-loop 205 | hidden_state) 206 | 207 | # Exit if all alive sequences are worse than any finished sequence. 208 | if tf.reduce_min(finished_logprobs) > tf.reduce_max(alive_logprobs): 209 | break 210 | # Execute one final collation, just in case any of the alive sequences are 211 | # higher in probability than any of the finished sequences. 212 | finished_sequences, finished_logprobs = get_best_sequences( 213 | beam_size, finished_sequences, finished_logprobs, 214 | alive_sequences, alive_logprobs) 215 | return finished_sequences, finished_logprobs 216 | 217 | 218 | class BeamSearchBenchmark(benchmark_base.ReportingBenchmark): 219 | """Runs benchmarks for eager/autograph variants of beam search.""" 220 | 221 | def _get_decoder(self, vocab_size): 222 | return FakeDecoder(vocab_size=vocab_size) 223 | 224 | def _benchmark_eager(self, max_seq_len, vocab_size): 225 | 226 | decoder = self._get_decoder(vocab_size) 227 | 228 | def target(): 229 | return beam_search(decoder, decoder.get_initial_state(1), INIT_TOKEN, 230 | EOS_TOKEN, BEAM_SIZE, vocab_size, 231 | max_seq_len=max_seq_len) 232 | 233 | self.time_execution(('Eager', max_seq_len, vocab_size), 234 | target, 235 | extras={ 236 | 'max_seq_len': max_seq_len, 237 | 'vocab_size': vocab_size 238 | }) 239 | 240 | def _benchmark_ag(self, max_seq_len, vocab_size): 241 | 242 | decoder = self._get_decoder(vocab_size) 243 | compiled_fn = tf.function(beam_search) 244 | 245 | def target(): 246 | return compiled_fn(decoder, decoder.get_initial_state(1), INIT_TOKEN, 247 | EOS_TOKEN, BEAM_SIZE, vocab_size, 248 | max_seq_len=max_seq_len) 249 | 250 | self.time_execution(('AutoGraph', max_seq_len, vocab_size), 251 | target, 252 | extras={ 253 | 'max_seq_len': max_seq_len, 254 | 'vocab_size': vocab_size 255 | }) 256 | 257 | def benchmark_beamsearch(self): 258 | for max_seq_len in (10, 20, 40, 80): 259 | for vocab_size in (1000, 3000, 10000, 30000): 260 | self._benchmark_eager(max_seq_len, vocab_size) 261 | self._benchmark_ag(max_seq_len, vocab_size) 262 | 263 | 264 | if __name__ == '__main__': 265 | tf.test.main() 266 | -------------------------------------------------------------------------------- /reference_tests/reference_test_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Reference tests check that a function is compiled correctly.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import inspect 22 | import numbers 23 | import os 24 | import sys 25 | import traceback 26 | 27 | import numpy as np 28 | import six 29 | import tensorflow.compat.v2 as tf 30 | import termcolor 31 | 32 | 33 | class MutableContainer(object): 34 | """Testing helper that can create objects with properties.""" 35 | 36 | def __init__(self, **kwargs): 37 | self.__dict__ = kwargs 38 | for k in kwargs: 39 | setattr(self, k, kwargs[k]) 40 | 41 | def __str__(self): 42 | return 'MutableContainer%s' % self.__dict__ 43 | 44 | def __ne__(self, other): 45 | return not self.__eq__(other) 46 | 47 | def __eq__(self, other): 48 | if not isinstance(other, MutableContainer): 49 | return False 50 | if self.__dict__.keys() != other.__dict__.keys(): 51 | return False 52 | return all( 53 | self.__dict__[k] == other.__dict__[k] for k in self.__dict__.keys()) 54 | 55 | 56 | def to_graph(func, recursive=True): 57 | new_func = tf.autograph.to_graph( 58 | func, 59 | recursive=recursive, 60 | experimental_optional_features=tf.autograph.experimental.Feature.ALL) 61 | # TODO(b/127686409): Remove this. 62 | if inspect.ismethod(func): 63 | return six.create_bound_method(new_func, func.__self__) 64 | return new_func 65 | 66 | 67 | def to_graph_nonrecursive(func): 68 | return to_graph(func, recursive=False) 69 | 70 | 71 | def tf_function(func): 72 | return tf.function(func) 73 | 74 | 75 | def tf_function_all(func): 76 | return tf.function( 77 | func, 78 | experimental_autograph_options=tf.autograph.experimental.Feature.ALL) 79 | 80 | 81 | def tf_function_custom(options=None): 82 | def fn(func): 83 | return tf.function( 84 | func, 85 | experimental_autograph_options=options) 86 | return fn 87 | 88 | 89 | class TestCase(tf.test.TestCase): 90 | """Base class for the reference tests.""" 91 | 92 | def setUp(self): 93 | super(TestCase, self).setUp() 94 | 95 | os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' 96 | # TODO(mdan): tf_function should be default here. 97 | self.convert = to_graph 98 | 99 | # TODO(mdan): Consider rewriting as a context manager. 100 | def _run_with_output_capture(self, func): 101 | out_capturer = six.StringIO() 102 | results = None 103 | captured_out = None 104 | captured_err = None 105 | try: 106 | sys.stdout = out_capturer 107 | results = func() 108 | captured_out = out_capturer.getvalue() 109 | except Exception as e: # pylint:disable=broad-except 110 | sys.stdout = sys.__stdout__ 111 | captured_err = e 112 | print('*** Capturing exception:\n{}\n'.format(traceback.format_exc())) 113 | finally: 114 | sys.stdout = sys.__stdout__ 115 | out_capturer.close() 116 | return results, captured_out, captured_err 117 | 118 | def _as_tensors(self, args): 119 | tensor_args = [] 120 | for a in args: 121 | if isinstance(a, (numbers.Number, list, np.ndarray)): 122 | tensor_arg = tf.constant(a) 123 | elif isinstance(a, dict): 124 | keys = tuple(a.keys()) 125 | tensor_arg = dict(zip(keys, self._as_tensors([a[k] for k in keys]))) 126 | elif isinstance(a, MutableContainer): 127 | tensor_arg = MutableContainer(**self._as_tensors([a.__dict__])[0]) 128 | else: 129 | tensor_arg = a 130 | tensor_args.append(tensor_arg) 131 | return tensor_args 132 | 133 | def _as_ndarrays(self, args): 134 | return tuple( 135 | np.array(a) if isinstance(a, (numbers.Number, list, tuple)) else a 136 | for a in args 137 | ) 138 | 139 | # TODO(mdan): Rename these to snake_case. 140 | def runCompiled(self, f, *args): 141 | return self.runTf(self.convert(f), *args) 142 | 143 | def runNumpy(self, f, *args): 144 | return self._run_with_output_capture(lambda: f(*self._as_ndarrays(args))) 145 | 146 | def runNative(self, f, *args): 147 | return self._run_with_output_capture(lambda: f(*args)) 148 | 149 | def runTf(self, f, *args): 150 | with self.test_session() as sess: 151 | f_outs = f(*self._as_tensors(args)) 152 | 153 | if isinstance(f_outs, tuple): 154 | outs = f_outs 155 | else: 156 | outs = (f_outs,) 157 | if f_outs is None: 158 | return None, '', None 159 | 160 | primitive_outs = tuple( 161 | o.__dict__ if isinstance(o, MutableContainer) else o for o in outs) 162 | # Convert any remaining primitives to tensors. 163 | primitive_outs = self._as_tensors(primitive_outs) 164 | 165 | (primitive_results, captured_out, captured_err 166 | ) = self._run_with_output_capture(lambda: sess.run(primitive_outs)) 167 | if primitive_results is not None: 168 | final_outs = tuple( 169 | MutableContainer(**r) if isinstance(o, MutableContainer) else r 170 | for r, o in zip(primitive_results, outs)) 171 | else: 172 | final_outs = (None,) 173 | 174 | if isinstance(f_outs, tuple): 175 | return final_outs, captured_out, captured_err 176 | else: 177 | return final_outs[0], captured_out, captured_err 178 | return final_outs 179 | 180 | def _deep_equal(self, left, right): 181 | if isinstance(left, tf.Tensor): 182 | return self._deep_equal(left.numpy(), right) 183 | if isinstance(right, tf.Tensor): 184 | return self._deep_equal(left, right.numpy()) 185 | if isinstance(left, tf.SparseTensor) and isinstance(right, tf.SparseTensor): 186 | return (self._deep_equal(left.indices, right.indices) 187 | and self._deep_equal(left.values, right.values) 188 | and self._deep_equal(left.shape, right.shape)) 189 | if isinstance(left, np.ndarray) or isinstance(right, np.ndarray): 190 | return np.array_equal(left, right) 191 | if isinstance(left, (list, tuple)) and isinstance(right, (list, tuple)): 192 | return all(self._deep_equal(l, r) for l, r in zip(left, right)) 193 | return left == right 194 | 195 | def assertResultsMatch(self, 196 | f, 197 | args, 198 | native_data, 199 | compiled_data): 200 | native_results, native_out, native_err = native_data 201 | compiled_results, compiled_out, compiled_err = compiled_data 202 | str_args = '(%s)' % ', '.join(str(a) for a in args) 203 | # Using a manual verification to avoid a second compilation on success. 204 | # For exceptions, we don't enforce that they are the same, only that 205 | # both paths raised. 206 | # TODO(mdan): Add an API that returns both object and source code instead. 207 | outputs_equal = ( 208 | self._deep_equal(native_results, compiled_results) and 209 | native_out == compiled_out) 210 | errors_equivalent = type(native_err) == type(compiled_err) # pylint:disable=unidiomatic-typecheck 211 | if (not outputs_equal or not errors_equivalent): 212 | self.fail('Native and compiled functions are not equivalent.\n\n' 213 | 'Native results: %s\n' 214 | 'Compiled results: %s\n' 215 | 'Native out: %s\n' 216 | 'Compiled out: %s\n' 217 | 'Native error: %s\n' 218 | 'Compiled error: %s\n' 219 | 'Native call: %s%s\n' 220 | 'Check the logs for the generated code.' 221 | '' % 222 | (termcolor.colored(native_results, 'green', attrs=['bold']), 223 | termcolor.colored(compiled_results, 'red', attrs=['bold']), 224 | termcolor.colored(native_out, 'green', attrs=['bold']), 225 | termcolor.colored(compiled_out, 'red', attrs=['bold']), 226 | termcolor.colored( 227 | '%s: %s' % (type(native_err).__name__, native_err), 228 | 'green', 229 | attrs=['bold']), 230 | termcolor.colored( 231 | '%s: %s' % (type(compiled_err).__name__, compiled_err), 232 | 'red', 233 | attrs=['bold']), 234 | termcolor.colored(f.__name__, 'blue', attrs=['bold']), 235 | termcolor.colored(str_args, 'blue', attrs=['bold']))) 236 | 237 | def assertFunctionMatchesEagerStatefulInput(self, f, args): 238 | """Like assertFunctionMatchesEager but creates new inputs each time.""" 239 | compiled_data = self.runNative(tf.function(f), *args()) 240 | native_data = self.runNative(f, *args()) 241 | self.assertResultsMatch(f, args(), native_data, compiled_data) 242 | 243 | def assertFunctionMatchesEager(self, f, *args): 244 | compiled_data = self.runNative(tf.function(f), *args) 245 | native_data = self.runNative(f, *args) 246 | self.assertResultsMatch(f, args, native_data, compiled_data) 247 | 248 | def assertNativeMatchesCompiled(self, f, *args): 249 | compiled_data = self.runCompiled(f, *args) 250 | native_data = self.runNative(f, *args) 251 | self.assertResultsMatch(f, args, native_data, compiled_data) 252 | 253 | def assertTfMatchesCompiled(self, f, *args): 254 | compiled_data = self.runCompiled(f, *args) 255 | native_data = self.runTf(f, *args) 256 | self.assertResultsMatch(f, args, native_data, compiled_data) 257 | 258 | def assertNativeMatchesCompiledMethod(self, m, *args): 259 | compiled_data = self.runCompiled(m, *args) 260 | native_data = self.runNative(m, *args) 261 | self.assertResultsMatch(m, args, native_data, compiled_data) 262 | 263 | def assertMatchesObject(self, c, methods_and_args, native_run, compiled_run): 264 | init_func, init_args = methods_and_args[0] 265 | assert init_func == '__init__' 266 | 267 | native_object = c(*init_args) 268 | compiled_c = self.convert(c) 269 | compiled_object = compiled_c(*self._as_tensors(init_args)) 270 | for name, args in methods_and_args[1:]: 271 | native_method = getattr(native_object, name) 272 | compiled_method = getattr(compiled_object, name) 273 | native_data = native_run(native_method, *args) 274 | compiled_data = compiled_run(compiled_method, *args) 275 | self.assertResultsMatch(native_method, args, native_data, compiled_data) 276 | 277 | def assertTfMatchesCompiledObject(self, c, methods_and_args): 278 | self.assertMatchesObject(c, methods_and_args, self.runTf, self.runTf) 279 | 280 | def assertNativeMatchesCompiledObject(self, c, methods_and_args): 281 | self.assertMatchesObject(c, methods_and_args, self.runNative, self.runTf) 282 | 283 | def try_execute_compiled(self, f, *args): 284 | _, _, err = self.runCompiled(f, *args) 285 | if err: 286 | raise err 287 | 288 | 289 | if __name__ == '__main__': 290 | tf.test.main() 291 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. --------------------------------------------------------------------------------