├── .gitignore ├── .coveragerc ├── setup.cfg ├── .travis.yml ├── pyproject.toml ├── examples └── torch_from_tensorflow.py ├── LICENSE ├── tests └── test_adapters.py ├── README.md ├── setup.py └── tfpyth └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs/ 2 | __pycache__/ 3 | .idea/ 4 | build/ 5 | dist/ 6 | *.egg-info -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=tfpyth/* 3 | omit= 4 | */tests/* 5 | setup.py 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --cov=. 6 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | matrix: 3 | include: 4 | - python: 3.7 5 | dist: xenial 6 | sudo: true 7 | install: 8 | - pip install -q -e .[dev,test] 9 | script: 10 | - python setup.py test 11 | 12 | # Push the results back to codecov 13 | after_success: 14 | - codecov 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py36', 'py37', 'py38'] 4 | include = '\.pyi?$' 5 | exclude = ''' 6 | /( 7 | \.eggs 8 | | \.git 9 | | \.hg 10 | | \.mypy_cache 11 | | \.tox 12 | | \.venv 13 | | _build 14 | | buck-out 15 | | build 16 | | dist 17 | )/ 18 | ''' 19 | -------------------------------------------------------------------------------- /examples/torch_from_tensorflow.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import torch as th 3 | import numpy as np 4 | import tfpyth 5 | 6 | session = tf.Session() 7 | 8 | 9 | def get_torch_function(): 10 | a = tf.placeholder(tf.float32, name="a") 11 | b = tf.placeholder(tf.float32, name="b") 12 | c = 3 * a + 4 * b * b 13 | 14 | f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply 15 | return f 16 | 17 | 18 | f = get_torch_function() 19 | a = th.tensor(1, dtype=th.float32, requires_grad=True) 20 | b = th.tensor(3, dtype=th.float32, requires_grad=True) 21 | x = f(a, b) 22 | 23 | assert x == 39.0 24 | 25 | x.backward() 26 | 27 | assert np.allclose((a.grad, b.grad), (3.0, 24.0)) 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Andreas @blackhc Kirsch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/test_adapters.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import torch as th 3 | import numpy as np 4 | import tfpyth 5 | 6 | 7 | def test_pytorch_in_tensorflow_eager_mode(): 8 | tf.enable_eager_execution() 9 | tfe = tf.contrib.eager 10 | 11 | def pytorch_expr(a, b): 12 | return 3 * a + 4 * b * b 13 | 14 | x = tfpyth.eager_tensorflow_from_torch(pytorch_expr) 15 | 16 | assert tf.math.equal(x(tf.convert_to_tensor(1.0), tf.convert_to_tensor(3.0)), 39.0) 17 | 18 | dx = tfe.gradients_function(x) 19 | assert all(tf.math.equal(dx(tf.convert_to_tensor(1.0), tf.convert_to_tensor(3.0)), [3.0, 24.0])) 20 | tf.disable_eager_execution() 21 | 22 | 23 | def test_pytorch_in_tensorflow_graph_mode(): 24 | session = tf.Session() 25 | 26 | def pytorch_expr(a, b): 27 | return 3 * a + 4 * b * b 28 | 29 | a = tf.placeholder(tf.float32, name="a") 30 | b = tf.placeholder(tf.float32, name="b") 31 | c = tfpyth.tensorflow_from_torch(pytorch_expr, [a, b], tf.float32) 32 | c_grad = tf.gradients([c], [a, b], unconnected_gradients="zero") 33 | 34 | assert np.allclose(session.run([c, c_grad[0], c_grad[1]], {a: 1.0, b: 3.0}), [39.0, 3.0, 24.0]) 35 | 36 | 37 | def test_tensorflow_in_pytorch(): 38 | session = tf.Session() 39 | 40 | def get_tf_function(): 41 | a = tf.placeholder(tf.float32, name="a") 42 | b = tf.placeholder(tf.float32, name="b") 43 | c = 3 * a + 4 * b * b 44 | 45 | f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply 46 | return f 47 | 48 | f = get_tf_function() 49 | a_ = th.tensor(1, dtype=th.float32, requires_grad=True) 50 | b_ = th.tensor(3, dtype=th.float32, requires_grad=True) 51 | x = f(a_, b_) 52 | 53 | assert x == 39.0 54 | 55 | x.backward() 56 | 57 | assert np.allclose((a_.grad, b_.grad), (3.0, 24.0)) 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TfPyTh 2 | 3 | [![Build Status](https://travis-ci.com/BlackHC/tfpyth.svg?branch=master)](https://travis-ci.com/BlackHC/tfpyth) [![codecov](https://codecov.io/gh/BlackHC/tfpyth/branch/master/graph/badge.svg)](https://codecov.io/gh/BlackHC/tfpyth) 4 | 5 | > Putting TensorFlow back in PyTorch, back in TensorFlow (with differentiable TensorFlow PyTorch adapters). 6 | 7 | Do you have a codebase that uses TensorFlow and one that uses PyTorch and want to train a model that uses both end-to-end? 8 | 9 | This library makes it possible without having to rewrite either codebase! 10 | 11 | It allows you to wrap a TensorFlow graph to make it callable (and differentiable) through PyTorch, and vice-versa, using simple functions. 12 | 13 | The only caveat is that tensors have to be copied and routed through the CPU until TensorFlow supports `__cuda_array_interface` (please star the [GitHub issue](https://github.com/tensorflow/tensorflow/issues/29039)). 14 | 15 | ## Install 16 | 17 | ``` 18 | pip install tfpyth 19 | ``` 20 | 21 | ### Example 22 | 23 | ```python 24 | import tensorflow as tf 25 | import torch as th 26 | import numpy as np 27 | import tfpyth 28 | 29 | session = tf.Session() 30 | 31 | def get_torch_function(): 32 | a = tf.placeholder(tf.float32, name='a') 33 | b = tf.placeholder(tf.float32, name='b') 34 | c = 3 * a + 4 * b * b 35 | 36 | f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply 37 | return f 38 | 39 | f = get_torch_function() 40 | a = th.tensor(1, dtype=th.float32, requires_grad=True) 41 | b = th.tensor(3, dtype=th.float32, requires_grad=True) 42 | x = f(a, b) 43 | 44 | assert x == 39. 45 | 46 | x.backward() 47 | 48 | assert np.allclose((a.grad, b.grad), (3., 24.)) 49 | ``` 50 | 51 | ## What it's got 52 | 53 | ### `torch_from_tensorflow` 54 | 55 | Creates a PyTorch function that is differentiable by evaluating a TensorFlow output tensor given input placeholders. 56 | 57 | ### `eager_tensorflow_from_torch` 58 | 59 | Creates an eager Tensorflow function from a PyTorch function. 60 | 61 | ### `tensorflow_from_torch` 62 | 63 | Creates a TensorFlow op/tensor from a PyTorch function. 64 | 65 | ## Future work 66 | 67 | - [ ] support JAX 68 | - [ ] support higher-order derivatives 69 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Always prefer setuptools over distutils 2 | from setuptools import setup 3 | 4 | # To use a consistent encoding 5 | from codecs import open 6 | from os import path 7 | 8 | here = path.abspath(path.dirname(__file__)) 9 | 10 | # Get the long description from the README file 11 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 12 | long_description = f.read() 13 | 14 | setup( 15 | name="tfpyth", 16 | # Versions should comply with PEP440. For a discussion on single-sourcing 17 | # the version across setup.py and the project code, see 18 | # https://packaging.python.org/en/latest/single_source_version.html 19 | version="1.0.1", 20 | description="Putting TensorFlow back in PyTorch, back in Tensorflow (differentiable TensorFlow PyTorch adapters).", 21 | # Fix windows newlines. 22 | long_description=long_description.replace("\r\n", "\n"), 23 | long_description_content_type="text/markdown", 24 | # The project's main homepage. 25 | url="https://github.com/blackhc/tfpyth", 26 | # Author details 27 | author="Andreas @blackhc Kirsch", 28 | author_email="blackhc+tfpyth@gmail.com", 29 | # Choose your license 30 | license="MIT", 31 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 32 | classifiers=[ 33 | # How mature is this project? Common values are 34 | # 3 - Alpha 35 | # 4 - Beta 36 | # 5 - Production/Stable 37 | "Development Status :: 4 - Beta", 38 | # Indicate who your project is intended for 39 | "Intended Audience :: Developers", 40 | "Intended Audience :: Science/Research", 41 | "Topic :: Software Development :: Libraries :: Python Modules", 42 | # Pick your license as you wish (should match "license" above) 43 | "License :: OSI Approved :: MIT License", 44 | "Programming Language :: Python :: 3.6", 45 | ], 46 | # What does your project relate to? 47 | keywords="ml machine learning", 48 | # You can just specify the packages manually here if your project is 49 | # simple. Or you can use find_packages(). 50 | packages=["tfpyth"], 51 | #package_dir={"": ""}, 52 | # List run-time dependencies here. These will be installed by pip when 53 | # your project is installed. For an analysis of "install_requires" vs pip's 54 | # requirements files see: 55 | # https://packaging.python.org/en/latest/requirements.html 56 | install_requires=["tensorflow~=1.14", "torch~=1.1"], 57 | # List additional groups of dependencies here (e.g. development 58 | # dependencies). You can install these using the following syntax, 59 | # for example: 60 | # $ pip install -e .[dev,test] 61 | extras_require={"dev": ["check-manifest"], "test": ["coverage", "codecov", "pytest", "pytest-cov"]}, 62 | setup_requires=["pytest-runner"], 63 | ) 64 | -------------------------------------------------------------------------------- /tfpyth/__init__.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import torch as th 3 | 4 | 5 | class TensorFlowFunction(th.autograd.Function): 6 | """ 7 | Wrapper class for Tensorflow input/output nodes (incl gradient) in PyTorch. 8 | """ 9 | 10 | inputs: list = None 11 | output: tf.Tensor = None 12 | gradient_placeholder = None 13 | gradient_outputs = None 14 | 15 | 16 | def torch_from_tensorflow(tf_session, tf_inputs, tf_output, tf_dtype=tf.float32): 17 | """ 18 | Create a PyTorch TensorFlowFunction with forward and backward methods which executes evaluates the passed 19 | TensorFlow tensors. 20 | 21 | ```python 22 | my_tensorflow_func = MyTensorFlowFunction.apply 23 | 24 | result = my_tensorflow_func(th_a, th_b) 25 | ``` 26 | 27 | :param tf_session: TensorFlow session to use 28 | :param tf_inputs: TensorFlow input tensors/placeholders 29 | :param tf_output: TensorFlow output tensor 30 | :param tf_dtype: dtype to use for gradient placeholder. 31 | :return: TensorflowFunction which can be applied to PyTorch tensors. 32 | """ 33 | # create gradient placeholders 34 | tf_gradient_placeholder = tf.placeholder(dtype=tf_dtype, name=f"gradient") 35 | tf_gradient_outputs = tf.gradients( 36 | ys=tf_output, xs=tf_inputs, grad_ys=[tf_gradient_placeholder], unconnected_gradients="zero" 37 | ) 38 | 39 | class _TensorFlowFunction(TensorFlowFunction): 40 | inputs = tf_inputs 41 | output = tf_output 42 | gradient_placeholder = tf_gradient_placeholder 43 | gradient_outputs = tf_gradient_outputs 44 | 45 | @staticmethod 46 | def forward(ctx, *args): 47 | assert len(args) == len(tf_inputs) 48 | 49 | feed_dict = {tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, args)} 50 | output = tf_session.run(tf_output, feed_dict) 51 | 52 | ctx.save_for_backward(*args) 53 | 54 | th_output = th.as_tensor(output) 55 | return th_output 56 | 57 | # See https://www.janfreyberg.com/blog/2019-04-01-testing-pytorch-functions/ for why "no cover" 58 | @staticmethod 59 | def backward(ctx, grad_output): # pragma: no cover 60 | th_inputs = ctx.saved_tensors 61 | 62 | feed_dict = {} 63 | feed_dict.update({tf_input: th_input.detach().numpy() for tf_input, th_input in zip(tf_inputs, th_inputs)}) 64 | feed_dict.update({tf_gradient_placeholder: grad_output.detach().numpy()}) 65 | 66 | tf_gradients = tf_session.run(tf_gradient_outputs, feed_dict) 67 | return tuple(th.as_tensor(tf_gradient) for tf_gradient in tf_gradients) 68 | 69 | return _TensorFlowFunction() 70 | 71 | 72 | def eager_tensorflow_from_torch(func): 73 | """ 74 | Wraps a PyTorch function into a TensorFlow eager-mode function (ie can be executed within Tensorflow eager-mode). 75 | 76 | :param func: Function that takes PyTorch tensors and returns a PyTorch tensor. 77 | :return: Differentiable Tensorflow eager-mode function. 78 | """ 79 | 80 | @tf.custom_gradient 81 | def compute(*inputs): 82 | th_inputs = [th.tensor(tf_input.numpy(), requires_grad=True) for tf_input in inputs] 83 | th_output = func(*th_inputs) 84 | 85 | def compute_grad(d_output): 86 | th_d_output = th.tensor(d_output.numpy(), requires_grad=False) 87 | th_gradients = th.autograd.grad([th_output], th_inputs, grad_outputs=[th_d_output], allow_unused=True) 88 | tf_gradients = [tf.convert_to_tensor(th_gradient.numpy()) for th_gradient in th_gradients] 89 | return tf_gradients 90 | 91 | return tf.convert_to_tensor(th_output.detach().numpy()), compute_grad 92 | 93 | return compute 94 | 95 | 96 | def tensorflow_from_torch(func, inp, Tout, name=None): 97 | """ 98 | Executes a PyTorch function into a TensorFlow op and output tensor (ie can be evaluated within Tensorflow).\ 99 | 100 | :param func: Function that takes PyTorch tensors and returns a PyTorch tensor. 101 | :param inp: TensorFlow input tensors 102 | :param Tout: TensorFlow output dtype 103 | :param name: Name of the output tensor 104 | :return: Differentiable Tensorflow output tensor. 105 | """ 106 | eager_compute = eager_tensorflow_from_torch(func) 107 | 108 | return tf.py_function(eager_compute, inp, Tout, name=name) 109 | --------------------------------------------------------------------------------