├── configs ├── __init__.py └── default.yaml ├── crypten ├── mpc │ ├── provider │ │ ├── tuple_cache │ │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── homomorphic_provider.py │ │ ├── tfp_provider.py │ │ └── provider.py │ ├── primitives │ │ ├── __init__.py │ │ ├── converters.py │ │ ├── replicated.py │ │ ├── circuit.py │ │ ├── ot │ │ │ └── baseOT.py │ │ └── beaver.py │ ├── ptype.py │ ├── __init__.py │ ├── context.py │ └── mpc.py ├── cuda │ └── __init__.py ├── common │ ├── __init__.py │ ├── functions │ │ ├── __init__.py │ │ ├── dropout.py │ │ ├── logic.py │ │ ├── sampling.py │ │ ├── power.py │ │ └── regular.py │ ├── tensor_types.py │ ├── rng.py │ ├── util.py │ └── serial.py ├── config │ ├── __init__.py │ └── config.py ├── nn │ ├── privacy │ │ └── __init__.py │ ├── init.py │ ├── distances.py │ ├── tensorboard.py │ ├── __init__.py │ └── loss.py ├── optim │ ├── __init__.py │ ├── optimizer.py │ └── sgd.py ├── debug │ ├── __init__.py │ └── debug.py ├── autograd_cryptensor.py ├── communicator │ ├── __init__.py │ ├── in_process_communicator.py │ └── communicator.py ├── encoder.py └── models │ └── __init__.py ├── .gitattributes ├── examples ├── text-generation │ ├── requirements.txt │ ├── test_gpt2_128_comm.sh │ ├── test_gpt2_64_comm.sh │ ├── test_gpt2_128_comp.sh │ ├── test_gpt2_64_comp.sh │ ├── README.md │ └── multiprocess_launcher.py ├── image-classification │ ├── requirements.txt │ ├── test_vit_base_224_comm.sh │ ├── test_vit_base_224_comp.sh │ ├── README.md │ ├── multiprocess_launcher.py │ └── run_image_classification_private.py ├── text-classification │ ├── requirements.txt │ ├── test_bert_base_plain.sh │ ├── test_bert_base_acc.sh │ ├── test_bert_base_128_comm.sh │ ├── test_bert_large_128_comm.sh │ ├── test_bert_base_128_comp.sh │ ├── test_bert_large_128_comp.sh │ ├── README.md │ └── multiprocess_launcher.py ├── unit-test │ ├── README.md │ ├── run_test_softmax.py │ ├── run_test_gelu.py │ └── multiprocess_launcher.py └── ttp-test │ ├── README.md │ ├── run_test_bert.py │ └── run_test_ttp.py ├── requirements.txt ├── setup.py ├── README.md ├── LICENSE └── .gitignore /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /crypten/mpc/provider/tuple_cache/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /crypten/cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .cuda_tensor import CUDALongTensor 2 | 3 | 4 | __all__ = ["CUDALongTensor"] 5 | -------------------------------------------------------------------------------- /examples/text-generation/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | sentencepiece != 0.1.92 3 | protobuf 4 | torch >= 1.3 5 | -------------------------------------------------------------------------------- /examples/image-classification/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | torch>=1.5.0 3 | torchvision>=0.6.0 4 | datasets>=2.14.0 5 | evaluate -------------------------------------------------------------------------------- /examples/text-classification/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.34.2 2 | datasets >= 1.8.0 3 | sentencepiece != 0.1.92 4 | scipy 5 | scikit-learn 6 | protobuf 7 | torch >= 1.3 8 | evaluate -------------------------------------------------------------------------------- /examples/text-generation/test_gpt2_128_comm.sh: -------------------------------------------------------------------------------- 1 | python run_generation_private.py \ 2 | --model_type=gpt2 \ 3 | --model_name_or_path=openai-community/gpt2 \ 4 | --len_data 128 \ 5 | --length 1 -------------------------------------------------------------------------------- /examples/text-generation/test_gpt2_64_comm.sh: -------------------------------------------------------------------------------- 1 | python run_generation_private.py \ 2 | --model_type=gpt2 \ 3 | --model_name_or_path=openai-community/gpt2 \ 4 | --len_data 64 \ 5 | --length 1 -------------------------------------------------------------------------------- /examples/image-classification/test_vit_base_224_comm.sh: -------------------------------------------------------------------------------- 1 | python run_image_classification_private.py \ 2 | --model_name_or_path google/vit-base-patch16-224 \ 3 | --max_eval_samples 1 \ 4 | --seed 42 -------------------------------------------------------------------------------- /examples/text-generation/test_gpt2_128_comp.sh: -------------------------------------------------------------------------------- 1 | python run_generation_private.py \ 2 | --model_type=gpt2 \ 3 | --model_name_or_path=openai-community/gpt2 \ 4 | --len_data 128 \ 5 | --comp \ 6 | --length 1 -------------------------------------------------------------------------------- /examples/text-generation/test_gpt2_64_comp.sh: -------------------------------------------------------------------------------- 1 | python run_generation_private.py \ 2 | --model_type=gpt2 \ 3 | --model_name_or_path=openai-community/gpt2 \ 4 | --len_data 64 \ 5 | --comp \ 6 | --length 1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchvision>=0.9.1 3 | numpy<2 4 | omegaconf>=2.0.6 5 | onnx>=1.7.0 6 | pandas>=1.2.2 7 | pyyaml>=5.3.1 8 | tensorboard 9 | future 10 | scipy>=1.6.0 11 | scikit-learn 12 | -------------------------------------------------------------------------------- /examples/image-classification/test_vit_base_224_comp.sh: -------------------------------------------------------------------------------- 1 | python run_image_classification_private.py \ 2 | --model_name_or_path google/vit-base-patch16-224 \ 3 | --max_eval_samples 1 \ 4 | --comp \ 5 | --seed 42 -------------------------------------------------------------------------------- /examples/text-classification/test_bert_base_plain.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=sst2 2 | 3 | python run_glue_eval.py \ 4 | --model_name_or_path andeskyl/bert-base-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --max_length 128 \ 7 | --per_device_eval_batch_size 8 \ 8 | --output_dir eval/$TASK_NAME/ 9 | -------------------------------------------------------------------------------- /crypten/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | __all__ = ["functions", "rng", "tensor_types", "util", "serial"] 9 | -------------------------------------------------------------------------------- /examples/text-classification/test_bert_base_acc.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=sst2 2 | 3 | python run_glue_private.py \ 4 | --model_name_or_path andeskyl/bert-base-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --max_length 128 \ 7 | --acc \ 8 | --per_device_eval_batch_size 1 \ 9 | --output_dir eval_private/$TASK_NAME/ 10 | -------------------------------------------------------------------------------- /examples/text-classification/test_bert_base_128_comm.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=qnli 2 | 3 | python run_glue_private.py \ 4 | --model_name_or_path andeskyl/bert-base-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --len_data 128 \ 7 | --max_length 128 \ 8 | --per_device_eval_batch_size 1 \ 9 | --output_dir eval_private/$TASK_NAME/ 10 | -------------------------------------------------------------------------------- /examples/text-classification/test_bert_large_128_comm.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=qnli 2 | 3 | python run_glue_private.py \ 4 | --model_name_or_path andeskyl/bert-large-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --len_data 128 \ 7 | --max_length 128 \ 8 | --per_device_eval_batch_size 1 \ 9 | --output_dir eval_private/$TASK_NAME/ 10 | -------------------------------------------------------------------------------- /crypten/config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .config import CrypTenConfig 9 | 10 | cfg = CrypTenConfig() 11 | 12 | __all__ = ["cfg"] 13 | -------------------------------------------------------------------------------- /examples/text-classification/test_bert_base_128_comp.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=qnli 2 | 3 | python run_glue_private.py \ 4 | --model_name_or_path andeskyl/bert-base-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --len_data 128 \ 7 | --max_length 128 \ 8 | --comp \ 9 | --per_device_eval_batch_size 1 \ 10 | --output_dir eval_private/$TASK_NAME/ 11 | -------------------------------------------------------------------------------- /crypten/nn/privacy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .dp_split import DPSplitModel, SkippedLoss 9 | 10 | __all__ = ["DPSplitModel", "SkippedLoss"] 11 | -------------------------------------------------------------------------------- /examples/text-classification/test_bert_large_128_comp.sh: -------------------------------------------------------------------------------- 1 | export TASK_NAME=qnli 2 | 3 | python run_glue_private.py \ 4 | --model_name_or_path andeskyl/bert-large-cased-$TASK_NAME \ 5 | --task_name $TASK_NAME \ 6 | --len_data 128 \ 7 | --max_length 128 \ 8 | --comp \ 9 | --per_device_eval_batch_size 1 \ 10 | --output_dir eval_private/$TASK_NAME/ 11 | -------------------------------------------------------------------------------- /crypten/optim/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .optimizer import Optimizer 9 | from .sgd import SGD 10 | 11 | 12 | __all__ = ["Optimizer", "SGD", "Adam", "AdamW"] 13 | -------------------------------------------------------------------------------- /examples/unit-test/README.md: -------------------------------------------------------------------------------- 1 | # Performance of Private Softmax and GELU 2 | This directory evaluates the costs of private softmax and GELU protocols. 3 | 4 | ## Softmax Performance 5 | Evaluates the performance of private softmax protocol: 6 | ```bash 7 | python3 run_test_softmax.py 8 | ``` 9 | 10 | ## GELU Performance 11 | Evaluates the performance of private GELU protocol: 12 | ```bash 13 | python3 run_test_gelu.py 14 | ``` -------------------------------------------------------------------------------- /crypten/mpc/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .arithmetic import ArithmeticSharedTensor 9 | from .binary import BinarySharedTensor 10 | 11 | 12 | __all__ = ["ArithmeticSharedTensor", "BinarySharedTensor"] 13 | -------------------------------------------------------------------------------- /examples/ttp-test/README.md: -------------------------------------------------------------------------------- 1 | # Test For Trust Third Party Provider 2 | 3 | ## run 4 | 5 | ```shell 6 | python3 run_test_ttp.py 7 | ``` 8 | 9 | ```shell 10 | python3 run_test_bert.py 11 | ``` 12 | 13 | ## use 14 | 15 | specify in the code 16 | 17 | ```shell 18 | import crypten 19 | 20 | crypten.cfg.mpc.provider = "TTP" 21 | ``` 22 | 23 | or edit the configs/default.yaml 24 | 25 | ```yaml 26 | mpc: 27 | provider: "TTP" # default is TFP 28 | ``` -------------------------------------------------------------------------------- /crypten/mpc/provider/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .homomorphic_provider import HomomorphicProvider 9 | from .tfp_provider import TrustedFirstParty 10 | from .ttp_provider import TrustedThirdParty, TTPServer 11 | 12 | 13 | __all__ = ["TrustedFirstParty", "TrustedThirdParty", "TTPServer", "HomomorphicProvider"] 14 | -------------------------------------------------------------------------------- /examples/image-classification/README.md: -------------------------------------------------------------------------------- 1 | # Private Inference Cost of ViT-base 2 | This directory evaluates the private inference cost of ViT-base. 3 | ## Preparation 4 | Install dependencies: 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | ## Running Experiments 9 | Computation cost of private ViT-base inference for a 224×224 RGB image: 10 | ```bash 11 | bash test_vit_base_224_comp.sh 12 | ``` 13 | Communication cost of private ViT-base inference for a 224×224 RGB image: 14 | ```bash 15 | bash test_vit_base_224_comm.sh 16 | ``` -------------------------------------------------------------------------------- /crypten/common/functions/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from . import approximations, dropout, logic, maximum, pooling, power, regular, sampling 9 | 10 | __all__ = [ 11 | "approximations", 12 | "dropout", 13 | "logic", 14 | "maximum", 15 | "pooling", 16 | "power", 17 | "regular", 18 | "sampling", 19 | ] 20 | -------------------------------------------------------------------------------- /examples/text-generation/README.md: -------------------------------------------------------------------------------- 1 | # Private Inference Cost of GPT-2 2 | This directory evaluates the cost of generating one token with GPT-2 in private. 3 | ## Preparation 4 | Install dependencies: 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | ## Running Experiments 9 | Computation cost of private GPT-2 inference for a length-64 input: 10 | ```bash 11 | bash test_gpt2_64_comp.sh 12 | ``` 13 | Communication cost of private GPT-2 inference for a length-64 input: 14 | ```bash 15 | bash test_gpt2_64_comm.sh 16 | ``` 17 | Computation cost of private GPT-2 inference for a length-128 input: 18 | ```bash 19 | bash test_gpt2_128_comp.sh 20 | ``` 21 | Communication cost of private GPT-2 inference for a length-128 input: 22 | ```bash 23 | bash test_gpt2_128_comm.sh 24 | ``` 25 | -------------------------------------------------------------------------------- /crypten/mpc/ptype.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from enum import Enum 9 | 10 | from .primitives import ArithmeticSharedTensor, BinarySharedTensor 11 | 12 | 13 | class ptype(Enum): 14 | """Enumeration defining the private type attributes of encrypted tensors""" 15 | 16 | arithmetic = 0 17 | binary = 1 18 | 19 | def to_tensor(self): 20 | if self.value == 0: 21 | return ArithmeticSharedTensor 22 | elif self.value == 1: 23 | return BinarySharedTensor 24 | else: 25 | raise ValueError("Cannot convert %s to encrypted tensor" % (self.name)) 26 | -------------------------------------------------------------------------------- /crypten/debug/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from functools import wraps 10 | 11 | from crypten.config import cfg 12 | 13 | from .debug import configure_logging, MultiprocessingPdb, validate_correctness 14 | 15 | 16 | pdb = MultiprocessingPdb() 17 | 18 | __all__ = ["pdb", "configure_logging", "validate_correctness", "validate_decorator"] 19 | 20 | 21 | def register_validation(getattr_function): 22 | @wraps(getattr_function) 23 | def validate_attribute(self, name): 24 | # Get dispatched function call 25 | function = getattr_function(self, name) 26 | 27 | if not cfg.debug.validation_mode: 28 | return function 29 | 30 | # Run validation 31 | return validate_correctness(self, function, name) 32 | 33 | return validate_attribute 34 | -------------------------------------------------------------------------------- /examples/text-classification/README.md: -------------------------------------------------------------------------------- 1 | # Private Inference Cost of BERT-base and BERT-large 2 | This directory evaluates the private inference costs of BERT-base and BERT-large. 3 | ## Preparation 4 | Install dependencies: 5 | ```bash 6 | pip install -r requirements.txt 7 | ``` 8 | ## Running Experiments 9 | Computation cost of private BERT-base inference for a length-128 input: 10 | ```bash 11 | bash test_bert_base_128_comp.sh 12 | ``` 13 | Communication cost of private BERT-base inference for a length-128 input: 14 | ```bash 15 | bash test_bert_base_128_comm.sh 16 | ``` 17 | Computation cost of private BERT-large inference for a length-128 input: 18 | ```bash 19 | bash test_bert_large_128_comp.sh 20 | ``` 21 | Communication cost of private BERT-large inference for a length-128 input: 22 | ```bash 23 | bash test_bert_large_128_comm.sh 24 | ``` 25 | Private inference accuracy of BERT-base: 26 | ```bash 27 | bash test_bert_base_acc.sh 28 | ``` 29 | (Optional) Plaintext inference accuracy of BERT-base: 30 | ```bash 31 | bash test_bert_base_plain.sh 32 | ``` 33 | -------------------------------------------------------------------------------- /crypten/mpc/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from crypten.config import cfg 9 | from crypten.mpc import primitives, provider # noqa: F401 10 | 11 | from .context import run_multiprocess 12 | from .mpc import MPCTensor 13 | from .ptype import ptype 14 | 15 | 16 | __all__ = [ 17 | "MPCTensor", 18 | "primitives", 19 | "provider", 20 | "ptype", 21 | "run_multiprocess", 22 | ] 23 | 24 | # the different private type attributes of an mpc encrypted tensor 25 | arithmetic = ptype.arithmetic 26 | binary = ptype.binary 27 | 28 | # Set provider 29 | __SUPPORTED_PROVIDERS = { 30 | "TFP": provider.TrustedFirstParty(), 31 | "TTP": provider.TrustedThirdParty(), 32 | "HE": provider.HomomorphicProvider(), 33 | } 34 | 35 | 36 | def get_default_provider(): 37 | return __SUPPORTED_PROVIDERS[cfg.mpc.provider] 38 | 39 | 40 | def ttp_required(): 41 | return cfg.mpc.provider == "TTP" 42 | -------------------------------------------------------------------------------- /crypten/nn/init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | 10 | 11 | # Makes nn.init functions 12 | def make_crypten_compatible(initialization_function): 13 | def wrapper_func(tensor, *args, **kwargs): 14 | if not torch.is_tensor(tensor): 15 | result = torch.empty(tensor.size()) 16 | result = initialization_function(result, *args, **kwargs) 17 | tensor.set(result) 18 | return tensor 19 | 20 | return initialization_function(tensor, *args, **kwargs) 21 | 22 | return wrapper_func 23 | 24 | 25 | __all__ = [ # noqa: F822 26 | "constant_", 27 | "dirac_", 28 | "kaiming_normal_", 29 | "kaiming_uniform_", 30 | "normal_", 31 | "ones_", 32 | "orthogonal_", 33 | "sparse_", 34 | "trunc_normal_", 35 | "uniform_", 36 | "xavier_normal_", 37 | "xavier_uniform_", 38 | "zeros_", 39 | ] 40 | 41 | 42 | for func_name in __all__: 43 | globals()[func_name] = make_crypten_compatible(getattr(torch.nn.init, func_name)) 44 | -------------------------------------------------------------------------------- /crypten/common/tensor_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from crypten.cuda import CUDALongTensor 10 | 11 | 12 | # helper functions that determine if input is float, int, or base tensor: 13 | def _is_type_tensor(tensor, types): 14 | """Checks whether the elements of the input tensor are of a given type""" 15 | if is_tensor(tensor): 16 | if any(tensor.dtype == type_ for type_ in types): 17 | return True 18 | return False 19 | 20 | 21 | def is_tensor(tensor): 22 | """Checks if the input tensor is a Torch tensor or a CUDALongTensor""" 23 | return torch.is_tensor(tensor) or isinstance(tensor, CUDALongTensor) 24 | 25 | 26 | def is_float_tensor(tensor): 27 | """Checks if the input tensor is a Torch tensor of a float type.""" 28 | return _is_type_tensor(tensor, [torch.float16, torch.float32, torch.float64]) 29 | 30 | 31 | def is_int_tensor(tensor): 32 | """Checks if the input tensor is a Torch tensor of an int type.""" 33 | return _is_type_tensor( 34 | tensor, [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64] 35 | ) 36 | -------------------------------------------------------------------------------- /crypten/autograd_cryptensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import torch 10 | 11 | from .gradients import AutogradContext as _AutogradContext 12 | 13 | 14 | class AutogradContext(_AutogradContext): 15 | """ 16 | DEPRECATED: Object used by AutogradFunctions for saving context information. 17 | """ 18 | 19 | def __init__(self): 20 | raise DeprecationWarning( 21 | "crypten.autograd_cryptensor.AutogradContext is deprecated. Please " 22 | "use crypten.gradients.AutogradContext instead." 23 | ) 24 | super().__init__(self) 25 | 26 | 27 | def AutogradCrypTensor(tensor, requires_grad=True): 28 | """ 29 | DEPRECATED: CrypTensor with support for autograd, akin to the `Variable` 30 | originally in PyTorch. 31 | """ 32 | raise DeprecationWarning( 33 | "AutogradCrypTensor is deprecated. Please set the " 34 | "requires_grad attribute on the CrypTensor instead." 35 | ) 36 | if torch.is_tensor(tensor): 37 | tensor = crypten.cryptensor(tensor) 38 | tensor.requires_grad = requires_grad 39 | return tensor 40 | -------------------------------------------------------------------------------- /examples/unit-test/run_test_softmax.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import crypten 4 | from multiprocess_launcher import MultiProcessLauncher 5 | 6 | def main(): 7 | crypten.init() 8 | device = "cuda" 9 | runs = 10 10 | softmax_time, softmax_bytes, softmax_rounds = {}, {}, {} 11 | 12 | # softmax test 13 | for softmax_l in [32, 64, 128, 256]: 14 | softmax_in = crypten.cryptensor(torch.zeros([softmax_l]), device=device) 15 | crypten.reset_communication_stats() 16 | start_time = time.time() 17 | for _ in range(runs): 18 | softmax_in.softmax(-1) 19 | softmax_time[softmax_l] = time.time() - start_time 20 | stats = crypten.get_communication_stats() 21 | softmax_bytes[softmax_l] = stats["bytes"] 22 | softmax_rounds[softmax_l] = stats["rounds"] 23 | 24 | if crypten.comm.get().get_rank() == 0: 25 | for softmax_l in [32, 64, 128, 256]: 26 | print(f"l={softmax_l} " 27 | f"time: {softmax_time[softmax_l] / runs:.4f}s, " 28 | f"bytes: {softmax_bytes[softmax_l] / 1048576 / runs:.4f} MB, " 29 | f"rounds: {softmax_rounds[softmax_l] / runs:.0f}") 30 | 31 | if __name__ == "__main__": 32 | launcher = MultiProcessLauncher(2, main) 33 | launcher.start() 34 | launcher.join() 35 | launcher.terminate() -------------------------------------------------------------------------------- /crypten/mpc/provider/homomorphic_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .provider import TupleProvider 9 | 10 | 11 | class HomomorphicProvider(TupleProvider): 12 | NAME = "HE" 13 | 14 | def generate_additive_triple(self, size0, size1, op, *args, **kwargs): 15 | """Generate multiplicative triples of given sizes""" 16 | raise NotImplementedError("HomomorphicProvider not implemented") 17 | 18 | def square(self, size): 19 | """Generate square double of given size""" 20 | raise NotImplementedError("HomomorphicProvider not implemented") 21 | 22 | def generate_xor_triple(self, size0, size1): 23 | """Generate xor triples of given size""" 24 | raise NotImplementedError("HomomorphicProvider not implemented") 25 | 26 | def wrap_rng(self, size, num_parties): 27 | """Generate random shared tensor of given size and sharing of its wraps""" 28 | raise NotImplementedError("HomomorphicProvider not implemented") 29 | 30 | def B2A_rng(self, size): 31 | """Generate random bit tensor as arithmetic and binary shared tensors""" 32 | raise NotImplementedError("HomomorphicProvider not implemented") 33 | -------------------------------------------------------------------------------- /crypten/nn/distances.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from .loss import _Loss 10 | from .module import Module 11 | 12 | 13 | class CosineSimilarity(Module): 14 | r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along dim. 15 | 16 | .. math :: 17 | \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. 18 | 19 | Args: 20 | dim (int, optional): Dimension where cosine similarity is computed. Default: 1 21 | eps (float, optional): Not used in CrypTen 22 | Shape: 23 | - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` 24 | - Input2: :math:`(\ast_1, D, \ast_2)`, same shape as the Input1 25 | - Output: :math:`(\ast_1, \ast_2)` 26 | Examples:: 27 | >>> input1 = crypten.randn(100, 128) 28 | >>> input2 = crypten.randn(100, 128) 29 | >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) 30 | >>> output = cos(input1, input2) 31 | """ 32 | 33 | def __init__(self, dim=1, eps=1e-8): 34 | super(CosineSimilarity, self).__init__() 35 | self.dim = dim 36 | 37 | def forward(self, x1, x2): 38 | return x1.cosine_similarity(x2, self.dim) 39 | 40 | # Remove need to call module.encrypt() 41 | __getattribute__ = _Loss.__getattribute__ 42 | -------------------------------------------------------------------------------- /crypten/communicator/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from .communicator import Communicator 9 | from .distributed_communicator import DistributedCommunicator 10 | from .in_process_communicator import InProcessCommunicator 11 | 12 | 13 | __use_threads = False 14 | 15 | 16 | def get(): 17 | cls = InProcessCommunicator if __use_threads else DistributedCommunicator 18 | if not cls.is_initialized(): 19 | raise RuntimeError("Crypten not initialized. Please call crypten.init() first.") 20 | 21 | return cls.get() 22 | 23 | 24 | def _init(use_threads, rank=0, world_size=1, init_ttp=False): 25 | global __tls, __use_threads 26 | __use_threads = use_threads 27 | cls = InProcessCommunicator if __use_threads else DistributedCommunicator 28 | 29 | if cls.is_initialized(): 30 | return 31 | 32 | cls.initialize(rank, world_size, init_ttp=init_ttp) 33 | 34 | 35 | def uninit(): 36 | global __use_threads 37 | cls = InProcessCommunicator if __use_threads else DistributedCommunicator 38 | cls.shutdown() 39 | __use_threads = False 40 | 41 | 42 | def is_initialized(): 43 | cls = InProcessCommunicator if __use_threads else DistributedCommunicator 44 | return cls.is_initialized() 45 | 46 | 47 | # expose classes and functions in package: 48 | __all__ = ["Communicator", "DistributedCommunicator", "get", "uninit", "is_initialized"] 49 | -------------------------------------------------------------------------------- /crypten/common/functions/dropout.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | 10 | import crypten 11 | 12 | __all__ = ["dropout"] 13 | 14 | 15 | def dropout(self, p=0.5, training=True, inplace=False): 16 | r""" 17 | Randomly zeroes some of the elements of the input tensor with 18 | probability :attr:`p`. 19 | 20 | Args: 21 | p: probability of a channel to be zeroed. Default: 0.5 22 | training: apply dropout if is ``True``. Default: ``True`` 23 | inplace: If set to ``True``, will do this operation in-place. 24 | Default: ``False`` 25 | """ 26 | if p == 0.0: 27 | return self 28 | elif p == 1.0: 29 | return self - self 30 | 31 | assert p > 0.0 and p < 1.0, "dropout probability has to be between 0 and 1" 32 | if training and inplace: 33 | logging.warning( 34 | "CrypTen dropout does not support inplace computation during training." 35 | ) 36 | 37 | if not training: 38 | if inplace: 39 | return self 40 | else: 41 | return self.clone() 42 | 43 | rand_tensor = crypten.rand(self.size(), device=self.device) 44 | dropout_tensor = rand_tensor > p 45 | if inplace: 46 | result_tensor = self.div_(1 - p) 47 | result_tensor = result_tensor.mul_(dropout_tensor) 48 | else: 49 | result_tensor = self.div(1 - p) 50 | result_tensor = result_tensor.mul(dropout_tensor) 51 | return result_tensor 52 | -------------------------------------------------------------------------------- /examples/unit-test/run_test_gelu.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import crypten 4 | from multiprocess_launcher import MultiProcessLauncher 5 | 6 | def main(): 7 | crypten.init() 8 | device = "cuda" 9 | runs = 10 10 | gelu_time, gelu_bytes, gelu_rounds = {}, {}, {} 11 | approximate = "none" 12 | 13 | x = torch.arange(-5, 5, 0.001) 14 | y_original = torch.nn.functional.gelu(x, approximate=approximate) 15 | y_actual = crypten.cryptensor(x).gelu(approximate=approximate).get_plain_text() 16 | max_err = (y_original - y_actual).abs().max() 17 | avg_err = (y_original - y_actual).abs().mean() 18 | 19 | for gelu_size in [(128, 3072), (128, 4096)]: 20 | gelu_in = crypten.cryptensor(torch.zeros(gelu_size), device=device) 21 | crypten.reset_communication_stats() 22 | start_time = time.time() 23 | 24 | for _ in range(runs): 25 | gelu_in.gelu(approximate=approximate) 26 | gelu_time[gelu_size[1]] = time.time() - start_time 27 | stats = crypten.get_communication_stats() 28 | gelu_bytes[gelu_size[1]] = stats["bytes"] 29 | gelu_rounds[gelu_size[1]] = stats["rounds"] 30 | 31 | if crypten.comm.get().get_rank() == 0: 32 | print(f"max error: {max_err:.4f}, avg error: {avg_err:.6f}") 33 | for gelu_size in [[128, 3072], [128, 4096]]: 34 | print(f"({gelu_size[0]}, {gelu_size[1]}) " 35 | f"time: {gelu_time[gelu_size[1]] / runs:.4f}s, " 36 | f"bytes: {gelu_bytes[gelu_size[1]] / 1048576 / runs:.0f} MB, " 37 | f"rounds: {gelu_rounds[gelu_size[1]] / runs:.0f}" 38 | ) 39 | 40 | if __name__ == "__main__": 41 | launcher = MultiProcessLauncher(2, main) 42 | launcher.start() 43 | launcher.join() 44 | launcher.terminate() -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | communicator: 2 | verbose: True 3 | debug: 4 | report_cost: False 5 | debug_mode: False 6 | validation_mode: False 7 | encoder: 8 | precision_bits: 16 9 | cost: 10 | estimate_cost: True 11 | estimate_mode: "comm" 12 | functions: 13 | max_method: "log_reduction" 14 | 15 | # exponential function 16 | exp_method: "limit" 17 | exp_iterations: 8 18 | 19 | # reciprocal configuration 20 | reciprocal_method: "NR" 21 | reciprocal_nr_iters: 10 22 | reciprocal_log_iters: 1 23 | reciprocal_all_pos: False 24 | reciprocal_initial: null 25 | 26 | # sqrt configuration 27 | sqrt_method: "NR" 28 | sqrt_nr_iters: 5 29 | sqrt_nr_initial: null 30 | 31 | # sigmoid / tanh configuration 32 | sigmoid_tanh_method: "reciprocal" 33 | sigmoid_tanh_terms: 32 34 | sigmoid_fs_m: 5 35 | sigmoid_fs_terms: 5 36 | tanh_fs_m: 4 37 | tanh_fs_terms: 5 38 | tanh_ode_iter_num: 1024 39 | 40 | # softmax configuration 41 | softmax_method: "ode" 42 | softmax_ode_clip: True 43 | softmax_ode_iter_num: 16 44 | softmax_ode_ub: 12 45 | softmax_ode_lb: -4 46 | 47 | # log configuration 48 | log_iterations: 2 49 | log_exp_iterations: 8 50 | log_order: 8 51 | 52 | # trigonometry configuration 53 | trig_iterations: 10 54 | 55 | # error function configuration 56 | erf_method: "fourier" 57 | erf_iterations: 8 58 | erf_fs_period: 16 59 | erf_fs_terms: 5 60 | 61 | # gelu function configuration 62 | gelu_method: "fourier" 63 | gelu_fs_period: 8 64 | gelu_fs_terms: 8 65 | 66 | # silu function configuration 67 | silu_method: "fourier" 68 | silu_fs_period: 16 69 | silu_fs_terms: 12 70 | mpc: 71 | active_security: False 72 | provider: "TFP" 73 | protocol: "beaver" 74 | nn: 75 | dpsmpc: 76 | protocol: "layer_estimation" 77 | skip_loss_forward: True 78 | cache_pred_size: True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Code modified by SHAFT's team: Updated package information 4 | # 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # 7 | # This source code is licensed under the MIT license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import os.path 11 | import re 12 | import sys 13 | 14 | import setuptools 15 | 16 | 17 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "crypten")) 18 | 19 | # Read description and requirements. 20 | with open("README.md", encoding="utf8") as f: 21 | readme = f.read() 22 | with open("requirements.txt") as f: 23 | reqs = f.read() 24 | 25 | # get version string from module 26 | init_path = os.path.join(os.path.dirname(__file__), "crypten/__init__.py") 27 | with open(init_path, "r") as f: 28 | version = re.search(r"__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M).group(1) 29 | 30 | # Set key package information. 31 | DISTNAME = "shaft" 32 | DESCRIPTION = "Secure, Handy, Accurate, and Fast Transformer inference." 33 | LONG_DESCRIPTION = readme 34 | AUTHOR = "SHAFT's team" 35 | LICENSE = "MIT licensed, as found in the LICENSE file" 36 | REQUIREMENTS = (reqs.strip().split("\n"),) 37 | VERSION = version 38 | 39 | # Run installer. 40 | if __name__ == "__main__": 41 | if sys.version_info < (3, 7): 42 | sys.exit("Sorry, Python >=3.7 is required for CrypTen.") 43 | 44 | setuptools.setup( 45 | name=DISTNAME, 46 | install_requires=REQUIREMENTS, 47 | packages=setuptools.find_packages(), 48 | dependency_links=[], 49 | version=VERSION, 50 | description=DESCRIPTION, 51 | long_description=LONG_DESCRIPTION, 52 | long_description_content_type="text/markdown", 53 | url="", 54 | author=AUTHOR, 55 | license=LICENSE, 56 | setup_requires=["pytest-runner"], 57 | tests_require=["pytest"], 58 | data_files=[("/configs", ["configs/default.yaml"])], 59 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SHAFT: Secure, Handy, Accurate, and Fast Transformer Inference 2 | This repository implements secure, handy, accurate, and fast transformer inference based on [CrypTen](https://github.com/facebookresearch/CrypTen). 3 | 4 | ## Installing SHAFT 5 | The following commands run successfully on Ubuntu 22.04 with Python 3.10.12. 6 | ### 0. Set up Virtual Environment (Recommended) 7 | ```bash 8 | python3 -m venv ~/env/shaft 9 | source ~/env/shaft/bin/activate 10 | ``` 11 | ### 1. Install Dependencies 12 | ```bash 13 | pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 14 | pip install wheel==0.40.0 15 | ``` 16 | ### 2. Install SHAFT 17 | ```bash 18 | git clone https://github.com/andeskyl/SHAFT 19 | cd SHAFT 20 | pip install . 21 | ``` 22 | 23 | ### 3. Install Transformers (for Hugging Face Integration) 24 | ```bash 25 | git clone -b 'v4.45.0' --depth 1 https://github.com/huggingface/transformers 26 | pip install ./transformers 27 | ``` 28 | 29 | ## Running Experiments 30 | We have a set of sub-directories in the `examples` directory for reproducible experimental results. Additional dependencies for the experiments are included in the `requirements.txt` file in each subdirectory under the folder. Please refer to the `README.md` file in the sub-directories for instructions on how to set up and run the experiments. 31 | 32 | 1. `unit-test` - Costs of private softmax and GELU protocols. 33 | 2. `text-classification` - Private inference costs of BERT-base and BERT-large. 34 | 3. `text-generation` - Private inference cost of GPT-2. 35 | 4. `image-classification` - Private inference cost of ViT-base. 36 | 37 | ## Citation 38 | You can cite our paper as follows: 39 | ```bibtex 40 | @inproceedings{ndss/KeiC25, 41 | author = {Andes Y. L. Kei and Sherman S. M. Chow}, 42 | title = {{SHAFT}: {Secure}, Handy, Accurate, and Fast Transformer Inference}, 43 | booktitle = {{NDSS}}, 44 | year = {2025} 45 | } 46 | ``` 47 | 48 | ## License 49 | SHAFT is MIT licensed, as found in the LICENSE file. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Andes Y. L. Kei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ----------------------------- License for CrypTen ----------------------------- 24 | MIT License 25 | 26 | Copyright (c) Facebook, Inc. and its affiliates. 27 | 28 | Permission is hereby granted, free of charge, to any person obtaining a copy 29 | of this software and associated documentation files (the "Software"), to deal 30 | in the Software without restriction, including without limitation the rights 31 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 32 | copies of the Software, and to permit persons to whom the Software is 33 | furnished to do so, subject to the following conditions: 34 | 35 | The above copyright notice and this permission notice shall be included in all 36 | copies or substantial portions of the Software. 37 | 38 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 39 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 40 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 41 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 42 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 43 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 44 | SOFTWARE. -------------------------------------------------------------------------------- /examples/unit-test/multiprocess_launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import multiprocessing 10 | import os 11 | import uuid 12 | 13 | import crypten 14 | 15 | class MultiProcessLauncher: 16 | 17 | # run_process_fn will be run in subprocesses. 18 | def __init__(self, world_size, run_process_fn, fn_args=None): 19 | env = os.environ.copy() 20 | env["WORLD_SIZE"] = str(world_size) 21 | multiprocessing.set_start_method("spawn") 22 | 23 | # Use random file so multiple jobs can be run simultaneously 24 | INIT_METHOD = "file:///tmp/crypten-rendezvous-{}".format(uuid.uuid1()) 25 | env["RENDEZVOUS"] = INIT_METHOD 26 | 27 | self.processes = [] 28 | for rank in range(world_size): 29 | process_name = "process " + str(rank) 30 | process = multiprocessing.Process( 31 | target=self.__class__._run_process, 32 | name=process_name, 33 | args=(rank, world_size, env, run_process_fn, fn_args), 34 | ) 35 | self.processes.append(process) 36 | 37 | if crypten.mpc.ttp_required(): 38 | ttp_process = multiprocessing.Process( 39 | target=self.__class__._run_process, 40 | name="TTP", 41 | args=( 42 | world_size, 43 | world_size, 44 | env, 45 | crypten.mpc.provider.TTPServer, 46 | None, 47 | ), 48 | ) 49 | self.processes.append(ttp_process) 50 | 51 | @classmethod 52 | def _run_process(cls, rank, world_size, env, run_process_fn, fn_args): 53 | for env_key, env_value in env.items(): 54 | os.environ[env_key] = env_value 55 | os.environ["RANK"] = str(rank) 56 | orig_logging_level = logging.getLogger().level 57 | logging.getLogger().setLevel(logging.INFO) 58 | crypten.init() 59 | logging.getLogger().setLevel(orig_logging_level) 60 | if fn_args is None: 61 | run_process_fn() 62 | else: 63 | run_process_fn(fn_args) 64 | 65 | def start(self): 66 | for process in self.processes: 67 | process.start() 68 | 69 | def join(self): 70 | for process in self.processes: 71 | process.join() 72 | assert ( 73 | process.exitcode == 0 74 | ), f"{process.name} has non-zero exit code {process.exitcode}" 75 | 76 | def terminate(self): 77 | for process in self.processes: 78 | process.terminate() -------------------------------------------------------------------------------- /examples/text-generation/multiprocess_launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import multiprocessing 10 | import os 11 | import uuid 12 | 13 | import crypten 14 | 15 | class MultiProcessLauncher: 16 | 17 | # run_process_fn will be run in subprocesses. 18 | def __init__(self, world_size, run_process_fn, fn_args=None): 19 | env = os.environ.copy() 20 | env["WORLD_SIZE"] = str(world_size) 21 | multiprocessing.set_start_method("spawn") 22 | 23 | # Use random file so multiple jobs can be run simultaneously 24 | INIT_METHOD = "file:///tmp/crypten-rendezvous-{}".format(uuid.uuid1()) 25 | env["RENDEZVOUS"] = INIT_METHOD 26 | 27 | self.processes = [] 28 | for rank in range(world_size): 29 | process_name = "process " + str(rank) 30 | process = multiprocessing.Process( 31 | target=self.__class__._run_process, 32 | name=process_name, 33 | args=(rank, world_size, env, run_process_fn, fn_args), 34 | ) 35 | self.processes.append(process) 36 | 37 | if crypten.mpc.ttp_required(): 38 | ttp_process = multiprocessing.Process( 39 | target=self.__class__._run_process, 40 | name="TTP", 41 | args=( 42 | world_size, 43 | world_size, 44 | env, 45 | crypten.mpc.provider.TTPServer, 46 | None, 47 | ), 48 | ) 49 | self.processes.append(ttp_process) 50 | 51 | @classmethod 52 | def _run_process(cls, rank, world_size, env, run_process_fn, fn_args): 53 | for env_key, env_value in env.items(): 54 | os.environ[env_key] = env_value 55 | os.environ["RANK"] = str(rank) 56 | orig_logging_level = logging.getLogger().level 57 | logging.getLogger().setLevel(logging.INFO) 58 | crypten.init() 59 | logging.getLogger().setLevel(orig_logging_level) 60 | if fn_args is None: 61 | run_process_fn() 62 | else: 63 | run_process_fn(fn_args) 64 | 65 | def start(self): 66 | for process in self.processes: 67 | process.start() 68 | 69 | def join(self): 70 | for process in self.processes: 71 | process.join() 72 | assert ( 73 | process.exitcode == 0 74 | ), f"{process.name} has non-zero exit code {process.exitcode}" 75 | 76 | def terminate(self): 77 | for process in self.processes: 78 | process.terminate() -------------------------------------------------------------------------------- /examples/image-classification/multiprocess_launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import multiprocessing 10 | import os 11 | import uuid 12 | 13 | import crypten 14 | 15 | class MultiProcessLauncher: 16 | 17 | # run_process_fn will be run in subprocesses. 18 | def __init__(self, world_size, run_process_fn, fn_args=None): 19 | env = os.environ.copy() 20 | env["WORLD_SIZE"] = str(world_size) 21 | multiprocessing.set_start_method("spawn") 22 | 23 | # Use random file so multiple jobs can be run simultaneously 24 | INIT_METHOD = "file:///tmp/crypten-rendezvous-{}".format(uuid.uuid1()) 25 | env["RENDEZVOUS"] = INIT_METHOD 26 | 27 | self.processes = [] 28 | for rank in range(world_size): 29 | process_name = "process " + str(rank) 30 | process = multiprocessing.Process( 31 | target=self.__class__._run_process, 32 | name=process_name, 33 | args=(rank, world_size, env, run_process_fn, fn_args), 34 | ) 35 | self.processes.append(process) 36 | 37 | if crypten.mpc.ttp_required(): 38 | ttp_process = multiprocessing.Process( 39 | target=self.__class__._run_process, 40 | name="TTP", 41 | args=( 42 | world_size, 43 | world_size, 44 | env, 45 | crypten.mpc.provider.TTPServer, 46 | None, 47 | ), 48 | ) 49 | self.processes.append(ttp_process) 50 | 51 | @classmethod 52 | def _run_process(cls, rank, world_size, env, run_process_fn, fn_args): 53 | for env_key, env_value in env.items(): 54 | os.environ[env_key] = env_value 55 | os.environ["RANK"] = str(rank) 56 | orig_logging_level = logging.getLogger().level 57 | logging.getLogger().setLevel(logging.INFO) 58 | crypten.init() 59 | logging.getLogger().setLevel(orig_logging_level) 60 | if fn_args is None: 61 | run_process_fn() 62 | else: 63 | run_process_fn(fn_args) 64 | 65 | def start(self): 66 | for process in self.processes: 67 | process.start() 68 | 69 | def join(self): 70 | for process in self.processes: 71 | process.join() 72 | assert ( 73 | process.exitcode == 0 74 | ), f"{process.name} has non-zero exit code {process.exitcode}" 75 | 76 | def terminate(self): 77 | for process in self.processes: 78 | process.terminate() -------------------------------------------------------------------------------- /examples/text-classification/multiprocess_launcher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import multiprocessing 10 | import os 11 | import uuid 12 | 13 | import crypten 14 | 15 | class MultiProcessLauncher: 16 | 17 | # run_process_fn will be run in subprocesses. 18 | def __init__(self, world_size, run_process_fn, fn_args=None): 19 | env = os.environ.copy() 20 | env["WORLD_SIZE"] = str(world_size) 21 | multiprocessing.set_start_method("spawn") 22 | 23 | # Use random file so multiple jobs can be run simultaneously 24 | INIT_METHOD = "file:///tmp/crypten-rendezvous-{}".format(uuid.uuid1()) 25 | env["RENDEZVOUS"] = INIT_METHOD 26 | 27 | self.processes = [] 28 | for rank in range(world_size): 29 | process_name = "process " + str(rank) 30 | process = multiprocessing.Process( 31 | target=self.__class__._run_process, 32 | name=process_name, 33 | args=(rank, world_size, env, run_process_fn, fn_args), 34 | ) 35 | self.processes.append(process) 36 | 37 | if crypten.mpc.ttp_required(): 38 | ttp_process = multiprocessing.Process( 39 | target=self.__class__._run_process, 40 | name="TTP", 41 | args=( 42 | world_size, 43 | world_size, 44 | env, 45 | crypten.mpc.provider.TTPServer, 46 | None, 47 | ), 48 | ) 49 | self.processes.append(ttp_process) 50 | 51 | @classmethod 52 | def _run_process(cls, rank, world_size, env, run_process_fn, fn_args): 53 | for env_key, env_value in env.items(): 54 | os.environ[env_key] = env_value 55 | os.environ["RANK"] = str(rank) 56 | orig_logging_level = logging.getLogger().level 57 | logging.getLogger().setLevel(logging.INFO) 58 | crypten.init() 59 | logging.getLogger().setLevel(orig_logging_level) 60 | if fn_args is None: 61 | run_process_fn() 62 | else: 63 | run_process_fn(fn_args) 64 | 65 | def start(self): 66 | for process in self.processes: 67 | process.start() 68 | 69 | def join(self): 70 | for process in self.processes: 71 | process.join() 72 | assert ( 73 | process.exitcode == 0 74 | ), f"{process.name} has non-zero exit code {process.exitcode}" 75 | 76 | def terminate(self): 77 | for process in self.processes: 78 | process.terminate() -------------------------------------------------------------------------------- /crypten/nn/tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten.nn as nn 9 | from tensorboard.compat.proto.attr_value_pb2 import AttrValue 10 | from tensorboard.compat.proto.graph_pb2 import GraphDef 11 | from tensorboard.compat.proto.node_def_pb2 import NodeDef 12 | from tensorboard.compat.proto.versions_pb2 import VersionDef 13 | from torch.utils.tensorboard import SummaryWriter as _SummaryWriter 14 | 15 | 16 | def graph(model): 17 | """Converts a crypten.nn graph for consumption by TensorBoard.""" 18 | 19 | # convert individual module to graph: 20 | assert isinstance(model, nn.Module), "model must be crypten.nn.Module" 21 | if not isinstance(model, nn.Graph): 22 | graph = nn.Graph("input", "output") 23 | graph.add_module("output", model, ["input"]) 24 | model = graph 25 | 26 | # create mapping to more interpretable node naming: 27 | mapping = {input_name: input_name for input_name in model.input_names} 28 | modules = {name: module for name, module in model.named_modules()} 29 | for name, module in modules.items(): 30 | op = str(type(module))[26:-2] 31 | mapping[name] = "%s_%s" % (op, name) 32 | 33 | # create input variables: 34 | nodes = [ 35 | NodeDef( 36 | name=mapping[input_name].encode(encoding="utf_8"), 37 | op="Variable", 38 | input=[], 39 | ) 40 | for input_name in model.input_names 41 | ] 42 | 43 | # loop all graph connections: 44 | for output_name, input_names in model._graph.items(): 45 | 46 | # get parameters and type of module: 47 | module = modules[output_name] 48 | op = str(type(module)) 49 | input_names = [mapping[name] for name in input_names] 50 | parameters = [ 51 | "%s: %s" % (name, parameter.size()) 52 | for name, parameter in module.named_parameters() 53 | ] 54 | parameter_string = "; ".join(parameters).encode(encoding="utf_8") 55 | 56 | # add to graph: 57 | nodes.append( 58 | NodeDef( 59 | name=mapping[output_name].encode(encoding="utf_8"), 60 | op=op, 61 | input=input_names, 62 | attr={"attr": AttrValue(s=parameter_string)}, 63 | ) 64 | ) 65 | 66 | # return graph definition: 67 | return GraphDef(node=nodes, versions=VersionDef(producer=22)) 68 | 69 | 70 | class SummaryWriter(_SummaryWriter): 71 | """ 72 | Adapts the PyTorch SummaryWriter to output crypten graphs. 73 | """ 74 | 75 | def add_graph(self, model, input_to_model=None, verbose=False): 76 | self._get_file_writer().add_onnx_graph(graph(model)) 77 | -------------------------------------------------------------------------------- /crypten/common/rng.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | import crypten 8 | import torch 9 | from crypten.cuda import CUDALongTensor 10 | 11 | 12 | def generate_random_ring_element(size, ring_size=(2**64), generator=None, **kwargs): 13 | """Helper function to generate a random number from a signed ring""" 14 | if generator is None: 15 | device = kwargs.get("device", torch.device("cpu")) 16 | device = torch.device("cpu") if device is None else device 17 | device = torch.device(device) if isinstance(device, str) else device 18 | generator = crypten.generators["local"][device] 19 | # TODO (brianknott): Check whether this RNG contains the full range we want. 20 | rand_element = torch.randint( 21 | -(ring_size // 2), 22 | (ring_size - 1) // 2, 23 | size, 24 | generator=generator, 25 | dtype=torch.long, 26 | **kwargs, 27 | ) 28 | if rand_element.is_cuda: 29 | return CUDALongTensor(rand_element) 30 | return rand_element 31 | 32 | 33 | def generate_unsigned_random_ring_element(size, ring_size=(2**64), generator=None, **kwargs): 34 | """Helper function to generate a random number from a signed ring""" 35 | if generator is None: 36 | device = kwargs.get("device", torch.device("cpu")) 37 | device = torch.device("cpu") if device is None else device 38 | device = torch.device(device) if isinstance(device, str) else device 39 | generator = crypten.generators["local"][device] 40 | # TODO (brianknott): Check whether this RNG contains the full range we want. 41 | rand_element = torch.randint( 42 | 0, 43 | ring_size, 44 | size, 45 | generator=generator, 46 | dtype=torch.long, 47 | **kwargs, 48 | ) 49 | if rand_element.is_cuda: 50 | return CUDALongTensor(rand_element) 51 | return rand_element 52 | 53 | 54 | def generate_kbit_random_tensor(size, bitlength=None, generator=None, **kwargs): 55 | """Helper function to generate a random k-bit number""" 56 | if bitlength is None: 57 | bitlength = torch.iinfo(torch.long).bits 58 | if bitlength == 64: 59 | return generate_random_ring_element(size, generator=generator, **kwargs) 60 | if generator is None: 61 | device = kwargs.get("device", torch.device("cpu")) 62 | device = torch.device("cpu") if device is None else device 63 | device = torch.device(device) if isinstance(device, str) else device 64 | generator = crypten.generators["local"][device] 65 | rand_tensor = torch.randint( 66 | 0, 2**bitlength, size, generator=generator, dtype=torch.long, **kwargs 67 | ) 68 | if rand_tensor.is_cuda: 69 | return CUDALongTensor(rand_tensor) 70 | return rand_tensor 71 | -------------------------------------------------------------------------------- /crypten/mpc/primitives/converters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten.communicator as comm 9 | import torch 10 | from crypten.encoder import FixedPointEncoder 11 | 12 | from ..ptype import ptype as Ptype 13 | from . import beaver 14 | from .arithmetic import ArithmeticSharedTensor 15 | from .binary import BinarySharedTensor 16 | 17 | 18 | def _A2B(arithmetic_tensor): 19 | 20 | # first try memory-inefficient implementation that takes O(log P) rounds: 21 | try: 22 | binary_tensor = BinarySharedTensor.stack( 23 | [ 24 | BinarySharedTensor(arithmetic_tensor.share, src=i) 25 | for i in range(comm.get().get_world_size()) 26 | ] 27 | ) 28 | binary_tensor = binary_tensor.sum(dim=0) 29 | 30 | # if we OOM, try memory-efficient implementation that uses O(P) rounds: 31 | except RuntimeError: 32 | binary_tensor = None 33 | for i in range(comm.get().get_world_size()): 34 | binary_share = BinarySharedTensor(arithmetic_tensor.share, src=i) 35 | binary_tensor = binary_share if i == 0 else binary_tensor + binary_share 36 | 37 | # return the result: 38 | binary_tensor.encoder = arithmetic_tensor.encoder 39 | return binary_tensor 40 | 41 | 42 | def _B2A(binary_tensor, precision=None, bits=None): 43 | if bits is None: 44 | bits = torch.iinfo(torch.long).bits 45 | 46 | if bits == 1: 47 | binary_bit = binary_tensor & 1 48 | arithmetic_tensor = beaver.B2A_single_bit(binary_bit) 49 | else: 50 | binary_bits = BinarySharedTensor.stack( 51 | [binary_tensor >> i for i in range(bits)] 52 | ) 53 | binary_bits = binary_bits & 1 54 | arithmetic_bits = beaver.B2A_single_bit(binary_bits) 55 | 56 | multiplier = torch.cat( 57 | [ 58 | torch.tensor([1], dtype=torch.long, device=binary_tensor.device) << i 59 | for i in range(bits) 60 | ] 61 | ) 62 | while multiplier.dim() < arithmetic_bits.dim(): 63 | multiplier = multiplier.unsqueeze(1) 64 | 65 | arithmetic_tensor = arithmetic_bits.mul_(multiplier).sum(0) 66 | 67 | arithmetic_tensor.encoder = FixedPointEncoder(precision_bits=precision) 68 | scale = arithmetic_tensor.encoder._scale // binary_tensor.encoder._scale 69 | arithmetic_tensor *= scale 70 | return arithmetic_tensor 71 | 72 | 73 | def convert(tensor, ptype, **kwargs): 74 | tensor_name = ptype.to_tensor() 75 | if isinstance(tensor, tensor_name): 76 | return tensor 77 | if isinstance(tensor, ArithmeticSharedTensor) and ptype == Ptype.binary: 78 | return _A2B(tensor) 79 | elif isinstance(tensor, BinarySharedTensor) and ptype == Ptype.arithmetic: 80 | return _B2A(tensor, **kwargs) 81 | else: 82 | raise TypeError("Cannot convert %s to %s" % (type(tensor), ptype.__name__)) 83 | -------------------------------------------------------------------------------- /crypten/nn/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | from .init import * # noqa: F403 10 | from .distances import CosineSimilarity 11 | from .loss import _Loss, BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, L1Loss, MSELoss 12 | from .module import ( 13 | AdaptiveAvgPool2d, 14 | AdaptiveMaxPool2d, 15 | Add, 16 | AvgPool2d, 17 | BatchNorm1d, 18 | BatchNorm2d, 19 | BatchNorm3d, 20 | Cast, 21 | Concat, 22 | Constant, 23 | ConstantOfShape, 24 | ConstantPad1d, 25 | ConstantPad2d, 26 | ConstantPad3d, 27 | Conv, 28 | Conv1d, 29 | Conv2d, 30 | Div, 31 | Dropout, 32 | Dropout2d, 33 | Dropout3d, 34 | DropoutNd, 35 | Equal, 36 | Erf, 37 | Exp, 38 | Expand, 39 | Flatten, 40 | Gather, 41 | Gemm, 42 | GlobalAveragePool, 43 | Graph, 44 | GroupNorm, 45 | Hardtanh, 46 | Linear, 47 | LogSoftmax, 48 | MatMul, 49 | MaxPool2d, 50 | Mean, 51 | Module, 52 | ModuleDict, 53 | ModuleList, 54 | Mul, 55 | Parameter, 56 | Pow, 57 | Range, 58 | ReLU, 59 | ReLU6, 60 | Reshape, 61 | Sequential, 62 | Shape, 63 | Sigmoid, 64 | Slice, 65 | Softmax, 66 | Sqrt, 67 | Squeeze, 68 | Sub, 69 | Sum, 70 | Transpose, 71 | Unsqueeze, 72 | Where, 73 | ) 74 | from .onnx_converter import from_onnx, from_pytorch, from_tensorflow, TF_AND_TF2ONNX 75 | 76 | 77 | # expose contents of package 78 | __all__ = [ # noqa: F405 79 | "_Loss", 80 | "AdaptiveAvgPool2d", 81 | "AdaptiveMaxPool2d", 82 | "Add", 83 | "AvgPool2d", 84 | "BatchNorm1d", 85 | "BatchNorm2d", 86 | "BatchNorm3d", 87 | "BCELoss", 88 | "BCEWithLogitsLoss", 89 | "Cast", 90 | "Concat", 91 | "Constant", 92 | "ConstantOfShape", 93 | "ConstantPad1d", 94 | "ConstantPad2d", 95 | "ConstantPad3d", 96 | "Conv", 97 | "Conv1d", 98 | "Conv2d", 99 | "CosineSimilarity", 100 | "CrossEntropyLoss", 101 | "Div", 102 | "Dropout", 103 | "Dropout2d", 104 | "Dropout3d", 105 | "DropoutNd", 106 | "Erf", 107 | "Equal", 108 | "Exp", 109 | "Expand", 110 | "Flatten", 111 | "from_pytorch", 112 | "from_onnx", 113 | "from_tensorflow", 114 | "Gather", 115 | "Gemm", 116 | "GlobalAveragePool", 117 | "Graph", 118 | "GroupNorm", 119 | "Hardtanh", 120 | "L1Loss", 121 | "Linear", 122 | "LogSoftmax", 123 | "MatMul", 124 | "MaxPool2d", 125 | "Mean", 126 | "Module", 127 | "ModuleDict", 128 | "ModuleList", 129 | "MSELoss", 130 | "Mul", 131 | "Parameter", 132 | "Pow", 133 | "Range", 134 | "ReLU", 135 | "ReLU6", 136 | "Reshape", 137 | "Sequential", 138 | "Shape", 139 | "Sigmoid", 140 | "Slice", 141 | "Softmax", 142 | "Sqrt", 143 | "Squeeze", 144 | "Sub", 145 | "Sum", 146 | "TF_AND_TF2ONNX", 147 | "Transpose", 148 | "Unsqueeze", 149 | "Where", 150 | "init", 151 | ] 152 | -------------------------------------------------------------------------------- /crypten/config/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | from contextlib import contextmanager 10 | 11 | import yaml 12 | from omegaconf import OmegaConf 13 | 14 | 15 | class CrypTenConfig(object): 16 | """ 17 | Configuration object used to store configurable parameters for CrypTen. 18 | 19 | This object acts as a nested dictionary, but can be queried using dot-notation( 20 | e.g. querying or setting `cfg.a.b` is equivalent to `cfg['a']['b']`). 21 | 22 | Users can load a CrypTen config from a file using `cfg.load_config(filepath)`. 23 | 24 | Users can temporarily override a config parameter using the contextmanager temp_override: 25 | 26 | .. code-block:: python 27 | 28 | cfg.a.b = outer # sets cfg["a"]["b"] to outer value 29 | 30 | with cfg.temp_override("a.b", inner): 31 | print(cfg.a.b) # prints inner value 32 | 33 | print(cfg.a.b) # prints outer value 34 | """ 35 | 36 | __DEFAULT_CONFIG_PATH = os.path.normpath( 37 | os.path.join(__file__, "../../../configs/default.yaml") 38 | ) 39 | 40 | def __init__(self, config_file=None): 41 | self.load_config(config_file) 42 | 43 | def load_config(self, config_file): 44 | """Loads config from a yaml file""" 45 | if config_file is None: 46 | config_file = CrypTenConfig.__DEFAULT_CONFIG_PATH 47 | 48 | # Use yaml to open stream for safe load 49 | with open(config_file) as stream: 50 | config_dict = yaml.safe_load(stream) 51 | self.config = OmegaConf.create(config_dict) 52 | 53 | def set_config(self, config): 54 | if isinstance(config, CrypTenConfig): 55 | self.config = config.config 56 | else: 57 | self.config = config 58 | 59 | def __getattribute__(self, name): 60 | try: 61 | return object.__getattribute__(self, name) 62 | except AttributeError: 63 | keys = name.split(".") 64 | result = getattr(self.config, keys[0]) 65 | for key in keys[1:]: 66 | result = getattr(result, key) 67 | return result 68 | 69 | def __getitem__(self, name): 70 | return self.__getattribute__(name) 71 | 72 | def __setattr__(self, name, value): 73 | if name == "config": 74 | object.__setattr__(self, name, value) 75 | try: 76 | # Can only set attribute if already exists 77 | object.__getattribute__(self, name) 78 | object.__setattr__(self, name, value) 79 | except AttributeError: 80 | dotlist = [f"{name}={value}"] 81 | update = OmegaConf.from_dotlist(dotlist) 82 | self.config = OmegaConf.merge(self.config, update) 83 | 84 | def __setitem__(self, name, value): 85 | self.__setattr__(name, value) 86 | 87 | @contextmanager 88 | def temp_override(self, override_dict): 89 | old_config = self.config 90 | try: 91 | dotlist = [f"{k}={v}" for k, v in override_dict.items()] 92 | update = OmegaConf.from_dotlist(dotlist) 93 | self.config = OmegaConf.merge(self.config, update) 94 | yield 95 | finally: 96 | self.config = old_config 97 | -------------------------------------------------------------------------------- /crypten/common/functions/logic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | from crypten.common.tensor_types import is_tensor 10 | 11 | 12 | __all__ = [ 13 | "__eq__", 14 | "__ge__", 15 | "__gt__", 16 | "__le__", 17 | "__lt__", 18 | "__ne__", 19 | "abs", 20 | "eq", 21 | "ge", 22 | "gt", 23 | "hardtanh", 24 | "le", 25 | "lt", 26 | "ne", 27 | "relu", 28 | "sign", 29 | "where", 30 | ] 31 | 32 | 33 | def ge(self, y): 34 | """Returns self >= y""" 35 | return 1 - self.lt(y) 36 | 37 | 38 | def gt(self, y): 39 | """Returns self > y""" 40 | return (-self + y)._ltz() 41 | 42 | 43 | def le(self, y): 44 | """Returns self <= y""" 45 | return 1 - self.gt(y) 46 | 47 | 48 | def lt(self, y): 49 | """Returns self < y""" 50 | return (self - y)._ltz() 51 | 52 | 53 | def eq(self, y): 54 | """Returns self == y""" 55 | return 1 - self.ne(y) 56 | 57 | 58 | def ne(self, y): 59 | """Returns self != y""" 60 | difference = self - y 61 | difference = type(difference).stack([difference, -difference]) 62 | return difference._ltz().sum(0) 63 | 64 | 65 | __eq__ = eq 66 | __ge__ = ge 67 | __gt__ = gt 68 | __le__ = le 69 | __lt__ = lt 70 | __ne__ = ne 71 | 72 | 73 | def sign(self): 74 | """Computes the sign value of a tensor (0 is considered positive)""" 75 | return 1 - 2 * self._ltz() 76 | 77 | 78 | def abs(self): 79 | """Computes the absolute value of a tensor""" 80 | return self * self.sign() 81 | 82 | 83 | def relu(self): 84 | """Compute a Rectified Linear function on the input tensor.""" 85 | return self * self.ge(0) 86 | 87 | 88 | def hardtanh(self, min_value=-1, max_value=1): 89 | r"""Applies the HardTanh function element-wise 90 | 91 | HardTanh is defined as: 92 | 93 | .. math:: 94 | \text{HardTanh}(x) = \begin{cases} 95 | 1 & \text{ if } x > 1 \\ 96 | -1 & \text{ if } x < -1 \\ 97 | x & \text{ otherwise } \\ 98 | \end{cases} 99 | 100 | The range of the linear region :math:`[-1, 1]` can be adjusted using 101 | :attr:`min_val` and :attr:`max_val`. 102 | 103 | Args: 104 | min_val: minimum value of the linear region range. Default: -1 105 | max_val: maximum value of the linear region range. Default: 1 106 | """ 107 | intermediate = crypten.stack([self - min_value, self - max_value]).relu() 108 | intermediate = intermediate[0].sub(intermediate[1]) 109 | return intermediate.add_(min_value) 110 | 111 | 112 | def where(self, condition, y): 113 | """Selects elements from self or y based on condition 114 | 115 | Args: 116 | condition (torch.bool or MPCTensor): when True yield self, 117 | otherwise yield y 118 | y (torch.tensor or MPCTensor): values selected at indices 119 | where condition is False. 120 | 121 | Returns: MPCTensor or torch.tensor 122 | """ 123 | if is_tensor(condition): 124 | condition = condition.float() 125 | y_masked = y * (1 - condition) 126 | else: 127 | # encrypted tensor must be first operand 128 | y_masked = (1 - condition) * y 129 | 130 | return self * condition + y_masked 131 | -------------------------------------------------------------------------------- /crypten/common/functions/sampling.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import torch 10 | 11 | 12 | __all__ = [ 13 | "bernoulli", 14 | "randn", 15 | "weighted_index", 16 | "weighted_sample", 17 | ] 18 | 19 | 20 | def randn(*sizes, device=None): 21 | """ 22 | Returns a tensor with normally distributed elements. Samples are 23 | generated using the Box-Muller transform with optimizations for 24 | numerical precision and MPC efficiency. 25 | """ 26 | u = crypten.rand(*sizes, device=device).flatten() 27 | odd_numel = u.numel() % 2 == 1 28 | if odd_numel: 29 | u = crypten.cat([u, crypten.rand((1,), device=device)]) 30 | 31 | n = u.numel() // 2 32 | u1 = u[:n] 33 | u2 = u[n:] 34 | 35 | # Radius = sqrt(- 2 * log(u1)) 36 | r2 = -2 * u1.log(input_in_01=True) 37 | r = r2.sqrt() 38 | 39 | # Theta = cos(2 * pi * u2) or sin(2 * pi * u2) 40 | cos, sin = u2.sub(0.5).mul(6.28318531).cossin() 41 | 42 | # Generating 2 independent normal random variables using 43 | x = r.mul(sin) 44 | y = r.mul(cos) 45 | z = crypten.cat([x, y]) 46 | 47 | if odd_numel: 48 | z = z[1:] 49 | 50 | return z.view(*sizes) 51 | 52 | 53 | def bernoulli(self): 54 | """Returns a tensor with elements in {0, 1}. The i-th element of the 55 | output will be 1 with probability according to the i-th value of the 56 | input tensor.""" 57 | return self > crypten.rand(self.size(), device=self.device) 58 | 59 | 60 | def weighted_index(self, dim=None): 61 | """ 62 | Returns a tensor with entries that are one-hot along dimension `dim`. 63 | These one-hot entries are set at random with weights given by the input 64 | `self`. 65 | 66 | Examples:: 67 | 68 | >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) 69 | >>> index = encrypted_tensor.weighted_index().get_plain_text() 70 | # With 1 / 7 probability 71 | torch.tensor([1., 0.]) 72 | 73 | # With 6 / 7 probability 74 | torch.tensor([0., 1.]) 75 | """ 76 | if dim is None: 77 | return self.flatten().weighted_index(dim=0).view(self.size()) 78 | 79 | x = self.cumsum(dim) 80 | max_weight = x.index_select(dim, torch.tensor(x.size(dim) - 1, device=self.device)) 81 | r = crypten.rand(max_weight.size(), device=self.device) * max_weight 82 | 83 | gt = x.gt(r) 84 | shifted = gt.roll(1, dims=dim) 85 | shifted.data.index_fill_(dim, torch.tensor(0, device=self.device), 0) 86 | 87 | return gt - shifted 88 | 89 | 90 | def weighted_sample(self, dim=None): 91 | """ 92 | Samples a single value across dimension `dim` with weights corresponding 93 | to the values in `self` 94 | 95 | Returns the sample and the one-hot index of the sample. 96 | 97 | Examples:: 98 | 99 | >>> encrypted_tensor = MPCTensor(torch.tensor([1., 6.])) 100 | >>> index = encrypted_tensor.weighted_sample().get_plain_text() 101 | # With 1 / 7 probability 102 | (torch.tensor([1., 0.]), torch.tensor([1., 0.])) 103 | 104 | # With 6 / 7 probability 105 | (torch.tensor([0., 6.]), torch.tensor([0., 1.])) 106 | """ 107 | indices = self.weighted_index(dim) 108 | sample = self.mul(indices).sum(dim) 109 | return sample, indices 110 | -------------------------------------------------------------------------------- /examples/ttp-test/run_test_bert.py: -------------------------------------------------------------------------------- 1 | import crypten 2 | import torch 3 | from transformers import BertForSequenceClassification, BertConfig 4 | 5 | crypten.cfg.debug.report_cost = True 6 | 7 | 8 | class BertTinyConfig(BertConfig): 9 | def __init__(self): 10 | super().__init__( 11 | vocab_size=30522, 12 | hidden_size=128, 13 | num_hidden_layers=2, 14 | num_attention_heads=2, 15 | intermediate_size=512, 16 | max_position_embeddings=512, 17 | type_vocab_size=2, 18 | layer_norm_eps=1e-12, 19 | pad_token_id=0, 20 | num_labels=2 21 | ) 22 | 23 | 24 | class BertBaseConfig(BertConfig): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | 29 | class BertLargeConfig(BertConfig): 30 | def __init__(self): 31 | super().__init__( 32 | vocab_size=30522, 33 | hidden_size=1024, 34 | num_hidden_layers=24, 35 | num_attention_heads=16, 36 | intermediate_size=4096, 37 | max_position_embeddings=512, 38 | type_vocab_size=2, 39 | layer_norm_eps=1e-12, 40 | pad_token_id=0, 41 | num_labels=2 42 | ) 43 | 44 | 45 | @crypten.mpc.context.run_multiprocess(2) 46 | def test_bert(config: BertConfig = BertTinyConfig(), input_shape: tuple = (1, 128), device: str = "cuda"): 47 | crypten.init() 48 | 49 | print(config.__class__.__name__) 50 | model = BertForSequenceClassification(config) 51 | model = model.to(device) 52 | model = model.eval() 53 | # BertTiny: 54 | # TTP: comm byte: 0.48 GB, round: 1180 55 | # TFP: comm byte: 0.37 GB, round: 296 56 | # BertBase: 57 | # TTP: comm byte: 13.41 GB, round: 5980 58 | # TFP: comm byte: 10.48 GB, round: 1496 59 | # BertLarge: 60 | # TFP: comm byte: 28.49 GB, round: 2936 61 | # TTP: comm byte: 36.26 GB, round: 11744 62 | ct_model = crypten.nn.from_pytorch(model, ( 63 | torch.zeros(input_shape, dtype=torch.int64).to(device), 64 | torch.zeros(input_shape, dtype=torch.int64).to(device), 65 | torch.zeros(input_shape, dtype=torch.int64).to(device), 66 | )).encrypt() 67 | with torch.no_grad(): 68 | print(f"[rank {crypten.communicator.get().rank}]", "input shape: ", input_shape) 69 | ct_input_ids = crypten.cryptensor(torch.zeros(input_shape, dtype=torch.int64).to(device)) 70 | ct_attention_mask = crypten.cryptensor(torch.zeros(input_shape, dtype=torch.int64).to(device)) 71 | ct_token_type_ids = crypten.cryptensor(torch.zeros(input_shape, dtype=torch.int64).to(device)) 72 | 73 | get_v = ct_model(ct_input_ids, ct_attention_mask, ct_token_type_ids) 74 | 75 | get_v = get_v.get_plain_text() 76 | need_v = model.forward(torch.zeros(input_shape, dtype=torch.int64).to(device), 77 | torch.zeros(input_shape, dtype=torch.int64).to(device), 78 | torch.zeros(input_shape, dtype=torch.int64).to(device), 79 | ) 80 | print("ct model results", get_v) 81 | print("pt model results", need_v) 82 | 83 | 84 | if __name__ == "__main__": 85 | torch.manual_seed(101) 86 | torch.random.manual_seed(101) 87 | torch.cuda.manual_seed(101) 88 | torch.cuda.random.manual_seed(101) 89 | crypten.cfg.mpc.provider = "TTP" 90 | test_bert(BertTinyConfig(), (1, 128), "cuda") 91 | test_bert(BertBaseConfig(), (1, 128), "cuda") 92 | test_bert(BertLargeConfig(), (1, 128), "cuda") 93 | -------------------------------------------------------------------------------- /crypten/encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | 10 | import numpy as np 11 | import torch 12 | 13 | from .common.tensor_types import is_float_tensor, is_int_tensor 14 | from .config import cfg 15 | from .cryptensor import CrypTensor 16 | 17 | 18 | def nearest_integer_division(tensor, integer): 19 | """Performs division of integer tensor, rounding to nearest integer.""" 20 | assert integer > 0, "only supports positive divisors" 21 | assert is_int_tensor(tensor), "unsupported type: %s" % type(tensor) 22 | 23 | lez = (tensor < 0).long() 24 | pos_remainder = (1 - lez) * tensor % integer 25 | neg_remainder = lez * ((integer - tensor) % integer) 26 | remainder = pos_remainder + neg_remainder 27 | quotient = tensor.div(integer, rounding_mode="trunc") 28 | correction = (2 * remainder > integer).long() 29 | return quotient + tensor.sign() * correction 30 | 31 | 32 | class FixedPointEncoder: 33 | """Encoder that encodes long or float tensors into scaled integer tensors.""" 34 | 35 | def __init__(self, precision_bits=None): 36 | if precision_bits is None: 37 | precision_bits = cfg.encoder.precision_bits 38 | self._precision_bits = precision_bits 39 | self._scale = int(2**precision_bits) 40 | 41 | def encode(self, x, device=None): 42 | """Helper function to wrap data if needed""" 43 | if isinstance(x, CrypTensor): 44 | return x 45 | elif isinstance(x, int) or isinstance(x, float): 46 | # Squeeze in order to get a 0-dim tensor with value `x` 47 | return torch.tensor( 48 | [self._scale * x], dtype=torch.long, device=device 49 | ).squeeze() 50 | elif isinstance(x, list): 51 | return ( 52 | torch.tensor(x, dtype=torch.float, device=device) 53 | .mul_(self._scale) 54 | .long() 55 | ) 56 | elif is_float_tensor(x): 57 | return (self._scale * x).long() 58 | # For integer types cast to long prior to scaling to avoid overflow. 59 | elif is_int_tensor(x): 60 | return self._scale * x.long() 61 | elif isinstance(x, np.ndarray): 62 | return self._scale * torch.from_numpy(x).long().to(device) 63 | elif torch.is_tensor(x): 64 | raise TypeError("Cannot encode input with dtype %s" % x.dtype) 65 | else: 66 | raise TypeError("Unknown tensor type: %s." % type(x)) 67 | 68 | def decode(self, tensor): 69 | """Helper function that decodes from scaled tensor""" 70 | if tensor is None: 71 | return None 72 | assert is_int_tensor(tensor), "input must be a LongTensor" 73 | if self._scale > 1: 74 | correction = (tensor < 0).long() 75 | dividend = tensor.div(self._scale - correction, rounding_mode="floor") 76 | remainder = tensor % self._scale 77 | remainder += (remainder == 0).long() * self._scale * correction 78 | 79 | tensor = dividend.float() + remainder.float() / self._scale 80 | else: 81 | tensor = nearest_integer_division(tensor, self._scale) 82 | 83 | return tensor.data 84 | 85 | def __setattr__(self, name, value): 86 | if name == "_precision_bits": 87 | dict.__setattr__(self, "_scale", int(2**value)) 88 | elif name == "_scale": 89 | dict.__setattr__(self, "_precision_bits", int(math.log2(value))) 90 | dict.__setattr__(self, name, value) 91 | 92 | @property 93 | def scale(self): 94 | return self._scale 95 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /crypten/mpc/provider/tfp_provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Modified by Andes Y. L. Kei: Implemented generate_trig_triple, generate_one_hot_pair 4 | # 5 | # Copyright (c) Facebook, Inc. and its affiliates. 6 | # 7 | # This source code is licensed under the MIT license found in the 8 | # LICENSE file in the root directory of this source tree. 9 | 10 | import crypten 11 | import crypten.communicator as comm 12 | import math 13 | import torch 14 | from crypten.common.rng import generate_kbit_random_tensor, generate_random_ring_element, generate_unsigned_random_ring_element 15 | from crypten.common.util import count_wraps, torch_stack 16 | from crypten.mpc.primitives import ArithmeticSharedTensor, BinarySharedTensor 17 | 18 | from .provider import TupleProvider 19 | 20 | 21 | class TrustedFirstParty(TupleProvider): 22 | NAME = "TFP" 23 | 24 | def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs): 25 | """Generate multiplicative triples of given sizes""" 26 | a = generate_random_ring_element(size0, device=device) 27 | b = generate_random_ring_element(size1, device=device) 28 | 29 | c = getattr(torch, op)(a, b, *args, **kwargs) 30 | 31 | a = ArithmeticSharedTensor(a, precision=0, src=0) 32 | b = ArithmeticSharedTensor(b, precision=0, src=0) 33 | c = ArithmeticSharedTensor(c, precision=0, src=0) 34 | 35 | return a, b, c 36 | 37 | def square(self, size, device=None): 38 | """Generate square double of given size""" 39 | r = generate_random_ring_element(size, device=device) 40 | r2 = r.mul(r) 41 | 42 | # Stack to vectorize scatter function 43 | stacked = torch_stack([r, r2]) 44 | stacked = ArithmeticSharedTensor(stacked, precision=0, src=0) 45 | return stacked[0], stacked[1] 46 | 47 | def generate_trig_triple(self, size, period, terms, device=None): 48 | """Generate trigonometric triple of given size""" 49 | t = torch.rand(size, device=device) * period 50 | k = [i * 2 * math.pi / period for i in range(1, terms + 1)] 51 | tk = torch_stack([i * t for i in k]) 52 | u, v = torch.sin(tk), torch.cos(tk) 53 | 54 | t = ArithmeticSharedTensor(t, src=0) 55 | u = ArithmeticSharedTensor(u, src=0) 56 | v = ArithmeticSharedTensor(v, src=0) 57 | return t, u, v 58 | 59 | def generate_one_hot_pair(self, size, length, device=None): 60 | """Generate one hot encoding of given size (of output) and length (of one hot vector)""" 61 | r = generate_unsigned_random_ring_element(size, ring_size=length, device=device) 62 | v = torch.nn.functional.one_hot(r, num_classes=length) 63 | 64 | r = crypten.cryptensor(r, device=device) 65 | v = crypten.cryptensor(v, device=device) 66 | return r, v 67 | 68 | def generate_binary_triple(self, size0, size1, device=None): 69 | """Generate xor triples of given size""" 70 | a = generate_kbit_random_tensor(size0, device=device) 71 | b = generate_kbit_random_tensor(size1, device=device) 72 | c = a & b 73 | 74 | a = BinarySharedTensor(a, src=0) 75 | b = BinarySharedTensor(b, src=0) 76 | c = BinarySharedTensor(c, src=0) 77 | 78 | return a, b, c 79 | 80 | def wrap_rng(self, size, device=None): 81 | """Generate random shared tensor of given size and sharing of its wraps""" 82 | num_parties = comm.get().get_world_size() 83 | r = [ 84 | generate_random_ring_element(size, device=device) 85 | for _ in range(num_parties) 86 | ] 87 | theta_r = count_wraps(r) 88 | 89 | shares = comm.get().scatter(r, 0) 90 | r = ArithmeticSharedTensor.from_shares(shares, precision=0) 91 | theta_r = ArithmeticSharedTensor(theta_r, precision=0, src=0) 92 | 93 | return r, theta_r 94 | 95 | def B2A_rng(self, size, device=None): 96 | """Generate random bit tensor as arithmetic and binary shared tensors""" 97 | # generate random bit 98 | r = generate_kbit_random_tensor(size, bitlength=1, device=device) 99 | 100 | rA = ArithmeticSharedTensor(r, precision=0, src=0) 101 | rB = BinarySharedTensor(r, src=0) 102 | 103 | return rA, rB 104 | -------------------------------------------------------------------------------- /crypten/mpc/context.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import functools 9 | import logging 10 | import multiprocessing 11 | import os 12 | import tempfile 13 | from operator import itemgetter 14 | 15 | import crypten 16 | from crypten.communicator import DistributedCommunicator 17 | 18 | 19 | def _launch(func, rank, world_size, rendezvous_file, queue, func_args, func_kwargs): 20 | communicator_args = { 21 | "WORLD_SIZE": world_size, 22 | "RANK": rank, 23 | "RENDEZVOUS": "file://%s" % rendezvous_file, 24 | "DISTRIBUTED_BACKEND": "gloo", 25 | } 26 | for key, val in communicator_args.items(): 27 | os.environ[key] = str(val) 28 | 29 | crypten.init() 30 | return_value = func(*func_args, **func_kwargs) 31 | crypten.uninit() 32 | 33 | queue.put((rank, return_value)) 34 | 35 | 36 | def run_multiprocess(world_size, maxsize=None): 37 | """Defines decorator to run function across multiple processes 38 | 39 | Args: 40 | world_size (int): number of parties / processes to initiate. 41 | maxsize: Enables the user to increase the size of returnable values 42 | (See https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Queue) 43 | """ 44 | 45 | def decorator(func): 46 | @functools.wraps(func) 47 | def wrapper(*args, **kwargs): 48 | rendezvous_file = tempfile.NamedTemporaryFile(delete=True).name 49 | 50 | if maxsize is None: 51 | queue = multiprocessing.Queue() 52 | else: 53 | queue = multiprocessing.Queue(maxsize) 54 | 55 | processes = [ 56 | multiprocessing.Process( 57 | target=_launch, 58 | args=(func, rank, world_size, rendezvous_file, queue, args, kwargs), 59 | ) 60 | for rank in range(world_size) 61 | ] 62 | 63 | # Initialize TTP process 64 | if crypten.mpc.ttp_required(): 65 | processes += [ 66 | multiprocessing.Process( 67 | target=_launch, 68 | args=( 69 | crypten.mpc.provider.TTPServer, 70 | world_size, 71 | world_size, 72 | rendezvous_file, 73 | queue, 74 | (), 75 | {}, 76 | ), 77 | ) 78 | ] 79 | 80 | # This process will be forked and we need to re-initialize the 81 | # communicator in the children. If the parent process happened to 82 | # call crypten.init(), which might be valid in a Jupyter notebook 83 | # for instance, then the crypten.init() call on the children 84 | # process will not do anything. The call to uninit here makes sure 85 | # we actually get to initialize the communicator on the child 86 | # process. An alternative fix for this issue would be to use spawn 87 | # instead of fork, but we run into issues serializing the function 88 | # in that case. 89 | was_initialized = DistributedCommunicator.is_initialized() 90 | if was_initialized: 91 | crypten.uninit() 92 | 93 | for process in processes: 94 | process.start() 95 | 96 | for process in processes: 97 | process.join() 98 | 99 | if was_initialized: 100 | crypten.init() 101 | 102 | successful = [process.exitcode == 0 for process in processes] 103 | if not all(successful): 104 | logging.error("One of the parties failed. Check past logs") 105 | return None 106 | 107 | return_values = [] 108 | while not queue.empty(): 109 | return_values.append(queue.get()) 110 | 111 | return [value for _, value in sorted(return_values, key=itemgetter(0))] 112 | 113 | return wrapper 114 | 115 | return decorator 116 | -------------------------------------------------------------------------------- /examples/image-classification/run_image_classification_private.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Modified by SHAFT's team: Private Text Classification on ImageNet-1k. 3 | # 4 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Finetuning any 🤗 Transformers model for image classification leveraging 🤗 Accelerate.""" 18 | 19 | import argparse 20 | import logging 21 | import torch 22 | 23 | import datasets 24 | import crypten as ct 25 | from crypten.config import cfg 26 | from multiprocess_launcher import MultiProcessLauncher 27 | 28 | from tqdm.auto import tqdm 29 | 30 | import transformers 31 | from transformers import AutoConfig, AutoModelForImageClassification 32 | from transformers.utils import check_min_version 33 | from transformers.utils.versions import require_version 34 | 35 | 36 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 37 | check_min_version("4.42.0.dev0") 38 | 39 | require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/image-classification/requirements.txt") 40 | 41 | 42 | def parse_args(): 43 | parser = argparse.ArgumentParser(description="Fine-tune a Transformers model on an image classification dataset") 44 | parser.add_argument( 45 | "--comp", 46 | action="store_true", 47 | help="If passed, estimate computation time (without communication).", 48 | ) 49 | parser.add_argument("--validation_dir", type=str, default=None, help="A folder containing the validation data.") 50 | parser.add_argument( 51 | "--max_eval_samples", 52 | type=int, 53 | default=None, 54 | help=( 55 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 56 | "value if set." 57 | ), 58 | ) 59 | parser.add_argument( 60 | "--model_name_or_path", 61 | type=str, 62 | help="Path to pretrained model or model identifier from huggingface.co/models.", 63 | default="google/vit-base-patch16-224-in21k", 64 | ) 65 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible result.") 66 | args = parser.parse_args() 67 | 68 | return args 69 | 70 | 71 | def main(): 72 | args = parse_args() 73 | 74 | # Make one log on every process with the configuration for debugging. 75 | logging.basicConfig( 76 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 77 | datefmt="%m/%d/%Y %H:%M:%S", 78 | level=logging.INFO, 79 | ) 80 | datasets.utils.logging.set_verbosity_warning() 81 | transformers.utils.logging.set_verbosity_info() 82 | 83 | config = AutoConfig.from_pretrained( 84 | args.model_name_or_path, 85 | num_labels=1000, 86 | finetuning_task="image-classification", 87 | ) 88 | model = AutoModelForImageClassification.from_pretrained( 89 | args.model_name_or_path, 90 | from_tf=bool(".ckpt" in args.model_name_or_path), 91 | config=config, 92 | ) 93 | 94 | device = "cuda" 95 | ct.init() 96 | 97 | dummy_input = torch.rand([1, 3, 224, 224]) 98 | private_model = ct.nn.from_pytorch(model, dummy_input).encrypt().to(device) 99 | 100 | model.eval() 101 | for _ in range(args.max_eval_samples): 102 | input = torch.rand([1, 3, 224, 224], device=device) 103 | input_enc = ct.cryptensor(input).to(device) 104 | with ct.no_grad(): 105 | private_model(input_enc) 106 | 107 | 108 | if __name__ == "__main__": 109 | args = parse_args() 110 | if args.comp: 111 | # run without communication 112 | with cfg.temp_override({"cost.estimate_cost": True, "cost.estimate_mode": "comp"}): 113 | main() 114 | else: 115 | # run with communication 116 | launcher = MultiProcessLauncher(2, main) 117 | launcher.start() 118 | launcher.join() 119 | launcher.terminate() 120 | -------------------------------------------------------------------------------- /crypten/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import torch 10 | from torch.optim.optimizer import required 11 | 12 | 13 | class Optimizer(torch.optim.Optimizer): 14 | r"""Base class for all optimizers. 15 | .. warning:: 16 | Parameters need to be specified as collections that have a deterministic 17 | ordering that is consistent between runs. Examples of objects that don't 18 | satisfy those properties are sets and iterators over values of dictionaries. 19 | Arguments: 20 | params (iterable): an iterable of :class:`torch.Tensor` s, 21 | :class:`dict` s, or :class:`crypten.CrypTensor`s. Specifies what Tensors 22 | should be optimized. 23 | defaults: (dict): a dict containing default values of optimization 24 | options (used when a parameter group doesn't specify them). 25 | 26 | Note: This optimizer is adapted from torch.optim.Optimizer to work with CrypTensors 27 | """ 28 | 29 | def add_param_group(self, param_group): 30 | r"""Add a param group to the :class:`Optimizer` s `param_groups`. 31 | This can be useful when fine tuning a pre-trained network as frozen layers can be made 32 | trainable and added to the :class:`Optimizer` as training progresses. 33 | Arguments: 34 | param_group (dict): Specifies what Tensors should be optimized along with group 35 | specific optimization options. 36 | """ 37 | assert isinstance(param_group, dict), "param group must be a dict" 38 | 39 | params = param_group["params"] 40 | if isinstance(params, (torch.Tensor, crypten.CrypTensor)): 41 | param_group["params"] = [params] 42 | elif isinstance(params, set): 43 | raise TypeError( 44 | "optimizer parameters need to be organized in ordered collections, but " 45 | "the ordering of tensors in sets will change between runs. Please use a list instead." 46 | ) 47 | else: 48 | param_group["params"] = list(params) 49 | 50 | for param in param_group["params"]: 51 | if not isinstance(param, (torch.Tensor, crypten.CrypTensor)): 52 | raise TypeError( 53 | "optimizer can only optimize Tensors, " 54 | "but one of the params is " + torch.typename(param) 55 | ) 56 | 57 | for name, default in self.defaults.items(): 58 | if default is required and name not in param_group: 59 | raise ValueError( 60 | "parameter group didn't specify a value of required optimization parameter " 61 | + name 62 | ) 63 | else: 64 | param_group.setdefault(name, default) 65 | 66 | self.param_groups.append(param_group) 67 | 68 | def zero_grad(self, set_to_none=True): 69 | r"""Sets the gradients of all optimized parameters to zero or None. 70 | Args: 71 | set_to_none (bool): instead of setting to zero, set the grads to None. 72 | This will in general have lower memory footprint, and can modestly improve performance. 73 | However, it changes certain behaviors. For example: 74 | 1. When the user tries to access a gradient and perform manual ops on it, 75 | a None attribute or a Tensor full of 0s will behave differently. 76 | 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s 77 | are guaranteed to be None for params that did not receive a gradient. 78 | 3. ``crypten.optim`` optimizers have a different behavior if the gradient is 0 or None 79 | (in one case it does the step with a gradient of 0 and in the other it skips 80 | the step altogether). 81 | 82 | Note that CrypTen differs from PyTorch by setting the default value of `set_to_none` to True. 83 | This is because in CrypTen, it is often advantageous to set to None rather than to a zero-valued 84 | CrypTensor. 85 | """ 86 | if set_to_none: 87 | for group in self.param_groups: 88 | for param in group["params"]: 89 | param.grad = None 90 | else: 91 | for group in self.param_groups: 92 | for param in group["params"]: 93 | if param.grad is not None: 94 | param.grad -= param.grad 95 | -------------------------------------------------------------------------------- /crypten/common/functions/power.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import torch 10 | 11 | from ..tensor_types import is_tensor 12 | 13 | __all__ = ["norm", "polynomial", "pos_pow", "pow"] 14 | 15 | 16 | def pow(self, p, **kwargs): 17 | """ 18 | Computes an element-wise exponent `p` of a tensor, where `p` is an 19 | integer. 20 | """ 21 | if isinstance(p, float) and int(p) == p: 22 | p = int(p) 23 | 24 | if not isinstance(p, int): 25 | raise TypeError( 26 | "pow must take an integer exponent. For non-integer powers, use" 27 | " pos_pow with positive-valued base." 28 | ) 29 | if p < -1: 30 | return self.reciprocal().pow(-p) 31 | elif p == -1: 32 | return self.reciprocal() 33 | elif p == 0: 34 | # Note: This returns 0 ** 0 -> 1 when inputs have zeros. 35 | # This is consistent with PyTorch's pow function. 36 | return self.new(torch.ones_like(self.data)) 37 | elif p == 1: 38 | return self.clone() 39 | elif p == 2: 40 | return self.square() 41 | elif p % 2 == 0: 42 | return self.square().pow(p // 2) 43 | else: 44 | x = self.square().mul_(self) 45 | return x.pow((p - 1) // 2) 46 | 47 | 48 | def pos_pow(self, p): 49 | """ 50 | Approximates self ** p by computing: :math:`x^p = exp(p * log(x))` 51 | 52 | Note that this requires that the base `self` contain only positive values 53 | since log can only be computed on positive numbers. 54 | 55 | Note that the value of `p` can be an integer, float, public tensor, or 56 | encrypted tensor. 57 | """ 58 | if isinstance(p, int) or (isinstance(p, float) and int(p) == p): 59 | return self.pow(p) 60 | return self.log().mul_(p).exp() 61 | 62 | 63 | def polynomial(self, coeffs, func="mul"): 64 | """Computes a polynomial function on a tensor with given coefficients, 65 | `coeffs`, that can be a list of values or a 1-D tensor. 66 | 67 | Coefficients should be ordered from the order 1 (linear) term first, 68 | ending with the highest order term. (Constant is not included). 69 | """ 70 | # Coefficient input type-checking 71 | if isinstance(coeffs, list): 72 | coeffs = torch.tensor(coeffs, device=self.device) 73 | assert is_tensor(coeffs) or crypten.is_encrypted_tensor( 74 | coeffs 75 | ), "Polynomial coefficients must be a list or tensor" 76 | assert coeffs.dim() == 1, "Polynomial coefficients must be a 1-D tensor" 77 | 78 | # Handle linear case 79 | if coeffs.size(0) == 1: 80 | return self.mul(coeffs) 81 | 82 | # Compute terms of polynomial using exponentially growing tree 83 | terms = crypten.stack([self, self.square()]) 84 | while terms.size(0) < coeffs.size(0): 85 | highest_term = terms.index_select( 86 | 0, torch.tensor(terms.size(0) - 1, device=self.device) 87 | ) 88 | new_terms = getattr(terms, func)(highest_term) 89 | terms = crypten.cat([terms, new_terms]) 90 | 91 | # Resize the coefficients for broadcast 92 | terms = terms[: coeffs.size(0)] 93 | for _ in range(terms.dim() - 1): 94 | coeffs = coeffs.unsqueeze(1) 95 | 96 | # Multiply terms by coefficients and sum 97 | return terms.mul(coeffs).sum(0) 98 | 99 | 100 | def norm(self, p="fro", dim=None, keepdim=False): 101 | """Computes the p-norm of the input tensor (or along a dimension).""" 102 | if p == "fro": 103 | p = 2 104 | 105 | if isinstance(p, (int, float)): 106 | assert p >= 1, "p-norm requires p >= 1" 107 | if p == 1: 108 | if dim is None: 109 | return self.abs().sum() 110 | return self.abs().sum(dim, keepdim=keepdim) 111 | elif p == 2: 112 | if dim is None: 113 | return self.square().sum().sqrt() 114 | return self.square().sum(dim, keepdim=keepdim).sqrt() 115 | elif p == float("inf"): 116 | if dim is None: 117 | return self.abs().max() 118 | return self.abs().max(dim=dim, keepdim=keepdim)[0] 119 | else: 120 | if dim is None: 121 | return self.abs().pos_pow(p).sum().pos_pow(1 / p) 122 | return self.abs().pos_pow(p).sum(dim, keepdim=keepdim).pos_pow(1 / p) 123 | elif p == "nuc": 124 | raise NotImplementedError("Nuclear norm is not implemented") 125 | else: 126 | raise ValueError(f"Improper value p ({p})for p-norm") 127 | -------------------------------------------------------------------------------- /crypten/mpc/primitives/replicated.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | # This file implements Replicated Secret Sharing protocols 9 | # from the CryptGPU repo 10 | 11 | import crypten.communicator as comm 12 | import torch 13 | 14 | 15 | def replicate_shares(share_list): 16 | world_size = comm.get().get_world_size() 17 | if world_size < 3: 18 | raise ValueError("Cannot utilize Replicated Sharing securely with < 3 parties.") 19 | rank = comm.get().get_rank() 20 | prev_rank = (rank - 1) % world_size 21 | next_rank = (rank + 1) % world_size 22 | 23 | reqs = [] 24 | rep_shares = [] 25 | for share in share_list: 26 | rep_shares.append(torch.zeros_like(share)) 27 | 28 | send_req = comm.get().isend(share.contiguous(), dst=next_rank) 29 | recv_req = comm.get().irecv(rep_shares[-1], src=prev_rank) 30 | 31 | reqs.extend([send_req, recv_req]) 32 | 33 | for req in reqs: 34 | req.wait() 35 | 36 | # Order [(x1, x2), (y1, y2), ...] 37 | shares = [(share_list[i], rep_shares[i]) for i in range(len(share_list))] 38 | 39 | return shares 40 | 41 | 42 | def __replicated_secret_sharing_protocol(op, x, y, *args, **kwargs): 43 | """Implements bilinear functions using replicated secret shares. 44 | Shares are input as ArithmeticSharedTensors and are replicated 45 | within this function to perform computations. 46 | 47 | The protocol used here is that of section 3.2 of ABY3 48 | (https://eprint.iacr.org/2018/403.pdf). 49 | """ 50 | assert op in { 51 | "mul", 52 | "matmul", 53 | "conv1d", 54 | "conv2d", 55 | "conv_transpose1d", 56 | "conv_transpose2d", 57 | } 58 | x_shares, y_shares = replicate_shares([x.share, y.share]) 59 | x1, x2 = x_shares 60 | y1, y2 = y_shares 61 | 62 | z = x.shallow_copy() 63 | z.share = getattr(torch, op)(x1, y1, *args, **kwargs) 64 | z.share += getattr(torch, op)(x1, y2, *args, **kwargs) 65 | z.share += getattr(torch, op)(x2, y1, *args, **kwargs) 66 | 67 | return z 68 | 69 | 70 | def mul(x, y): 71 | return __replicated_secret_sharing_protocol("mul", x, y) 72 | 73 | 74 | def matmul(x, y): 75 | return __replicated_secret_sharing_protocol("matmul", x, y) 76 | 77 | 78 | def conv1d(x, y, **kwargs): 79 | return __replicated_secret_sharing_protocol("conv1d", x, y, **kwargs) 80 | 81 | 82 | def conv2d(x, y, **kwargs): 83 | return __replicated_secret_sharing_protocol("conv2d", x, y, **kwargs) 84 | 85 | 86 | def conv_transpose1d(x, y, **kwargs): 87 | return __replicated_secret_sharing_protocol("conv_transpose1d", x, y, **kwargs) 88 | 89 | 90 | def conv_transpose2d(x, y, **kwargs): 91 | return __replicated_secret_sharing_protocol("conv_transpose2d", x, y, **kwargs) 92 | 93 | 94 | def square(x): 95 | (x_shares,) = replicate_shares([x.share]) 96 | x1, x2 = x_shares 97 | 98 | x_square = x1**2 + 2 * x1 * x2 99 | 100 | z = x.shallow_copy() 101 | z.share = x_square 102 | return z 103 | 104 | 105 | def truncate(x, y): 106 | """Protocol to divide an ArithmeticSharedTensor `x` by a constant integer `y` 107 | using RSS (see ABY3 Figure 2: https://eprint.iacr.org/2018/403.pdf). 108 | 109 | Note: This is currently supported under 3PC only. This is because the protocol 110 | requires 2-out-of-N secret sharing since only 2 parties can perform division to 111 | provide statistical guarantees equivalent to 2-out-of-2 truncation. 112 | """ 113 | if comm.get().get_world_size() != 3: 114 | raise NotImplementedError( 115 | "RSS truncation is only implemented for world_size == 3." 116 | ) 117 | 118 | rank = x.rank 119 | 120 | if rank == 0: 121 | x.share = x.share.div(y, rounding_mode="trunc") 122 | elif rank == 1: 123 | x2 = comm.get().recv(x.share, 2) 124 | x.share = x.share.add(x2).div(y, rounding_mode="trunc") 125 | elif rank == 2: 126 | comm.get().send(x.share, 1) 127 | x.share -= x.share 128 | 129 | # Add PRZS - this takes the place of r 130 | x.share += x.PRZS(x.size(), device=x.device).share 131 | 132 | return x 133 | 134 | 135 | def AND(x, y): 136 | from .binary import BinarySharedTensor 137 | 138 | x_share = x 139 | y_share = y 140 | if isinstance(x, BinarySharedTensor): 141 | x_share = x.share 142 | y_share = y.share 143 | 144 | x_shares, y_shares = replicate_shares([x_share, y_share]) 145 | x1, x2 = x_shares 146 | y1, y2 = y_shares 147 | 148 | z = x.shallow_copy() 149 | z.share = (x1 & y1) ^ (x2 & y1) ^ (x1 & y2) 150 | 151 | return z 152 | -------------------------------------------------------------------------------- /crypten/debug/debug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import pdb as pythondebugger 10 | import sys 11 | 12 | from crypten.config import cfg 13 | 14 | 15 | class MultiprocessingPdb(pythondebugger.Pdb): 16 | """A Pdb subclass that may be used 17 | from a forked multiprocessing child 18 | 19 | """ 20 | 21 | def interaction(self, *args, **kwargs): 22 | _stdin = sys.stdin 23 | try: 24 | with open("/dev/stdin") as file: 25 | sys.stdin = file 26 | pythondebugger.Pdb.interaction(self, *args, **kwargs) 27 | finally: 28 | sys.stdin = _stdin 29 | 30 | 31 | def configure_logging(): 32 | """Configures a logging template useful for debugging multiple processes.""" 33 | 34 | level = logging.INFO 35 | logging.getLogger().setLevel(level) 36 | logging.basicConfig( 37 | level=level, 38 | format=( 39 | "[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d]" 40 | + "[%(processName)s] %(message)s" 41 | ), 42 | ) 43 | 44 | 45 | def crypten_print(*args, dst=0, **kwargs): 46 | """ 47 | Prints a message to only parties whose rank is contained by `dst` kwarg (default: 0). 48 | """ 49 | if isinstance(dst, int): 50 | dst = [dst] 51 | assert isinstance( 52 | dst, (list, tuple) 53 | ), "print destination must be a list or tuple of party ranks" 54 | import crypten.communicator as comm 55 | 56 | if comm.get().get_rank() in dst: 57 | print(*args, **kwargs) 58 | 59 | 60 | def crypten_log(*args, level=logging.INFO, dst=0, **kwargs): 61 | """ 62 | Logs a message to logger of parties whose rank is contained by `dst` kwarg (default: 0). 63 | 64 | Uses logging.INFO as default level. 65 | """ 66 | if isinstance(dst, int): 67 | dst = [dst] 68 | assert isinstance( 69 | dst, (list, tuple) 70 | ), "log destination must be a list or tuple of party ranks" 71 | import crypten.communicator as comm 72 | 73 | if comm.get().get_rank() in dst: 74 | logging.log(level, *args, **kwargs) 75 | 76 | 77 | def crypten_print_in_order(*args, **kwargs): 78 | """ 79 | Calls print(*args, **kwargs) on each party in rank order to ensure each party 80 | can print its full message uninterrupted and the full output is deterministic 81 | """ 82 | import crypten.communicator as comm 83 | 84 | for i in range(comm.get().get_world_size()): 85 | if comm.get().get_rank() == i: 86 | print(*args, **kwargs) 87 | comm.get().barrier() 88 | 89 | 90 | def validate_correctness(self, func, func_name, tolerance=0.5): 91 | import crypten 92 | import torch 93 | 94 | if not hasattr(torch.tensor([]), func_name): 95 | return func 96 | 97 | def validation_function(*args, **kwargs): 98 | with cfg.temp_override({"debug.validation_mode": False}): 99 | # Compute crypten result 100 | result_enc = func(*args, **kwargs) 101 | result = ( 102 | result_enc.get_plain_text() 103 | if crypten.is_encrypted_tensor(result_enc) 104 | else result_enc 105 | ) 106 | 107 | args = list(args) 108 | 109 | # Compute torch result for corresponding function 110 | for i, arg in enumerate(args): 111 | if crypten.is_encrypted_tensor(arg): 112 | args[i] = args[i].get_plain_text() 113 | 114 | kwargs.pop("input_in_01", None) 115 | for key, value in kwargs.items(): 116 | if crypten.is_encrypted_tensor(value): 117 | kwargs[key] = value.get_plain_text() 118 | reference = getattr(self.get_plain_text(), func_name)(*args, **kwargs) 119 | 120 | # TODO: Validate properties - Issue is tuples can contain encrypted tensors 121 | if not torch.is_tensor(reference): 122 | return result_enc 123 | 124 | # Check sizes match 125 | if result.size() != reference.size(): 126 | crypten_log( 127 | f"Size mismatch: Expected {reference.size()} but got {result.size()}" 128 | ) 129 | raise ValueError(f"Function {func_name} returned incorrect size") 130 | 131 | # Check that results match 132 | diff = (result - reference).abs_() 133 | norm_diff = diff.div(result.abs() + reference.abs()).abs_() 134 | test_passed = norm_diff.le(tolerance) + diff.le(tolerance * 0.1) 135 | test_passed = test_passed.gt(0).all().item() == 1 136 | if not test_passed: 137 | crypten_log(f"Function {func_name} returned incorrect values") 138 | crypten_log("Result %s" % result) 139 | crypten_log("Result - Reference = %s" % (result - reference)) 140 | raise ValueError(f"Function {func_name} returned incorrect values") 141 | 142 | return result_enc 143 | 144 | return validation_function 145 | -------------------------------------------------------------------------------- /examples/ttp-test/run_test_ttp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import crypten 4 | 5 | crypten.cfg.communicator.verbose = True 6 | 7 | 8 | def test_gelu(runs: int = 10, device: str = "cpu"): 9 | gelu_approximate = "none" 10 | 11 | x = torch.arange(-5, 5, 0.001) 12 | y_original = torch.nn.functional.gelu(x, approximate=gelu_approximate) 13 | y_actual = crypten.cryptensor(x).gelu(approximate=gelu_approximate).get_plain_text() 14 | max_err = (y_original - y_actual).abs().max() 15 | min_err = (y_original - y_actual).abs().min() 16 | avg_err = (y_original - y_actual).abs().mean() 17 | 18 | size = (128, 3072) 19 | x = crypten.cryptensor(torch.zeros(size), device=device) 20 | crypten.reset_communication_stats() 21 | start_time = time.time() 22 | 23 | for _ in range(runs): 24 | x.gelu(approximate=gelu_approximate) 25 | comm_time = time.time() - start_time 26 | stats = crypten.get_communication_stats() 27 | comm_bytes = stats["bytes"] 28 | comm_rounds = stats["rounds"] 29 | 30 | log = f"gelu: max error: {max_err}, avg error: {avg_err}, min error: {min_err} " 31 | log += f"{size} time: {comm_time / runs}s, bytes: {comm_bytes / (2 ** 20) / runs} MB, rounds: {comm_rounds / runs}" 32 | log = f"[rank {crypten.comm.get().get_rank()}] [provider {crypten.cfg.mpc.provider}] " + log 33 | print(log) 34 | 35 | 36 | def test_softmax(runs: int = 10, device: str = "cpu"): 37 | x = torch.arange(-5, 5, 0.001) 38 | y_original = torch.nn.functional.softmax(x, -1) 39 | y_actual = crypten.cryptensor(x).softmax(-1).get_plain_text() 40 | max_err = (y_original - y_actual).abs().max() 41 | avg_err = (y_original - y_actual).abs().mean() 42 | min_err = (y_original - y_actual).abs().min() 43 | 44 | size = (12, 128, 128) 45 | x = crypten.cryptensor(torch.zeros(size), device=device) 46 | crypten.reset_communication_stats() 47 | start_time = time.time() 48 | 49 | for _ in range(runs): 50 | y = x.softmax(-1) 51 | comm_time = time.time() - start_time 52 | stats = crypten.get_communication_stats() 53 | comm_bytes = stats["bytes"] 54 | comm_rounds = stats["rounds"] 55 | 56 | log = f"softmax: max error: {max_err}, avg error: {avg_err}, min error: {min_err} " 57 | log += f"{size} time: {comm_time / runs}s, bytes: {comm_bytes / (2 ** 20) / runs} MB, rounds: {comm_rounds / runs}" 58 | log = f"[rank {crypten.comm.get().get_rank()}] [provider {crypten.cfg.mpc.provider}] " + log 59 | print(log) 60 | 61 | 62 | def test_embedding(runs: int = 10, device: str = "cuda"): 63 | num_embeddings = 32 64 | embedding_dim = 768 65 | 66 | ct_emb = crypten.nn.module.Embedding().to(device) 67 | pt_emb = torch.nn.Embedding(num_embeddings, embedding_dim, device=device) 68 | pt_x = torch.randint(0, num_embeddings, (1, 128), device=device) 69 | y_original = pt_emb.forward(pt_x) 70 | y_actual = ct_emb.forward( 71 | (crypten.cryptensor(pt_emb.weight), crypten.cryptensor(pt_x), None) 72 | ).get_plain_text() 73 | max_err = (y_original - y_actual).abs().max() 74 | avg_err = (y_original - y_actual).abs().mean() 75 | min_err = (y_original - y_actual).abs().min() 76 | print(f"[rank {crypten.comm.get().get_rank()}] [provider {crypten.cfg.mpc.provider}] " 77 | f"Embedding max error: {max_err:.4f}, avg error: {avg_err:.6f}, min error: {min_err:.6f}") 78 | embedding_time = {} 79 | embedding_bytes = {} 80 | embedding_rounds = {} 81 | 82 | embedding_sizes = [(1, 128), (1, 256)] 83 | pt_emb = torch.nn.Embedding(num_embeddings, embedding_dim).to(device) 84 | for embedding_size in embedding_sizes: 85 | embedding_in = crypten.cryptensor(torch.randint(0, num_embeddings, embedding_size), device=device) 86 | crypten.reset_communication_stats() 87 | start_time = time.time() 88 | 89 | for _ in range(runs): 90 | ct_emb.forward( 91 | (crypten.cryptensor(pt_emb.weight), embedding_in, None) 92 | ) 93 | embedding_time[embedding_size[1]] = time.time() - start_time 94 | stats = crypten.get_communication_stats() 95 | embedding_bytes[embedding_size[1]] = stats["bytes"] 96 | embedding_rounds[embedding_size[1]] = stats["rounds"] 97 | 98 | for embedding_size in embedding_sizes: 99 | print(f"[rank {crypten.comm.get().get_rank()}] [provider {crypten.cfg.mpc.provider}] " 100 | f"Embedding ({embedding_size[0]}, {embedding_size[1]}) " 101 | f"time: {embedding_time[embedding_size[1]] / runs:.4f}s, " 102 | f"bytes: {embedding_bytes[embedding_size[1]] / 2 ** 20 / runs:.0f} MB, " 103 | f"rounds: {embedding_rounds[embedding_size[1]] / runs:.0f}" 104 | ) 105 | 106 | 107 | @crypten.mpc.context.run_multiprocess(2) 108 | def main(device: str = "cuda"): 109 | test_gelu(device=device) 110 | test_softmax(device=device) 111 | test_embedding(device=device) 112 | 113 | 114 | if __name__ == "__main__": 115 | torch.manual_seed(101) 116 | torch.random.manual_seed(101) 117 | torch.cuda.manual_seed(101) 118 | torch.cuda.random.manual_seed(101) 119 | crypten.cfg.mpc.provider = "TTP" 120 | main() 121 | crypten.cfg.mpc.provider = "TFP" 122 | main() 123 | -------------------------------------------------------------------------------- /crypten/mpc/primitives/circuit.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import functools 9 | import math 10 | 11 | import torch 12 | 13 | # Cache masks and constants to skip computation during each call 14 | __BITS = torch.iinfo(torch.long).bits 15 | __LOG_BITS = int(math.log2(torch.iinfo(torch.long).bits)) 16 | 17 | 18 | @functools.lru_cache(maxsize=None) 19 | def __SPK_circuit_constants(device): 20 | """ 21 | Generate the __MASKS, __OUT_MASKS, and __MULTIPLIERS constants 22 | used by __SPK_circuit. 23 | """ 24 | # Cached SPK masks are: 25 | # [0] -> 010101010101....0101 = 01 x 32 26 | # [1] -> 001000100010....0010 = 0010 x 16 27 | # [2] -> 000010000000....0010 = 00001000 x 8 28 | # [n] -> [2^n 0s, 1, (2^n -1) 0s] x (32 / (2^n)) 29 | __MASKS = torch.tensor( 30 | [ 31 | 6148914691236517205, 32 | 2459565876494606882, 33 | 578721382704613384, 34 | 36029346783166592, 35 | 140737488388096, 36 | 2147483648, 37 | ], 38 | dtype=torch.long, 39 | device=device, 40 | ) 41 | 42 | __MULTIPLIERS = torch.tensor( 43 | [(1 << (2**iter + 1)) - 2 for iter in range(__LOG_BITS)], device=device 44 | ) 45 | __OUT_MASKS = __MASKS * __MULTIPLIERS 46 | 47 | return __MASKS, __OUT_MASKS, __MULTIPLIERS 48 | 49 | 50 | def __SPK_circuit(S, P): 51 | """ 52 | Computes the Set-Propagate-Kill Tree circuit for a set (S, P) 53 | (K is implied by S, P since (SPK) is one-hot) 54 | 55 | (See section 6.3 of Damgard, "Unconditionally Secure Constant-Rounds 56 | Multi-Party Computation for Equality, Comparison, Bits and Exponentiation") 57 | 58 | At each stage: 59 | S <- S0 ^ (P0 & S1) 60 | P <- P0 & P1 61 | K <- K0 ^ (P0 & K1) <- don't need K since it is implied by S and P 62 | """ 63 | from .binary import BinarySharedTensor 64 | 65 | # Vectorize private AND calls to reduce rounds: 66 | SP = BinarySharedTensor.stack([S, P]) 67 | 68 | __MASKS, __OUT_MASKS, __MULTIPLIERS = __SPK_circuit_constants(SP.device) 69 | 70 | # fmt: off 71 | # Tree reduction circuit 72 | for i in range(__LOG_BITS): 73 | in_mask = __MASKS[i] # Start of arrows 74 | out_mask = __OUT_MASKS[i] # End of arrows 75 | not_out_mask = out_mask ^ -1 # Not (end of arrows) 76 | 77 | # Set up S0, S1, P0, and P1 78 | P0 = SP[1] & out_mask # Mask P0 from P 79 | S1P1 = SP & in_mask # Mask S1P1 from SP 80 | S1P1._tensor *= __MULTIPLIERS[i] # Fan out S1P1 along arrows 81 | 82 | # Update S and P 83 | update = P0 & S1P1 # S0 ^= P0 & S1, P0 = P0 & P1 84 | SP[1] &= not_out_mask 85 | SP ^= update 86 | # fmt: on 87 | return SP[0], SP[1] 88 | 89 | 90 | def __P_circuit(P): 91 | """ 92 | Computes the Propagate Tree circuit for input P. 93 | The P circuit will return 1 only if the binary of 94 | the input is all ones (i.e. the value is -1). 95 | 96 | Otherwise this circuit returns 0 97 | 98 | At each stage: 99 | P <- P0 & P1 100 | """ 101 | shift = __BITS // 2 102 | for _ in range(__LOG_BITS): 103 | P &= P << shift # using lshift since rshift was modified to arithmetic 104 | shift //= 2 105 | return P 106 | 107 | 108 | def __flip_sign_bit(x): 109 | return x ^ -(2**63) 110 | 111 | 112 | def __get_sign_bit(x): 113 | from .binary import BinarySharedTensor 114 | 115 | y = x >> 63 116 | 117 | # NOTE: __rshift__ was changed to arithmetic shift 118 | if isinstance(y, BinarySharedTensor): 119 | y.share = y.share.eq(-1).long() 120 | else: 121 | y = y.eq(-1).long() 122 | return y 123 | 124 | 125 | def add(x, y): 126 | """Returns x + y from BinarySharedTensors `x` and `y`""" 127 | S = x & y 128 | P = x ^ y 129 | carry, _ = __SPK_circuit(S, P) 130 | return P ^ (carry << 1) 131 | 132 | 133 | def eq(x, y): 134 | """Returns x == y from BinarySharedTensors `x` and `y`""" 135 | bitwise_equal = ~(x ^ y) 136 | P = __P_circuit(bitwise_equal) 137 | return __get_sign_bit(P) 138 | 139 | 140 | def lt(x, y): 141 | """Returns x < y from BinarySharedTensors `x` and `y`""" 142 | x, y = __flip_sign_bit(x), __flip_sign_bit(y) 143 | 144 | S = y & ~x 145 | P = ~(x ^ y) 146 | S, _ = __SPK_circuit(S, P) 147 | return __get_sign_bit(S) 148 | 149 | 150 | def le(x, y): 151 | """Returns x <= y from BinarySharedTensors `x` and `y`""" 152 | x, y = __flip_sign_bit(x), __flip_sign_bit(y) 153 | 154 | S = y & ~x 155 | P = ~(x ^ y) 156 | S, P = __SPK_circuit(S, P) 157 | return __get_sign_bit(S ^ P) 158 | 159 | 160 | def gt(x, y): 161 | """Returns x > y from BinarySharedTensors `x` and `y`""" 162 | x, y = __flip_sign_bit(x), __flip_sign_bit(y) 163 | 164 | S = x & ~y 165 | P = ~(x ^ y) 166 | S, _ = __SPK_circuit(S, P) 167 | return __get_sign_bit(S) 168 | 169 | 170 | def ge(x, y): 171 | """Returns x >= y from BinarySharedTensors `x` and `y`""" 172 | x, y = __flip_sign_bit(x), __flip_sign_bit(y) 173 | 174 | S = x & ~y 175 | P = ~(x ^ y) 176 | S, P = __SPK_circuit(S, P) 177 | return __get_sign_bit(S ^ P) 178 | -------------------------------------------------------------------------------- /crypten/mpc/primitives/ot/baseOT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import random 9 | from hashlib import sha256 10 | from typing import List 11 | 12 | import crypten.communicator as comm 13 | 14 | 15 | """ 16 | I dont think random modular is secure enough, but we can live with it for testing purpose 17 | 18 | """ 19 | 20 | 21 | class BaseOT: 22 | """ 23 | hardcoded public parameter 24 | log2(__prime) > 128 25 | 26 | __generator is a primitive root of __prime 27 | """ 28 | 29 | __prime = 631276824160446938136046282957027762913 30 | __generator = 3 31 | __inverse__generator = pow(__generator, (__prime - 2), __prime) 32 | 33 | @staticmethod 34 | def string_xor(s1, s2): 35 | """ 36 | XOR of two strings 37 | """ 38 | return "".join(chr(ord(a) ^ ord(b)) for a, b in zip(s1, s2)) 39 | 40 | def __init__(self, partner_rank): 41 | self.partner_rank = partner_rank 42 | return 43 | 44 | def send(self, message0s: List[str], message1s: List[str]): 45 | """ 46 | sender's input is two message lists 47 | """ 48 | if len(message0s) != len(message1s): 49 | raise ("inconsistent input size!") 50 | 51 | alphas = [] 52 | masks_for_message1s = [] 53 | for _i in range(len(message1s)): 54 | # pick a random element from Z_p 55 | alpha = random.randint(0, self.__prime - 1) 56 | alphas.append(alpha) 57 | 58 | # g^\alpha 59 | mask_for_message1 = pow(self.__generator, alpha, self.__prime) 60 | masks_for_message1s.append(mask_for_message1) 61 | 62 | # send mask_for_message1 63 | for i in range(len(message1s)): 64 | comm.get().send_obj(masks_for_message1s[i], self.partner_rank) 65 | 66 | # compute (g^\alpha)^-\alpha when waiting for response 67 | # (g^-1)^(\alpha^2) = (g^-1)^(\alpha^2 mod (p-1)) 68 | dividers = [] 69 | for i in range(len(message1s)): 70 | divider = pow( 71 | self.__inverse__generator, 72 | alphas[i] * alphas[i] % (self.__prime - 1), 73 | self.__prime, 74 | ) 75 | dividers.append(divider) 76 | 77 | masks_for_choices = [] 78 | 79 | # recv mask_for_choice 80 | for _i in range(len(message1s)): 81 | mask_for_choice = comm.get().recv_obj(self.partner_rank) 82 | masks_for_choices.append(mask_for_choice) 83 | 84 | for i in range(len(message1s)): 85 | masks_for_choices[i] = pow(masks_for_choices[i], alphas[i], self.__prime) 86 | 87 | # hash 88 | pad0 = sha256(str(masks_for_choices[i]).encode("utf-8")).hexdigest() 89 | pad1 = sha256( 90 | str(masks_for_choices[i] * dividers[i] % self.__prime).encode("utf-8") 91 | ).hexdigest() 92 | 93 | if len(pad0) < len(message0s[i]): 94 | raise (str(i) + "-th message0 is too long") 95 | if len(pad1) < len(message1s[i]): 96 | raise (str(i) + "-th message1 is too long") 97 | # encrypt with one time pad 98 | message0_enc = self.string_xor(pad0, message0s[i]) 99 | message1_enc = self.string_xor(pad1, message1s[i]) 100 | 101 | # send message0, message1 102 | comm.get().send_obj(message0_enc, self.partner_rank) 103 | comm.get().send_obj(message1_enc, self.partner_rank) 104 | 105 | def receive(self, choices: List[bool]): 106 | """ 107 | choice: 108 | false: pick message0 109 | true: pick message1 110 | """ 111 | 112 | betas = [] 113 | masks_for_choices = [] 114 | for _i in range(len(choices)): 115 | # pick a random element from Z_p 116 | beta = random.randint(0, self.__prime - 1) 117 | mask_for_choice = pow(self.__generator, beta, self.__prime) 118 | betas.append(beta) 119 | masks_for_choices.append(mask_for_choice) 120 | 121 | masks_for_message1s = [] 122 | for i in range(len(choices)): 123 | # recv mask_for_message1 124 | mask_for_message1 = comm.get().recv_obj(self.partner_rank) 125 | masks_for_message1s.append(mask_for_message1) 126 | if choices[i]: 127 | masks_for_choices[i] = ( 128 | masks_for_choices[i] * mask_for_message1 129 | ) % self.__prime 130 | 131 | for i in range(len(choices)): 132 | # send mask_for_choice 133 | comm.get().send_obj(masks_for_choices[i], self.partner_rank) 134 | 135 | keys = [] 136 | for i in range(len(choices)): 137 | # compute the hash when waiting for response 138 | key = sha256( 139 | str(pow(masks_for_message1s[i], betas[i], self.__prime)).encode("utf-8") 140 | ).hexdigest() 141 | keys.append(key) 142 | 143 | rst = [] 144 | 145 | for i in range(len(choices)): 146 | # recv message0, message1 147 | message0_enc = comm.get().recv_obj(self.partner_rank) 148 | message1_enc = comm.get().recv_obj(self.partner_rank) 149 | 150 | if choices[i]: 151 | rst.append(self.string_xor(keys[i], message1_enc)) 152 | else: 153 | rst.append(self.string_xor(keys[i], message0_enc)) 154 | return rst 155 | -------------------------------------------------------------------------------- /crypten/common/util.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import abc 9 | import functools 10 | 11 | import numpy as np 12 | import torch 13 | from crypten.cuda import CUDALongTensor 14 | 15 | 16 | def count_wraps(share_list): 17 | """Computes the number of overflows or underflows in a set of shares 18 | 19 | We compute this by counting the number of overflows and underflows as we 20 | traverse the list of shares. 21 | """ 22 | result = torch.zeros_like(share_list[0], dtype=torch.long) 23 | prev = share_list[0] 24 | for cur in share_list[1:]: 25 | next = cur + prev 26 | result -= ((prev < 0) & (cur < 0) & (next > 0)).long() # underflow 27 | result += ((prev > 0) & (cur > 0) & (next < 0)).long() # overflow 28 | prev = next 29 | return result 30 | 31 | 32 | @functools.lru_cache(maxsize=10) 33 | def chebyshev_series(func, width, terms): 34 | r"""Computes Chebyshev coefficients 35 | 36 | For n = terms, the ith Chebyshev series coefficient is 37 | 38 | .. math:: 39 | c_i = 2/n \sum_{k=1}^n \cos(j(2k-1)\pi / 4n) f(w\cos((2k-1)\pi / 4n)) 40 | 41 | Args: 42 | func (function): function to be approximated 43 | width (int): approximation will support inputs in range [-width, width] 44 | terms (int): number of Chebyshev terms used in approximation 45 | 46 | Returns: 47 | Chebyshev coefficients with shape equal to num of terms. 48 | """ 49 | n_range = torch.arange(start=0, end=terms).float() 50 | x = width * torch.cos((n_range + 0.5) * np.pi / terms) 51 | y = func(x) 52 | cos_term = torch.cos(torch.ger(n_range, n_range + 0.5) * np.pi / terms) 53 | coeffs = (2 / terms) * torch.sum(y * cos_term, axis=1) 54 | return coeffs 55 | 56 | @functools.lru_cache(maxsize=10) 57 | def fourier_series(func, width, terms, step=1e-6): 58 | r"""Computes Fourier coefficients 59 | 60 | For L = width, A = beta_cos, B = beta_sin, the Fourier series coefficient is 61 | 62 | .. math:: 63 | alpha = 1 / 2L * \int_{-L}^L f(x) dx, 64 | A_i = 1 / L * \int_{-L}^L f(x) \cos((n\pi x) / L) dx, 65 | B_i = 1 / L * \int_{-L}^L f(x) \sin((n\pi x) / L) dx 66 | 67 | Args: 68 | func (function): function to be approximated 69 | width (int): approximation will support inputs in range [-width, width] 70 | terms (int): number of Fourier terms used in approximation 71 | step (float): interval for numerical integration 72 | 73 | Returns: 74 | Fourier coefficients with shape equal to num of terms. 75 | """ 76 | n_range = torch.arange(start=1, end=terms+1).float() 77 | x = torch.arange(start=-width, end=width, step=step).float() 78 | npx = np.pi * torch.unsqueeze(n_range, -1) @ torch.unsqueeze(x, 0) / width 79 | y, cos_npx, sin_npx = func(x), torch.cos(npx), torch.sin(npx) 80 | alpha = torch.trapz(y, x) / (2 * width) 81 | beta_cos = torch.trapz(y * cos_npx, x) / width 82 | beta_sin = torch.trapz(y * sin_npx, x) / width 83 | return alpha, beta_cos, beta_sin 84 | 85 | # FIXME: pytorch currently does not register `torch.cat` and 86 | # `torch.stack` in __torch_function__. We therefore can not call 87 | # torch.stack/torch.cat with CUDALongTensor as parameters. This is 88 | # a temporary solution before pytorch fix their issue. 89 | # See https://github.com/pytorch/pytorch/issues/34294 for details 90 | def torch_cat(tensors, dim=0, out=None): 91 | is_cuda = any(t.is_cuda for t in tensors) 92 | if is_cuda: 93 | return CUDALongTensor.cat(tensors, dim=dim, out=out) 94 | return torch.cat(tensors, dim=dim, out=out) 95 | 96 | 97 | def torch_stack(tensors, dim=0, out=None): 98 | is_cuda = any(t.is_cuda for t in tensors) 99 | if is_cuda: 100 | return CUDALongTensor.stack(tensors, dim=dim, out=out) 101 | return torch.stack(tensors, dim=dim, out=out) 102 | 103 | 104 | # TODO: Remove this function and change the calling locations accordingly. 105 | # See https://github.com/pytorch/pytorch/commit/445ee5620ec203cfccefd6f3dca4f0962a83b03e 106 | def _grad_input_padding( 107 | grad_output, input_size, stride, padding, kernel_size, dilation=None 108 | ): 109 | if dilation is None: 110 | # For backward compatibility 111 | dilation = [1] * len(stride) 112 | 113 | input_size = list(input_size) 114 | k = grad_output.dim() - 2 115 | 116 | if len(input_size) == k + 2: 117 | input_size = input_size[-k:] 118 | if len(input_size) != k: 119 | raise ValueError( 120 | "input_size must have {} elements (got {})".format(k + 2, len(input_size)) 121 | ) 122 | 123 | def dim_size(d): 124 | return ( 125 | (grad_output.size(d + 2) - 1) * stride[d] 126 | - 2 * padding[d] 127 | + 1 128 | + dilation[d] * (kernel_size[d] - 1) 129 | ) 130 | 131 | min_sizes = [dim_size(d) for d in range(k)] 132 | max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)] 133 | for size, min_size, max_size in zip(input_size, min_sizes, max_sizes): 134 | if size < min_size or size > max_size: 135 | raise ValueError( 136 | ( 137 | "requested an input grad size of {}, but valid sizes range " 138 | "from {} to {} (for a grad_output of {})" 139 | ).format(input_size, min_sizes, max_sizes, grad_output.size()[2:]) 140 | ) 141 | 142 | return tuple(input_size[d] - min_sizes[d] for d in range(k)) 143 | -------------------------------------------------------------------------------- /crypten/common/serial.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import builtins # noqa 9 | import collections 10 | import inspect 11 | import io 12 | import logging 13 | import pickle 14 | 15 | import torch 16 | 17 | 18 | def _safe_load_from_bytes(b): 19 | return _safe_legacy_load(io.BytesIO(b)) 20 | 21 | 22 | # Legacy code from torch._utils_internal 23 | def get_source_lines_and_file(obj, error_msg=None): 24 | """ 25 | Wrapper around inspect.getsourcelines and inspect.getsourcefile. 26 | 27 | Returns: (sourcelines, file_lino, filename) 28 | """ 29 | filename = None # in case getsourcefile throws 30 | try: 31 | filename = inspect.getsourcefile(obj) 32 | sourcelines, file_lineno = inspect.getsourcelines(obj) 33 | except OSError as e: 34 | msg = f"Can't get source for {obj}." 35 | if error_msg: 36 | msg += "\n" + error_msg 37 | raise OSError(msg) from e 38 | 39 | return sourcelines, file_lineno, filename 40 | 41 | 42 | class RestrictedUnpickler(pickle.Unpickler): 43 | __SAFE_CLASSES = { 44 | "torch.storage._load_from_bytes": _safe_load_from_bytes, 45 | } 46 | __ALLOWLIST = [ 47 | "builtins.set", 48 | "collections.OrderedDict", 49 | "torch.nn.modules.activation.LogSigmoid", 50 | "torch.nn.modules.activation.LogSoftmax", 51 | "torch.nn.modules.activation.ReLU", 52 | "torch.nn.modules.activation.Sigmoid", 53 | "torch.nn.modules.activation.Softmax", 54 | "torch.nn.modules.batchnorm.BatchNorm1d", 55 | "torch.nn.modules.batchnorm.BatchNorm2d", 56 | "torch.nn.modules.batchnorm.BatchNorm3d", 57 | "torch.nn.modules.conv.Conv1d", 58 | "torch.nn.modules.conv.Conv2d", 59 | "torch.nn.modules.conv.ConvTranspose1d", 60 | "torch.nn.modules.conv.ConvTranspose2d", 61 | "torch.nn.modules.dropout.Dropout2d", 62 | "torch.nn.modules.dropout.Dropout3d", 63 | "torch.nn.modules.flatten.Flatten", 64 | "torch.nn.modules.linear.Linear", 65 | "torch.nn.modules.loss.BCELoss", 66 | "torch.nn.modules.loss.BCEWithLogitsLoss", 67 | "torch.nn.modules.loss.CrossEntropyLoss", 68 | "torch.nn.modules.loss.L1Loss", 69 | "torch.nn.modules.loss.MSELoss", 70 | "torch.nn.modules.pooling.AvgPool2d", 71 | "torch.nn.modules.pooling.MaxPool2d", 72 | "torch._utils._rebuild_parameter", 73 | "torch._utils._rebuild_tensor_v2", 74 | "torch.Size", 75 | "torch.BFloat16Storage", 76 | "torch.BoolStorage", 77 | "torch.CharStorage", 78 | "torch.ComplexDoubleStorage", 79 | "torch.ComplexFloatStorage", 80 | "torch.HalfStorage", 81 | "torch.IntStorage", 82 | "torch.LongStorage", 83 | "torch.QInt32Storage", 84 | "torch.QInt8Storage", 85 | "torch.QUInt8Storage", 86 | "torch.ShortStorage", 87 | "torch.storage._StorageBase", 88 | "torch.ByteStorage", 89 | "torch.DoubleStorage", 90 | "torch.FloatStorage", 91 | "torch._C.HalfStorageBase", 92 | "torch._C.QInt32StorageBase", 93 | "torch._C.QInt8StorageBase", 94 | "torch.storage._TypedStorage", 95 | ] 96 | 97 | for item in __ALLOWLIST: 98 | try: 99 | attrs = item.split(".") 100 | g = globals()[attrs[0]] 101 | for attr in attrs[1:]: 102 | g = getattr(g, attr) 103 | __SAFE_CLASSES[item] = g 104 | except (KeyError, AttributeError): 105 | logging.info(f"Could not find {item} to register as a SAFE_CLASS") 106 | 107 | @classmethod 108 | def register_safe_class(cls, input_class): 109 | assert isinstance(input_class, type), "Cannot register %s type as safe" % type( 110 | input_class 111 | ) 112 | classname = str(input_class).split("'")[1] 113 | logging.info(f"Registering {classname} class as safe for deserialization.") 114 | cls.__SAFE_CLASSES[classname] = input_class 115 | 116 | def find_class(self, module, name): 117 | classname = f"{module}.{name}" 118 | if classname not in self.__SAFE_CLASSES.keys(): 119 | raise ValueError( 120 | f"Deserialization is restricted for pickled module {classname}" 121 | ) 122 | return self.__SAFE_CLASSES[classname] 123 | 124 | 125 | def register_safe_class(input_class): 126 | RestrictedUnpickler.register_safe_class(input_class) 127 | 128 | 129 | def _assert_empty_ordered_dict(x): 130 | assert isinstance(x, collections.OrderedDict) 131 | assert len(x) == 0 132 | 133 | 134 | def _check_hooks_are_valid(result, hook_name): 135 | if hasattr(result, hook_name): 136 | _assert_empty_ordered_dict(getattr(result, hook_name)) 137 | if hasattr(result, "parameters"): 138 | for param in result.parameters(): 139 | _assert_empty_ordered_dict(getattr(param, hook_name)) 140 | if hasattr(result, "modules"): 141 | for module in result.modules(): 142 | _assert_empty_ordered_dict(getattr(module, hook_name)) 143 | 144 | 145 | def restricted_loads(s): 146 | result = RestrictedUnpickler(io.BytesIO(s)).load() 147 | if torch.is_tensor(result) or isinstance(result, torch.nn.Module): 148 | _check_hooks_are_valid(result, "_backward_hooks") 149 | return result 150 | 151 | 152 | class safe_pickle: 153 | Unpickler = RestrictedUnpickler 154 | 155 | @staticmethod 156 | def load(f): 157 | return RestrictedUnpickler(f).load() 158 | 159 | 160 | def _safe_legacy_load(f): 161 | return torch.serialization._legacy_load( 162 | f, map_location=None, pickle_module=safe_pickle 163 | ) 164 | -------------------------------------------------------------------------------- /crypten/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib.util 9 | import logging 10 | import sys 11 | 12 | import crypten.nn as cnn 13 | import torch 14 | 15 | 16 | # List of modules to import and additional classes to update from them 17 | __import_list = { 18 | "alexnet": [], 19 | "densenet": ["_DenseLayer", "_DenseBlock", "_Transition"], 20 | "googlenet": ["Inception", "InceptionAux", "BasicConv2d"], 21 | "inception": [ 22 | "BasicConv2d", 23 | "InceptionA", 24 | "InceptionB", 25 | "InceptionC", 26 | "InceptionD", 27 | "InceptionE", 28 | "InceptionAux", 29 | ], 30 | "mnasnet": ["_InvertedResidual"], 31 | "mobilenet": [], 32 | "resnet": ["BasicBlock", "Bottleneck"], 33 | "shufflenetv2": ["InvertedResidual"], 34 | "squeezenet": ["Fire"], 35 | "vgg": [], 36 | } 37 | 38 | 39 | __all__ = [] 40 | 41 | 42 | def __import_module_copy(module_name): 43 | """ 44 | Returns a copy of an imported module so it can be modified 45 | without modifying future imports of the given module 46 | """ 47 | starting_modules = sys.modules.copy() 48 | 49 | module_spec = importlib.util.find_spec(module_name) 50 | module = importlib.util.module_from_spec(module_spec) 51 | module_spec.loader.exec_module(module) 52 | new_modules = set(sys.modules) - set(starting_modules) 53 | 54 | del module_spec 55 | for m in new_modules: 56 | del sys.modules[m] 57 | 58 | return module 59 | 60 | 61 | def __import_model_package_copy(import_name): 62 | """ 63 | Returns a copy of an imported model whose package contains 64 | a function of the same name. 65 | """ 66 | starting_modules = sys.modules.copy() 67 | 68 | model_type = importlib.import_module(f"torchvision.models.{import_name}") 69 | new_modules = set(sys.modules) - set(starting_modules) 70 | for m in new_modules: 71 | del sys.modules[m] 72 | 73 | return model_type 74 | 75 | 76 | def __update_model_class_inheritance(cls): 77 | """ 78 | Updates the class inheritance of a torch.nn.Module to instead use 79 | crypten.nn.Module 80 | """ 81 | bases = [] 82 | for m in cls.__bases__: 83 | if m == torch.nn.Module: 84 | bases.append(cnn.Module) 85 | elif m == torch.nn.Sequential: 86 | bases.append(cnn.Sequential) 87 | elif m == torch.nn.ModuleDict: 88 | bases.append(cnn.ModuleDict) 89 | else: 90 | bases.append(m) 91 | 92 | cls.__bases__ = tuple(bases) 93 | 94 | 95 | class FunctionalReplacement: 96 | """Replacement for `torch.nn.functional` that overwrites torch functionals to be crypten compatible""" 97 | 98 | @staticmethod 99 | def dropout(x, **kwargs): 100 | return x.dropout(**kwargs) 101 | 102 | @staticmethod 103 | def relu(x, **kwargs): 104 | return x.relu() 105 | 106 | @staticmethod 107 | def adaptive_avg_pool2d(x, *args): 108 | return cnn.AdaptiveAvgPool2d(*args)(x) 109 | 110 | @staticmethod 111 | def avg_pool2d(x, *args, **kwargs): 112 | return x.avg_pool2d(*args, **kwargs) 113 | 114 | @staticmethod 115 | def max_pool2d(x, *args, **kwargs): 116 | return x.max_pool2d(*args, **kwargs) 117 | 118 | 119 | def __update_torch_functions(module): 120 | if hasattr(module, "nn"): 121 | module.nn = cnn 122 | 123 | # TODO: fix replacement in global `torch` module - perhaps use __torch_function__ 124 | if hasattr(module, "torch"): 125 | module.torch.flatten = lambda x, *args: x.flatten(*args) 126 | module.torch.transpose = lambda x, *args: x.transpose(*args) 127 | # module.torch.cat = lambda *args, **kwargs: args[0].cat(*args, **kwargs) 128 | 129 | if hasattr(module, "F"): 130 | module.F = FunctionalReplacement() 131 | 132 | 133 | def __get_module_list(model_name, model_type): 134 | return __import_list[model_name] + model_type.__all__ 135 | 136 | 137 | try: 138 | models = __import_module_copy("torchvision").models 139 | 140 | except ModuleNotFoundError: 141 | models = None 142 | logging.warning("Unable to load torchvision models.") 143 | 144 | 145 | if models is not None: 146 | for import_name in __import_list.keys(): 147 | try: 148 | model_type = getattr(models, import_name) 149 | except AttributeError: 150 | logging.warning(f"Could not load {import_name} from torchvision.modules") 151 | continue 152 | 153 | try: 154 | # If function imported rather than package, replace with package 155 | if not hasattr(model_type, "__all__"): 156 | model_type = __import_model_package_copy(import_name) 157 | 158 | __update_torch_functions(model_type) 159 | module_list = __get_module_list(import_name, model_type) 160 | for module_name in module_list: 161 | module = getattr(model_type, module_name) 162 | 163 | # Replace class inheritance from torch.nn.Module to crypten.nn.Module 164 | if isinstance(module, type): 165 | __update_model_class_inheritance(module) 166 | 167 | module.load_state_dict = ( 168 | lambda *args, **kwargs: cnn.Module.load_state_dict( 169 | *args, strict=False, **kwargs 170 | ) 171 | ) 172 | 173 | if module_name in model_type.__all__: 174 | globals()[module_name] = module 175 | __all__.append(module_name) 176 | except (RuntimeError, AttributeError) as e: 177 | # Log that module produced an error 178 | logging.warning(e) 179 | 180 | 181 | raise DeprecationWarning( 182 | "crypten.models is being deprecated. To import models from torchvision, ", 183 | "please import them directly and use crypten.nn.from_pytorch() to convert", 184 | " to CrypTen models.", 185 | ) 186 | -------------------------------------------------------------------------------- /crypten/optim/sgd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | 10 | from .optimizer import Optimizer 11 | 12 | 13 | class SGD(Optimizer): 14 | r"""Implements stochastic gradient descent (optionally with momentum). 15 | Nesterov momentum is based on the formula from 16 | `On the importance of initialization and momentum in deep learning`__. 17 | Args: 18 | params (iterable): iterable of parameters to optimize or dicts defining 19 | parameter groups 20 | lr (float): learning rate 21 | momentum (float, optional): momentum factor (default: 0) 22 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 23 | dampening (float, optional): dampening for momentum (default: 0) 24 | nesterov (bool, optional): enables Nesterov momentum (default: False) 25 | grad_threshold (float, optional): imposes a threshold on the magnitude of gradient values. 26 | Gradient values with magnitude above the threshold will be replaced with 0. 27 | Example: 28 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 29 | >>> optimizer.zero_grad() 30 | >>> loss_fn(model(input), target).backward() 31 | >>> optimizer.step() 32 | __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 33 | .. note:: 34 | The implementation of SGD with Momentum/Nesterov subtly differs from 35 | Sutskever et. al. and implementations in some other frameworks. 36 | Considering the specific case of Momentum, the update can be written as 37 | .. math:: 38 | \begin{aligned} 39 | v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ 40 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 41 | \end{aligned} 42 | where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the 43 | parameters, gradient, velocity, and momentum respectively. 44 | This is in contrast to Sutskever et. al. and 45 | other frameworks which employ an update of the form 46 | .. math:: 47 | \begin{aligned} 48 | v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ 49 | p_{t+1} & = p_{t} - v_{t+1}. 50 | \end{aligned} 51 | The Nesterov version is analogously modified. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | params, 57 | lr, 58 | momentum=0, 59 | dampening=0, 60 | weight_decay=0, 61 | nesterov=False, 62 | grad_threshold=None, 63 | ): 64 | if not isinstance(lr, (int, float)) or lr < 0.0: 65 | raise ValueError("Invalid learning rate: {}".format(lr)) 66 | if not isinstance(momentum, (int, float)) or momentum < 0.0: 67 | raise ValueError("Invalid momentum value: {}".format(momentum)) 68 | if not isinstance(dampening, (int, float)): 69 | raise ValueError("Invalid dampening value {}".format(dampening)) 70 | if not isinstance(weight_decay, (int, float)) or weight_decay < 0.0: 71 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 72 | 73 | defaults = { 74 | "lr": lr, 75 | "momentum": momentum, 76 | "dampening": dampening, 77 | "weight_decay": weight_decay, 78 | "nesterov": nesterov, 79 | } 80 | if nesterov and (momentum <= 0 or dampening != 0): 81 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 82 | 83 | # Compute thresholding based on square value since abs is more expensive 84 | self.square_threshold = grad_threshold 85 | if self.square_threshold is not None: 86 | self.square_threshold *= self.square_threshold 87 | 88 | super(SGD, self).__init__(params, defaults) 89 | 90 | def __setstate__(self, state): 91 | super(SGD, self).__setstate__(state) 92 | for group in self.param_groups: 93 | group.setdefault("nesterov", False) 94 | 95 | def step(self, closure=None): 96 | """Performs a single optimization step. 97 | Arguments: 98 | closure (callable, optional): A closure that reevaluates the model 99 | and returns the loss. 100 | """ 101 | with crypten.no_grad(): 102 | loss = None 103 | if closure is not None: 104 | with crypten.enable_grad(): 105 | loss = closure() 106 | 107 | for group in self.param_groups: 108 | weight_decay = group["weight_decay"] 109 | momentum = group["momentum"] 110 | dampening = group["dampening"] 111 | nesterov = group["nesterov"] 112 | 113 | for p in group["params"]: 114 | if p.grad is None: 115 | continue 116 | 117 | # Threshold gradients to prevent gradient explosion 118 | if self.square_threshold is not None: 119 | d_p = p.grad.mul(p.grad.square().lt(self.square_threshold)) 120 | else: 121 | d_p = p.grad 122 | 123 | if weight_decay != 0: 124 | d_p = d_p.add(p.mul(weight_decay)) 125 | if momentum != 0: 126 | param_state = self.state[id(p)] 127 | if "momentum_buffer" not in param_state: 128 | buf = param_state["momentum_buffer"] = d_p.clone().detach() 129 | else: 130 | buf = param_state["momentum_buffer"] 131 | buf.mul_(momentum).add_(d_p.mul(1 - dampening)) 132 | if nesterov: 133 | d_p = d_p.add(buf.mul(momentum)) 134 | else: 135 | d_p = buf 136 | 137 | p.sub_(d_p.mul(group["lr"])) 138 | 139 | return loss 140 | -------------------------------------------------------------------------------- /crypten/communicator/in_process_communicator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | import threading 10 | from operator import itemgetter 11 | from queue import Queue 12 | 13 | import torch 14 | from torch.distributed import ReduceOp 15 | 16 | from .communicator import Communicator 17 | 18 | 19 | class InProcessCommunicator(Communicator): 20 | 21 | BYTES_PER_ELEMENT = 8 22 | tls = threading.local() 23 | mailbox = None 24 | barrier = None 25 | lock = threading.Lock() 26 | 27 | @classmethod 28 | def initialize(cls, rank, world_size, init_ttp=False): 29 | cls.tls.instance = cls(rank, world_size) 30 | 31 | def __init__(self, rank, world_size, init_ttp=False): 32 | self.world_size = world_size 33 | self.rank = rank 34 | self.reset_communication_stats() 35 | self._name = f"rank{rank}" 36 | 37 | with InProcessCommunicator.lock: 38 | if InProcessCommunicator.mailbox is None: 39 | InProcessCommunicator.mailbox = [ 40 | Queue() for _ in range(self.world_size) 41 | ] 42 | 43 | # This prevents one thread from running ahead of the others and doing 44 | # multiple puts that would show up in the get calls below 45 | InProcessCommunicator.barrier = threading.Barrier(self.world_size) 46 | 47 | # logging: 48 | level = logging.getLogger().level 49 | logging.getLogger().setLevel(logging.INFO) 50 | logging.info("==================") 51 | logging.info("InProcessCommunicator with rank %d" % self.rank) 52 | logging.info("==================") 53 | 54 | logging.info("World size = %d" % self.get_world_size()) 55 | logging.getLogger().setLevel(level) 56 | 57 | @classmethod 58 | def get(cls): 59 | if not hasattr(cls.tls, "instance"): 60 | return None 61 | 62 | return cls.tls.instance 63 | 64 | @classmethod 65 | def is_initialized(cls): 66 | return hasattr(cls.tls, "instance") 67 | 68 | def send(self, tensor, dst): 69 | """Sends the specified tensor to the destination dst.""" 70 | self.mailbox[dst].put((self.rank, tensor.clone())) 71 | 72 | def recv(self, tensor, src=None): 73 | """Receives a tensor from an (optional) source src.""" 74 | rank, result = self.mailbox[self.rank].get() 75 | if src is not None and rank != src: 76 | raise NotImplementedError("Can't receive messages out of order yet") 77 | return result 78 | 79 | def isend(self, tensor, dst): 80 | """Sends the specified tensor to the destination dst.""" 81 | self.send(tensor, dst) 82 | 83 | class Result: 84 | def is_completed(self): 85 | return True 86 | 87 | def wait(self): 88 | pass 89 | 90 | return Result() 91 | 92 | def irecv(self, tensor, src=None): 93 | """Receives a tensor from an (optional) source src.""" 94 | 95 | class Result: 96 | def __init__(self, mailbox, rank): 97 | self.completed = False 98 | self.mailbox = mailbox 99 | self.rank = rank 100 | 101 | def is_completed(self): 102 | return self.completed 103 | 104 | def wait(self): 105 | rank, result = self.mailbox[self.rank].get() 106 | if src is not None and rank != src: 107 | raise NotImplementedError("Can't receive messages out of order yet") 108 | tensor.copy_(result) 109 | 110 | return Result(self.mailbox, self.rank) 111 | 112 | def scatter(self, scatter_list, src, size=None, async_op=False): 113 | """Scatters a list of tensors to all parties.""" 114 | if async_op: 115 | raise NotImplementedError() 116 | 117 | if src == self.rank: 118 | for i in range(self.world_size): 119 | self.mailbox[i].put(scatter_list[i].clone()) 120 | 121 | self.barrier.wait() 122 | 123 | return self.mailbox[self.rank].get() 124 | 125 | def reduce(self, tensor, dst, op=ReduceOp.SUM, async_op=False): 126 | """Reduces the tensor data across all parties.""" 127 | tensors = self.gather(tensor, dst) 128 | if self.rank == dst: 129 | reduce_fn = self._reduce_op_to_function(op) 130 | return reduce_fn(torch.stack(tensors), dim=0) 131 | 132 | @classmethod 133 | def shutdown(cls): 134 | # Destroy all thread-local instances 135 | cls.tls = threading.local() 136 | cls.mailbox = None 137 | cls.barrier = None 138 | 139 | def _reduce_op_to_function(self, op): 140 | if op == ReduceOp.SUM: 141 | return torch.sum 142 | 143 | raise NotImplementedError() 144 | 145 | def all_reduce(self, tensor, op=ReduceOp.SUM, async_op=False): 146 | """Reduces the tensor data across all parties; all get the final result.""" 147 | if async_op: 148 | raise NotImplementedError() 149 | 150 | ag = self.all_gather(tensor) 151 | reduce_fn = self._reduce_op_to_function(op) 152 | return reduce_fn(torch.stack(ag), dim=0) 153 | 154 | def gather(self, tensor, dst, async_op=False): 155 | """Gathers a list of tensors in a single party.""" 156 | if async_op: 157 | raise NotImplementedError() 158 | 159 | self.mailbox[dst].put((self.rank, tensor.clone())) 160 | 161 | self.barrier.wait() 162 | 163 | if self.rank == dst: 164 | result = [self.mailbox[dst].get() for _ in range(self.world_size)] 165 | return [tensor for rank, tensor in sorted(result, key=itemgetter(0))] 166 | 167 | def all_gather(self, tensor, async_op=False): 168 | """Gathers tensors from all parties in a list.""" 169 | if async_op: 170 | raise NotImplementedError() 171 | 172 | for i in range(self.world_size): 173 | self.mailbox[i].put((self.rank, tensor.clone())) 174 | 175 | self.barrier.wait() 176 | 177 | result = sorted( 178 | (self.mailbox[self.rank].get() for _ in range(self.world_size)), 179 | key=itemgetter(0), 180 | ) 181 | 182 | return [tensor for (rank, tensor) in result] 183 | 184 | def broadcast(self, tensor, src, async_op=False): 185 | """Broadcasts the tensor to all parties.""" 186 | if async_op: 187 | raise NotImplementedError() 188 | 189 | if self.rank == src: 190 | for i in range(self.get_world_size()): 191 | self.mailbox[i].put(tensor.clone()) 192 | 193 | # No need for a barrier here. 194 | 195 | return self.mailbox[self.rank].get() 196 | 197 | def get_world_size(self): 198 | """Returns the size of the world.""" 199 | return self.world_size 200 | 201 | def get_rank(self): 202 | """Returns the rank of the current process.""" 203 | return self.rank 204 | 205 | def set_name(self, name): 206 | """Sets the party name of the current rank.""" 207 | assert isinstance( 208 | name, str 209 | ), f"Improper name provided to process on rank {self.get_rank()}" 210 | self._name = name 211 | 212 | def get_name(self): 213 | """Returns the party name of the current rank.""" 214 | return self._name 215 | -------------------------------------------------------------------------------- /crypten/nn/loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import torch 10 | 11 | from .module import Module 12 | 13 | 14 | class _Loss(Module): 15 | """ 16 | Base criterion class that mimics Pytorch's Loss. 17 | """ 18 | 19 | def __init__(self, reduction="mean", skip_forward=False): 20 | super(_Loss, self).__init__() 21 | if reduction != "mean": 22 | raise NotImplementedError("reduction %s not supported") 23 | self.reduction = reduction 24 | self.skip_forward = skip_forward 25 | 26 | def forward(self, *args, **kwargs): 27 | raise NotImplementedError("forward not implemented") 28 | 29 | def __call__(self, *args, **kwargs): 30 | return self.forward(*args, **kwargs) 31 | 32 | def __getattribute__(self, name): 33 | if name != "forward": 34 | return object.__getattribute__(self, name) 35 | 36 | def forward_function(*args, **kwargs): 37 | """Silently encrypt Torch tensors if needed.""" 38 | if self.encrypted or any( 39 | isinstance(arg, crypten.CrypTensor) for arg in args 40 | ): 41 | args = list(args) 42 | for idx, arg in enumerate(args): 43 | if torch.is_tensor(arg): 44 | args[idx] = crypten.cryptensor(arg) 45 | return object.__getattribute__(self, name)(*tuple(args), **kwargs) 46 | 47 | return forward_function 48 | 49 | 50 | class MSELoss(_Loss): 51 | r""" 52 | Creates a criterion that measures the mean squared error (squared L2 norm) between 53 | each element in the prediction :math:`x` and target :math:`y`. 54 | 55 | The loss can be described as: 56 | 57 | .. math:: 58 | \ell(x, y) = mean(L) = mean(\{l_1,\dots,l_N\}^\top), \quad 59 | l_n = (x_n - y_n)^2, 60 | 61 | where :math:`N` is the batch size, :math:`x` and :math:`y` are tensors of 62 | arbitrary shapes with a total of :math:`n` elements each. 63 | """ # noqa: W605 64 | 65 | def forward(self, x, y): 66 | assert x.size() == y.size(), "input and target must have the same size" 67 | return (x - y).square().mean() 68 | 69 | 70 | class L1Loss(_Loss): 71 | r""" 72 | Creates a criterion that measures the mean absolute error between each element in 73 | the prediction :math:`x` and target :math:`y`. 74 | 75 | The loss can be described as: 76 | 77 | .. math:: 78 | \ell(x, y) = mean(L) = mean(\{l_1,\dots,l_N\}^\top), \quad 79 | l_n = \left | x_n - y_n \right |, 80 | 81 | where :math:`N` is the batch size, :math:`x` and :math:`y` are tensors of 82 | arbitrary shapes with a total of :math:`n` elements each. 83 | """ # noqa: W605 84 | 85 | def forward(self, x, y): 86 | assert x.size() == y.size(), "input and target must have the same size" 87 | return (x - y).abs().mean() 88 | 89 | 90 | class BCELoss(_Loss): 91 | r""" 92 | Creates a criterion that measures the Binary Cross Entropy 93 | between the prediction :math:`x` and the target :math:`y`. 94 | 95 | The loss can be described as: 96 | 97 | .. math:: 98 | \ell(x, y) = mean(L) = mean(\{l_1,\dots,l_N\}^\top), \quad 99 | l_n = - \left [ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right ], 100 | 101 | where :math:`N` is the batch size, :math:`x` and :math:`y` are tensors of 102 | arbitrary shapes with a total of :math:`n` elements each. 103 | 104 | This is used for measuring the error of a reconstruction in for example 105 | an auto-encoder. Note that the targets :math:`y` should be numbers 106 | between 0 and 1. 107 | """ # noqa: W605 108 | 109 | def forward(self, x, y): 110 | assert x.size() == y.size(), "input and target must have the same size" 111 | return x.binary_cross_entropy(y, skip_forward=self.skip_forward) 112 | 113 | 114 | class CrossEntropyLoss(_Loss): 115 | r""" 116 | Creates a criterion that measures cross-entropy loss between the 117 | prediction :math:`x` and the target :math:`y`. It is useful when 118 | training a classification problem with `C` classes. 119 | 120 | The prediction `x` is expected to contain raw, unnormalized scores for each class. 121 | 122 | The prediction `x` has to be a Tensor of size either :math:`(N, C)` or 123 | :math:`(N, C, d_1, d_2, ..., d_K)`, where :math:`N` is the size of the minibatch, 124 | and with :math:`K \geq 1` for the `K`-dimensional case (described later). 125 | 126 | This criterion expects a class index in the range :math:`[0, C-1]` as the 127 | target `y` for each value of a 1D tensor of size `N`. 128 | 129 | The loss can be described as: 130 | 131 | .. math:: 132 | \text{loss}(x, class) = -\log \left( 133 | \frac{\exp(x[class])}{\sum_j \exp(x[j])} \right ) 134 | = -x[class] + \log \left (\sum_j \exp(x[j]) \right) 135 | 136 | The losses are averaged across observations for each batch 137 | 138 | Can also be used for higher dimension inputs, such as 2D images, by providing 139 | an input of size :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`, 140 | where :math:`K` is the number of dimensions, and a target of appropriate shape. 141 | """ # noqa: W605 142 | 143 | def forward(self, x, y): 144 | x = x.squeeze() 145 | y = y.squeeze() 146 | assert x.size() == y.size(), "input and target must have the same size" 147 | return x.cross_entropy(y, skip_forward=self.skip_forward) 148 | 149 | 150 | class BCEWithLogitsLoss(_Loss): 151 | r""" 152 | This loss combines a Sigmoid layer and the BCELoss in one single class. 153 | 154 | The loss can be described as: 155 | 156 | .. math:: 157 | p = \sigma(x) 158 | 159 | .. math:: 160 | \ell(x, y) = mean(L) = mean(\{l_1,\dots,l_N\}^\top), \quad 161 | l_n = - \left [ y_n \cdot \log p_n + (1 - y_n) \cdot \log (1 - p_n) \right ], 162 | 163 | This is used for measuring the error of a reconstruction in for example an 164 | auto-encoder. Note that the targets t[i] should be numbers between 0 and 1. 165 | """ # noqa: W605 166 | 167 | def forward(self, x, y): 168 | assert x.size() == y.size(), "input and target must have the same size" 169 | return x.binary_cross_entropy_with_logits(y, skip_forward=self.skip_forward) 170 | 171 | 172 | class RAPPORLoss(_Loss): 173 | r""" 174 | This loss computes the BCEWithLogitsLoss with corrections applied to account 175 | for randomized response, where the input `alpha` represents the probability 176 | of flipping a label. 177 | 178 | The loss can be described as: 179 | 180 | .. math:: 181 | p = \sigma(x) 182 | 183 | .. math:: 184 | r = \alpha * p + (1 - \alpha) * (1 - p) 185 | 186 | .. math:: 187 | \ell(x, y) = mean(L) = mean(\{l_1,\dots,l_N\}^\top), \quad 188 | l_n = - \left [ y_n \cdot \log r_n + (1 - y_n) \cdot \log (1 - r_n) \right ], 189 | 190 | This is used for measuring the error of a reconstruction in for example an 191 | auto-encoder. Note that the targets t[i] should be numbers between 0 and 1. 192 | """ 193 | 194 | def __init__(self, alpha, reduction="mean", skip_forward=False): 195 | super(RAPPORLoss, self).__init__(reduction=reduction, skip_forward=skip_forward) 196 | self.alpha = alpha 197 | 198 | def forward(self, x, y): 199 | assert x.size() == y.size(), "input and target must have the same size" 200 | return x.rappor_loss(y, self.alpha, skip_forward=self.skip_forward) 201 | -------------------------------------------------------------------------------- /crypten/mpc/provider/provider.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import os 9 | 10 | import crypten 11 | import crypten.communicator as comm 12 | import torch 13 | 14 | 15 | class TupleProvider: 16 | TRACEABLE_FUNCTIONS = [ 17 | "generate_additive_triple", 18 | "square", 19 | "generate_binary_triple", 20 | "generate_trig_triple", 21 | "generate_one_hot_pair", 22 | "wrap_rng", 23 | "B2A_rng", 24 | ] 25 | 26 | _DEFAULT_CACHE_PATH = os.path.normpath(os.path.join(__file__, "../tuple_cache/")) 27 | 28 | def __init__(self): 29 | self.tracing = False 30 | self.request_cache = [] 31 | self.tuple_cache = {} 32 | 33 | @property 34 | def rank(self): 35 | return comm.get().get_rank() 36 | 37 | def _get_request_path(self, prefix=None): 38 | if prefix is None: 39 | prefix = self._DEFAULT_CACHE_PATH 40 | return prefix + f"/request_cache-{self.rank}" 41 | 42 | def _get_tuple_path(self, prefix=None): 43 | if prefix is None: 44 | prefix = self._DEFAULT_CACHE_PATH 45 | return prefix + f"/tuple_cache-{self.rank}" 46 | 47 | def trace(self, tracing=True): 48 | """Sets tracing attribute. 49 | 50 | When tracing is True, provider caches all tuple requests. 51 | When tracing is False, provider attempts to load tuples from cache. 52 | """ 53 | self.tracing = tracing 54 | 55 | def trace_once(self): 56 | """Sets tracing attribute True only if the request cache is empty. 57 | If `trace_once()` is called again, it sets tracing attribute to False 58 | """ 59 | untraced = len(self.request_cache) == 0 60 | self.trace(tracing=untraced) 61 | 62 | def _save_requests(self, filepath=None): 63 | # TODO: Deal with any overwrite issues 64 | if len(self.request_cache) == 0: 65 | crypten.log("Request cache not saved - cache is empty") 66 | return 67 | filepath = self._get_request_path(prefix=filepath) 68 | torch.save(self.request_cache, filepath) 69 | self.request_cache = [] 70 | 71 | def _load_requests(self, filepath=None): 72 | filepath = self._get_request_path(prefix=filepath) 73 | if os.path.exists(filepath): 74 | self.request_cache = torch.load(filepath) 75 | os.remove(filepath) 76 | else: 77 | crypten.log(f"Cache requests not loaded - File `{filepath}` not found") 78 | 79 | def _save_tuples(self, filepath=None): 80 | # TODO: Deal with any overwrite issues 81 | if len(self.tuple_cache) == 0: 82 | crypten.log("Tuple cache not saved - cache is empty") 83 | return 84 | filepath = self._get_tuple_path(prefix=filepath) 85 | torch.save(self.tuple_cache, filepath) 86 | self.tuple_cache = {} 87 | 88 | def _load_tuples(self, filepath=None): 89 | filepath = self._get_tuple_path(prefix=filepath) 90 | if os.path.exists(filepath): 91 | self.tuple_cache = torch.load(filepath) 92 | os.remove(filepath) 93 | else: 94 | crypten.log(f"Tuple cache not loaded - File `{filepath}` not found") 95 | 96 | def save_cache(self, filepath=None): 97 | """Saves request and tuple cache to a file. 98 | 99 | args: 100 | filepath - base filepath for cache folder (default: "provider/tuple_cache/") 101 | """ 102 | self._save_requests(filepath=filepath) 103 | self._save_tuples(filepath=filepath) 104 | 105 | def load_cache(self, filepath=None): 106 | """Loads request and tuple cache from a file. 107 | 108 | args: 109 | filepath - base filepath for cache folder (default: "provider/tuple_cache/") 110 | """ 111 | self._load_requests(filepath=filepath) 112 | self._load_tuples(filepath=filepath) 113 | 114 | def __getattribute__(self, func_name): 115 | """Deals with caching logic""" 116 | if func_name not in TupleProvider.TRACEABLE_FUNCTIONS: 117 | return object.__getattribute__(self, func_name) 118 | 119 | # Trace requests while tracing 120 | if self.tracing: 121 | 122 | def func_with_trace(*args, **kwargs): 123 | request = (func_name, args, kwargs) 124 | self.request_cache.append(request) 125 | return object.__getattribute__(self, func_name)(*args, **kwargs) 126 | 127 | return func_with_trace 128 | 129 | # If the cache is empty, call function directly 130 | if len(self.tuple_cache) == 0: 131 | return object.__getattribute__(self, func_name) 132 | 133 | # Return results from cache if available 134 | def func_from_cache(*args, **kwargs): 135 | hashable_kwargs = frozenset(kwargs.items()) 136 | request = (func_name, args, hashable_kwargs) 137 | # Read from cache 138 | if request in self.tuple_cache.keys(): 139 | return self.tuple_cache[request].pop() 140 | # Cache miss 141 | return object.__getattribute__(self, func_name)(*args, **kwargs) 142 | 143 | return func_from_cache 144 | 145 | def fill_cache(self): 146 | """Fills tuple_cache with tuples requested in the request_cache""" 147 | # TODO: parallelize / async this 148 | for request in self.request_cache: 149 | func_name, args, kwargs = request 150 | try: 151 | result = object.__getattribute__(self, func_name)(*args, **kwargs) 152 | 153 | hashable_kwargs = frozenset(kwargs.items()) 154 | hashable_request = (func_name, args, hashable_kwargs) 155 | if hashable_request in self.tuple_cache.keys(): 156 | self.tuple_cache[hashable_request].append(result) 157 | else: 158 | self.tuple_cache[hashable_request] = [result] 159 | except Exception: 160 | continue 161 | 162 | def generate_additive_triple(self, size0, size1, op, device=None, *args, **kwargs): 163 | """Generate multiplicative triples of given sizes""" 164 | raise NotImplementedError( 165 | "TupleProvider generate_additive_triple not implemented." 166 | ) 167 | 168 | def square(self, size, device=None): 169 | """Generate square double of given size""" 170 | raise NotImplementedError("TupleProvider square not implemented.") 171 | 172 | def generate_trig_triple(self, size, period, terms, device=None): 173 | """Generate trigonometric triple of given size""" 174 | # TODO: Implement generate_trig_triple() for TTP provider 175 | raise NotImplementedError("TupleProvider generate_trig_triple not implemented") 176 | 177 | def generate_one_hot_pair(self, size, length, device=None): 178 | # TODO: Implement generate_one_hot_pair() for TTP provider 179 | raise NotImplementedError("TupleProvider generate_one_hot_pair not implemented") 180 | 181 | def generate_binary_triple(self, size0, size1, device=None): 182 | """Generate xor triples of given size""" 183 | raise NotImplementedError( 184 | "TupleProvider generate_binary_triple not implemented." 185 | ) 186 | 187 | def wrap_rng(self, size, device=None): 188 | """Generate random shared tensor of given size and sharing of its wraps""" 189 | raise NotImplementedError("TupleProvider wrap_rng not implemented.") 190 | 191 | def B2A_rng(self, size, device=None): 192 | """Generate random bit tensor as arithmetic and binary shared tensors""" 193 | raise NotImplementedError("TupleProvider B2A_rng not implemented.") 194 | -------------------------------------------------------------------------------- /crypten/mpc/primitives/beaver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import crypten 9 | import crypten.communicator as comm 10 | import torch 11 | from crypten.common.util import count_wraps 12 | from crypten.config import cfg 13 | 14 | 15 | class IgnoreEncodings: 16 | """Context Manager to ignore tensor encodings""" 17 | 18 | def __init__(self, list_of_tensors): 19 | self.list_of_tensors = list_of_tensors 20 | self.encodings_cache = [tensor.encoder.scale for tensor in list_of_tensors] 21 | 22 | def __enter__(self): 23 | for tensor in self.list_of_tensors: 24 | tensor.encoder._scale = 1 25 | 26 | def __exit__(self, exc_type, exc_value, exc_traceback): 27 | for i, tensor in enumerate(self.list_of_tensors): 28 | tensor.encoder._scale = self.encodings_cache[i] 29 | 30 | 31 | def __beaver_protocol(op, x, y, *args, **kwargs): 32 | """Performs Beaver protocol for additively secret-shared tensors x and y 33 | 34 | 1. Obtain uniformly random sharings [a],[b] and [c] = [a * b] 35 | 2. Additively hide [x] and [y] with appropriately sized [a] and [b] 36 | 3. Open ([epsilon] = [x] - [a]) and ([delta] = [y] - [b]) 37 | 4. Return [z] = [c] + (epsilon * [b]) + ([a] * delta) + (epsilon * delta) 38 | """ 39 | assert op in { 40 | "mul", 41 | "matmul", 42 | "conv1d", 43 | "conv2d", 44 | "conv_transpose1d", 45 | "conv_transpose2d", 46 | } 47 | if x.device != y.device: 48 | raise ValueError(f"x lives on device {x.device} but y on device {y.device}") 49 | 50 | provider = crypten.mpc.get_default_provider() 51 | a, b, c = provider.generate_additive_triple( 52 | x.size(), y.size(), op, device=x.device, *args, **kwargs 53 | ) 54 | 55 | from .arithmetic import ArithmeticSharedTensor 56 | 57 | if cfg.mpc.active_security: 58 | """ 59 | Reference: "Multiparty Computation from Somewhat Homomorphic Encryption" 60 | Link: https://eprint.iacr.org/2011/535.pdf 61 | """ 62 | f, g, h = provider.generate_additive_triple( 63 | x.size(), y.size(), op, device=x.device, *args, **kwargs 64 | ) 65 | 66 | t = ArithmeticSharedTensor.PRSS(a.size(), device=x.device) 67 | t_plain_text = t.get_plain_text() 68 | 69 | rho = (t_plain_text * a - f).get_plain_text() 70 | sigma = (b - g).get_plain_text() 71 | triples_check = t_plain_text * c - h - sigma * f - rho * g - rho * sigma 72 | triples_check = triples_check.get_plain_text() 73 | 74 | if torch.any(triples_check != 0): 75 | raise ValueError("Beaver Triples verification failed!") 76 | 77 | # Vectorized reveal to reduce rounds of communication 78 | with IgnoreEncodings([a, b, x, y]): 79 | epsilon, delta = ArithmeticSharedTensor.reveal_batch([x - a, y - b]) 80 | 81 | # z = c + (a * delta) + (epsilon * b) + epsilon * delta 82 | c._tensor += getattr(torch, op)(epsilon, b._tensor, *args, **kwargs) 83 | c._tensor += getattr(torch, op)(a._tensor, delta, *args, **kwargs) 84 | c += getattr(torch, op)(epsilon, delta, *args, **kwargs) 85 | 86 | return c 87 | 88 | 89 | def mul(x, y): 90 | return __beaver_protocol("mul", x, y) 91 | 92 | 93 | def matmul(x, y): 94 | return __beaver_protocol("matmul", x, y) 95 | 96 | 97 | def conv1d(x, y, **kwargs): 98 | return __beaver_protocol("conv1d", x, y, **kwargs) 99 | 100 | 101 | def conv2d(x, y, **kwargs): 102 | return __beaver_protocol("conv2d", x, y, **kwargs) 103 | 104 | 105 | def conv_transpose1d(x, y, **kwargs): 106 | return __beaver_protocol("conv_transpose1d", x, y, **kwargs) 107 | 108 | 109 | def conv_transpose2d(x, y, **kwargs): 110 | return __beaver_protocol("conv_transpose2d", x, y, **kwargs) 111 | 112 | 113 | def square(x): 114 | """Computes the square of `x` for additively secret-shared tensor `x` 115 | 116 | 1. Obtain uniformly random sharings [r] and [r2] = [r * r] 117 | 2. Additively hide [x] with appropriately sized [r] 118 | 3. Open ([epsilon] = [x] - [r]) 119 | 4. Return z = [r2] + 2 * epsilon * [r] + epsilon ** 2 120 | """ 121 | provider = crypten.mpc.get_default_provider() 122 | r, r2 = provider.square(x.size(), device=x.device) 123 | 124 | with IgnoreEncodings([x, r]): 125 | epsilon = (x - r).reveal() 126 | return r2 + 2 * r * epsilon + epsilon * epsilon 127 | 128 | 129 | def wraps(x): 130 | """Privately computes the number of wraparounds for a set a shares 131 | 132 | To do so, we note that: 133 | [theta_x] = theta_z + [beta_xr] - [theta_r] - [eta_xr] 134 | 135 | Where [theta_i] is the wraps for a variable i 136 | [beta_ij] is the differential wraps for variables i and j 137 | [eta_ij] is the plaintext wraps for variables i and j 138 | 139 | Note: Since [eta_xr] = 0 with probability 1 - |x| / Q for modulus Q, we 140 | can make the assumption that [eta_xr] = 0 with high probability. 141 | """ 142 | provider = crypten.mpc.get_default_provider() 143 | r, theta_r = provider.wrap_rng(x.size(), device=x.device) 144 | beta_xr = theta_r.clone() 145 | beta_xr._tensor = count_wraps([x._tensor, r._tensor]) 146 | 147 | with IgnoreEncodings([x, r]): 148 | z = x + r 149 | theta_z = comm.get().gather(z._tensor, 0) 150 | theta_x = beta_xr - theta_r 151 | 152 | # TODO: Incorporate eta_xr 153 | if x.rank == 0: 154 | theta_z = count_wraps(theta_z) 155 | theta_x._tensor += theta_z 156 | return theta_x 157 | 158 | 159 | def truncate(x, y): 160 | """Protocol to divide an ArithmeticSharedTensor `x` by a constant integer `y`""" 161 | wrap_count = wraps(x) 162 | x.share = x.share.div_(y, rounding_mode="trunc") 163 | # NOTE: The multiplication here must be split into two parts 164 | # to avoid long out-of-bounds when y <= 2 since (2 ** 63) is 165 | # larger than the largest long integer. 166 | correction = wrap_count * 4 * (int(2**62) // y) 167 | x.share -= correction.share 168 | return x 169 | 170 | 171 | def AND(x, y): 172 | """ 173 | Performs Beaver protocol for binary secret-shared tensors x and y 174 | 175 | 1. Obtain uniformly random sharings [a],[b] and [c] = [a & b] 176 | 2. XOR hide [x] and [y] with appropriately sized [a] and [b] 177 | 3. Open ([epsilon] = [x] ^ [a]) and ([delta] = [y] ^ [b]) 178 | 4. Return [c] ^ (epsilon & [b]) ^ ([a] & delta) ^ (epsilon & delta) 179 | """ 180 | from .binary import BinarySharedTensor 181 | 182 | provider = crypten.mpc.get_default_provider() 183 | a, b, c = provider.generate_binary_triple(x.size(), y.size(), device=x.device) 184 | 185 | # Stack to vectorize reveal 186 | eps_del = BinarySharedTensor.reveal_batch([x ^ a, y ^ b]) 187 | epsilon = eps_del[0] 188 | delta = eps_del[1] 189 | 190 | return (b & epsilon) ^ (a & delta) ^ (epsilon & delta) ^ c 191 | 192 | 193 | def B2A_single_bit(xB): 194 | """Converts a single-bit BinarySharedTensor xB into an 195 | ArithmeticSharedTensor. This is done by: 196 | 197 | 1. Generate ArithmeticSharedTensor [rA] and BinarySharedTensor =rB= with 198 | a common 1-bit value r. 199 | 2. Hide xB with rB and open xB ^ rB 200 | 3. If xB ^ rB = 0, then return [rA], otherwise return 1 - [rA] 201 | Note: This is an arithmetic xor of a single bit. 202 | """ 203 | if comm.get().get_world_size() < 2: 204 | from .arithmetic import ArithmeticSharedTensor 205 | 206 | return ArithmeticSharedTensor(xB._tensor, precision=0, src=0) 207 | 208 | provider = crypten.mpc.get_default_provider() 209 | rA, rB = provider.B2A_rng(xB.size(), device=xB.device) 210 | 211 | z = (xB ^ rB).reveal() 212 | rA = rA * (1 - 2 * z) + z 213 | return rA 214 | -------------------------------------------------------------------------------- /crypten/communicator/communicator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import sys 9 | import timeit 10 | 11 | from crypten.config import cfg 12 | 13 | 14 | class Communicator: 15 | """ 16 | Abstract class defining the functions that a Communicator should implement. 17 | """ 18 | 19 | @classmethod 20 | def is_initialized(cls): 21 | """Returns whether the communicator has been initialized""" 22 | raise NotImplementedError("is_initialized is not implemented") 23 | 24 | @classmethod 25 | def get(cls): 26 | """Returns an instance of the communicator""" 27 | raise NotImplementedError("get is not implemented") 28 | 29 | @classmethod 30 | def initialize(cls, **kwargs): 31 | """Initializes the communicator. Call this function before using it.""" 32 | raise NotImplementedError("initialize is not implemented") 33 | 34 | @classmethod 35 | def shutdown(cls): 36 | raise NotImplementedError("shutdown is not implemented") 37 | 38 | def send(self, tensor, dst): 39 | """Sends the specified tensor to the destination dst.""" 40 | raise NotImplementedError("send is not implemented") 41 | 42 | def recv(self, tensor, src=None): 43 | """Receives a tensor from an (optional) source src.""" 44 | raise NotImplementedError("recv is not implemented") 45 | 46 | def scatter(self, scatter_list, src, size=None, async_op=False): 47 | """Scatters a list of tensors to all parties.""" 48 | raise NotImplementedError("scatter is not implemented") 49 | 50 | def reduce(self, tensor, op=None, async_op=False): 51 | """Reduces the tensor data across all parties.""" 52 | raise NotImplementedError("tensor is not implemented") 53 | 54 | def all_reduce(self, tensor, op=None, async_op=False): 55 | """Reduces the tensor data across all parties; all get the final result.""" 56 | raise NotImplementedError("tensor is not implemented") 57 | 58 | def gather(self, tensor, dst, async_op=False): 59 | """Gathers a list of tensors in a single party.""" 60 | raise NotImplementedError("gather is not implemented") 61 | 62 | def all_gather(self, tensor, async_op=False): 63 | """Gathers tensors from all parties in a list.""" 64 | raise NotImplementedError("all_gather is not implemented") 65 | 66 | def broadcast(self, tensor, src, async_op=False): 67 | """Broadcasts the tensor to all parties.""" 68 | raise NotImplementedError("broadcast is not implemented") 69 | 70 | def barrier(self): 71 | """Synchronizes all processes. 72 | 73 | This collective blocks processes until the whole group enters this 74 | function. 75 | """ 76 | raise NotImplementedError("barrier is not implemented") 77 | 78 | def send_obj(self, obj, dst): 79 | """Sends the specified object to the destination `dst`.""" 80 | raise NotImplementedError("send_obj is not implemented") 81 | 82 | def recv_obj(self, src): 83 | """Receives a tensor from a source src.""" 84 | raise NotImplementedError("recv_obj is not implemented") 85 | 86 | def broadcast_obj(self, obj, src): 87 | """Broadcasts a given object to all parties.""" 88 | raise NotImplementedError("broadcast_obj is not implemented") 89 | 90 | def get_world_size(self): 91 | """Returns the size of the world.""" 92 | raise NotImplementedError("get_world_size is not implemented") 93 | 94 | def get_rank(self): 95 | """Returns the rank of the current process.""" 96 | raise NotImplementedError("get_rank is not implemented") 97 | 98 | def set_name(self): 99 | """Sets the party name of the current process.""" 100 | raise NotImplementedError("set_name is not implemented") 101 | 102 | def get_name(self): 103 | """Returns the party name of the current process.""" 104 | raise NotImplementedError("get_name is not implemented") 105 | 106 | def reset_communication_stats(self): 107 | """Resets communication statistics.""" 108 | self.comm_rounds = 0 109 | self.comm_bytes = 0 110 | self.comm_time = 0 111 | 112 | def print_communication_stats(self): 113 | """ 114 | Prints communication statistics. 115 | 116 | NOTE: Each party performs its own logging of communication, so one needs 117 | to sum the number of bytes communicated over all parties and divide by 118 | two (to prevent double-counting) to obtain the number of bytes 119 | communicated in the overall system. 120 | """ 121 | import crypten 122 | 123 | crypten.log("====Communication Stats====") 124 | crypten.log("Rounds: {}".format(self.comm_rounds)) 125 | crypten.log("Bytes: {}".format(self.comm_bytes)) 126 | crypten.log("Communication time: {}".format(self.comm_time)) 127 | 128 | def get_communication_stats(self): 129 | """ 130 | Returns communication statistics in a Python dict. 131 | 132 | NOTE: Each party performs its own logging of communication, so one needs 133 | to sum the number of bytes communicated over all parties and divide by 134 | two (to prevent double-counting) to obtain the number of bytes 135 | communicated in the overall system. 136 | """ 137 | return { 138 | "rounds": self.comm_rounds, 139 | "bytes": self.comm_bytes, 140 | "time": self.comm_time, 141 | } 142 | 143 | def _log_communication(self, nelement): 144 | """Updates log of communication statistics.""" 145 | self.comm_rounds += 1 146 | self.comm_bytes += nelement * self.BYTES_PER_ELEMENT 147 | 148 | def _log_communication_time(self, comm_time): 149 | self.comm_time += comm_time 150 | 151 | 152 | def _logging(func): 153 | """ 154 | Decorator that performs logging of communication statistics. 155 | 156 | NOTE: Each party performs its own logging of communication, so one needs to 157 | sum the number of bytes communicated over all parties and divide by two 158 | (to prevent double-counting) to obtain the number of bytes communicated in 159 | the overall system. 160 | """ 161 | from functools import wraps 162 | 163 | @wraps(func) 164 | def logging_wrapper(self, *args, **kwargs): 165 | 166 | # TODO: Replace this 167 | # - hacks the inputs into some of the functions for world_size 1: 168 | world_size = self.get_world_size() 169 | if world_size < 2: 170 | if func.__name__ in ["gather", "all_gather"]: 171 | return [args[0]] 172 | elif len(args) > 0: 173 | return args[0] 174 | 175 | # only log communication if needed: 176 | if cfg.communicator.verbose: 177 | rank = self.get_rank() 178 | _log = self._log_communication 179 | 180 | # count number of bytes communicates for each MPI collective: 181 | if func.__name__ == "barrier": 182 | _log(0) 183 | elif func.__name__ in ["send", "recv", "isend", "irecv"]: 184 | _log(args[0].nelement()) # party sends or receives tensor 185 | elif func.__name__ == "scatter": 186 | if args[1] == rank: # party scatters P - 1 tensors 187 | nelements = sum( 188 | x.nelement() for idx, x in enumerate(args[0]) if idx != rank 189 | ) 190 | _log(nelements) # NOTE: We deal with other parties later 191 | elif func.__name__ == "all_gather": 192 | _log(2 * (world_size - 1) * args[0].nelement()) 193 | # party sends and receives P - 1 tensors 194 | elif func.__name__ == "send_obj": 195 | nbytes = sys.getsizeof(args[0]) 196 | _log(nbytes / self.BYTES_PER_ELEMENT) # party sends object 197 | elif func.__name__ == "broadcast_obj": 198 | nbytes = sys.getsizeof(args[0]) 199 | _log(nbytes / self.BYTES_PER_ELEMENT * (world_size - 1)) 200 | # party sends object to P - 1 parties 201 | elif func.__name__ in ["broadcast", "gather", "reduce"]: 202 | multiplier = world_size - 1 if args[1] == rank else 1 203 | # broadcast: party sends tensor to P - 1 parties, or receives 1 tensor 204 | # gather: party receives P - 1 tensors, or sends 1 tensor 205 | # reduce: party receives P - 1 tensors, or sends 1 tensor 206 | if "batched" in kwargs and kwargs["batched"]: 207 | nelements = sum(x.nelement() for x in args[0]) 208 | _log(nelements * multiplier) 209 | else: 210 | _log(args[0].nelement() * multiplier) 211 | elif func.__name__ == "all_reduce": 212 | # each party sends and receives one tensor in ring implementation 213 | if "batched" in kwargs and kwargs["batched"]: 214 | nelements = sum(2 * x.nelement() for x in args[0]) 215 | _log(nelements) 216 | else: 217 | _log(2 * args[0].nelement()) 218 | 219 | # execute and time the MPI collective: 220 | tic = timeit.default_timer() 221 | result = func(self, *args, **kwargs) 222 | toc = timeit.default_timer() 223 | self._log_communication_time(toc - tic) 224 | 225 | # for some function, we only know the object size now: 226 | if func.__name__ == "scatter" and args[1] != rank: 227 | _log(result.nelement()) # party receives 1 tensor 228 | if func.__name__ == "recv_obj": 229 | _log(sys.getsizeof(result) / self.BYTES_PER_ELEMENT) 230 | # party receives 1 object 231 | 232 | return result 233 | 234 | return func(self, *args, **kwargs) 235 | 236 | return logging_wrapper 237 | -------------------------------------------------------------------------------- /crypten/common/functions/regular.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | 10 | from ..tensor_types import is_tensor 11 | from ..util import torch_cat, torch_stack 12 | 13 | 14 | __all__ = [ # noqa: F822 15 | "__getitem__", 16 | "__len__", 17 | "__setitem__", 18 | "cat", 19 | "cumsum", 20 | "dim", 21 | "dot", 22 | "expand", 23 | "flatten", 24 | "flip", 25 | "gather", 26 | "ger", 27 | "index_add", 28 | "index_select", 29 | "mean", 30 | "narrow", 31 | "nelement", 32 | "numel", 33 | "pad", 34 | "permute", 35 | "prod", 36 | "repeat", 37 | "reshape", 38 | "roll", 39 | "scatter", 40 | "scatter_add", 41 | "size", 42 | "split", 43 | "squeeze", 44 | "stack", 45 | "sum", 46 | "t", 47 | "take", 48 | "trace", 49 | "transpose", 50 | "unbind", 51 | "unfold", 52 | "unsqueeze", 53 | "var", 54 | "view", 55 | ] 56 | 57 | 58 | PROPERTY_FUNCTIONS = ["__len__", "nelement", "dim", "size", "numel"] 59 | 60 | 61 | def __setitem__(self, index, value): 62 | """Set tensor values by index""" 63 | if not isinstance(value, type(self)): 64 | kwargs = {"device": self.device} 65 | if hasattr(self, "ptype"): 66 | kwargs["ptype"] = self.ptype 67 | value = self.new(value, **kwargs) 68 | self._tensor.__setitem__(index, value._tensor) 69 | 70 | 71 | def pad(self, pad, mode="constant", value=0): 72 | result = self.shallow_copy() 73 | if hasattr(value, "_tensor"): 74 | value = value._tensor 75 | 76 | if hasattr(result._tensor, "pad"): 77 | result._tensor = self._tensor.pad(pad, mode=mode, value=value) 78 | else: 79 | result._tensor = torch.nn.functional.pad( 80 | self._tensor, pad, mode=mode, value=value 81 | ) 82 | return result 83 | 84 | 85 | def index_add(self, dim, index, tensor): 86 | """Performs out-of-place index_add: Accumulate the elements of tensor into the 87 | self tensor by adding to the indices in the order given in index. 88 | """ 89 | result = self.clone() 90 | assert index.dim() == 1, "index needs to be a vector" 91 | tensor = getattr(tensor, "_tensor", tensor) 92 | result._tensor.index_add_(dim, index, tensor) 93 | return result 94 | 95 | 96 | def scatter_add(self, dim, index, other): 97 | """Adds all values from the tensor other into self at the indices 98 | specified in the index tensor in a similar fashion as scatter_(). For 99 | each value in other, it is added to an index in self which is specified 100 | by its index in other for dimension != dim and by the corresponding 101 | value in index for dimension = dim. 102 | """ 103 | result = self.clone() 104 | other = getattr(other, "_tensor", other) 105 | result._tensor.scatter_add_(dim, index, other) 106 | return result 107 | 108 | 109 | def scatter(self, dim, index, src): 110 | """Out-of-place version of :math:`CrypTensor.scatter_`""" 111 | result = self.clone() 112 | if is_tensor(src): 113 | src = self.new(src) 114 | assert isinstance(src, type(self)), "Unrecognized scatter src type: %s" % type(src) 115 | result._tensor.scatter_(dim, index, src._tensor) 116 | return result 117 | 118 | 119 | def unbind(self, dim=0): 120 | tensors = self._tensor.unbind(dim=dim) 121 | results = tuple(self.shallow_copy() for _ in range(len(tensors))) 122 | for i in range(len(tensors)): 123 | results[i]._tensor = tensors[i] 124 | return results 125 | 126 | 127 | def split(self, split_size, dim=0): 128 | tensors = self._tensor.split(split_size, dim=dim) 129 | results = tuple(self.shallow_copy() for _ in range(len(tensors))) 130 | for i in range(len(tensors)): 131 | results[i]._tensor = tensors[i] 132 | return results 133 | 134 | 135 | def take(self, index, dimension=None): 136 | """Take entries of tensor along a dimension according to the index. 137 | This function is identical to torch.take() when dimension=None, 138 | otherwise, it is identical to ONNX gather() function. 139 | """ 140 | result = self.shallow_copy() 141 | index = index.long() 142 | if dimension is None or self.dim() == 0: 143 | result._tensor = self._tensor.take(index) 144 | else: 145 | all_indices = [slice(0, x) for x in self.size()] 146 | all_indices[dimension] = index 147 | result._tensor = self._tensor[all_indices] 148 | return result 149 | 150 | 151 | def mean(self, *args, **kwargs): 152 | """Computes mean of given tensor""" 153 | result = self.sum(*args, **kwargs) 154 | 155 | # Handle special case where input has 0 dimensions 156 | if self.dim() == 0: 157 | return result 158 | 159 | # Compute divisor to use to compute mean 160 | divisor = self.nelement() // result.nelement() 161 | return result.div(divisor) 162 | 163 | 164 | def var(self, *args, **kwargs): 165 | """Computes variance of tensor along specified dimensions.""" 166 | # preprocess inputs: 167 | if len(args) == 0: 168 | dim = kwargs.get("dim", None) 169 | unbiased = kwargs.get("unbiased", False) 170 | keepdim = kwargs.get("keepdim", False) 171 | elif len(args) == 1: 172 | dim = args[0] 173 | unbiased = kwargs.get("unbiased", False) 174 | keepdim = kwargs.get("keepdim", False) 175 | elif len(args) == 2: 176 | dim, unbiased = args[0], args[1] 177 | keepdim = kwargs.get("keepdim", False) 178 | else: 179 | dim, unbiased, keepdim = args[0], args[1], args[2] 180 | 181 | if dim is not None: # dimension is specified 182 | mean = self.mean(dim, keepdim=True) 183 | else: 184 | mean = self.mean() 185 | 186 | # Compute square error 187 | result = (self - mean).square() 188 | if dim is None: 189 | result = result.sum() 190 | else: 191 | result = result.sum(dim, keepdim=keepdim) 192 | 193 | # Determine divisor 194 | divisor = self.nelement() // result.nelement() 195 | if not unbiased: 196 | divisor -= 1 197 | 198 | # Compute mean square error 199 | if divisor in [0, 1]: 200 | return result 201 | return result.div(divisor) 202 | 203 | 204 | def prod(self, dim=None, keepdim=False): 205 | """ 206 | Returns the product of each row of the `input` tensor in the given 207 | dimension `dim`. 208 | 209 | If `keepdim` is `True`, the output tensor is of the same size as `input` 210 | except in the dimension `dim` where it is of size 1. Otherwise, `dim` is 211 | squeezed, resulting in the output tensor having 1 fewer dimension than 212 | `input`. 213 | """ 214 | if dim is None: 215 | return self.flatten().prod(dim=0) 216 | 217 | result = self.clone() 218 | while result.size(dim) > 1: 219 | size = result.size(dim) 220 | x, y, remainder = result.split([size // 2, size // 2, size % 2], dim=dim) 221 | result = x.mul_(y) 222 | result = type(self).cat([result, remainder], dim=dim) 223 | 224 | # Squeeze result if necessary 225 | if not keepdim: 226 | result = result.squeeze(dim) 227 | return result 228 | 229 | 230 | def dot(self, y, weights=None): 231 | """Compute a dot product between two tensors""" 232 | assert self.size() == y.size(), "Number of elements do not match" 233 | if weights is not None: 234 | assert weights.size() == self.size(), "Incorrect number of weights" 235 | result = self * weights 236 | else: 237 | result = self.clone() 238 | 239 | return result.mul(y).sum() 240 | 241 | 242 | def ger(self, y): 243 | """Computer an outer product between two vectors""" 244 | assert self.dim() == 1 and y.dim() == 1, "Outer product must be on 1D tensors" 245 | return self.view((-1, 1)).matmul(y.view((1, -1))) 246 | 247 | 248 | def __cat_stack_helper(op, tensors, *args, **kwargs): 249 | assert op in ["cat", "stack"], "Unsupported op for helper function" 250 | assert isinstance(tensors, list), "%s input must be a list" % op 251 | assert len(tensors) > 0, "expected a non-empty list of CrypTensors" 252 | 253 | # Determine op-type 254 | funcs = {"cat": torch_cat, "stack": torch_stack} 255 | func = funcs[op] 256 | if hasattr(tensors[0]._tensor, op): 257 | func = getattr(tensors[0]._tensor, op) 258 | 259 | # type coordination 260 | for i, tensor in enumerate(tensors[1:]): 261 | if torch.is_tensor(tensor) or isinstance(tensor, (int, float)): 262 | tensors[i] = tensors[0].new(tensor) 263 | assert isinstance(tensors[i], type(tensors[0])), f"{op} tensor type mismatch" 264 | 265 | # Operate on all input tensors 266 | result = tensors[0].clone() 267 | result._tensor = func([tensor._tensor for tensor in tensors], *args, **kwargs) 268 | return result 269 | 270 | 271 | def cat(tensors, *args, **kwargs): 272 | """Perform tensor concatenation""" 273 | return __cat_stack_helper("cat", tensors, *args, **kwargs) 274 | 275 | 276 | def stack(tensors, *args, **kwargs): 277 | """Perform tensor stacking""" 278 | return __cat_stack_helper("stack", tensors, *args, **kwargs) 279 | 280 | 281 | # Make static methods static 282 | cat = staticmethod(cat) 283 | stack = staticmethod(stack) 284 | 285 | 286 | # Add remaining regular functions 287 | def _add_regular_function(function_name): 288 | """ 289 | Adds regular function that is applied directly on the underlying 290 | `_tensor` attribute, and stores the result in the same attribute. 291 | """ 292 | 293 | def regular_func(self, *args, **kwargs): 294 | result = self.shallow_copy() 295 | result._tensor = getattr(result._tensor, function_name)(*args, **kwargs) 296 | return result 297 | 298 | if function_name not in globals(): 299 | globals()[function_name] = regular_func 300 | 301 | 302 | def _add_property_function(function_name): 303 | """ 304 | Adds regular function that is applied directly on the underlying 305 | `_tensor` attribute, and returns the result of that function. 306 | """ 307 | 308 | def property_func(self, *args, **kwargs): 309 | return getattr(self._tensor, function_name)(*args, **kwargs) 310 | 311 | if function_name not in globals(): 312 | globals()[function_name] = property_func 313 | 314 | 315 | for function_name in __all__: 316 | if function_name in PROPERTY_FUNCTIONS: 317 | continue 318 | _add_regular_function(function_name) 319 | 320 | for function_name in PROPERTY_FUNCTIONS: 321 | _add_property_function(function_name) 322 | -------------------------------------------------------------------------------- /crypten/mpc/mpc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import torch 9 | from crypten import communicator as comm 10 | from crypten.common.tensor_types import is_tensor 11 | from crypten.common.util import torch_stack 12 | from crypten.config import cfg 13 | from crypten.cuda import CUDALongTensor 14 | 15 | from ..cryptensor import CrypTensor 16 | from ..encoder import FixedPointEncoder 17 | from .primitives.binary import BinarySharedTensor 18 | from .primitives.converters import convert 19 | from .ptype import ptype as Ptype 20 | 21 | 22 | @CrypTensor.register_cryptensor("mpc") 23 | class MPCTensor(CrypTensor): 24 | def __init__(self, tensor, ptype=Ptype.arithmetic, device=None, *args, **kwargs): 25 | """ 26 | Creates the shared tensor from the input `tensor` provided by party `src`. 27 | The `ptype` defines the type of sharing used (default: arithmetic). 28 | 29 | The other parties can specify a `tensor` or `size` to determine the size 30 | of the shared tensor object to create. In this case, all parties must 31 | specify the same (tensor) size to prevent the party's shares from varying 32 | in size, which leads to undefined behavior. 33 | 34 | Alternatively, the parties can set `broadcast_size` to `True` to have the 35 | `src` party broadcast the correct size. The parties who do not know the 36 | tensor size beforehand can provide an empty tensor as input. This is 37 | guaranteed to produce correct behavior but requires an additional 38 | communication round. 39 | 40 | The parties can also set the `precision` and `device` for their share of 41 | the tensor. If `device` is unspecified, it is set to `tensor.device`. 42 | """ 43 | if tensor is None: 44 | raise ValueError("Cannot initialize tensor with None.") 45 | 46 | # take required_grad from kwargs, input tensor, or set to False: 47 | default = tensor.requires_grad if torch.is_tensor(tensor) else False 48 | requires_grad = kwargs.pop("requires_grad", default) 49 | 50 | # call CrypTensor constructor: 51 | super().__init__(requires_grad=requires_grad) 52 | if device is None and hasattr(tensor, "device"): 53 | device = tensor.device 54 | 55 | # create the MPCTensor: 56 | tensor_type = ptype.to_tensor() 57 | if tensor is []: 58 | self._tensor = torch.tensor([], device=device) 59 | else: 60 | self._tensor = tensor_type(tensor=tensor, device=device, *args, **kwargs) 61 | self.ptype = ptype 62 | 63 | @staticmethod 64 | def new(*args, **kwargs): 65 | """ 66 | Creates a new MPCTensor, passing all args and kwargs into the constructor. 67 | """ 68 | return MPCTensor(*args, **kwargs) 69 | 70 | @staticmethod 71 | def from_shares(share, precision=None, ptype=Ptype.arithmetic): 72 | result = MPCTensor([]) 73 | from_shares = ptype.to_tensor().from_shares 74 | result._tensor = from_shares(share, precision=precision) 75 | result.ptype = ptype 76 | return result 77 | 78 | def clone(self): 79 | """Create a deep copy of the input tensor.""" 80 | # TODO: Rename this to __deepcopy__()? 81 | result = MPCTensor([]) 82 | result._tensor = self._tensor.clone() 83 | result.ptype = self.ptype 84 | return result 85 | 86 | def shallow_copy(self): 87 | """Create a shallow copy of the input tensor.""" 88 | # TODO: Rename this to __copy__()? 89 | result = MPCTensor([]) 90 | result._tensor = self._tensor 91 | result.ptype = self.ptype 92 | return result 93 | 94 | def copy_(self, other): 95 | """Copies value of other MPCTensor into this MPCTensor.""" 96 | assert isinstance(other, MPCTensor), "other must be MPCTensor" 97 | self._tensor.copy_(other._tensor) 98 | self.ptype = other.ptype 99 | 100 | def to(self, *args, **kwargs): 101 | r""" 102 | Depending on the input arguments, 103 | converts underlying share to the given ptype or 104 | performs `torch.to` on the underlying torch tensor 105 | 106 | To convert underlying share to the given ptype, call `to` as: 107 | to(ptype, **kwargs) 108 | 109 | It will call MPCTensor.to_ptype with the arguments provided above. 110 | 111 | Otherwise, `to` performs `torch.to` on the underlying 112 | torch tensor. See 113 | https://pytorch.org/docs/stable/tensors.html?highlight=#torch.Tensor.to 114 | for a reference of the parameters that can be passed in. 115 | 116 | Args: 117 | ptype: Ptype.arithmetic or Ptype.binary. 118 | """ 119 | if "ptype" in kwargs: 120 | return self._to_ptype(**kwargs) 121 | elif args and isinstance(args[0], Ptype): 122 | ptype = args[0] 123 | return self._to_ptype(ptype, **kwargs) 124 | else: 125 | share = self.share.to(*args, **kwargs) 126 | if share.is_cuda: 127 | share = CUDALongTensor(share) 128 | self.share = share 129 | return self 130 | 131 | def _to_ptype(self, ptype, **kwargs): 132 | r""" 133 | Convert MPCTensor's underlying share to the corresponding ptype 134 | (ArithmeticSharedTensor, BinarySharedTensor) 135 | 136 | Args: 137 | ptype (Ptype.arithmetic or Ptype.binary): The ptype to convert 138 | the shares to. 139 | precision (int, optional): Precision of the fixed point encoder when 140 | converting a binary share to an arithmetic share. It will be ignored 141 | if the ptype doesn't match. 142 | bits (int, optional): If specified, will only preserve the bottom `bits` bits 143 | of a binary tensor when converting from a binary share to an arithmetic share. 144 | It will be ignored if the ptype doesn't match. 145 | """ 146 | retval = self.clone() 147 | if retval.ptype == ptype: 148 | return retval 149 | retval._tensor = convert(self._tensor, ptype, **kwargs) 150 | retval.ptype = ptype 151 | return retval 152 | 153 | @property 154 | def device(self): 155 | """Return the `torch.device` of the underlying share""" 156 | return self.share.device 157 | 158 | @property 159 | def is_cuda(self): 160 | """Return True if the underlying share is stored on GPU, False otherwise""" 161 | return self.share.is_cuda 162 | 163 | def cuda(self, *args, **kwargs): 164 | """Call `torch.Tensor.cuda` on the underlying share""" 165 | self.share = CUDALongTensor(self.share.cuda(*args, **kwargs)) 166 | return self 167 | 168 | def cpu(self): 169 | """Call `torch.Tensor.cpu` on the underlying share""" 170 | self.share = self.share.cpu() 171 | return self 172 | 173 | def get_plain_text(self, dst=None): 174 | """Decrypts the tensor.""" 175 | return self._tensor.get_plain_text(dst=dst) 176 | 177 | def reveal(self, dst=None): 178 | """Decrypts the tensor without any downscaling.""" 179 | return self._tensor.reveal(dst=dst) 180 | 181 | def __repr__(self): 182 | """Returns a representation of the tensor useful for debugging.""" 183 | debug_mode = cfg.debug.debug_mode 184 | 185 | share = self.share 186 | plain_text = self._tensor.get_plain_text() if debug_mode else "HIDDEN" 187 | ptype = self.ptype 188 | return ( 189 | f"MPCTensor(\n\t_tensor={share}\n" 190 | f"\tplain_text={plain_text}\n\tptype={ptype}\n)" 191 | ) 192 | 193 | def __hash__(self): 194 | return hash(self.share) 195 | 196 | @property 197 | def share(self): 198 | """Returns underlying share""" 199 | return self._tensor.share 200 | 201 | @share.setter 202 | def share(self, value): 203 | """Sets share to value""" 204 | self._tensor.share = value 205 | 206 | @property 207 | def encoder(self): 208 | """Returns underlying encoder""" 209 | return self._tensor.encoder 210 | 211 | @encoder.setter 212 | def encoder(self, value): 213 | """Sets encoder to value""" 214 | self._tensor.encoder = value 215 | 216 | @staticmethod 217 | def rand(*sizes, device=None): 218 | """ 219 | Returns a tensor with elements uniformly sampled in [0, 1). The uniform 220 | random samples are generated by generating random bits using fixed-point 221 | encoding and converting the result to an ArithmeticSharedTensor. 222 | """ 223 | rand = MPCTensor([]) 224 | encoder = FixedPointEncoder() 225 | rand._tensor = BinarySharedTensor.rand( 226 | *sizes, bits=encoder._precision_bits, device=device 227 | ) 228 | rand._tensor.encoder = encoder 229 | rand.ptype = Ptype.binary 230 | return rand.to(Ptype.arithmetic, bits=encoder._precision_bits) 231 | 232 | # Comparators 233 | def _ltz(self): 234 | """Returns 1 for elements that are < 0 and 0 otherwise""" 235 | shift = torch.iinfo(torch.long).bits - 1 236 | precision = 0 if self.encoder.scale == 1 else None 237 | 238 | result = self._to_ptype(Ptype.binary) 239 | result.share >>= shift 240 | result = result._to_ptype(Ptype.arithmetic, precision=precision, bits=1) 241 | result.encoder._scale = 1 242 | return result 243 | 244 | def eq(self, y): 245 | """Returns self == y""" 246 | if comm.get().get_world_size() == 2: 247 | return (self - y)._eqz_2PC() 248 | 249 | return 1 - self.ne(y) 250 | 251 | def ne(self, y): 252 | """Returns self != y""" 253 | if comm.get().get_world_size() == 2: 254 | return 1 - self.eq(y) 255 | 256 | difference = self - y 257 | difference.share = torch_stack([difference.share, -(difference.share)]) 258 | return difference._ltz().sum(0) 259 | 260 | def _eqz_2PC(self): 261 | """Returns self == 0""" 262 | # Create BinarySharedTensors from shares 263 | x0 = MPCTensor(self.share, src=0, ptype=Ptype.binary) 264 | x1 = MPCTensor(-self.share, src=1, ptype=Ptype.binary) 265 | 266 | # Perform equality testing using binary shares 267 | x0._tensor = x0._tensor.eq(x1._tensor) 268 | x0.encoder = self.encoder 269 | 270 | # Convert to Arithmetic sharing 271 | result = x0.to(Ptype.arithmetic, bits=1) 272 | result.encoder._scale = 1 273 | 274 | return result 275 | 276 | def div(self, y): 277 | r"""Divides each element of :attr:`self` with the scalar :attr:`y` or 278 | each element of the tensor :attr:`y` and returns a new resulting tensor. 279 | 280 | For `y` a scalar: 281 | 282 | .. math:: 283 | \text{out}_i = \frac{\text{self}_i}{\text{y}} 284 | 285 | For `y` a tensor: 286 | 287 | .. math:: 288 | \text{out}_i = \frac{\text{self}_i}{\text{y}_i} 289 | 290 | Note for :attr:`y` a tensor, the shapes of :attr:`self` and :attr:`y` must be 291 | `broadcastable`_. 292 | 293 | .. _broadcastable: 294 | https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics""" # noqa: B950 295 | result = self.clone() 296 | if isinstance(y, CrypTensor): 297 | result.share = torch.broadcast_tensors(result.share, y.share)[0].clone() 298 | elif is_tensor(y): 299 | result.share = torch.broadcast_tensors(result.share, y)[0].clone() 300 | 301 | if isinstance(y, MPCTensor): 302 | return result.mul(y.reciprocal()) 303 | result._tensor.div_(y) 304 | return result 305 | 306 | 307 | UNARY_FUNCTIONS = [ 308 | "avg_pool2d", 309 | "square", 310 | "neg", 311 | ] 312 | 313 | BINARY_FUNCTIONS = [ 314 | "add", 315 | "sub", 316 | "mul", 317 | "matmul", 318 | "conv1d", 319 | "conv2d", 320 | "conv_transpose1d", 321 | "conv_transpose2d", 322 | ] 323 | 324 | 325 | def _add_unary_passthrough_function(name): 326 | def unary_wrapper_function(self, *args, **kwargs): 327 | result = self.shallow_copy() 328 | result._tensor = getattr(result._tensor, name)(*args, **kwargs) 329 | return result 330 | 331 | setattr(MPCTensor, name, unary_wrapper_function) 332 | 333 | 334 | def _add_binary_passthrough_function(name): 335 | def binary_wrapper_function(self, value, *args, **kwargs): 336 | result = self.shallow_copy() 337 | if isinstance(value, MPCTensor): 338 | value = value._tensor 339 | result._tensor = getattr(result._tensor, name)(value, *args, **kwargs) 340 | return result 341 | 342 | setattr(MPCTensor, name, binary_wrapper_function) 343 | 344 | 345 | for func_name in UNARY_FUNCTIONS: 346 | _add_unary_passthrough_function(func_name) 347 | 348 | for func_name in BINARY_FUNCTIONS: 349 | _add_binary_passthrough_function(func_name) 350 | --------------------------------------------------------------------------------