├── .github ├── actions │ └── build-manylinux │ │ └── action.yml └── workflows │ └── build.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── configs ├── lola.yml ├── mlp.yml └── resnet.yml ├── examples ├── run_lola.py ├── run_mlp.py └── run_resnet.py ├── models ├── __init__.py ├── alexnet.py ├── lenet.py ├── lola.py ├── mlp.py ├── resnet.py ├── vgg.py └── yolo.py ├── orion ├── __init__.py ├── backend │ ├── __init__.py │ ├── heaan │ │ └── README.md │ ├── lattigo │ │ ├── __init__.py │ │ ├── bindings.py │ │ ├── bootstrapper.go │ │ ├── encoder.go │ │ ├── encryptor.go │ │ ├── evaluator.go │ │ ├── go.mod │ │ ├── go.sum │ │ ├── keygenerator.go │ │ ├── lineartransform.go │ │ ├── main.go │ │ ├── minheap.go │ │ ├── polyeval.go │ │ ├── scheme.go │ │ ├── tensors.go │ │ └── utils.go │ ├── openfhe │ │ └── README.md │ └── python │ │ ├── __init__.py │ │ ├── bootstrapper.py │ │ ├── encoder.py │ │ ├── encryptor.py │ │ ├── evaluator.py │ │ ├── key_generator.py │ │ ├── lt_evaluator.py │ │ ├── parameters.py │ │ ├── poly_evaluator.py │ │ └── tensors.py ├── core │ ├── __init__.py │ ├── auto_bootstrap.py │ ├── fuser.py │ ├── level_dag.py │ ├── network_dag.py │ ├── orion.py │ ├── packing.py │ ├── tracer.py │ └── utils.py ├── models │ ├── __init__.py │ ├── alexnet.py │ ├── lenet.py │ ├── lola.py │ ├── mlp.py │ ├── resnet.py │ ├── vgg.py │ └── yolo.py └── nn │ ├── __init__.py │ ├── activation.py │ ├── linear.py │ ├── module.py │ ├── normalization.py │ ├── operations.py │ ├── pooling.py │ └── reshape.py ├── pyproject.toml ├── tests ├── __init__.py ├── configs │ └── mlp.yml ├── models │ └── test_mlp.py ├── test_imports.py └── test_placeholder.py └── tools └── build_lattigo.py /.github/actions/build-manylinux/action.yml: -------------------------------------------------------------------------------- 1 | name: 'build manylinux wheels' 2 | description: 'builds manylinux wheels for orion' 3 | inputs: 4 | python-version: 5 | description: 'python version to use' 6 | required: true 7 | 8 | runs: 9 | using: 'composite' 10 | steps: 11 | - name: set up go 12 | uses: actions/setup-go@v5 13 | with: 14 | go-version: '1.21.x' 15 | 16 | - name: install uv 17 | shell: bash 18 | run: | 19 | yum install -y python3-pip 20 | python${{ inputs.python-version }} -m pip install poetry 21 | 22 | - name: build package 23 | shell: bash -l {0} 24 | run: | 25 | python${{ inputs.python-version }} -m poetry build 26 | 27 | echo "Contents of the wheel file:" 28 | unzip -l dist/*.whl 29 | 30 | - name: use auditwheel to support earlier manylinux 31 | shell: bash 32 | run: | 33 | auditwheel repair dist/*.whl --plat manylinux2014_x86_64 34 | 35 | echo "Contents of the wheelhouse directory:" 36 | ls -l wheelhouse/ 37 | 38 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish 2 | 3 | on: 4 | push: 5 | branches: [main, dev] 6 | pull_request: 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | # In order to build general manylinux wheels that support older Linux 13 | # distributions, we'll need to use Docker. 14 | build-manylinux: 15 | runs-on: ubuntu-latest 16 | container: 17 | image: quay.io/pypa/manylinux_2_28_x86_64 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: ["3.9", "3.10", "3.11", "3.12"] 22 | steps: 23 | - name: checkout repository 24 | uses: actions/checkout@v4 25 | 26 | - name: build in docker 27 | uses: ./.github/actions/build-manylinux 28 | with: 29 | python-version: ${{ matrix.python-version }} 30 | 31 | - name: upload linux wheels 32 | uses: actions/upload-artifact@v4 33 | with: 34 | name: wheels-manylinux2014-py${{ matrix.python-version }}-x64 35 | path: wheelhouse/*.whl 36 | 37 | # For all other operating systems we support (Windows, macOS Intel (x64) 38 | # and macOS Apple Silicon (arm64)) we can use the default machines GitHub 39 | # Actions provides 40 | build-other: 41 | runs-on: ${{ matrix.config.os }} 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | config: 46 | - {os: windows-latest, arch: "x64"} 47 | - {os: macos-latest, arch: "x64"} 48 | - {os: macos-latest, arch: "arm64"} 49 | python-version: ["3.9", "3.10", "3.11", "3.12"] 50 | 51 | steps: 52 | - name: checkout orion repository 53 | uses: actions/checkout@v4 54 | 55 | # Conda is necessary here else we can't build MacOS arm64 on 56 | # Python versions < 3.11. 57 | - name: setup python 58 | uses: conda-incubator/setup-miniconda@v2 59 | with: 60 | miniconda-version: "latest" 61 | python-version: ${{ matrix.python-version }} 62 | architecture: ${{ matrix.config.arch }} 63 | auto-activate-base: true 64 | 65 | - name: set up golang 66 | uses: actions/setup-go@v5 67 | with: 68 | go-version: '1.21.x' 69 | 70 | - name: build wheels 71 | shell: bash -l {0} 72 | run: | 73 | pip install poetry 74 | poetry build 75 | 76 | echo "Contents of the wheel file:" 77 | unzip -l dist/*.whl 78 | 79 | - name: upload wheels 80 | uses: actions/upload-artifact@v4 81 | with: 82 | name: wheels-${{ matrix.config.os }}-py${{ matrix.python-version }}-${{ matrix.config.arch }} 83 | path: dist/*.whl 84 | 85 | # Now we can test the wheels we just created. We'll do this by downloading 86 | # all wheels that were added as artifacts in the previous build jobs. 87 | test: 88 | needs: [build-manylinux, build-other] 89 | runs-on: ${{ matrix.config.os }} 90 | strategy: 91 | fail-fast: false 92 | matrix: 93 | config: 94 | - {os: ubuntu-latest, arch: "x64"} 95 | - {os: windows-latest, arch: "x64"} 96 | - {os: macos-latest, arch: "x64"} 97 | - {os: macos-latest, arch: "arm64"} 98 | python-version: ["3.9", "3.10", "3.11", "3.12"] 99 | 100 | steps: 101 | - name: checkout orion repository 102 | uses: actions/checkout@v4 103 | 104 | - name: set up python 105 | uses: actions/setup-python@v5 106 | with: 107 | python-version: ${{ matrix.python-version }} 108 | 109 | - name: download artifact wheels 110 | uses: actions/download-artifact@v4 111 | with: 112 | pattern: wheels-* 113 | path: dist/ 114 | merge-multiple: true 115 | 116 | - name: install packages 117 | shell: bash 118 | run: | 119 | pip install --find-links=dist orion-fhe pytest pytest-cov 120 | 121 | - name: run pytest 122 | shell: bash 123 | run: | 124 | # Create temp directory 125 | cd $RUNNER_TEMP 126 | echo "Testing from directory: $RUNNER_TEMP" 127 | python -m pytest $GITHUB_WORKSPACE/tests/ 128 | 129 | publish: 130 | needs: [test] 131 | runs-on: ubuntu-latest 132 | if: github.event_name == 'push' && github.ref == 'refs/heads/main' 133 | 134 | steps: 135 | - name: checkout repository 136 | uses: actions/checkout@v4 137 | 138 | - name: set up python 139 | uses: actions/setup-python@v5 140 | with: 141 | python-version: "3.12" 142 | 143 | - name: download all artifacts 144 | uses: actions/download-artifact@v4 145 | with: 146 | pattern: wheels-* 147 | path: dist/ 148 | merge-multiple: true 149 | 150 | - name: publish to pypi 151 | run: | 152 | pip install poetry 153 | poetry config pypi-token.pypi ${{ secrets.PYPI_API_TOKEN }} 154 | poetry publish --skip-existing 155 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | 3 | *.DS_Store 4 | *data/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | *__pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | *.h 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | *.vscode/ 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # UV 105 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | #uv.lock 109 | 110 | # poetry 111 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 112 | # This is especially recommended for binary packages to ensure reproducibility, and is more 113 | # commonly ignored for libraries. 114 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 115 | #poetry.lock 116 | 117 | # pdm 118 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 119 | #pdm.lock 120 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 121 | # in version control. 122 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 123 | .pdm.toml 124 | .pdm-python 125 | .pdm-build/ 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | .env/ 143 | .venv/ 144 | ENV/ 145 | env.bak/ 146 | venv.bak/ 147 | 148 | # Spyder project settings 149 | .spyderproject 150 | .spyproject 151 | 152 | # Rope project settings 153 | .ropeproject 154 | 155 | # mkdocs documentation 156 | /site 157 | 158 | # mypy 159 | .mypy_cache/ 160 | .dmypy.json 161 | dmypy.json 162 | 163 | # Pyre type checker 164 | .pyre/ 165 | 166 | # pytype static type analyzer 167 | .pytype/ 168 | 169 | # Cython debug symbols 170 | cython_debug/ 171 | 172 | # PyCharm 173 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 174 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 175 | # and can be added to the global gitignore or merged into this file. For a more nuclear 176 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 177 | #.idea/ 178 | 179 | # Ruff stuff: 180 | .ruff_cache/ 181 | 182 | # PyPI configuration file 183 | .pypirc 184 | 185 | # Golang 186 | 187 | # Binaries for programs and plugins 188 | *.exe 189 | *.exe~ 190 | *.dll 191 | *.so 192 | *.dylib 193 | 194 | # Test binary, built with `go test -c` 195 | *.test 196 | 197 | # Output of the go coverage tool, specifically when used with LiteIDE 198 | *.out 199 | 200 | # Dependency directories (remove the comment below to include it) 201 | # vendor/ 202 | 203 | # Go workspace file 204 | go.work 205 | go.work.sum 206 | 207 | # env file 208 | .env 209 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Austin Ebel 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include CONTRIBUTING.md 2 | include LICENSE 3 | include README.md 4 | 5 | recursive-include orion/backend/lattigo *.so *.dylib *.dll 6 | 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.py[co] 9 | 10 | exclude tests 11 | exclude tests/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Orion 2 | 3 | Adding installation instructions and additional examples shortly. -------------------------------------------------------------------------------- /configs/lola.yml: -------------------------------------------------------------------------------- 1 | comment: Config for LoLA from Figure 3 of https://arxiv.org/pdf/1812.10659 2 | 3 | ckks_params: 4 | LogN: 13 5 | LogQ: [29, 26, 26, 26, 26, 26] 6 | LogP: [29, 29] 7 | LogScale: 26 8 | H: 8192 9 | RingType: ConjugateInvariant 10 | 11 | orion: 12 | margin: 2 # >= 1 13 | embedding_method: hybrid # [hybrid, square] 14 | backend: lattigo # [lattigo, openfhe, heaan] 15 | 16 | fuse_modules: true 17 | debug: false 18 | 19 | diags_path: ../data/diagonals.h5 # "path/to/diags" | "" 20 | keys_path: ../data/keys.h5 # "path/to/keys" | "" 21 | io_mode: none # "load" | "save" | "none" 22 | -------------------------------------------------------------------------------- /configs/mlp.yml: -------------------------------------------------------------------------------- 1 | comment: Config for MLP from https://eprint.iacr.org/2017/396.pdf 2 | 3 | ckks_params: 4 | LogN: 13 5 | LogQ: [29, 26, 26, 26, 26, 26] 6 | LogP: [29, 29] 7 | LogScale: 26 8 | H: 8192 9 | RingType: ConjugateInvariant 10 | 11 | orion: 12 | margin: 2 # >= 1 13 | embedding_method: hybrid # [hybrid, square] 14 | backend: lattigo # [lattigo, openfhe, heaan] 15 | 16 | fuse_modules: true 17 | debug: false 18 | 19 | diags_path: ../data/diagonals.h5 # "path/to/diags" | "" 20 | keys_path: ../data/keys.h5 # "path/to/keys" | "" 21 | io_mode: none # "load" | "save" | "none" 22 | -------------------------------------------------------------------------------- /configs/resnet.yml: -------------------------------------------------------------------------------- 1 | comment: "ResNet Parameter Set" 2 | 3 | ckks_params: 4 | LogN: 16 5 | LogQ: [55, 40, 40, 40, 40, 40, 40, 40, 40, 40, 40] 6 | LogP: [61, 61, 61] 7 | LogScale: 40 8 | H: 192 9 | RingType: standard 10 | 11 | boot_params: 12 | LogP: [61, 61, 61, 61, 61, 61, 61, 61] 13 | 14 | orion: 15 | margin: 2 # >= 1 16 | embedding_method: hybrid # [hybrid, square] 17 | backend: lattigo # [lattigo, openfhe, heaan] 18 | 19 | fuse_modules: true 20 | debug: true 21 | 22 | diags_path: ../data/diagonals.h5 # "path/to/diags" | "" 23 | keys_path: ../data/keys.h5 # "path/to/keys" | "" 24 | io_mode: none # "load" | "save" | "none" -------------------------------------------------------------------------------- /examples/run_lola.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import orion 5 | import orion.models as models 6 | from orion.core.utils import ( 7 | get_mnist_datasets, 8 | mae, 9 | train_on_mnist 10 | ) 11 | 12 | # Set seed for reproducibility 13 | torch.manual_seed(42) 14 | 15 | # Initialize the Orion scheme, model, and data 16 | scheme = orion.init_scheme("../configs/lola.yml") 17 | trainloader, testloader = get_mnist_datasets(data_dir="../data", batch_size=1) 18 | net = models.LoLA() 19 | 20 | # Train model (optional) 21 | # device = "cuda" if torch.cuda.is_available() else "cpu" 22 | # train_on_mnist(net, data_dir="../data", epochs=1, device=device) 23 | 24 | # Get a test batch to pass through our network 25 | inp, _ = next(iter(testloader)) 26 | 27 | # Run cleartext inference 28 | net.eval() 29 | out_clear = net(inp) 30 | 31 | # Prepare for FHE inference. 32 | # Certain polynomial activation functions require us to know the precise range 33 | # of possible input values. We'll determine these ranges by aggregating 34 | # statistics from the training set and applying a tolerance factor = margin. 35 | orion.fit(net, trainloader) 36 | input_level = orion.compile(net) 37 | 38 | # Encode and encrypt the input vector 39 | vec_ptxt = orion.encode(inp, input_level) 40 | vec_ctxt = orion.encrypt(vec_ptxt) 41 | net.he() # Switch to FHE mode 42 | 43 | # Run FHE inference 44 | print("\nStarting FHE inference", flush=True) 45 | start = time.time() 46 | out_ctxt = net(vec_ctxt) 47 | end = time.time() 48 | 49 | # Get the FHE results and decrypt + decode. 50 | out_ptxt = out_ctxt.decrypt() 51 | out_fhe = out_ptxt.decode() 52 | 53 | # Compare the cleartext and FHE results. 54 | print() 55 | print(out_clear) 56 | print(out_fhe) 57 | 58 | dist = mae(out_clear, out_fhe) 59 | print(f"\nMAE: {dist:.4f}") 60 | print(f"Precision: {-math.log2(dist):.4f}") 61 | print(f"Runtime: {end-start:.4f} secs.\n") -------------------------------------------------------------------------------- /examples/run_mlp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import orion 5 | import orion.models as models 6 | from orion.core.utils import ( 7 | get_mnist_datasets, 8 | mae, 9 | train_on_mnist 10 | ) 11 | 12 | # Set seed for reproducibility 13 | torch.manual_seed(42) 14 | 15 | # Initialize the Orion scheme, model, and data 16 | scheme = orion.init_scheme("../configs/mlp.yml") 17 | trainloader, testloader = get_mnist_datasets(data_dir="../data", batch_size=1) 18 | net = models.MLP() 19 | 20 | # Train model (optional) 21 | # device = "cuda" if torch.cuda.is_available() else "cpu" 22 | # train_on_mnist(net, data_dir="../data", epochs=1, device=device) 23 | 24 | # Get a test batch to pass through our network 25 | inp, _ = next(iter(testloader)) 26 | 27 | # Run cleartext inference 28 | net.eval() 29 | out_clear = net(inp) 30 | 31 | # Prepare for FHE inference. 32 | # Certain polynomial activation functions require us to know the precise range 33 | # of possible input values. We'll determine these ranges by aggregating 34 | # statistics from the training set and applying a tolerance factor = margin. 35 | orion.fit(net, inp, batch_size=128) 36 | input_level = orion.compile(net) 37 | 38 | # Encode and encrypt the input vector 39 | vec_ptxt = orion.encode(inp, input_level) 40 | vec_ctxt = orion.encrypt(vec_ptxt) 41 | net.he() # Switch to FHE mode 42 | 43 | # Run FHE inference 44 | print("\nStarting FHE inference", flush=True) 45 | start = time.time() 46 | out_ctxt = net(vec_ctxt) 47 | end = time.time() 48 | 49 | # Get the FHE results and decrypt + decode. 50 | out_ptxt = out_ctxt.decrypt() 51 | out_fhe = out_ptxt.decode() 52 | 53 | # Compare the cleartext and FHE results. 54 | print() 55 | print(out_clear) 56 | print(out_fhe) 57 | 58 | dist = mae(out_clear, out_fhe) 59 | print(f"\nMAE: {dist:.4f}") 60 | print(f"Precision: {-math.log2(dist):.4f}") 61 | print(f"Runtime: {end-start:.4f} secs.\n") -------------------------------------------------------------------------------- /examples/run_resnet.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import torch 4 | import orion 5 | import orion.models as models 6 | from orion.core.utils import ( 7 | get_cifar_datasets, 8 | mae, 9 | train_on_cifar 10 | ) 11 | 12 | # Set seed for reproducibility 13 | torch.manual_seed(42) 14 | 15 | # Initialize the Orion scheme, model, and data 16 | scheme = orion.init_scheme("../configs/resnet.yml") 17 | trainloader, testloader = get_cifar_datasets(data_dir="../data", batch_size=1) 18 | net = models.ResNet20() 19 | 20 | # Train model (optional) 21 | # device = "cuda" if torch.cuda.is_available() else "cpu" 22 | # train_on_cifar(net, data_dir="../data", epochs=1, device=device) 23 | 24 | # Get a test batch to pass through our network 25 | inp, _ = next(iter(testloader)) 26 | 27 | # Run cleartext inference 28 | net.eval() 29 | out_clear = net(inp) 30 | 31 | # Prepare for FHE inference. 32 | # Some polynomial activation functions require knowing the range of possible 33 | # input values. We'll estimate these ranges using training set statistics, 34 | # adjusted to be wider by a tolerance factor (= margin). 35 | orion.fit(net, inp) 36 | input_level = orion.compile(net) 37 | 38 | # Encode and encrypt the input vector 39 | vec_ptxt = orion.encode(inp, input_level) 40 | vec_ctxt = orion.encrypt(vec_ptxt) 41 | net.he() # Switch to FHE mode 42 | 43 | # Run FHE inference 44 | print("\nStarting FHE inference", flush=True) 45 | start = time.time() 46 | out_ctxt = net(vec_ctxt) 47 | end = time.time() 48 | 49 | # Get the FHE results and decrypt + decode. 50 | out_ptxt = out_ctxt.decrypt() 51 | out_fhe = out_ptxt.decode() 52 | 53 | # Compare the cleartext and FHE results. 54 | print() 55 | print(out_clear) 56 | print(out_fhe) 57 | 58 | dist = mae(out_clear, out_fhe) 59 | print(f"\nMAE: {dist:.4f}") 60 | print(f"Precision: {-math.log2(dist):.4f}") 61 | print(f"Runtime: {end-start:.4f} secs.\n") -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .lenet import * 3 | from .lola import * 4 | from .resnet import * 5 | from .vgg import * 6 | from .yolo import * 7 | from .mlp import * -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | 5 | class ConvBlock(on.Module): 6 | def __init__(self, Ci, Co, kernel_size, stride, padding): 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | on.Conv2d(Ci, Co, kernel_size, stride, padding, bias=False), 10 | on.BatchNorm2d(Co), 11 | on.SiLU(degree=127)) 12 | 13 | def forward(self, x): 14 | return self.conv(x) 15 | 16 | 17 | class LinearBlock(on.Module): 18 | def __init__(self, ni, no): 19 | super().__init__() 20 | self.linear = nn.Sequential( 21 | on.Linear(ni, no), 22 | on.BatchNorm1d(no), 23 | on.SiLU(degree=127)) 24 | 25 | def forward(self, x): 26 | return self.linear(x) 27 | 28 | 29 | class AlexNet(on.Module): 30 | cfg = [64, 'M', 192, 'M', 384, 256, 256, 'A'] 31 | 32 | def __init__(self, num_classes=10): 33 | super().__init__() 34 | self.features = self._make_layers() 35 | self.flatten = on.Flatten() 36 | self.classifier = nn.Sequential( 37 | LinearBlock(1024, 4096), 38 | LinearBlock(4096, 4096), 39 | on.Linear(4096, num_classes)) 40 | 41 | def _make_layers(self): 42 | layers = [] 43 | in_channels = 3 44 | for x in self.cfg: 45 | if x == 'M': 46 | layers += [on.AvgPool2d(kernel_size=2, stride=2)] 47 | elif x == 'A': 48 | layers += [on.AdaptiveAvgPool2d((2, 2))] 49 | else: 50 | layers += [ConvBlock(in_channels, x, kernel_size=3, 51 | stride=1, padding=1)] 52 | in_channels = x 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | x = self.features(x) 57 | x = self.flatten(x) 58 | x = self.classifier(x) 59 | return x 60 | 61 | 62 | if __name__ == "__main__": 63 | import torch 64 | from torchsummary import summary 65 | from fvcore.nn import FlopCountAnalysis 66 | 67 | net = AlexNet() 68 | net.eval() 69 | 70 | x = torch.randn(1,3,32,32) 71 | total_flops = FlopCountAnalysis(net, x).total() 72 | 73 | summary(net, (3,32,32), depth=10) 74 | print("Total flops: ", total_flops) 75 | -------------------------------------------------------------------------------- /models/lenet.py: -------------------------------------------------------------------------------- 1 | import orion.nn as on 2 | 3 | class LeNet(on.Module): 4 | def __init__(self, num_classes=10): 5 | super().__init__() 6 | self.conv1 = on.Conv2d(1, 32, kernel_size=5, padding=2, stride=2) 7 | self.bn1 = on.BatchNorm2d(32) 8 | self.act1 = on.Quad() 9 | 10 | self.conv2 = on.Conv2d(32, 64, kernel_size=5, padding=2, stride=2) 11 | self.bn2 = on.BatchNorm2d(64) 12 | self.act2 = on.Quad() 13 | 14 | self.flatten = on.Flatten() 15 | self.fc1 = on.Linear(7*7*64, 512) 16 | self.bn3 = on.BatchNorm1d(512) 17 | self.act3 = on.Quad() 18 | 19 | self.fc2 = on.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | x = self.act1(self.bn1(self.conv1(x))) 23 | x = self.act2(self.bn2(self.conv2(x))) 24 | x = self.flatten(x) 25 | x = self.act3(self.bn3(self.fc1(x))) 26 | return self.fc2(x) 27 | 28 | 29 | if __name__ == "__main__": 30 | import torch 31 | from torchsummary import summary 32 | from fvcore.nn import FlopCountAnalysis 33 | 34 | net = LeNet() 35 | net.eval() 36 | 37 | x = torch.randn(1,1,28,28) 38 | total_flops = FlopCountAnalysis(net, x).total() 39 | 40 | summary(net, (1,28,28), depth=10) 41 | print("Total flops: ", total_flops) 42 | -------------------------------------------------------------------------------- /models/lola.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | class LoLA(on.Module): 5 | def __init__(self, num_classes=10): 6 | super().__init__() 7 | self.conv1 = on.Conv2d(1, 5, kernel_size=2, padding=0, stride=2) 8 | self.bn1 = on.BatchNorm2d(5) 9 | self.act1 = on.Quad() 10 | 11 | self.fc1 = on.Linear(980, 100) 12 | self.bn2 = on.BatchNorm1d(100) 13 | self.act2 = on.Quad() 14 | 15 | self.fc2 = on.Linear(100, num_classes) 16 | self.flatten = on.Flatten() 17 | 18 | def forward(self, x): 19 | x = self.act1(self.bn1(self.conv1(x))) 20 | x = self.flatten(x) 21 | x = self.act2(self.bn2(self.fc1(x))) 22 | return self.fc2(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | import torch 27 | from torchsummary import summary 28 | from fvcore.nn import FlopCountAnalysis 29 | 30 | net = LoLA() 31 | net.eval() 32 | 33 | x = torch.randn(1,1,28,28) 34 | total_flops = FlopCountAnalysis(net, x).total() 35 | 36 | summary(net, (1,28,28), depth=10) 37 | print("Total flops: ", total_flops) 38 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import orion.nn as on 2 | 3 | class MLP(on.Module): 4 | def __init__(self, num_classes=10): 5 | super().__init__() 6 | self.flatten = on.Flatten() 7 | 8 | self.fc1 = on.Linear(784, 128) 9 | self.bn1 = on.BatchNorm1d(128) 10 | self.act1 = on.Quad() 11 | 12 | self.fc2 = on.Linear(128, 128) 13 | self.bn2 = on.BatchNorm1d(128) 14 | self.act2 = on.Quad() 15 | 16 | self.fc3 = on.Linear(128, num_classes) 17 | 18 | def forward(self, x): 19 | x = self.flatten(x) 20 | x = self.act1(self.bn1(self.fc1(x))) 21 | x = self.act2(self.bn2(self.fc2(x))) 22 | return self.fc3(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | import torch 27 | from torchsummary import summary 28 | from fvcore.nn import FlopCountAnalysis 29 | 30 | net = MLP() 31 | net.eval() 32 | 33 | x = torch.randn(1,1,28,28) 34 | total_flops = FlopCountAnalysis(net, x).total() 35 | 36 | summary(net, (1,28,28), depth=10) 37 | print("Total flops: ", total_flops) 38 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | 5 | class BasicBlock(on.Module): 6 | expansion = 1 7 | 8 | def __init__(self, Ci, Co, stride=1): 9 | super().__init__() 10 | self.conv1 = on.Conv2d(Ci, Co, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn1 = on.BatchNorm2d(Co) 12 | self.act1 = on.ReLU() 13 | 14 | self.conv2 = on.Conv2d(Co, Co, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = on.BatchNorm2d(Co) 16 | self.act2 = on.ReLU() 17 | 18 | self.add = on.Add() 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or Ci != self.expansion*Co: 21 | self.shortcut = nn.Sequential( 22 | on.Conv2d(Ci, self.expansion*Co, kernel_size=1, stride=stride, bias=False), 23 | on.BatchNorm2d(self.expansion*Co)) 24 | 25 | def forward(self, x): 26 | out = self.act1(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out = self.add(out, self.shortcut(x)) 29 | return self.act2(out) 30 | 31 | 32 | class Bottleneck(on.Module): 33 | expansion = 4 34 | 35 | def __init__(self, Ci, Co, stride=1): 36 | super().__init__() 37 | self.conv1 = on.Conv2d(Ci, Co, kernel_size=1, bias=False) 38 | self.bn1 = on.BatchNorm2d(Co) 39 | self.act1 = on.SiLU(degree=127) 40 | 41 | self.conv2 = on.Conv2d(Co, Co, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = on.BatchNorm2d(Co) 43 | self.act2 = on.SiLU(degree=127) 44 | 45 | self.conv3 = on.Conv2d(Co, Co*self.expansion, kernel_size=1, stride=1, bias=False) 46 | self.bn3 = on.BatchNorm2d(Co*self.expansion) 47 | self.act3 = on.SiLU(degree=127) 48 | 49 | self.add = on.Add() 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or Ci != self.expansion*Co: 52 | self.shortcut = nn.Sequential( 53 | on.Conv2d(Ci, self.expansion*Co, kernel_size=1, stride=stride, bias=False), 54 | on.BatchNorm2d(self.expansion*Co)) 55 | 56 | def forward(self, x): 57 | out = self.act1(self.bn1(self.conv1(x))) 58 | out = self.act2(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out = self.add(out, self.shortcut(x)) 61 | return self.act3(out) 62 | 63 | 64 | class ResNet(on.Module): 65 | def __init__(self, dataset, block, num_blocks, num_chans, conv1_params, num_classes): 66 | super().__init__() 67 | self.in_chans = num_chans[0] 68 | self.last_chans = num_chans[-1] 69 | 70 | self.conv1 = on.Conv2d(3, self.in_chans, **conv1_params, bias=False) 71 | self.bn1 = on.BatchNorm2d(self.in_chans) 72 | self.act = on.ReLU() 73 | 74 | self.pool = nn.Identity() 75 | if dataset == 'imagenet': # for imagenet we must also downsample 76 | self.pool = on.AvgPool2d(kernel_size=3, stride=2, padding=1) 77 | 78 | self.layers = nn.ModuleList() 79 | for i in range(len(num_blocks)): 80 | stride = 1 if i == 0 else 2 81 | self.layers.append(self.layer(block, num_chans[i], num_blocks[i], stride)) 82 | 83 | self.avgpool = on.AdaptiveAvgPool2d(output_size=(1,1)) 84 | self.flatten = on.Flatten() 85 | self.linear = on.Linear(self.last_chans * block.expansion, num_classes) 86 | 87 | def layer(self, block, chans, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_chans, chans, stride)) 92 | self.in_chans = chans * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = self.act(self.bn1(self.conv1(x))) 97 | out = self.pool(out) 98 | for layer in self.layers: 99 | out = layer(out) 100 | 101 | out = self.avgpool(out) 102 | out = self.flatten(out) 103 | return self.linear(out) 104 | 105 | 106 | ################################ 107 | # CIFAR-10 / CIFAR-100 ResNets # 108 | ################################ 109 | 110 | def ResNet20(dataset='cifar10'): 111 | conv1_params, num_classes = get_resnet_config(dataset) 112 | return ResNet(dataset, BasicBlock, [3,3,3], [16,32,64], conv1_params, num_classes) 113 | 114 | def ResNet32(dataset='cifar10'): 115 | conv1_params, num_classes = get_resnet_config(dataset) 116 | return ResNet(dataset, BasicBlock, [5,5,5], [16,32,64], conv1_params, num_classes) 117 | 118 | def ResNet44(dataset='cifar10'): 119 | conv1_params, num_classes = get_resnet_config(dataset) 120 | return ResNet(dataset, BasicBlock, [7,7,7], [16,32,64], conv1_params, num_classes) 121 | 122 | def ResNet56(dataset='cifar10'): 123 | conv1_params, num_classes = get_resnet_config(dataset) 124 | return ResNet(dataset, BasicBlock, [9,9,9], [16,32,64], conv1_params, num_classes) 125 | 126 | def ResNet110(dataset='cifar10'): 127 | conv1_params, num_classes = get_resnet_config(dataset) 128 | return ResNet(dataset, BasicBlock, [18,18,18], [16,32,64], conv1_params, num_classes) 129 | 130 | def ResNet1202(dataset='cifar10'): 131 | conv1_params, num_classes = get_resnet_config(dataset) 132 | return ResNet(dataset, BasicBlock, [200,200,200], [16,32,64], conv1_params, num_classes) 133 | 134 | #################################### 135 | # Tiny ImageNet / ImageNet ResNets # 136 | #################################### 137 | 138 | def ResNet18(dataset='imagenet'): 139 | conv1_params, num_classes = get_resnet_config(dataset) 140 | return ResNet(dataset, BasicBlock, [2,2,2,2], [64,128,256,512], conv1_params, num_classes) 141 | 142 | def ResNet34(dataset='imagenet'): 143 | conv1_params, num_classes = get_resnet_config(dataset) 144 | return ResNet(dataset, BasicBlock, [3,4,6,3], [64,128,256,512], conv1_params, num_classes) 145 | 146 | def ResNet50(dataset='imagenet'): 147 | conv1_params, num_classes = get_resnet_config(dataset) 148 | return ResNet(dataset, Bottleneck, [3,4,6,3], [64,128,256,512], conv1_params, num_classes) 149 | 150 | def ResNet101(dataset='imagenet'): 151 | conv1_params, num_classes = get_resnet_config(dataset) 152 | return ResNet(dataset, Bottleneck, [3,4,23,3], [64,128,256,512], conv1_params, num_classes) 153 | 154 | def ResNet152(dataset='imagenet'): 155 | conv1_params, num_classes = get_resnet_config(dataset) 156 | return ResNet(dataset, Bottleneck, [3,8,36,3], [64,128,256,512], conv1_params, num_classes) 157 | 158 | 159 | def get_resnet_config(dataset): 160 | configs = { 161 | "cifar10": {"kernel_size": 3, "stride": 1, "padding": 1, "num_classes": 10}, 162 | "cifar100": {"kernel_size": 3, "stride": 1, "padding": 1, "num_classes": 100}, 163 | "tiny": {"kernel_size": 7, "stride": 1, "padding": 3, "num_classes": 200}, 164 | "imagenet": {"kernel_size": 7, "stride": 2, "padding": 3, "num_classes": 1000}, 165 | } 166 | 167 | if dataset not in configs: 168 | raise ValueError(f"ResNet with dataset {dataset} is not supported.") 169 | 170 | config = configs[dataset] 171 | conv1_params = { 172 | 'kernel_size': config["kernel_size"], 173 | 'stride': config["stride"], 174 | 'padding': config["padding"] 175 | } 176 | 177 | return conv1_params, config["num_classes"] 178 | 179 | 180 | if __name__ == "__main__": 181 | import torch 182 | from torchsummary import summary 183 | from fvcore.nn import FlopCountAnalysis 184 | 185 | net = ResNet50() 186 | net.eval() 187 | 188 | x = torch.randn(1,3,224,224) 189 | total_flops = FlopCountAnalysis(net, x).total() 190 | 191 | summary(net, (3,224,224), depth=10) 192 | print("Total flops: ", total_flops) 193 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | cfg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 9 | } 10 | 11 | class VGG(on.Module): 12 | def __init__(self, vgg_name): 13 | super().__init__() 14 | self.features = self._make_layers(cfg[vgg_name]) 15 | self.classifier = on.Linear(512, 10) 16 | self.flatten = on.Flatten() 17 | 18 | def forward(self, x): 19 | out = self.features(x) 20 | out = self.flatten(out) 21 | out = self.classifier(out) 22 | return out 23 | 24 | def _make_layers(self, cfg): 25 | layers = [] 26 | in_channels = 3 27 | for x in cfg: 28 | if x == 'M': 29 | layers += [on.AvgPool2d(kernel_size=2, stride=2)] 30 | else: 31 | layers += [on.Conv2d(in_channels, x, kernel_size=3, padding=1), 32 | on.BatchNorm2d(x), 33 | on.ReLU(degrees=[15,15,27])] 34 | in_channels = x 35 | layers += [on.AvgPool2d(kernel_size=1, stride=1)] 36 | return nn.Sequential(*layers) 37 | 38 | 39 | if __name__ == "__main__": 40 | import torch 41 | from torchsummary import summary 42 | from fvcore.nn import FlopCountAnalysis 43 | 44 | net = VGG('VGG16') 45 | net.eval() 46 | 47 | x = torch.randn(1,3,32,32) 48 | total_flops = FlopCountAnalysis(net, x).total() 49 | 50 | summary(net, (3,32,32), depth=10) 51 | print("Total flops: ", total_flops) 52 | -------------------------------------------------------------------------------- /models/yolo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import orion.nn as on 4 | 5 | from .resnet import * 6 | 7 | 8 | class YOLOv1(on.Module): 9 | def __init__(self, backbone, num_bboxes=2, num_classes=20, model_path=None): 10 | super().__init__() 11 | 12 | self.feature_size = 7 13 | self.num_bboxes = num_bboxes 14 | self.num_classes = num_classes 15 | self.model_path = model_path 16 | 17 | self.backbone = backbone 18 | self.conv_layers = self._make_conv_layers() 19 | self.fc_layers = self._make_fc_layers() 20 | 21 | # Remove last layers of backbone 22 | self.backbone.avgpool = nn.Identity() 23 | self.backbone.flatten = nn.Identity() 24 | self.backbone.linear = nn.Identity() 25 | 26 | self._init_weights() 27 | 28 | def _init_weights(self): 29 | if self.model_path: 30 | state_dict = torch.load(self.model_path, map_location='cpu', weights_only=False) 31 | self.load_state_dict(state_dict, strict=False) 32 | 33 | def _make_conv_layers(self): 34 | net = nn.Sequential( 35 | on.Conv2d(512, 512, 3, padding=1), 36 | on.SiLU(degree=127), 37 | on.Conv2d(512, 512, 3, stride=2, padding=1), 38 | on.SiLU(degree=127), 39 | 40 | on.Conv2d(512, 512, 3, padding=1), 41 | on.SiLU(degree=127), 42 | on.Conv2d(512, 512, 3, padding=1), 43 | on.SiLU(degree=127) 44 | ) 45 | 46 | return net 47 | 48 | def _make_fc_layers(self): 49 | S, B, C = self.feature_size, self.num_bboxes, self.num_classes 50 | net = nn.Sequential( 51 | on.Flatten(), 52 | on.Linear(7 * 7 * 512, 4096), 53 | on.SiLU(degree=127), 54 | on.Linear(4096, S * S * (5 * B + C)), 55 | ) 56 | 57 | return net 58 | 59 | def forward(self, x): 60 | x = self.backbone(x) 61 | x = self.conv_layers(x) 62 | x = self.fc_layers(x) 63 | return x 64 | 65 | 66 | def YOLOv1_ResNet34(model_path=None): 67 | backbone = ResNet34() 68 | net = YOLOv1(backbone, num_bboxes=2, num_classes=20, model_path=model_path) 69 | return net 70 | 71 | 72 | if __name__ == "__main__": 73 | import torch 74 | from torchsummary import summary 75 | from fvcore.nn import FlopCountAnalysis 76 | 77 | net = YOLOv1_ResNet34() 78 | net.eval() 79 | 80 | x = torch.randn(1,3,448,448) 81 | total_flops = FlopCountAnalysis(net, x).total() 82 | 83 | summary(net, (3,448,448), depth=10) 84 | print("Total flops: ", total_flops) 85 | -------------------------------------------------------------------------------- /orion/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import ( 2 | init_scheme, 3 | delete_scheme, 4 | encode, 5 | decode, 6 | encrypt, 7 | decrypt, 8 | fit, 9 | compile 10 | ) 11 | 12 | __version__ = "1.0.1" -------------------------------------------------------------------------------- /orion/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baahl-nyu/orion/f0581052b28d02a00299cce742949930b3260aa8/orion/backend/__init__.py -------------------------------------------------------------------------------- /orion/backend/heaan/README.md: -------------------------------------------------------------------------------- 1 | # Orion -------------------------------------------------------------------------------- /orion/backend/lattigo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baahl-nyu/orion/f0581052b28d02a00299cce742949930b3260aa8/orion/backend/lattigo/__init__.py -------------------------------------------------------------------------------- /orion/backend/lattigo/bootstrapper.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | ) 6 | import ( 7 | "fmt" 8 | "math" 9 | 10 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/bootstrapping" 11 | "github.com/baahl-nyu/lattigo/v6/utils" 12 | ) 13 | 14 | // Map to store bootstrapping.Evaluators by their slot count 15 | // Initialize the map at package level 16 | var bootstrapperMap = make(map[int]*bootstrapping.Evaluator) 17 | 18 | //export NewBootstrapper 19 | func NewBootstrapper( 20 | LogPs *C.int, 21 | lenLogPs C.int, 22 | numSlots C.int, 23 | ) { 24 | slots := int(numSlots) 25 | 26 | if _, exists := bootstrapperMap[slots]; exists { 27 | return 28 | } 29 | 30 | // If not initialized for this slot count, create a new one 31 | logP := CArrayToSlice(LogPs, lenLogPs, convertCIntToInt) 32 | 33 | btpParametersLit := bootstrapping.ParametersLiteral{ 34 | LogN: utils.Pointy(scheme.Params.LogN()), 35 | LogP: logP, 36 | Xs: scheme.Params.Xs(), 37 | LogSlots: utils.Pointy(int(math.Log2(float64(slots)))), 38 | } 39 | 40 | btpParams, err := bootstrapping.NewParametersFromLiteral( 41 | *scheme.Params, btpParametersLit) 42 | if err != nil { 43 | panic(err) 44 | } 45 | 46 | btpKeys, _, err := btpParams.GenEvaluationKeys(scheme.SecretKey) 47 | if err != nil { 48 | panic(err) 49 | } 50 | 51 | var btpEval *bootstrapping.Evaluator 52 | if btpEval, err = bootstrapping.NewEvaluator(btpParams, btpKeys); err != nil { 53 | panic(err) 54 | } 55 | 56 | // Store the new evaluator in the map 57 | bootstrapperMap[slots] = btpEval 58 | } 59 | 60 | //export Bootstrap 61 | func Bootstrap(ciphertextID, numSlots C.int) C.int { 62 | ctIn := RetrieveCiphertext(int(ciphertextID)) 63 | bootstrapper := GetBootstrapper(int(numSlots)) 64 | 65 | ctBtp := ctIn.CopyNew() 66 | ctBtp.LogDimensions.Cols = bootstrapper.LogMaxSlots() 67 | 68 | ctOut, err := bootstrapper.Bootstrap(ctBtp) 69 | if err != nil { 70 | panic(err) 71 | } 72 | 73 | postscale := int(1 << (scheme.Params.LogMaxSlots() - bootstrapper.LogMaxSlots())) 74 | scheme.Evaluator.Mul(ctOut, postscale, ctOut) 75 | 76 | ctOut.LogDimensions.Cols = scheme.Params.LogMaxSlots() 77 | 78 | idx := PushCiphertext(ctOut) 79 | return C.int(idx) 80 | } 81 | 82 | func GetBootstrapper(numSlots int) *bootstrapping.Evaluator { 83 | bootstrapper, exists := bootstrapperMap[numSlots] 84 | if !exists { 85 | panic(fmt.Errorf("no bootstrapper found for slot count: %d", numSlots)) 86 | } 87 | return bootstrapper 88 | } 89 | 90 | //export DeleteBootstrappers 91 | func DeleteBootstrappers() { 92 | bootstrapperMap = make(map[int]*bootstrapping.Evaluator) 93 | } 94 | -------------------------------------------------------------------------------- /orion/backend/lattigo/encoder.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 7 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 8 | ) 9 | 10 | //export NewEncoder 11 | func NewEncoder() { 12 | scheme.Encoder = ckks.NewEncoder(*scheme.Params) 13 | } 14 | 15 | //export Encode 16 | func Encode( 17 | valuesPtr *C.float, 18 | lenValues C.int, 19 | level C.int, 20 | scale C.ulong, 21 | ) C.int { 22 | values := CArrayToSlice(valuesPtr, lenValues, convertCFloatToFloat) 23 | plaintext := ckks.NewPlaintext(*scheme.Params, int(level)) 24 | plaintext.Scale = rlwe.NewScale(uint64(scale)) 25 | 26 | scheme.Encoder.Encode(values, plaintext) 27 | 28 | idx := PushPlaintext(plaintext) 29 | return C.int(idx) 30 | } 31 | 32 | //export Decode 33 | func Decode( 34 | plaintextID C.int, 35 | ) (*C.float, C.ulong) { 36 | plaintext := RetrievePlaintext(int(plaintextID)) 37 | result := make([]float64, scheme.Params.MaxSlots()) 38 | scheme.Encoder.Decode(plaintext, result) 39 | 40 | arrPtr, length := SliceToCArray(result, convertFloatToCFloat) 41 | return arrPtr, length 42 | } 43 | -------------------------------------------------------------------------------- /orion/backend/lattigo/encryptor.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 7 | ) 8 | 9 | //export NewEncryptor 10 | func NewEncryptor() { 11 | scheme.Encryptor = ckks.NewEncryptor(*scheme.Params, scheme.PublicKey) 12 | } 13 | 14 | //export NewDecryptor 15 | func NewDecryptor() { 16 | scheme.Decryptor = ckks.NewDecryptor(*scheme.Params, scheme.SecretKey) 17 | } 18 | 19 | //export Encrypt 20 | func Encrypt(plaintextID C.int) C.int { 21 | plaintext := RetrievePlaintext(int(plaintextID)) 22 | ciphertext := ckks.NewCiphertext(*scheme.Params, 1, plaintext.Level()) 23 | scheme.Encryptor.Encrypt(plaintext, ciphertext) 24 | 25 | idx := PushCiphertext(ciphertext) 26 | return C.int(idx) 27 | } 28 | 29 | //export Decrypt 30 | func Decrypt(ciphertextID C.int) C.int { 31 | ciphertext := RetrieveCiphertext(int(ciphertextID)) 32 | 33 | plaintext := ckks.NewPlaintext(*scheme.Params, ciphertext.Level()) 34 | scheme.Decryptor.Decrypt(ciphertext, plaintext) 35 | 36 | idx := PushPlaintext(plaintext) 37 | return C.int(idx) 38 | } 39 | -------------------------------------------------------------------------------- /orion/backend/lattigo/evaluator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 7 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 8 | ) 9 | 10 | var liveRotKeys = make(map[uint64]*rlwe.GaloisKey) 11 | var savedRotKeys = []uint64{} 12 | 13 | //export NewEvaluator 14 | func NewEvaluator() { 15 | scheme.Evaluator = ckks.NewEvaluator( 16 | *scheme.Params, rlwe.NewMemEvaluationKeySet(scheme.RelinKey)) 17 | 18 | // After declaring the evaluator, we'll also just generate and 19 | // store in memory all power of two rotation keys. This will ensure 20 | // all keys needed for the rotations and summations in the hyrid 21 | // method remain alive. 22 | AddPo2RotationKeys() 23 | } 24 | 25 | func AddPo2RotationKeys() { 26 | maxSlots := scheme.Params.MaxSlots() 27 | // Generate all positive power-of-two rotation keys 28 | for i := 1; i < maxSlots; i *= 2 { 29 | AddRotationKey(C.int(i)) 30 | } 31 | } 32 | 33 | //export AddRotationKey 34 | func AddRotationKey(rotation C.int) { 35 | galEl := scheme.Params.GaloisElement(int(rotation)) 36 | 37 | // Generate the required rotation key if it doesn't exist 38 | if _, exists := liveRotKeys[galEl]; !exists { 39 | rotKey := scheme.KeyGen.GenGaloisKeyNew(galEl, scheme.SecretKey) 40 | liveRotKeys[galEl] = rotKey 41 | 42 | allKeysList := GetValuesFromMap(liveRotKeys) 43 | keys := rlwe.NewMemEvaluationKeySet(scheme.RelinKey, allKeysList...) 44 | scheme.Evaluator = scheme.Evaluator.WithKey(keys) 45 | } 46 | } 47 | 48 | //export Negate 49 | func Negate(ciphertextID C.int) C.int { 50 | ctIn := RetrieveCiphertext(int(ciphertextID)) 51 | ctOut, err := scheme.Evaluator.MulNew(ctIn, -1.0) 52 | if err != nil { 53 | panic(err) 54 | } 55 | 56 | idx := PushCiphertext(ctOut) 57 | return C.int(idx) 58 | } 59 | 60 | //export Rotate 61 | func Rotate(ciphertextID, amount C.int) C.int { 62 | ctIn := RetrieveCiphertext(int(ciphertextID)) 63 | AddRotationKey(amount) 64 | scheme.Evaluator.Rotate(ctIn, int(amount), ctIn) 65 | 66 | return ciphertextID 67 | } 68 | 69 | //export RotateNew 70 | func RotateNew(ciphertextID, amount C.int) C.int { 71 | ctIn := RetrieveCiphertext(int(ciphertextID)) 72 | AddRotationKey(amount) 73 | 74 | ctOut, err := scheme.Evaluator.RotateNew(ctIn, int(amount)) 75 | if err != nil { 76 | panic(err) 77 | } 78 | 79 | idx := PushCiphertext(ctOut) 80 | return C.int(idx) 81 | } 82 | 83 | //export Rescale 84 | func Rescale(ciphertextID C.int) C.int { 85 | ctIn := RetrieveCiphertext(int(ciphertextID)) 86 | scheme.Evaluator.Rescale(ctIn, ctIn) 87 | 88 | return ciphertextID 89 | } 90 | 91 | //export RescaleNew 92 | func RescaleNew(ciphertextID C.int) C.int { 93 | ctIn := RetrieveCiphertext(int(ciphertextID)) 94 | scheme.Evaluator.Rescale(ctIn, ctIn) 95 | ctOut := ctIn.CopyNew() 96 | 97 | idx := PushCiphertext(ctOut) 98 | return C.int(idx) 99 | } 100 | 101 | //export AddScalar 102 | func AddScalar(ciphertextID C.int, scalar C.float) C.int { 103 | ctIn := RetrieveCiphertext(int(ciphertextID)) 104 | scheme.Evaluator.Add(ctIn, float64(scalar), ctIn) 105 | 106 | return ciphertextID 107 | } 108 | 109 | //export AddScalarNew 110 | func AddScalarNew(ciphertextID C.int, scalar C.float) C.int { 111 | ctIn := RetrieveCiphertext(int(ciphertextID)) 112 | ctOut, err := scheme.Evaluator.AddNew(ctIn, float64(scalar)) 113 | if err != nil { 114 | panic(err) 115 | } 116 | 117 | idx := PushCiphertext(ctOut) 118 | return C.int(idx) 119 | } 120 | 121 | //export SubScalar 122 | func SubScalar(ciphertextID C.int, scalar C.float) C.int { 123 | ctIn := RetrieveCiphertext(int(ciphertextID)) 124 | scheme.Evaluator.Sub(ctIn, float64(scalar), ctIn) 125 | 126 | return ciphertextID 127 | } 128 | 129 | //export SubScalarNew 130 | func SubScalarNew(ciphertextID C.int, scalar C.float) C.int { 131 | ctIn := RetrieveCiphertext(int(ciphertextID)) 132 | ctOut, err := scheme.Evaluator.SubNew(ctIn, float64(scalar)) 133 | if err != nil { 134 | panic(err) 135 | } 136 | 137 | idx := PushCiphertext(ctOut) 138 | return C.int(idx) 139 | } 140 | 141 | //export MulScalarInt 142 | func MulScalarInt(ciphertextID C.int, scalar C.int) C.int { 143 | ctIn := RetrieveCiphertext(int(ciphertextID)) 144 | scheme.Evaluator.Mul(ctIn, int(scalar), ctIn) 145 | 146 | return ciphertextID 147 | } 148 | 149 | //export MulScalarIntNew 150 | func MulScalarIntNew(ciphertextID C.int, scalar C.int) C.int { 151 | ctIn := RetrieveCiphertext(int(ciphertextID)) 152 | ctOut, err := scheme.Evaluator.MulNew(ctIn, int(scalar)) 153 | if err != nil { 154 | panic(err) 155 | } 156 | 157 | idx := PushCiphertext(ctOut) 158 | return C.int(idx) 159 | } 160 | 161 | //export MulScalarFloat 162 | func MulScalarFloat(ciphertextID C.int, scalar C.float) C.int { 163 | ctIn := RetrieveCiphertext(int(ciphertextID)) 164 | scheme.Evaluator.Mul(ctIn, float64(scalar), ctIn) 165 | 166 | return ciphertextID 167 | } 168 | 169 | //export MulScalarFloatNew 170 | func MulScalarFloatNew(ciphertextID C.int, scalar C.float) C.int { 171 | ctIn := RetrieveCiphertext(int(ciphertextID)) 172 | ctOut, err := scheme.Evaluator.MulNew(ctIn, float64(scalar)) 173 | if err != nil { 174 | panic(err) 175 | } 176 | 177 | idx := PushCiphertext(ctOut) 178 | return C.int(idx) 179 | } 180 | 181 | //export AddPlaintext 182 | func AddPlaintext(ciphertextID, plaintextID C.int) C.int { 183 | ctIn := RetrieveCiphertext(int(ciphertextID)) 184 | ptIn := RetrievePlaintext(int(plaintextID)) 185 | scheme.Evaluator.Add(ctIn, ptIn, ctIn) 186 | 187 | return ciphertextID 188 | } 189 | 190 | //export AddPlaintextNew 191 | func AddPlaintextNew(ciphertextID, plaintextID C.int) C.int { 192 | ctIn := RetrieveCiphertext(int(ciphertextID)) 193 | ptIn := RetrievePlaintext(int(plaintextID)) 194 | 195 | ctOut, err := scheme.Evaluator.AddNew(ctIn, ptIn) 196 | if err != nil { 197 | panic(err) 198 | } 199 | 200 | idx := PushCiphertext(ctOut) 201 | return C.int(idx) 202 | } 203 | 204 | //export SubPlaintext 205 | func SubPlaintext(ciphertextID, plaintextID C.int) C.int { 206 | ctIn := RetrieveCiphertext(int(ciphertextID)) 207 | ptIn := RetrievePlaintext(int(plaintextID)) 208 | scheme.Evaluator.Sub(ctIn, ptIn, ctIn) 209 | 210 | return ciphertextID 211 | } 212 | 213 | //export SubPlaintextNew 214 | func SubPlaintextNew(ciphertextID, plaintextID C.int) C.int { 215 | ctIn := RetrieveCiphertext(int(ciphertextID)) 216 | ptIn := RetrievePlaintext(int(plaintextID)) 217 | 218 | ctOut, err := scheme.Evaluator.SubNew(ctIn, ptIn) 219 | if err != nil { 220 | panic(err) 221 | } 222 | 223 | idx := PushCiphertext(ctOut) 224 | return C.int(idx) 225 | } 226 | 227 | //export MulPlaintext 228 | func MulPlaintext(ciphertextID, plaintextID C.int) C.int { 229 | ctIn := RetrieveCiphertext(int(ciphertextID)) 230 | ptIn := RetrievePlaintext(int(plaintextID)) 231 | scheme.Evaluator.Mul(ctIn, ptIn, ctIn) 232 | 233 | return ciphertextID 234 | } 235 | 236 | //export MulPlaintextNew 237 | func MulPlaintextNew(ciphertextID, plaintextID C.int) C.int { 238 | ctIn := RetrieveCiphertext(int(ciphertextID)) 239 | ptIn := RetrievePlaintext(int(plaintextID)) 240 | 241 | ctOut, err := scheme.Evaluator.MulNew(ctIn, ptIn) 242 | if err != nil { 243 | panic(err) 244 | } 245 | 246 | idx := PushCiphertext(ctOut) 247 | return C.int(idx) 248 | } 249 | 250 | //export AddCiphertext 251 | func AddCiphertext(ctID0, ctID1 C.int) C.int { 252 | ctIn0 := RetrieveCiphertext(int(ctID0)) 253 | ctIn1 := RetrieveCiphertext((int(ctID1))) 254 | scheme.Evaluator.Add(ctIn0, ctIn1, ctIn0) 255 | 256 | return ctID0 257 | } 258 | 259 | //export AddCiphertextNew 260 | func AddCiphertextNew(ctID0, ctID1 C.int) C.int { 261 | ctIn0 := RetrieveCiphertext(int(ctID0)) 262 | ctIn1 := RetrieveCiphertext((int(ctID1))) 263 | 264 | ctOut, err := scheme.Evaluator.AddNew(ctIn0, ctIn1) 265 | if err != nil { 266 | panic(err) 267 | } 268 | 269 | idx := PushCiphertext(ctOut) 270 | return C.int(idx) 271 | } 272 | 273 | //export SubCiphertext 274 | func SubCiphertext(ctID0, ctID1 C.int) C.int { 275 | ctIn0 := RetrieveCiphertext(int(ctID0)) 276 | ctIn1 := RetrieveCiphertext((int(ctID1))) 277 | scheme.Evaluator.Sub(ctIn0, ctIn1, ctIn0) 278 | 279 | return ctID0 280 | } 281 | 282 | //export SubCiphertextNew 283 | func SubCiphertextNew(ctID0, ctID1 C.int) C.int { 284 | ctIn0 := RetrieveCiphertext(int(ctID0)) 285 | ctIn1 := RetrieveCiphertext((int(ctID1))) 286 | 287 | ctOut, err := scheme.Evaluator.SubNew(ctIn0, ctIn1) 288 | if err != nil { 289 | panic(err) 290 | } 291 | 292 | idx := PushCiphertext(ctOut) 293 | return C.int(idx) 294 | } 295 | 296 | //export MulRelinCiphertext 297 | func MulRelinCiphertext(ctID0, ctID1 C.int) C.int { 298 | ctIn0 := RetrieveCiphertext(int(ctID0)) 299 | ctIn1 := RetrieveCiphertext((int(ctID1))) 300 | scheme.Evaluator.MulRelin(ctIn0, ctIn1, ctIn0) 301 | 302 | return ctID0 303 | } 304 | 305 | //export MulRelinCiphertextNew 306 | func MulRelinCiphertextNew(ctID0, ctID1 C.int) C.int { 307 | ctIn0 := RetrieveCiphertext(int(ctID0)) 308 | ctIn1 := RetrieveCiphertext((int(ctID1))) 309 | 310 | ctOut, err := scheme.Evaluator.MulRelinNew(ctIn0, ctIn1) 311 | if err != nil { 312 | panic(err) 313 | } 314 | 315 | idx := PushCiphertext(ctOut) 316 | return C.int(idx) 317 | } 318 | 319 | func DeleteRotationKeys() { 320 | liveRotKeys = make(map[uint64]*rlwe.GaloisKey) 321 | savedRotKeys = []uint64{} 322 | } 323 | -------------------------------------------------------------------------------- /orion/backend/lattigo/go.mod: -------------------------------------------------------------------------------- 1 | module github.com/baahl-nyu/orion 2 | 3 | go 1.23.0 4 | 5 | require github.com/baahl-nyu/lattigo/v6 v6.2.0 6 | 7 | require ( 8 | github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 // indirect 9 | github.com/davecgh/go-spew v1.1.1 // indirect 10 | github.com/google/go-cmp v0.5.8 // indirect 11 | github.com/kr/text v0.2.0 // indirect 12 | github.com/pmezard/go-difflib v1.0.0 // indirect 13 | github.com/stretchr/testify v1.8.0 // indirect 14 | golang.org/x/crypto v0.31.0 // indirect 15 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29 // indirect 16 | golang.org/x/sys v0.28.0 // indirect 17 | gopkg.in/yaml.v3 v3.0.1 // indirect 18 | ) 19 | -------------------------------------------------------------------------------- /orion/backend/lattigo/go.sum: -------------------------------------------------------------------------------- 1 | github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924 h1:DG4UyTVIujioxwJc8Zj8Nabz1L1wTgQ/xNBSQDfdP3I= 2 | github.com/ALTree/bigfloat v0.0.0-20220102081255-38c8b72a9924/go.mod h1:+NaH2gLeY6RPBPPQf4aRotPPStg+eXc8f9ZaE4vRfD4= 3 | github.com/baahl-nyu/lattigo/v6 v6.2.0 h1:q3Q3D7BsQeRmiyH2DbjSeR9HOgg1BXRYtDvZnnk9tUU= 4 | github.com/baahl-nyu/lattigo/v6 v6.2.0/go.mod h1:v2QIZBeS8QD8uaKkGzUYKK9qWrqSxKqYLlJ2JePoZWU= 5 | github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= 6 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 7 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 8 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 9 | github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= 10 | github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= 11 | github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= 12 | github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= 13 | github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 14 | github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 15 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 16 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 17 | github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= 18 | github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= 19 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 20 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 21 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 22 | github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= 23 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 24 | golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= 25 | golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= 26 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29 h1:ooxPy7fPvB4kwsA2h+iBNHkAbp/4JxTSwCmvdjEYmug= 27 | golang.org/x/exp v0.0.0-20230321023759-10a507213a29/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= 28 | golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= 29 | golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= 30 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 31 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= 32 | gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= 33 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 34 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 35 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 36 | -------------------------------------------------------------------------------- /orion/backend/lattigo/keygenerator.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 7 | ) 8 | import ( 9 | "unsafe" 10 | ) 11 | 12 | //export NewKeyGenerator 13 | func NewKeyGenerator() { 14 | scheme.KeyGen = rlwe.NewKeyGenerator(scheme.Params) 15 | } 16 | 17 | //export GenerateSecretKey 18 | func GenerateSecretKey() { 19 | scheme.SecretKey = scheme.KeyGen.GenSecretKeyNew() 20 | } 21 | 22 | //export GeneratePublicKey 23 | func GeneratePublicKey() { 24 | scheme.PublicKey = scheme.KeyGen.GenPublicKeyNew(scheme.SecretKey) 25 | } 26 | 27 | //export GenerateRelinearizationKey 28 | func GenerateRelinearizationKey() { 29 | scheme.RelinKey = scheme.KeyGen.GenRelinearizationKeyNew(scheme.SecretKey) 30 | } 31 | 32 | //export GenerateEvaluationKeys 33 | func GenerateEvaluationKeys() { 34 | scheme.EvalKeys = rlwe.NewMemEvaluationKeySet(scheme.RelinKey) 35 | } 36 | 37 | //export SerializeSecretKey 38 | func SerializeSecretKey() (*C.char, C.ulong) { 39 | data, err := scheme.SecretKey.MarshalBinary() 40 | if err != nil { 41 | panic(err) 42 | } 43 | 44 | arrPtr, length := SliceToCArray(data, convertByteToCChar) 45 | return arrPtr, length 46 | } 47 | 48 | //export LoadSecretKey 49 | func LoadSecretKey(dataPtr *C.char, lenData C.ulong) { 50 | skSerial := CArrayToByteSlice(unsafe.Pointer(dataPtr), uint64(lenData)) 51 | 52 | sk := &rlwe.SecretKey{} 53 | if err := sk.UnmarshalBinary(skSerial); err != nil { 54 | panic(err) 55 | } 56 | 57 | scheme.SecretKey = sk 58 | } 59 | -------------------------------------------------------------------------------- /orion/backend/lattigo/lineartransform.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | "math" 6 | "unsafe" 7 | 8 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/lintrans" 9 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 10 | "github.com/baahl-nyu/lattigo/v6/ring" 11 | "github.com/baahl-nyu/lattigo/v6/ring/ringqp" 12 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 13 | ) 14 | 15 | var ltHeap = NewHeapAllocator() 16 | 17 | func AddLinearTransform(lt lintrans.LinearTransformation) int { 18 | return ltHeap.Add(lt) 19 | } 20 | 21 | func RetrieveLinearTransform(id int) lintrans.LinearTransformation { 22 | return ltHeap.Retrieve(id).(lintrans.LinearTransformation) 23 | } 24 | 25 | //export DeleteLinearTransform 26 | func DeleteLinearTransform(id C.int) { 27 | ltHeap.Delete(int(id)) 28 | } 29 | 30 | //export NewLinearTransformEvaluator 31 | func NewLinearTransformEvaluator() { 32 | scheme.LinEvaluator = lintrans.NewEvaluator( 33 | ckks.NewEvaluator(*scheme.Params, scheme.EvalKeys)) 34 | } 35 | 36 | //export GenerateLinearTransform 37 | func GenerateLinearTransform( 38 | diagIdxsC *C.int, diagIdxsLen C.int, 39 | diagDataC *C.float, diagDataLen C.int, 40 | level C.int, 41 | bsgsRatio C.float, 42 | ioModeC *C.char, 43 | ) C.int { 44 | ioMode := C.GoString(ioModeC) 45 | 46 | // Unload diags data 47 | diagIdxs := CArrayToSlice(diagIdxsC, diagIdxsLen, convertCIntToInt) 48 | diagDataFlat := CArrayToSlice(diagDataC, diagDataLen, convertCFloatToFloat) 49 | 50 | // diagDataFlat is a flattened array of length len(diagIdxs) * slots. 51 | // The first element in diagIdxs corresponds to the first [0, slots] 52 | // values in diagsDataFlat, and so on. We'll extract these into a 53 | // dictionary that can be passed to Lattigo's LinearTransform evaluator. 54 | slots := scheme.Params.MaxSlots() 55 | diagonals := make(lintrans.Diagonals[float64]) 56 | 57 | for i, key := range diagIdxs { 58 | diagonals[key] = diagDataFlat[i*slots : (i+1)*slots] 59 | } 60 | 61 | ltparams := lintrans.Parameters{ 62 | DiagonalsIndexList: diagonals.DiagonalsIndexList(), 63 | LevelQ: int(level), 64 | LevelP: scheme.Params.MaxLevelP(), 65 | Scale: rlwe.NewScale(scheme.Params.Q()[int(level)]), 66 | LogDimensions: ring.Dimensions{Rows: 0, Cols: scheme.Params.LogMaxSlots()}, 67 | LogBabyStepGiantStepRatio: int(math.Log(float64(bsgsRatio))), 68 | } 69 | 70 | lt := lintrans.NewTransformation(scheme.Params, ltparams) 71 | 72 | // ---------------------------- // 73 | // Diagonal Generation/Saving // 74 | // ---------------------------- // 75 | 76 | // If ioMode is "load", then we expect the diagonals to have already been 77 | // generated and serialized, so there's no need to regenerate them here. 78 | // We do, however, still need to instantiate empty plaintext diagonals. 79 | if ioMode == "load" { 80 | lt.Vec = make(map[int]ringqp.Poly) 81 | for _, diag := range diagIdxs { 82 | lt.Vec[diag] = ringqp.Poly{} 83 | } 84 | } else { // otherwise, generate diagonals here. 85 | if err := lintrans.Encode(scheme.Encoder, diagonals, lt); err != nil { 86 | panic(err) 87 | } 88 | } 89 | 90 | // Return reference to linear transform object we just created 91 | ltID := ltHeap.Add(lt) 92 | return C.int(ltID) 93 | } 94 | 95 | //export EvaluateLinearTransform 96 | func EvaluateLinearTransform(transformID, ctxtID C.int) C.int { 97 | transform := RetrieveLinearTransform(int(transformID)) 98 | ctIn := RetrieveCiphertext(int(ctxtID)) 99 | 100 | // Update the linear transform evaluator to have the most 101 | // recent set of rotation keys. 102 | scheme.LinEvaluator = lintrans.NewEvaluator( 103 | scheme.Evaluator.WithKey(scheme.EvalKeys), 104 | ) 105 | 106 | ctOut, err := scheme.LinEvaluator.EvaluateNew(ctIn, transform) 107 | if err != nil { 108 | panic(err) 109 | } 110 | 111 | idx := PushCiphertext(ctOut) 112 | return C.int(idx) 113 | } 114 | 115 | //export GetLinearTransformRotationKeys 116 | func GetLinearTransformRotationKeys(transformID C.int) (*C.int, C.ulong) { 117 | transform := RetrieveLinearTransform(int(transformID)) 118 | galEls := transform.GaloisElements(scheme.Params) 119 | 120 | arrPtr, length := SliceToCArray(galEls, convertULongtoInt) 121 | return arrPtr, length 122 | } 123 | 124 | //export GenerateLinearTransformRotationKey 125 | func GenerateLinearTransformRotationKey(galEl C.int) { 126 | rotKey := scheme.KeyGen.GenGaloisKeyNew(uint64(galEl), scheme.SecretKey) 127 | scheme.EvalKeys.GaloisKeys[uint64(galEl)] = rotKey 128 | } 129 | 130 | //export GenerateAndSerializeRotationKey 131 | func GenerateAndSerializeRotationKey(galEl C.int) (*C.char, C.ulong) { 132 | rotKey := scheme.KeyGen.GenGaloisKeyNew(uint64(galEl), scheme.SecretKey) 133 | data, err := rotKey.MarshalBinary() // Marshal the key to binary 134 | if err != nil { 135 | panic(err) 136 | } 137 | 138 | arrPtr, length := SliceToCArray(data, convertByteToCChar) 139 | return arrPtr, length 140 | } 141 | 142 | //export LoadRotationKey 143 | func LoadRotationKey( 144 | dataPtr *C.char, lenData C.ulong, 145 | galEl C.ulong, 146 | ) { 147 | rotKeySerial := CArrayToByteSlice(unsafe.Pointer(dataPtr), uint64(lenData)) 148 | 149 | // Unmarshal the binary data into a GaloisKey 150 | var rotKey rlwe.GaloisKey 151 | if err := rotKey.UnmarshalBinary(rotKeySerial); err != nil { 152 | panic(err) 153 | } 154 | 155 | // Update our global map of evaluation keys to include what 156 | // we just loaded. This will eventually get used by the 157 | // current linear transform and then deleted from RAM. 158 | scheme.EvalKeys.GaloisKeys[uint64(galEl)] = &rotKey 159 | } 160 | 161 | //export SerializeDiagonal 162 | func SerializeDiagonal(transformID, diagIdx C.int) (*C.char, C.ulong) { 163 | transform := RetrieveLinearTransform(int(transformID)) 164 | diag := transform.Vec[int(diagIdx)] 165 | 166 | data, err := diag.MarshalBinary() // Marshal the diag to binary 167 | if err != nil { 168 | panic(err) 169 | } 170 | 171 | // Since it will be saved to disk, we can delete it from our 172 | // linear transform object and load it in dynamically at runtime 173 | transform.Vec[int(diagIdx)] = ringqp.Poly{} 174 | 175 | arrPtr, length := SliceToCArray(data, convertByteToCChar) 176 | return arrPtr, length 177 | } 178 | 179 | //export LoadPlaintextDiagonal 180 | func LoadPlaintextDiagonal( 181 | dataPtr *C.char, lenData C.ulong, 182 | transformID C.int, 183 | diagIdx C.ulong, 184 | ) { 185 | transform := RetrieveLinearTransform(int(transformID)) 186 | diagSerial := CArrayToByteSlice(unsafe.Pointer(dataPtr), uint64(lenData)) 187 | 188 | var poly ringqp.Poly 189 | if err := poly.UnmarshalBinary(diagSerial); err != nil { 190 | panic(err) 191 | } 192 | transform.Vec[int(diagIdx)] = poly 193 | } 194 | 195 | //export RemovePlaintextDiagonals 196 | func RemovePlaintextDiagonals(transformID C.int) { 197 | linTransf := RetrieveLinearTransform(int(transformID)) 198 | for diag := range linTransf.Vec { 199 | linTransf.Vec[diag] = ringqp.Poly{} 200 | } 201 | } 202 | 203 | //export RemoveRotationKeys 204 | func RemoveRotationKeys() { 205 | // We'll just update the linear transform evaluator to no longer have 206 | // access to the Galois keys it had before. GC should do the rest. 207 | scheme.EvalKeys = rlwe.NewMemEvaluationKeySet(scheme.RelinKey) 208 | scheme.LinEvaluator = lintrans.NewEvaluator(scheme.Evaluator.WithKey( 209 | scheme.EvalKeys, 210 | )) 211 | } 212 | -------------------------------------------------------------------------------- /orion/backend/lattigo/main.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import "C" 4 | 5 | func main() {} 6 | -------------------------------------------------------------------------------- /orion/backend/lattigo/minheap.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "container/heap" 5 | "fmt" 6 | ) 7 | 8 | type MinHeap []int 9 | 10 | func (h MinHeap) Len() int { return len(h) } 11 | func (h MinHeap) Less(i, j int) bool { return h[i] < h[j] } 12 | func (h MinHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } 13 | 14 | func (h *MinHeap) Push(x interface{}) { 15 | *h = append(*h, x.(int)) 16 | } 17 | 18 | func (h *MinHeap) Pop() interface{} { 19 | old := *h 20 | n := len(old) 21 | x := old[n-1] 22 | *h = old[0 : n-1] 23 | return x 24 | } 25 | 26 | // HeapAllocator updated to store pointers 27 | type HeapAllocator struct { 28 | nextInt int // The next integer to allocate 29 | freedIntegers MinHeap // Min-heap to store freed integers 30 | InterfaceMap map[int]*interface{} // Map to store/retrieve pointers to structs 31 | } 32 | 33 | // NewHeapAllocator initializes and returns a new HeapAllocator. 34 | func NewHeapAllocator() *HeapAllocator { 35 | allocator := &HeapAllocator{ 36 | nextInt: 0, 37 | freedIntegers: MinHeap{}, 38 | InterfaceMap: make(map[int]*interface{}), 39 | } 40 | heap.Init(&allocator.freedIntegers) 41 | return allocator 42 | } 43 | 44 | // Add assigns the lowest available integer to the provided object and 45 | // returns the integer. Now ensures we're storing a pointer. 46 | func (ha *HeapAllocator) Add(obj interface{}) int { 47 | var allocated int 48 | if len(ha.freedIntegers) > 0 { 49 | // Reuse the smallest available integer from the heap 50 | allocated = heap.Pop(&ha.freedIntegers).(int) 51 | } else { 52 | // Allocate a new integer 53 | allocated = ha.nextInt 54 | ha.nextInt++ 55 | } 56 | 57 | // Create a pointer to the interface value 58 | objCopy := obj 59 | objPtr := &objCopy 60 | 61 | // Store the pointer in the map 62 | ha.InterfaceMap[allocated] = objPtr 63 | return allocated 64 | } 65 | 66 | // Retrieve returns the associated object with integer. 67 | func (ha *HeapAllocator) Retrieve(integer int) interface{} { 68 | if objPtr, exists := ha.InterfaceMap[integer]; exists { 69 | // Dereference the pointer to get the original interface value 70 | return *objPtr 71 | } 72 | panic(fmt.Sprintf("Heap object not found for integer: %d", integer)) 73 | } 74 | 75 | // Delete removes the integer and its associated object from the allocator 76 | // and adds the integer back to the pool of available integers. 77 | func (ha *HeapAllocator) Delete(integer int) { 78 | if _, exists := ha.InterfaceMap[integer]; exists { 79 | heap.Push(&ha.freedIntegers, integer) 80 | delete(ha.InterfaceMap, integer) 81 | } 82 | } 83 | 84 | // Reset clears the allocator's state, reinitializing its fields. 85 | func (ha *HeapAllocator) Reset() { 86 | ha.nextInt = 0 87 | ha.freedIntegers = MinHeap{} // Reinitialize the slice 88 | heap.Init(&ha.freedIntegers) // Reinitialize the heap properties 89 | ha.InterfaceMap = make(map[int]*interface{}) 90 | } 91 | 92 | func (ha *HeapAllocator) GetLiveKeys() []int { 93 | keys := make([]int, 0, len(ha.InterfaceMap)) 94 | for k := range ha.InterfaceMap { 95 | keys = append(keys, k) 96 | } 97 | return keys 98 | } 99 | -------------------------------------------------------------------------------- /orion/backend/lattigo/polyeval.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "fmt" 7 | "math/big" 8 | "strings" 9 | 10 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/minimax" 11 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/polynomial" 12 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 13 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 14 | "github.com/baahl-nyu/lattigo/v6/utils/bignum" 15 | ) 16 | 17 | var polyHeap = NewHeapAllocator() 18 | var minimaxSignMap = make(map[string][][]float64) 19 | 20 | func AddPoly(poly bignum.Polynomial) int { 21 | return polyHeap.Add(poly) 22 | } 23 | 24 | func RetrievePoly(polyID int) bignum.Polynomial { 25 | return polyHeap.Retrieve(polyID).(bignum.Polynomial) 26 | } 27 | 28 | func DeletePoly(polyID int) { 29 | polyHeap.Delete(polyID) 30 | } 31 | 32 | //export NewPolynomialEvaluator 33 | func NewPolynomialEvaluator() { 34 | scheme.PolyEvaluator = polynomial.NewEvaluator(*scheme.Params, scheme.Evaluator) 35 | } 36 | 37 | //export GenerateMonomial 38 | func GenerateMonomial( 39 | coeffsPtr *C.float, 40 | lenCoeffs C.int, 41 | ) C.int { 42 | coeffs := CArrayToSlice(coeffsPtr, lenCoeffs, convertCFloatToFloat) 43 | poly := bignum.NewPolynomial(bignum.Monomial, coeffs, nil) 44 | 45 | idx := AddPoly(poly) 46 | return C.int(idx) 47 | } 48 | 49 | //export GenerateChebyshev 50 | func GenerateChebyshev( 51 | coeffsPtr *C.float, 52 | lenCoeffs C.int, 53 | ) C.int { 54 | coeffs := CArrayToSlice(coeffsPtr, lenCoeffs, convertCFloatToFloat) 55 | poly := bignum.NewPolynomial( 56 | bignum.Chebyshev, coeffs, [2]float64{-1.0, 1.0}) 57 | 58 | idx := AddPoly(poly) 59 | return C.int(idx) 60 | } 61 | 62 | //export EvaluatePolynomial 63 | func EvaluatePolynomial( 64 | ctInID C.int, 65 | polyID C.int, 66 | outScale C.ulong, 67 | ) C.int { 68 | poly := RetrievePoly(int(polyID)) 69 | ctIn := RetrieveCiphertext(int(ctInID)) 70 | 71 | // Often times we'll want to keep the original input ciphertext unchanged. 72 | ctTmp := ckks.NewCiphertext(*scheme.Params, 1, ctIn.Level()) 73 | ctTmp.Copy(ctIn) 74 | 75 | res, err := scheme.PolyEvaluator.Evaluate( 76 | ctTmp, poly, rlwe.NewScale(uint64(outScale)), 77 | ) 78 | if err != nil { 79 | panic(err) 80 | } 81 | 82 | ctOutID := PushCiphertext(res) 83 | return C.int(ctOutID) 84 | } 85 | 86 | // ------------------------------ // 87 | // Minimax Sign Helper Functions // 88 | // ------------------------------ // 89 | 90 | //export GenerateMinimaxSignCoeffs 91 | func GenerateMinimaxSignCoeffs( 92 | degreesPtr *C.int, lenDegrees C.int, 93 | prec C.int, 94 | logalpha C.int, 95 | logerr C.int, 96 | debug C.int, 97 | ) (*C.double, C.ulong) { 98 | degrees := CArrayToSlice(degreesPtr, lenDegrees, convertCIntToInt) 99 | 100 | // We'll eventually return this flattened list of coefficients 101 | sumDegrees := 0 102 | for _, d := range degrees { 103 | sumDegrees += d + 1 104 | } 105 | flatCoeffs := make([]float64, sumDegrees) 106 | 107 | // Generate key for given minimax sign parameters 108 | key := GenerateUniqueKey( 109 | degrees, 110 | uint(prec), 111 | int(logalpha), 112 | int(logerr), 113 | ) 114 | 115 | // Check if coefficients already exist in the map 116 | if existingCoeffs, exists := minimaxSignMap[key]; exists { 117 | // If so, avoid generating them and instead return these 118 | idx := 0 119 | for _, poly := range existingCoeffs { 120 | for _, coeff := range poly { 121 | flatCoeffs[idx] = coeff 122 | idx++ 123 | } 124 | } 125 | } else { 126 | // Otherwise, generate new coefficients 127 | coeffs := minimax.GenMinimaxCompositePolynomial( 128 | uint(prec), 129 | int(logalpha), 130 | int(logerr), 131 | degrees, 132 | bignum.Sign, 133 | int(debug) != 0, 134 | ) 135 | 136 | // Divide last poly by 2 to scale from [-1,1] -> [-0.5, 0.5] 137 | for i := range coeffs[len(degrees)-1] { 138 | coeffs[len(degrees)-1][i].Quo(coeffs[len(degrees)-1][i], big.NewFloat(2)) 139 | } 140 | 141 | // Add 0.5 to last polynomial so sign outputs in range [0, 1] 142 | coeffs[len(degrees)-1][0] = coeffs[len(degrees)-1][0].Add( 143 | coeffs[len(degrees)-1][0], big.NewFloat(0.5)) 144 | 145 | // Create 2D array of float64 to store in map 146 | float64Coeffs := make([][]float64, len(coeffs)) 147 | for i := range coeffs { 148 | float64Coeffs[i] = make([]float64, len(coeffs[i])) 149 | } 150 | 151 | idx := 0 152 | for i, poly := range coeffs { 153 | for j, coeff := range poly { 154 | f64, _ := coeff.Float64() 155 | flatCoeffs[idx] = f64 156 | float64Coeffs[i][j] = f64 157 | idx++ 158 | } 159 | } 160 | 161 | // Store coefficients in the map for future use 162 | minimaxSignMap[key] = float64Coeffs 163 | } 164 | 165 | arrPtr, arrLen := SliceToCArray(flatCoeffs, convertFloat64ToCDouble) 166 | return arrPtr, arrLen 167 | } 168 | 169 | // Create a unique string from the minimax parameters to use as an 170 | // index for the sign map. 171 | func GenerateUniqueKey( 172 | degrees []int, 173 | prec uint, 174 | logAlpha int, 175 | logErr int, 176 | ) string { 177 | degreesStr := make([]string, len(degrees)) 178 | for i, deg := range degrees { 179 | degreesStr[i] = fmt.Sprintf("%d", deg) 180 | } 181 | 182 | // Create a composite key 183 | return fmt.Sprintf("%s|%d|%d|%d", 184 | strings.Join(degreesStr, ","), 185 | prec, 186 | logAlpha, 187 | logErr) 188 | } 189 | 190 | func DeleteMinimaxSignMap() { 191 | minimaxSignMap = make(map[string][][]float64) 192 | } 193 | -------------------------------------------------------------------------------- /orion/backend/lattigo/scheme.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/bootstrapping" 7 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/polynomial" 8 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 9 | "github.com/baahl-nyu/lattigo/v6/ring" 10 | "github.com/baahl-nyu/lattigo/v6/schemes/ckks" 11 | ) 12 | import ( 13 | "github.com/baahl-nyu/lattigo/v6/circuits/ckks/lintrans" 14 | ) 15 | 16 | type Scheme struct { 17 | Params *ckks.Parameters 18 | KeyGen *rlwe.KeyGenerator 19 | SecretKey *rlwe.SecretKey 20 | PublicKey *rlwe.PublicKey 21 | RelinKey *rlwe.RelinearizationKey 22 | EvalKeys *rlwe.MemEvaluationKeySet 23 | Encoder *ckks.Encoder 24 | Encryptor *rlwe.Encryptor 25 | Decryptor *rlwe.Decryptor 26 | Evaluator *ckks.Evaluator 27 | PolyEvaluator *polynomial.Evaluator 28 | LinEvaluator *lintrans.Evaluator 29 | Bootstrapper *bootstrapping.Evaluator 30 | } 31 | 32 | var scheme Scheme 33 | 34 | //export NewScheme 35 | func NewScheme( 36 | logN C.int, 37 | logQPtr *C.int, lenQ C.int, 38 | logPPtr *C.int, lenP C.int, 39 | logScale C.int, 40 | h C.int, 41 | ringType *C.char, 42 | keysPath *C.char, 43 | ioMode *C.char, 44 | ) { 45 | // Convert LogQ and LogP to Go slices 46 | logQ := CArrayToSlice(logQPtr, lenQ, convertCIntToInt) 47 | logP := CArrayToSlice(logPPtr, lenP, convertCIntToInt) 48 | 49 | ringT := ring.Standard 50 | if C.GoString(ringType) != "standard" { 51 | ringT = ring.ConjugateInvariant 52 | } 53 | 54 | var err error 55 | var params ckks.Parameters 56 | 57 | if params, err = ckks.NewParametersFromLiteral( 58 | ckks.ParametersLiteral{ 59 | LogN: int(logN), 60 | LogQ: logQ, 61 | LogP: logP, 62 | LogDefaultScale: int(logScale), 63 | Xs: ring.Ternary{H: int(h)}, 64 | RingType: ringT, 65 | }); err != nil { 66 | panic(err) 67 | } 68 | 69 | keyGen := rlwe.NewKeyGenerator(params) 70 | 71 | scheme = Scheme{ 72 | Params: ¶ms, 73 | KeyGen: keyGen, 74 | SecretKey: nil, 75 | PublicKey: nil, 76 | RelinKey: nil, 77 | EvalKeys: nil, 78 | Encoder: nil, 79 | Encryptor: nil, 80 | Decryptor: nil, 81 | Evaluator: nil, 82 | PolyEvaluator: nil, 83 | LinEvaluator: nil, 84 | Bootstrapper: nil, 85 | } 86 | } 87 | 88 | //export DeleteScheme 89 | func DeleteScheme() { 90 | scheme = Scheme{} 91 | 92 | DeleteRotationKeys() 93 | DeleteBootstrappers() 94 | DeleteMinimaxSignMap() 95 | 96 | ltHeap.Reset() 97 | polyHeap.Reset() 98 | ptHeap.Reset() 99 | ctHeap.Reset() 100 | } 101 | -------------------------------------------------------------------------------- /orion/backend/lattigo/tensors.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "C" 5 | 6 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 7 | ) 8 | 9 | var ( 10 | ptHeap = NewHeapAllocator() 11 | ctHeap = NewHeapAllocator() 12 | ) 13 | 14 | func PushPlaintext(plaintext *rlwe.Plaintext) int { 15 | return ptHeap.Add(plaintext) 16 | } 17 | 18 | func PushCiphertext(ciphertext *rlwe.Ciphertext) int { 19 | return ctHeap.Add(ciphertext) 20 | } 21 | 22 | func RetrievePlaintext(plaintextID int) *rlwe.Plaintext { 23 | return ptHeap.Retrieve(plaintextID).(*rlwe.Plaintext) 24 | } 25 | 26 | func RetrieveCiphertext(ciphertextID int) *rlwe.Ciphertext { 27 | return ctHeap.Retrieve(ciphertextID).(*rlwe.Ciphertext) 28 | } 29 | 30 | // ---------------------------------------- // 31 | // PYTHON BINDINGS // 32 | // ---------------------------------------- // 33 | 34 | //export DeletePlaintext 35 | func DeletePlaintext(plaintextID C.int) { 36 | ptHeap.Delete(int(plaintextID)) 37 | } 38 | 39 | //export DeleteCiphertext 40 | func DeleteCiphertext(ciphertextID C.int) { 41 | ctHeap.Delete(int(ciphertextID)) 42 | } 43 | 44 | //export GetPlaintextScale 45 | func GetPlaintextScale(plaintextID C.int) C.ulong { 46 | plaintext := RetrievePlaintext(int(plaintextID)) 47 | scaleBig := &plaintext.Scale.Value 48 | scale, _ := scaleBig.Uint64() 49 | return C.ulong(scale) 50 | } 51 | 52 | //export GetCiphertextScale 53 | func GetCiphertextScale(ciphertextID C.int) C.ulong { 54 | ciphertext := RetrieveCiphertext(int(ciphertextID)) 55 | scaleBig := &ciphertext.Scale.Value 56 | scale, _ := scaleBig.Uint64() 57 | return C.ulong(scale) 58 | } 59 | 60 | //export SetPlaintextScale 61 | func SetPlaintextScale(plaintextID C.int, scale C.ulong) { 62 | plaintext := RetrievePlaintext(int(plaintextID)) 63 | plaintext.Scale = rlwe.NewScale(uint64(scale)) 64 | } 65 | 66 | //export SetCiphertextScale 67 | func SetCiphertextScale(ciphertextID C.int, scale C.ulong) { 68 | ciphertext := RetrieveCiphertext(int(ciphertextID)) 69 | ciphertext.Scale = rlwe.NewScale(uint64(scale)) 70 | } 71 | 72 | //export GetPlaintextLevel 73 | func GetPlaintextLevel(plaintextID C.int) C.int { 74 | plaintext := RetrievePlaintext(int(plaintextID)) 75 | return C.int(plaintext.Level()) 76 | } 77 | 78 | //export GetCiphertextLevel 79 | func GetCiphertextLevel(ciphertextID int) C.int { 80 | ciphertext := RetrieveCiphertext(ciphertextID) 81 | return C.int(ciphertext.Level()) 82 | } 83 | 84 | //export GetPlaintextSlots 85 | func GetPlaintextSlots(plaintextID int) C.int { 86 | plaintext := RetrievePlaintext(plaintextID) 87 | slots := 1 << plaintext.LogDimensions.Cols 88 | return C.int(slots) 89 | } 90 | 91 | //export GetCiphertextSlots 92 | func GetCiphertextSlots(ciphertextID int) C.int { 93 | ciphertext := RetrieveCiphertext(ciphertextID) 94 | slots := 1 << ciphertext.LogDimensions.Cols 95 | return C.int(slots) 96 | } 97 | 98 | //export GetCiphertextDegree 99 | func GetCiphertextDegree(ciphertextID int) C.int { 100 | ciphertext := RetrieveCiphertext(ciphertextID) 101 | return C.int(ciphertext.Degree()) 102 | } 103 | 104 | //export GetModuliChain 105 | func GetModuliChain() (*C.ulong, C.ulong) { 106 | moduli := scheme.Params.Q() 107 | arrPtr, length := SliceToCArray(moduli, convertULongtoCULong) 108 | return arrPtr, length 109 | } 110 | 111 | //export GetLivePlaintexts 112 | func GetLivePlaintexts() (*C.int, C.ulong) { 113 | ids := ptHeap.GetLiveKeys() 114 | arrPtr, length := SliceToCArray(ids, convertIntToCInt) 115 | return arrPtr, length 116 | } 117 | 118 | //export GetLiveCiphertexts 119 | func GetLiveCiphertexts() (*C.int, C.ulong) { 120 | ids := ctHeap.GetLiveKeys() 121 | arrPtr, length := SliceToCArray(ids, convertIntToCInt) 122 | return arrPtr, length 123 | } 124 | -------------------------------------------------------------------------------- /orion/backend/lattigo/utils.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | //#include 4 | import "C" 5 | 6 | import ( 7 | "fmt" 8 | "unsafe" 9 | 10 | "github.com/baahl-nyu/lattigo/v6/core/rlwe" 11 | ) 12 | 13 | func convertCIntToInt(v C.int) int { 14 | return int(v) 15 | } 16 | func convertCFloatToFloat(v C.float) float64 { 17 | return float64(v) 18 | } 19 | 20 | func CArrayToByteSlice(dataPtr unsafe.Pointer, length uint64) []byte { 21 | return unsafe.Slice((*byte)(dataPtr), length) 22 | } 23 | 24 | func convertFloatToCFloat(v float64) C.float { 25 | return C.float(v) 26 | } 27 | 28 | func convertFloat64ToCDouble(v float64) C.double { 29 | return C.double(v) 30 | } 31 | 32 | func convertIntToCInt(v int) C.int { 33 | return C.int(v) 34 | } 35 | 36 | func convertULongtoCULong(v uint64) C.ulong { 37 | return C.ulong(v) 38 | } 39 | 40 | func convertULongtoInt(v uint64) C.int { 41 | return C.int(v) 42 | } 43 | 44 | func convertByteToCChar(b byte) C.char { 45 | return C.char(b) 46 | } 47 | 48 | func CArrayToSlice[T, U any](ptr *U, length C.int, conv func(U) T) []T { 49 | cSlice := unsafe.Slice(ptr, int(length)) 50 | result := make([]T, int(length)) 51 | for i, v := range cSlice { 52 | result[i] = conv(v) 53 | } 54 | return result 55 | } 56 | 57 | func SliceToCArray[T, U any](slice []T, conv func(T) U) (*U, C.ulong) { 58 | n := len(slice) 59 | if n == 0 { 60 | return nil, 0 61 | } 62 | size := C.size_t(n) * C.size_t(unsafe.Sizeof(*new(U))) 63 | ptr := C.malloc(size) 64 | if ptr == nil { 65 | panic("C.malloc failed") 66 | } 67 | cArray := unsafe.Slice((*U)(ptr), n) 68 | for i, v := range slice { 69 | cArray[i] = conv(v) 70 | } 71 | return (*U)(ptr), C.ulong(n) 72 | } 73 | 74 | // Keys returns a slice of keys from the provided map. 75 | func GetKeysFromMap[K comparable, V any](m map[K]V) []K { 76 | keys := make([]K, 0, len(m)) 77 | for k := range m { 78 | keys = append(keys, k) 79 | } 80 | return keys 81 | } 82 | 83 | // Values returns a slice of values from the provided map. 84 | func GetValuesFromMap[K comparable, V any](m map[K]V) []V { 85 | values := make([]V, 0, len(m)) 86 | for _, v := range m { 87 | values = append(values, v) 88 | } 89 | return values 90 | } 91 | 92 | //export FreeCArray 93 | func FreeCArray(ptr unsafe.Pointer) { 94 | C.free(ptr) 95 | } 96 | 97 | func PrintCipher(scheme Scheme, ctxt *rlwe.Ciphertext) { 98 | msg := make([]float64, ctxt.Slots()) 99 | 100 | // Decode and check result 101 | ptxt := scheme.Decryptor.DecryptNew(ctxt) 102 | _ = scheme.Encoder.Decode(ptxt, msg) 103 | 104 | for i := 0; i < min(16, ctxt.Slots()); i++ { 105 | fmt.Printf("msg[%d]: %.5f\n", i, msg[i]) 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /orion/backend/openfhe/README.md: -------------------------------------------------------------------------------- 1 | # Orion -------------------------------------------------------------------------------- /orion/backend/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baahl-nyu/orion/f0581052b28d02a00299cce742949930b3260aa8/orion/backend/python/__init__.py -------------------------------------------------------------------------------- /orion/backend/python/bootstrapper.py: -------------------------------------------------------------------------------- 1 | class NewEvaluator: 2 | def __init__(self, scheme): 3 | self.scheme = scheme 4 | self.backend = scheme.backend 5 | 6 | def __del__(self): 7 | self.backend.DeleteBootstrappers() 8 | 9 | def generate_bootstrapper(self, slots): 10 | # We will wait to instantiate any bootstrapper until our bootstrap 11 | # placement algorithm determines they're necessary. 12 | logp = self.scheme.params.get_boot_logp() 13 | return self.backend.NewBootstrapper(logp, slots) 14 | 15 | def bootstrap(self, ctxt, slots): 16 | return self.backend.Bootstrap(ctxt, slots) -------------------------------------------------------------------------------- /orion/backend/python/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .tensors import PlainTensor 3 | 4 | class NewEncoder: 5 | def __init__(self, scheme): 6 | self.scheme = scheme 7 | self.params = scheme.params 8 | self.backend = scheme.backend 9 | self.setup_encoder() 10 | 11 | def setup_encoder(self): 12 | self.backend.NewEncoder() 13 | 14 | def encode(self, values, level=None, scale=None): 15 | if isinstance(values, list): 16 | values = torch.tensor(values) 17 | elif not isinstance(values, torch.Tensor): 18 | raise TypeError( 19 | f"Expected 'values' passed to encode() to be a either a list " 20 | f"or a torch.Tensor, but got {type(values)}.") 21 | 22 | if not level: 23 | level = self.params.get_max_level() 24 | if not scale: 25 | scale = self.params.get_default_scale() 26 | 27 | num_slots = self.params.get_slots() 28 | num_elements = values.numel() 29 | 30 | values = values.cpu() 31 | pad_length = (-num_elements) % num_slots 32 | vector = torch.zeros(num_elements + pad_length) 33 | vector[:num_elements] = values.flatten() 34 | num_plaintexts = len(vector) // num_slots 35 | 36 | plaintext_ids = [] 37 | for i in range(num_plaintexts): 38 | to_encode = vector[i*num_slots:(i+1)*num_slots].tolist() 39 | plaintext_id = self.backend.Encode(to_encode, level, scale) 40 | plaintext_ids.append(plaintext_id) 41 | 42 | return PlainTensor(self.scheme, plaintext_ids, values.shape) 43 | 44 | def decode(self, plaintensor: PlainTensor): 45 | values = [] 46 | for plaintext_id in plaintensor.ids: 47 | values.extend(self.backend.Decode(plaintext_id)) 48 | 49 | values = torch.tensor(values)[:plaintensor.on_shape.numel()] 50 | return values.reshape(plaintensor.on_shape) 51 | 52 | def get_moduli_chain(self): 53 | return self.backend.GetModuliChain() 54 | -------------------------------------------------------------------------------- /orion/backend/python/encryptor.py: -------------------------------------------------------------------------------- 1 | from .tensors import PlainTensor, CipherTensor 2 | 3 | class NewEncryptor: 4 | def __init__(self, scheme): 5 | self.scheme = scheme 6 | self.backend = scheme.backend 7 | self.new_encryptor() 8 | self.new_decryptor() 9 | 10 | def new_encryptor(self): 11 | self.backend.NewEncryptor() 12 | 13 | def new_decryptor(self): 14 | self.backend.NewDecryptor() 15 | 16 | def encrypt(self, plaintensor): 17 | ciphertext_ids = [] 18 | for ptxt in plaintensor.ids: 19 | ciphertext_id = self.backend.Encrypt(ptxt) 20 | ciphertext_ids.append(ciphertext_id) 21 | 22 | return CipherTensor( 23 | self.scheme, ciphertext_ids, plaintensor.shape, plaintensor.on_shape) 24 | 25 | def decrypt(self, ciphertensor): 26 | plaintext_ids = [] 27 | for ctxt in ciphertensor.ids: 28 | plaintext_id = self.backend.Decrypt(ctxt) 29 | plaintext_ids.append(plaintext_id) 30 | 31 | return PlainTensor( 32 | self.scheme, plaintext_ids, ciphertensor.shape, ciphertensor.on_shape 33 | ) -------------------------------------------------------------------------------- /orion/backend/python/evaluator.py: -------------------------------------------------------------------------------- 1 | class NewEvaluator: 2 | def __init__(self, scheme): 3 | self.backend = scheme.backend 4 | self.new_evaluator() 5 | 6 | def new_evaluator(self): 7 | self.backend.NewEvaluator() 8 | 9 | def add_rotation_key(self, amount: int): 10 | self.backend.AddRotationKey(amount) 11 | 12 | def negate(self, ctxt): 13 | return self.backend.Negate(ctxt) 14 | 15 | def rotate(self, ctxt, amount, in_place): 16 | if in_place: 17 | return self.backend.Rotate(ctxt, amount) 18 | return self.backend.RotateNew(ctxt, amount) 19 | 20 | def add_scalar(self, ctxt, scalar, in_place): 21 | if in_place: 22 | return self.backend.AddScalar(ctxt, float(scalar)) 23 | return self.backend.AddScalarNew(ctxt, float(scalar)) 24 | 25 | def sub_scalar(self, ctxt, scalar, in_place): 26 | if in_place: 27 | return self.backend.SubScalar(ctxt, float(scalar)) 28 | return self.backend.SubScalarNew(ctxt, float(scalar)) 29 | 30 | def mul_scalar(self, ctxt, scalar, in_place): 31 | if isinstance(scalar, float) and scalar.is_integer(): 32 | scalar = int(scalar) # (e.g., 1.00 -> 1) 33 | 34 | if isinstance(scalar, int): 35 | ct_out = (self.backend.MulScalarInt if in_place 36 | else self.backend.MulScalarIntNew)(ctxt, scalar) 37 | else: 38 | ct_out = (self.backend.MulScalarFloat if in_place 39 | else self.backend.MulScalarFloatNew)(ctxt, scalar) 40 | ct_out = self.backend.Rescale(ct_out) 41 | 42 | return ct_out 43 | 44 | def add_plaintext(self, ctxt, ptxt, in_place): 45 | if in_place: 46 | return self.backend.AddPlaintext(ctxt, ptxt) 47 | return self.backend.AddPlaintextNew(ctxt, ptxt) 48 | 49 | def sub_plaintext(self, ctxt, ptxt, in_place): 50 | if in_place: 51 | return self.backend.SubPlaintext(ctxt, ptxt) 52 | return self.backend.SubPlaintextNew(ctxt, ptxt) 53 | 54 | def mul_plaintext(self, ctxt, ptxt, in_place): 55 | if in_place: # ct_out = ctxt 56 | ct_out = self.backend.MulPlaintext(ctxt, ptxt) 57 | else: 58 | ct_out = self.backend.MulPlaintextNew(ctxt, ptxt) 59 | 60 | return self.backend.Rescale(ct_out) 61 | 62 | def add_ciphertext(self, ctxt0, ctxt1, in_place): 63 | if in_place: 64 | return self.backend.AddCiphertext(ctxt0, ctxt1) 65 | return self.backend.AddCiphertextNew(ctxt0, ctxt1) 66 | 67 | def sub_ciphertext(self, ctxt0, ctxt1, in_place): 68 | if in_place: 69 | return self.backend.SubCiphertext(ctxt0, ctxt1) 70 | return self.backend.SubCiphertextNew(ctxt0, ctxt1) 71 | 72 | def mul_ciphertext(self, ctxt0, ctxt1, in_place): 73 | if in_place: # ct_out = ctxt 74 | ct_out = self.backend.MulRelinCiphertext(ctxt0, ctxt1) 75 | else: 76 | ct_out = self.backend.MulRelinCiphertextNew(ctxt0, ctxt1) 77 | 78 | return self.backend.Rescale(ct_out) 79 | 80 | def rescale(self, ctxt, in_place): 81 | if in_place: 82 | return self.backend.Rescale(ctxt) 83 | return self.backend.RescaleNew(ctxt) 84 | 85 | def get_live_plaintexts(self): 86 | return self.backend.GetLivePlaintexts() 87 | 88 | def get_live_ciphertexts(self): 89 | return self.backend.GetLiveCiphertexts() 90 | 91 | -------------------------------------------------------------------------------- /orion/backend/python/key_generator.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | 3 | class NewKeyGenerator: 4 | def __init__(self, scheme): 5 | self.backend = scheme.backend 6 | self.io_mode = scheme.params.get_io_mode() 7 | self.keys_path = scheme.params.get_keys_path() 8 | self.new_key_generator() 9 | 10 | def new_key_generator(self): 11 | self.backend.NewKeyGenerator() 12 | self.generate_secret_key() 13 | self.generate_public_key() 14 | self.generate_relinearization_key() 15 | self.generate_evaluation_keys() 16 | 17 | def generate_secret_key(self): 18 | if self.io_mode != "load": # we'll need to generate a fresh sk 19 | self.backend.GenerateSecretKey() 20 | 21 | # Save key if in "save" mode 22 | if self.io_mode == "save": 23 | sk_serial, _ = self.backend.SerializeSecretKey() 24 | with h5py.File(self.keys_path, "a") as f: 25 | f.create_dataset("sk", data=sk_serial) 26 | 27 | # Load key if in "load" mode 28 | elif self.io_mode == "load": 29 | with h5py.File(self.keys_path, "r") as f: 30 | serial_sk = f["sk"][()] 31 | self.backend.LoadSecretKey(serial_sk) 32 | 33 | def generate_public_key(self): 34 | self.backend.GeneratePublicKey() 35 | 36 | def generate_relinearization_key(self): 37 | self.backend.GenerateRelinearizationKey() 38 | 39 | def generate_evaluation_keys(self): 40 | self.backend.GenerateEvaluationKeys() -------------------------------------------------------------------------------- /orion/backend/python/parameters.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Literal, List 3 | from dataclasses import dataclass, field 4 | 5 | 6 | @dataclass 7 | class CKKSParameters: 8 | logn: int 9 | logq: List[int] 10 | logp: List[int] 11 | logscale: int = field(default=None) 12 | h: int = 192 13 | ringtype: str = "standard" 14 | boot_logp: List[int] = field(default=None) 15 | 16 | def __post_init__(self): 17 | if self.logq and self.logp and len(self.logp) > len(self.logq): 18 | raise ValueError( 19 | f"Invalid parameters: The length of logp ({len(self.logp)}) " 20 | f"cannot exceed the length of logq ({len(self.logq)})." 21 | ) 22 | 23 | valid_ringtypes = {"standard", "conjugateinvariant"} 24 | ring = self.ringtype.lower() 25 | if ring not in valid_ringtypes: 26 | raise ValueError( 27 | f"Invalid ringtype: {self.ringtype}. Only 'Standard' or " 28 | f"'ConjugateInvariant' ring types are supported." 29 | ) 30 | 31 | self.logscale = self.logscale or self.logq[-1] 32 | self.boot_logp = self.boot_logp or self.logp 33 | self.logslots = ( 34 | self.logn-1 if self.ringtype.lower() == "standard" 35 | else self.logn 36 | ) 37 | 38 | def __str__(self): 39 | if self.ringtype.lower() == "standard": 40 | ring_type_display = "Standard" 41 | else: 42 | ring_type_display = "Conjugate invariant" 43 | 44 | output = [ 45 | "CKKS Parameters:", 46 | f" Ring degree (N): {1 << self.logn} (LogN = {self.logn})", 47 | f" Number of slots (n): {1 << self.logslots}", 48 | f" Effective levels (L_eff): {len(self.logq) - 1}" 49 | f" Ring type: {ring_type_display}", 50 | f" Scale: 2^{self.logscale}", 51 | f" Hamming weight: {self.h}" 52 | ] 53 | 54 | # Format LogQ values 55 | logq_str = ", ".join(str(q) for q in self.logq) 56 | output.append(f" LogQ: [{logq_str}] (length: {len(self.logq)})") 57 | 58 | # Format LogP values 59 | logp_str = ", ".join(str(p) for p in self.logp) 60 | output.append(f" LogP: [{logp_str}] (length: {len(self.logp)})") 61 | 62 | # Format Boot LogP values if different from LogP 63 | if self.boot_logp != self.logp: 64 | boot_logp_str = ", ".join(str(p) for p in self.boot_logp) 65 | output.append(f" Boot LogP: [{boot_logp_str}] (length: {len(self.boot_logp)})") 66 | 67 | return "\n".join(output) 68 | 69 | 70 | @dataclass 71 | class OrionParameters: 72 | margin: int = 2 73 | fuse_modules: bool = True 74 | debug: bool = True 75 | embedding_method: Literal["hybrid", "square"] = "hybrid" 76 | backend: Literal["lattigo", "openfhe", "heaan"] = "lattigo" 77 | io_mode: Literal["none", "save", "load"] = "none" 78 | diags_path: str = "" 79 | keys_path: str = "" 80 | 81 | def __str__(self) -> str: 82 | output = [ 83 | "Orion Parameters:", 84 | f" Backend: {self.backend}", 85 | f" Margin: {self.margin}", 86 | f" Embedding Method: {self.embedding_method}", 87 | f" Fuse Modules: {self.fuse_modules}", 88 | f" Debug Mode: {self.debug}" 89 | ] 90 | 91 | output.append(f" I/O Mode: {self.io_mode}") 92 | if self.diags_path: 93 | output.append(f" Diagonals Path: {self.diags_path}") 94 | if self.keys_path: 95 | output.append(f" Keys Path: {self.keys_path}") 96 | 97 | return "\n".join(output) 98 | 99 | 100 | @dataclass 101 | class NewParameters: 102 | params_json: dict 103 | ckks_params: CKKSParameters = field(init=False) 104 | orion_params: OrionParameters = field(init=False) 105 | 106 | def __post_init__(self): 107 | params = self.params_json 108 | ckks_params = { 109 | k.lower(): v for k, v in params.get("ckks_params", {}).items()} 110 | boot_params = { 111 | k.lower(): v for k, v in params.get("boot_params", {}).items()} 112 | orion_params = { 113 | k.lower(): v for k, v in params.get("orion", {}).items()} 114 | 115 | self.ckks_params = CKKSParameters( 116 | **ckks_params, boot_logp=boot_params.get("logp") 117 | ) 118 | self.orion_params = OrionParameters(**orion_params) 119 | 120 | # Finally, we'll delete existing keys/diagonals if the user 121 | # specifies to overwrite them. 122 | if self.get_io_mode() == "save" and self.io_paths_exist(): 123 | self.reset_stored_keys() 124 | self.reset_stored_diags() 125 | 126 | def __str__(self) -> str: 127 | border = "=" * 50 128 | return f"\n{border}\n{self.ckks_params}\n\n{self.orion_params}\n{border}\n" 129 | 130 | def get_logn(self): 131 | return self.ckks_params.logn 132 | 133 | def get_margin(self): 134 | return self.orion_params.margin 135 | 136 | def get_fuse_modules(self): 137 | return self.orion_params.fuse_modules 138 | 139 | def get_debug_status(self): 140 | return self.orion_params.debug 141 | 142 | def get_backend(self): 143 | return self.orion_params.backend.lower() 144 | 145 | def get_logq(self): 146 | return self.ckks_params.logq 147 | 148 | def get_logp(self): 149 | return self.ckks_params.logp 150 | 151 | def get_logscale(self): 152 | return self.ckks_params.logscale 153 | 154 | def get_default_scale(self): 155 | return 1 << self.ckks_params.logscale 156 | 157 | def get_hamming_weight(self): 158 | return self.ckks_params.h 159 | 160 | def get_ringtype(self): 161 | return self.ckks_params.ringtype.lower() 162 | 163 | def get_max_level(self): 164 | return len(self.ckks_params.logq) - 1 165 | 166 | def get_slots(self): 167 | return int(1 << self.ckks_params.logslots) 168 | 169 | def get_ring_degree(self): 170 | return int(1 << self.ckks_params.logn) 171 | 172 | def get_embedding_method(self): 173 | return self.orion_params.embedding_method.lower() 174 | 175 | def get_diags_path(self): 176 | path = self.orion_params.diags_path 177 | return os.path.abspath(os.path.join(os.getcwd(), path)) 178 | 179 | def get_keys_path(self): 180 | path = self.orion_params.keys_path 181 | return os.path.abspath(os.path.join(os.getcwd(), path)) 182 | 183 | def get_io_mode(self): 184 | return self.orion_params.io_mode.lower() 185 | 186 | def get_boot_logp(self): 187 | return self.ckks_params.boot_logp 188 | 189 | def io_paths_exist(self): 190 | return bool(self.get_diags_path()) and bool(self.get_keys_path()) 191 | 192 | def reset_stored_file(self, path: str, file_type: str): 193 | if self.get_io_mode() == "save" and path: 194 | print(f"Deleting existing {file_type} at {path}") 195 | abs_path = os.path.abspath(os.path.join(os.getcwd(), path)) 196 | if os.path.exists(abs_path): 197 | os.remove(abs_path) 198 | 199 | def reset_stored_diags(self): 200 | self.reset_stored_file(self.get_diags_path(), "diagonals") 201 | 202 | def reset_stored_keys(self): 203 | self.reset_stored_file(self.get_keys_path(), "keys") -------------------------------------------------------------------------------- /orion/backend/python/poly_evaluator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .tensors import CipherTensor 5 | 6 | class NewEvaluator: 7 | def __init__(self, scheme): 8 | self.scheme = scheme 9 | self.backend = scheme.backend 10 | self.new_polynomial_evaluator() 11 | 12 | def new_polynomial_evaluator(self): 13 | self.backend.NewPolynomialEvaluator() 14 | 15 | def generate_monomial(self, coeffs): 16 | if isinstance(coeffs, (torch.Tensor, np.ndarray)): 17 | coeffs = coeffs.tolist() 18 | return self.backend.GenerateMonomial(coeffs[::-1]) 19 | 20 | def generate_chebyshev(self, coeffs): 21 | if isinstance(coeffs, (torch.Tensor, np.ndarray)): 22 | coeffs = coeffs.tolist() 23 | return self.backend.GenerateChebyshev(coeffs) 24 | 25 | def evaluate_polynomial(self, ciphertensor, poly, out_scale=None): 26 | out_scale = out_scale or self.scheme.params.get_default_scale() 27 | 28 | cts_out = [] 29 | for ctxt in ciphertensor.ids: 30 | ct_out = self.backend.EvaluatePolynomial(ctxt, poly, out_scale) 31 | cts_out.append(ct_out) 32 | 33 | return CipherTensor( 34 | self.scheme, cts_out, ciphertensor.shape, ciphertensor.on_shape) 35 | 36 | def generate_minimax_sign_coeffs(self, degrees, prec=128, logalpha=12, 37 | logerr=12, debug=False): 38 | if isinstance(degrees, int): 39 | degrees = [degrees] 40 | else: 41 | degrees = list(degrees) 42 | 43 | degrees = [d for d in degrees if d != 0] 44 | if len(degrees) == 0: 45 | raise ValueError( 46 | "At least one non-zero degree polynomial must be provided to " 47 | "generate_minimax_sign_coeffs(). " 48 | ) 49 | 50 | coeffs_flat = self.backend.GenerateMinimaxSignCoeffs( 51 | degrees, prec, logalpha, logerr, int(debug) 52 | ) 53 | 54 | coeffs_flat = torch.tensor(coeffs_flat) 55 | splits = [degree + 1 for degree in degrees] 56 | return torch.split(coeffs_flat, splits) 57 | 58 | def get_depth(self, poly): 59 | return self.backend.GetPolyDepth(poly) -------------------------------------------------------------------------------- /orion/backend/python/tensors.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | 4 | class PlainTensor: 5 | def __init__(self, scheme, ptxt_ids, shape, on_shape=None): 6 | self.scheme = scheme 7 | self.backend = scheme.backend 8 | self.encoder = scheme.encoder 9 | 10 | self.ids = [ptxt_ids] if isinstance(ptxt_ids, int) else ptxt_ids 11 | self.shape = shape 12 | self.on_shape = on_shape or shape 13 | 14 | def __del__(self): 15 | if 'sys' in globals() and sys.modules and self.scheme: 16 | try: 17 | for idx in self.ids: 18 | self.backend.DeletePlaintext(idx) 19 | except Exception: 20 | pass # avoids errors for GC at program termination 21 | 22 | def __len__(self): 23 | return len(self.ids) 24 | 25 | def __str__(self): 26 | return str(self.decode()) 27 | 28 | def mul(self, other, in_place=False): 29 | if not isinstance(other, CipherTensor): 30 | raise ValueError(f"Multiplication between PlainTensor and " 31 | f"{type(other)} is not supported.") 32 | 33 | mul_ids = [] 34 | for i in range(len(self.ids)): 35 | mul_id = self.evaluator.mul_ciphertext( 36 | other.ids[i], self.ids[i], in_place) 37 | mul_ids.append(mul_id) 38 | 39 | if in_place: 40 | return other 41 | return CipherTensor(self.scheme, mul_ids, self.shape, self.on_shape) 42 | 43 | def __mul__(self, other): 44 | return self.mul(other, in_place=False) 45 | 46 | def __imul__(self, other): 47 | return self.mul(other, in_place=True) 48 | 49 | def _check_valid(self, other): 50 | return 51 | 52 | def get_ids(self): 53 | return self.ids 54 | 55 | def scale(self): 56 | return self.backend.GetPlaintextScale(self.ids[0]) 57 | 58 | def set_scale(self, scale): 59 | for ptxt in self.ids: 60 | self.backend.SetPlaintextScale(ptxt, scale) 61 | 62 | def level(self): 63 | return self.backend.GetPlaintextLevel(self.ids[0]) 64 | 65 | def slots(self): 66 | return self.backend.GetPlaintextSlots(self.ids[0]) 67 | 68 | def min(self): 69 | return self.decode().min() 70 | 71 | def max(self): 72 | return self.decode().max() 73 | 74 | def moduli(self): 75 | return self.backend.GetModuliChain() 76 | 77 | def decode(self): 78 | return self.encoder.decode(self) 79 | 80 | 81 | class CipherTensor: 82 | def __init__(self, scheme, ctxt_ids, shape, on_shape=None): 83 | self.scheme = scheme 84 | self.backend = scheme.backend 85 | self.encryptor = scheme.encryptor 86 | self.evaluator = scheme.evaluator 87 | self.bootstrapper = scheme.bootstrapper 88 | 89 | self.ids = [ctxt_ids] if isinstance(ctxt_ids, int) else ctxt_ids 90 | self.shape = shape 91 | self.on_shape = on_shape or shape 92 | 93 | def __del__(self): 94 | if 'sys' in globals() and sys.modules and self.scheme: 95 | try: 96 | for idx in self.ids: 97 | self.backend.DeleteCiphertext(idx) 98 | except Exception: 99 | pass # avoids errors for GC at program termination 100 | 101 | def __len__(self): 102 | return len(self.ids) 103 | 104 | def __str__(self): 105 | ptxt = self.decrypt() 106 | return str(ptxt.decode()) 107 | 108 | #--------------# 109 | # Operations # 110 | #--------------# 111 | 112 | def __neg__(self): 113 | neg_ids = [] 114 | for ctxt in self.ids: 115 | neg_id = self.evaluator.negate(ctxt) 116 | neg_ids.append(neg_id) 117 | 118 | return CipherTensor(self.scheme, neg_ids, self.shape, self.on_shape) 119 | 120 | def add(self, other, in_place=False): 121 | self._check_valid(other) 122 | 123 | add_ids = [] 124 | for i in range(len(self.ids)): 125 | if isinstance(other, (int, float)): 126 | add_id = self.evaluator.add_scalar( 127 | self.ids[i], other, in_place) 128 | elif isinstance(other, PlainTensor): 129 | add_id = self.evaluator.add_plaintext( 130 | self.ids[i], other.ids[i], in_place) 131 | elif isinstance(other, CipherTensor): 132 | add_id = self.evaluator.add_ciphertext( 133 | self.ids[i], other.ids[i], in_place) 134 | else: 135 | raise ValueError(f"Addition between CipherTensor and " 136 | f"{type(other)} is not supported.") 137 | 138 | add_ids.append(add_id) 139 | 140 | if in_place: 141 | return self 142 | return CipherTensor(self.scheme, add_ids, self.shape, self.on_shape) 143 | 144 | def __add__(self, other): 145 | return self.add(other, in_place=False) 146 | 147 | def __iadd__(self, other): 148 | return self.add(other, in_place=True) 149 | 150 | def sub(self, other, in_place=False): 151 | self._check_valid(other) 152 | 153 | sub_ids = [] 154 | for i in range(len(self.ids)): 155 | if isinstance(other, (int, float)): 156 | sub_id = self.evaluator.sub_scalar( 157 | self.ids[i], other, in_place) 158 | elif isinstance(other, PlainTensor): 159 | sub_id = self.evaluator.sub_plaintext( 160 | self.ids[i], other.ids[i], in_place) 161 | elif isinstance(other, CipherTensor): 162 | sub_id = self.evaluator.sub_ciphertext( 163 | self.ids[i], other.ids[i], in_place) 164 | else: 165 | raise ValueError(f"Subtraction between CipherTensor and " 166 | f"{type(other)} is not supported.") 167 | 168 | sub_ids.append(sub_id) 169 | 170 | if in_place: 171 | return self 172 | return CipherTensor(self.scheme, sub_ids, self.shape, self.on_shape) 173 | 174 | def __sub__(self, other): 175 | return self.sub(other, in_place=False) 176 | 177 | def __isub__(self, other): 178 | return self.sub(other, in_place=True) 179 | 180 | def mul(self, other, in_place=False): 181 | self._check_valid(other) 182 | 183 | mul_ids = [] 184 | for i in range(len(self.ids)): 185 | if isinstance(other, (int, float)): 186 | mul_id = self.evaluator.mul_scalar( 187 | self.ids[i], other, in_place) 188 | elif isinstance(other, PlainTensor): 189 | mul_id = self.evaluator.mul_plaintext( 190 | self.ids[i], other.ids[i], in_place) 191 | elif isinstance(other, CipherTensor): 192 | mul_id = self.evaluator.mul_ciphertext( 193 | self.ids[i], other.ids[i], in_place) 194 | else: 195 | raise ValueError(f"Multiplication between CipherTensor and " 196 | f"{type(other)} is not supported.") 197 | 198 | mul_ids.append(mul_id) 199 | 200 | if in_place: 201 | return self 202 | return CipherTensor(self.scheme, mul_ids, self.shape, self.on_shape) 203 | 204 | def __mul__(self, other): 205 | return self.mul(other, in_place=False) 206 | 207 | def __imul__(self, other): 208 | return self.mul(other, in_place=True) 209 | 210 | def roll(self, amount, in_place=False): 211 | rot_ids = [] 212 | for ctxt in self.ids: 213 | rot_id = self.evaluator.rotate(ctxt, amount, in_place) 214 | rot_ids.append(rot_id) 215 | 216 | return CipherTensor(self.scheme, rot_ids, self.shape, self.on_shape) 217 | 218 | def _check_valid(self, other): 219 | return 220 | 221 | #---------------------- 222 | # 223 | #--------------------- 224 | 225 | def scale(self): 226 | return self.backend.GetCiphertextScale(self.ids[0]) 227 | 228 | def set_scale(self, scale): 229 | for ctxt in self.ids: 230 | self.backend.SetCiphertextScale(ctxt, scale) 231 | 232 | def level(self): 233 | return self.backend.GetCiphertextLevel(self.ids[0]) 234 | 235 | def slots(self): 236 | return self.backend.GetCiphertextSlots(self.ids[0]) 237 | 238 | def degree(self): 239 | return self.backend.GetCiphertextDegree(self.ids[0]) 240 | 241 | def min(self): 242 | return self.decrypt().min() 243 | 244 | def max(self): 245 | return self.decrypt().max() 246 | 247 | def moduli(self): 248 | return self.backend.GetModuliChain() 249 | 250 | def bootstrap(self): 251 | elements = self.on_shape.numel() 252 | slots = 2 ** math.ceil(math.log2(elements)) 253 | slots = int(min(self.slots(), slots)) # sparse bootstrapping 254 | 255 | btp_ids = [] 256 | for ctxt in self.ids: 257 | btp_id = self.bootstrapper.bootstrap(ctxt, slots) 258 | btp_ids.append(btp_id) 259 | 260 | return CipherTensor(self.scheme, btp_ids, self.shape, self.on_shape) 261 | 262 | def decrypt(self): 263 | return self.encryptor.decrypt(self) -------------------------------------------------------------------------------- /orion/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .orion import scheme 2 | 3 | init_scheme = scheme.init_scheme 4 | delete_scheme = scheme.delete_scheme 5 | encode = scheme.encode 6 | decode = scheme.decode 7 | encrypt = scheme.encrypt 8 | decrypt = scheme.decrypt 9 | fit = scheme.fit 10 | compile = scheme.compile -------------------------------------------------------------------------------- /orion/core/auto_bootstrap.py: -------------------------------------------------------------------------------- 1 | import math 2 | import networkx as nx 3 | import matplotlib.pyplot as plt 4 | 5 | from .level_dag import LevelDAG 6 | from orion.nn.operations import Bootstrap 7 | 8 | 9 | class BootstrapSolver: 10 | def __init__(self, net, network_dag, l_eff): 11 | self.net = net 12 | self.network_dag = network_dag 13 | self.l_eff = l_eff 14 | self.full_level_dag = LevelDAG(l_eff, network_dag) 15 | 16 | def extract_all_residual_subgraphs(self): 17 | all_residual_subgraphs = [] 18 | for fork in self.network_dag.residuals.keys(): 19 | subgraph = self.network_dag.extract_residual_subgraph(fork) 20 | all_residual_subgraphs.append(subgraph) 21 | 22 | return all_residual_subgraphs 23 | 24 | def sort_residual_subgraphs(self): 25 | # Sort the residual subgraphs by their number of paths from fork 26 | # to join node. 27 | all_residual_subgraphs = self.extract_all_residual_subgraphs() 28 | 29 | residuals = [] 30 | for i, (fork, join) in enumerate(self.network_dag.residuals.items()): 31 | subgraph = all_residual_subgraphs[i] 32 | paths = list(nx.all_simple_paths(subgraph, fork, join)) 33 | 34 | unique_paths = [] 35 | visited_children = set() 36 | for path in paths: 37 | if path[1] not in visited_children: 38 | unique_paths.append(path) 39 | visited_children.add(path[1]) 40 | 41 | residuals.append((fork, paths, unique_paths)) 42 | 43 | # Sort by the number of simple paths from fork to join in the graph. 44 | # This way, we're guaranteed to always solve the "inner-most" 45 | # residual subgraph in the event it is entirely encapsulated by 46 | # a larger residual connection. 47 | sorted_subgraphs = sorted(residuals, key=lambda x: len(x[1])) 48 | 49 | return sorted_subgraphs 50 | 51 | def first_solve_residual_subgraphs(self): 52 | # We'll first extract all residual subgraphs in the network and create 53 | # their aggregate level DAGs. We'll be iterating over DAGs sorted in 54 | # increasing order by the number of paths from their corresponding fork 55 | # and join nodes. This guarantees we solve the "inner-most" level DAGs 56 | # first, which can then be inserted into subsequent calls. 57 | 58 | sorted_residual_subgraphs = self.sort_residual_subgraphs() 59 | self.network_dag.solved_residual_level_dags = {} 60 | 61 | for (fork, _, paths) in sorted_residual_subgraphs: 62 | aggregate_level_dag = LevelDAG(self.l_eff, self.network_dag, path=None) 63 | for path in paths: 64 | path_dag = nx.DiGraph() 65 | 66 | # Then we'll just create a new DAG by extracting the 67 | # subgraph along the path. 68 | nodes_in_path = [ 69 | (node, self.network_dag.nodes[node]) 70 | for node in path 71 | ] 72 | edges_in_path = [ 73 | (u, v, self.network_dag[u][v]) 74 | for u, v in zip(path[:-1], path[1:]) 75 | ] 76 | 77 | path_dag.add_nodes_from(nodes_in_path) 78 | path_dag.add_edges_from(edges_in_path) 79 | 80 | # And create the level DAG based on the path. 81 | aggregate_level_dag += LevelDAG( 82 | self.l_eff, self.network_dag, path_dag 83 | ) 84 | 85 | self.network_dag.solved_residual_level_dags[fork] = aggregate_level_dag 86 | 87 | return self.network_dag.solved_residual_level_dags 88 | 89 | def then_build_full_level_dag(self, solved_residual_level_dags): 90 | # We can now either append our aggregate level DAGs from residual 91 | # connections into the network or the next layer. 92 | 93 | all_forks = self.network_dag.residuals.keys() 94 | 95 | visited = set() 96 | for node in nx.topological_sort(self.network_dag): 97 | if node not in visited: 98 | if node in all_forks: 99 | # It is a fork node and so this subgraph has already 100 | # been solved. We'll just connect it to the existing 101 | # full_level_dag. 102 | next_level_dag = solved_residual_level_dags[node] 103 | subgraph = self.network_dag.extract_residual_subgraph(node) 104 | visited.update(subgraph.nodes) 105 | else: 106 | node_dag = nx.DiGraph() 107 | node_dag.add_nodes_from([(node, self.network_dag.nodes[node])]) 108 | next_level_dag = LevelDAG( 109 | self.l_eff, self.network_dag, node_dag 110 | ) 111 | visited.update(node) 112 | 113 | self.full_level_dag.append(next_level_dag) 114 | 115 | def finally_solve_full_level_dag(self): 116 | # Now that we've built our aggregate level DAG, we can now call 117 | # one final shortest path on it to determine the optimal level 118 | # management policy for our network. 119 | 120 | heads = self.full_level_dag.head() 121 | tails = self.full_level_dag.tail() 122 | 123 | self.full_level_dag.add_node("source", weight=0) 124 | self.full_level_dag.add_node("target", weight=0) 125 | 126 | for head, tail in zip(heads, tails): 127 | self.full_level_dag.add_edge("source", head, weight=0) 128 | self.full_level_dag.add_edge(tail, "target", weight=0) 129 | 130 | shortest_path, latency = self.full_level_dag.shortest_path( 131 | source="source", target="target" 132 | ) 133 | 134 | if latency == float("inf"): 135 | raise ValueError( 136 | "Automatic bootstrap placement failed. First try increasing " 137 | "the length of your LogQ moduli chain the associated " 138 | "parameters YAML file. If this fails, double check that the " 139 | "network was instantiated properly." 140 | ) 141 | 142 | # Just remove the source/target we added 143 | shortest_path = shortest_path[1:-1] 144 | 145 | # The shortest path above, while correct, also black-boxes the paths 146 | # within skip connections. We haven't lost this data, we just need 147 | # to access it within edge attributes designed to track it. 148 | reconstructed_path = set() 149 | for u, v in zip(shortest_path[:-1], shortest_path[1:]): 150 | edge = self.full_level_dag[u][v] 151 | reconstructed_path.update(edge["path"]) 152 | 153 | self.shortest_path = reconstructed_path 154 | 155 | input_level = int(shortest_path[1].split("=")[-1]) 156 | return input_level 157 | 158 | def solve(self): 159 | solved_residual_dags = self.first_solve_residual_subgraphs() 160 | self.then_build_full_level_dag(solved_residual_dags) 161 | input_level = self.finally_solve_full_level_dag() 162 | 163 | self.assign_levels_to_layers() 164 | num_bootstraps, bootstrapper_slots = self.mark_bootstrap_locations() 165 | 166 | return input_level, num_bootstraps, bootstrapper_slots 167 | 168 | def assign_levels_to_layers(self): 169 | # Set each Orion module's attribute with it's level found by this 170 | # algorithm. This let's linear transforms be encoded at the 171 | # correct level. 172 | for node in self.network_dag.nodes: 173 | node_module = self.network_dag.nodes[node]["module"] 174 | for layer in self.shortest_path: 175 | name = layer.split("@")[0] 176 | level = int(layer.split("=")[-1]) 177 | 178 | if node == name: 179 | self.network_dag.nodes[node]["level"] = level 180 | if node_module: 181 | node_module.level = level 182 | continue 183 | 184 | def mark_bootstrap_locations(self): 185 | # Makes things a bit easier below 186 | node_map = {} 187 | for node in self.shortest_path: 188 | name = node.split("@")[0] 189 | node_map[name] = node 190 | 191 | # We'll use this empty level DAG to query the number of 192 | # bootstraps per layer of the network dag. 193 | query = LevelDAG(self.l_eff, self.network_dag, path=None) 194 | 195 | total_bootstraps = 0 196 | bootstrapper_slots = [] 197 | 198 | for node in self.network_dag.nodes: 199 | node_w_level = node_map[node] 200 | 201 | children = self.network_dag.successors(node) 202 | self.network_dag.nodes[node]["bootstrap"] = False 203 | 204 | # Iterate over the layer's children to determine if their assigned 205 | # levels necessitate a bootstrap of the current layer. 206 | for child in children: 207 | child_w_level = node_map[child] 208 | _, curr_boots = query.estimate_bootstrap_latency( 209 | node_w_level, child_w_level) 210 | 211 | total_bootstraps += curr_boots 212 | if curr_boots > 0: 213 | self.network_dag.nodes[node]["bootstrap"] = True 214 | slots = self.get_bootstrap_slots(node) 215 | 216 | # Add bootstrapper to generate 217 | if slots not in bootstrapper_slots: 218 | bootstrapper_slots.append(slots) 219 | break 220 | 221 | return total_bootstraps , bootstrapper_slots 222 | 223 | def get_bootstrap_slots(self, node): 224 | # If we're here, then our auto-bootstrapper has determined that the 225 | # output of this node will be bootstrapped. Therefore it must be an 226 | # Orion module, and so a module attribute exists. 227 | module = self.network_dag.nodes[node]["module"] 228 | max_slots = module.scheme.params.get_slots() 229 | 230 | elements = module.fhe_output_shape.numel() 231 | curr_slots = 2 ** math.ceil(math.log2(elements)) 232 | slots = int(min(max_slots, curr_slots)) # sparse bootstrapping 233 | 234 | return slots 235 | 236 | def plot_shortest_path(self, save_path="", figsize=(10,10)): 237 | """Plot the network digraph. For the best visualization, please install 238 | Graphviz and PyGraphviz.""" 239 | 240 | nodes = {} 241 | for node in self.shortest_path: 242 | name = node.split("@")[0] 243 | level = node.split("=")[-1] 244 | nodes[name] = level 245 | 246 | network = nx.DiGraph(self.network_dag) 247 | shortest_graph = nx.DiGraph() 248 | 249 | for name, level in nodes.items(): 250 | shortest_graph.add_node(name, level=level) 251 | 252 | # Add edges from the original graph 253 | for u, v in network.edges(): 254 | if u in nodes and v in nodes: 255 | shortest_graph.add_edge(u, v) 256 | 257 | try: 258 | pos = nx.nx_agraph.graphviz_layout(shortest_graph, prog='dot') 259 | except: 260 | print("Graphviz not installed. Defaulting to worse visualization.\n") 261 | pos = nx.kamada_kawai_layout(shortest_graph) 262 | 263 | plt.figure(figsize=figsize) 264 | nx.draw( 265 | shortest_graph, pos, with_labels=False, arrows=True, font_size=8) 266 | 267 | node_labels = { 268 | node: f"{node}\n(level: {data['level']})" 269 | for node, data in shortest_graph.nodes(data=True) 270 | } 271 | nx.draw_networkx_labels( 272 | shortest_graph, pos, labels=node_labels, font_size=8) 273 | 274 | if save_path: 275 | plt.savefig(save_path) 276 | plt.show() 277 | 278 | 279 | class BootstrapPlacer: 280 | def __init__(self, net, network_dag): 281 | self.net = net 282 | self.network_dag = network_dag 283 | 284 | def place_bootstraps(self): 285 | for node in self.network_dag.nodes: 286 | if self.network_dag.nodes[node]["bootstrap"]: 287 | module = self.network_dag.nodes[node]["module"] 288 | self._apply_bootstrap_hook(module) 289 | 290 | def _apply_bootstrap_hook(self, module): 291 | bootstrapper = self._create_bootstrapper(module) 292 | module.bootstrapper = bootstrapper 293 | 294 | # Register a forward hook that applies bootstrapping to outputs 295 | module.register_forward_hook(lambda mod, input, output: bootstrapper(output)) 296 | 297 | def _create_bootstrapper(self, module): 298 | # Set bootstrap statistics to scale into [-1, 1] 299 | btp_input_level = module.level - module.depth 300 | btp_input_min = module.output_min 301 | btp_input_max = module.output_max 302 | 303 | bootstrapper = Bootstrap(btp_input_min, btp_input_max, btp_input_level) 304 | 305 | bootstrapper.scheme = self.net.scheme 306 | bootstrapper.margin = self.net.margin 307 | bootstrapper.fhe_input_shape = module.fhe_output_shape 308 | bootstrapper.fit() 309 | bootstrapper.compile() 310 | 311 | return bootstrapper -------------------------------------------------------------------------------- /orion/core/fuser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import networkx as nx 3 | 4 | from orion.nn.activation import Chebyshev 5 | from orion.nn.linear import Linear, Conv2d 6 | from orion.nn.normalization import BatchNorm1d, BatchNorm2d 7 | 8 | 9 | class Fuser: 10 | def __init__(self, network_dag: nx.DiGraph): 11 | self.network_dag = network_dag 12 | 13 | def _fuse_linear_chebyshev(self, linear, cheb): 14 | linear.on_weight = linear.on_weight * cheb.prescale 15 | linear.on_bias = linear.on_bias * cheb.prescale + cheb.constant 16 | 17 | cheb.fused = True 18 | cheb.depth -= 1 # The prescale no longer consumes a level 19 | 20 | def _fuse_bn_chebyshev(self, bn, cheb): 21 | if bn.affine: 22 | bn.on_weight = bn.on_weight * cheb.prescale 23 | bn.on_bias = bn.on_bias * cheb.prescale + cheb.constant 24 | else: 25 | bn.affine = True 26 | bn.on_weight = torch.ones(bn.num_features) * cheb.prescale 27 | bn.on_bias = torch.ones(bn.num_features) * cheb.constant 28 | 29 | cheb.fused = True 30 | cheb.depth -= 1 31 | 32 | def _fuse_linear_bn(self, linear, bn): 33 | on_inv_running_std = 1 / torch.sqrt(bn.on_running_var + bn.eps) 34 | scale = bn.on_weight * on_inv_running_std 35 | 36 | if len(linear.on_weight.shape) == 2: # fc layer 37 | linear.on_weight *= scale.reshape(-1, 1) 38 | else: # conv2d layer 39 | linear.on_weight *= scale.reshape(-1, 1, 1, 1) 40 | 41 | linear.on_bias = scale * (linear.on_bias - bn.running_mean) + bn.on_bias 42 | 43 | bn.fused = True 44 | bn.depth -= (2 if bn.affine else 1) 45 | 46 | def fuse_two_layers(self, parent_class, child_class, fusing_function): 47 | 48 | def get_parent_modules(node): 49 | parent_modules = [] 50 | for parent in self.network_dag.predecessors(node): 51 | parent_module = self.network_dag.nodes[parent]["module"] 52 | if isinstance(parent_module, parent_class): 53 | parent_modules.append(parent_module) 54 | 55 | return parent_modules 56 | 57 | # We'll iterate over all nodes in our network to determine if the 58 | # pattern ever goes parent_class -> child_class sequentially. If 59 | # so, then we'll apply `fusing_function` to those two modules. 60 | for node in self.network_dag.nodes: 61 | child_module = self.network_dag.nodes[node]["module"] 62 | if isinstance(child_module, child_class): 63 | parent_modules = get_parent_modules(node) 64 | 65 | for parent_module in parent_modules: 66 | fusing_function(parent_module, child_module) 67 | 68 | def fuse_linear_chebyshev(self): 69 | self.fuse_two_layers((Linear, Conv2d), Chebyshev, 70 | self._fuse_linear_chebyshev) 71 | 72 | def fuse_bn_chebyshev(self): 73 | self.fuse_two_layers((BatchNorm1d, BatchNorm2d), Chebyshev, 74 | self._fuse_bn_chebyshev) 75 | 76 | def fuse_linear_bn(self): 77 | self.fuse_two_layers(Linear, BatchNorm1d, self._fuse_linear_bn) 78 | self.fuse_two_layers(Conv2d, BatchNorm2d, self._fuse_linear_bn) 79 | 80 | def fuse_modules(self): 81 | self.fuse_linear_chebyshev() 82 | self.fuse_bn_chebyshev() 83 | self.fuse_linear_bn() 84 | 85 | -------------------------------------------------------------------------------- /orion/core/network_dag.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | 4 | from orion.nn.normalization import BatchNormNd 5 | 6 | class NetworkDAG(nx.DiGraph): 7 | """ 8 | Represents a neural network as a directed acyclic graph (DAG) using 9 | NetworkX. This class builds a DAG from a symbolic trace of a PyTorch 10 | network, identifies residual connections, and provides several useful 11 | methods that we will use in our automatic bootstrap placement algorithm. 12 | """ 13 | def __init__(self, trace): 14 | super().__init__() 15 | self.trace = trace 16 | self.residuals = {} 17 | 18 | def build_dag(self): 19 | """Builds the DAG representation of the neural network based on 20 | the provided symbolic trace.""" 21 | 22 | for node in self.trace.graph.nodes: 23 | # If the user assumes a default layer parameter (e.g. bias=False) 24 | # this will appear as an unconnected node with node.users = 0. 25 | # It is fine to disregard these cases. 26 | if len(node.users) > 0: 27 | module = None 28 | if node.op == "call_module": 29 | module = self.trace.get_submodule(node.target) 30 | 31 | # Insert the node into the graph 32 | self.add_node(node.name, fx_node=node, op=node.op, module=module) 33 | for input_node in node.all_input_nodes: 34 | self.add_edge(input_node.name, node.name) 35 | 36 | def find_residuals(self): 37 | """Finds pairs of fork/join nodes representing residual connections. 38 | We consider a fork (join) node to be any Orion module or arithmetic 39 | operation in our computational graph that has two or more children 40 | (parents). Each residual connection creates a pair of fork/join nodes 41 | that become the start/end nodes of each subgraph that we will 42 | ultimately extract in our automatic bootstrap placement algorithm.""" 43 | 44 | # Residual connections in FHE are particularly difficult to deal with. 45 | # Each residual connection creates a pair of fork and join nodes in our 46 | # graph. For every fork, there is a join somewhere later in the graph. 47 | # Our automatic bootstrap placement algorithm relies on extracting the 48 | # subgraphs between pairs of fork/join nodes. This function nicely finds 49 | # fork/join pairs and stores them in the self.residuals dictionary so 50 | # we can reference them later. 51 | topo = list(nx.topological_sort(self)) 52 | for start_node in list(self.nodes): 53 | successors = list(self.successors(start_node)) 54 | 55 | # Fork node found 56 | if len(successors) > 1: 57 | paths = [] 58 | # For each child of the node, get a path from that child to the 59 | # last node in the network. 60 | for source in successors: 61 | path = nx.shortest_path(self, source, topo[-1]) 62 | paths.append(set(path)) 63 | 64 | # By set intersecting all paths from child -> end, we can find 65 | # nodes common between all paths. 66 | common_nodes = list(set.intersection(*paths)) 67 | 68 | # The join node is the "first" common node of the graph in 69 | # topological order. 70 | end_node = [node for node in topo if node in common_nodes][0] 71 | 72 | # Finally, we'll insert special (auxiliary) fork/join nodes into 73 | # the graph. This makes our automatic bootstrap placement 74 | # algorithm slightly cleaner. 75 | fork, join = self.insert_fork_and_join_nodes(start_node, end_node) 76 | self.residuals[fork] = join 77 | 78 | def insert_fork_and_join_nodes(self, start, end): 79 | """Inserts special fork/join nodes into the graph around the residual 80 | connection. This makes our automatic bootstrap placement algorithm 81 | slightly cleaner.""" 82 | 83 | fork = f"{start}_fork" 84 | join = f"{end}_join" 85 | 86 | # Add fork/join nodes to the network 87 | self.add_node(fork, op="fork", module=None) 88 | self.add_node(join, op="join", module=None) 89 | 90 | # Insert fork node and update edges 91 | for child in list(self.successors(start)): 92 | self.remove_edge(start, child) 93 | self.add_edge(fork, child) 94 | self.add_edge(start, fork) 95 | 96 | # Insert join node and update edges 97 | for parent in list(self.predecessors(end)): 98 | self.remove_edge(parent, end) 99 | self.add_edge(parent, join) 100 | self.add_edge(join, end) 101 | 102 | return fork, join 103 | 104 | def extract_residual_subgraph(self, fork): 105 | """A helper function designed to extract the subgraphs between 106 | the fork/join nodes of a residual connection.""" 107 | 108 | nodes_in_residual = set() 109 | edges_in_residual = set() 110 | 111 | # Get all paths from fork -> join and build up a set of unique 112 | # nodes/edges in its subgraph 113 | join = self.residuals[fork] 114 | for path in nx.all_simple_paths(self, fork, join): 115 | nodes_in_residual.update(path) 116 | edges_in_residual.update(zip(path[:-1], path[1:])) 117 | 118 | # Rebuild the subgraph from the nodes/edges 119 | residual_subgraph = nx.DiGraph() 120 | residual_subgraph.add_nodes_from(nodes_in_residual) 121 | residual_subgraph.add_edges_from(edges_in_residual) 122 | 123 | return residual_subgraph 124 | 125 | def remove_fused_batchnorms(self): 126 | """Removes BatchNorm nodes from the graph when it is known that they 127 | can be fused with preceding linear layers.""" 128 | 129 | for node in list(self.nodes): 130 | node_module = self.nodes[node]["module"] 131 | 132 | if isinstance(node_module, BatchNormNd): 133 | # Get the parents and children of the batchnorm node 134 | parent_nodes = list(self.predecessors(node)) 135 | child_nodes = list(self.successors(node)) 136 | 137 | # Our tracer has already verified that the BN node only has 138 | # one parent. 139 | parent = parent_nodes[0] 140 | 141 | # Our fuser will have fused this BN node if it was possible, 142 | # and it's fused attribute will have been set. Remove this 143 | # BN node so that it isn't counted when we assign levels to 144 | # layers further into the compilation process. 145 | if node_module.fused: 146 | for child in child_nodes: 147 | self.add_edge(parent, child) 148 | self.remove_node(node) 149 | 150 | def topological_sort(self): 151 | return nx.topological_sort(self) 152 | 153 | def plot(self, save_path="", figsize=(10,10)): 154 | """Plot the network digraph. For the best visualization, please install 155 | Graphviz and PyGraphviz.""" 156 | 157 | try: 158 | pos = nx.nx_agraph.graphviz_layout(self, prog='dot') 159 | except: 160 | print("Graphviz not installed. Defaulting to worse visualization.\n") 161 | pos = nx.kamada_kawai_layout(self) 162 | 163 | plt.figure(figsize=figsize) 164 | nx.draw(self, pos, with_labels=True, arrows=True, font_size=8) 165 | 166 | if save_path: 167 | plt.savefig(save_path) 168 | plt.show() -------------------------------------------------------------------------------- /orion/core/orion.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from typing import Union, Dict, Any 4 | 5 | import yaml 6 | import torch 7 | from tqdm import tqdm 8 | from torch.utils.data import DataLoader, RandomSampler 9 | 10 | from orion.nn.module import Module 11 | from orion.nn.linear import LinearTransform 12 | from orion.backend.lattigo import bindings as lgo 13 | from orion.backend.python import ( 14 | parameters, 15 | key_generator, 16 | encoder, 17 | encryptor, 18 | evaluator, 19 | poly_evaluator, 20 | lt_evaluator, 21 | bootstrapper 22 | ) 23 | 24 | from .tracer import StatsTracker, OrionTracer 25 | from .fuser import Fuser 26 | from .network_dag import NetworkDAG 27 | from .auto_bootstrap import BootstrapSolver, BootstrapPlacer 28 | 29 | 30 | class Scheme: 31 | """ 32 | This Scheme class drives most of the functionality in Orion. It 33 | configures and manages how our framework interfaces with FHE backends, 34 | and exposes this functionality to the user through attributes such as 35 | the encoder, evaluators (linear transform, polynomial, etc.) and 36 | bootstrappers. 37 | 38 | It also serves two important purposes required before running FHE 39 | inference: fitting the network and then compiling it. The fit() method 40 | runs cleartext forward passes through the network to determine per-layer 41 | input ranges, which are then used to fit polynomial approximations to 42 | common activation functions (e.g., SiLU, ReLU). 43 | 44 | The compile() function is responsible for all packing of data and 45 | determines a level management policy by running our automatic bootstrap 46 | placement algorithm. Once done, each Orion module is automatically 47 | assigned a level that can then be used in its compilation. This primarily 48 | includes generating the plaintexts needed for each linear transform. 49 | """ 50 | 51 | def __init__(self): 52 | self.backend = None 53 | self.traced = None 54 | 55 | def init_scheme(self, config: Union[str, Dict[str, Any]]): 56 | """Initializes the scheme.""" 57 | if isinstance(config, str): 58 | try: 59 | with open(config, "r") as f: 60 | config = yaml.safe_load(f) 61 | except FileNotFoundError: 62 | raise ValueError(f"Configuration file '{config}' not found.") 63 | elif not isinstance(config, dict): 64 | raise TypeError("Config must be a file path (str) or a dictionary.") 65 | 66 | self.params = parameters.NewParameters(config) 67 | self.backend = self.setup_backend(self.params) 68 | 69 | self.keygen = key_generator.NewKeyGenerator(self) 70 | self.encoder = encoder.NewEncoder(self) 71 | self.encryptor = encryptor.NewEncryptor(self) 72 | self.evaluator = evaluator.NewEvaluator(self) 73 | self.poly_evaluator = poly_evaluator.NewEvaluator(self) 74 | self.lt_evaluator = lt_evaluator.NewEvaluator(self) 75 | self.bootstrapper = bootstrapper.NewEvaluator(self) 76 | 77 | return self 78 | 79 | def delete_scheme(self): 80 | if self.backend: 81 | self.backend.DeleteScheme() 82 | 83 | def __del__(self): 84 | self.delete_scheme() 85 | 86 | def __str__(self): 87 | return str(self.params) 88 | 89 | def setup_backend(self, params): 90 | backend = params.get_backend() 91 | if backend == "lattigo": 92 | py_lattigo = lgo.LattigoLibrary() 93 | py_lattigo.setup_bindings(params) 94 | return py_lattigo 95 | elif backend in ("heaan", "openfhe"): 96 | raise ValueError(f"Backend {backend} not yet supported.") 97 | else: 98 | raise ValueError( 99 | f"Invalid {backend}. Set the backend to Lattigo until " 100 | f"further notice." 101 | ) 102 | 103 | def encode(self, tensor, level=None, scale=None): 104 | self._check_initialization() 105 | return self.encoder.encode(tensor, level, scale) 106 | 107 | def decode(self, ptxt): 108 | self._check_initialization() 109 | return self.encoder.decode(ptxt) 110 | 111 | def encrypt(self, ptxt): 112 | self._check_initialization() 113 | return self.encryptor.encrypt(ptxt) 114 | 115 | def decrypt(self, ctxt): 116 | self._check_initialization() 117 | return self.encryptor.decrypt(ctxt) 118 | 119 | def fit(self, net, input_data, batch_size=128): 120 | self._check_initialization() 121 | 122 | net.set_scheme(self) 123 | net.set_margin(self.params.get_margin()) 124 | 125 | tracer = OrionTracer() 126 | traced = tracer.trace_model(net) 127 | self.traced = traced 128 | 129 | stats_tracker = StatsTracker(traced) 130 | 131 | #-----------------------------------------# 132 | # Populate layers with useful metadata # 133 | #-----------------------------------------# 134 | 135 | # Send input_data to the same device as the model. 136 | param = next(iter(net.parameters()), None) 137 | device = param.device if param is not None else torch.device("cpu") 138 | 139 | print("\n{1} Finding per-layer input/output ranges and shapes...", 140 | flush=True) 141 | start = time.time() 142 | if isinstance(input_data, DataLoader): 143 | # Users often specify small batch sizes for FHE operations. 144 | # However, fitting statistics with large datasets would take 145 | # unnecessarily long with small batches. To speed this up, we'll 146 | # temporarily increase the batch size during the statistics-fitting 147 | # step, and then restore the original batch size afterward. 148 | user_batch_size = input_data.batch_size 149 | if batch_size > user_batch_size: 150 | dataset = input_data.dataset 151 | shuffle = input_data.sampler is None or isinstance(input_data.sampler, RandomSampler) 152 | 153 | input_data = DataLoader( 154 | dataset=dataset, 155 | batch_size=batch_size, 156 | shuffle=shuffle, 157 | num_workers=input_data.num_workers, 158 | pin_memory=input_data.pin_memory, 159 | drop_last=input_data.drop_last 160 | ) 161 | 162 | # Use this (potentially new) dataloader 163 | for batch in tqdm(input_data, desc="Processing input data", 164 | unit="batch", leave=True): 165 | stats_tracker.propagate(batch[0].to(device)) 166 | 167 | # Now we'll reset the batch size back to what the user specified. 168 | stats_tracker.update_batch_size(user_batch_size) 169 | 170 | elif isinstance(input_data, torch.Tensor): 171 | stats_tracker.propagate(input_data.to(device)) 172 | else: 173 | raise ValueError( 174 | "Input data must be a torch.Tensor or DataLoader, but " 175 | f"received {type(input_data)}." 176 | ) 177 | 178 | #-------------------------------------# 179 | # Fit polynomial activations # 180 | #-------------------------------------# 181 | 182 | # Now we can use the statistics we just obtained above to fit 183 | # all polynomial activation functions. 184 | print("\n{2} Fitting polynomials... ", end="", flush=True) 185 | start = time.time() 186 | for module in net.modules(): 187 | if hasattr(module, "fit") and callable(module.fit): 188 | module.fit() 189 | print(f"done! [{time.time()-start:.3f} secs.]") 190 | 191 | def compile(self, net): 192 | self._check_initialization() 193 | 194 | if self.traced is None: 195 | raise ValueError( 196 | "Network has not been fit yet! Before running orion.compile(net) " 197 | "you must run orion.fit(net, input_data)." 198 | ) 199 | 200 | #------------------------------------------------# 201 | # Build DAG representation of neural network # 202 | #------------------------------------------------# 203 | 204 | network_dag = NetworkDAG(self.traced) 205 | network_dag.build_dag() 206 | 207 | # Before fusing, we'll instantiate our own Orion parameters (e.g. 208 | # weights and biases) that can be fused/modified without affecting 209 | # the original network's parameters. 210 | for module in net.modules(): 211 | if (hasattr(module, "init_orion_params") and 212 | callable(module.init_orion_params)): 213 | module.init_orion_params() 214 | 215 | #-------------------------------------# 216 | # Resolve pooling kernels # 217 | #-------------------------------------# 218 | 219 | # AvgPools are implemented as grouped convolutions in Orion, which 220 | # are not passed arguments for the number of channels for consistency 221 | # with PyTorch. We must resolve this after the passes above use 222 | # torch.nn.functional. 223 | for module in net.modules(): 224 | if hasattr(module, "update_params") and callable(module.update_params): 225 | module.update_params() 226 | 227 | #------------------------------------------# 228 | # Fuse Orion modules (Conv -> BN, etc) # 229 | #------------------------------------------# 230 | 231 | enable_fusing = self.params.get_fuse_modules() 232 | if enable_fusing: 233 | fuser = Fuser(network_dag) 234 | fuser.fuse_modules() 235 | network_dag.remove_fused_batchnorms() 236 | 237 | #---------------------------------------------# 238 | # Pack diagonals of all linear transforms # 239 | #---------------------------------------------# 240 | 241 | # Then, we must ensure that there is no junk data left in the slots 242 | # of the final linear layer (leaking information about partials). 243 | # This would occur when using the hybrid embedding method. We could 244 | # use an additional level to zero things out, but instead, we'll 245 | # just force the last linear layer to use the "square" embedding 246 | # method which solves this while consuming just one level (albeit 247 | # usually for more ciphertext rotations). 248 | topo_sort = list(network_dag.topological_sort()) 249 | 250 | last_linear = None 251 | for node in reversed(topo_sort): 252 | module = network_dag.nodes[node]["module"] 253 | if isinstance(module, LinearTransform): 254 | last_linear = node 255 | break 256 | 257 | # Now we can generate the diagonals 258 | print("\n{3} Generating matrix diagonals...", flush=True) 259 | for node in topo_sort: 260 | module = network_dag.nodes[node]["module"] 261 | if isinstance(module, LinearTransform): 262 | print(f"\nPacking {node}:") 263 | module.generate_diagonals(last=(node == last_linear)) 264 | 265 | #------------------------------# 266 | # Find and place bootstraps # 267 | #------------------------------# 268 | 269 | network_dag.find_residuals() 270 | #(save_path="network.png", figsize=(8,30)) # optional plot 271 | 272 | print("\n{4} Running bootstrap placement... ", end="", flush=True) 273 | start = time.time() 274 | l_eff = len(self.params.get_logq()) - 1 275 | btp_solver = BootstrapSolver(net, network_dag, l_eff=l_eff) 276 | input_level, num_bootstraps, bootstrapper_slots = btp_solver.solve() 277 | print(f"done! [{time.time()-start:.3f} secs.]", flush=True) 278 | print(f"├── Network requires {num_bootstraps} bootstrap " 279 | f"{'operation' if num_bootstraps == 1 else 'operations'}.") 280 | 281 | #btp_solver.plot_shortest_path( 282 | # save_path="network-with-levels.png", figsize=(8,30) # optional plot 283 | #) 284 | 285 | if bootstrapper_slots: 286 | start = time.time() 287 | slots_str = ", ".join([str(int(math.log2(slot))) for slot in bootstrapper_slots]) 288 | print(f"├── Generating bootstrappers for logslots = {slots_str} ... ", 289 | end="", flush=True) 290 | 291 | # Generate the required (potentially sparse) bootstrappers. 292 | for slot_count in bootstrapper_slots: 293 | self.bootstrapper.generate_bootstrapper(slot_count) 294 | print(f"done! [{time.time()-start:.3f} secs.]") 295 | 296 | btp_placer = BootstrapPlacer(net, network_dag) 297 | btp_placer.place_bootstraps() 298 | 299 | #------------------------------------------# 300 | # Compile Orion modules in the network # 301 | #------------------------------------------# 302 | 303 | print("\n{5} Compiling network layers...", flush=True) 304 | for node in topo_sort: 305 | node_attrs = network_dag.nodes[node] 306 | module = node_attrs["module"] 307 | if isinstance(module, Module): 308 | print(f"├── {node} @ level={module.level}", flush=True) 309 | module.compile() 310 | 311 | return input_level # level at which to encrypt the input. 312 | 313 | def _check_initialization(self): 314 | if self.backend is None: 315 | raise ValueError( 316 | "Scheme not initialized. Call `orion.init_scheme()` first.") 317 | 318 | scheme = Scheme() 319 | -------------------------------------------------------------------------------- /orion/core/tracer.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.fx as fx 6 | 7 | import orion.nn as on 8 | from orion.nn.module import Module 9 | from orion.nn.linear import LinearTransform 10 | from orion.nn.normalization import BatchNormNd 11 | 12 | 13 | class OrionTracer(fx.Tracer): 14 | """ 15 | Overrides the default fx.Tracer that does not recursively access all 16 | modules in the network. This is a deeper trace. 17 | """ 18 | def is_leaf_module(self, m, _): 19 | if not isinstance(m, nn.Module): 20 | return False 21 | if isinstance(m, (nn.Sequential, nn.ModuleList, nn.ModuleDict)): 22 | return False 23 | return not any(True for _ in m.children()) 24 | 25 | def trace_model(self, model): 26 | # Tracing outputs are slightly different when the user provides 27 | # a leaf module (e.g on.Conv2d) rather than a network. We'll wrap 28 | # it temporarily to consistently track FHE statistics. 29 | if self.is_leaf_module(model, ""): 30 | model = ModuleWrapper(model) 31 | 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter("ignore") 34 | return fx.GraphModule(model, super().trace(model)) 35 | 36 | 37 | class ModuleWrapper(on.Module): 38 | """Wrapper for leaf modules to make them traceable.""" 39 | def __init__(self, module): 40 | super().__init__() 41 | self.module = module 42 | 43 | def forward(self, x): 44 | return self.module(x) 45 | 46 | 47 | class StatsTracker(fx.Interpreter): 48 | """Tracks important FHE statistics. """ 49 | 50 | def __init__(self, module: fx.GraphModule) -> None: 51 | super().__init__(module) 52 | self._init_node_attributes() 53 | 54 | def _init_node_attributes(self): 55 | # Tracks min/max values and shapes for FHE-friendly inference 56 | for node in self.module.graph.nodes: 57 | node.input_min = float("inf") 58 | node.input_max = float("-inf") 59 | node.output_min = float("inf") 60 | node.output_max = float("-inf") 61 | node.input_shape = None 62 | node.output_shape = None 63 | node.fhe_input_shape = None 64 | node.fhe_output_shape = None 65 | node.input_gap = 1 66 | node.output_gap = 1 67 | 68 | def run_node(self, node: fx.Node): 69 | # Run one node and track its input/output stats 70 | self._validate_node(node) 71 | 72 | inp = self.map_nodes_to_values(node.args, node) 73 | if inp: 74 | self.update_input_stats(inp, node) 75 | 76 | result = super().run_node(node) # Forward pass the node 77 | self.update_output_stats(result, node) 78 | 79 | if node.op == "call_module": 80 | module = self.module.get_submodule(node.target) 81 | if isinstance(module, Module): 82 | self.sync_module_attributes(node) 83 | 84 | return result 85 | 86 | def _validate_node(self, node): 87 | # Validate that the layer works under FHE 88 | self._validate_shapes_and_gaps(node) 89 | 90 | if node.op == "call_module": 91 | self._validate_module_properties(node) 92 | 93 | def _validate_shapes_and_gaps(self, node): 94 | # Ensure consistent shapes and gaps across inputs 95 | parents = node.all_input_nodes 96 | if not parents: 97 | return 98 | 99 | # Helper function to check consistency 100 | def check_consistency(attr_name, label): 101 | values = [getattr(p, attr_name) for p in parents 102 | if getattr(p, attr_name) is not None] 103 | if len(set(values)) > 1: 104 | raise ValueError( 105 | f"Inconsistent {label} for {node.name}: {set(values)}" 106 | ) 107 | 108 | # Check all required consistencies 109 | check_consistency('output_shape', 'input shapes') 110 | check_consistency('fhe_output_shape', 'FHE shapes') 111 | check_consistency('output_gap', 'input gaps') 112 | 113 | def _validate_module_properties(self, node): 114 | # Check module-specific FHE compatibility requirements 115 | submodule = self.module.get_submodule(node.target) 116 | 117 | # Check stride equality in pooling layers 118 | stride = getattr(submodule, "stride", None) 119 | if stride and len(set(stride)) > 1: 120 | raise ValueError( 121 | f"Stride for {node.name} must be equal in all directions: {stride}" 122 | ) 123 | 124 | # Check BatchNorm parent count 125 | is_batchnorm = isinstance(submodule, BatchNormNd) 126 | has_multiple_parents = len(node.all_input_nodes) > 1 127 | 128 | if is_batchnorm and has_multiple_parents: 129 | raise ValueError( 130 | f"BatchNorm node {node} has multiple parents which prevents fusion" 131 | ) 132 | 133 | def update_input_stats(self, inp: tuple, node: fx.Node): 134 | # Update input statistics from actual tensor values 135 | min_values = [] 136 | max_values = [] 137 | 138 | for e in inp: 139 | if isinstance(e, torch.Tensor): 140 | min_values.append(e.detach().min()) 141 | max_values.append(e.detach().max()) 142 | else: # scalars 143 | scalar_tensor = torch.tensor(e) 144 | min_values.append(scalar_tensor) 145 | max_values.append(scalar_tensor) 146 | 147 | current_min = min(min_values) 148 | current_max = max(max_values) 149 | node.input_min = min(node.input_min, current_min) 150 | node.input_max = max(node.input_max, current_max) 151 | 152 | # Set input shape from parent's output shape for structure preservation 153 | if node.all_input_nodes: 154 | parent = node.all_input_nodes[0] 155 | node.input_shape = parent.output_shape 156 | node.input_gap = parent.output_gap 157 | node.fhe_input_shape = parent.fhe_output_shape 158 | else: 159 | # For input nodes with no parents, use actual tensor shape 160 | node.input_shape = inp[0].shape 161 | 162 | def update_output_stats(self, result: torch.Tensor, node: fx.Node): 163 | # Update output statistics based on actual result tensor 164 | node.output_min = min(node.output_min, result.min()) 165 | node.output_max = max(node.output_max, result.max()) 166 | 167 | # Determine appropriate output shape based on module type 168 | node.output_shape = self.compute_clear_output_shape(node, result) 169 | node.fhe_output_shape = self.compute_fhe_output_shape(node) 170 | node.output_gap = self.compute_fhe_output_gap(node) 171 | 172 | def compute_clear_output_shape(self, node: fx.Node, result): 173 | # Determine output shape, preserving structure except for transforming ops 174 | if not node.input_shape: 175 | return result.shape 176 | 177 | # Only LinearTransform modules change the output shape 178 | if node.op == "call_module": 179 | module = self.module.get_submodule(node.target) 180 | if isinstance(module, LinearTransform): 181 | return result.shape 182 | 183 | # For all other modules, preserve the input shape 184 | return node.input_shape 185 | 186 | def compute_fhe_output_gap(self, node: fx.Node): 187 | if node.op == "call_module": 188 | module = self.module.get_submodule(node.target) 189 | if isinstance(module, LinearTransform): 190 | return module.compute_fhe_output_gap( 191 | input_gap=node.input_gap, 192 | input_shape=node.input_shape, 193 | output_shape=node.output_shape, 194 | ) 195 | return node.input_gap 196 | 197 | def compute_fhe_output_shape(self, node: fx.Node): 198 | if not node.input_shape: 199 | return node.output_shape 200 | 201 | if node.op == "call_module": 202 | module = self.module.get_submodule(node.target) 203 | if isinstance(module, LinearTransform): 204 | return module.compute_fhe_output_shape( 205 | input_gap=node.input_gap, 206 | input_shape=node.input_shape, 207 | output_shape=node.output_shape, 208 | fhe_input_shape=node.fhe_input_shape, 209 | output_gap=node.output_gap, 210 | clear_output_shape=node.output_shape 211 | ) 212 | return node.fhe_input_shape 213 | 214 | def sync_module_attributes(self, node: fx.Node): 215 | # Sync tracked node statistics to the corresponding module 216 | module = self.module.get_submodule(node.target) 217 | module.name = node.name 218 | 219 | # Min/max values 220 | module.input_min = node.input_min 221 | module.input_max = node.input_max 222 | module.output_min = node.output_min 223 | module.output_max = node.output_max 224 | 225 | # Shapes 226 | module.input_shape = node.input_shape 227 | module.output_shape = node.output_shape 228 | module.fhe_input_shape = node.fhe_input_shape 229 | module.fhe_output_shape = node.fhe_output_shape 230 | 231 | # Multiplexed aps 232 | module.input_gap = node.input_gap 233 | module.output_gap = node.output_gap 234 | 235 | def update_batch_size(self, batch_size): 236 | for node in self.module.graph.nodes: 237 | if node.op == "call_module": 238 | module = self.module.get_submodule(node.target) 239 | 240 | shape_attrs = [ 241 | 'input_shape', 242 | 'output_shape', 243 | 'fhe_input_shape', 244 | 'fhe_output_shape' 245 | ] 246 | 247 | # Update only batch dimension 248 | for attr in shape_attrs: 249 | current_shape = getattr(module, attr) 250 | new_shape = torch.Size([batch_size] + list(current_shape[1:])) 251 | setattr(module, attr, new_shape) 252 | 253 | def propagate(self, *args) -> None: 254 | # Run the graph with the provided inputs 255 | self.run(*args) -------------------------------------------------------------------------------- /orion/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .lenet import * 3 | from .lola import * 4 | from .resnet import * 5 | from .vgg import * 6 | from .yolo import * 7 | from .mlp import * -------------------------------------------------------------------------------- /orion/models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | 5 | class ConvBlock(on.Module): 6 | def __init__(self, Ci, Co, kernel_size, stride, padding): 7 | super().__init__() 8 | self.conv = nn.Sequential( 9 | on.Conv2d(Ci, Co, kernel_size, stride, padding, bias=False), 10 | on.BatchNorm2d(Co), 11 | on.SiLU(degree=127)) 12 | 13 | def forward(self, x): 14 | return self.conv(x) 15 | 16 | 17 | class LinearBlock(on.Module): 18 | def __init__(self, ni, no): 19 | super().__init__() 20 | self.linear = nn.Sequential( 21 | on.Linear(ni, no), 22 | on.BatchNorm1d(no), 23 | on.SiLU(degree=127)) 24 | 25 | def forward(self, x): 26 | return self.linear(x) 27 | 28 | 29 | class AlexNet(on.Module): 30 | cfg = [64, 'M', 192, 'M', 384, 256, 256, 'A'] 31 | 32 | def __init__(self, num_classes=10): 33 | super().__init__() 34 | self.features = self._make_layers() 35 | self.flatten = on.Flatten() 36 | self.classifier = nn.Sequential( 37 | LinearBlock(1024, 4096), 38 | LinearBlock(4096, 4096), 39 | on.Linear(4096, num_classes)) 40 | 41 | def _make_layers(self): 42 | layers = [] 43 | in_channels = 3 44 | for x in self.cfg: 45 | if x == 'M': 46 | layers += [on.AvgPool2d(kernel_size=2, stride=2)] 47 | elif x == 'A': 48 | layers += [on.AdaptiveAvgPool2d((2, 2))] 49 | else: 50 | layers += [ConvBlock(in_channels, x, kernel_size=3, 51 | stride=1, padding=1)] 52 | in_channels = x 53 | return nn.Sequential(*layers) 54 | 55 | def forward(self, x): 56 | x = self.features(x) 57 | x = self.flatten(x) 58 | x = self.classifier(x) 59 | return x 60 | 61 | 62 | if __name__ == "__main__": 63 | import torch 64 | from torchsummary import summary 65 | from fvcore.nn import FlopCountAnalysis 66 | 67 | net = AlexNet() 68 | net.eval() 69 | 70 | x = torch.randn(1,3,32,32) 71 | total_flops = FlopCountAnalysis(net, x).total() 72 | 73 | summary(net, (3,32,32), depth=10) 74 | print("Total flops: ", total_flops) 75 | -------------------------------------------------------------------------------- /orion/models/lenet.py: -------------------------------------------------------------------------------- 1 | import orion.nn as on 2 | 3 | class LeNet(on.Module): 4 | def __init__(self, num_classes=10): 5 | super().__init__() 6 | self.conv1 = on.Conv2d(1, 32, kernel_size=5, padding=2, stride=2) 7 | self.bn1 = on.BatchNorm2d(32) 8 | self.act1 = on.Quad() 9 | 10 | self.conv2 = on.Conv2d(32, 64, kernel_size=5, padding=2, stride=2) 11 | self.bn2 = on.BatchNorm2d(64) 12 | self.act2 = on.Quad() 13 | 14 | self.flatten = on.Flatten() 15 | self.fc1 = on.Linear(7*7*64, 512) 16 | self.bn3 = on.BatchNorm1d(512) 17 | self.act3 = on.Quad() 18 | 19 | self.fc2 = on.Linear(512, num_classes) 20 | 21 | def forward(self, x): 22 | x = self.act1(self.bn1(self.conv1(x))) 23 | x = self.act2(self.bn2(self.conv2(x))) 24 | x = self.flatten(x) 25 | x = self.act3(self.bn3(self.fc1(x))) 26 | return self.fc2(x) 27 | 28 | 29 | if __name__ == "__main__": 30 | import torch 31 | from torchsummary import summary 32 | from fvcore.nn import FlopCountAnalysis 33 | 34 | net = LeNet() 35 | net.eval() 36 | 37 | x = torch.randn(1,1,28,28) 38 | total_flops = FlopCountAnalysis(net, x).total() 39 | 40 | summary(net, (1,28,28), depth=10) 41 | print("Total flops: ", total_flops) 42 | -------------------------------------------------------------------------------- /orion/models/lola.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | class LoLA(on.Module): 5 | def __init__(self, num_classes=10): 6 | super().__init__() 7 | self.conv1 = on.Conv2d(1, 5, kernel_size=2, padding=0, stride=2) 8 | self.bn1 = on.BatchNorm2d(5) 9 | self.act1 = on.Quad() 10 | 11 | self.fc1 = on.Linear(980, 100) 12 | self.bn2 = on.BatchNorm1d(100) 13 | self.act2 = on.Quad() 14 | 15 | self.fc2 = on.Linear(100, num_classes) 16 | self.flatten = on.Flatten() 17 | 18 | def forward(self, x): 19 | x = self.act1(self.bn1(self.conv1(x))) 20 | x = self.flatten(x) 21 | x = self.act2(self.bn2(self.fc1(x))) 22 | return self.fc2(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | import torch 27 | from torchsummary import summary 28 | from fvcore.nn import FlopCountAnalysis 29 | 30 | net = LoLA() 31 | net.eval() 32 | 33 | x = torch.randn(1,1,28,28) 34 | total_flops = FlopCountAnalysis(net, x).total() 35 | 36 | summary(net, (1,28,28), depth=10) 37 | print("Total flops: ", total_flops) 38 | -------------------------------------------------------------------------------- /orion/models/mlp.py: -------------------------------------------------------------------------------- 1 | import orion.nn as on 2 | 3 | class MLP(on.Module): 4 | def __init__(self, num_classes=10): 5 | super().__init__() 6 | self.flatten = on.Flatten() 7 | 8 | self.fc1 = on.Linear(784, 128) 9 | self.bn1 = on.BatchNorm1d(128) 10 | self.act1 = on.Quad() 11 | 12 | self.fc2 = on.Linear(128, 128) 13 | self.bn2 = on.BatchNorm1d(128) 14 | self.act2 = on.Quad() 15 | 16 | self.fc3 = on.Linear(128, num_classes) 17 | 18 | def forward(self, x): 19 | x = self.flatten(x) 20 | x = self.act1(self.bn1(self.fc1(x))) 21 | x = self.act2(self.bn2(self.fc2(x))) 22 | return self.fc3(x) 23 | 24 | 25 | if __name__ == "__main__": 26 | import torch 27 | from torchsummary import summary 28 | from fvcore.nn import FlopCountAnalysis 29 | 30 | net = MLP() 31 | net.eval() 32 | 33 | x = torch.randn(1,1,28,28) 34 | total_flops = FlopCountAnalysis(net, x).total() 35 | 36 | summary(net, (1,28,28), depth=10) 37 | print("Total flops: ", total_flops) 38 | -------------------------------------------------------------------------------- /orion/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | 5 | class BasicBlock(on.Module): 6 | expansion = 1 7 | 8 | def __init__(self, Ci, Co, stride=1): 9 | super().__init__() 10 | self.conv1 = on.Conv2d(Ci, Co, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn1 = on.BatchNorm2d(Co) 12 | self.act1 = on.ReLU() 13 | 14 | self.conv2 = on.Conv2d(Co, Co, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = on.BatchNorm2d(Co) 16 | self.act2 = on.ReLU() 17 | 18 | self.add = on.Add() 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or Ci != self.expansion*Co: 21 | self.shortcut = nn.Sequential( 22 | on.Conv2d(Ci, self.expansion*Co, kernel_size=1, stride=stride, bias=False), 23 | on.BatchNorm2d(self.expansion*Co)) 24 | 25 | def forward(self, x): 26 | out = self.act1(self.bn1(self.conv1(x))) 27 | out = self.bn2(self.conv2(out)) 28 | out = self.add(out, self.shortcut(x)) 29 | return self.act2(out) 30 | 31 | 32 | class Bottleneck(on.Module): 33 | expansion = 4 34 | 35 | def __init__(self, Ci, Co, stride=1): 36 | super().__init__() 37 | self.conv1 = on.Conv2d(Ci, Co, kernel_size=1, bias=False) 38 | self.bn1 = on.BatchNorm2d(Co) 39 | self.act1 = on.SiLU(degree=127) 40 | 41 | self.conv2 = on.Conv2d(Co, Co, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = on.BatchNorm2d(Co) 43 | self.act2 = on.SiLU(degree=127) 44 | 45 | self.conv3 = on.Conv2d(Co, Co*self.expansion, kernel_size=1, stride=1, bias=False) 46 | self.bn3 = on.BatchNorm2d(Co*self.expansion) 47 | self.act3 = on.SiLU(degree=127) 48 | 49 | self.add = on.Add() 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or Ci != self.expansion*Co: 52 | self.shortcut = nn.Sequential( 53 | on.Conv2d(Ci, self.expansion*Co, kernel_size=1, stride=stride, bias=False), 54 | on.BatchNorm2d(self.expansion*Co)) 55 | 56 | def forward(self, x): 57 | out = self.act1(self.bn1(self.conv1(x))) 58 | out = self.act2(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out = self.add(out, self.shortcut(x)) 61 | return self.act3(out) 62 | 63 | 64 | class ResNet(on.Module): 65 | def __init__(self, dataset, block, num_blocks, num_chans, conv1_params, num_classes): 66 | super().__init__() 67 | self.in_chans = num_chans[0] 68 | self.last_chans = num_chans[-1] 69 | 70 | self.conv1 = on.Conv2d(3, self.in_chans, **conv1_params, bias=False) 71 | self.bn1 = on.BatchNorm2d(self.in_chans) 72 | self.act = on.ReLU() 73 | 74 | self.pool = nn.Identity() 75 | if dataset == 'imagenet': # for imagenet we must also downsample 76 | self.pool = on.AvgPool2d(kernel_size=3, stride=2, padding=1) 77 | 78 | self.layers = nn.ModuleList() 79 | for i in range(len(num_blocks)): 80 | stride = 1 if i == 0 else 2 81 | self.layers.append(self.layer(block, num_chans[i], num_blocks[i], stride)) 82 | 83 | self.avgpool = on.AdaptiveAvgPool2d(output_size=(1,1)) 84 | self.flatten = on.Flatten() 85 | self.linear = on.Linear(self.last_chans * block.expansion, num_classes) 86 | 87 | def layer(self, block, chans, num_blocks, stride): 88 | strides = [stride] + [1]*(num_blocks-1) 89 | layers = [] 90 | for stride in strides: 91 | layers.append(block(self.in_chans, chans, stride)) 92 | self.in_chans = chans * block.expansion 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = self.act(self.bn1(self.conv1(x))) 97 | out = self.pool(out) 98 | for layer in self.layers: 99 | out = layer(out) 100 | 101 | out = self.avgpool(out) 102 | out = self.flatten(out) 103 | return self.linear(out) 104 | 105 | 106 | ################################ 107 | # CIFAR-10 / CIFAR-100 ResNets # 108 | ################################ 109 | 110 | def ResNet20(dataset='cifar10'): 111 | conv1_params, num_classes = get_resnet_config(dataset) 112 | return ResNet(dataset, BasicBlock, [3,3,3], [16,32,64], conv1_params, num_classes) 113 | 114 | def ResNet32(dataset='cifar10'): 115 | conv1_params, num_classes = get_resnet_config(dataset) 116 | return ResNet(dataset, BasicBlock, [5,5,5], [16,32,64], conv1_params, num_classes) 117 | 118 | def ResNet44(dataset='cifar10'): 119 | conv1_params, num_classes = get_resnet_config(dataset) 120 | return ResNet(dataset, BasicBlock, [7,7,7], [16,32,64], conv1_params, num_classes) 121 | 122 | def ResNet56(dataset='cifar10'): 123 | conv1_params, num_classes = get_resnet_config(dataset) 124 | return ResNet(dataset, BasicBlock, [9,9,9], [16,32,64], conv1_params, num_classes) 125 | 126 | def ResNet110(dataset='cifar10'): 127 | conv1_params, num_classes = get_resnet_config(dataset) 128 | return ResNet(dataset, BasicBlock, [18,18,18], [16,32,64], conv1_params, num_classes) 129 | 130 | def ResNet1202(dataset='cifar10'): 131 | conv1_params, num_classes = get_resnet_config(dataset) 132 | return ResNet(dataset, BasicBlock, [200,200,200], [16,32,64], conv1_params, num_classes) 133 | 134 | #################################### 135 | # Tiny ImageNet / ImageNet ResNets # 136 | #################################### 137 | 138 | def ResNet18(dataset='imagenet'): 139 | conv1_params, num_classes = get_resnet_config(dataset) 140 | return ResNet(dataset, BasicBlock, [2,2,2,2], [64,128,256,512], conv1_params, num_classes) 141 | 142 | def ResNet34(dataset='imagenet'): 143 | conv1_params, num_classes = get_resnet_config(dataset) 144 | return ResNet(dataset, BasicBlock, [3,4,6,3], [64,128,256,512], conv1_params, num_classes) 145 | 146 | def ResNet50(dataset='imagenet'): 147 | conv1_params, num_classes = get_resnet_config(dataset) 148 | return ResNet(dataset, Bottleneck, [3,4,6,3], [64,128,256,512], conv1_params, num_classes) 149 | 150 | def ResNet101(dataset='imagenet'): 151 | conv1_params, num_classes = get_resnet_config(dataset) 152 | return ResNet(dataset, Bottleneck, [3,4,23,3], [64,128,256,512], conv1_params, num_classes) 153 | 154 | def ResNet152(dataset='imagenet'): 155 | conv1_params, num_classes = get_resnet_config(dataset) 156 | return ResNet(dataset, Bottleneck, [3,8,36,3], [64,128,256,512], conv1_params, num_classes) 157 | 158 | 159 | def get_resnet_config(dataset): 160 | configs = { 161 | "cifar10": {"kernel_size": 3, "stride": 1, "padding": 1, "num_classes": 10}, 162 | "cifar100": {"kernel_size": 3, "stride": 1, "padding": 1, "num_classes": 100}, 163 | "tiny": {"kernel_size": 7, "stride": 1, "padding": 3, "num_classes": 200}, 164 | "imagenet": {"kernel_size": 7, "stride": 2, "padding": 3, "num_classes": 1000}, 165 | } 166 | 167 | if dataset not in configs: 168 | raise ValueError(f"ResNet with dataset {dataset} is not supported.") 169 | 170 | config = configs[dataset] 171 | conv1_params = { 172 | 'kernel_size': config["kernel_size"], 173 | 'stride': config["stride"], 174 | 'padding': config["padding"] 175 | } 176 | 177 | return conv1_params, config["num_classes"] 178 | 179 | 180 | if __name__ == "__main__": 181 | import torch 182 | from torchsummary import summary 183 | from fvcore.nn import FlopCountAnalysis 184 | 185 | net = ResNet50() 186 | net.eval() 187 | 188 | x = torch.randn(1,3,224,224) 189 | total_flops = FlopCountAnalysis(net, x).total() 190 | 191 | summary(net, (3,224,224), depth=10) 192 | print("Total flops: ", total_flops) 193 | -------------------------------------------------------------------------------- /orion/models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import orion.nn as on 3 | 4 | cfg = { 5 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 6 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 7 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 8 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 9 | } 10 | 11 | class VGG(on.Module): 12 | def __init__(self, vgg_name): 13 | super().__init__() 14 | self.features = self._make_layers(cfg[vgg_name]) 15 | self.classifier = on.Linear(512, 10) 16 | self.flatten = on.Flatten() 17 | 18 | def forward(self, x): 19 | out = self.features(x) 20 | out = self.flatten(out) 21 | out = self.classifier(out) 22 | return out 23 | 24 | def _make_layers(self, cfg): 25 | layers = [] 26 | in_channels = 3 27 | for x in cfg: 28 | if x == 'M': 29 | layers += [on.AvgPool2d(kernel_size=2, stride=2)] 30 | else: 31 | layers += [on.Conv2d(in_channels, x, kernel_size=3, padding=1), 32 | on.BatchNorm2d(x), 33 | on.ReLU(degrees=[15,15,27])] 34 | in_channels = x 35 | layers += [on.AvgPool2d(kernel_size=1, stride=1)] 36 | return nn.Sequential(*layers) 37 | 38 | 39 | if __name__ == "__main__": 40 | import torch 41 | from torchsummary import summary 42 | from fvcore.nn import FlopCountAnalysis 43 | 44 | net = VGG('VGG16') 45 | net.eval() 46 | 47 | x = torch.randn(1,3,32,32) 48 | total_flops = FlopCountAnalysis(net, x).total() 49 | 50 | summary(net, (3,32,32), depth=10) 51 | print("Total flops: ", total_flops) 52 | -------------------------------------------------------------------------------- /orion/models/yolo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import orion.nn as on 4 | 5 | from .resnet import * 6 | 7 | 8 | class YOLOv1(on.Module): 9 | def __init__(self, backbone, num_bboxes=2, num_classes=20, model_path=None): 10 | super().__init__() 11 | 12 | self.feature_size = 7 13 | self.num_bboxes = num_bboxes 14 | self.num_classes = num_classes 15 | self.model_path = model_path 16 | 17 | self.backbone = backbone 18 | self.conv_layers = self._make_conv_layers() 19 | self.fc_layers = self._make_fc_layers() 20 | 21 | # Remove last layers of backbone 22 | self.backbone.avgpool = nn.Identity() 23 | self.backbone.flatten = nn.Identity() 24 | self.backbone.linear = nn.Identity() 25 | 26 | self._init_weights() 27 | 28 | def _init_weights(self): 29 | if self.model_path: 30 | state_dict = torch.load(self.model_path, map_location='cpu', weights_only=False) 31 | self.load_state_dict(state_dict, strict=False) 32 | 33 | def _make_conv_layers(self): 34 | net = nn.Sequential( 35 | on.Conv2d(512, 512, 3, padding=1), 36 | on.SiLU(degree=127), 37 | on.Conv2d(512, 512, 3, stride=2, padding=1), 38 | on.SiLU(degree=127), 39 | 40 | on.Conv2d(512, 512, 3, padding=1), 41 | on.SiLU(degree=127), 42 | on.Conv2d(512, 512, 3, padding=1), 43 | on.SiLU(degree=127) 44 | ) 45 | 46 | return net 47 | 48 | def _make_fc_layers(self): 49 | S, B, C = self.feature_size, self.num_bboxes, self.num_classes 50 | net = nn.Sequential( 51 | on.Flatten(), 52 | on.Linear(7 * 7 * 512, 4096), 53 | on.SiLU(degree=127), 54 | on.Linear(4096, S * S * (5 * B + C)), 55 | ) 56 | 57 | return net 58 | 59 | def forward(self, x): 60 | x = self.backbone(x) 61 | x = self.conv_layers(x) 62 | x = self.fc_layers(x) 63 | return x 64 | 65 | 66 | def YOLOv1_ResNet34(model_path=None): 67 | backbone = ResNet34() 68 | net = YOLOv1(backbone, num_bboxes=2, num_classes=20, model_path=model_path) 69 | return net 70 | 71 | 72 | if __name__ == "__main__": 73 | import torch 74 | from torchsummary import summary 75 | from fvcore.nn import FlopCountAnalysis 76 | 77 | net = YOLOv1_ResNet34() 78 | net.eval() 79 | 80 | x = torch.randn(1,3,448,448) 81 | total_flops = FlopCountAnalysis(net, x).total() 82 | 83 | summary(net, (3,448,448), depth=10) 84 | print("Total flops: ", total_flops) 85 | -------------------------------------------------------------------------------- /orion/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import * 2 | from .linear import * 3 | from .module import * 4 | from .normalization import * 5 | from .operations import * 6 | from .pooling import * 7 | from .reshape import * -------------------------------------------------------------------------------- /orion/nn/activation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from orion.nn.module import Module, timer 9 | from orion.nn.operations import Mult 10 | 11 | 12 | class Activation(Module): 13 | def __init__(self, coeffs): 14 | super().__init__() 15 | self.coeffs = coeffs 16 | self.output_scale = None 17 | self.set_depth() 18 | 19 | def extra_repr(self): 20 | return super().extra_repr() + f", degree={len(self.coeffs)-1}" 21 | 22 | def set_depth(self): 23 | self.depth = int(math.ceil(math.log2(len(self.coeffs)))) 24 | 25 | def set_output_scale(self, output_scale): 26 | self.output_scale = output_scale 27 | 28 | def compile(self): 29 | self.poly = self.scheme.poly_evaluator.generate_monomial(self.coeffs) 30 | 31 | @timer 32 | def forward(self, x): 33 | if self.he_mode: 34 | return self.scheme.poly_evaluator.evaluate_polynomial( 35 | x, self.poly, self.output_scale) 36 | 37 | # Horner's method 38 | out = 0 39 | for coeff in self.coeffs: 40 | out = coeff + x * out 41 | 42 | return out 43 | 44 | 45 | class Quad(Module): 46 | def __init__(self): 47 | super().__init__() 48 | self.set_depth(1) 49 | 50 | def forward(self, x): 51 | out = x * x 52 | if self.he_mode: 53 | out.set_scale(x.scale()) 54 | return out 55 | 56 | 57 | class Chebyshev(Module): 58 | def __init__(self, degree: int, fn, within_composite=False): 59 | super().__init__() 60 | self.degree = degree 61 | self.fn = fn 62 | self.within_composite = within_composite 63 | self.coeffs = None 64 | 65 | self.output_scale = None 66 | self.prescale = 1 67 | self.constant = 0 68 | 69 | def extra_repr(self): 70 | return super().extra_repr() + f", degree={self.degree}" 71 | 72 | def fit(self): 73 | if not self.within_composite: 74 | center = (self.input_min + self.input_max) / 2 75 | half_range = (self.input_max - self.input_min) / 2 76 | self.low = (center - (self.margin * half_range)).item() 77 | self.high = (center + (self.margin * half_range)).item() 78 | 79 | nodes = np.polynomial.chebyshev.chebpts1(self.degree + 1) 80 | if self.low < -1 or self.high > 1: 81 | self.prescale = 2 / (self.high - self.low) 82 | self.constant = -self.prescale * (self.low + self.high) / 2 83 | evals = (nodes + 1) * (self.high - self.low) / 2 + self.low 84 | else: 85 | evals = nodes 86 | 87 | evals = torch.tensor(evals) 88 | T = np.polynomial.Chebyshev.fit(nodes, self.fn(evals), self.degree) 89 | self.set_coeffs(T.coef.tolist()) 90 | self.set_depth() 91 | 92 | def set_coeffs(self, coeffs): 93 | self.coeffs = coeffs 94 | 95 | def set_depth(self): 96 | self.depth = int(math.ceil(math.log2(self.degree+1))) 97 | if self.prescale != 1: # additional level needed 98 | self.depth += 1 99 | 100 | def set_output_scale(self, output_scale): 101 | self.output_scale = output_scale 102 | 103 | def compile(self): 104 | self.poly = self.scheme.poly_evaluator.generate_chebyshev(self.coeffs) 105 | 106 | @timer 107 | def forward(self, x): 108 | if not self.he_mode: 109 | return self.fn(x) 110 | 111 | # Scale into [-1, 1] if needed. 112 | if not self.fused: 113 | if self.prescale != 1: 114 | x *= self.prescale 115 | if self.constant != 0: 116 | x += self.constant 117 | 118 | return self.scheme.poly_evaluator.evaluate_polynomial( 119 | x, self.poly, self.output_scale) 120 | 121 | 122 | class ELU(Chebyshev): 123 | def __init__(self, alpha=1.0, degree=31): 124 | self.alpha = alpha 125 | super().__init__(degree, self.fn) 126 | 127 | def fn(self, x): 128 | return torch.where(x > 0, x, self.alpha * (torch.exp(x) - 1)) 129 | 130 | 131 | class Hardshrink(Chebyshev): 132 | def __init__(self, degree=31, lambd=0.5): 133 | self.lambd = lambd 134 | super().__init__(degree, self.fn) 135 | 136 | def fn(self, x): 137 | return torch.where((x > self.lambd) | (x < -self.lambd), x, torch.tensor(0.0)) 138 | 139 | 140 | class GELU(Chebyshev): 141 | def __init__(self, degree=31): 142 | super().__init__(degree, self.fn) 143 | 144 | def fn(self, x): 145 | return F.gelu(x) 146 | 147 | 148 | class SiLU(Chebyshev): 149 | def __init__(self, degree=31): 150 | super().__init__(degree, self.fn) 151 | 152 | def fn(self, x): 153 | return F.silu(x) 154 | 155 | 156 | class Sigmoid(Chebyshev): 157 | def __init__(self, degree=31): 158 | super().__init__(degree, self.fn) 159 | 160 | def fn(self, x): 161 | return F.sigmoid(x) 162 | 163 | 164 | class SELU(Chebyshev): 165 | def __init__(self, degree=31): 166 | super().__init__(degree, self.fn) 167 | 168 | def fn(self, x): 169 | alpha = 1.6732632423543772 170 | scale = 1.0507009873554805 171 | return scale * torch.where(x > 0, x, alpha * (torch.exp(x) - 1)) 172 | 173 | 174 | class Softplus(Chebyshev): 175 | def __init__(self, degree=31): 176 | super().__init__(degree, self.fn) 177 | 178 | def fn(self, x): 179 | return F.softplus(x) 180 | 181 | 182 | class Mish(Chebyshev): 183 | def __init__(self, degree=31): 184 | super().__init__(degree, self.fn) 185 | 186 | def fn(self, x): 187 | return x * torch.tanh(F.softplus(x)) 188 | 189 | 190 | class _Sign(Module): 191 | def __init__( 192 | self, 193 | degrees=[15,15,27], 194 | prec=128, 195 | logalpha=6, 196 | logerr=12, 197 | ): 198 | super().__init__() 199 | self.degrees = degrees 200 | self.prec = prec 201 | self.logalpha = logalpha 202 | self.logerr = logerr 203 | self.mult = Mult() 204 | 205 | acts = [] 206 | for i, degree in enumerate(degrees): 207 | is_last = (i == len(degrees) - 1) 208 | fn = self.fn1 if not is_last else self.fn2 209 | act = Chebyshev(degree, fn, within_composite=True) 210 | acts.append(act) 211 | 212 | self.acts = nn.Sequential(*acts) 213 | 214 | def extra_repr(self): 215 | return super().extra_repr() + f", degrees={self.degrees}" 216 | 217 | def fit(self): 218 | debug = self.scheme.params.get_debug_status() 219 | self.coeffs = self.scheme.poly_evaluator.generate_minimax_sign_coeffs( 220 | self.degrees, self.prec, self.logalpha, self.logerr, debug) 221 | 222 | for i, coeffs in enumerate(self.coeffs): 223 | self.acts[i].set_coeffs(coeffs) 224 | self.acts[i].set_depth() 225 | 226 | def fn1(self, x): 227 | return torch.where(x <= 0, torch.tensor(-1.0), torch.tensor(1.0)) 228 | 229 | def fn2(self, x): 230 | return torch.where(x <= 0, torch.tensor(0.0), torch.tensor(1.0)) 231 | 232 | def forward(self, x): 233 | if self.he_mode: 234 | l1 = x.level() 235 | l2 = self.acts[-1].level - self.acts[-1].depth 236 | 237 | # We'll calculate the output level of sign on the fly by 238 | # comparing and taking the minimum of x and sign(x), as FHE 239 | # multiplication will do the same. Then, we'll set the output 240 | # scale of sign to be the modulus in the chain at this level. 241 | # This way, rescaling divides ql / ql and is exact. 242 | output_level = min(l1, l2) 243 | ql = self.scheme.encoder.get_moduli_chain()[output_level] 244 | self.acts[-1].set_output_scale(ql) 245 | 246 | # Composite polynomial evaluation 247 | for act in self.acts: 248 | x = act(x) 249 | return x 250 | 251 | 252 | class ReLU(Module): 253 | def __init__(self, 254 | degrees=[15,15,27], 255 | prec=128, 256 | logalpha=6, 257 | logerr=12, 258 | ): 259 | super().__init__() 260 | self.degrees = degrees 261 | self.prec = prec 262 | self.logalpha = logalpha 263 | self.logerr = logerr 264 | self.sign = _Sign(degrees, prec, logalpha, logerr) 265 | self.mult1 = Mult() 266 | self.mult2 = Mult() 267 | 268 | self.prescale = 1 269 | self.postscale = 1 270 | 271 | def extra_repr(self): 272 | return super().extra_repr() + f", degrees={self.degrees}" 273 | 274 | def fit(self): 275 | self.input_min = self.mult1.input_min 276 | self.input_max = self.mult1.input_max 277 | 278 | absmax = max(abs(self.input_min), abs(self.input_max)) * self.margin 279 | if absmax > 1: 280 | self.postscale = int(math.ceil(absmax)) 281 | self.prescale = 1 / self.postscale 282 | 283 | @timer 284 | def forward(self, x): 285 | x = self.mult1(x, self.prescale) 286 | x = self.mult2(x, self.sign(x)) 287 | x *= self.postscale # integer mult, no level consumed 288 | return x -------------------------------------------------------------------------------- /orion/nn/linear.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import math 3 | from abc import abstractmethod 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .module import Module, timer 9 | from ..core import packing 10 | 11 | 12 | class LinearTransform(Module): 13 | def __init__(self, bsgs_ratio, level) -> None: 14 | super().__init__() 15 | self.bsgs_ratio = float(bsgs_ratio) 16 | self.set_depth(1) 17 | self.set_level(level) 18 | 19 | self.diagonals = {} # diags[(row, col)] = {0: [...], 1: [...], ...} 20 | self.transform_ids = {} # ids[(row, col)] = int 21 | self.output_rotations = 0 22 | 23 | def __del__(self): 24 | if 'sys' in globals() and sys.modules and self.scheme: 25 | try: 26 | self.scheme.lt_evaluator.delete_transforms(self.transform_ids) 27 | except Exception: 28 | pass # avoids errors for GC at program termination 29 | 30 | def extra_repr(self): 31 | return super().extra_repr() + f", bsgs_ratio={self.bsgs_ratio}" 32 | 33 | def init_orion_params(self): 34 | # Initialize additional Orion-specific weights/biases. 35 | self.on_weight = self.weight.data.clone() 36 | self.on_bias = (self.bias.data.clone() if hasattr(self, 'bias') and 37 | self.bias is not None else torch.zeros(self.weight.shape[0])) 38 | 39 | @abstractmethod 40 | def compute_fhe_output_gap(self, **kwargs): 41 | """Compute the multiplexed output gap.""" 42 | pass 43 | 44 | @abstractmethod 45 | def compute_fhe_output_shape(self, **kwargs) -> tuple: 46 | """Compute the FHE output dimensions after multiplexing.""" 47 | pass 48 | 49 | @abstractmethod 50 | def generate_diagonals(self, last: bool): 51 | pass 52 | 53 | def get_io_mode(self): 54 | return self.scheme.params.get_io_mode() 55 | 56 | def save_transforms(self): 57 | self.scheme.lt_evaluator.save_transforms(self) 58 | 59 | def load_transforms(self): 60 | return self.scheme.lt_evaluator.load_transforms(self) 61 | 62 | def compile(self): 63 | self.transform_ids = self.scheme.lt_evaluator.generate_transforms(self) 64 | 65 | @timer 66 | def evaluate_transforms(self, x): 67 | out = self.scheme.lt_evaluator.evaluate_transforms(self, x) 68 | 69 | # Hybrid method's output rotations 70 | slots = self.scheme.params.get_slots() 71 | for i in range(1, self.output_rotations+1): 72 | out += out.roll(slots // (2**i)) 73 | 74 | out += self.on_bias_ptxt 75 | return out 76 | 77 | 78 | class Linear(LinearTransform): 79 | def __init__( 80 | self, 81 | in_features: int, 82 | out_features: int, 83 | bias: bool = True, 84 | bsgs_ratio: int = 2, 85 | level: int = None, 86 | ) -> None: 87 | super().__init__(bsgs_ratio, level) 88 | 89 | self.in_features = in_features 90 | self.out_features = out_features 91 | self.weight = nn.Parameter(torch.empty((out_features, in_features))) 92 | self.bias = nn.Parameter(torch.empty(out_features)) if bias else None 93 | self.reset_parameters() 94 | 95 | def extra_repr(self): 96 | return (f"in_features={self.in_features}, out_features={self.out_features}, " + 97 | super().extra_repr()) 98 | 99 | def reset_parameters(self): 100 | # Initialize weights and biases following standard PyTorch instantiation. 101 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 102 | 103 | if self.bias is not None: 104 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 105 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 106 | nn.init.uniform_(self.bias, -bound, bound) 107 | 108 | def compute_fhe_output_gap(self, **kwargs): 109 | return 1 # linear layers in reset the multiplexed gap to 1. 110 | 111 | def compute_fhe_output_shape(self, **kwargs) -> tuple: 112 | # Linear layers also remove any padded zeros, Therefore the output 113 | # shape under FHE inference is identical to cleartext inference. 114 | return kwargs["clear_output_shape"] 115 | 116 | def generate_diagonals(self, last): 117 | # Here, we'll apply our packing strategies to return the diagonals 118 | # of our linear layer. When using the "hybrid" method of packing, this 119 | # may also require several output rotations and summations. 120 | self.diagonals, self.output_rotations = packing.pack_linear(self, last) 121 | if self.get_io_mode() == "save": 122 | self.save_transforms() 123 | 124 | def compile(self): 125 | # If the user specifies an I/O mode = "save" or "load", then diagonals will 126 | # be temporarily stored to disk to save memory. Load right before they're 127 | # needed to generate the backend transforms themselves. 128 | if self.get_io_mode() != "none": 129 | self.diagonals, self.on_bias, self.output_rotations = self.load_transforms() 130 | 131 | # We delay constructing the bias until now, so that any fusing can 132 | # modify the bias variable beforehand. 133 | bias = packing.construct_linear_bias(self) 134 | self.on_bias_ptxt = self.scheme.encoder.encode(bias, self.level-self.depth) 135 | self.transform_ids = self.scheme.lt_evaluator.generate_transforms(self) 136 | 137 | def forward(self, x): 138 | if not self.he_mode: 139 | if x.dim() != 2: 140 | extra = " Forgot to call on.Flatten() first?" if x.dim() == 4 else "" 141 | raise ValueError( 142 | f"Expected input to {self.__class__.__name__} to have " 143 | f"2 dimensions (N, in_features), but got {x.dim()} " 144 | f"dimension(s): {x.shape}." + extra 145 | ) 146 | 147 | # If we're not in FHE inference mode, then we'll just return 148 | # the default PyTorch result. 149 | return torch.nn.functional.linear(x, self.weight, self.bias) 150 | 151 | # Otherwise, call parent evaluation for FHE. 152 | return self.evaluate_transforms(x) 153 | 154 | 155 | class Conv2d(LinearTransform): 156 | def __init__( 157 | self, 158 | in_channels: int, 159 | out_channels: int, 160 | kernel_size: int, 161 | stride: int = 1, 162 | padding: int = 0, 163 | dilation: int = 1, 164 | groups: int = 1, 165 | bias: bool = True, 166 | bsgs_ratio: int = 2, 167 | level: int = None, 168 | ) -> None: 169 | super().__init__(bsgs_ratio, level) 170 | 171 | # Standard PyTorch Conv2d attributes 172 | self.in_channels = in_channels 173 | self.out_channels = out_channels 174 | 175 | # Convert int parameters to tuples 176 | self.kernel_size = self._make_tuple(kernel_size) 177 | self.stride = self._make_tuple(stride) 178 | self.padding = self._make_tuple(padding) 179 | self.dilation = self._make_tuple(dilation) 180 | self.groups = groups 181 | 182 | self.weight = nn.Parameter( 183 | torch.empty(out_channels, in_channels // groups, *self.kernel_size) 184 | ) 185 | self.bias = nn.Parameter(torch.empty(out_channels)) if bias else None 186 | self.reset_parameters() 187 | 188 | def _make_tuple(self, value): 189 | return (value, value) if isinstance(value, int) else value 190 | 191 | def reset_parameters(self): 192 | """Initialize parameters using PyTorch's standard approach.""" 193 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 194 | 195 | if self.bias is not None: 196 | fan_in = self.weight.size(1) * self.weight.size(2) * self.weight.size(3) 197 | bound = 1 / math.sqrt(fan_in) 198 | nn.init.uniform_(self.bias, -bound, bound) 199 | 200 | def extra_repr(self): 201 | return (f"in_channels={self.in_channels}, out_channels={self.out_channels}, " 202 | f"kernel_size={self.kernel_size}, stride={self.stride}, " 203 | f"padding={self.padding}, dilation={self.dilation}, " 204 | f"groups={self.groups}, " + super().extra_repr()) 205 | 206 | def compute_fhe_output_gap(self, **kwargs): 207 | # Strided convolutions increase the multiplexed gap by a factor 208 | # of the stride. 209 | input_gap = kwargs['input_gap'] 210 | return input_gap * self.stride[0] 211 | 212 | def compute_fhe_output_shape(self, **kwargs) -> tuple: 213 | input_shape = kwargs['input_shape'] 214 | clear_output_shape = kwargs['clear_output_shape'] 215 | input_gap = kwargs['input_gap'] 216 | 217 | Hi, Wi = input_shape[2:] 218 | N, Co, Ho, Wo = clear_output_shape 219 | output_gap = self.compute_fhe_output_gap(input_gap=input_gap) 220 | 221 | on_Co = math.ceil(Co / (output_gap**2)) 222 | on_Ho = max(Hi, Ho*output_gap) 223 | on_Wo = max(Wi, Wo*output_gap) 224 | 225 | return torch.Size((N, on_Co, on_Ho, on_Wo)) 226 | 227 | def generate_diagonals(self, last): 228 | # Generate Toeplitz diagonals and determine the number of output 229 | # rotations if the `hybrid` packing method is used. 230 | self.diagonals, self.output_rotations = packing.pack_conv2d(self, last) 231 | if self.get_io_mode() == "save": 232 | self.save_transforms() 233 | 234 | def compile(self): 235 | # If the user specifies an io mode = "save" or "load", then diagonals will 236 | # be temporarily stored to disk to save memory. Load right before they're 237 | # needed to generate the backend transforms themselves. 238 | if self.get_io_mode() != "none": 239 | self.diagonals, self.on_bias, self.output_rotations = self.load_transforms() 240 | 241 | # We delay constructing the bias until now, so that any fusing can 242 | # modify the bias variable beforehand. 243 | bias = packing.construct_conv2d_bias(self) 244 | self.on_bias_ptxt = self.scheme.encoder.encode(bias, self.level-self.depth) 245 | self.transform_ids = self.scheme.lt_evaluator.generate_transforms(self) 246 | 247 | def forward(self, x): 248 | # Forward pass that handles both cleartext and FHE inference. 249 | if not self.he_mode: # cleartext mode 250 | if x.dim() != 4: 251 | raise ValueError( 252 | f"Expected input to {self.__class__.__name__} to have " 253 | f" 4 dimensions (N, C, H, W), but got {x.dim()} " 254 | f"dimension(s): {x.shape}." 255 | ) 256 | return torch.nn.functional.conv2d( 257 | x, self.weight, self.bias, self.stride, 258 | self.padding, self.dilation, self.groups 259 | ) 260 | 261 | return self.evaluate_transforms(x) # FHE mode -------------------------------------------------------------------------------- /orion/nn/module.py: -------------------------------------------------------------------------------- 1 | import time 2 | import functools 3 | from abc import ABC, abstractmethod 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class Module(nn.Module, ABC): 10 | scheme = None 11 | margin = None 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.level = None 16 | self.depth = None 17 | self.fused = False 18 | self.he_mode = False 19 | 20 | @staticmethod 21 | def set_scheme(scheme): 22 | Module.scheme = scheme 23 | 24 | @staticmethod 25 | def set_margin(margin): 26 | Module.margin = margin 27 | 28 | def _set_mode_for_all(self, he_mode=False, training=True): 29 | for m in self.modules(): 30 | m.training = training 31 | if hasattr(m, "he_mode"): 32 | m.he_mode = he_mode 33 | 34 | def _set_attribute_for_all(self, attr, value): 35 | for m in self.modules(): 36 | setattr(m, attr, value) 37 | 38 | def extra_repr(self): 39 | torch_repr = super().extra_repr() 40 | orion_repr = (", " if torch_repr else "") + f"level={self.level}" 41 | return torch_repr + orion_repr 42 | 43 | def train(self, mode=True): 44 | self._set_mode_for_all(he_mode=False, training=mode) 45 | 46 | def eval(self): 47 | self._set_mode_for_all(he_mode=False, training=False) 48 | 49 | def he(self): 50 | self._set_mode_for_all(he_mode=True, training=False) 51 | 52 | def set_depth(self, depth): 53 | self.depth = depth 54 | 55 | def set_level(self, level): 56 | self.level = level 57 | 58 | @abstractmethod 59 | def forward(self, x): 60 | raise NotImplementedError( 61 | f"The 'forward' method is not implemented in {type(self).__name__}. " 62 | "All Orion modules must override this method with a custom " 63 | "implementation." 64 | ) 65 | 66 | 67 | def timer(func): 68 | @functools.wraps(func) 69 | @torch.compiler.disable 70 | def wrapper(self, *args, **kwargs): 71 | if not self.he_mode: 72 | return func(self, *args, **kwargs) 73 | 74 | debug_enabled = self.scheme.params.get_debug_status() 75 | if debug_enabled: 76 | layer_name = getattr(self, "name", self.__class__.__name__) 77 | print(f"\n{layer_name}:") 78 | 79 | # Print input statistics 80 | print(f"Clear input min/max: {self.input_min:.3f} / {self.input_max:.3f}") 81 | print(f"FHE input min/max: {args[0].min():.3f} / {args[0].max():.3f}") 82 | 83 | start = time.time() # start timer that ends after module finishes 84 | 85 | result = func(self, *args, **kwargs) 86 | 87 | # Finish timing and print output stats if in debug mode 88 | if debug_enabled: 89 | if hasattr(self, "output_min"): 90 | output_min = self.output_min 91 | output_max = self.output_max 92 | else: # for bootstrap 93 | output_min = self.input_min 94 | output_max = self.input_max 95 | 96 | elapsed = time.time() - start 97 | 98 | print(f"Clear output min/max: {output_min:.3f} / {output_max:.3f}") 99 | print(f"FHE output min/max: {result.min():.3f} / {result.max():.3f}") 100 | print(f"done! [{elapsed:.3f} secs.]") 101 | 102 | return result 103 | 104 | return wrapper -------------------------------------------------------------------------------- /orion/nn/normalization.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .module import Module, timer 7 | from ..core import packing 8 | 9 | class BatchNormNd(Module): 10 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 11 | super().__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.momentum = momentum 15 | self.affine = affine 16 | self.fused = False 17 | self.set_depth(2 if affine else 1) 18 | 19 | if self.affine: 20 | self.weight = nn.Parameter(torch.ones(num_features)) 21 | self.bias = nn.Parameter(torch.zeros(num_features)) 22 | else: 23 | self.register_parameter('weight', None) 24 | self.register_parameter('bias', None) 25 | 26 | self.register_buffer('running_mean', torch.zeros(num_features)) 27 | self.register_buffer('running_var', torch.ones(num_features)) 28 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 29 | 30 | @abstractmethod 31 | def _check_input_dim(self, x): 32 | raise NotImplementedError("Subclasses must implement _check_input_dim") 33 | 34 | def init_orion_params(self): 35 | self.on_running_mean = self.running_mean.data.clone() 36 | self.on_running_var = self.running_var.data.clone() 37 | 38 | if self.affine: 39 | self.on_weight = self.weight.data.clone() 40 | self.on_bias = self.bias.data.clone() 41 | else: 42 | self.on_weight = torch.ones_like(self.on_running_mean) 43 | self.on_bias = torch.zeros_like(self.on_running_mean) 44 | 45 | def extra_repr(self): 46 | return super().extra_repr() + f", level={self.level}, fused={self.fused}" 47 | 48 | def compile(self, a, b, c, d): 49 | level = self.level 50 | encoder = self.scheme.encoder 51 | 52 | q1 = encoder.get_moduli_chain()[level] 53 | q2 = encoder.get_moduli_chain()[level - 1] 54 | 55 | # In order to ensure errorless neural network evaluation, we'll 56 | # need to encode the scaling/shifting and affine maps at the 57 | # correct scale value. 58 | self.on_running_mean_ptxt = encoder.encode(a, level=level, scale=q1) 59 | self.on_inv_running_std_ptxt = encoder.encode(b, level=level, scale=q1) 60 | 61 | if self.affine: 62 | self.on_weight_ptxt = encoder.encode(c, level=level-1, scale=q2) 63 | self.on_bias_ptxt = encoder.encode(d, level=level-1, scale=q2) 64 | 65 | @timer 66 | def forward(self, x): 67 | if not self.he_mode: 68 | self._check_input_dim(x) 69 | 70 | if self.training: 71 | exponential_average_factor = 0.0 72 | if self.momentum is not None: 73 | exponential_average_factor = self.momentum 74 | if self.num_batches_tracked is not None: 75 | self.num_batches_tracked += 1 76 | if self.momentum is None: # use cumulative moving average 77 | exponential_average_factor = 1.0 / self.num_batches_tracked 78 | else: 79 | exponential_average_factor = 0.0 80 | 81 | if not self.he_mode: 82 | return torch.nn.functional.batch_norm( 83 | x, 84 | self.running_mean, 85 | self.running_var, 86 | self.weight, 87 | self.bias, 88 | self.training, 89 | exponential_average_factor, 90 | self.eps 91 | ) 92 | 93 | # In HE evaluation mode. 94 | if not self.fused: 95 | x -= self.on_running_mean_ptxt 96 | x *= self.on_inv_running_std_ptxt 97 | 98 | if self.affine: 99 | x *= self.on_weight_ptxt 100 | x += self.on_bias_ptxt 101 | 102 | return x 103 | 104 | 105 | class BatchNorm1d(BatchNormNd): 106 | def _check_input_dim(self, x): 107 | if x.dim() != 2 and x.dim() != 3: 108 | raise ValueError(f'expected 2D or 3D input (got {x.dim()}D input)') 109 | 110 | def compile(self): 111 | a, b, c, d = packing.pack_bn1d(self) 112 | super().compile(a, b, c, d) 113 | 114 | 115 | class BatchNorm2d(BatchNormNd): 116 | def _check_input_dim(self, x): 117 | if x.dim() != 4: 118 | raise ValueError(f'expected 4D input (got {x.dim()}D input)') 119 | 120 | def compile(self): 121 | a, b, c, d = packing.pack_bn2d(self) 122 | super().compile(a, b, c, d) 123 | 124 | 125 | -------------------------------------------------------------------------------- /orion/nn/operations.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from .module import Module, timer 5 | 6 | class Add(Module): 7 | def __init__(self): 8 | super().__init__() 9 | self.set_depth(0) 10 | 11 | def forward(self, x, y): 12 | return x + y 13 | 14 | 15 | class Mult(Module): 16 | def __init__(self): 17 | super().__init__() 18 | self.set_depth(1) 19 | 20 | def forward(self, x, y): 21 | return x * y 22 | 23 | 24 | class Bootstrap(Module): 25 | def __init__(self, input_min, input_max, input_level): 26 | super().__init__() 27 | self.input_min = input_min 28 | self.input_max = input_max 29 | self.input_level = input_level 30 | self.prescale = 1 31 | self.postscale = 1 32 | self.constant = 0 33 | 34 | def extra_repr(self): 35 | l_eff = len(self.scheme.params.get_logq()) - 1 36 | return f"l_eff={l_eff}" 37 | 38 | def fit(self): 39 | center = (self.input_min + self.input_max) / 2 40 | half_range = (self.input_max - self.input_min) / 2 41 | self.low = (center - (self.margin * half_range)).item() 42 | self.high = (center + (self.margin * half_range)).item() 43 | 44 | # We'll want to scale from [A, B] into [-1, 1] using a value of the 45 | # form 1 / integer, so that way our multiplication back to the range 46 | # [A, B] (by integer) after bootstrapping doesn't consume a level. 47 | if self.high - self.low > 2: 48 | self.postscale = math.ceil((self.high - self.low) / 2) 49 | self.prescale = 1 / self.postscale 50 | 51 | self.constant = -(self.low + self.high) / 2 52 | 53 | def compile(self): 54 | # We'll then encode the prescale at the level of the input ciphertext 55 | # to ensure its rescaling is errorless 56 | elements = self.fhe_input_shape.numel() 57 | curr_slots = 2 ** math.ceil(math.log2(elements)) 58 | 59 | prescale_vec = torch.zeros(curr_slots) 60 | prescale_vec[:elements] = self.prescale 61 | 62 | ql = self.scheme.encoder.get_moduli_chain()[self.input_level] 63 | self.prescale_ptxt = self.scheme.encoder.encode( 64 | prescale_vec, level=self.input_level, scale=ql) 65 | 66 | @timer 67 | def forward(self, x): 68 | if not self.he_mode: 69 | return x 70 | 71 | # Shift and scale into range [-1, 1]. Important caveat -- here we first 72 | # shift, then scale. This let's us zero out unused slots and enables 73 | # sparse bootstrapping (i.e., where slots < N/2). 74 | if self.constant != 0: 75 | x += self.constant 76 | x *= self.prescale_ptxt 77 | 78 | x = x.bootstrap() 79 | 80 | # Scale and shift back to the original range 81 | if self.postscale != 1: 82 | x *= self.postscale 83 | if self.constant != 0: 84 | x -= self.constant 85 | 86 | return x 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /orion/nn/pooling.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .linear import Conv2d 6 | 7 | 8 | class AvgPool2d(Conv2d): 9 | def __init__( 10 | self, 11 | kernel_size, 12 | stride=None, 13 | padding=0, 14 | bsgs_ratio=2, 15 | level=None, 16 | ): 17 | super().__init__(1, 1, kernel_size, stride or kernel_size, padding, 18 | dilation=1, groups=1, bias=False, 19 | bsgs_ratio=bsgs_ratio, level=level) 20 | 21 | def extra_repr(self): 22 | return (f"AvgPool2d(kernel_size={self.kernel_size}, stride={self.stride} " + 23 | f"level = {self.level})" 24 | ) 25 | 26 | def update_params(self): 27 | kH, kW = self.kernel_size 28 | self.in_channels = self.out_channels = self.groups = self.input_shape[1] 29 | self.on_weight = torch.ones(self.out_channels, 1, kH, kW) / (kH * kW) 30 | self.on_bias = torch.zeros(self.out_channels) 31 | 32 | def forward(self, x): 33 | if not self.he_mode: 34 | return F.avg_pool2d(x, self.kernel_size, self.stride, self.padding) 35 | 36 | return super().forward(x) 37 | 38 | 39 | class AdaptiveAvgPool2d(AvgPool2d): 40 | def __init__( 41 | self, 42 | output_size, 43 | bsgs_ratio=2, 44 | level=None 45 | ): 46 | super().__init__(kernel_size=1, stride=1, padding=0, 47 | bsgs_ratio=bsgs_ratio, level=level) 48 | 49 | if isinstance(output_size, int): 50 | output_size = (output_size, output_size) 51 | self.output_size = output_size 52 | 53 | def extra_repr(self): 54 | return (f"AdaptiveAvgPool2d(output_size={self.output_size}) " + 55 | f"level={self.level})" 56 | ) 57 | 58 | def update_params(self): 59 | Hi, Wi = self.input_shape[2:] 60 | Ho, Wo = self.output_size 61 | 62 | self.stride = (Hi // Ho, Wi // Wo) 63 | kH = Hi - (Ho - 1) * self.stride[0] 64 | kW = Wi - (Wo - 1) * self.stride[1] 65 | self.kernel_size = (kH, kW) 66 | super().update_params() 67 | 68 | def compute_fhe_output_gap(self, **kwargs): 69 | input_gap = kwargs['input_gap'] 70 | input_shape = kwargs['input_shape'] 71 | output_shape = kwargs['output_shape'] 72 | 73 | # We'll have to manually calculate the stride here because it is not 74 | # passed as an argument to AdaptiveAvgPool2d, yet we need it ASAP 75 | # to propagate FHE shapes and multiplexed gaps. 76 | return input_gap * (input_shape[2] // output_shape[2]) 77 | 78 | def compute_fhe_output_shape(self, **kwargs): 79 | input_shape = kwargs['input_shape'] 80 | output_shape = kwargs['clear_output_shape'] 81 | input_gap = kwargs['input_gap'] 82 | 83 | Hi, Wi = input_shape[2:] 84 | No, Co, Ho, Wo = output_shape 85 | 86 | output_gap = self.compute_fhe_output_gap( 87 | input_gap=input_gap, input_shape=input_shape, output_shape=output_shape 88 | ) 89 | 90 | # We'll also need to compute this ASAP too, since FHE shapes are 91 | # propogated to future layers in orion.fit(). 92 | on_Co = math.ceil(Co / (output_gap**2)) 93 | on_Ho = max(Hi, Ho*output_gap) 94 | on_Wo = max(Wi, Wo*output_gap) 95 | 96 | return torch.Size((No, on_Co, on_Ho, on_Wo)) 97 | 98 | def forward(self, x): 99 | if not self.he_mode: 100 | Ho, Wo = self.output_size 101 | if x.shape[2] % Ho != 0 or x.shape[3] % Wo != 0: 102 | raise ValueError( 103 | f"Output spatial dimensions {self.output_size} are not " + 104 | f"a multiple of the input spatial dimensions {x.shape[2:]}." 105 | ) 106 | return F.adaptive_avg_pool2d(x, self.output_size) 107 | 108 | return super().forward(x) -------------------------------------------------------------------------------- /orion/nn/reshape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .module import Module 4 | 5 | class Flatten(Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.set_depth(0) 9 | 10 | def extra_repr(self): 11 | return super().extra_repr() + ", start_dim=1" 12 | 13 | def forward(self, x): 14 | if self.he_mode: 15 | return x 16 | return torch.flatten(x, start_dim=1) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core>=1.0.2", "setuptools>=61.0"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [project] 6 | name = "orion-fhe" 7 | version = "1.0.2" 8 | description = "A Fully Homomorphic Encryption Framework for Deep Learning" 9 | authors = [ 10 | {name = "Austin Ebel", email = "abe5240@nyu.edu"} 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.9,<3.13" 14 | dependencies = [ 15 | "PyYAML>=6.0", 16 | "torch>=2.2.0", 17 | "torchvision>=0.17.0", 18 | "tqdm>=4.30.0", 19 | "numpy>=1.21.0", 20 | "scipy>=1.7.0,<=1.14.1", 21 | "matplotlib>=3.1.0", 22 | "h5py>=3.5.0", 23 | "certifi>=2024.2.2", 24 | ] 25 | 26 | [tool.poetry] 27 | packages = [{include = "orion"}] 28 | 29 | [tool.poetry.build] 30 | generate-setup-file = true 31 | script = "tools/build_lattigo.py" 32 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/baahl-nyu/orion/f0581052b28d02a00299cce742949930b3260aa8/tests/__init__.py -------------------------------------------------------------------------------- /tests/configs/mlp.yml: -------------------------------------------------------------------------------- 1 | comment: Config for MLP from https://eprint.iacr.org/2017/396.pdf 2 | 3 | ckks_params: 4 | LogN: 13 5 | LogQ: [29, 26, 26, 26, 26, 26] 6 | LogP: [29, 29] 7 | LogScale: 26 8 | H: 8192 9 | RingType: ConjugateInvariant 10 | 11 | orion: 12 | margin: 2 # >= 1 13 | embedding_method: hybrid # [hybrid, square] 14 | backend: lattigo # [lattigo, openfhe, heaan] 15 | 16 | fuse_modules: true 17 | debug: false 18 | 19 | diags_path: ../data/diagonals.h5 # "path/to/diags" | "" 20 | keys_path: ../data/keys.h5 # "path/to/keys" | "" 21 | io_mode: none # "load" | "save" | "none" 22 | -------------------------------------------------------------------------------- /tests/models/test_mlp.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from pathlib import Path 4 | 5 | import torch 6 | import orion 7 | import orion.models as models 8 | from orion.core.utils import get_mnist_datasets, mae 9 | 10 | def get_config_path(yml_str): 11 | orion_path = Path(__file__).parent.parent 12 | return str(orion_path / "configs" / f"{yml_str}") 13 | 14 | def test_mlp(): 15 | torch.manual_seed(42) # set seed 16 | 17 | # Initialize the Orion scheme and model 18 | orion.init_scheme(get_config_path("mlp.yml")) 19 | trainloader, testloader = get_mnist_datasets(data_dir="./data", batch_size=1) 20 | net = models.MLP() 21 | 22 | # Get a test batch to pass through our network 23 | inp, _ = next(iter(testloader)) 24 | 25 | # Run cleartext inference 26 | net.eval() 27 | out_clear = net(inp) 28 | 29 | # Fit and compile 30 | orion.fit(net, trainloader) 31 | input_level = orion.compile(net) 32 | 33 | # Encode and encrypt the input vector 34 | vec_ptxt = orion.encode(inp, input_level) 35 | vec_ctxt = orion.encrypt(vec_ptxt) 36 | net.he() # Switch to FHE mode 37 | 38 | # Run FHE inference 39 | out_ctxt = net(vec_ctxt) 40 | 41 | # Get the FHE results and decrypt + decode. 42 | out_ptxt = out_ctxt.decrypt() 43 | out_fhe = out_ptxt.decode() 44 | 45 | dist = mae(out_clear, out_fhe) 46 | 47 | # small tolerable difference depends on parameter set 48 | assert dist < 0.005 49 | -------------------------------------------------------------------------------- /tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_orion_core_imports(): 2 | """Test that core Orion modules can be imported.""" 3 | import orion 4 | import orion.nn as on 5 | import orion.models as models 6 | import orion.core.utils as utils 7 | 8 | assert orion.__name__ == "orion" 9 | assert on.__name__ == "orion.nn" 10 | assert models.__name__ == "orion.models" 11 | assert utils.__name__ == "orion.core.utils" 12 | 13 | def test_linear_transforms(): 14 | """Test that linear transform modules can be instantiated.""" 15 | import orion.nn as on 16 | 17 | linear = on.Linear(10, 10) 18 | assert isinstance(linear, on.Linear) 19 | 20 | conv = on.Conv2d(1, 3, kernel_size=3, stride=1, padding=1) 21 | assert isinstance(conv, on.Conv2d) 22 | 23 | avg_pool = on.AvgPool2d(kernel_size=3, stride=3) 24 | assert isinstance(avg_pool, on.AvgPool2d) 25 | 26 | adaptive_pool = on.AdaptiveAvgPool2d(output_size=1) 27 | assert isinstance(adaptive_pool, on.AdaptiveAvgPool2d) 28 | 29 | def test_activation_functions(): 30 | """Test that activation function modules can be instantiated.""" 31 | import orion.nn as on 32 | 33 | activation = on.Activation(coeffs=[1,0,0]) 34 | assert isinstance(activation, on.Activation) 35 | 36 | quad = on.Quad() 37 | assert isinstance(quad, on.Quad) 38 | 39 | sigmoid = on.Sigmoid(degree=31) 40 | assert isinstance(sigmoid, on.Sigmoid) 41 | 42 | silu = on.SiLU(degree=127) 43 | assert isinstance(silu, on.SiLU) 44 | 45 | gelu = on.GELU() 46 | assert isinstance(gelu, on.GELU) 47 | 48 | relu = on.ReLU(degrees=[15,15,27], logalpha=6, logerr=12) 49 | assert isinstance(relu, on.ReLU) 50 | 51 | def test_normalization(): 52 | """Test that normalization modules can be instantiated.""" 53 | import orion.nn as on 54 | 55 | bn1d = on.BatchNorm1d(32) 56 | assert isinstance(bn1d, on.BatchNorm1d) 57 | 58 | bn2d = on.BatchNorm2d(32) 59 | assert isinstance(bn2d, on.BatchNorm2d) 60 | 61 | def test_operations(): 62 | """Test that operation modules can be instantiated.""" 63 | import orion.nn as on 64 | 65 | add = on.Add() 66 | assert isinstance(add, on.Add) 67 | 68 | mult = on.Mult() 69 | assert isinstance(mult, on.Mult) 70 | 71 | bootstrap = on.Bootstrap(-1, 1, input_level=1) # internal module 72 | assert isinstance(bootstrap, on.Bootstrap) 73 | 74 | def test_reshape(): 75 | """Test that reshape modules can be instantiated.""" 76 | import orion.nn as on 77 | 78 | flatten = on.Flatten() 79 | assert isinstance(flatten, on.Flatten) -------------------------------------------------------------------------------- /tests/test_placeholder.py: -------------------------------------------------------------------------------- 1 | def test_placeholder(): 2 | """Temporary test to prevent pytest from failing due to no tests.""" 3 | assert True # Always passes -------------------------------------------------------------------------------- /tools/build_lattigo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import platform 4 | import subprocess 5 | from pathlib import Path 6 | 7 | def build(setup_kwargs=None): 8 | """Build the Go shared library for Lattigo.""" 9 | print("=== Building Go shared library ===") 10 | 11 | # Determine the output filename based on platform 12 | if platform.system() == "Windows": 13 | output_file = "lattigo-windows.dll" 14 | elif platform.system() == "Darwin": # macOS 15 | if platform.machine().lower() in ("arm64", "aarch64"): 16 | output_file = "lattigo-mac-arm64.dylib" 17 | else: 18 | output_file = "lattigo-mac.dylib" 19 | elif platform.system() == "Linux": 20 | output_file = "lattigo-linux.so" 21 | else: 22 | raise RuntimeError("Unsupported platform") 23 | 24 | # Set up paths 25 | root_dir = Path(__file__).parent.parent 26 | backend_dir = root_dir / "orion" / "backend" / "lattigo" 27 | output_path = backend_dir / output_file 28 | 29 | # Set up CGO for Go build 30 | env = os.environ.copy() 31 | env["CGO_ENABLED"] = "1" 32 | 33 | # Set architecture for macOS 34 | if platform.system() == "Darwin": 35 | if platform.machine().lower() in ("arm64", "aarch64"): 36 | env["GOARCH"] = "arm64" 37 | else: 38 | env["GOARCH"] = "amd64" 39 | 40 | # Build command 41 | build_cmd = [ 42 | "go", "build", 43 | "-buildmode=c-shared", 44 | "-buildvcs=false", 45 | "-o", str(output_path), 46 | str(backend_dir) 47 | ] 48 | 49 | # Run the build command with the configured environment 50 | try: 51 | print(f"Running: {' '.join(build_cmd)}") 52 | subprocess.run(build_cmd, cwd=str(backend_dir), env=env, check=True) 53 | print(f"Successfully built {output_file}") 54 | except subprocess.CalledProcessError as e: 55 | print(f"Go build failed with exit code {e.returncode}") 56 | sys.exit(1) 57 | 58 | # Return setup_kwargs for Poetry 59 | return setup_kwargs or {} 60 | 61 | if __name__ == "__main__": 62 | success = build() 63 | sys.exit(0 if success else 1) --------------------------------------------------------------------------------