├── .gitignore ├── LICENSE ├── README.md ├── examples ├── intro.py └── paillier.py ├── federated_aggregations ├── __init__.py ├── channels │ ├── __init__.py │ ├── channel.py │ ├── channel_grid.py │ ├── channel_grid_test.py │ ├── channel_test.py │ ├── channel_test_utils.py │ ├── computations.py │ └── key_store.py ├── paillier │ ├── __init__.py │ ├── computations.py │ ├── factory.py │ ├── factory_test.py │ ├── placement.py │ ├── strategy.py │ └── strategy_test.py ├── utils.py └── version.py ├── protocol.gif ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Federated Aggregation 2 | Using [TF Encrypted](https://github.com/tf-encrypted/tf-encrypted) primitives for secure aggregation in [TensorFlow Federated](https://github.com/tensorflow/federated) 3 | 4 | This project implements specific protocols for secure aggregation using secure computation primitives from [TF Encrypted](https://github.com/tf-encrypted/tf-encrypted). Our aim is to express secure aggregations with the full breadth of TFE's language for secure computations, however this prototype is much smaller in scope. We implement a specific aggregation protocol based on Paillier homomorphic encryption; see the [accompanying blog post](https://medium.com/dropoutlabs/building-secure-aggregation-into-tensorflow-federated-4514fca40cc0) or the sections below for more details. 5 | 6 | ```python 7 | import numpy as np 8 | import tensorflow as tf 9 | import tensorflow_federated as tff 10 | from federated_aggregations import paillier 11 | 12 | paillier_factory = paillier.local_paillier_executor_factory() 13 | paillier_context = tff.framework.ExecutionContext(paillier_factory) 14 | tff.framework.set_default_context(paillier_context) 15 | 16 | # data from 5 clients 17 | x = [np.array([i, i + 1], dtype=np.int32) for i in range(5)] 18 | x_type = tff.TensorType(tf.int32, [2]) 19 | 20 | @tff.federated_computation(tff.FederatedType(x_type, tff.CLIENTS)) 21 | def secure_paillier_addition(x): 22 | return tff.federated_secure_sum(x, bitwidth=32) 23 | 24 | result = secure_paillier_addition(x) 25 | print(result) 26 | >>> [10 15] 27 | ``` 28 | 29 | # Installation 30 | This library is offered as a Python package but is not currently published on PyPI, so you must install it from source in your preferred Python environment. The code has been tested with Python 3.7 on MacOS. 31 | 32 | ``` 33 | pip install -r requirements.txt 34 | python setup.py install 35 | ``` 36 | 37 | # Features 38 | ## Federated computation 39 | Currently, we simply add an implementation of the [`tff.federated_secure_sum`](https://www.tensorflow.org/federated/api_docs/python/tff/federated_secure_sum) to the default TFF simulation stack. We do not rewrite any of the higher-level APIs for federated averaging, but these should be straightforward to implement. 40 | 41 | ## Protocols 42 | Currently, we implement secure aggregation via the Paillier homomorphic encryption scheme. This protocol is well suited for federated averaging with highly-available clients, e.g. cross-silo federated learning between organizations. Please see the [accompanying blog post](https://medium.com/dropoutlabs/building-secure-aggregation-into-tensorflow-federated-4514fca40cc0) or the [illustration below](#illustrated-protocol) for more details. 43 | 44 | We outsource the Paillier aggregation to an "aggregation service" running separately from the Server role in traditional FL. This fits into the bulletin-board style of FL, where a service separate from the coordinator is responsible for aggregating model updates securely, and the Server (i.e. coordinator) periodically pulls & decrypts the latest model from that service. This achieves the specific functionality outlined in [this section](https://github.com/tf-encrypted/rfcs/tree/master/20190924-tensorflow-federated#specific-encrypted-executors) of the corresponding RFC. 45 | 46 | ## Secure channels 47 | TensorFlow Federated does not implement point-to-point communication directly between placements; it instead routes all communications through a computation "driver" (i.e. the host running the TFF Python script, also usually responsible for unplaced computation). To reduce communication when using the native backend, this driver is usually collocated with the tff.SERVER placement's executor stack, so that any values communicated between the driver and the tff.SERVER don't incur a network cost. 48 | 49 | This communication pattern presents a problem for implementing secure aggregation, since many SMPC protocols assume the existence of authenticated channels between parties. In order to realize this in the specific case of a bulletin-board aggregation service, we follow the approach outlined in [this section](https://github.com/tf-encrypted/rfcs/tree/master/20190924-tensorflow-federated#network-strategy-and-secure-channels) of our RFC. Please see the [accompanying blog post](https://medium.com/dropoutlabs/building-secure-aggregation-into-tensorflow-federated-4514fca40cc0) for an illustration and more details. 50 | 51 | # Illustrated protocol 52 | 53 | ![Secure aggregation protocol](./protocol.gif) 54 | 55 | # Development 56 | If you want to get up and running, please follow these steps. We strongly encourage using a virtual environment. 57 | 1. Install dependencies with `pip install -r requirements.txt`. Depending on your platform, you may need to build these projects from source. See instructions specific to [tf-encrypted-primitives](https://github.com/tf-encrypted/tf-encrypted/tree/master/primitives) or [tf-big](https://github.com/tf-encrypted/tf-big) for more information. We do not guarantee support for all platforms. 58 | 2. Install this package using pip (e.g. `pip install -e .`). 59 | 3. Run tests. 60 | 61 | If you run into issues, please [reach out](#support-and-feedback). 62 | 63 | # Roadmap 64 | Please see the original [TFF Integration RFC](https://github.com/tf-encrypted/rfcs/tree/master/20190924-tensorflow-federated) for an overview of our goals. While the implementation in this project isn't identical, and our plans have evolved since then, the high-level objectives have not changed. 65 | 66 | # Support and Feedback 67 | Bug reports and feature requests? Please [open an issue](https://github.com/tf-encrypted/federated-aggregations/issues) on Github. 68 | 69 | For any other questions or feedback, please reach out directly on [Slack](https://join.slack.com/t/tf-encrypted/shared_invite/enQtNjI5NjY5NTc0NjczLWM4MTVjOGVmNGFkMWU2MGEzM2Q5ZWFjMTdmZjdmMTM2ZTU4YjJmNTVjYmE1NDAwMDIzMjllZjJjMWNiMTlmZTQ), or send an email to [contact@tf-encrypted.io](mailto:contact@tf-encrypted.io). 70 | 71 | # License 72 | 73 | Licensed under Apache License, Version 2.0 (see [LICENSE](./LICENSE) or http://www.apache.org/licenses/LICENSE-2.0). Copyright as specified in [NOTICE](./NOTICE). 74 | -------------------------------------------------------------------------------- /examples/intro.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_federated as tff 3 | from federated_aggregations import paillier 4 | 5 | NUM_CLIENTS = 5 6 | paillier_factory = paillier.local_paillier_executor_factory() 7 | paillier_context = tff.framework.ExecutionContext(paillier_factory) 8 | tff.framework.set_default_context(paillier_context) 9 | 10 | @tff.federated_computation( 11 | tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS)) 12 | def secure_paillier_addition(x): 13 | return tff.federated_secure_sum(x, 32) 14 | 15 | x = [[i, i + 1] for i in range(NUM_CLIENTS)] 16 | result = secure_paillier_addition(x) 17 | print(result) 18 | -------------------------------------------------------------------------------- /examples/paillier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow_federated as tff 4 | 5 | from federated_aggregations import paillier 6 | 7 | NUM_CLIENTS = 5 8 | 9 | paillier_factory = paillier.local_paillier_executor_factory(NUM_CLIENTS) 10 | paillier_context = tff.framework.ExecutionContext(paillier_factory) 11 | tff.framework.set_default_context(paillier_context) 12 | 13 | @tff.federated_computation(tff.FederatedType(tff.TensorType(tf.int32, [2]), tff.CLIENTS), tff.TensorType(tf.int32)) 14 | def secure_paillier_addition(x, bitwidth): 15 | return tff.federated_secure_sum(x, bitwidth) 16 | 17 | base = np.array([1, 2], np.int32) 18 | x = [base + i for i in range(NUM_CLIENTS)] 19 | result = secure_paillier_addition(x, 32) 20 | print(result) 21 | -------------------------------------------------------------------------------- /federated_aggregations/__init__.py: -------------------------------------------------------------------------------- 1 | from . import channels 2 | from . import paillier 3 | -------------------------------------------------------------------------------- /federated_aggregations/channels/__init__.py: -------------------------------------------------------------------------------- 1 | from federated_aggregations.channels.channel_grid import ChannelGrid 2 | from federated_aggregations.channels.channel import Channel 3 | from federated_aggregations.channels.channel import PlaintextChannel 4 | from federated_aggregations.channels.channel import EasyBoxChannel 5 | -------------------------------------------------------------------------------- /federated_aggregations/channels/channel.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import asyncio 3 | import itertools 4 | from typing import Tuple 5 | 6 | import tensorflow_federated as tff 7 | from tensorflow_federated.python.common_libs import py_typecheck 8 | from tensorflow_federated.python.common_libs import structure 9 | from tensorflow_federated.python.core.api import computations 10 | from tensorflow_federated.python.core.impl.executors import federated_resolving_strategy 11 | from tensorflow_federated.python.core.impl.types import placement_literals 12 | from tensorflow_federated.python.core.impl.types import type_analysis 13 | from tensorflow_federated.python.core.impl.types import type_conversions 14 | 15 | from federated_aggregations import utils 16 | from federated_aggregations.channels import computations as sodium_comp 17 | from federated_aggregations.channels import key_store 18 | 19 | PlacementPair = Tuple[ 20 | placement_literals.PlacementLiteral, 21 | placement_literals.PlacementLiteral] 22 | 23 | 24 | class Channel(metaclass=abc.ABCMeta): 25 | @abc.abstractmethod 26 | async def transfer(self, value): 27 | pass 28 | 29 | @abc.abstractmethod 30 | async def send(self, value, source, recipient): 31 | pass 32 | 33 | @abc.abstractmethod 34 | async def receive(self, value, source, recipient): 35 | pass 36 | 37 | @abc.abstractmethod 38 | async def setup(self): 39 | pass 40 | 41 | 42 | class BaseChannel(Channel): 43 | """Abstract interface for communication channels between placement pairs. 44 | 45 | This class defines the contract between FederatingStrategy and concrete 46 | implementations of the Channel abstraction. For any subclass ConcreteChannel, 47 | FederatingStrategy should only use the ConcreteChannel.transfer method. 48 | Implementations of the Channel interface need only be concerned with 49 | extending Channel.send, Channel.receive, and Channel.setup. 50 | 51 | Attributes: 52 | strategy: The FederatingStrategy this Channel belongs to. 53 | placements: The (unordered) pair of placements that will communicate 54 | through this channel. 55 | """ 56 | def __init__( 57 | self, 58 | strategy, 59 | *placements: PlacementPair): 60 | self.strategy = strategy 61 | self.placements = placements 62 | 63 | async def transfer(self, value): 64 | _check_value_placement(value, self.placements) 65 | sender_placement = value.type_signature.placement 66 | receiver_placement = _get_other_placement( 67 | sender_placement, self.placements) 68 | sent = await self.send(value, sender_placement, receiver_placement) 69 | rcv_children = self.strategy._get_child_executors(receiver_placement) 70 | message = await sent.compute() 71 | message_type = type_conversions.infer_type(message) 72 | if receiver_placement is tff.CLIENTS: 73 | if isinstance(message_type, tff.StructType): 74 | iterator = zip(rcv_children, message, message_type) 75 | member_type = message_type[0] 76 | all_equal = False 77 | else: 78 | iterator = zip(rcv_children, 79 | itertools.repeat(message), 80 | itertools.repeat(message_type)) 81 | member_type = message_type 82 | all_equal = True 83 | message_value = federated_resolving_strategy.FederatedResolvingStrategyValue( 84 | await asyncio.gather(*[c.create_value(m, t) for c, m, t in iterator]), 85 | tff.FederatedType(member_type, receiver_placement, all_equal)) 86 | else: 87 | rcv_child = rcv_children[0] 88 | if isinstance(message_type, tff.StructType): 89 | message_value = federated_resolving_strategy.FederatedResolvingStrategyValue( 90 | structure.from_container( 91 | await asyncio.gather(*[ 92 | rcv_child.create_value(m, t) 93 | for m, t in zip(message, message_type)])), 94 | tff.StructType([ 95 | tff.FederatedType(mt, receiver_placement, True) 96 | for mt in message_type])) 97 | else: 98 | message_value = federated_resolving_strategy.FederatedResolvingStrategyValue( 99 | await rcv_child.create_value(message, message_type), 100 | tff.FederatedType(message_type, receiver_placement, 101 | value.type_signature.all_equal)) 102 | return await self.receive(message_value, sender_placement, receiver_placement) 103 | 104 | 105 | class PlaintextChannel(BaseChannel): 106 | """An insecure Channel implementation that communicates tensors in plaintext. 107 | 108 | Attributes: 109 | strategy: The FederatingStrategy this Channel belongs to. 110 | placements: The (unordered) pair of placements that will use this channel 111 | to communicate. 112 | """ 113 | async def send(self, value, source, recipient): 114 | del source, recipient 115 | return value 116 | 117 | async def receive(self, value, source, recipient): 118 | del source, recipient 119 | return value 120 | 121 | async def setup(self): 122 | pass 123 | 124 | 125 | class EasyBoxChannel(BaseChannel): 126 | """A secure Channel using authenticated encryption. 127 | 128 | EasyBoxChannel uses libsodium's EasyBox for authenticated encryption. This 129 | channel is responsible for key setup, including key generation and exchange. 130 | It also defines the pre- and post-processing steps required to achieve 131 | encryption-in-transit under the native runtime's communication model. 132 | 133 | TODO: surface _key_generator to avoid keygen/exchange, and instead accept 134 | keys from a trusted PKI. 135 | 136 | Reference for encryption scheme: 137 | https://libsodium.gitbook.io/doc/public-key_cryptography/authenticated_encryption 138 | 139 | Attributes: 140 | key_references: A KeyStore responsible for managing each placement's 141 | public & secret keys. 142 | strategy: The FederatingStrategy this Channel belongs to. 143 | placements: The (unordered) pair of placements that will use this channel 144 | to communicate. 145 | """ 146 | def __init__( 147 | self, 148 | strategy, 149 | *placements: PlacementPair): 150 | super().__init__(strategy, *placements) 151 | self.key_references = key_store.KeyStore() 152 | self._requires_setup = True 153 | self._key_generator = None # lazy key generation 154 | self._encryptor_cache = {} 155 | self._decryptor_cache = {} 156 | 157 | async def send(self, value, sender_placement, receiver_placement): 158 | if sender_placement is tff.CLIENTS: 159 | return await self._encrypt_values_on_clients(value, sender_placement, 160 | receiver_placement) 161 | return await self._encrypt_values_on_singleton(value, sender_placement, 162 | receiver_placement) 163 | 164 | async def receive(self, value, sender_placement, receiver_placement): 165 | if receiver_placement is tff.CLIENTS: 166 | return await self._decrypt_values_on_clients(value, sender_placement, 167 | receiver_placement) 168 | return await self._decrypt_values_on_singleton(value, sender_placement, 169 | receiver_placement) 170 | 171 | async def setup(self): 172 | if self._requires_setup: 173 | p0, p1 = self.placements 174 | await asyncio.gather(*[ 175 | self._generate_keys(p0), 176 | self._generate_keys(p1)]) 177 | await asyncio.gather(*[ 178 | self._share_public_key(p0, p1), 179 | self._share_public_key(p1, p0)]) 180 | self._requires_setup = False 181 | 182 | async def _encrypt_values_on_singleton(self, val, sender, receiver): 183 | ### 184 | # we can safely assume sender has cardinality=1 when receiver is CLIENTS 185 | ### 186 | # Case 1: receiver=CLIENTS 187 | # plaintext: Fed(Tensor, sender, all_equal=True) 188 | # pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True) 189 | # sk_sender: Fed(Tensor, sender, all_equal=True) 190 | # Returns: 191 | # encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True)) 192 | ### 193 | ### Check proper key placement 194 | sk_sender = self.key_references.get_secret_key(sender) 195 | pk_receiver = self.key_references.get_public_key(receiver) 196 | type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) 197 | assert sk_sender.type_signature.placement is sender 198 | assert pk_receiver.type_signature.placement is sender 199 | ### Check placement cardinalities 200 | rcv_children = self.strategy._get_child_executors(receiver) 201 | snd_children = self.strategy._get_child_executors(sender) 202 | py_typecheck.check_len(snd_children, 1) 203 | snd_child = snd_children[0] 204 | ### Check value cardinalities 205 | type_analysis.check_federated_type(val.type_signature, placement=sender) 206 | py_typecheck.check_len(val.internal_representation, 1) 207 | py_typecheck.check_type(pk_receiver.type_signature.member, 208 | tff.StructType) 209 | py_typecheck.check_len(pk_receiver.internal_representation, 210 | len(rcv_children)) 211 | py_typecheck.check_len(sk_sender.internal_representation, 1) 212 | ### Materialize encryptor function definition & type spec 213 | input_type = val.type_signature.member 214 | self._input_type_cache = input_type 215 | pk_rcv_type = pk_receiver.type_signature.member 216 | sk_snd_type = sk_sender.type_signature.member 217 | pk_element_type = pk_rcv_type[0] 218 | encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) 219 | encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( 220 | sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) 221 | ### Prepare encryption arguments 222 | v = val.internal_representation[0] 223 | sk = sk_sender.internal_representation[0] 224 | ### Encrypt values and return them 225 | encryptor_fn = await snd_child.create_value(encryptor_proto, encryptor_type) 226 | encryptor_args = await asyncio.gather(*[ 227 | snd_child.create_struct([v, this_pk, sk]) 228 | for this_pk in pk_receiver.internal_representation]) 229 | encrypted_values = await asyncio.gather(*[ 230 | snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args]) 231 | encrypted_value_types = [encryptor_type.result] * len(encrypted_values) 232 | return federated_resolving_strategy.FederatedResolvingStrategyValue( 233 | structure.from_container(encrypted_values), 234 | tff.StructType([tff.FederatedType(evt, sender, all_equal=False) 235 | for evt in encrypted_value_types])) 236 | 237 | async def _encrypt_values_on_clients(self, val, sender, receiver): 238 | ### 239 | # Case 2: sender=CLIENTS 240 | # plaintext: Fed(Tensor, CLIENTS, all_equal=False) 241 | # pk_receiver: Fed(Tensor, CLIENTS, all_equal=True) 242 | # sk_sender: Fed(Tensor, CLIENTS, all_equal=False) 243 | # Returns: 244 | # encrypted_values: Fed(Tensor, CLIENTS, all_equal=False) 245 | ### 246 | ### Check proper key placement 247 | sk_sender = self.key_references.get_secret_key(sender) 248 | pk_receiver = self.key_references.get_public_key(receiver) 249 | type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) 250 | assert sk_sender.type_signature.placement is sender 251 | assert pk_receiver.type_signature.placement is sender 252 | ### Check placement cardinalities 253 | snd_children = self.strategy._get_child_executors(sender) 254 | rcv_children = self.strategy._get_child_executors(receiver) 255 | py_typecheck.check_len(rcv_children, 1) 256 | ### Check value cardinalities 257 | type_analysis.check_federated_type(val.type_signature, placement=sender) 258 | federated_value_internals = [ 259 | val.internal_representation, 260 | pk_receiver.internal_representation, 261 | sk_sender.internal_representation] 262 | for v in federated_value_internals: 263 | py_typecheck.check_len(v, len(snd_children)) 264 | ### Materialize encryptor function definition & type spec 265 | input_type = val.type_signature.member 266 | self._input_type_cache = input_type 267 | pk_rcv_type = pk_receiver.type_signature.member 268 | sk_snd_type = sk_sender.type_signature.member 269 | pk_element_type = pk_rcv_type 270 | encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) 271 | encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( 272 | sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) 273 | ### Encrypt values and return them 274 | encryptor_fns = asyncio.gather(*[ 275 | snd_child.create_value(encryptor_proto, encryptor_type) 276 | for snd_child in snd_children]) 277 | encryptor_args = asyncio.gather(*[ 278 | snd_child.create_struct([v, pk, sk]) 279 | for v, pk, sk, snd_child in zip( 280 | *federated_value_internals, snd_children)]) 281 | encryptor_fns, encryptor_args = await asyncio.gather( 282 | encryptor_fns, encryptor_args) 283 | encrypted_values = [ 284 | snd_child.create_call(encryptor, arg) 285 | for encryptor, arg, snd_child in zip( 286 | encryptor_fns, encryptor_args, snd_children)] 287 | return federated_resolving_strategy.FederatedResolvingStrategyValue( 288 | await asyncio.gather(*encrypted_values), 289 | tff.FederatedType(encryptor_type.result, sender, 290 | all_equal=val.type_signature.all_equal)) 291 | 292 | async def _decrypt_values_on_clients(self, val, sender, receiver): 293 | ### Check proper key placement 294 | pk_sender = self.key_references.get_public_key(sender) 295 | sk_receiver = self.key_references.get_secret_key(receiver) 296 | type_analysis.check_federated_type(pk_sender.type_signature, 297 | placement=receiver) 298 | type_analysis.check_federated_type(sk_receiver.type_signature, 299 | placement=receiver) 300 | pk_snd_type = pk_sender.type_signature.member 301 | sk_rcv_type = sk_receiver.type_signature.member 302 | ### Check value cardinalities 303 | rcv_children = self.strategy._get_child_executors(receiver) 304 | federated_value_internals = [ 305 | val.internal_representation, 306 | pk_sender.internal_representation, 307 | sk_receiver.internal_representation] 308 | for fv in federated_value_internals: 309 | py_typecheck.check_len(fv, len(rcv_children)) 310 | ### Materialize decryptor type_spec & function definition 311 | input_type = val.type_signature.member 312 | # input_type[0] is a tff.TensorType, thus input_type represents the 313 | # tuple needed for a single value to be decrypted. 314 | py_typecheck.check_type(input_type[0], tff.TensorType) 315 | py_typecheck.check_type(pk_snd_type, tff.TensorType) 316 | input_element_type = input_type 317 | pk_element_type = pk_snd_type 318 | decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) 319 | decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( 320 | sodium_comp.make_decryptor, 321 | self._decryptor_cache, 322 | decryptor_arg_spec, 323 | orig_tensor_dtype=self._input_type_cache.dtype) 324 | ### Decrypt values and return them 325 | decryptor_fns = asyncio.gather(*[ 326 | rcv_child.create_value(decryptor_proto, decryptor_type) 327 | for rcv_child in rcv_children]) 328 | decryptor_args = asyncio.gather(*[ 329 | rcv_child.create_struct([v, pk, sk]) 330 | for v, pk, sk, rcv_child in zip( 331 | *federated_value_internals, rcv_children)]) 332 | decryptor_fns, decryptor_args = await asyncio.gather( 333 | decryptor_fns, decryptor_args) 334 | decrypted_values = [ 335 | rcv_child.create_call(decryptor, arg) 336 | for decryptor, arg, rcv_child in zip( 337 | decryptor_fns, decryptor_args, rcv_children)] 338 | return federated_resolving_strategy.FederatedResolvingStrategyValue( 339 | await asyncio.gather(*decrypted_values), 340 | tff.FederatedType(decryptor_type.result, receiver, 341 | all_equal=val.type_signature.all_equal)) 342 | 343 | async def _decrypt_values_on_singleton(self, val, sender, receiver): 344 | ### Check proper key placement 345 | pk_sender = self.key_references.get_public_key(sender) 346 | sk_receiver = self.key_references.get_secret_key(receiver) 347 | type_analysis.check_federated_type(pk_sender.type_signature, 348 | placement=receiver) 349 | type_analysis.check_federated_type(sk_receiver.type_signature, 350 | placement=receiver) 351 | pk_snd_type = pk_sender.type_signature.member 352 | sk_rcv_type = sk_receiver.type_signature.member 353 | ### Check placement cardinalities 354 | snd_children = self.strategy._get_child_executors(sender) 355 | rcv_children = self.strategy._get_child_executors(receiver) 356 | py_typecheck.check_len(rcv_children, 1) 357 | rcv_child = rcv_children[0] 358 | ### Check value cardinalities 359 | py_typecheck.check_len(pk_sender.internal_representation, len(snd_children)) 360 | py_typecheck.check_len(sk_receiver.internal_representation, 1) 361 | ### Materialize decryptor type_spec & function definition 362 | py_typecheck.check_type(val.type_signature, tff.StructType) 363 | type_analysis.check_federated_type(val.type_signature[0], 364 | placement=receiver, all_equal=True) 365 | input_type = val.type_signature[0].member 366 | # each input_type is a tuple needed for one value to be decrypted 367 | py_typecheck.check_type(input_type, tff.StructType) 368 | py_typecheck.check_type(pk_snd_type, tff.StructType) 369 | py_typecheck.check_len(val.type_signature, len(pk_snd_type)) 370 | input_element_type = input_type 371 | pk_element_type = pk_snd_type[0] 372 | decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type) 373 | decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( 374 | sodium_comp.make_decryptor, 375 | self._decryptor_cache, 376 | decryptor_arg_spec, 377 | orig_tensor_dtype=self._input_type_cache.dtype) 378 | ### Decrypt values and return them 379 | vals = val.internal_representation 380 | sk = sk_receiver.internal_representation[0] 381 | decryptor_fn = await rcv_child.create_value(decryptor_proto, decryptor_type) 382 | decryptor_args = await asyncio.gather(*[ 383 | rcv_child.create_struct([v, pk, sk]) 384 | for v, pk in zip(vals, pk_sender.internal_representation)]) 385 | decrypted_values = await asyncio.gather(*[ 386 | rcv_child.create_call(decryptor_fn, arg) 387 | for arg in decryptor_args]) 388 | decrypted_value_types = [decryptor_type.result] * len(decrypted_values) 389 | return federated_resolving_strategy.FederatedResolvingStrategyValue( 390 | structure.from_container(decrypted_values), 391 | tff.StructType([ 392 | tff.FederatedType(dvt, receiver, all_equal=True) 393 | for dvt in decrypted_value_types])) 394 | 395 | async def _generate_keys(self, key_owner): 396 | py_typecheck.check_type(key_owner, placement_literals.PlacementLiteral) 397 | executors = self.strategy._get_child_executors(key_owner) 398 | if self._key_generator is None: 399 | self._key_generator = sodium_comp.make_keygen() 400 | keygen, keygen_type = utils.lift_to_computation_spec(self._key_generator) 401 | pk_vals, sk_vals = [], [] 402 | async def keygen_call(child): 403 | return await child.create_call(await child.create_value( 404 | keygen, keygen_type)) 405 | 406 | key_generators = await asyncio.gather(*[ 407 | keygen_call(executor) for executor in executors]) 408 | public_keys = asyncio.gather(*[ 409 | executor.create_selection(key_generator, 0) 410 | for executor, key_generator in zip(executors, key_generators)]) 411 | secret_keys = asyncio.gather(*[ 412 | executor.create_selection(key_generator, 1) 413 | for executor, key_generator in zip(executors, key_generators)]) 414 | pk_vals, sk_vals = await asyncio.gather(public_keys, secret_keys) 415 | pk_type = pk_vals[0].type_signature 416 | sk_type = sk_vals[0].type_signature 417 | # all_equal whenever owner is non-CLIENTS singleton placement 418 | val_all_equal = len(executors) == 1 and key_owner != tff.CLIENTS 419 | pk_fed_val = federated_resolving_strategy.FederatedResolvingStrategyValue( 420 | pk_vals, tff.FederatedType(pk_type, key_owner, val_all_equal)) 421 | sk_fed_val = federated_resolving_strategy.FederatedResolvingStrategyValue( 422 | sk_vals, tff.FederatedType(sk_type, key_owner, val_all_equal)) 423 | self.key_references.update_keys( 424 | key_owner, public_key=pk_fed_val, secret_key=sk_fed_val) 425 | 426 | async def _share_public_key(self, key_owner, key_receiver): 427 | public_key = self.key_references.get_public_key(key_owner) 428 | children = self.strategy._get_child_executors(key_receiver) 429 | val = await public_key.compute() 430 | key_type = public_key.type_signature.member 431 | # we currently only support sharing n keys with 1 executor, 432 | # or sharing 1 key with n executors 433 | if isinstance(val, list): 434 | # sharing n keys with 1 executor 435 | py_typecheck.check_len(children, 1) 436 | executor = children[0] 437 | vals = [executor.create_value(v, key_type) for v in val] 438 | vals_type = tff.FederatedType(type_conversions.infer_type(val), key_receiver) 439 | else: 440 | # sharing 1 key with n executors 441 | # val is a single tensor 442 | vals = [c.create_value(val, key_type) for c in children] 443 | vals_type = tff.FederatedType(key_type, key_receiver, all_equal=True) 444 | public_key_rcv = federated_resolving_strategy.FederatedResolvingStrategyValue( 445 | await asyncio.gather(*vals), vals_type) 446 | self.key_references.update_keys(key_owner, public_key=public_key_rcv) 447 | 448 | 449 | def _get_other_placement(this_placement, both_placements): 450 | for p in both_placements: 451 | if p != this_placement: 452 | return p 453 | 454 | 455 | def _check_value_placement(arg, placements): 456 | py_typecheck.check_type(arg, federated_resolving_strategy.FederatedResolvingStrategyValue) 457 | py_typecheck.check_type(arg.type_signature, (tff.FederatedType, tff.StructType)) 458 | value_type = arg.type_signature 459 | sender_placement = arg.type_signature.placement 460 | if sender_placement not in placements: 461 | raise ValueError( 462 | 'Tried to send a value with placement {} through channel for ' 463 | 'placements ({},{}).'.format( 464 | str(sender_placement), *(str(p) for p in placements))) 465 | -------------------------------------------------------------------------------- /federated_aggregations/channels/channel_grid.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from dataclasses import dataclass 3 | from typing import Dict 4 | 5 | from tensorflow_federated.python.common_libs import py_typecheck 6 | from tensorflow_federated.python.core.impl.types import placement_literals 7 | 8 | from federated_aggregations.channels import channel 9 | from federated_aggregations import utils 10 | 11 | @dataclass 12 | class ChannelGrid: 13 | """A container fully characterizing the network topology between placements. 14 | 15 | FederatingStrategy implementations can use this class to keep track of 16 | Channels between any pair of placements. It assumes that each Channel is 17 | two-way, i.e. it describes an unordered pair of placements. 18 | 19 | Attributes: 20 | requires_setup: Tracks whether the underlying channels of the grid have 21 | been set up; only some Channels require this setup phase. 22 | """ 23 | _channel_dict: Dict[channel.PlacementPair, channel.Channel] 24 | requires_setup: bool = True 25 | 26 | async def setup_channels(self, strategy): 27 | if self.requires_setup: 28 | setup_steps = [] 29 | tmp_channel_dict = {} 30 | for placement_pair, channel_factory in self._channel_dict.items(): 31 | pair = tuple(sorted(placement_pair, key=lambda p: p.uri)) 32 | channel = channel_factory(strategy, *pair) 33 | setup_steps.append(channel.setup()) 34 | tmp_channel_dict[pair] = channel 35 | await asyncio.gather(*setup_steps) 36 | self._channel_dict = tmp_channel_dict 37 | self.requires_setup = False 38 | 39 | def __getitem__(self, placements: channel.PlacementPair): 40 | py_typecheck.check_type(placements, tuple) 41 | py_typecheck.check_len(placements, 2) 42 | sorted_placements = sorted(placements, key=lambda p: p.uri) 43 | return self._channel_dict.get(tuple(sorted_placements)) 44 | -------------------------------------------------------------------------------- /federated_aggregations/channels/channel_grid_test.py: -------------------------------------------------------------------------------- 1 | import tensorflow_federated as tff 2 | 3 | from federated_aggregations.channels import channel as ch 4 | from federated_aggregations.channels import channel_grid as grid 5 | from federated_aggregations.channels import channel_test_utils as utils 6 | 7 | class ChannelGridTest(utils.AsyncTestCase): 8 | def test_channel_grid_setup(self): 9 | channel_grid = grid.ChannelGrid( 10 | {(tff.CLIENTS, tff.SERVER): ch.PlaintextChannel}) 11 | ex = utils.create_test_executor(channel_grid=channel_grid) 12 | self.run_sync(channel_grid.setup_channels(ex._strategy)) 13 | 14 | channel = channel_grid[(tff.CLIENTS, tff.SERVER)] 15 | 16 | assert isinstance(channel, ch.PlaintextChannel) 17 | -------------------------------------------------------------------------------- /federated_aggregations/channels/channel_test.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from absl.testing import absltest 4 | from absl.testing import parameterized 5 | 6 | import asyncio 7 | import tensorflow as tf 8 | import tensorflow_federated as tff 9 | from tensorflow_federated.python.common_libs import structure 10 | from tensorflow_federated.python.core.impl.executors import federated_resolving_strategy 11 | from tensorflow_federated.python.core.impl.types import placement_literals 12 | from tensorflow_federated.python.core.impl.types import type_conversions 13 | 14 | from federated_aggregations.channels import channel as ch 15 | from federated_aggregations.channels import channel_grid as grid 16 | from federated_aggregations.channels import channel_test_utils as utils 17 | 18 | 19 | class PlaintextChannelTest(utils.AsyncTestCase): 20 | @parameterized.named_parameters( 21 | ("clients_to_server", [1., 2., 3.], tff.CLIENTS, tff.SERVER, 22 | [tf.constant(i + 1, dtype=tf.float32) for i in range(3)]), 23 | ("server_to_clients", 2.0, tff.SERVER, tff.CLIENTS, 24 | tf.constant(2.0, dtype=tf.float32))) 25 | def test_transfer(self, value, source_placement, target_placement, expected): 26 | fed_ex = utils.create_test_executor(channel=ch.PlaintextChannel) 27 | strategy = fed_ex._strategy 28 | channel_grid = strategy.channel_grid 29 | self.run_sync(channel_grid.setup_channels(strategy)) 30 | 31 | channel = channel_grid[(tff.CLIENTS, tff.SERVER)] 32 | val = self.run_sync(fed_ex.create_value(value, 33 | tff.FederatedType(tf.float32, source_placement))) 34 | transferred = self.run_sync(channel.transfer(val)) 35 | result = self.run_sync(transferred.compute()) 36 | 37 | expected_type = type_conversions.infer_type(expected) 38 | assert isinstance(transferred, 39 | federated_resolving_strategy.FederatedResolvingStrategyValue) 40 | if isinstance(expected, list): 41 | assert isinstance(transferred.type_signature, tff.StructType) 42 | for i, elt_type_spec in enumerate(transferred.type_signature): 43 | self.assertEqual(elt_type_spec, 44 | tff.FederatedType(expected_type[i], target_placement, True)) 45 | result = structure.flatten(result) 46 | else: 47 | self.assertEqual(transferred.type_signature, 48 | tff.FederatedType(expected_type, target_placement, True)) 49 | self.assertEqual(result, expected) 50 | 51 | 52 | class EasyBoxChannelTest(utils.AsyncTestCase): 53 | def test_generate_aggregator_keys(self): 54 | fed_ex = utils.create_test_executor() 55 | strategy = fed_ex._strategy 56 | channel_grid = strategy.channel_grid 57 | self.run_sync(channel_grid.setup_channels(strategy)) 58 | 59 | channel = channel_grid[(tff.CLIENTS, tff.SERVER)] 60 | pk_clients, sk_clients = channel.key_references.get_key_pair(tff.CLIENTS) 61 | pk_server, sk_server = channel.key_references.get_key_pair(tff.SERVER) 62 | 63 | self.assertEqual(str(pk_clients.type_signature), '@SERVER') 64 | self.assertEqual(str(sk_clients.type_signature), '{uint8[32]}@CLIENTS') 65 | self.assertEqual(str(pk_server.type_signature), 'uint8[32]@CLIENTS') 66 | self.assertEqual(str(sk_server.type_signature), 'uint8[32]@SERVER') 67 | 68 | @parameterized.named_parameters( 69 | ("clients_to_server", [1., 2., 3.], tff.CLIENTS, tff.SERVER, 70 | [tf.constant(i + 1, dtype=tf.float32) for i in range(3)]), 71 | ("server_to_clients", 2.0, tff.SERVER, tff.CLIENTS, 72 | tf.constant(2.0, dtype=tf.float32))) 73 | def test_transfer(self, value, source_placement, target_placement, expected): 74 | fed_ex = utils.create_test_executor() 75 | strategy = fed_ex._strategy 76 | channel_grid = strategy.channel_grid 77 | self.run_sync(channel_grid.setup_channels(fed_ex._strategy)) 78 | 79 | channel = channel_grid[(placement_literals.CLIENTS, 80 | placement_literals.SERVER)] 81 | val = self.run_sync(fed_ex.create_value(value, 82 | tff.FederatedType(tf.float32, source_placement))) 83 | transferred = self.run_sync(channel.transfer(val)) 84 | decrypted = self.run_sync(transferred.compute()) 85 | 86 | if isinstance(expected, list): 87 | decrypted = structure.flatten(decrypted) 88 | self.assertEqual(decrypted, expected) 89 | else: 90 | for d in decrypted: 91 | self.assertEqual(d, expected) 92 | -------------------------------------------------------------------------------- /federated_aggregations/channels/channel_test_utils.py: -------------------------------------------------------------------------------- 1 | # Lint as: python3 2 | # Copyright 2019, The TensorFlow Federated Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Utils for testing channels.""" 16 | from absl.testing import parameterized 17 | import asyncio 18 | 19 | import tensorflow_federated as tff 20 | from tensorflow_federated.python.core.impl.executors import federated_resolving_strategy 21 | 22 | from federated_aggregations.channels import channel_grid as grid 23 | from federated_aggregations.channels import channel as ch 24 | 25 | 26 | def create_bottom_stack(): 27 | executor = tff.framework.EagerTFExecutor() 28 | return tff.framework.ReferenceResolvingExecutor(executor) 29 | 30 | 31 | def create_test_executor( 32 | number_of_clients: int = 3, 33 | channel_grid: grid.ChannelGrid = None, 34 | channel: ch.Channel = ch.EasyBoxChannel): 35 | if channel_grid is None: 36 | channel_grid = grid.ChannelGrid({(tff.CLIENTS, tff.SERVER): channel}) 37 | strategy_executors = { 38 | tff.SERVER: create_bottom_stack(), 39 | tff.CLIENTS: [create_bottom_stack() for _ in range(number_of_clients)], 40 | } 41 | return tff.framework.FederatingExecutor( 42 | MockStrategy.factory(strategy_executors, channel_grid), 43 | unplaced_executor=create_bottom_stack()) 44 | 45 | 46 | class AsyncTestCase(parameterized.TestCase): 47 | """A test case that manages a new event loop for each test. 48 | 49 | Each test will have a new event loop instead of using the current event loop. 50 | This ensures that tests are isolated from each other and avoid unexpected side 51 | effects. 52 | 53 | Attributes: 54 | loop: An `asyncio` event loop. 55 | """ 56 | 57 | def setUp(self): 58 | super().setUp() 59 | self.loop = asyncio.new_event_loop() 60 | 61 | # If `setUp()` fails, then `tearDown()` is not called; however cleanup 62 | # functions will be called. Register the newly created loop `close()` 63 | # function here to ensure it is closed after each test. 64 | self.addCleanup(self.loop.close) 65 | 66 | def run_sync(self, coro): 67 | return self.loop.run_until_complete(coro) 68 | 69 | 70 | class MockStrategy(federated_resolving_strategy.FederatedResolvingStrategy): 71 | def __init__(self, executor, target_executors, channel_grid): 72 | super().__init__(executor, target_executors) 73 | self.channel_grid = channel_grid 74 | 75 | @classmethod 76 | def factory(cls, target_executors, channel_grid): 77 | return lambda executor: cls(executor, target_executors, channel_grid) 78 | 79 | def _get_child_executors(self, placement, index=None): 80 | child_executors = self._target_executors[placement] 81 | if index is not None: 82 | return child_executors[index] 83 | return child_executors 84 | -------------------------------------------------------------------------------- /federated_aggregations/channels/computations.py: -------------------------------------------------------------------------------- 1 | import tensorflow_federated as tff 2 | from tf_encrypted.primitives.sodium import easy_box 3 | 4 | 5 | def make_encryptor(plaintext_type, pk_rcv_type, sk_snd_type): 6 | @tff.tf_computation(plaintext_type, pk_rcv_type, sk_snd_type) 7 | def encrypt_tensor(plaintext, pk_rcv, sk_snd): 8 | pk_rcv = easy_box.PublicKey(pk_rcv) 9 | sk_snd = easy_box.SecretKey(sk_snd) 10 | nonce = easy_box.gen_nonce() 11 | ciphertext, mac = easy_box.seal_detached(plaintext, nonce, pk_rcv, sk_snd) 12 | return ciphertext.raw, mac.raw, nonce.raw 13 | 14 | return encrypt_tensor 15 | 16 | 17 | def make_decryptor(sender_values_type, pk_snd_type, sk_rcv_snd, 18 | orig_tensor_dtype): 19 | @tff.tf_computation(sender_values_type, pk_snd_type, sk_rcv_snd) 20 | def decrypt_tensor(sender_values, pk_snd, sk_rcv): 21 | ciphertext = easy_box.Ciphertext(sender_values[0]) 22 | mac = easy_box.Mac(sender_values[1]) 23 | nonce = easy_box.Nonce(sender_values[2]) 24 | pk_snd = easy_box.PublicKey(pk_snd) 25 | sk_rcv = easy_box.SecretKey(sk_rcv) 26 | plaintext_recovered = easy_box.open_detached( 27 | ciphertext, mac, nonce, pk_snd, sk_rcv, orig_tensor_dtype) 28 | return plaintext_recovered 29 | 30 | return decrypt_tensor 31 | 32 | 33 | def make_keygen(): 34 | @tff.tf_computation() 35 | def key_generator(): 36 | pk, sk = easy_box.gen_keypair() 37 | return pk.raw, sk.raw 38 | 39 | return key_generator 40 | -------------------------------------------------------------------------------- /federated_aggregations/channels/key_store.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | import tensorflow_federated as tff 4 | from tensorflow_federated.python.common_libs import py_typecheck 5 | from tensorflow_federated.python.core.impl.executors import federated_resolving_strategy 6 | from tensorflow_federated.python.core.impl.types import placement_literals 7 | 8 | 9 | class KeyStore: 10 | """A container for key management and storage. 11 | 12 | This is only used during the setup phase of EasyBoxChannel. 13 | """ 14 | _default_store = lambda k: {'pk': None, 'sk': None} 15 | 16 | def __init__(self): 17 | self._key_store = collections.defaultdict(self._default_store) 18 | 19 | def get_key_pair(self, key_owner): 20 | key_owner_cache = self._get_keys(key_owner) 21 | return key_owner_cache['pk'], key_owner_cache['sk'] 22 | 23 | def get_public_key(self, key_owner): 24 | return self._get_keys(key_owner)['pk'] 25 | 26 | def get_secret_key(self, key_owner): 27 | return self._get_keys(key_owner)['sk'] 28 | 29 | def _get_keys(self, key_owner): 30 | py_typecheck.check_type(key_owner, placement_literals.PlacementLiteral) 31 | return self._key_store[key_owner.name] 32 | 33 | def update_keys(self, key_owner, public_key=None, secret_key=None): 34 | key_owner_cache = self._get_keys(key_owner) 35 | if public_key is not None: 36 | self._check_key_type(public_key) 37 | key_owner_cache['pk'] = public_key 38 | if secret_key is not None: 39 | self._check_key_type(secret_key) 40 | key_owner_cache['sk'] = secret_key 41 | 42 | def _check_key_type(self, key): 43 | py_typecheck.check_type(key, 44 | federated_resolving_strategy.FederatedResolvingStrategyValue) 45 | py_typecheck.check_type(key.type_signature, 46 | (tff.StructType, tff.FederatedType)) -------------------------------------------------------------------------------- /federated_aggregations/paillier/__init__.py: -------------------------------------------------------------------------------- 1 | from .factory import local_paillier_executor_factory -------------------------------------------------------------------------------- /federated_aggregations/paillier/computations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_federated as tff 3 | from tf_encrypted.primitives import paillier 4 | 5 | 6 | def make_keygen(transit_dtype=tf.int32, modulus_bitlength=2048): 7 | @tff.tf_computation 8 | def _keygen(): 9 | encryption_key, decryption_key = paillier.gen_keypair(modulus_bitlength) 10 | ek_raw = encryption_key.export(dtype=transit_dtype) 11 | dk_raw = decryption_key.export(dtype=transit_dtype) 12 | return ek_raw, dk_raw 13 | 14 | return _keygen 15 | 16 | 17 | def make_encryptor(transit_dtype=tf.int32): 18 | @tff.tf_computation 19 | def _encrypt(encryption_key_raw, plaintext): 20 | ek = paillier.EncryptionKey(encryption_key_raw) 21 | ciphertext = paillier.encrypt(ek, plaintext) 22 | return ciphertext.export(dtype=transit_dtype) 23 | 24 | return _encrypt 25 | 26 | 27 | def make_decryptor( 28 | decryption_key_type, 29 | encryption_key_type, 30 | ciphertext_type, 31 | export_dtype, 32 | ): 33 | @tff.tf_computation( 34 | decryption_key_type, encryption_key_type, ciphertext_type) 35 | def _decrypt(decryption_key_raw, encryption_key_raw, ciphertext_raw): 36 | dk = paillier.DecryptionKey(*decryption_key_raw) 37 | ek = paillier.EncryptionKey(encryption_key_raw) 38 | ciphertext = paillier.Ciphertext(ek, ciphertext_raw) 39 | return paillier.decrypt(dk, ciphertext, export_dtype) 40 | 41 | return _decrypt 42 | 43 | 44 | def make_sequence_sum(transit_dtype=tf.int32): 45 | def adder(ek, xs): 46 | assert len(xs) >= 1 47 | if len(xs) == 1: 48 | return xs[0] 49 | split = len(xs) // 2 50 | left = xs[:split] 51 | right = xs[split:] 52 | return paillier.add(ek, adder(ek, left), adder(ek, right), do_refresh=False) 53 | 54 | @tff.tf_computation 55 | def _sequence_sum(encryption_key_raw, summands_raw): 56 | ek = paillier.EncryptionKey(encryption_key_raw) 57 | summands = [ 58 | paillier.Ciphertext(ek, summand) 59 | for summand in summands_raw 60 | ] 61 | result = adder(ek, summands) 62 | refreshed_result = paillier.refresh(ek, result) 63 | return refreshed_result.export(dtype=transit_dtype) 64 | 65 | return _sequence_sum 66 | 67 | 68 | def make_reshape_tensor(tensor_type, output_shape): 69 | @tff.tf_computation 70 | def _reshape_tensor(tensor): 71 | return tf.reshape(tensor, output_shape) 72 | 73 | return _reshape_tensor 74 | -------------------------------------------------------------------------------- /federated_aggregations/paillier/factory.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import List, Callable, Optional, Sequence 3 | 4 | import tensorflow as tf 5 | import tensorflow_federated as tff 6 | from tensorflow_federated.python.common_libs import py_typecheck 7 | from tensorflow_federated.python.core.impl.executors import eager_tf_executor 8 | from tensorflow_federated.python.core.impl.executors import executor_base 9 | from tensorflow_federated.python.core.impl.executors import executor_factory 10 | from tensorflow_federated.python.core.impl.executors import executor_stacks 11 | from tensorflow_federated.python.core.impl.executors import federating_executor 12 | from tensorflow_federated.python.core.impl.executors import sizing_executor 13 | from tensorflow_federated.python.core.impl.types import placement_literals 14 | 15 | from federated_aggregations import channels 16 | from federated_aggregations.paillier import placement as paillier_placement 17 | from federated_aggregations.paillier import strategy as paillier_strategy 18 | from federated_aggregations.paillier import computations as paillier_comp 19 | 20 | 21 | # TODO: add more factory functions, including: 22 | # - composite executory factory 23 | # - worker pool factory (for use with RemoteExecutor) 24 | 25 | class AggregatingUnplacedExecutorFactory(executor_stacks.UnplacedExecutorFactory): 26 | 27 | def __init__( 28 | self, 29 | *, 30 | use_caching: bool, 31 | server_device: Optional[tf.config.LogicalDevice] = None, 32 | aggregator_device: Optional[tf.config.LogicalDevice] = None, 33 | client_devices: Optional[Sequence[tf.config.LogicalDevice]] = ()): 34 | super().__init__( 35 | use_caching=use_caching, 36 | server_device=server_device, 37 | client_devices=client_devices) 38 | self._aggregator_device = aggregator_device 39 | 40 | def create_executor( 41 | self, 42 | *, 43 | cardinalities: Optional[executor_factory.CardinalitiesType] = None, 44 | placement: Optional[placement_literals.PlacementLiteral] = None 45 | ) -> executor_base.Executor: 46 | if cardinalities: 47 | raise ValueError( 48 | 'Unplaced executors cannot accept nonempty cardinalities as ' 49 | 'arguments. Received cardinalities: {}.'.format(cardinalities)) 50 | if placement == paillier_placement.AGGREGATOR: 51 | ex = eager_tf_executor.EagerTFExecutor(device=self._aggregator_device) 52 | return executor_stacks._wrap_executor_in_threading_stack(ex) 53 | return super().create_executor( 54 | cardinalities=cardinalities, placement=placement) 55 | 56 | 57 | class PaillierAggregatingExecutorFactory(executor_stacks.FederatingExecutorFactory): 58 | 59 | def create_executor( 60 | self, cardinalities: executor_factory.CardinalitiesType 61 | ) -> executor_base.Executor: 62 | """Constructs a federated executor with requested cardinalities.""" 63 | num_clients = self._validate_requested_clients(cardinalities) 64 | client_stacks = [ 65 | self._unplaced_executor_factory.create_executor( 66 | cardinalities={}, placement=placement_literals.CLIENTS) 67 | for _ in range(self._num_client_executors) 68 | ] 69 | if self._use_sizing: 70 | client_stacks = [ 71 | sizing_executor.SizingExecutor(ex) for ex in client_stacks 72 | ] 73 | self._sizing_executors.extend(client_stacks) 74 | paillier_stack = self._unplaced_executor_factory.create_executor( 75 | cardinalities={}, placement=paillier_placement.AGGREGATOR) 76 | if self._use_sizing: 77 | paillier_stack = sizing_executor.SizingExecutor(paillier_stack) 78 | # Set up secure channel between clients & Paillier executor 79 | secure_channel_grid = channels.ChannelGrid({ 80 | (tff.CLIENTS, 81 | paillier_placement.AGGREGATOR): channels.EasyBoxChannel, 82 | (tff.CLIENTS, 83 | tff.SERVER): channels.PlaintextChannel, 84 | (paillier_placement.AGGREGATOR, 85 | tff.SERVER): channels.PlaintextChannel}) 86 | # Build a FederatingStrategy factory for Paillier aggregation with the secure channel setup 87 | strategy_factory = paillier_strategy.PaillierAggregatingStrategy.factory( 88 | { 89 | placement_literals.CLIENTS: [ 90 | client_stacks[k % len(client_stacks)] 91 | for k in range(num_clients) 92 | ], 93 | placement_literals.SERVER: 94 | self._unplaced_executor_factory.create_executor( 95 | cardinalities={}, placement=placement_literals.SERVER), 96 | paillier_placement.AGGREGATOR: paillier_stack, 97 | }, 98 | channel_grid=secure_channel_grid, 99 | # NOTE: we let the server generate it's own key here, but for proper 100 | # deployment we would want to supply a key verified by proper PKI 101 | key_inputter=paillier_comp.make_keygen(modulus_bitlength=2048)) 102 | unplaced_executor = self._unplaced_executor_factory.create_executor( 103 | cardinalities={}) 104 | executor = federating_executor.FederatingExecutor( 105 | strategy_factory, unplaced_executor) 106 | return executor_stacks._wrap_executor_in_threading_stack(executor) 107 | 108 | 109 | def local_paillier_executor_factory( 110 | num_clients=None, 111 | num_client_executors=32, 112 | server_tf_device=None, 113 | aggregator_tf_device=None, 114 | client_tf_devices=tuple()): 115 | """Like tff.framework.local_executor_factory, but with Paillier aggregation. 116 | 117 | The resulting factory function does not implement composing executor stacks, 118 | so there is no max_fanout argument. 119 | 120 | Args: 121 | num_clients: The number of clients. If specified, the executor factory 122 | function returned by `local_paillier_executor_factory` will be configured 123 | to have exactly `num_clients` clients. If unspecified (`None`), then the 124 | function returned will attempt to infer cardinalities of all placements 125 | for which it is passed values. 126 | num_client_executors: The number of distinct client executors to run 127 | concurrently; executing more clients than this number results in 128 | multiple clients having their work pinned on a single executor in a 129 | synchronous fashion. 130 | server_tf_device: A `tf.config.LogicalDevice` to place server and 131 | other computation without explicit TFF placement. 132 | aggregator_tf_device: A `tf.config.LogicalDevice` to place computation 133 | of the Paillier aggregation. See README for a clearer description. 134 | client_tf_devices: List/tuple of `tf.config.LogicalDevice` to place clients 135 | for simulation. Possibly accelerators returned by 136 | `tf.config.list_logical_devices()`. 137 | """ 138 | # TODO consider parameterizing this function with channel_grid 139 | if server_tf_device is not None: 140 | py_typecheck.check_type(server_tf_device, tf.config.LogicalDevice) 141 | py_typecheck.check_type(client_tf_devices, (tuple, list)) 142 | py_typecheck.check_type(num_client_executors, int) 143 | if num_clients is not None: 144 | py_typecheck.check_type(num_clients, int) 145 | unplaced_ex_factory = AggregatingUnplacedExecutorFactory( 146 | use_caching=True, 147 | server_device=server_tf_device, 148 | client_devices=client_tf_devices) 149 | paillier_aggregating_executor_factory = PaillierAggregatingExecutorFactory( 150 | num_client_executors=num_client_executors, 151 | unplaced_ex_factory=unplaced_ex_factory, 152 | num_clients=num_clients, 153 | use_sizing=False) 154 | factory_fn = paillier_aggregating_executor_factory.create_executor 155 | return tff.framework.create_executor_factory( 156 | paillier_aggregating_executor_factory.create_executor) 157 | -------------------------------------------------------------------------------- /federated_aggregations/paillier/factory_test.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from absl.testing import absltest 4 | from absl.testing import parameterized 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_federated as tff 8 | from tensorflow_federated.python.core.impl.executors import execution_context 9 | from tensorflow_federated.python.core.impl.executors import executor_factory 10 | 11 | from federated_aggregations.paillier import factory 12 | 13 | def _temperature_sensor_example_next_fn(): 14 | 15 | @tff.tf_computation( 16 | tff.SequenceType(tf.float32), tf.float32) 17 | def count_over(ds, t): 18 | return ds.reduce( 19 | np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) 20 | 21 | @tff.tf_computation(tff.SequenceType(tf.float32)) 22 | def count_total(ds): 23 | return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) 24 | 25 | @tff.federated_computation( 26 | tff.FederatedType(tff.SequenceType(tf.float32), tff.CLIENTS), 27 | tff.FederatedType(tf.float32, tff.SERVER)) 28 | def comp(temperatures, threshold): 29 | return tff.federated_mean( 30 | tff.federated_map( 31 | count_over, 32 | tff.federated_zip( 33 | [temperatures, 34 | tff.federated_broadcast(threshold)])), 35 | tff.federated_map(count_total, temperatures)) 36 | 37 | return comp 38 | 39 | 40 | def _install_executor(executor_factory_instance): 41 | context = execution_context.ExecutionContext(executor_factory_instance) 42 | return tff.framework.get_context_stack().install(context) 43 | 44 | 45 | class ExecutorMock(mock.MagicMock, tff.framework.Executor): 46 | 47 | def create_value(self, *args): 48 | pass 49 | 50 | def create_call(self, *args): 51 | pass 52 | 53 | def create_selection(self, *args): 54 | pass 55 | 56 | def create_struct(self, *args): 57 | pass 58 | 59 | def close(self, *args): 60 | pass 61 | 62 | 63 | class ExecutorStacksTest(parameterized.TestCase): 64 | 65 | @parameterized.named_parameters( 66 | ('paillier_executor_factory', factory.local_paillier_executor_factory), 67 | ) 68 | def test_construction_with_no_args(self, executor_factory_fn): 69 | executor_factory_impl = executor_factory_fn() 70 | self.assertIsInstance(executor_factory_impl, 71 | executor_factory.ExecutorFactoryImpl) 72 | 73 | @parameterized.named_parameters( 74 | ('paillier_executor_factory_none_clients', 75 | factory.local_paillier_executor_factory()), 76 | ('paillier_executor_factory_three_clients', 77 | factory.local_paillier_executor_factory(num_clients=3)), 78 | ) 79 | def test_execution_of_temperature_sensor_example(self, executor): 80 | comp = _temperature_sensor_example_next_fn() 81 | to_float = lambda x: tf.cast(x, tf.float32) 82 | temperatures = [ 83 | tf.data.Dataset.range(10).map(to_float), 84 | tf.data.Dataset.range(20).map(to_float), 85 | tf.data.Dataset.range(30).map(to_float), 86 | ] 87 | threshold = 15.0 88 | 89 | with _install_executor(executor): 90 | result = comp(temperatures, threshold) 91 | 92 | self.assertAlmostEqual(result, 8.333, places=3) 93 | 94 | @parameterized.named_parameters( 95 | ('paillier_executor_factory_none_clients', 96 | factory.local_paillier_executor_factory()), 97 | ('paillier_executor_factory_one_client', 98 | factory.local_paillier_executor_factory(num_clients=1)), 99 | ) 100 | def test_execution_of_tensorflow(self, executor): 101 | 102 | @tff.tf_computation 103 | def comp(): 104 | return tf.math.add(5, 5) 105 | 106 | with _install_executor(executor): 107 | result = comp() 108 | 109 | self.assertEqual(result, 10) 110 | 111 | 112 | @parameterized.named_parameters( 113 | ('paillier_executor_factory', factory.local_paillier_executor_factory), 114 | ) 115 | def test_create_executor_raises_with_wrong_cardinalities( 116 | self, executor_factory_fn): 117 | executor_factory_impl = executor_factory_fn(num_clients=5) 118 | cardinalities = { 119 | tff.SERVER: 1, 120 | None: 1, 121 | tff.CLIENTS: 1, 122 | } 123 | with self.assertRaises(ValueError,): 124 | executor_factory_impl.create_executor(cardinalities) 125 | 126 | if __name__ == '__main__': 127 | absltest.main() 128 | -------------------------------------------------------------------------------- /federated_aggregations/paillier/placement.py: -------------------------------------------------------------------------------- 1 | from tensorflow_federated.python.core.impl.types import placement_literals 2 | 3 | AGGREGATOR = placement_literals.PlacementLiteral( 4 | 'AGGREGATOR', 5 | 'aggregator', 6 | default_all_equal=True, 7 | description='An "unplacement" for aggregations.') 8 | -------------------------------------------------------------------------------- /federated_aggregations/paillier/strategy.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from collections import OrderedDict 3 | 4 | import tensorflow_federated as tff 5 | from tensorflow_federated.python.common_libs import py_typecheck 6 | from tensorflow_federated.python.core.api import computation_types 7 | from tensorflow_federated.python.core.impl.executors import federated_resolving_strategy 8 | from tensorflow_federated.python.core.impl.types import placement_literals 9 | from tensorflow_federated.python.core.impl.types import type_analysis 10 | from tensorflow_federated.python.core.impl.types import type_conversions 11 | 12 | from federated_aggregations import utils 13 | from federated_aggregations.channels import channel 14 | from federated_aggregations.paillier import placement as paillier_placement 15 | from federated_aggregations.paillier import computations as paillier_comp 16 | 17 | 18 | class PaillierAggregatingStrategy(tff.framework.FederatedResolvingStrategy): 19 | """A FederatingStrategy implementing secure_sum with a Paillier cryptoscheme. 20 | 21 | Outside of its ability to handle a tff.federated_secure_sum intrinsic, this 22 | strategy is essentially identical to FederatedResolvingStrategy. Our 23 | implementaiton of federated_secure_sum has an important semantic difference 24 | from the one described in TFF; we explicitly do not handle the positional 25 | bitwidth argument. This is because we do not optimize for efficient packing 26 | of client ciphertexts during communication and computation, which would 27 | require the bitwidth argument in order to maximize the number of native 28 | int32s that can be packed into a single Paillier ciphertext tensor. 29 | 30 | Attributes: 31 | channel_grid: A ChannelGrid containing the explicit communication Channels 32 | to use between each unordered pair of placements. A minimally secure 33 | instantiation of this strategy will use a secure channel for the 34 | `(AGGREGATOR, CLIENTS)` placement pair, but it could also be a plaintext 35 | channel (e.g. for tests or baseline benchmarks.). 36 | 37 | Raises: 38 | ValueError if the `target_executors` argument supplied for Strategy 39 | construction is malformed. This is the case if it is missing an 40 | executor stack for the AGGREGATOR placement, or if the executor stack 41 | has cardinality other than 1. 42 | """ 43 | @classmethod 44 | def factory(cls, target_executors, channel_grid, key_inputter): 45 | return lambda executor: cls(executor, target_executors, channel_grid, key_inputter) 46 | 47 | def __init__(self, executor, target_executors, channel_grid, key_inputter): 48 | super().__init__(executor, target_executors) 49 | self._check_for_paillier_placement() 50 | self.channel_grid = channel_grid 51 | self._requires_setup = True # lazy key setup 52 | self._key_inputter = key_inputter 53 | self._paillier_encryptor = paillier_comp.make_encryptor() 54 | self._paillier_sequence_sum = paillier_comp.make_sequence_sum() 55 | self._paillier_decryptor_cache = {} 56 | self._reshape_function_cache = {} 57 | 58 | def _get_child_executors(self, placement, index=None): 59 | child_executors = self._target_executors[placement] 60 | if index is not None: 61 | return child_executors[index] 62 | return child_executors 63 | 64 | def _check_for_paillier_placement(self): 65 | if paillier_placement.AGGREGATOR not in self._target_executors: 66 | raise ValueError('Missing Paillier aggregator placement.') 67 | paillier_executor = self._target_executors[paillier_placement.AGGREGATOR] 68 | paillier_cardinality = len(paillier_executor) 69 | if paillier_cardinality != 1: 70 | raise ValueError( 71 | 'Unsupported cardinality for Paillier aggregator placement {}: ' 72 | '{}.'.format(paillier_placement.AGGREGATOR, paillier_cardinality)) 73 | 74 | async def _move(self, value, source_placement, target_placement): 75 | await self.channel_grid.setup_channels(self) 76 | channel = self.channel_grid[(source_placement, target_placement)] 77 | return await channel.transfer(value) 78 | 79 | async def _paillier_setup(self): 80 | # Load paillier keys on server 81 | key_inputter = await self._executor.create_value(self._key_inputter) 82 | _check_key_inputter(key_inputter) 83 | fed_output = await self._eval(key_inputter, tff.SERVER, all_equal=True) 84 | output = fed_output.internal_representation[0] 85 | # Broadcast encryption key to all placements 86 | server_executor = self._get_child_executors(tff.SERVER, index=0) 87 | ek_ref = await server_executor.create_selection(output, index=0) 88 | ek = federated_resolving_strategy.FederatedResolvingStrategyValue(ek_ref, 89 | tff.FederatedType(ek_ref.type_signature, tff.SERVER, True)) 90 | placed = await asyncio.gather( 91 | self._move(ek, tff.SERVER, tff.CLIENTS), 92 | self._move(ek, tff.SERVER, paillier_placement.AGGREGATOR)) 93 | self.encryption_key_server = ek 94 | self.encryption_key_clients = placed[0] 95 | self.encryption_key_paillier = placed[1] 96 | # Keep decryption key on server with formal placement 97 | dk_ref = await server_executor.create_selection(output, index=1) 98 | self.decryption_key = federated_resolving_strategy.FederatedResolvingStrategyValue(dk_ref, 99 | tff.FederatedType(dk_ref.type_signature, tff.SERVER, all_equal=True)) 100 | 101 | async def compute_federated_secure_sum(self, arg): 102 | self._check_arg_is_structure(arg) 103 | py_typecheck.check_len(arg.internal_representation, 2) 104 | value_type = arg.type_signature[0] 105 | type_analysis.check_federated_type(value_type, placement=tff.CLIENTS) 106 | py_typecheck.check_type(value_type.member, tff.TensorType) 107 | # Stash input dtype for later 108 | input_tensor_dtype = value_type.member.dtype 109 | # Paillier setup phase 110 | if self._requires_setup: 111 | await self._paillier_setup() 112 | self._requires_setup = False 113 | # Stash input shape, and reshape input tensor to matrix-form 114 | input_tensor_shape = value_type.member.shape 115 | if len(input_tensor_shape) != 2: 116 | clients_value = await self._compute_reshape_on_tensor( 117 | await self._executor.create_selection(arg, index=0), 118 | output_shape=[1, input_tensor_shape.num_elements()]) 119 | else: 120 | clients_value = await self._executor.create_selection(arg, index=0) 121 | # Encrypt summands on tff.CLIENTS 122 | encrypted_values = await self._compute_paillier_encryption( 123 | self.encryption_key_clients, clients_value) 124 | # Perform Paillier sum on ciphertexts 125 | encrypted_values = await self._move(encrypted_values, 126 | tff.CLIENTS, paillier_placement.AGGREGATOR) 127 | encrypted_sum = await self._compute_paillier_sum( 128 | self.encryption_key_paillier, encrypted_values) 129 | # Move to server and decrypt the result 130 | encrypted_sum = await self._move(encrypted_sum, 131 | paillier_placement.AGGREGATOR, tff.SERVER) 132 | decrypted_result = await self._compute_paillier_decryption( 133 | self.decryption_key, 134 | self.encryption_key_server, 135 | encrypted_sum, 136 | export_dtype=input_tensor_dtype) 137 | return await self._compute_reshape_on_tensor( 138 | decrypted_result, output_shape=input_tensor_shape.as_list()) 139 | 140 | async def _compute_paillier_encryption(self, 141 | client_encryption_keys: federated_resolving_strategy.FederatedResolvingStrategyValue, 142 | clients_value: federated_resolving_strategy.FederatedResolvingStrategyValue): 143 | client_children = self._get_child_executors(tff.CLIENTS) 144 | num_clients = len(client_children) 145 | py_typecheck.check_len(client_encryption_keys.internal_representation, 146 | num_clients) 147 | py_typecheck.check_len(clients_value.internal_representation, num_clients) 148 | encryptor_proto, encryptor_type = utils.lift_to_computation_spec( 149 | self._paillier_encryptor, 150 | input_arg_type=tff.StructType(( 151 | client_encryption_keys.type_signature.member, 152 | clients_value.type_signature.member))) 153 | encryptor_fns = asyncio.gather(*[ 154 | c.create_value(encryptor_proto, encryptor_type) 155 | for c in client_children]) 156 | encryptor_args = asyncio.gather(*[c.create_struct((ek, v)) for c, ek, v in zip( 157 | client_children, 158 | client_encryption_keys.internal_representation, 159 | clients_value.internal_representation)]) 160 | encryptor_fns, encryptor_args = await asyncio.gather( 161 | encryptor_fns, encryptor_args) 162 | encrypted_values = await asyncio.gather(*[ 163 | c.create_call(fn, arg) for c, fn, arg in zip( 164 | client_children, encryptor_fns, encryptor_args)]) 165 | return federated_resolving_strategy.FederatedResolvingStrategyValue(encrypted_values, 166 | tff.FederatedType(encryptor_type.result, tff.CLIENTS, 167 | clients_value.type_signature.all_equal)) 168 | 169 | async def _compute_paillier_sum(self, 170 | encryption_key: federated_resolving_strategy.FederatedResolvingStrategyValue, 171 | values: federated_resolving_strategy.FederatedResolvingStrategyValue): 172 | paillier_child = self._get_child_executors( 173 | paillier_placement.AGGREGATOR, index=0) 174 | sum_proto, sum_type = utils.lift_to_computation_spec( 175 | self._paillier_sequence_sum, 176 | input_arg_type=tff.StructType(( 177 | encryption_key.type_signature.member, 178 | tff.StructType([vt.member for vt in values.type_signature])))) 179 | sum_fn = paillier_child.create_value(sum_proto, sum_type) 180 | sum_arg = paillier_child.create_struct(( 181 | encryption_key.internal_representation, 182 | await paillier_child.create_struct(values.internal_representation))) 183 | sum_fn, sum_arg = await asyncio.gather(sum_fn, sum_arg) 184 | encrypted_sum = await paillier_child.create_call(sum_fn, sum_arg) 185 | return federated_resolving_strategy.FederatedResolvingStrategyValue(encrypted_sum, 186 | tff.FederatedType(sum_type.result, paillier_placement.AGGREGATOR, True)) 187 | 188 | async def _compute_paillier_decryption(self, 189 | decryption_key: federated_resolving_strategy.FederatedResolvingStrategyValue, 190 | encryption_key: federated_resolving_strategy.FederatedResolvingStrategyValue, 191 | value: federated_resolving_strategy.FederatedResolvingStrategyValue, 192 | export_dtype): 193 | server_child = self._get_child_executors(tff.SERVER, index=0) 194 | decryptor_arg_spec = (decryption_key.type_signature.member, 195 | encryption_key.type_signature.member, 196 | value.type_signature.member) 197 | decryptor_proto, decryptor_type = utils.materialize_computation_from_cache( 198 | paillier_comp.make_decryptor, 199 | self._paillier_decryptor_cache, 200 | arg_spec=decryptor_arg_spec, 201 | export_dtype=export_dtype) 202 | decryptor_fn = server_child.create_value(decryptor_proto, decryptor_type) 203 | decryptor_arg = server_child.create_struct(( 204 | decryption_key.internal_representation, 205 | encryption_key.internal_representation, 206 | value.internal_representation)) 207 | decryptor_fn, decryptor_arg = await asyncio.gather(decryptor_fn, decryptor_arg) 208 | decrypted_value = await server_child.create_call(decryptor_fn, decryptor_arg) 209 | return federated_resolving_strategy.FederatedResolvingStrategyValue([decrypted_value], 210 | tff.FederatedType(decryptor_type.result, tff.SERVER, True)) 211 | 212 | async def _compute_reshape_on_tensor(self, tensor, output_shape): 213 | tensor_type = tensor.type_signature.member 214 | shape_type = type_conversions.infer_type(output_shape) 215 | reshaper_proto, reshaper_type = utils.materialize_computation_from_cache( 216 | paillier_comp.make_reshape_tensor, 217 | self._reshape_function_cache, 218 | arg_spec=(tensor_type,), 219 | output_shape=output_shape) 220 | tensor_placement = tensor.type_signature.placement 221 | children = self._get_child_executors(tensor_placement) 222 | py_typecheck.check_len(tensor.internal_representation, len(children)) 223 | reshaper_fns = await asyncio.gather(*[ 224 | ex.create_value(reshaper_proto, reshaper_type) for ex in children]) 225 | reshaped_tensors = await asyncio.gather(*[ 226 | ex.create_call(fn, arg) for ex, fn, arg in zip( 227 | children, reshaper_fns, tensor.internal_representation)]) 228 | output_tensor_spec = tff.FederatedType( 229 | tff.TensorType(tensor_type.dtype, output_shape), 230 | tensor_placement, 231 | tensor.type_signature.all_equal) 232 | return federated_resolving_strategy.FederatedResolvingStrategyValue( 233 | reshaped_tensors, output_tensor_spec) 234 | 235 | 236 | def _check_key_inputter(fn_value): 237 | fn_type = fn_value.type_signature 238 | py_typecheck.check_type(fn_type, tff.FunctionType) 239 | try: 240 | py_typecheck.check_len(fn_type.result, 2) 241 | except ValueError: 242 | raise ValueError( 243 | 'Expected 2 elements in the output of key_inputter, ' 244 | 'found {}.'.format(len(fn_type.result))) 245 | ek_type, dk_type = fn_type.result 246 | py_typecheck.check_type(ek_type, tff.TensorType) 247 | py_typecheck.check_type(dk_type, tff.StructType) 248 | try: 249 | py_typecheck.check_len(dk_type, 2) 250 | except ValueError: 251 | raise ValueError( 252 | 'Expected a two element tuple for the decryption key from ' 253 | 'key_inputter, found {} elements.'.format(len(fn_type.result))) 254 | py_typecheck.check_type(dk_type[0], tff.TensorType) 255 | py_typecheck.check_type(dk_type[1], tff.TensorType) 256 | -------------------------------------------------------------------------------- /federated_aggregations/paillier/strategy_test.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | from absl.testing import absltest 4 | from absl.testing import parameterized 5 | import numpy as np 6 | import tensorflow as tf 7 | import tensorflow_federated as tff 8 | from tensorflow_federated.python.core.impl.executors import execution_context 9 | 10 | from federated_aggregations.paillier import factory 11 | 12 | 13 | def _install_executor(executor_factory_instance): 14 | context = execution_context.ExecutionContext(executor_factory_instance) 15 | return tff.framework.get_context_stack().install(context) 16 | 17 | 18 | def make_integer_secure_sum(input_shape): 19 | if input_shape is None: 20 | member_type = tf.int32 21 | else: 22 | member_type = tff.TensorType(tf.int32, input_shape) 23 | @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS)) 24 | def secure_paillier_addition(x): 25 | return tff.federated_secure_sum(x, 64) 26 | 27 | return secure_paillier_addition 28 | 29 | 30 | class PaillierAggregatingStrategyTest(parameterized.TestCase): 31 | @parameterized.named_parameters( 32 | ('paillier_executor_factory_none_clients', 33 | factory.local_paillier_executor_factory()), 34 | ('paillier_executor_factory_five_clients', 35 | factory.local_paillier_executor_factory(num_clients=5))) 36 | def test_federated_secure_sum_with(self, factory): 37 | secure_paillier_addition = make_integer_secure_sum(None) 38 | with _install_executor(factory): 39 | result = secure_paillier_addition([1, 2, 3, 4, 5]) 40 | self.assertAlmostEqual(result, 15.0) 41 | 42 | @parameterized.named_parameters( 43 | (('rank{}'.format(i), [2] * i) for i in range(1, 6))) 44 | def test_secure_sum_inputs(self, input_shape): 45 | input_tensor = np.ones(input_shape, dtype=np.int32) 46 | NUM_CLIENTS = 5 47 | expected = input_tensor * NUM_CLIENTS 48 | secure_paillier_addition = make_integer_secure_sum(input_shape) 49 | with _install_executor(factory.local_paillier_executor_factory()): 50 | result = secure_paillier_addition([input_tensor] * NUM_CLIENTS) 51 | np.testing.assert_almost_equal(result, expected) 52 | 53 | @parameterized.named_parameters( 54 | (('{}'.format(n), n) for n in [5, 20, 50])) 55 | def test_secure_sum_many_clients(self, num_clients): 56 | secure_paillier_addition = make_integer_secure_sum([1, 1]) 57 | with _install_executor(factory.local_paillier_executor_factory()): 58 | result = secure_paillier_addition([[[1]]] * num_clients) 59 | self.assertAlmostEqual(result, num_clients) 60 | 61 | @parameterized.named_parameters( 62 | ('{}x{}'.format(r, c), r, c) for r, c in [(1, 1), (2, 2), (5, 5), (10, 10)]) 63 | def test_secure_sum_larger_matrices(self, first_dim, second_dim): 64 | NUM_CLIENTS = 5 65 | shape = (first_dim, second_dim) 66 | input_tensor = np.ones(shape, dtype=np.int32) 67 | member_type = tff.TensorType(tf.int32, shape) 68 | @tff.federated_computation(tff.FederatedType(member_type, tff.CLIENTS)) 69 | def secure_paillier_addition(x): 70 | return tff.federated_secure_sum(x, 64) 71 | 72 | with _install_executor(factory.local_paillier_executor_factory()): 73 | result = secure_paillier_addition([input_tensor] * NUM_CLIENTS) 74 | expected = input_tensor * NUM_CLIENTS 75 | np.testing.assert_almost_equal(result, expected) 76 | -------------------------------------------------------------------------------- /federated_aggregations/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import tensorflow_federated as tff 4 | 5 | def materialize_computation_from_cache( 6 | factory_func, cache, arg_spec, **factory_kwargs): 7 | """Materialize a tf_computation generated by factory_func. 8 | 9 | If this function has already been called with a given combination of 10 | factory_func, arg_spec, and factory_kwargs, the resulting function proto & 11 | type spec will be retreived from cache instead of re-tracing. 12 | """ 13 | 14 | hashable_arg_spec = tuple(( 15 | *(repr(arg) for _, arg in factory_kwargs.items()), 16 | *(x.compact_representation() for x in arg_spec))) 17 | fn_proto, fn_type = cache.get(hashable_arg_spec, (None, None)) 18 | if fn_proto is None: 19 | func = factory_func(*arg_spec, **factory_kwargs) 20 | if len(arg_spec) > 1: 21 | arg_spec = tff.StructType(arg_spec) 22 | else: 23 | arg_spec = arg_spec[0] 24 | fn_proto, fn_type = lift_to_computation_spec( 25 | func, input_arg_type=arg_spec) 26 | cache[hashable_arg_spec] = (fn_proto, fn_type) 27 | return fn_proto, fn_type 28 | 29 | 30 | def lift_to_computation_spec(tf_func, input_arg_type=None): 31 | """Determine computation definition & type spec from a tf_computation. 32 | 33 | If tf_func is polymorphic, first make it concrete with input_arg_type. 34 | """ 35 | if not hasattr(tf_func, '_computation_proto'): 36 | if input_arg_type is None: 37 | raise ValueError('Polymorphic tf_computation requires arg_type to ' 38 | 'be made concrete.') 39 | tf_func = tf_func.fn_for_argument_type(input_arg_type) 40 | return tf_func._computation_proto, tf_func.type_signature 41 | -------------------------------------------------------------------------------- /federated_aggregations/version.py: -------------------------------------------------------------------------------- 1 | """TFF Aggregations version.""" 2 | __version__ = '0.0.1' -------------------------------------------------------------------------------- /protocol.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tf-encrypted/federated-aggregations/b4ab7a15c2719d4119db7d9d609f8c06d9df8958/protocol.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-federated~=0.16.1 2 | tf-encrypted-primitives~=0.1.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """TFF Aggregations is a lib for secure aggregation in TensorFlow Federated. 2 | 3 | TFF Aggregations uses primitives from 4 | TF Encrypted (https://github.com/tf-encrypted/tf-encrypted) to define and 5 | execute secure versions of federated aggregation functions from TFF's 6 | Federated Core.""" 7 | import setuptools 8 | 9 | DOCLINES = __doc__.split('\n') 10 | REQUIRED_PACKAGES = [ 11 | 'tensorflow-federated>=0.16.1', 12 | 'tf-encrypted-primitives>=0.1.0', 13 | ] 14 | 15 | with open('federated_aggregations/version.py') as fp: 16 | globals_dict = {} 17 | exec(fp.read(), globals_dict) # pylint: disable=exec-used 18 | VERSION = globals_dict['__version__'] 19 | 20 | setuptools.setup( 21 | name='federated_aggregations', 22 | version=VERSION, 23 | packages=setuptools.find_packages(exclude=('examples')), 24 | description=DOCLINES[0], 25 | long_description='\n'.join(DOCLINES[2:]), 26 | long_description_content_type='text/markdown', 27 | author='The TF Encrypted Authors', 28 | author_email='contact@tf-encrypted.io', 29 | url='https://github.com/tf-encrypted/federated-aggregations', 30 | download_url='https://github.com/tf-encrypted/federated-aggregations/tags', 31 | install_requires=REQUIRED_PACKAGES, 32 | # PyPI package information. 33 | classifiers=[ 34 | "Programming Language :: Python :: 3", 35 | "License :: OSI Approved :: Apache Software License", 36 | "Development Status :: 2 - Pre-Alpha", 37 | "Operating System :: OS Independent", 38 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 39 | "Topic :: Security :: Cryptography", 40 | ], 41 | license='Apache 2.0', 42 | keywords='tensorflow encrypted secure paillier federated machine learning', 43 | ) 44 | 45 | --------------------------------------------------------------------------------