├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── cmake └── modules │ ├── AddHecate.cmake │ └── CMakeLists.txt ├── config.json ├── config.sh ├── examples ├── benchmarks │ ├── LinearRegression.py │ ├── MLP.py │ └── SobelFilter.py ├── data │ ├── cornertest.jpg │ └── mlp.model └── tests │ ├── LinearRegression.py │ ├── MLP.py │ └── SobelFilter.py ├── include ├── hecate │ ├── CMakeLists.txt │ ├── Conversion │ │ ├── CKKSCommon │ │ │ └── PolyTypeConverter.h │ │ ├── CKKSToCKKS │ │ │ └── UpscaleToMulcp.h │ │ ├── CMakeLists.txt │ │ ├── EarthToCKKS │ │ │ └── EarthToCKKS.h │ │ ├── Passes.h │ │ └── Passes.td │ ├── Dialect │ │ ├── CKKS │ │ │ ├── CMakeLists.txt │ │ │ ├── IR │ │ │ │ ├── CKKSOps.h │ │ │ │ ├── CKKSOps.td │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── PolyTypeInterface.h │ │ │ │ └── PolyTypeInterface.td │ │ │ └── Transforms │ │ │ │ ├── CMakeLists.txt │ │ │ │ ├── Passes.h │ │ │ │ └── Passes.td │ │ ├── CMakeLists.txt │ │ └── Earth │ │ │ ├── Analysis │ │ │ ├── AutoDifferentiation.h │ │ │ └── ScaleManagementUnit.h │ │ │ ├── CMakeLists.txt │ │ │ ├── IR │ │ │ ├── CMakeLists.txt │ │ │ ├── EarthCanonicalizer.td │ │ │ ├── EarthOps.h │ │ │ ├── EarthOps.td │ │ │ ├── ForwardManagementInterface.h │ │ │ ├── ForwardManagementInterface.td │ │ │ ├── HEParameterInterface.h │ │ │ └── HEParameterInterface.td │ │ │ └── Transforms │ │ │ ├── CMakeLists.txt │ │ │ ├── Common.h │ │ │ ├── Passes.h │ │ │ └── Passes.td │ └── Support │ │ ├── HEVMHeader.h │ │ └── Support.h └── nlohmann │ ├── json.hpp │ └── json_fwd.hpp ├── install.sh ├── lib ├── CMakeLists.txt ├── Conversion │ ├── CKKSCommon │ │ ├── CMakeLists.txt │ │ └── PolyTypeConverter.cpp │ ├── CKKSToCKKS │ │ ├── CMakeLists.txt │ │ └── UpscaleToMulcp.cpp │ ├── CMakeLists.txt │ └── EarthToCKKS │ │ ├── CMakeLists.txt │ │ └── EarthToCKKS.cpp ├── Dialect │ ├── CKKS │ │ ├── CMakeLists.txt │ │ ├── IR │ │ │ ├── CKKSDialect.cpp │ │ │ └── CMakeLists.txt │ │ └── Transforms │ │ │ ├── CMakeLists.txt │ │ │ ├── EmitHEVM.cpp │ │ │ ├── RemoveLevel.cpp │ │ │ └── ReuseBuffer.cpp │ ├── CMakeLists.txt │ └── Earth │ │ ├── Analysis │ │ ├── AutoDifferentiation.cpp │ │ ├── CMakeLists.txt │ │ └── ScaleManagementUnit.cpp │ │ ├── CMakeLists.txt │ │ ├── IR │ │ ├── CMakeLists.txt │ │ ├── EarthDialect.cpp │ │ ├── ForwardManagementInterface.cpp │ │ └── HEParameterInterface.cpp │ │ └── Transforms │ │ ├── CMakeLists.txt │ │ ├── Common.cpp │ │ ├── ELASMExplorer.cpp │ │ ├── EarlyModswitch.cpp │ │ ├── ElideConstant.cpp │ │ ├── ErrorEstimator.cpp │ │ ├── LatencyEstimator.cpp │ │ ├── PrivatizeConstant.cpp │ │ ├── ProactiveRescaling.cpp │ │ ├── SMUChecker.cpp │ │ ├── SMUEmbedding.cpp │ │ ├── SNRRescaling.cpp │ │ ├── ScaleManagementScheduler.cpp │ │ ├── UpscaleBubbling.cpp │ │ └── WaterlineRescaling.cpp └── Runtime │ ├── CMakeLists.txt │ └── SEAL_HEVM.cpp ├── python └── hecate │ ├── hecate │ ├── __init__.py │ ├── expr.py │ └── runner.py │ └── setup.py ├── requirements.txt ├── tools ├── CMakeLists.txt ├── frontend.cpp └── optimizer.cpp └── versions.txt /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveMacros: false 7 | AlignConsecutiveAssignments: false 8 | AlignConsecutiveDeclarations: false 9 | AlignEscapedNewlines: Right 10 | AlignOperands: true 11 | AlignTrailingComments: true 12 | AllowAllArgumentsOnNextLine: true 13 | AllowAllConstructorInitializersOnNextLine: true 14 | AllowAllParametersOfDeclarationOnNextLine: true 15 | AllowShortBlocksOnASingleLine: Never 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: All 18 | AllowShortLambdasOnASingleLine: All 19 | AllowShortIfStatementsOnASingleLine: Never 20 | AllowShortLoopsOnASingleLine: false 21 | AlwaysBreakAfterDefinitionReturnType: None 22 | AlwaysBreakAfterReturnType: None 23 | AlwaysBreakBeforeMultilineStrings: false 24 | AlwaysBreakTemplateDeclarations: MultiLine 25 | BinPackArguments: true 26 | BinPackParameters: true 27 | BraceWrapping: 28 | AfterCaseLabel: false 29 | AfterClass: false 30 | AfterControlStatement: false 31 | AfterEnum: false 32 | AfterFunction: false 33 | AfterNamespace: false 34 | AfterObjCDeclaration: false 35 | AfterStruct: false 36 | AfterUnion: false 37 | AfterExternBlock: false 38 | BeforeCatch: false 39 | BeforeElse: false 40 | IndentBraces: false 41 | SplitEmptyFunction: true 42 | SplitEmptyRecord: true 43 | SplitEmptyNamespace: true 44 | BreakBeforeBinaryOperators: None 45 | BreakBeforeBraces: Attach 46 | BreakBeforeInheritanceComma: false 47 | BreakInheritanceList: BeforeColon 48 | BreakBeforeTernaryOperators: true 49 | BreakConstructorInitializersBeforeComma: false 50 | BreakConstructorInitializers: BeforeColon 51 | BreakAfterJavaFieldAnnotations: false 52 | BreakStringLiterals: true 53 | ColumnLimit: 80 54 | CommentPragmas: '^ IWYU pragma:' 55 | CompactNamespaces: false 56 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 57 | ConstructorInitializerIndentWidth: 4 58 | ContinuationIndentWidth: 4 59 | Cpp11BracedListStyle: true 60 | DeriveLineEnding: true 61 | DerivePointerAlignment: false 62 | DisableFormat: false 63 | ExperimentalAutoDetectBinPacking: false 64 | FixNamespaceComments: true 65 | ForEachMacros: 66 | - foreach 67 | - Q_FOREACH 68 | - BOOST_FOREACH 69 | IncludeBlocks: Preserve 70 | IncludeCategories: 71 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 72 | Priority: 2 73 | SortPriority: 0 74 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 75 | Priority: 3 76 | SortPriority: 0 77 | - Regex: '.*' 78 | Priority: 1 79 | SortPriority: 0 80 | IncludeIsMainRegex: '(Test)?$' 81 | IncludeIsMainSourceRegex: '' 82 | IndentCaseLabels: false 83 | IndentGotoLabels: true 84 | IndentPPDirectives: None 85 | IndentWidth: 2 86 | IndentWrappedFunctionNames: false 87 | JavaScriptQuotes: Leave 88 | JavaScriptWrapImports: true 89 | KeepEmptyLinesAtTheStartOfBlocks: true 90 | MacroBlockBegin: '' 91 | MacroBlockEnd: '' 92 | MaxEmptyLinesToKeep: 1 93 | NamespaceIndentation: None 94 | ObjCBinPackProtocolList: Auto 95 | ObjCBlockIndentWidth: 2 96 | ObjCSpaceAfterProperty: false 97 | ObjCSpaceBeforeProtocolList: true 98 | PenaltyBreakAssignment: 2 99 | PenaltyBreakBeforeFirstCallParameter: 19 100 | PenaltyBreakComment: 300 101 | PenaltyBreakFirstLessLess: 120 102 | PenaltyBreakString: 1000 103 | PenaltyBreakTemplateDeclaration: 10 104 | PenaltyExcessCharacter: 1000000 105 | PenaltyReturnTypeOnItsOwnLine: 60 106 | PointerAlignment: Right 107 | ReflowComments: true 108 | SortIncludes: true 109 | SortUsingDeclarations: true 110 | SpaceAfterCStyleCast: false 111 | SpaceAfterLogicalNot: false 112 | SpaceAfterTemplateKeyword: true 113 | SpaceBeforeAssignmentOperators: true 114 | SpaceBeforeCpp11BracedList: false 115 | SpaceBeforeCtorInitializerColon: true 116 | SpaceBeforeInheritanceColon: true 117 | SpaceBeforeParens: ControlStatements 118 | SpaceBeforeRangeBasedForLoopColon: true 119 | SpaceInEmptyBlock: false 120 | SpaceInEmptyParentheses: false 121 | SpacesBeforeTrailingComments: 1 122 | SpacesInAngles: false 123 | SpacesInConditionalStatement: false 124 | SpacesInContainerLiterals: true 125 | SpacesInCStyleCastParentheses: false 126 | SpacesInParentheses: false 127 | SpacesInSquareBrackets: false 128 | SpaceBeforeSquareBrackets: false 129 | Standard: Latest 130 | StatementMacros: 131 | - Q_UNUSED 132 | - QT_REQUIRE_VERSION 133 | TabWidth: 8 134 | UseCRLF: false 135 | UseTab: Never 136 | ... 137 | 138 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | build-debug 3 | .cache 4 | .root 5 | .undodir 6 | *.sw* 7 | *.un~ 8 | dot 9 | *.bc 10 | *.mlir 11 | *.hevm 12 | *.o 13 | *.so 14 | *.txt 15 | *.dot 16 | *.csv 17 | !CMakeLists.txt 18 | *__pycache__ 19 | *.cst 20 | *.nsys-rep 21 | *.sqlite 22 | !.gitignore 23 | !requirements.txt 24 | !versions.txt 25 | python/hecate/dist 26 | python/hecate/hecate.egg-info 27 | examples/optimized 28 | examples/traced 29 | Session.vim 30 | .venv 31 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.12) 2 | 3 | 4 | # VERSION SETTING 5 | 6 | if (NOT DEFINED HECATE_VERSION_MAJOR) 7 | set(HECATE_VERSION_MAJOR 0) 8 | endif () 9 | 10 | if (NOT DEFINED HECATE_VERSION_MINOR) 11 | set(HECATE_VERSION_MINOR 1) 12 | endif () 13 | 14 | if (NOT DEFINED HECATE_VERSION_PATCH) 15 | set(HECATE_VERSION_PATCH 0) 16 | endif () 17 | 18 | if (NOT DEFINED HECATE_VERSION_SUFFIX) 19 | set(HECATE_VERSION_SUFFIX git) 20 | endif () 21 | 22 | if (NOT PACKAGE_VERSION) 23 | set(PACKAGE_VERSION 24 | "${HECATE_VERSION_MAJOR}.${HECATE_VERSION_MINOR}.${HECATE_VERSION_PATCH}${HECATE_VERSION_SUFFIX}") 25 | endif() 26 | 27 | # PROJECT SETTING 28 | 29 | project(HECATE 30 | VERSION ${HECATE_VERSION_MAJOR}.${HECATE_VERSION_MINOR}.${HECATE_VERSION_PATCH} 31 | LANGUAGES CXX C) 32 | 33 | 34 | # CXX BUILD SETTING 35 | 36 | set(CMAKE_CXX_STANDARD 17) 37 | set(CMAKE_CXX_STANDARD_REQUIRED True) 38 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 39 | set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--no-undefined") 40 | 41 | if (NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) 42 | message(STATUS "No build type selected, default to Release") 43 | set(CMAKE_BUILD_TYPE "Release" CACHE STRING "Build type (default Release)" FORCE) 44 | endif() 45 | 46 | string(TOUPPER "${CMAKE_BUILD_TYPE}" uppercase_CMAKE_BUILD_TYPE) 47 | 48 | if (CMAKE_BUILD_TYPE AND 49 | NOT uppercase_CMAKE_BUILD_TYPE MATCHES "^(DEBUG|RELEASE|RELWITHDEBINFO|MINSIZEREL)$") 50 | message(FATAL_ERROR "Invalid value for CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}") 51 | endif() 52 | 53 | set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 54 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) 55 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 56 | # set(CMAKE_BUILD_TYPE Debug) 57 | 58 | # PACKAGE SETTING 59 | find_package(MLIR REQUIRED CONFIG) 60 | find_package(SEAL 4.0 REQUIRED CONFIG) 61 | 62 | # MLIR SETTING 63 | 64 | list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") 65 | list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") 66 | list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake/modules") 67 | include(TableGen) 68 | include(AddLLVM) 69 | include(AddMLIR) 70 | include(AddHecate) 71 | include(HandleLLVMOptions) 72 | 73 | include_directories(${LLVM_INCLUDE_DIRS}) 74 | include_directories(${MLIR_INCLUDE_DIRS}) 75 | include_directories(${PROJECT_SOURCE_DIR}/include) 76 | include_directories(${PROJECT_BINARY_DIR}/include) 77 | link_directories(${LLVM_BUILD_LIBRARY_DIR}) 78 | add_definitions(${LLVM_DEFINITIONS}) 79 | 80 | set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) 81 | set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) 82 | set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) 83 | 84 | # HECATE DIRECTORY 85 | 86 | add_compile_options(-fno-rtti) 87 | add_compile_options(-fno-exceptions) 88 | add_compile_options(-fPIC) 89 | add_definitions (-DMEM_CHECK) 90 | 91 | add_subdirectory (include/hecate) 92 | add_subdirectory (lib) 93 | add_subdirectory (tools) 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hecate-compiler 2 | Hecate (Homomorphic Encryption Compiler for Approximate TEnsor computation) is an optimizing compiler for the CKKS FHE scheme, built by Compiler Optimization Research Laboratory (Corelab) @ Yonsei University. 3 | Hecate is built on the top of Multi-Level Intermediate Representation (MLIR) compiler framework. 4 | We aim to support privacy-preserving machine learning and deep learning applications. 5 | 6 | 7 | * [Installation](#installation) 8 | + [Requirements](#requirements) 9 | + [Install MLIR](#install-mlir) 10 | + [Install SEAL](#install-seal) 11 | + [Build Hecate](#build-hecate) 12 | + [Configure Hecate](#configure-hecate) 13 | + [Install Hecate Python Binding](#install-hecate-python-binding) 14 | * [Tutorial](#tutorial) 15 | + [Trace the example python file to Encrypted ARiTHmetic IR](#trace-the-example-python-file-to-encrypted-arithmetic-ir) 16 | + [Compile the traced Earth Hecate IR](#compile-the-traced-earth-ir) 17 | + [Test the optimized code](#test-the-optimized-code) 18 | + [One-liner for compilation and testing](#one-liner-for-compilation-and-testing) 19 | * [Papers](#papers) 20 | * [Citations](#citations) 21 | 22 | ## Installation 23 | 24 | ### Requirements 25 | ``` 26 | Ninja 27 | git 28 | cmake >= 3.22.1 29 | python >= 3.10 30 | clang,clang++ >= 14.0.0 31 | ``` 32 | 33 | ### Install MLIR 34 | ```bash 35 | git clone https://github.com/llvm/llvm-project.git 36 | cd llvm-project 37 | git checkout llvmorg-16.0.0 38 | cmake -GNinja -Bbuild \ 39 | -DCMAKE_C_COMPILER=clang \ 40 | -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release \ 41 | -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_INSTALL_UTILS=ON \ 42 | -DLLVM_TARGETS_TO_BUILD=host \ 43 | llvm 44 | cmake --build build 45 | sudo cmake --install build 46 | cd .. 47 | ``` 48 | #### Optional : Install Directory to maintain multiple versions or a debug build 49 | ```bash 50 | git clone https://github.com/llvm/llvm-project.git 51 | cd llvm-project 52 | git checkout llvmorg-16.0.0 53 | cmake -GNinja -Bbuild \ 54 | -DCMAKE_C_COMPILER=clang \ 55 | -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_BUILD_TYPE=Release\ 56 | -DLLVM_ENABLE_PROJECTS=mlir -DLLVM_INSTALL_UTILS=ON \ 57 | -DLLVM_TARGETS_TO_BUILD=host -DCMAKE_INSTALL_PREFIX=\ 58 | llvm 59 | cmake --build build 60 | sudo cmake --install build 61 | cd .. 62 | ``` 63 | 64 | ### Install SEAL 65 | ```bash 66 | git clone https://github.com/microsoft/SEAL.git 67 | cd SEAL 68 | git checkout 4.0.0 69 | cmake -S . -B build 70 | cmake --build build 71 | sudo cmake --install build 72 | cd .. 73 | ``` 74 | #### Optional : Install Directory to maintain multiple versions or a debug build 75 | ```bash 76 | git clone https://github.com/microsoft/SEAL.git 77 | cd SEAL 78 | git checkout 4.0.0 79 | cmake -S . -B build -DCMAKE_INSTALL_PREFIX= 80 | cmake --build build 81 | sudo cmake --install build 82 | cd .. 83 | ``` 84 | ### Build Hecate 85 | ```bash 86 | git clone 87 | cd 88 | cmake -S . -B build 89 | cmake --build build 90 | ``` 91 | #### Optional : Install Directory to maintain multiple versions or a debug build 92 | ```bash 93 | git clone 94 | cd 95 | cmake -S . -B build -DMLIR_ROOT= -DSEAL_ROOT= 96 | cmake --build build 97 | ``` 98 | ### Configure Hecate 99 | ```bash 100 | python3 -m venv .venv 101 | source .venv/bin/activate 102 | source config.sh 103 | ``` 104 | 105 | ### Install Hecate Python Binding 106 | ```bash 107 | pip install -r requirements.txt 108 | ./install.sh 109 | ``` 110 | 111 | ## Tutorial 112 | 113 | ### Trace the example python file to Encrypted ARiTHmetic IR 114 | 115 | ```bash 116 | hc-trace 117 | ``` 118 | e.g., 119 | ```bash 120 | hc-trace MLP 121 | ``` 122 | 123 | ### Compile the traced Earth IR 124 | 125 | ```bash 126 | hopts 127 | ``` 128 | e.g., 129 | ```bash 130 | hopts elasm 30 MLP 131 | ``` 132 | 133 | ### Test the optimized code 134 | ```bash 135 | hc-test 136 | ``` 137 | e.g., 138 | ```bash 139 | hc-test elasm 30 MLP 140 | ``` 141 | 142 | This command will print like this: 143 | ``` 144 | 1.810851535 145 | 9.63624118628367e-06 146 | ``` 147 | 148 | The first line shows the wall-clock time for FHE execution 149 | The second line shows the RMS of the resulting error 150 | 151 | ### One-liner for compilation and testing 152 | ```bash 153 | hcot 154 | ``` 155 | With printing pass timings : 156 | ```bash 157 | hcott 158 | ``` 159 | 160 | ## DaCapo Source Code 161 | The compiler DaCapo, which supports automatic bootstrapping placement for CKKS, has been placed in a separate repository 162 | due to license issue. You can find more details in the [DaCapo Repository](https://github.com/corelab-src/dacapo). 163 | 164 | ## Papers 165 | **DaCapo: Automatic Bootstrapping Management for Efficient Fully Homomorphic Encryption**\ 166 | Seonyoung Cheon, Yongwoo Lee, Ju Min Lee, Dongkwan Kim, Sunchul Jung, Taekyung Kim, Dongyoon Lee, and Hanjun Kim 167 | *33rd USENIX Security Symposium (USENIX Security)*, August 2024. 168 | [[Prepublication](https://www.usenix.org/system/files/sec24summer-prepub-336-cheon.pdf)] 169 | 170 | **ELASM: Error-Latency-Aware Scale Management for Fully Homomorphic Encryption** [[abstract](https://www.usenix.org/conference/usenixsecurity23/presentation/lee-yongwoo)] 171 | Yongwoo Lee, Seonyoung Cheon, Dongkwan Kim, Dongyoon Lee, and Hanjun Kim 172 | *32nd USENIX Security Symposium (USENIX Security)*, August 2023. 173 | [[Prepublication](https://www.usenix.org/system/files/usenixsecurity23-lee-yongwoo.pdf)] 174 | 175 | **HECATE: Performance-Aware Scale Optimization for Homomorphic Encryption Compiler**\[[IEEE Xplore](http://doi.org/10.1109/CGO53902.2022.9741265)] 176 | Yongwoo Lee, Seonyeong Heo, Seonyoung Cheon, Shinnung Jeong, Changsu Kim, Eunkyung Kim, Dongyoon Lee, and Hanjun Kim 177 | *Proceedings of the 2022 International Symposium on Code Generation and Optimization (CGO)*, April 2022. 178 | [[Prepublication](http://corelab.or.kr/Pubs/cgo22_hecate.pdf)] 179 | 180 | ## Citations 181 | ```bibtex 182 | @INPROCEEDINGS{lee:hecate:cgo, 183 | author={Lee, Yongwoo and Heo, Seonyeong and Cheon, Seonyoung and Jeong, Shinnung and Kim, Changsu and Kim, Eunkyung and Lee, Dongyoon and Kim, Hanjun}, 184 | booktitle={2022 IEEE/ACM International Symposium on Code Generation and Optimization (CGO)}, 185 | title={HECATE: Performance-Aware Scale Optimization for Homomorphic Encryption Compiler}, 186 | year={2022}, 187 | volume={}, 188 | number={}, 189 | pages={193-204}, 190 | doi={10.1109/CGO53902.2022.9741265}} 191 | ``` 192 | ```bibtex 193 | @INPROCEEDINGS{lee:elasm:sec, 194 | title={{ELASM}: Error-Latency-Aware Scale Management for Fully Homomorphic Encryption}, 195 | author={Lee, Yongwoo and Cheon, Seonyoung and Kim, Dongkwan and Lee, Dongyoon and Kim, Hanjun}, 196 | booktitle={{32nd} USENIX Security Symposium (USENIX Security 23)}, 197 | year={2023}, 198 | address = {Anaheim, CA}, 199 | publisher = {USENIX Association}, 200 | month = aug 201 | } 202 | ``` 203 | ```bibtex 204 | @INPROCEEDINGS{cheon:dacapo:sec, 205 | title={{DaCapo}: Automatic Bootstrapping Management for Efficient Fully Homomorphic Encryption}, 206 | author={Cheon, Seonyoung and Lee, Yongwoo and Kim, Dongkwan and Lee, Ju Min and Jung, Sunchul and Kim, Taekyung and Lee, Dongyoon and Kim, Hanjun}, 207 | booktitle={{33rd} USENIX Security Symposium (USENIX Security 24)}, 208 | year={2024}, 209 | address = {Philadelphia, CA}, 210 | publisher = {USENIX Association}, 211 | month = aug 212 | } 213 | ``` 214 | -------------------------------------------------------------------------------- /cmake/modules/AddHecate.cmake: -------------------------------------------------------------------------------- 1 | function(add_hecate_dialect dialect dialect_namespace) 2 | set(LLVM_TARGET_DEFINITIONS ${dialect}.td) 3 | mlir_tablegen(${dialect}.h.inc -gen-op-decls) 4 | mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) 5 | mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls -typedefs-dialect=${dialect_namespace}) 6 | mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs -typedefs-dialect=${dialect_namespace}) 7 | mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) 8 | mlir_tablegen(${dialect}Dialect.cpp.inc -gen-dialect-defs -dialect=${dialect_namespace}) 9 | add_public_tablegen_target(Hecate${dialect}IncGen) 10 | add_dependencies(mlir-headers Hecate${dialect}IncGen) 11 | endfunction() 12 | 13 | function(add_hecate_pattern pattern dialect_namespace) 14 | set(LLVM_TARGET_DEFINITIONS ${pattern}.td) 15 | mlir_tablegen(${pattern}Pattern.inc -gen-rewriters -dialect=${dialect_namespace}) 16 | add_public_tablegen_target(Hecate${pattern}IncGen) 17 | add_dependencies(mlir-headers Hecate${pattern}IncGen) 18 | endfunction() 19 | 20 | # Declare a dialect in the include directory 21 | function(add_hecate_interface interface) 22 | set(LLVM_TARGET_DEFINITIONS ${interface}.td) 23 | mlir_tablegen(${interface}.h.inc -gen-op-interface-decls) 24 | mlir_tablegen(${interface}.cpp.inc -gen-op-interface-defs) 25 | mlir_tablegen(${interface}Types.h.inc -gen-type-interface-decls) 26 | mlir_tablegen(${interface}Types.cpp.inc -gen-type-interface-defs) 27 | add_public_tablegen_target(Hecate${interface}IncGen) 28 | add_dependencies(mlir-generic-headers Hecate${interface}IncGen) 29 | endfunction() 30 | -------------------------------------------------------------------------------- /cmake/modules/CMakeLists.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corelab-src/elasm/3c37c11b29ca480525bb6681e0254bdf90029425/cmake/modules/CMakeLists.txt -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "runtime" : "SEAL-HEVM", 3 | "rescalingFactor" : 60, 4 | "polynomialDegree" : 32768, 5 | "levelLowerBound" : 1, 6 | "levelUpperBound" : 13, 7 | "bootstrapLevelLowerBound" : -1, 8 | "bootstrapLevelUpperBound" : -1, 9 | "latencyTable" : { 10 | "earth.rotate_single" : 11 | [ 12 | 3828, 7966, 13584, 20933, 28832, 13 | 40137, 51080, 64134, 78216, 94012, 14 | 110653, 131245, 150699 15 | ], 16 | "earth.rescale_single" : [ 17 | 1926, 3119, 4525, 5706, 6901, 18 | 8198, 9570, 10781, 12339, 13339, 19 | 14488, 17418, 17418 20 | ], 21 | "earth.modswitch_single" : [ 22 | 48, 86, 156, 208, 286, 23 | 315, 391, 457, 536, 622, 24 | 717, 895 25 | ], 26 | "earth.add_single" : [ 27 | 50, 98, 153, 209, 269, 28 | 335, 409, 472, 561, 638, 29 | 709, 800, 2650 30 | ], 31 | "earth.add_double" : [ 32 | 85, 204, 250, 339, 421, 33 | 531, 615, 723, 827, 1021, 34 | 1093, 1251, 3120 35 | ], 36 | "earth.mul_single" : [ 37 | 211, 421, 642, 853, 1120, 38 | 1260, 1509, 1726, 2031, 2270, 39 | 2518, 2918, 4990 40 | ], 41 | "earth.mul_double" : [ 42 | 4363, 9172, 15658, 23517, 33974, 43 | 43235, 56611, 68785, 85137, 101308, 44 | 119907, 137953, 160732 45 | ] 46 | }, 47 | "noiseTable" : { 48 | "earth.rotate_single" : [ 49 | 1243767652.125024, 50 | 3053517076.303607, 51 | 4202768329.642825, 52 | 5839542263.660615, 53 | 6982415435.517867, 54 | 9066416705.357107, 55 | 10703926878.339247, 56 | 13029700700.035797, 57 | 14563337546.892847, 58 | 15257531833.982080, 59 | 17555923920.678551, 60 | 18115173373.499962, 61 | 19223819509.410683 62 | ], 63 | "earth.rescale_single" : [ 64 | 29041012.461168, 65 | 28989829.196367, 66 | 30513059.012577, 67 | 29249886.451856, 68 | 29939898.437991, 69 | 30011344.885873, 70 | 29425438.800167, 71 | 29278080.626328, 72 | 29791926.468794, 73 | 29763853.483675, 74 | 29026441.533765, 75 | 29574204.309626, 76 | 29574204.309626 77 | ], 78 | "earth.mul_double" : [ 79 | 912420980.422666, 80 | 1626939697.436701, 81 | 2841237898.166427, 82 | 5216678287.872595, 83 | 4656426194.539523, 84 | 6498252289.502644, 85 | 4454104710.442873, 86 | 4809761973.565252, 87 | 4954123191.027460, 88 | 7101044074.407553, 89 | 8247373646.850708, 90 | 6977441673.353531, 91 | 8918871036.187374 92 | ] 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /config.sh: -------------------------------------------------------------------------------- 1 | export HECATE=$( cd -- "$( dirname -- "$BASH_SOURCE[0]" )" &> /dev/null && pwd ) 2 | 3 | alias hopt=$HECATE/build/bin/hecate-opt 4 | alias hopt-debug=$HECATE/build-debug/bin/hecate-opt 5 | 6 | mkdir -p $HECATE/examples/traced 7 | mkdir -p $HECATE/examples/optimized/eva 8 | mkdir -p $HECATE/examples/optimized/elasm 9 | 10 | build-hopt()( 11 | cd $HECATE/build 12 | ninja 13 | ) 14 | 15 | build-hoptd()( 16 | cd $HECATE/build-debug 17 | ninja 18 | ) 19 | 20 | hc-trace()( 21 | cd $HECATE/examples 22 | python3 $HECATE/examples/benchmarks/$1.py 23 | ) 24 | 25 | hc-test()( 26 | cd $HECATE/examples 27 | python3 $HECATE/examples/tests/$3.py $1 $2 28 | ) 29 | 30 | 31 | hopt-print(){ 32 | hopt --$1 --ckks-config="$HECATE/config.json" --waterline=$2 --enable-debug-printer $HECATE/examples/traced/$3.mlir --mlir-print-debuginfo --mlir-pretty-debuginfo --mlir-print-local-scope --mlir-timing -o $HECATE/examples/optimized/$1/$3.$2.mlir 33 | } 34 | 35 | hopt-debug-print(){ 36 | hopt-debug --$1 --ckks-config="$HECATE/config.json" --waterline=$2 --enable-debug-printer $HECATE/examples/traced/$3.mlir --mlir-print-debuginfo --mlir-pretty-debuginfo --mlir-print-local-scope --mlir-disable-threading --mlir-timing --mlir-print-ir-after-failure 37 | } 38 | 39 | hopt-debug-print-all(){ 40 | hopt-debug --$1 --ckks-config="$HECATE/config.json" --waterline=$2 --enable-debug-printer $HECATE/examples/traced/$3.mlir --mlir-print-debuginfo --mlir-pretty-debuginfo --mlir-print-local-scope --mlir-disable-threading --mlir-timing --mlir-print-ir-after-failure --debug 41 | } 42 | 43 | hopt-timing-only(){ 44 | hopt --$1 --ckks-config="$HECATE/config.json" --waterline=$2 $HECATE/examples/traced/$3.mlir --mlir-timing -o $HECATE/examples/optimized/$1/$3.$2.mlir 45 | } 46 | 47 | hopt-silent(){ 48 | hopt --$1 --ckks-config="$HECATE/config.json" --waterline=$2 $HECATE/examples/traced/$3.mlir -o $HECATE/examples/optimized/$1/$3.$2.mlir 49 | } 50 | 51 | hc-opt-test() { 52 | hopt-silent $1 $2 $3 && hc-test $1 $2 $3 53 | } 54 | 55 | hc-opt-test-timing() { 56 | hopt-timing-only $1 $2 $3 && hc-test $1 $2 $3 57 | } 58 | 59 | 60 | alias hoptd=hopt-debug-print 61 | alias hopta=hopt-debug-print-all 62 | alias hopts=hopt-silent 63 | alias hoptt=hopt-timing-only 64 | alias hoptp=hopt-print 65 | alias hcot=hc-opt-test 66 | alias hcott=hc-opt-test-timing 67 | -------------------------------------------------------------------------------- /examples/benchmarks/LinearRegression.py: -------------------------------------------------------------------------------- 1 | import hecate as hc 2 | import sys 3 | 4 | 5 | def sum_elements(data): 6 | for i in range(12): 7 | rot = data.rotate(1<<(11-i)) 8 | data = data +rot 9 | 10 | return data 11 | 12 | @hc.func("c,c") 13 | def LinearRegression(x_data, y_data) : 14 | W = hc.Plain([1.0]) 15 | b = hc.Plain([0.0]) 16 | 17 | epochs = 2 18 | learning_rate = hc.Plain([-0.01]) 19 | 20 | for i in range(epochs): 21 | xW = x_data*W 22 | xWb = xW + b 23 | 24 | error = xWb - y_data 25 | 26 | errX = error * x_data 27 | meanErrX = errX * hc.Plain([1/2048]) 28 | gradW = sum_elements(meanErrX) 29 | meanErr = error * hc.Plain([1/2048]) 30 | gradb = sum_elements(meanErr) 31 | Wup = learning_rate * gradW 32 | bup = learning_rate * gradb 33 | W = W + Wup 34 | b = b + bup 35 | 36 | return W, b 37 | 38 | 39 | 40 | modName = hc.save("traced", "traced") 41 | print (modName) 42 | 43 | 44 | -------------------------------------------------------------------------------- /examples/benchmarks/MLP.py: -------------------------------------------------------------------------------- 1 | import hecate as hc 2 | import numpy as np 3 | import sys 4 | import torch 5 | 6 | 7 | def get_flat_weight(file_name): 8 | x = [] 9 | f = open(file_name,'r') 10 | for y in f.read().split('\n'): 11 | if not y == "": 12 | x.append(float(y)) 13 | return x 14 | 15 | #input 784, output 100 16 | def input_to_layer_MNIST(image, W): 17 | res = [0.00000 for i in range(100)] 18 | new_W = [[0.00000 for j in range(800)] for i in range (100)] 19 | for n in range(100) : 20 | for c in range(8) : 21 | for k in range(100) : 22 | index = c*100 + k 23 | if(index < 784) : 24 | if (index+n >= 784) : 25 | new_W[n][800-n+(index+n)%784] = W[100-n+(index+n)%784][(index+n)%784] 26 | else : 27 | new_W[n][index] = W[k][(index + n) % 784 ] 28 | new_W = [hc.Plain(Win) for Win in new_W] 29 | for n in range(100) : 30 | rot = image.rotate(n) 31 | mul = rot * new_W[n] 32 | result = mul if n == 0 else result + mul 33 | 34 | m = 800 35 | res = result 36 | for i in range(3): 37 | m = m >> 1 38 | temp = res.rotate(m) 39 | res = res + temp 40 | 41 | return res 42 | 43 | #input 100, output 10 44 | def layer_to_output_MNIST(image, W): 45 | res = [0.0 for i in range(10)] 46 | new_W = [[0.0 for j in range(100)] for i in range (10)] 47 | for n in range(10) : 48 | for c in range(10) : 49 | for k in range(10) : 50 | index = c *10 + k 51 | if(c * 10 + k < 100) : 52 | new_W[n][index] = W[k][(index + n) % 100 ] 53 | new_W = [hc.Plain(Win) for Win in new_W] 54 | for n in range(10): 55 | rot = image.rotate(n) 56 | mul = rot * new_W[n] 57 | result = mul if n == 0 else result + mul 58 | 59 | temp = result.rotate(50) 60 | res = result + temp 61 | for i in range(5): 62 | temp = res.rotate(i*10) 63 | res1 = temp if i == 0 else res1 + temp 64 | return res1 65 | 66 | 67 | 68 | @hc.func("c") 69 | def MLP(image) : 70 | from pathlib import Path 71 | 72 | source_path = Path(__file__).resolve() 73 | source_dir = source_path.parent 74 | model = torch.load(str(source_dir)+"/../data/mlp.model", map_location=torch.device('cpu')) 75 | W1 = model["linear1.weight"].cpu().detach().numpy() 76 | b1 = model["linear1.bias"].cpu().detach().numpy() 77 | W2 = model["linear2.weight"].cpu().detach().numpy() 78 | b2 = model["linear2.bias"].cpu().detach().numpy() 79 | h1 = input_to_layer_MNIST(image, W1) + hc.Plain(list(b1)) 80 | h = h1 * h1 81 | res = layer_to_output_MNIST(h, W2) + hc.Plain(list(b2)) 82 | return res 83 | 84 | 85 | 86 | modName = hc.save("traced", "traced") 87 | print (modName) 88 | 89 | 90 | -------------------------------------------------------------------------------- /examples/benchmarks/SobelFilter.py: -------------------------------------------------------------------------------- 1 | import hecate as hc 2 | import sys 3 | import numpy as np 4 | 5 | 6 | def roll(a, i) : 7 | return a.rotate(i) 8 | 9 | @hc.func("c") 10 | def SobelFilter(lena_array) : 11 | 12 | F = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]] 13 | Ix = 0 14 | Iy = 0 15 | for i in range(3) : 16 | for j in range(3) : 17 | rot = roll (lena_array, i*64 +j) 18 | h = rot * F[i][j] 19 | v = rot * F[j][i] 20 | Ix = Ix + h 21 | Iy = Iy + v 22 | Ix2 = Ix * Ix 23 | Iy2 = Iy * Iy 24 | c = Ix2 + Iy2 25 | d = c*c*c *0.173 - c * c *1.098 + c * 2.214 26 | return d 27 | 28 | 29 | modName = hc.save("traced", "traced") 30 | print (modName) 31 | -------------------------------------------------------------------------------- /examples/data/cornertest.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corelab-src/elasm/3c37c11b29ca480525bb6681e0254bdf90029425/examples/data/cornertest.jpg -------------------------------------------------------------------------------- /examples/data/mlp.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/corelab-src/elasm/3c37c11b29ca480525bb6681e0254bdf90029425/examples/data/mlp.model -------------------------------------------------------------------------------- /examples/tests/LinearRegression.py: -------------------------------------------------------------------------------- 1 | 2 | import hecate as hc 3 | from random import * 4 | import numpy as np 5 | import sys 6 | from pathlib import Path 7 | import time 8 | 9 | a_compile_type = sys.argv[1] 10 | a_compile_opt = int(sys.argv[2]) 11 | 12 | 13 | hevm = hc.HEVM() 14 | stem = Path(__file__).stem 15 | hevm.load (f"traced/_hecate_{stem}.cst", f"optimized/{a_compile_type}/{stem}.{a_compile_opt}._hecate_{stem}.hevm") 16 | 17 | x = [ uniform (-1, 1) for a in range(4096)] 18 | a = 2.0 19 | b = 1.0 20 | y = [ a*point +b + uniform (-0.01, 0.01) for point in x] 21 | 22 | 23 | # print(res) 24 | 25 | W = 1.0 26 | c = 0.0 27 | 28 | epochs = 2 29 | learning_rate = -0.01 30 | 31 | for i in range(epochs): 32 | 33 | error = [ W*x[i]+c-y[i] for i in range(4096)] 34 | errX = [ error[i]* x[i] for i in range(4096)] 35 | gradW = sum(errX)/ 2048 36 | gradb = sum(error)/2048 37 | Wup = learning_rate * gradW 38 | bup = learning_rate * gradb 39 | W = W + Wup 40 | c = c + bup 41 | 42 | # print (W, c) 43 | 44 | hevm.setInput(0, x) 45 | hevm.setInput(1, y) 46 | timer = time.perf_counter_ns() 47 | hevm.run() 48 | timer = time.perf_counter_ns() -timer 49 | res = hevm.getOutput() 50 | rms = np.sqrt(np.mean(np.power(res[0] - W, 2) + np.power(res[1] - c, 2))) 51 | print (timer / pow(10,9)) 52 | print(rms) 53 | 54 | -------------------------------------------------------------------------------- /examples/tests/MLP.py: -------------------------------------------------------------------------------- 1 | import hecate as hc 2 | import sys 3 | # import pandas as pd 4 | import torch 5 | from torchvision import datasets, transforms 6 | 7 | 8 | from PIL import Image 9 | import numpy as np 10 | from random import * 11 | import pprint 12 | 13 | 14 | from pathlib import Path 15 | 16 | source_path = Path(__file__).resolve() 17 | source_dir = source_path.parent 18 | 19 | def preprocess(): 20 | x = [ uniform (0.0, 1.0) for a in range(784)] 21 | b = [ 0 for a in range(16)] 22 | x = x+b 23 | return np.array(x) 24 | 25 | def process(x) : 26 | model = torch.load(str(source_dir)+"/../data/mlp.model", map_location=torch.device('cpu')) 27 | W1 = model["linear1.weight"].cpu().detach().numpy() 28 | b1 = model["linear1.bias"].cpu().detach().numpy() 29 | W2 = model["linear2.weight"].cpu().detach().numpy() 30 | b2 = model["linear2.bias"].cpu().detach().numpy() 31 | 32 | inter = [0.0 for i in range (100)] 33 | res = [0.0 for i in range (10)] 34 | 35 | for i in range(100) : 36 | for j in range(784) : 37 | inter[i] += x[j] * W1[i][j] 38 | inter[i] += b1[i] 39 | inter = [i*i for i in inter] 40 | for i in range(10) : 41 | for j in range(100) : 42 | res[i] += inter[j] * W2[i][j] 43 | res[i] += b2[i] 44 | 45 | return np.array([res]) 46 | 47 | def postprocess(res) : 48 | return res[0,:10] 49 | 50 | 51 | 52 | if __name__ == "__main__" : 53 | 54 | from random import * 55 | import sys 56 | from pathlib import Path 57 | import time 58 | from PIL import Image 59 | 60 | a_compile_type = sys.argv[1] 61 | a_compile_opt = int(sys.argv[2]) 62 | hevm = hc.HEVM() 63 | stem = Path(__file__).stem 64 | hevm.load (f"traced/_hecate_{stem}.cst", f"optimized/{a_compile_type}/{stem}.{a_compile_opt}._hecate_{stem}.hevm") 65 | 66 | input_dat = preprocess() 67 | reference = postprocess(process(input_dat)) 68 | [hevm.setInput(i, dat) for i, dat in enumerate([input_dat])] 69 | timer = time.perf_counter_ns() 70 | hevm.run() 71 | timer = time.perf_counter_ns() -timer 72 | res = hevm.getOutput() 73 | res = postprocess(res) 74 | err = res - reference 75 | rms = np.sqrt( np.sum(err*err) / res.shape[-1]) 76 | print (timer/ (pow(10, 9))) 77 | print (rms) 78 | -------------------------------------------------------------------------------- /examples/tests/SobelFilter.py: -------------------------------------------------------------------------------- 1 | 2 | import hecate as hc 3 | import numpy as np 4 | 5 | 6 | def roll (a, i) : 7 | return np.roll(a, -i) 8 | 9 | def preprocess(): 10 | 11 | lena = Image.open(f'{hc.hecate_dir}/examples//data/cornertest.jpg').convert('L') 12 | lena = lena.resize((64,64)) 13 | lena_array = np.asarray(lena.getdata(), dtype=np.float64) / 256 14 | lena_array = lena_array.reshape([64*64]) 15 | 16 | return lena_array.reshape([1, 4096]); 17 | 18 | def process(lena_array) : 19 | 20 | lena_array = lena_array[0] 21 | F = [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]] 22 | Ix = 0 23 | Iy = 0 24 | for i in range(3) : 25 | for j in range(3) : 26 | rot = roll (lena_array, i*64 +j) 27 | h = rot * F[i][j] 28 | v = rot * F[j][i] 29 | Ix = Ix + h 30 | Iy = Iy + v 31 | Ix2 = Ix * Ix 32 | Iy2 = Iy * Iy 33 | c = Ix2 + Iy2 34 | d = 0.173*c*c*c - 1.098 * c * c + 2.214*c 35 | 36 | return d.reshape([1, 4096]) 37 | 38 | def postprocess (result) : 39 | return (result *256) [:, :4096] 40 | 41 | ## EVAL 42 | 43 | if __name__ == "__main__" : 44 | 45 | from random import * 46 | import sys 47 | from pathlib import Path 48 | import time 49 | from PIL import Image 50 | 51 | a_compile_type = sys.argv[1] 52 | a_compile_opt = int(sys.argv[2]) 53 | hevm = hc.HEVM() 54 | stem = Path(__file__).stem 55 | hevm.load (f"traced/_hecate_{stem}.cst", f"optimized/{a_compile_type}/{stem}.{a_compile_opt}._hecate_{stem}.hevm") 56 | 57 | input_dat = preprocess() 58 | reference = postprocess(process(input_dat)) 59 | [hevm.setInput(i, dat) for i, dat in enumerate(input_dat)] 60 | timer = time.perf_counter_ns() 61 | hevm.run() 62 | timer = time.perf_counter_ns() -timer 63 | res = hevm.getOutput() 64 | res = postprocess(res) 65 | err = res - reference 66 | rms = np.sqrt( np.sum(err*err) / res.shape[-1]) 67 | print (timer/ (pow(10, 9))) 68 | print (rms) 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /include/hecate/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Conversion) 2 | add_subdirectory(Dialect) 3 | -------------------------------------------------------------------------------- /include/hecate/Conversion/CKKSCommon/PolyTypeConverter.h: -------------------------------------------------------------------------------- 1 | #ifndef HECATE_CONVERSION_CKKSCOMMON_TYPECONVERTER_H 2 | #define HECATE_CONVERSION_CKKSCOMMON_TYPECONVERTER_H 3 | 4 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 5 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 6 | #include "mlir/Transforms/DialectConversion.h" 7 | 8 | namespace hecate { 9 | struct PolyTypeConverter : public mlir::TypeConverter { 10 | using TypeConverter::TypeConverter; 11 | 12 | PolyTypeConverter(int64_t base_level); 13 | mlir::Type convertFunctionType(mlir::FunctionType t); 14 | mlir::Type convertTensorType(mlir::TensorType t); 15 | mlir::Type convertCipherType(hecate::earth::CipherType t); 16 | mlir::Type convertPlainType(hecate::earth::PlainType t); 17 | 18 | private: 19 | int64_t base_level; 20 | }; 21 | } // namespace hecate 22 | 23 | #endif 24 | -------------------------------------------------------------------------------- /include/hecate/Conversion/CKKSToCKKS/UpscaleToMulcp.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_CONVERSION_CKKSTOCKKS_UPSCALETOMULCP_H 3 | #define HECATE_CONVERSION_CKKSTOCKKS_UPSCALETOMULCP_H 4 | 5 | #include 6 | 7 | #include "hecate/Conversion/CKKSCommon/PolyTypeConverter.h" 8 | 9 | namespace mlir { 10 | namespace func { 11 | class FuncOp; 12 | } 13 | class RewritePatternSet; 14 | template class OperationPass; 15 | } // namespace mlir 16 | 17 | namespace hecate { 18 | 19 | #define GEN_PASS_DECL_UPSACLETOMULCPCONVERSION 20 | #include "mlir/Conversion/Passes.h.inc" 21 | 22 | namespace ckks { 23 | 24 | std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> 25 | createUpscaleToMulcpConversionPass(); 26 | 27 | void populateUpscaleToMulcpConversionPatterns( 28 | mlir::MLIRContext *ctxt, mlir::RewritePatternSet &patterns); 29 | 30 | } // namespace ckks 31 | 32 | } // namespace hecate 33 | 34 | #endif // MLIR_CONVERSION_ARITHTOLLVM_ARITHTOLLVM_H 35 | -------------------------------------------------------------------------------- /include/hecate/Conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | set (LLVM_TARGET_DEFINITIONS Passes.td) 3 | 4 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) 5 | add_public_tablegen_target(HecateConversionPassIncGen) 6 | -------------------------------------------------------------------------------- /include/hecate/Conversion/EarthToCKKS/EarthToCKKS.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_CONVERSION_EARTHTOCKKS_EARTHTOCKKS_H 3 | #define HECATE_CONVERSION_EARTHTOCKKS_EARTHTOCKKS_H 4 | 5 | #include 6 | 7 | #include "hecate/Conversion/CKKSCommon/PolyTypeConverter.h" 8 | 9 | namespace mlir { 10 | namespace func { 11 | class FuncOp; 12 | } 13 | class RewritePatternSet; 14 | template class OperationPass; 15 | } // namespace mlir 16 | 17 | namespace hecate { 18 | 19 | #define GEN_PASS_DECL_EARTHTOCKKSCONVERSION 20 | #include "mlir/Conversion/Passes.h.inc" 21 | 22 | namespace earth { 23 | std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> 24 | createEarthToCKKSConversionPass(); 25 | 26 | void populateEarthToCKKSConversionPatterns(mlir::MLIRContext *ctxt, 27 | mlir::TypeConverter &converter, 28 | mlir::RewritePatternSet &patterns, 29 | int64_t init_level); 30 | 31 | } // namespace earth 32 | 33 | } // namespace hecate 34 | 35 | #endif // MLIR_CONVERSION_ARITHTOLLVM_ARITHTOLLVM_H 36 | -------------------------------------------------------------------------------- /include/hecate/Conversion/Passes.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_CONVERSION_PASSES_H 3 | #define HECATE_CONVERSION_PASSES_H 4 | 5 | #include "hecate/Conversion/CKKSToCKKS/UpscaleToMulcp.h" 6 | #include "hecate/Conversion/EarthToCKKS/EarthToCKKS.h" 7 | 8 | namespace hecate { 9 | 10 | #define GEN_PASS_REGISTRATION 11 | #include "hecate/Conversion/Passes.h.inc" 12 | 13 | } // namespace hecate 14 | 15 | #endif 16 | -------------------------------------------------------------------------------- /include/hecate/Conversion/Passes.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_CONVERSION_PASSES 3 | #define HECATE_CONVERSION_PASSES 4 | 5 | include "mlir/Pass/PassBase.td" 6 | 7 | def EarthToCKKSConversion : Pass<"convert-earth-to-ckks", "::mlir::func::FuncOp"> { 8 | let summary = "Convert Earth dialect to CKKS dialect"; 9 | let description = [{ 10 | This pass converts supported Earth ops to CKKS dialect instructions. 11 | }]; 12 | let constructor = "::hecate::earth::createEarthToCKKSConversionPass()"; 13 | let dependentDialects = ["hecate::earth::EarthDialect", "hecate::ckks::CKKSDialect", "mlir::tensor::TensorDialect"]; 14 | let options = [ 15 | ]; 16 | } 17 | def UpscaleToMulcpConversion : Pass<"convert-upscale-to-mulcp", "::mlir::func::FuncOp"> { 18 | let summary = "Convert upscale operation to mulcp operation"; 19 | let description = [{ 20 | This pass converts upscale ops to mulcp instructions. 21 | }]; 22 | let constructor = "::hecate::ckks::createUpscaleToMulcpConversionPass()"; 23 | let dependentDialects = ["hecate::ckks::CKKSDialect", "mlir::tensor::TensorDialect"]; 24 | let options = [ 25 | ]; 26 | } 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/IR/CKKSOps.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_DIALECT_CKKS_IR_ARITHOPS_H 3 | #define HECATE_DIALECT_CKKS_IR_ARITHOPS_H 4 | 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/IR/Builders.h" 7 | #include "mlir/IR/BuiltinOps.h" 8 | #include "mlir/IR/BuiltinTypes.h" 9 | #include "mlir/IR/Dialect.h" 10 | #include "mlir/IR/OpDefinition.h" 11 | #include "mlir/IR/OpImplementation.h" 12 | #include "mlir/IR/PatternMatch.h" 13 | #include "mlir/Interfaces/CallInterfaces.h" 14 | #include "mlir/Interfaces/CastInterfaces.h" 15 | #include "mlir/Interfaces/ControlFlowInterfaces.h" 16 | #include "mlir/Interfaces/CopyOpInterface.h" 17 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" 18 | #include "mlir/Interfaces/InferTypeOpInterface.h" 19 | #include "mlir/Interfaces/SideEffectInterfaces.h" 20 | #include 21 | 22 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h" 23 | 24 | #include "hecate/Dialect/CKKS/IR/CKKSOpsDialect.h.inc" 25 | #define GET_TYPEDEF_CLASSES 26 | #include "hecate/Dialect/CKKS/IR/CKKSOpsTypes.h.inc" 27 | #define GET_OP_CLASSES 28 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h.inc" 29 | 30 | template 31 | class SameOperandsAndResultLevel 32 | : public mlir::OpTrait::TraitBase { 34 | public: 35 | static mlir::LogicalResult verifyTrait(mlir::Operation *op) { 36 | return mlir::success(); 37 | } 38 | }; 39 | 40 | template 41 | class SameOperandsAndLowerResultLevel 42 | : public mlir::OpTrait::TraitBase { 44 | public: 45 | static mlir::LogicalResult verifyTrait(mlir::Operation *op) { 46 | return mlir::success(); 47 | } 48 | }; 49 | 50 | namespace hecate { 51 | namespace ckks { 52 | ::hecate::ckks::PolyTypeInterface getPolyType(mlir::Value v); 53 | ::mlir::RankedTensorType getTensorType(mlir::Value v); 54 | } // namespace ckks 55 | } // namespace hecate 56 | #endif 57 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/IR/CKKSOps.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_CKKS_OPS 3 | #define HECATE_CKKS_OPS 4 | 5 | include "mlir/IR/OpBase.td" 6 | include "mlir/IR/PatternBase.td" 7 | include "mlir/IR/AttrTypeBase.td" 8 | include "mlir/Interfaces/CallInterfaces.td" 9 | include "mlir/Interfaces/SideEffectInterfaces.td" 10 | include "mlir/IR/BuiltinTypeInterfaces.td" 11 | include "mlir/Interfaces/DestinationStyleOpInterface.td" 12 | include "mlir/Interfaces/InferTypeOpInterface.td" 13 | include "mlir/IR/OpAsmInterface.td" 14 | include "hecate/Dialect/CKKS/IR/PolyTypeInterface.td" 15 | /* include "hecate/Dialect/Earth/IR/HEParameterInterface.td" */ 16 | /* include "hecate/Dialect/Earth/IR/ForwardManagementInterface.td" */ 17 | 18 | 19 | def CKKS_Dialect : Dialect { 20 | let name = "ckks"; 21 | let cppNamespace = "::hecate::ckks"; 22 | let useFoldAPI= kEmitFoldAdaptorFolder; 23 | let useDefaultTypePrinterParser = 1; 24 | } 25 | 26 | def PolyType : TypeDef, DeclareTypeInterfaceMethods]> { 27 | let summary = "CKKS Polynomial Type with # of polynomial and level"; 28 | let description = [{ 29 | A type for CKKS API level optimization. 30 | }]; 31 | let mnemonic = "poly"; 32 | let parameters = (ins "unsigned":$num_poly, "unsigned":$level); 33 | let assemblyFormat = "`<` $num_poly `*` $level `>`"; 34 | let genVerifyDecl = 0; 35 | } 36 | 37 | def PolyTensor : TensorOf<[PolyType]>; 38 | 39 | class CKKS_Op traits = []> : 40 | Op, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods])> { 41 | let results = (outs PolyTensor); 42 | string arithClassDefinition = [{ 43 | // - ScaleOpInterface Non-default methods 44 | // - isSingle, isConsume is remaining methods 45 | std::pair $cppClass::getDpsInitsPositionRange (){ return {0, 1};} 46 | }]; 47 | let extraClassDefinition = arithClassDefinition; 48 | } 49 | 50 | /* ######### Zero-operand Operations ########## */ 51 | 52 | def EncodeOp : CKKS_Op<"encode",[]> { 53 | 54 | let arguments = (ins PolyTensor:$dst, I64Attr:$value, I64Attr:$scale, I64Attr:$level); 55 | // AnyAttr because of constant hoisting 56 | let results = (outs PolyTensor); 57 | 58 | 59 | code extraClassDeclaration = [{ 60 | static bool isCompatibleReturnTypes(::mlir::TypeRange lhs, ::mlir::TypeRange rhs) ; 61 | int64_t getNumOperands(); 62 | }]; 63 | 64 | code extraClassDefinition = arithClassDefinition # [{ 65 | bool $cppClass::isCompatibleReturnTypes(::mlir::TypeRange lhs, ::mlir::TypeRange rhs){ 66 | return rhs.back().dyn_cast().getNumPoly() == 1; 67 | } 68 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 69 | HEVMOperation op; 70 | op.opcode = 0; 71 | op.dst = plainMap[getDst()]; 72 | op.lhs = getValue(); 73 | op.rhs = (getLevel() << 8) + getScale(); 74 | return op; 75 | } 76 | int64_t $cppClass::getNumOperands() { 77 | return 1; 78 | } 79 | }]; 80 | } 81 | 82 | /* ######### Unary Operations ########## */ 83 | 84 | 85 | def RotateCOp : CKKS_Op<"rotatec", [SameOperandsAndResultType]> { 86 | let arguments = (ins PolyTensor:$dst, PolyTensor:$src, DenseI64ArrayAttr:$offset); 87 | code extraClassDefinition = arithClassDefinition # [{ 88 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 89 | HEVMOperation op; 90 | op.opcode = 1; 91 | op.dst = cipherMap[getDst()]; 92 | op.lhs = cipherMap[getSrc()]; 93 | op.rhs = getOffset()[0]; 94 | return op; 95 | } 96 | }]; 97 | } 98 | def NegateCOp : CKKS_Op<"negatec", [SameOperandsAndResultType]> { 99 | let arguments = (ins PolyTensor:$dst, PolyTensor:$src); 100 | code extraClassDefinition = arithClassDefinition # [{ 101 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 102 | HEVMOperation op; 103 | op.opcode = 2; 104 | op.dst = cipherMap[getDst()]; 105 | op.lhs = cipherMap[getSrc()]; 106 | op.rhs = 0; 107 | return op; 108 | } 109 | }]; 110 | } 111 | 112 | /* ######## Scale Management Operations ####### */ 113 | 114 | def RescaleCOp : CKKS_Op<"rescalec", [SameOperandsAndResultShape]> { 115 | let arguments = (ins PolyTensor:$dst, PolyTensor:$src); 116 | code extraClassDefinition = arithClassDefinition # [{ 117 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 118 | HEVMOperation op; 119 | op.opcode = 3; 120 | op.dst = cipherMap[getDst()]; 121 | op.lhs = cipherMap[getSrc()]; 122 | op.rhs = 0; 123 | return op; 124 | } 125 | }]; 126 | } 127 | def ModswitchCOp : CKKS_Op<"modswitchc", [SameOperandsAndResultShape]> { 128 | let arguments = (ins PolyTensor:$dst, PolyTensor:$src, I64Attr:$downFactor); 129 | code extraClassDefinition = arithClassDefinition # [{ 130 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 131 | HEVMOperation op; 132 | op.opcode = 4; 133 | op.dst = cipherMap[getDst()]; 134 | op.lhs = cipherMap[getSrc()]; 135 | op.rhs = getDownFactor(); 136 | return op; 137 | } 138 | }]; 139 | } 140 | def UpscaleCOp : CKKS_Op<"upscalec", [SameOperandsAndResultType]> { 141 | let arguments = (ins PolyTensor:$dst, PolyTensor:$src, I64Attr:$upFactor); 142 | code extraClassDefinition = arithClassDefinition # [{ 143 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 144 | HEVMOperation op; 145 | op.opcode = 5; 146 | op.dst = cipherMap[getDst()]; 147 | op.lhs = cipherMap[getSrc()]; 148 | op.rhs = getUpFactor(); 149 | return op; 150 | } 151 | }]; 152 | } 153 | 154 | 155 | /* ######### Binary Operations ########## */ 156 | 157 | def AddCCOp: CKKS_Op<"addcc", [Commutative, SameOperandsAndResultType]> { 158 | let arguments = (ins PolyTensor:$dst, PolyTensor:$lhs, PolyTensor:$rhs); 159 | code extraClassDefinition = arithClassDefinition # [{ 160 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 161 | HEVMOperation op; 162 | op.opcode = 6; 163 | op.dst = cipherMap[getDst()]; 164 | op.lhs = cipherMap[getLhs()]; 165 | op.rhs = cipherMap[getRhs()]; 166 | return op; 167 | } 168 | }]; 169 | } 170 | def AddCPOp: CKKS_Op<"addcp", [Commutative, SameOperandsAndResultShape]> { 171 | let arguments = (ins PolyTensor:$dst, PolyTensor:$lhs, PolyTensor:$rhs); 172 | code extraClassDefinition = arithClassDefinition # [{ 173 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 174 | HEVMOperation op; 175 | op.opcode = 7; 176 | op.dst = cipherMap[getDst()]; 177 | op.lhs = cipherMap[getLhs()]; 178 | op.rhs = plainMap[getRhs()]; 179 | return op; 180 | } 181 | }]; 182 | } 183 | def MulCCOp: CKKS_Op<"mulcc", [Commutative, SameOperandsAndResultType]> { 184 | let arguments = (ins PolyTensor:$dst, PolyTensor:$lhs, PolyTensor:$rhs); 185 | code extraClassDefinition = arithClassDefinition # [{ 186 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 187 | HEVMOperation op; 188 | op.opcode = 8; 189 | op.dst = cipherMap[getDst()]; 190 | op.lhs = cipherMap[getLhs()]; 191 | op.rhs = cipherMap[getRhs()]; 192 | return op; 193 | } 194 | }]; 195 | } 196 | def MulCPOp: CKKS_Op<"mulcp", [Commutative, SameOperandsAndResultShape]> { 197 | let arguments = (ins PolyTensor:$dst, PolyTensor:$lhs, PolyTensor:$rhs); 198 | code extraClassDefinition = arithClassDefinition # [{ 199 | HEVMOperation $cppClass::getHEVMOperation (llvm::DenseMap plainMap, llvm::DenseMap cipherMap){ 200 | HEVMOperation op; 201 | op.opcode = 9; 202 | op.dst = cipherMap[getDst()]; 203 | op.lhs = cipherMap[getLhs()]; 204 | op.rhs = plainMap[getRhs()]; 205 | return op; 206 | } 207 | }]; 208 | } 209 | 210 | 211 | 212 | #endif 213 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_hecate_interface(PolyTypeInterface ckks) 2 | add_hecate_dialect(CKKSOps ckks) 3 | add_dependencies(HecateCKKSOpsIncGen HecatePolyTypeInterfaceIncGen) 4 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/IR/PolyTypeInterface.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_DIALECT_CKKS_IR_POLYTYPEINTERFACE_H 3 | #define HECATE_DIALECT_CKKS_IR_POLYTYPEINTERFACE_H 4 | 5 | #include "hecate/Support/HEVMHeader.h" 6 | #include "mlir/IR/BuiltinTypes.h" 7 | #include "mlir/IR/OpDefinition.h" 8 | #include 9 | 10 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterfaceTypes.h.inc" 11 | 12 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h.inc" 13 | 14 | #endif // MLIR_DIALECT_AFFINE_IR_AFFINEMEMORYOPINTERFACES_H 15 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/IR/PolyTypeInterface.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HE_PARAMETER_INTERFACE 3 | #define HE_PARAMETER_INTERFACE 4 | 5 | include "mlir/IR/OpBase.td" 6 | include "mlir/IR/OpBase.td" 7 | 8 | 9 | def PolyTypeInterface : TypeInterface<"PolyTypeInterface"> { 10 | 11 | let description = [{ 12 | A type interface for CKKS API level optimization. 13 | Type that implement this interface can be analyzed and managed by CKKS API-level optimization. 14 | }]; 15 | let cppNamespace = "::hecate::ckks"; 16 | let methods = [ 17 | InterfaceMethod< 18 | /*desc=*/[{Get a number of polynomials in a given type. 19 | }], 20 | /*retType=*/"unsigned", 21 | /*methodName=*/"getNumPoly" 22 | >, 23 | InterfaceMethod< 24 | /*desc=*/[{Get a level of a given type. 25 | }], 26 | /*retType=*/"unsigned", 27 | /*methodName=*/"getLevel" 28 | >, 29 | InterfaceMethod< 30 | /*desc=*/[{Get a level of a given type. 31 | }], 32 | /*retType=*/"::hecate::ckks::PolyTypeInterface", 33 | /*methodName=*/"switchLevel", 34 | /*args=*/(ins "unsigned":$level) 35 | >, 36 | InterfaceMethod< 37 | /*desc=*/[{Get a level of a given type. 38 | }], 39 | /*retType=*/"::hecate::ckks::PolyTypeInterface", 40 | /*methodName=*/"switchNumPoly", 41 | /*args=*/(ins "unsigned":$num_poly) 42 | >, 43 | ]; 44 | } 45 | 46 | def HEVMOpInterface : OpInterface<"HEVMOpInterface"> { 47 | let description = [{ 48 | A printer interface for HEVM API . 49 | Operations that implement this interface can be printed to HEVM. 50 | }]; 51 | let cppNamespace = "::hecate::ckks"; 52 | let methods = [ 53 | InterfaceMethod< 54 | /*desc=*/[{Get a number of polynomials in a given type. 55 | }], 56 | /*retType=*/"HEVMOperation", 57 | /*methodName=*/"getHEVMOperation", 58 | /*args=*/(ins "llvm::DenseMap":$plainMap, "llvm::DenseMap":$cipherMap) 59 | > 60 | ]; 61 | } 62 | 63 | 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | set(LLVM_TARGET_DEFINITIONS Passes.td) 3 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name CKKS) 4 | add_public_tablegen_target(HecateCKKSTransformsIncGen) 5 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/Transforms/Passes.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef HECATE_DIALECT_CKKS_TRANSFORMS_PASSES_H_ 4 | #define HECATE_DIALECT_CKKS_TRANSFORMS_PASSES_H_ 5 | 6 | #include "mlir/Pass/Pass.h" 7 | #include 8 | 9 | namespace hecate { 10 | 11 | namespace ckks { 12 | 13 | #define GEN_PASS_DECL 14 | #include "hecate/Dialect/CKKS/Transforms/Passes.h.inc" 15 | 16 | //===----------------------------------------------------------------------===// 17 | // Registration 18 | //===----------------------------------------------------------------------===// 19 | 20 | /// Generate the code for registering passes. 21 | #define GEN_PASS_REGISTRATION 22 | #include "hecate/Dialect/CKKS/Transforms/Passes.h.inc" 23 | 24 | } // namespace ckks 25 | } // namespace hecate 26 | 27 | #endif 28 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CKKS/Transforms/Passes.td: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef HECATE_DIALECT_CKKS_TRANSFORMS_PASSES 4 | #define HECATE_DIALECT_CKKS_TRANSFORMS_PASSES 5 | 6 | include "mlir/Pass/PassBase.td" 7 | 8 | def RemoveLevel : Pass<"remove-level", "::mlir::func::FuncOp"> { 9 | let summary = "Remove Level Informations for Ciphertext-granularity Register Allocation"; 10 | let description = [{ 11 | }]; 12 | } 13 | 14 | def ReuseBuffer : Pass<"reuse-buffer", "::mlir::func::FuncOp"> { 15 | let summary = "Remove Level Informations for Ciphertext-granularity Register Allocation"; 16 | let description = [{ 17 | }]; 18 | } 19 | 20 | def EmitHEVM : Pass<"emit-hevm", "::mlir::func::FuncOp"> { 21 | let summary = "Remove Level Informations for Ciphertext-granularity Register Allocation"; 22 | let description = [{ 23 | }]; 24 | let options = [ 25 | Option<"prefix", "prefix", "std::string", /*default=*/[{""}], 26 | "Name prefix of output file">, 27 | ]; 28 | 29 | } 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /include/hecate/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Earth) 2 | add_subdirectory(CKKS) 3 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Analysis/AutoDifferentiation.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_ANALYSIS_AUTODIFFERENTIATION 3 | #define HECATE_ANALYSIS_AUTODIFFERENTIATION 4 | 5 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Pass/AnalysisManager.h" 8 | #include "llvm/ADT/SmallSet.h" 9 | #include 10 | 11 | namespace hecate { 12 | 13 | struct AutoDifferentiation { 14 | public: 15 | AutoDifferentiation(mlir::Operation *op); 16 | 17 | // Passing Operation and Value gives same result 18 | double getBackDiff(mlir::Operation *op); 19 | double getBackDiff(mlir::Value v); 20 | 21 | double getBackDiff(mlir::OpOperand &oper); 22 | 23 | double getValueEstimation(mlir::Operation *op); 24 | double getValueEstimation(mlir::Value v); 25 | 26 | private: 27 | void build(); 28 | 29 | llvm::DenseMap operandDiffMap; 30 | llvm::DenseMap valueDiffMap; 31 | llvm::DenseMap valueMap; 32 | mlir::Operation *_op; 33 | }; 34 | } // namespace hecate 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h: -------------------------------------------------------------------------------- 1 | #ifndef HECATE_ANALYSIS_SCALEMANAGEMENTUNIT 2 | #define HECATE_ANALYSIS_SCALEMANAGEMENTUNIT 3 | 4 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/Pass/AnalysisManager.h" 7 | #include "llvm/ADT/SmallSet.h" 8 | #include 9 | 10 | namespace hecate { 11 | 12 | struct ScaleManagementUnit { 13 | public: 14 | ScaleManagementUnit(mlir::Operation *op); 15 | 16 | // Default Implementation 17 | int64_t getID(mlir::Value v) const; 18 | int64_t getID(mlir::Operation *op) const; 19 | int64_t getEdge(mlir::OpOperand *op) const; 20 | mlir::SmallVector getEdgeSet(int64_t edge) const; 21 | mlir::SmallVector getValueSet(int64_t id) const; 22 | int64_t getNumEdges() const; 23 | int64_t getNumSMUs() const; 24 | 25 | // Helper for ELASM 26 | bool inNoisyGroup(mlir::Operation *op) const; 27 | bool inNoisyGroup(mlir::Value v) const; 28 | 29 | // This is a helper for passes. 30 | // Should not be called in analysis. 31 | void attach(); 32 | void detach(); 33 | 34 | bool verify() const; 35 | bool isInvalidated(const mlir::AnalysisManager::PreservedAnalyses &); 36 | 37 | private: 38 | void build(); 39 | 40 | int64_t idMax; 41 | int64_t edgeMax; 42 | 43 | llvm::DenseMap smuIds; 44 | llvm::DenseMap smuEdges; 45 | llvm::SmallVector, 4> idToValue; 46 | llvm::SmallVector, 4> edgeToOper; 47 | 48 | llvm::SmallVector noisyMap; 49 | /* llvm::DenseMap> 50 | * edgeToOper; */ 51 | /* llvm::DenseMap> idToValue; */ 52 | mlir::Operation *_op; 53 | }; 54 | } // namespace hecate 55 | 56 | #endif 57 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_subdirectory(IR) 3 | add_subdirectory(Transforms) 4 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # set(LLVM_TARGET_DEFINITIONS HEParameterInterface.td) 2 | # mlir_tablegen(ScaleTypeInterface.h.inc -gen-type-interface-decls) 3 | # mlir_tablegen(ScaleTypeInterface.cpp.inc -gen-type-interface-defs) 4 | # mlir_tablegen(ScaleOpInterface.h.inc -gen-op-interface-decls) 5 | # mlir_tablegen(ScaleOpInterface.cpp.inc -gen-op-interface-defs) 6 | # add_public_tablegen_target(HECATEScaleInterfacesIncGen) 7 | 8 | add_hecate_interface(HEParameterInterface earth) 9 | add_hecate_interface(ForwardManagementInterface earth) 10 | add_hecate_dialect(EarthOps earth) 11 | add_hecate_pattern(EarthCanonicalizer earth) 12 | add_dependencies(HecateEarthOpsIncGen HecateHEParameterInterfaceIncGen) 13 | 14 | # set(LLVM_TARGET_DEFINITIONS ArithOps.td) 15 | # mlir_tablegen(ArithOp.h.inc -gen-op-decls) 16 | # mlir_tablegen(ArithOp.cpp.inc -gen-op-defs) 17 | # add_public_tablegen_target(HECATEArithOpsIncGen) 18 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/EarthCanonicalizer.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_EARTH_CANONICALIZER 3 | #define HECATE_EARTH_CANONICALIZER 4 | include "mlir/IR/PatternBase.td" 5 | include "hecate/Dialect/Earth/IR/EarthOps.td" 6 | 7 | 8 | /* def TestPattern : Pat<(NegateOp(NegateOp $arg)), (replaceWithValue $arg)>; */ 9 | 10 | 11 | def isZero : Constraint().getInt() == 0"> , "check zero">; 12 | def isAllZero : Constraint() && llvm::all_of($0.dyn_cast().getValues(),[&](double d){ return d == 0;})">, "check all zero">; 13 | def isAllOne : Constraint() && llvm::all_of($0.dyn_cast().getValues(),[&](double d){ return d == 1.0;})">, "check all one">; 14 | def isAllMinusOne : Constraint() && llvm::all_of($0.dyn_cast().getValues(),[&](double d){ return d == -1.0;})">, "check all minus one">; 15 | def geZeroAfterRescale : Constraint().getInt() - hecate::earth::EarthDialect::rescalingFactor >= 0"> , "check zero">; 16 | def singleUser : Constraint, "single user">; 17 | def structureNotFixed : ConstraintgetAttr(\"sm_fixed\")">, "scale management structure is not fixed">; 18 | 19 | def ZeroUpscalePattern : Pat<(UpscaleOp $arg, $val), (replaceWithValue $arg), [(isZero $val)]>; 20 | def ZeroModswitchPattern : Pat<(ModswitchOp $arg, $val), (replaceWithValue $arg), [(isZero $val)]>; 21 | def AddZeroPattern : Pat<(AddOp (either $lhs, (ConstantOp $val, $dat2)) ), (replaceWithValue $lhs), [(isAllZero $val)]>; 22 | def MulZeroPattern : Pat<(MulOp (either $lhs, (ConstantOp:$res $val, $dat2)) ), (replaceWithValue $res), [(isAllZero $val)]>; 23 | def MulOnePattern : Pat<(MulOp (either $lhs, (ConstantOp:$res $val, $dat2)) ), (replaceWithValue $lhs), [(isAllOne $val)]>; 24 | def NegMulPattern : Pat<(MulOp (either $lhs, (ConstantOp:$res $val, $dat2)) ), (NegateOp $lhs), [(isAllMinusOne $val)]>; 25 | 26 | 27 | // This patterns 28 | def addIntegerAttr : NativeCodeCall<"$_builder.getI64IntegerAttr($0.dyn_cast().getInt() + $1.dyn_cast().getInt())">; 29 | def UpscaleUpscalePattern : Pat<(UpscaleOp (UpscaleOp:$res $arg, $val1), $val2), (UpscaleOp $arg, (addIntegerAttr $val1, $val2)), [(structureNotFixed $res)]>; 30 | def ModswitchModswitchPattern : Pat<(ModswitchOp (ModswitchOp:$res $arg, $val1), $val2), (ModswitchOp $arg, (addIntegerAttr $val1, $val2)), [(structureNotFixed $res)]>; 31 | 32 | def changeConstScale : NativeCodeCall<"$_builder.create<::hecate::earth::ConstantOp>($_loc, $0.getType().dyn_cast<::hecate::earth::HEScaleTypeInterface>().switchScale($0.getType().dyn_cast<::hecate::earth::HEScaleTypeInterface>().getScale() + $1.dyn_cast().getInt()), $2, $3)">; 33 | def UpscaleConstantPattern: Pat<(UpscaleOp (ConstantOp:$res $dat1,$dat2), $up), (changeConstScale $res, $up, $dat1, $dat2)>; 34 | 35 | def changeConstLevel : NativeCodeCall<"$_builder.create<::hecate::earth::ConstantOp>($_loc, $0.getType().dyn_cast<::hecate::earth::HEScaleTypeInterface>().switchLevel($0.getType().dyn_cast<::hecate::earth::HEScaleTypeInterface>().getLevel() + $1.dyn_cast().getInt()), $2, $3)">; 36 | def ModswitchConstantPattern: Pat<(ModswitchOp (ConstantOp:$res $dat1,$dat2), $down), (changeConstLevel $res, $down, $dat1, $dat2)>; 37 | 38 | /* def moduloSlot : NativeCodeCall<"$0 % (hecate::earth::EarthDialect::polynomialDegree/2)">; */ 39 | /* def RotateOffsetModuloPattern: Pat<(RotateOp $arg, $offset), (RotateOp $arg, (moduloSlot $offset))>; */ 40 | 41 | 42 | def UpscaleRescalePattern: Pat<(UpscaleOp (RescaleOp:$res $arg), $up), (RescaleOp (UpscaleOp $arg, $up)), [(singleUser $res)]>; 43 | 44 | def getScaleAfterRescale : NativeCodeCall<"$0.dyn_cast().getInt() - hecate::earth::EarthDialect::rescalingFactor">; 45 | 46 | def RescaleUpscalePattern: Pat<(RescaleOp (UpscaleOp $arg, $up) ), (ModswitchOp (UpscaleOp $arg, (getScaleAfterRescale $up) ), (NativeCodeCall<"$_builder.getI64IntegerAttr(1)">) ), [(geZeroAfterRescale $up)]>; 47 | 48 | #endif 49 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/EarthOps.h: -------------------------------------------------------------------------------- 1 | #ifndef HECATE_DIALECT_ARITH_IR_ARITHOPS_H 2 | #define HECATE_DIALECT_ARITH_IR_ARITHOPS_H 3 | 4 | #include "mlir/Dialect/Func/IR/FuncOps.h" 5 | #include "mlir/IR/Builders.h" 6 | #include "mlir/IR/BuiltinOps.h" 7 | #include "mlir/IR/BuiltinTypes.h" 8 | #include "mlir/IR/Dialect.h" 9 | #include "mlir/IR/OpDefinition.h" 10 | #include "mlir/IR/OpImplementation.h" 11 | #include "mlir/IR/PatternMatch.h" 12 | #include "mlir/Interfaces/CallInterfaces.h" 13 | #include "mlir/Interfaces/CastInterfaces.h" 14 | #include "mlir/Interfaces/ControlFlowInterfaces.h" 15 | #include "mlir/Interfaces/CopyOpInterface.h" 16 | #include "mlir/Interfaces/InferTypeOpInterface.h" 17 | #include "mlir/Interfaces/SideEffectInterfaces.h" 18 | #include 19 | 20 | #include "hecate/Support/Support.h" 21 | 22 | #include "hecate/Dialect/Earth/IR/ForwardManagementInterface.h" 23 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 24 | 25 | #include "hecate/Dialect/Earth/IR/EarthOpsDialect.h.inc" 26 | #define GET_TYPEDEF_CLASSES 27 | #include "hecate/Dialect/Earth/IR/EarthOpsTypes.h.inc" 28 | #define GET_OP_CLASSES 29 | #include "hecate/Dialect/Earth/IR/EarthOps.h.inc" 30 | 31 | namespace hecate { 32 | namespace earth { 33 | ::hecate::earth::HEScaleTypeInterface getScaleType(mlir::Value v); 34 | ::mlir::RankedTensorType getTensorType(mlir::Value v); 35 | } // namespace earth 36 | } // namespace hecate 37 | #endif 38 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/ForwardManagementInterface.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_DIALECT_ARITH_IR_FORWARD_H 3 | #define HECATE_DIALECT_ARITH_IR_FORWARD_H 4 | 5 | #include "mlir/IR/BuiltinTypes.h" 6 | #include "mlir/IR/OpDefinition.h" 7 | 8 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 9 | 10 | #include "hecate/Dialect/Earth/IR/ForwardManagementInterface.h.inc" 11 | 12 | #endif // MLIR_DIALECT_AFFINE_IR_AFFINEMEMORYOPINTERFACES_H 13 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/ForwardManagementInterface.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef FORWARD_MANAGEMENT_INTERFACE 3 | #define FORWARD_MANAGEMENT_INTERFACE 4 | 5 | include "mlir/IR/OpBase.td" 6 | include "hecate/Dialect/Earth/IR/HEParameterInterface.td" 7 | 8 | def ForwardMgmtInterface : OpInterface<"ForwardMgmtInterface"> { 9 | let description = [{ 10 | An op interface of CKKS parameter optimization. 11 | Ops that implement this interface can be analyzed and managed by scale management scheme. 12 | }]; 13 | let cppNamespace = "::hecate::earth"; 14 | let methods = [ 15 | InterfaceMethod< 16 | /*desc=*/[{Check that this is constant op. 17 | }], 18 | /*retType=*/"void", 19 | /*methodName=*/"processOperandsEVA", 20 | /*args=*/(ins "int64_t":$param), 21 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }] 22 | >, 23 | InterfaceMethod< 24 | /*desc=*/[{Check that this is constant op. 25 | }], 26 | /*retType=*/"void", 27 | /*methodName=*/"processResultsEVA", 28 | /*args=*/(ins "int64_t":$param), 29 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return; }] 30 | >, 31 | InterfaceMethod< 32 | /*desc=*/[{Check that this is constant op. 33 | }], 34 | /*retType=*/"void", 35 | /*methodName=*/"processOperandsPARS", 36 | /*args=*/(ins "int64_t":$param), 37 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.processOperandsEVA(param); }] 38 | >, 39 | InterfaceMethod< 40 | /*desc=*/[{Check that this is constant op. 41 | }], 42 | /*retType=*/"void", 43 | /*methodName=*/"processResultsPARS", 44 | /*args=*/(ins "int64_t":$param), 45 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.processResultsEVA(param); }] 46 | >, 47 | InterfaceMethod< 48 | /*desc=*/[{Check that this is constant op. 49 | }], 50 | /*retType=*/"void", 51 | /*methodName=*/"processOperandsSNR", 52 | /*args=*/(ins "int64_t":$param), 53 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.processOperandsPARS(param); }] 54 | >, 55 | InterfaceMethod< 56 | /*desc=*/[{Check that this is constant op. 57 | }], 58 | /*retType=*/"void", 59 | /*methodName=*/"processResultsSNR", 60 | /*args=*/(ins "int64_t":$param), 61 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.processResultsPARS(param); }] 62 | > 63 | ]; 64 | } 65 | 66 | #endif 67 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/HEParameterInterface.h: -------------------------------------------------------------------------------- 1 | #ifndef HECATE_DIALECT_ARITH_IR_HEPARAMETERINTERFACES_H 2 | #define HECATE_DIALECT_ARITH_IR_HEPARAMETERINTERFACES_H 3 | 4 | #include "mlir/IR/BuiltinTypes.h" 5 | #include "mlir/IR/OpDefinition.h" 6 | 7 | #include "hecate/Dialect/Earth/IR/HEParameterInterfaceTypes.h.inc" 8 | 9 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h.inc" 10 | 11 | #endif // MLIR_DIALECT_AFFINE_IR_AFFINEMEMORYOPINTERFACES_H 12 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/IR/HEParameterInterface.td: -------------------------------------------------------------------------------- 1 | #ifndef HE_PARAMETER_INTERFACE 2 | #define HE_PARAMETER_INTERFACE 3 | 4 | include "mlir/IR/OpBase.td" 5 | include "mlir/IR/OpBase.td" 6 | 7 | def HEScaleOpInterface : OpInterface<"HEScaleOpInterface"> { 8 | let description = [{ 9 | An op interface of CKKS parameter optimization. 10 | Ops that implement this interface can be analyzed and managed by scale management scheme. 11 | }]; 12 | let cppNamespace = "::hecate::earth"; 13 | let methods = [ 14 | InterfaceMethod< 15 | /*desc=*/[{Check that this is constant op. 16 | }], 17 | /*retType=*/"bool", 18 | /*methodName=*/"isConst", 19 | /*args=*/(ins ), 20 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op->template hasTrait(); }] 21 | >, 22 | InterfaceMethod< 23 | /*desc=*/[{Check that there is only one ciphertext operand. 24 | }], 25 | /*retType=*/"bool", 26 | /*methodName=*/"isSingle" 27 | >, 28 | InterfaceMethod< 29 | /*desc=*/[{Check that the scale consumption is appeared. 30 | }], 31 | /*retType=*/"bool", 32 | /*methodName=*/"isConsume" 33 | >, 34 | InterfaceMethod< 35 | /*desc=*/[{Check that the scale consumption is appeared. 36 | }], 37 | /*retType=*/"bool", 38 | /*methodName=*/"isNoisy", 39 | /*args=*/(ins ), 40 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return false; }] 41 | >, 42 | InterfaceMethod< 43 | /*desc=*/[{Get an operand scale type of the operation. 44 | }], 45 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 46 | /*methodName=*/"getOperandScaleType", 47 | /*args=*/(ins "unsigned":$idx) 48 | >, 49 | InterfaceMethod< 50 | /*desc=*/[{Get a result scale type of the operation. 51 | }], 52 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 53 | /*methodName=*/"getScaleType" 54 | >, 55 | InterfaceMethod< 56 | /*desc=*/[{Get a level of the result. 57 | }], 58 | /*retType=*/"bool", 59 | /*methodName=*/"isOperandCipher", 60 | /*args=*/(ins "unsigned":$idx), 61 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getOperandScaleType(idx).isCipher(); }] 62 | >, 63 | InterfaceMethod< 64 | /*desc=*/[{Get a level of the result. 65 | }], 66 | /*retType=*/"unsigned", 67 | /*methodName=*/"getOperandLevel", 68 | /*args=*/(ins "unsigned":$idx), 69 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getOperandScaleType(idx).getLevel(); }] 70 | >, 71 | InterfaceMethod< 72 | /*desc=*/[{Get a level of the result. 73 | }], 74 | /*retType=*/"unsigned", 75 | /*methodName=*/"getOperandScale", 76 | /*args=*/(ins "unsigned":$idx), 77 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getOperandScaleType(idx).getScale(); }] 78 | >, 79 | InterfaceMethod< 80 | /*desc=*/[{Get a level of the result. 81 | }], 82 | /*retType=*/"bool", 83 | /*methodName=*/"isCipher", 84 | /*args=*/(ins), 85 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getScaleType().isCipher(); }] 86 | >, 87 | InterfaceMethod< 88 | /*desc=*/[{Get a level of the result. 89 | }], 90 | /*retType=*/"unsigned", 91 | /*methodName=*/"getRescaleLevel", 92 | /*args=*/(ins), 93 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getScaleType().getLevel(); }] 94 | >, 95 | InterfaceMethod< 96 | /*desc=*/[{Get a scale of the result. 97 | }], 98 | /*retType=*/"unsigned", 99 | /*methodName=*/"getScale", 100 | /*args=*/(ins), 101 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ return $_op.getScaleType().getScale(); }] 102 | > 103 | ]; 104 | } 105 | 106 | 107 | def HEProfInterface : OpInterface<"HEProfInterface"> { 108 | let description = [{ 109 | Profile data reader interface for profile-based scale optimization. 110 | This interface depends on the HEScaleOpInterface. 111 | }]; 112 | let cppNamespace = "::hecate::earth"; 113 | let methods = [ 114 | InterfaceMethod< 115 | /*desc=*/[{Get an estimated latency of the given operation on a given level. 116 | }], 117 | /*retType=*/"unsigned", 118 | /*methodName=*/"getLatencyOf", 119 | /*args=*/(ins "unsigned":$level) 120 | >, 121 | InterfaceMethod< 122 | /*desc=*/[{Get an estimated latency of the given operation on a given level. 123 | }], 124 | /*retType=*/"unsigned", 125 | /*methodName=*/"getNum", 126 | /*args=*/(ins), 127 | /*methodBody=*/[{}], 128 | /*defaultImplementation=*/[{ 129 | int num = 0; 130 | for (auto&& v : $_op->getResults()){ 131 | if (auto ranked = v.getType().template dyn_cast()){ 132 | num += ranked.getNumElements(); 133 | } else { 134 | num +=1; 135 | } 136 | } 137 | return num; 138 | }] 139 | >, 140 | InterfaceMethod< 141 | /*desc=*/[{Get a scale of the result. 142 | }], 143 | /*retType=*/"unsigned", 144 | /*methodName=*/"getNoiseScale", 145 | /*args=*/(ins) 146 | >, 147 | InterfaceMethod< 148 | /*desc=*/[{Check that the scale consumption is appeared. 149 | }], 150 | /*retType=*/"unsigned", 151 | /*methodName=*/"getCipherLevel" 152 | >, 153 | InterfaceMethod< 154 | /*desc=*/[{Get an estimated noise of the given operation on a given level. 155 | }], 156 | /*retType=*/"double", 157 | /*methodName=*/"getNoiseOf", 158 | /*args=*/(ins "unsigned":$level) 159 | >, 160 | InterfaceMethod< 161 | /*desc=*/[{Get an estimated latency of the given operation. 162 | }], 163 | /*retType=*/"unsigned", 164 | /*methodName=*/"getLatency", 165 | /*args=*/(ins), 166 | /*methodBody=*/[{}], 167 | /*defaultImplementation=*/[{ return $_op.getLatencyOf($_op.getCipherLevel()); }] 168 | >, 169 | InterfaceMethod< 170 | /*desc=*/[{Get an estimated noise of the given operation. 171 | }], 172 | /*retType=*/"double", 173 | /*methodName=*/"getNoise", 174 | /*args=*/(ins), 175 | /*methodBody=*/[{}], 176 | /*defaultImplementation=*/[{ return $_op.getNoiseOf($_op.getCipherLevel()); }] 177 | >, 178 | InterfaceMethod< 179 | /*desc=*/[{Get an estimated first derivative of the latency of the given operation on a given level. 180 | }], 181 | /*retType=*/"unsigned", 182 | /*methodName=*/"getLatencyDiffOf", 183 | /*args=*/(ins "unsigned":$level), 184 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ 185 | if (level >0) 186 | return $_op.getLatencyOf(level)-$_op.getLatencyOf(level-1); 187 | else 188 | return $_op.getLatencyOf(level); 189 | }] 190 | >, 191 | InterfaceMethod< 192 | /*desc=*/[{Get an estimated first derivative of the noise of the given operation on a given level. 193 | }], 194 | /*retType=*/"double", 195 | /*methodName=*/"getNoiseDiffOf", 196 | /*args=*/(ins "unsigned":$level), 197 | /*methodBody=*/[{}], /*defaultImplementation=*/[{ 198 | if (level >0) 199 | return $_op.getNoiseOf(level)-$_op.getNoiseOf(level-1); 200 | else 201 | return $_op.getNoiseOf(level); 202 | }] 203 | >, 204 | InterfaceMethod< 205 | /*desc=*/[{Get an estimated first derivative of the latency of the given operation. 206 | }], 207 | /*retType=*/"unsigned", 208 | /*methodName=*/"getLatencyDiff", 209 | /*args=*/(ins), 210 | /*methodBody=*/[{}], 211 | /*defaultImplementation=*/[{ return getLatencyDiffOf($_op.getCipherLevel()); }] 212 | >, 213 | InterfaceMethod< 214 | /*desc=*/[{Get an estimated first derivative of the noise of the given operation . 215 | }], 216 | /*retType=*/"double", 217 | /*methodName=*/"getNoiseDiff", 218 | /*args=*/(ins), 219 | /*methodBody=*/[{}], 220 | /*defaultImplementation=*/[{ return getNoiseDiffOf($_op.getCipherLevel()); }] 221 | > 222 | ]; 223 | } 224 | 225 | 226 | def HEAutoDiffInterface : OpInterface<"HEAutoDiffInterface"> { 227 | let description = [{ 228 | Auto differentiation framework for HE. 229 | Assuming element-wise differentiation and ciphertext-wise result to reduce the overhead. 230 | fp64 elements takes 512KiB per each ciphertext. 231 | Use backward analysis for diff and forward analysis for value. 232 | }]; 233 | let cppNamespace = "::hecate::earth"; 234 | let methods = [ 235 | InterfaceMethod< 236 | /*desc=*/[{Analyze error inference structure by automatic differentiation. 237 | }], 238 | /*retType=*/"::llvm::SmallVector", 239 | /*methodName=*/"differentiate", 240 | /*args=*/(ins "::llvm::ArrayRef":$gradient,"::llvm::ArrayRef":$estimation) 241 | >, 242 | InterfaceMethod< 243 | /*desc=*/[{Estimate value. 244 | }], 245 | /*retType=*/"::llvm::SmallVector", 246 | /*methodName=*/"estimateValue", 247 | /*args=*/(ins "::llvm::ArrayRef":$estimation) 248 | > 249 | ]; 250 | } 251 | 252 | def HEScaleTypeInterface : TypeInterface<"HEScaleTypeInterface"> { 253 | 254 | let description = [{ 255 | A type interface of CKKS parameter optimization. 256 | Type that implement this interface can be analyzed and managed by scale management scheme. 257 | }]; 258 | let cppNamespace = "::hecate::earth"; 259 | let methods = [ 260 | InterfaceMethod< 261 | /*desc=*/[{Check that there is only one ciphertext operand. 262 | }], 263 | /*retType=*/"bool", 264 | /*methodName=*/"isCipher" 265 | >, 266 | InterfaceMethod< 267 | /*desc=*/[{Check that there is only one ciphertext operand. 268 | }], 269 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 270 | /*methodName=*/"toCipher" 271 | >, 272 | InterfaceMethod< 273 | /*desc=*/[{Check that there is only one ciphertext operand. 274 | }], 275 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 276 | /*methodName=*/"toPlain" 277 | >, 278 | InterfaceMethod< 279 | /*desc=*/[{Get a scale of a given type. 280 | }], 281 | /*retType=*/"unsigned", 282 | /*methodName=*/"getScale" 283 | >, 284 | InterfaceMethod< 285 | /*desc=*/[{Get a scale of a given type. 286 | }], 287 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 288 | /*methodName=*/"switchScale", 289 | /*args=*/(ins "unsigned":$scale) 290 | >, 291 | InterfaceMethod< 292 | /*desc=*/[{Get a level of a given type. This can be either rescale or ciphertext level. 293 | }], 294 | /*retType=*/"unsigned", 295 | /*methodName=*/"getLevel" 296 | >, 297 | InterfaceMethod< 298 | /*desc=*/[{Get a scale of a given type. 299 | }], 300 | /*retType=*/"::hecate::earth::HEScaleTypeInterface", 301 | /*methodName=*/"switchLevel", 302 | /*args=*/(ins "unsigned":$level) 303 | >, 304 | ]; 305 | } 306 | 307 | 308 | 309 | #endif 310 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LLVM_TARGET_DEFINITIONS Passes.td) 2 | mlir_tablegen(Passes.h.inc -gen-pass-decls -name Earth) 3 | add_public_tablegen_target(HecateEarthTransformsIncGen) 4 | 5 | # add_mlir_doc(Passes EarthPasses ./ -gen-pass-doc) 6 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Transforms/Common.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_EARTH_TRANSFROMS_COMMON 3 | #define HECATE_EARTH_TRANSFROMS_COMMON 4 | 5 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 6 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 7 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 8 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 9 | #include "mlir/Dialect/Func/IR/FuncOps.h" 10 | 11 | namespace hecate { 12 | namespace earth { 13 | 14 | void refineReturnValues(mlir::func::FuncOp func, mlir::OpBuilder builder, 15 | llvm::SmallVector inputTypes, 16 | int64_t waterline, int64_t output_val); 17 | void inferTypeForward(hecate::earth::ForwardMgmtInterface sop); 18 | 19 | } // namespace earth 20 | } // namespace hecate 21 | 22 | #endif 23 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Transforms/Passes.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef HECATE_DIALECT_EARTH_TRANSFORMS_PASSES_H_ 4 | #define HECATE_DIALECT_EARTH_TRANSFORMS_PASSES_H_ 5 | 6 | #include "mlir/Pass/Pass.h" 7 | #include 8 | 9 | namespace hecate { 10 | 11 | namespace earth { 12 | 13 | #define GEN_PASS_DECL 14 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 15 | 16 | /* #define GEN_PASS_DECL_ARITHINTRANGEOPTS */ 17 | /* #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" */ 18 | 19 | //===----------------------------------------------------------------------===// 20 | // Registration 21 | //===----------------------------------------------------------------------===// 22 | 23 | /// Generate the code for registering passes. 24 | #define GEN_PASS_REGISTRATION 25 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 26 | 27 | } // namespace earth 28 | } // namespace hecate 29 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /include/hecate/Dialect/Earth/Transforms/Passes.td: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_DIALECT_EARTH_TRANSFORMS_PASSES 3 | #define HECATE_DIALECT_EARTH_TRANSFORMS_PASSES 4 | 5 | include "mlir/Pass/PassBase.td" 6 | 7 | def ElideConstant: Pass<"elide-constant", "::mlir::func::FuncOp"> { 8 | let summary = "Elide Constants"; 9 | let description = [{ 10 | This pass saves the constant datas in a {function name}.cst file. 11 | The file contains the list of the (length, data) pair. 12 | The constant value will be changed to the index. 13 | }]; 14 | let options = [ 15 | Option<"name", "name", "std::string", /*default=*/[{""}], 16 | "Name of output file"> 17 | ]; 18 | } 19 | def PrivatizeConstant: Pass<"privatize-constant", "::mlir::func::FuncOp"> { 20 | let summary = "Privatize Constants"; 21 | let description = [{ 22 | All of the constant op should have a single use after this pass. 23 | }]; 24 | } 25 | def WaterlineRescaling : Pass<"waterline-rescaling", "::mlir::func::FuncOp"> { 26 | let summary = "Apply Waterline Rescaling"; 27 | let description = [{ 28 | This pass implements the waterline rescaling of EVA. 29 | }]; 30 | let options = [ 31 | Option<"waterline", "waterline", "int64_t", /*default=*/"20", 32 | "The minimal result scale of rescaling">, 33 | Option<"output_val", "output_val", "int64_t", /*default=*/"10", 34 | "The maximum result value bits of return value">, 35 | ]; 36 | } 37 | def SNRRescaling : Pass<"snr-rescaling", "::mlir::func::FuncOp"> { 38 | let summary = "Apply SNR Rescaling"; 39 | let description = [{ 40 | This pass implements the SNR rescaling of ELASM. 41 | }]; 42 | let options = [ 43 | Option<"waterline", "waterline", "int64_t", /*default=*/"20", 44 | "The minimal result scale of rescaling">, 45 | Option<"output_val", "output_val", "int64_t", /*default=*/"10", 46 | "The maximum result value bits of return value">, 47 | ]; 48 | } 49 | def ProactiveRescaling : Pass<"proactive-rescaling", "::mlir::func::FuncOp"> { 50 | let summary = "Apply Proactive Rescaling"; 51 | let description = [{ 52 | This pass implements the proactive rescaling of Hecate. 53 | }]; 54 | let options = [ 55 | Option<"waterline", "waterline", "int64_t", /*default=*/"20", 56 | "The minimal result scale of rescaling">, 57 | Option<"output_val", "output_val", "int64_t", /*default=*/"10", 58 | "The maximum result value bits of return value">, 59 | ]; 60 | } 61 | def EarlyModswitch: Pass<"early-modswitch", "::mlir::func::FuncOp"> { 62 | let summary = "Apply Early Modswitch"; 63 | let description = [{ 64 | This pass implements the early modswitch of EVA. 65 | }]; 66 | } 67 | def UpscaleBubbling : Pass<"upscale-bubbling", "::mlir::func::FuncOp"> { 68 | let summary = "Apply Upscale Bubbling"; 69 | let description = [{ 70 | This pass may accelerate the convergence of ELASM. 71 | This pass is not proposed on the ELASM paper. 72 | }]; 73 | } 74 | def SMUChecker : Pass<"check-smu", "::mlir::func::FuncOp"> { 75 | let summary = "Check the correctness of SMU generation"; 76 | let description = [{ 77 | This pass checks the correctness of SMU generation. 78 | }]; 79 | } 80 | def SMUEmbedding : Pass<"embed-smu", "::mlir::func::FuncOp"> { 81 | let summary = "Embed SMU IDs to the operation"; 82 | let description = [{ 83 | This pass embeds the SMU analysis. Embdding analysis means that the SMU analysis of copy of the function can be recovered. 84 | }]; 85 | } 86 | def ScaleManagementScheduler : Pass<"schedule-scale-management", "::mlir::func::FuncOp"> { 87 | let summary = "Schedule the scale management with apply_schedule op"; 88 | let description = [{ 89 | This pass schedules the scale management with apply_schedule op }]; 90 | } 91 | def ELASMExplorer : Pass<"elasm-explorer", "::mlir::func::FuncOp"> { 92 | let summary = "Error-latency-aware scale management driver pass"; 93 | let description = [{ 94 | This pass iteratively generates schedule, applies schedule, generates code, evaluate the error and latency.}]; 95 | let options = [ 96 | Option<"waterline", "waterline", "int64_t", /*default=*/"20", 97 | "The minimal result scale of rescaling">, 98 | Option<"output_val", "output_val", "int64_t", /*default=*/"10", 99 | "The maximum result value bits of return value">, 100 | Option<"parallel", "parallel", "int64_t", /*default=*/"20", 101 | "The number of parallel explorer">, 102 | Option<"num_iter", "num_iter", "int64_t", /*default=*/"1000", 103 | "The iteration count of explorer">, 104 | Option<"beta", "beta", "int64_t", /*default=*/"50", 105 | "The minimal result scale of rescaling">, 106 | Option<"gamma", "gamma", "int64_t", /*default=*/"50", 107 | "The minimal result scale of rescaling">, 108 | ]; 109 | } 110 | def ErrorEstimator : Pass<"estimate-error", "::mlir::func::FuncOp"> { 111 | let summary = "Estimate the error of a given function"; 112 | let description = [{ 113 | This pass estimates the resulting error of function and attaches attribute. 114 | This pass requires for function to fully scale managed and all operations have HEScaleOpInterface, 115 | HEProfInterface, and HEAutoDiffInterface}]; 116 | } 117 | def LatencyEstimator : Pass<"estimate-latency", "::mlir::func::FuncOp"> { 118 | let summary = "Estimate the latency of a given function"; 119 | let description = [{ 120 | This pass estimates the overall latency of function and attaches attribute. 121 | This pass requires for function to fully scale managed and all operations have HEScaleOpInterface, 122 | HEProfInterface}]; 123 | } 124 | 125 | /* def ScaleManagementStructuring : Pass<> */ 126 | 127 | 128 | #endif 129 | -------------------------------------------------------------------------------- /include/hecate/Support/HEVMHeader.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_HEVMHEADER 3 | #define HECATE_HEVMHEADER 4 | 5 | #include 6 | extern "C" { 7 | 8 | struct Body; 9 | 10 | struct HEVMHeader { 11 | uint32_t magic_number = 0x4845564D; 12 | uint32_t hevm_header_size; // 8 + config header size byte 13 | struct ConfigHeader { 14 | uint64_t arg_length; 15 | uint64_t res_length; 16 | } config_header; 17 | }; 18 | struct ConfigBody { // All of entry is 64 bit 19 | uint64_t config_body_length; 20 | uint64_t num_operations; 21 | uint64_t num_ctxt_buffer; 22 | uint64_t num_ptxt_buffer; 23 | uint64_t init_level; 24 | /* uint64_t arg_scale[arg_length]; */ 25 | /* uint64_t arg_level[arg_length]; */ 26 | /* uint64_t res_scale[res_length]; */ 27 | /* uint64_t res_level[res_length]; */ 28 | /* uint64_t res_dst[res_length]; */ 29 | }; 30 | struct HEVMOperation { 31 | uint16_t opcode; 32 | uint16_t dst; 33 | uint16_t lhs; 34 | uint16_t rhs; 35 | }; // 64 bit 36 | 37 | } // namespace "C" 38 | #endif 39 | -------------------------------------------------------------------------------- /include/hecate/Support/Support.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef HECATE_SUPPORT_SUPPORT 3 | #define HECATE_SUPPORT_SUPPORT 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | namespace hecate { 10 | 11 | inline llvm::SmallVector naf(int value) { 12 | llvm::SmallVector res; 13 | 14 | // Record the sign of the original value and compute abs 15 | bool sign = value < 0; 16 | value = std::abs(value); 17 | 18 | // Transform to non-adjacent form (NAF) 19 | for (int i = 0; value; i++) { 20 | int zi = (value & int(0x1)) ? 2 - (value & int(0x3)) : 0; 21 | value = (value - zi) >> 1; 22 | if (zi) { 23 | res.push_back((sign ? -zi : zi) * (1 << i)); 24 | } 25 | } 26 | 27 | return res; 28 | } 29 | 30 | inline void setIntegerAttr(llvm::StringRef name, mlir::Value v, int64_t data) { 31 | unsigned argnum = 0; 32 | mlir::Operation *op = nullptr; 33 | if (auto ba = v.dyn_cast()) { 34 | argnum = ba.getArgNumber(); 35 | op = ba.getOwner()->getParentOp(); 36 | } else if (auto opr = v.dyn_cast()) { 37 | argnum = opr.getResultNumber(); 38 | op = opr.getOwner(); 39 | } else { 40 | assert(0 && "Value should be either block argument or op result"); 41 | } 42 | auto builder = mlir::OpBuilder(op); 43 | op->setAttr(std::string(name) + std::to_string(argnum), 44 | builder.getI64IntegerAttr(data)); 45 | } 46 | 47 | inline int64_t getIntegerAttr(llvm::StringRef name, mlir::Value v) { 48 | unsigned argnum = 0; 49 | mlir::Operation *op = nullptr; 50 | if (auto ba = v.dyn_cast()) { 51 | argnum = ba.getArgNumber(); 52 | op = ba.getOwner()->getParentOp(); 53 | } else if (auto opr = v.dyn_cast()) { 54 | argnum = opr.getResultNumber(); 55 | op = opr.getOwner(); 56 | } else { 57 | assert(0 && "Value should be either block argument or op result"); 58 | } 59 | if (auto attr = op->getAttr(std::string(name) + std::to_string(argnum))) { 60 | return attr.dyn_cast().getInt(); 61 | } else { 62 | return -1; 63 | } 64 | } 65 | 66 | } // namespace hecate 67 | 68 | #endif 69 | -------------------------------------------------------------------------------- /include/nlohmann/json_fwd.hpp: -------------------------------------------------------------------------------- 1 | // __ _____ _____ _____ 2 | // __| | __| | | | JSON for Modern C++ 3 | // | | |__ | | | | | | version 3.11.2 4 | // |_____|_____|_____|_|___| https://github.com/nlohmann/json 5 | // 6 | // SPDX-FileCopyrightText: 2013-2022 Niels Lohmann 7 | // SPDX-License-Identifier: MIT 8 | 9 | #ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ 10 | #define INCLUDE_NLOHMANN_JSON_FWD_HPP_ 11 | 12 | #include // int64_t, uint64_t 13 | #include // map 14 | #include // allocator 15 | #include // string 16 | #include // vector 17 | 18 | // #include 19 | // __ _____ _____ _____ 20 | // __| | __| | | | JSON for Modern C++ 21 | // | | |__ | | | | | | version 3.11.2 22 | // |_____|_____|_____|_|___| https://github.com/nlohmann/json 23 | // 24 | // SPDX-FileCopyrightText: 2013-2022 Niels Lohmann 25 | // SPDX-License-Identifier: MIT 26 | 27 | 28 | 29 | // This file contains all macro definitions affecting or depending on the ABI 30 | 31 | #ifndef JSON_SKIP_LIBRARY_VERSION_CHECK 32 | #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) 33 | #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 2 34 | #warning "Already included a different version of the library!" 35 | #endif 36 | #endif 37 | #endif 38 | 39 | #define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) 40 | #define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) 41 | #define NLOHMANN_JSON_VERSION_PATCH 2 // NOLINT(modernize-macro-to-enum) 42 | 43 | #ifndef JSON_DIAGNOSTICS 44 | #define JSON_DIAGNOSTICS 0 45 | #endif 46 | 47 | #ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 48 | #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 49 | #endif 50 | 51 | #if JSON_DIAGNOSTICS 52 | #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag 53 | #else 54 | #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS 55 | #endif 56 | 57 | #if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 58 | #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp 59 | #else 60 | #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON 61 | #endif 62 | 63 | #ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION 64 | #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 65 | #endif 66 | 67 | // Construct the namespace ABI tags component 68 | #define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) json_abi ## a ## b 69 | #define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b) \ 70 | NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b) 71 | 72 | #define NLOHMANN_JSON_ABI_TAGS \ 73 | NLOHMANN_JSON_ABI_TAGS_CONCAT( \ 74 | NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ 75 | NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON) 76 | 77 | // Construct the namespace version component 78 | #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \ 79 | _v ## major ## _ ## minor ## _ ## patch 80 | #define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ 81 | NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) 82 | 83 | #if NLOHMANN_JSON_NAMESPACE_NO_VERSION 84 | #define NLOHMANN_JSON_NAMESPACE_VERSION 85 | #else 86 | #define NLOHMANN_JSON_NAMESPACE_VERSION \ 87 | NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \ 88 | NLOHMANN_JSON_VERSION_MINOR, \ 89 | NLOHMANN_JSON_VERSION_PATCH) 90 | #endif 91 | 92 | // Combine namespace components 93 | #define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b 94 | #define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \ 95 | NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) 96 | 97 | #ifndef NLOHMANN_JSON_NAMESPACE 98 | #define NLOHMANN_JSON_NAMESPACE \ 99 | nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \ 100 | NLOHMANN_JSON_ABI_TAGS, \ 101 | NLOHMANN_JSON_NAMESPACE_VERSION) 102 | #endif 103 | 104 | #ifndef NLOHMANN_JSON_NAMESPACE_BEGIN 105 | #define NLOHMANN_JSON_NAMESPACE_BEGIN \ 106 | namespace nlohmann \ 107 | { \ 108 | inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \ 109 | NLOHMANN_JSON_ABI_TAGS, \ 110 | NLOHMANN_JSON_NAMESPACE_VERSION) \ 111 | { 112 | #endif 113 | 114 | #ifndef NLOHMANN_JSON_NAMESPACE_END 115 | #define NLOHMANN_JSON_NAMESPACE_END \ 116 | } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ 117 | } // namespace nlohmann 118 | #endif 119 | 120 | 121 | /*! 122 | @brief namespace for Niels Lohmann 123 | @see https://github.com/nlohmann 124 | @since version 1.0.0 125 | */ 126 | NLOHMANN_JSON_NAMESPACE_BEGIN 127 | 128 | /*! 129 | @brief default JSONSerializer template argument 130 | 131 | This serializer ignores the template arguments and uses ADL 132 | ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) 133 | for serialization. 134 | */ 135 | template 136 | struct adl_serializer; 137 | 138 | /// a class to store JSON values 139 | /// @sa https://json.nlohmann.me/api/basic_json/ 140 | template class ObjectType = 141 | std::map, 142 | template class ArrayType = std::vector, 143 | class StringType = std::string, class BooleanType = bool, 144 | class NumberIntegerType = std::int64_t, 145 | class NumberUnsignedType = std::uint64_t, 146 | class NumberFloatType = double, 147 | template class AllocatorType = std::allocator, 148 | template class JSONSerializer = 149 | adl_serializer, 150 | class BinaryType = std::vector, // cppcheck-suppress syntaxError 151 | class CustomBaseClass = void> 152 | class basic_json; 153 | 154 | /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document 155 | /// @sa https://json.nlohmann.me/api/json_pointer/ 156 | template 157 | class json_pointer; 158 | 159 | /*! 160 | @brief default specialization 161 | @sa https://json.nlohmann.me/api/json/ 162 | */ 163 | using json = basic_json<>; 164 | 165 | /// @brief a minimal map-like container that preserves insertion order 166 | /// @sa https://json.nlohmann.me/api/ordered_map/ 167 | template 168 | struct ordered_map; 169 | 170 | /// @brief specialization that maintains the insertion order of object keys 171 | /// @sa https://json.nlohmann.me/api/ordered_json/ 172 | using ordered_json = basic_json; 173 | 174 | NLOHMANN_JSON_NAMESPACE_END 175 | 176 | #endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ 177 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | cd $HECATE/python/hecate 2 | python setup.py sdist --format=tar 3 | pip uninstall hecate 4 | pip install dist/hecate-0.0.1.tar 5 | -------------------------------------------------------------------------------- /lib/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Conversion) 2 | add_subdirectory(Dialect) 3 | add_subdirectory(Runtime) 4 | -------------------------------------------------------------------------------- /lib/Conversion/CKKSCommon/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_mlir_conversion_library(HecateCKKSCommonConversion 3 | PolyTypeConverter.cpp 4 | 5 | LINK_COMPONENTS 6 | Core 7 | 8 | LINK_LIBS PUBLIC 9 | EarthDialect 10 | CKKSDialect 11 | MLIRLLVMCommonConversion 12 | MLIRTransforms 13 | ) 14 | -------------------------------------------------------------------------------- /lib/Conversion/CKKSCommon/PolyTypeConverter.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Conversion/CKKSCommon/PolyTypeConverter.h" 3 | #include "mlir/Transforms/DialectConversion.h" 4 | 5 | using namespace mlir; 6 | using namespace hecate; 7 | 8 | PolyTypeConverter::PolyTypeConverter(int64_t base_level) 9 | : base_level(base_level) { 10 | addConversion([&](mlir::Type t) { return t; }); 11 | addConversion([&](mlir::FunctionType t) { return convertFunctionType(t); }); 12 | addConversion([&](mlir::RankedTensorType t) { return convertTensorType(t); }); 13 | 14 | addConversion( 15 | [&](hecate::earth::CipherType t) { return convertCipherType(t); }); 16 | addConversion( 17 | [&](hecate::earth::PlainType t) { return convertPlainType(t); }); 18 | 19 | /* addConversion([&](hecate::earth::HEScaleTypeInterface t) { */ 20 | /* return convertScaleType(t); */ 21 | /* }); */ 22 | } 23 | 24 | mlir::Type PolyTypeConverter::convertFunctionType(mlir::FunctionType t) { 25 | mlir::SmallVector inputTys; 26 | mlir::SmallVector outputTys; 27 | for (auto &&t : t.getInputs()) { 28 | inputTys.push_back(convertType(t)); 29 | } 30 | for (auto &&t : t.getResults()) { 31 | outputTys.push_back(convertType(t)); 32 | } 33 | return mlir::FunctionType::get(t.getContext(), inputTys, outputTys); 34 | } 35 | 36 | mlir::Type PolyTypeConverter::convertTensorType(mlir::TensorType t) { 37 | return mlir::RankedTensorType::get(t.getShape(), 38 | convertType(t.getElementType())); 39 | } 40 | 41 | mlir::Type PolyTypeConverter::convertCipherType(hecate::earth::CipherType t) { 42 | return hecate::ckks::PolyType::get(t.getContext(), 2, 43 | base_level - t.getLevel()); 44 | } 45 | mlir::Type PolyTypeConverter::convertPlainType(hecate::earth::PlainType t) { 46 | return hecate::ckks::PolyType::get(t.getContext(), 1, 47 | base_level - t.getLevel()); 48 | } 49 | /* mlir::Type */ 50 | /* PolyTypeConverter::convertScaleType(hecate::earth::HEScaleTypeInterface t) { 51 | */ 52 | /* return hecate::ckks::PolyType::get(t.getContext(), t.isCipher() ? 2 : 1, */ 53 | /* base_level - t.getLevel()); */ 54 | /* } */ 55 | -------------------------------------------------------------------------------- /lib/Conversion/CKKSToCKKS/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_mlir_conversion_library(HecateCKKSToCKKS 3 | UpscaleToMulcp.cpp 4 | 5 | DEPENDS 6 | HecateConversionPassIncGen 7 | 8 | LINK_COMPONENTS 9 | Core 10 | 11 | LINK_LIBS PUBLIC 12 | EarthDialect 13 | CKKSDialect 14 | MLIRLLVMCommonConversion 15 | MLIRTransforms 16 | ) 17 | -------------------------------------------------------------------------------- /lib/Conversion/CKKSToCKKS/UpscaleToMulcp.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "hecate/Conversion/CKKSToCKKS/UpscaleToMulcp.h" 4 | #include "hecate/Conversion/CKKSCommon/PolyTypeConverter.h" 5 | 6 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 7 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 8 | #include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" 9 | #include "mlir/Dialect/Tensor/IR/Tensor.h" 10 | #include "mlir/IR/TypeUtilities.h" 11 | #include "mlir/Pass/Pass.h" 12 | #include "mlir/Transforms/DialectConversion.h" 13 | #include 14 | 15 | namespace hecate { 16 | #define GEN_PASS_DEF_UPSCALETOMULCPCONVERSION 17 | #include "hecate/Conversion/Passes.h.inc" 18 | } // namespace hecate 19 | 20 | using namespace mlir; 21 | using namespace hecate; 22 | 23 | namespace { 24 | 25 | //===----------------------------------------------------------------------===// 26 | // Straightforward Op Lowerings 27 | //===----------------------------------------------------------------------===// 28 | 29 | //===----------------------------------------------------------------------===// 30 | // Op Lowering Patterns 31 | //===----------------------------------------------------------------------===// 32 | 33 | /// Directly lower to LLVM op. 34 | 35 | struct UpscaleCOpLowering 36 | : public OpConversionPattern { 37 | using OpConversionPattern::ConversionPattern; 38 | UpscaleCOpLowering(MLIRContext *ctxt) 39 | : OpConversionPattern(ctxt) {} 40 | 41 | LogicalResult 42 | matchAndRewrite(hecate::ckks::UpscaleCOp op, OpAdaptor adaptor, 43 | ConversionPatternRewriter &rewriter) const override; 44 | }; 45 | 46 | } // namespace 47 | 48 | //===----------------------------------------------------------------------===// 49 | // UpscaleOpLowering 50 | //===----------------------------------------------------------------------===// 51 | 52 | LogicalResult 53 | UpscaleCOpLowering::matchAndRewrite(hecate::ckks::UpscaleCOp op, 54 | OpAdaptor adaptor, 55 | ConversionPatternRewriter &rewriter) const { 56 | auto tt = 57 | op.getType().getElementType().dyn_cast(); 58 | 59 | auto dst = rewriter.create( 60 | op.getLoc(), op.getType().getShape(), tt.switchNumPoly(1)); 61 | 62 | auto rhs = rewriter.create( 63 | op.getLoc(), dst, -1, adaptor.getUpFactor(), tt.getLevel()); 64 | 65 | rewriter.replaceOpWithNewOp(op, adaptor.getDst(), 66 | adaptor.getSrc(), rhs); 67 | 68 | /* rewriter.replaceOpWithNewOp(op, dst, adaptor.getValue(), 69 | */ 70 | /* adaptor.getUpFactor()); */ 71 | return success(); 72 | } 73 | 74 | //===----------------------------------------------------------------------===// 75 | // Pass Definition 76 | //===----------------------------------------------------------------------===// 77 | 78 | namespace { 79 | struct UpscaleToMulcpConversion 80 | : public hecate::impl::UpscaleToMulcpConversionBase< 81 | UpscaleToMulcpConversion> { 82 | 83 | using Base::Base; 84 | 85 | void runOnOperation() override { 86 | ConversionTarget target(getContext()); 87 | 88 | auto func = getOperation(); 89 | 90 | mlir::RewritePatternSet patterns(&getContext()); 91 | 92 | target.addIllegalOp(); 93 | target.addLegalDialect(); 94 | target.addLegalDialect(); 95 | target.addLegalDialect(); 96 | 97 | hecate::ckks::populateUpscaleToMulcpConversionPatterns(&getContext(), 98 | patterns); 99 | 100 | if (failed(applyPartialConversion(getOperation(), target, 101 | std::move(patterns)))) 102 | signalPassFailure(); 103 | } 104 | }; 105 | } // namespace 106 | 107 | //===----------------------------------------------------------------------===// 108 | // Pattern Population 109 | //===----------------------------------------------------------------------===// 110 | // 111 | void hecate::ckks::populateUpscaleToMulcpConversionPatterns( 112 | mlir::MLIRContext *ctxt, mlir::RewritePatternSet &patterns) { 113 | // clang-format off 114 | patterns.add (ctxt); 115 | 116 | // clang-format on 117 | } 118 | std::unique_ptr<::mlir::OperationPass<::mlir::func::FuncOp>> 119 | hecate::ckks::createUpscaleToMulcpConversionPass() { 120 | return std::make_unique(); 121 | } 122 | -------------------------------------------------------------------------------- /lib/Conversion/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(CKKSCommon) 2 | add_subdirectory(EarthToCKKS) 3 | add_subdirectory(CKKSToCKKS) 4 | -------------------------------------------------------------------------------- /lib/Conversion/EarthToCKKS/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_conversion_library(HecateEarthToCKKS 2 | EarthToCKKS.cpp 3 | 4 | DEPENDS 5 | HecateConversionPassIncGen 6 | 7 | LINK_COMPONENTS 8 | Core 9 | 10 | LINK_LIBS PUBLIC 11 | EarthDialect 12 | CKKSDialect 13 | MLIRLLVMCommonConversion 14 | MLIRTransforms 15 | ) 16 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/IR/CKKSDialect.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 3 | #include "mlir/IR/DialectImplementation.h" 4 | #include "mlir/IR/PatternMatch.h" 5 | #include "mlir/IR/TypeSupport.h" 6 | #include "mlir/IR/Types.h" 7 | #include "mlir/Support/LLVM.h" 8 | #include "llvm/ADT/SmallString.h" 9 | #include "llvm/ADT/TypeSwitch.h" 10 | 11 | using namespace mlir; 12 | 13 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h" 14 | 15 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterfaceTypes.cpp.inc" 16 | 17 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.cpp.inc" 18 | 19 | #define GET_TYPEDEF_CLASSES 20 | #include "hecate/Dialect/CKKS/IR/CKKSOpsTypes.cpp.inc" 21 | 22 | #define GET_OP_CLASSES 23 | #include "hecate/Dialect/CKKS/IR/CKKSOps.cpp.inc" 24 | 25 | #include "hecate/Dialect/CKKS/IR/CKKSOpsDialect.cpp.inc" 26 | 27 | /* #include "hecate/Dialect/CKKS/IR/EarthCanonicalizerPattern.inc" */ 28 | 29 | struct PolyTypeTensorModel 30 | : public hecate::ckks::PolyTypeInterface::ExternalModel< 31 | PolyTypeTensorModel, mlir::RankedTensorType> { 32 | unsigned getNumPoly(Type t) const { 33 | if (auto polyType = t.dyn_cast() 34 | .getElementType() 35 | .dyn_cast()) { 36 | return polyType.getNumPoly(); 37 | } else { 38 | return 0; 39 | } 40 | } 41 | unsigned getLevel(Type t) const { 42 | if (auto polyType = t.dyn_cast() 43 | .getElementType() 44 | .dyn_cast()) { 45 | return polyType.getLevel(); 46 | } else { 47 | return 0; 48 | } 49 | } 50 | 51 | hecate::ckks::PolyTypeInterface switchLevel(Type t, unsigned level) const { 52 | return mlir::RankedTensorType::get( 53 | t.dyn_cast().getShape(), 54 | t.dyn_cast() 55 | .getElementType() 56 | .dyn_cast() 57 | .switchLevel(level)); 58 | } 59 | hecate::ckks::PolyTypeInterface switchNumPoly(Type t, 60 | unsigned num_poly) const { 61 | return mlir::RankedTensorType::get( 62 | t.dyn_cast().getShape(), 63 | t.dyn_cast() 64 | .getElementType() 65 | .dyn_cast() 66 | .switchNumPoly(num_poly)); 67 | } 68 | }; 69 | 70 | void hecate::ckks::CKKSDialect::initialize() { 71 | // Registers all the Types into the EVADialect class 72 | addTypes< 73 | #define GET_TYPEDEF_LIST 74 | #include "hecate/Dialect/CKKS/IR/CKKSOpsTypes.cpp.inc" 75 | >(); 76 | 77 | // Registers all the Operations into the EVADialect class 78 | addOperations< 79 | #define GET_OP_LIST 80 | #include "hecate/Dialect/CKKS/IR/CKKSOps.cpp.inc" 81 | >(); 82 | mlir::RankedTensorType::attachInterface(*getContext()); 83 | } 84 | 85 | ::mlir::LogicalResult hecate::ckks::EncodeOp::inferReturnTypes( 86 | ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 87 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 88 | ::mlir::RegionRange regions, 89 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 90 | 91 | auto op = EncodeOpAdaptor(operands, attributes, regions); 92 | auto dPoly = ckks::getPolyType(op.getDst()); 93 | if (dPoly.getNumPoly() == 1 && 94 | (dPoly.getLevel() == op.getLevel() || dPoly.getLevel() == 0)) { 95 | inferredReturnTypes.push_back(op.getDst().getType()); 96 | return ::mlir::success(); 97 | } else { 98 | return ::mlir::failure(); 99 | } 100 | } 101 | 102 | ::mlir::LogicalResult hecate::ckks::RescaleCOp::inferReturnTypes( 103 | ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 104 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 105 | ::mlir::RegionRange regions, 106 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 107 | auto op = RescaleCOpAdaptor(operands, attributes, regions); 108 | auto dPoly = ckks::getPolyType(op.getDst()); 109 | auto lPoly = ckks::getPolyType(op.getSrc()); 110 | if (dPoly.getNumPoly() == lPoly.getNumPoly() && 111 | (dPoly.getLevel() == lPoly.getLevel() - 1 || dPoly.getLevel() == 0)) { 112 | inferredReturnTypes.push_back(op.getDst().getType()); 113 | return ::mlir::success(); 114 | } else { 115 | return ::mlir::failure(); 116 | } 117 | } 118 | 119 | ::mlir::LogicalResult hecate::ckks::ModswitchCOp::inferReturnTypes( 120 | ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 121 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 122 | ::mlir::RegionRange regions, 123 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 124 | auto op = ModswitchCOpAdaptor(operands, attributes, regions); 125 | auto dPoly = ckks::getPolyType(op.getDst()); 126 | auto lPoly = ckks::getPolyType(op.getSrc()); 127 | if (dPoly.getNumPoly() == lPoly.getNumPoly() && 128 | (dPoly.getLevel() == lPoly.getLevel() - op.getDownFactor() || 129 | dPoly.getLevel() == 0)) { 130 | inferredReturnTypes.push_back(op.getDst().getType()); 131 | return ::mlir::success(); 132 | } else { 133 | return ::mlir::failure(); 134 | } 135 | } 136 | 137 | ::mlir::LogicalResult hecate::ckks::AddCPOp::inferReturnTypes( 138 | ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 139 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 140 | ::mlir::RegionRange regions, 141 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 142 | auto op = AddCPOpAdaptor(operands, attributes, regions); 143 | auto dPoly = ckks::getPolyType(op.getDst()); 144 | auto lPoly = ckks::getPolyType(op.getLhs()); 145 | auto rPoly = ckks::getPolyType(op.getRhs()); 146 | if (std::min(lPoly.getNumPoly(), rPoly.getNumPoly()) == 1 && 147 | dPoly.getNumPoly() == std::max(rPoly.getNumPoly(), lPoly.getNumPoly()) && 148 | lPoly.getLevel() == rPoly.getLevel() && 149 | dPoly.getLevel() == lPoly.getLevel()) { 150 | inferredReturnTypes.push_back(op.getDst().getType()); 151 | return ::mlir::success(); 152 | } else { 153 | return ::mlir::failure(); 154 | } 155 | } 156 | 157 | ::mlir::LogicalResult hecate::ckks::MulCPOp::inferReturnTypes( 158 | ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location, 159 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, 160 | ::mlir::RegionRange regions, 161 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { 162 | auto op = AddCPOpAdaptor(operands, attributes, regions); 163 | auto dPoly = ckks::getPolyType(op.getDst()); 164 | auto lPoly = ckks::getPolyType(op.getLhs()); 165 | auto rPoly = ckks::getPolyType(op.getRhs()); 166 | 167 | if (std::min(lPoly.getNumPoly(), rPoly.getNumPoly()) == 1 && 168 | dPoly.getNumPoly() == std::max(lPoly.getNumPoly(), rPoly.getNumPoly()) && 169 | lPoly.getLevel() == rPoly.getLevel() && 170 | lPoly.getLevel() == dPoly.getLevel()) { 171 | inferredReturnTypes.push_back(op.getDst().getType()); 172 | return ::mlir::success(); 173 | } else { 174 | return ::mlir::failure(); 175 | } 176 | } 177 | hecate::ckks::PolyTypeInterface 178 | hecate::ckks::PolyType::switchLevel(unsigned level) const { 179 | return get(getContext(), getNumPoly(), level); 180 | } 181 | hecate::ckks::PolyTypeInterface 182 | hecate::ckks::PolyType::switchNumPoly(unsigned num_poly) const { 183 | return get(getContext(), num_poly, getLevel()); 184 | } 185 | 186 | mlir::RankedTensorType hecate::ckks::getTensorType(mlir::Value v) { 187 | return v.getType().dyn_cast(); 188 | } 189 | hecate::ckks::PolyTypeInterface hecate::ckks::getPolyType(mlir::Value v) { 190 | return v.getType().dyn_cast(); 191 | /* .dyn_cast() */ 192 | /* .getElementType() */ 193 | /* .dyn_cast(); */ 194 | } 195 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(CKKSDialect 2 | CKKSDialect.cpp 3 | DEPENDS 4 | HecateCKKSOpsIncGen 5 | LINK_LIBS 6 | PUBLIC 7 | MLIRIR 8 | ) 9 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_mlir_dialect_library(CKKSTransforms 3 | RemoveLevel.cpp 4 | ReuseBuffer.cpp 5 | EmitHEVM.cpp 6 | DEPENDS 7 | HecateCKKSTransformsIncGen 8 | LINK_LIBS 9 | PUBLIC 10 | CKKSDialect 11 | MLIRIR 12 | ) 13 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/Transforms/EmitHEVM.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 3 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h" 4 | #include "hecate/Dialect/CKKS/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/Dialect/Tensor/IR/Tensor.h" 7 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 8 | #include 9 | #include 10 | 11 | #include "hecate/Support/HEVMHeader.h" 12 | 13 | namespace hecate { 14 | namespace ckks { 15 | #define GEN_PASS_DEF_EMITHEVM 16 | #include "hecate/Dialect/CKKS/Transforms/Passes.h.inc" 17 | } // namespace ckks 18 | } // namespace hecate 19 | 20 | using namespace mlir; 21 | 22 | namespace { 23 | /// Pass to bufferize Arith ops. 24 | struct EmitHEVMPass : public hecate::ckks::impl::EmitHEVMBase { 25 | EmitHEVMPass() {} 26 | EmitHEVMPass(hecate::ckks::EmitHEVMOptions ops) { this->prefix = ops.prefix; } 27 | 28 | void runOnOperation() override { 29 | markAllAnalysesPreserved(); 30 | auto &&func = getOperation(); 31 | HEVMHeader header; 32 | ConfigBody config_body; 33 | header.hevm_header_size = sizeof(HEVMHeader); 34 | header.config_header.arg_length = func.getNumArguments(); 35 | header.config_header.res_length = func.getNumResults(); 36 | 37 | config_body.config_body_length = sizeof(ConfigBody); 38 | SmallVector config_body_ints; 39 | SmallVector insts; 40 | 41 | int64_t cipher_registers = 0; 42 | int64_t plain_registers = 0; 43 | llvm::DenseMap cipher_register_file; 44 | llvm::DenseMap plain_register_file; 45 | for (auto &&arg : func.getArguments()) { 46 | auto tt = arg.getType().dyn_cast(); 47 | if (tt.getNumPoly() == 1) { 48 | plain_register_file.insert({arg, plain_registers++}); 49 | } else { 50 | cipher_register_file.insert({arg, cipher_registers++}); 51 | } 52 | } 53 | func.walk([&](mlir::Operation *op) { 54 | if (auto alloc = dyn_cast(op)) { 55 | auto tt = alloc.getType().dyn_cast(); 56 | HEVMOperation heops; 57 | heops.opcode = -1; 58 | insts.push_back(heops); 59 | 60 | if (tt.getNumPoly() == 1) { 61 | plain_register_file.insert({alloc, plain_registers++}); 62 | } else { 63 | cipher_register_file.insert({alloc, cipher_registers++}); 64 | } 65 | } else if (auto ops = dyn_cast(op)) { 66 | HEVMOperation heops = 67 | ops.getHEVMOperation(plain_register_file, cipher_register_file); 68 | insts.push_back(heops); 69 | 70 | if (heops.opcode > 0) { 71 | cipher_register_file.insert({op->getResult(0), heops.dst}); 72 | } else { 73 | plain_register_file.insert({op->getResult(0), heops.dst}); 74 | } 75 | } 76 | }); 77 | 78 | SmallVector ret_dst; 79 | auto retOp = 80 | dyn_cast(func.getBlocks().front().getTerminator()); 81 | for (auto arg : retOp.getOperands()) { 82 | ret_dst.push_back(cipher_register_file[arg]); 83 | } 84 | 85 | auto arg_scale_array = 86 | func->getAttrOfType("arg_scale").asArrayRef(); 87 | auto arg_level_array = 88 | func->getAttrOfType("arg_level").asArrayRef(); 89 | auto res_scale_array = 90 | func->getAttrOfType("res_scale").asArrayRef(); 91 | auto res_level_array = 92 | func->getAttrOfType("res_level").asArrayRef(); 93 | 94 | config_body.num_operations = insts.size(); 95 | config_body.num_ctxt_buffer = cipher_registers; 96 | config_body.num_ptxt_buffer = plain_registers; 97 | config_body.init_level = 98 | func->getAttrOfType("init_level").getInt(); 99 | 100 | config_body_ints.append(arg_scale_array.begin(), arg_scale_array.end()); 101 | config_body_ints.append(arg_level_array.begin(), arg_level_array.end()); 102 | config_body_ints.append(res_scale_array.begin(), res_scale_array.end()); 103 | config_body_ints.append(res_level_array.begin(), res_level_array.end()); 104 | config_body_ints.append(ret_dst.begin(), ret_dst.end()); 105 | 106 | config_body.config_body_length += 107 | config_body_ints.size() * sizeof(uint64_t); 108 | 109 | std::filesystem::path printpath(prefix.getValue()); 110 | 111 | printpath = std::string(printpath) + "." + func.getName().str() + ".hevm"; 112 | 113 | std::ofstream of(printpath, std::ios::binary); 114 | of.write((char *)&header, sizeof(HEVMHeader)); 115 | of.write((char *)&config_body, sizeof(ConfigBody)); 116 | of.write((char *)config_body_ints.data(), 117 | config_body_ints.size() * sizeof(uint64_t)); 118 | of.write((char *)insts.data(), insts.size() * sizeof(HEVMOperation)); 119 | of.close(); 120 | } 121 | 122 | void getDependentDialects(DialectRegistry ®istry) const override { 123 | registry.insert(); 124 | } 125 | }; 126 | } // namespace 127 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/Transforms/RemoveLevel.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 4 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h" 5 | #include "hecate/Dialect/CKKS/Transforms/Passes.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 8 | 9 | namespace hecate { 10 | namespace ckks { 11 | #define GEN_PASS_DEF_REMOVELEVEL 12 | #include "hecate/Dialect/CKKS/Transforms/Passes.h.inc" 13 | } // namespace ckks 14 | } // namespace hecate 15 | 16 | using namespace mlir; 17 | 18 | namespace { 19 | /// Pass to bufferize Arith ops. 20 | struct RemoveLevelPass 21 | : public hecate::ckks::impl::RemoveLevelBase { 22 | RemoveLevelPass() {} 23 | 24 | void runOnOperation() override { 25 | markAllAnalysesPreserved(); 26 | auto &&func = getOperation(); 27 | mlir::OpBuilder builder(func); 28 | 29 | SmallVector level_in; 30 | SmallVector level_out; 31 | for (auto &&arg : func.getArguments()) { 32 | level_in.push_back( 33 | arg.getType().dyn_cast().getLevel()); 34 | } 35 | func->setAttr("arg_level", builder.getDenseI64ArrayAttr(level_in)); 36 | for (auto &&restype : func.getResultTypes()) { 37 | level_out.push_back( 38 | restype.dyn_cast().getLevel()); 39 | } 40 | func->setAttr("res_level", builder.getDenseI64ArrayAttr(level_out)); 41 | 42 | for (auto value : func.getArguments()) { 43 | auto &&tt = value.getType().dyn_cast(); 44 | value.setType(tt.switchLevel(0)); 45 | } 46 | func.walk([&](Operation *op) { 47 | for (auto value : op->getResults()) { 48 | auto &&tt = value.getType().dyn_cast(); 49 | value.setType(tt.switchLevel(0)); 50 | } 51 | }); 52 | 53 | auto funcType = func.getFunctionType(); 54 | llvm::SmallVector inputTys; 55 | llvm::SmallVector outputTys; 56 | for (auto &&ty : funcType.getInputs()) { 57 | auto &&tt = ty.dyn_cast(); 58 | inputTys.push_back(tt.switchLevel(0)); 59 | } 60 | for (auto &&ty : funcType.getResults()) { 61 | auto &&tt = ty.dyn_cast(); 62 | outputTys.push_back(tt.switchLevel(0)); 63 | } 64 | func.setFunctionType(builder.getFunctionType(inputTys, outputTys)); 65 | } 66 | 67 | void getDependentDialects(DialectRegistry ®istry) const override { 68 | registry.insert(); 69 | } 70 | }; 71 | } // namespace 72 | -------------------------------------------------------------------------------- /lib/Dialect/CKKS/Transforms/ReuseBuffer.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include 4 | 5 | #include "hecate/Dialect/CKKS/IR/CKKSOps.h" 6 | #include "hecate/Dialect/CKKS/IR/PolyTypeInterface.h" 7 | #include "hecate/Dialect/CKKS/Transforms/Passes.h" 8 | #include "mlir/Analysis/Liveness.h" 9 | #include "mlir/Dialect/Func/IR/FuncOps.h" 10 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 11 | 12 | namespace hecate { 13 | namespace ckks { 14 | #define GEN_PASS_DEF_REUSEBUFFER 15 | #include "hecate/Dialect/CKKS/Transforms/Passes.h.inc" 16 | } // namespace ckks 17 | } // namespace hecate 18 | 19 | using namespace mlir; 20 | 21 | namespace { 22 | /// Pass to bufferize Arith ops. 23 | struct ReuseBufferPass 24 | : public hecate::ckks::impl::ReuseBufferBase { 25 | ReuseBufferPass() {} 26 | 27 | void runOnOperation() override { 28 | auto &&func = getOperation(); 29 | mlir::OpBuilder builder(func); 30 | mlir::Liveness l(func); 31 | SmallVector garbage; 32 | func.walk([&](mlir::DestinationStyleOpInterface op) { 33 | for (int i = 0; i < op.getNumDpsInputs(); i++) { 34 | auto v = op.getDpsInputOperand(i); 35 | if (auto tt = hecate::ckks::getPolyType(v->get())) { 36 | if (tt.getNumPoly() == 1) 37 | continue; 38 | if (l.isDeadAfter(v->get(), op) && !garbage.empty() && 39 | v->get() != garbage.back()) { 40 | garbage.push_back(v->get()); 41 | } 42 | } 43 | } 44 | for (int i = 0; i < op.getNumDpsInits(); i++) { 45 | auto v = op.getDpsInitOperand(i); 46 | if (auto tt = hecate::ckks::getPolyType(v->get())) { 47 | if (tt.getNumPoly() == 1) 48 | continue; 49 | if (!garbage.empty()) { 50 | op.getDpsInitOperand(i)->set(garbage.pop_back_val()); 51 | } 52 | } 53 | } 54 | }); 55 | } 56 | 57 | void getDependentDialects(DialectRegistry ®istry) const override { 58 | registry.insert(); 59 | } 60 | }; 61 | } // namespace 62 | -------------------------------------------------------------------------------- /lib/Dialect/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(Earth) 2 | add_subdirectory(CKKS) 3 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Analysis/AutoDifferentiation.cpp: -------------------------------------------------------------------------------- 1 | #include "hecate/Dialect/Earth/Analysis/AutoDifferentiation.h" 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | 4 | using namespace hecate; 5 | using namespace mlir; 6 | 7 | hecate::AutoDifferentiation::AutoDifferentiation(mlir::Operation *op) 8 | : _op(op) { 9 | build(); 10 | } 11 | 12 | void hecate::AutoDifferentiation::build() { 13 | 14 | SmallVector hops; 15 | _op->walk([&](mlir::Block *block) { 16 | for (auto &&arg : block->getArguments()) { 17 | valueMap[arg] = 0.0; 18 | } 19 | }); 20 | 21 | _op->walk([&](hecate::earth::HEAutoDiffInterface sop) { 22 | SmallVector estimation; 23 | hops.push_back(sop); 24 | for (auto &&arg : sop->getOperands()) { 25 | auto it = valueMap.try_emplace(arg, 1.0); 26 | estimation.push_back(it.first->second); 27 | } 28 | auto &&resultEst = sop.estimateValue(estimation); 29 | for (auto &&[val, est] : 30 | llvm::zip(sop.getOperation()->getResults(), resultEst)) { 31 | valueMap[val] = est; 32 | } 33 | }); 34 | 35 | // We may need masking-aware propagation 36 | for (auto &&hop : llvm::reverse(hops)) { 37 | SmallVector gradients; 38 | for (auto &&val : hop.getOperation()->getResults()) { 39 | double gradient = 0.0; 40 | for (auto &&uses : val.getUses()) { 41 | auto it = operandDiffMap.try_emplace(&uses, 1.0); 42 | gradient += it.first->second; 43 | } 44 | valueDiffMap[val] = gradient; 45 | gradients.push_back(gradient); 46 | } 47 | SmallVector estimation; 48 | for (auto &&arg : hop->getOperands()) { 49 | auto it = valueMap.try_emplace(arg, 1.0); 50 | estimation.push_back(it.first->second); 51 | } 52 | auto &&resultGrad = hop.differentiate(gradients, estimation); 53 | for (auto &&[oper, grad] : llvm::zip(hop->getOpOperands(), resultGrad)) { 54 | operandDiffMap[&oper] = grad; 55 | } 56 | } 57 | } 58 | 59 | double AutoDifferentiation::getBackDiff(mlir::Operation *op) { 60 | if (op->getNumResults()) { 61 | return getBackDiff(op->getResult(0)); 62 | } else { 63 | return 1.0; 64 | } 65 | } 66 | 67 | double AutoDifferentiation::getBackDiff(mlir::Value v) { 68 | auto &&i = valueDiffMap.find(v); 69 | if (i != valueDiffMap.end()) { 70 | return i->second; 71 | } else { 72 | return 1.0; 73 | } 74 | } 75 | 76 | double AutoDifferentiation::getBackDiff(mlir::OpOperand &oper) { 77 | auto &&i = operandDiffMap.find(&oper); 78 | if (i != operandDiffMap.end()) 79 | return i->second; 80 | else { 81 | return 1.0; 82 | } 83 | } 84 | 85 | double AutoDifferentiation::getValueEstimation(mlir::Operation *op) { 86 | if (op->getNumResults()) { 87 | getBackDiff(op->getResult(0)); 88 | } else { 89 | return 1.0; 90 | } 91 | } 92 | double AutoDifferentiation::getValueEstimation(mlir::Value v) { 93 | auto &&i = valueMap.find(v); 94 | if (i != valueMap.end()) 95 | return i->second; 96 | else { 97 | return 1.0; 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Analysis/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(EarthAnalysis 2 | ScaleManagementUnit.cpp 3 | AutoDifferentiation.cpp 4 | DEPENDS 5 | HecateEarthOpsIncGen 6 | HecateEarthTransformsIncGen 7 | LINK_LIBS 8 | PUBLIC 9 | EarthDialect 10 | MLIRIR 11 | MLIRAnalysis 12 | ) 13 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(IR) 2 | add_subdirectory(Transforms) 3 | add_subdirectory(Analysis) 4 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/IR/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(EarthDialect 2 | EarthDialect.cpp 3 | HEParameterInterface.cpp 4 | ForwardManagementInterface.cpp 5 | DEPENDS 6 | HecateEarthOpsIncGen 7 | HecateEarthTransformsIncGen 8 | LINK_LIBS 9 | PUBLIC 10 | MLIRIR 11 | ) 12 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/IR/ForwardManagementInterface.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "mlir/IR/BuiltinTypes.h" 4 | #include "mlir/IR/OpDefinition.h" 5 | 6 | #include "hecate/Dialect/Earth/IR/ForwardManagementInterface.h" 7 | 8 | #include "hecate/Dialect/Earth/IR/ForwardManagementInterfaceTypes.cpp.inc" 9 | 10 | #include "hecate/Dialect/Earth/IR/ForwardManagementInterface.cpp.inc" 11 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/IR/HEParameterInterface.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "mlir/IR/BuiltinTypes.h" 3 | #include "mlir/IR/OpDefinition.h" 4 | 5 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 6 | 7 | #include "hecate/Dialect/Earth/IR/HEParameterInterfaceTypes.cpp.inc" 8 | 9 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.cpp.inc" 10 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_mlir_dialect_library(EarthTransforms 2 | WaterlineRescaling.cpp 3 | ElideConstant.cpp 4 | PrivatizeConstant.cpp 5 | Common.cpp 6 | UpscaleBubbling.cpp 7 | EarlyModswitch.cpp 8 | SMUChecker.cpp 9 | SNRRescaling.cpp 10 | ProactiveRescaling.cpp 11 | ScaleManagementScheduler.cpp 12 | ELASMExplorer.cpp 13 | SMUEmbedding.cpp 14 | ErrorEstimator.cpp 15 | LatencyEstimator.cpp 16 | DEPENDS 17 | HecateEarthTransformsIncGen 18 | LINK_LIBS 19 | PUBLIC 20 | EarthDialect 21 | EarthAnalysis 22 | MLIRIR 23 | ) 24 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/Common.cpp: -------------------------------------------------------------------------------- 1 | #include "hecate/Dialect/Earth/Transforms/Common.h" 2 | 3 | using namespace mlir; 4 | 5 | void hecate::earth::refineReturnValues(mlir::func::FuncOp func, 6 | mlir::OpBuilder builder, 7 | SmallVector inputTypes, 8 | int64_t waterline, int64_t output_val) { 9 | 10 | int64_t max_required_level = 0; 11 | // Reduce the level of the resulting values to reduce the size of returns 12 | // 13 | auto rop = dyn_cast(func.getBlocks().front().getTerminator()); 14 | /* func.walk([&](func::ReturnOp rop) { */ 15 | builder.setInsertionPoint(rop); 16 | 17 | int64_t acc_scale_max = 0; 18 | int64_t rescalingFactor = hecate::earth::EarthDialect::rescalingFactor; 19 | for (auto v : rop.getOperands()) { 20 | auto st = v.getType().dyn_cast(); 21 | auto acc_scale = st.getLevel() * rescalingFactor + st.getScale(); 22 | acc_scale_max = std::max(acc_scale_max, acc_scale); 23 | } 24 | 25 | max_required_level = 26 | (acc_scale_max + output_val + rescalingFactor - 1) / rescalingFactor; 27 | 28 | for (size_t i = 0; i < rop.getNumOperands(); i++) { 29 | auto v = rop.getOperand(i); 30 | auto st = v.getType().dyn_cast(); 31 | auto acc_scale = st.getLevel() * rescalingFactor + st.getScale(); 32 | int64_t required_level = 33 | (acc_scale + output_val + rescalingFactor - 1) / rescalingFactor; 34 | int64_t level_diff = max_required_level - required_level; 35 | rop.setOperand(i, builder.create( 36 | rop.getLoc(), v, level_diff)); 37 | } 38 | 39 | // Remap the return types 40 | func.setFunctionType( 41 | builder.getFunctionType(inputTypes, rop.getOperandTypes())); 42 | /* }); */ 43 | func->setAttr("init_level", builder.getI64IntegerAttr(max_required_level)); 44 | SmallVector scales_in; 45 | SmallVector scales_out; 46 | for (auto &&arg : func.getArguments()) { 47 | scales_in.push_back(arg.getType() 48 | .dyn_cast() 49 | .getScale()); 50 | } 51 | func->setAttr("arg_scale", builder.getDenseI64ArrayAttr(scales_in)); 52 | for (auto &&restype : func.getResultTypes()) { 53 | scales_out.push_back( 54 | restype.dyn_cast().getScale()); 55 | } 56 | func->setAttr("res_scale", builder.getDenseI64ArrayAttr(scales_out)); 57 | } 58 | 59 | void hecate::earth::inferTypeForward(hecate::earth::ForwardMgmtInterface sop) { 60 | Operation *oop = sop.getOperation(); 61 | auto iop = dyn_cast(oop); 62 | SmallVector retTypes; 63 | if (iop.inferReturnTypes(oop->getContext(), oop->getLoc(), oop->getOperands(), 64 | oop->getAttrDictionary(), oop->getRegions(), 65 | retTypes) 66 | .succeeded()) { 67 | oop->getResults().back().setType(retTypes.back()); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/ELASMExplorer.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 4 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 5 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | 8 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 9 | #include "hecate/Dialect/Earth/Transforms/Common.h" 10 | #include "mlir/Pass/PassManager.h" 11 | #include "mlir/Transforms/Passes.h" 12 | #include 13 | 14 | namespace hecate { 15 | namespace earth { 16 | #define GEN_PASS_DEF_ELASMEXPLORER 17 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 18 | } // namespace earth 19 | } // namespace hecate 20 | 21 | using namespace mlir; 22 | 23 | #define DEBUG_TYPE "elasm" 24 | 25 | namespace { 26 | /// Pass to bufferize Arith ops. 27 | struct ELASMExplorerPass 28 | : public hecate::earth::impl::ELASMExplorerBase { 29 | ELASMExplorerPass() {} 30 | 31 | ELASMExplorerPass(hecate::earth::ELASMExplorerOptions ops) { 32 | this->waterline = ops.waterline; 33 | this->output_val = ops.output_val; 34 | this->parallel = ops.parallel; 35 | this->num_iter = ops.num_iter; 36 | this->beta = ops.beta; 37 | this->gamma = ops.gamma; 38 | } 39 | ELASMExplorerPass(std::pair ops) { 40 | this->waterline = ops.first; 41 | this->output_val = ops.second; 42 | this->parallel = 20; 43 | this->num_iter = 1000; 44 | this->beta = 50; 45 | this->gamma = 50; 46 | } 47 | 48 | double costFunc(double cost, double noise) { 49 | return std::sqrt(cost) * (beta + std::log2(noise)); 50 | } 51 | 52 | void runOnOperation() override { 53 | 54 | auto func = getOperation(); 55 | mlir::OpBuilder builder(func); 56 | 57 | markAnalysesPreserved(); 58 | hecate::ScaleManagementUnit smu = 59 | getAnalysis(); 60 | 61 | smu.attach(); 62 | 63 | SmallVector funcs; 64 | SmallVector, SmallVector, 65 | SmallVector>> 66 | plans(parallel); 67 | 68 | SmallVector costs(parallel, std::numeric_limits::max()); 69 | double optcost = std::numeric_limits::max(); 70 | std::tuple, SmallVector, 71 | SmallVector> 72 | optplan; 73 | 74 | auto mod = mlir::ModuleOp::create(func.getLoc()); 75 | 76 | PassManager pm(mod.getContext()); 77 | pm.addNestedPass( 78 | hecate::earth::createScaleManagementScheduler()); 79 | pm.addNestedPass( 80 | hecate::earth::createSNRRescaling({waterline, output_val})); 81 | pm.addNestedPass(hecate::earth::createUpscaleBubbling()); 82 | pm.addPass(mlir::createCanonicalizerPass()); 83 | pm.addNestedPass(hecate::earth::createEarlyModswitch()); 84 | pm.addPass(mlir::createCanonicalizerPass()); 85 | pm.addPass(mlir::createCSEPass()); 86 | pm.addNestedPass(hecate::earth::createErrorEstimator()); 87 | pm.addNestedPass(hecate::earth::createLatencyEstimator()); 88 | 89 | std::uniform_real_distribution dd(0.0, 1.0); 90 | std::random_device rd; 91 | std::mt19937 gen(rd()); 92 | 93 | for (int n = 0; n < num_iter; n++) { 94 | /* llvm::errs() << n << "th Run Start\n"; */ 95 | for (int i = 0; i < parallel; i++) { 96 | auto dup = func.clone(); 97 | dup.setName((func.getName() + "_" + std::to_string(i)).str()); 98 | mod.push_back(dup); 99 | funcs.push_back(dup); 100 | dup->setAttr("sm_plan_edge", 101 | builder.getDenseI64ArrayAttr(std::get<0>(plans[i]))); 102 | dup->setAttr("sm_plan_scale", 103 | builder.getDenseI64ArrayAttr(std::get<1>(plans[i]))); 104 | dup->setAttr("sm_plan_level", 105 | builder.getDenseI64ArrayAttr(std::get<2>(plans[i]))); 106 | } 107 | 108 | if (pm.run(mod).failed()) { 109 | assert(0 && "Pass failed inside ELASM explorer"); 110 | pm.dump(); 111 | } 112 | /* llvm::errs() << n << "th Run Done\n"; */ 113 | for (int i = 0; i < parallel; i++) { 114 | 115 | double thres = dd(gen); 116 | func::FuncOp &target = funcs[i]; 117 | double cost = 118 | costFunc(target->getAttrOfType("est_latency") 119 | .getValueAsDouble(), 120 | target->getAttrOfType("est_error") 121 | .getValueAsDouble()); 122 | 123 | double alpha = 124 | std::min(1.0, std::pow(2.0, -gamma * (1.0 - costs[i] / cost))); 125 | 126 | if (thres < alpha) { 127 | plans[i] = { 128 | SmallVector( 129 | target->getAttrOfType("sm_plan_edge") 130 | .asArrayRef()), 131 | SmallVector( 132 | target->getAttrOfType("sm_plan_scale") 133 | .asArrayRef()), 134 | SmallVector( 135 | target->getAttrOfType("sm_plan_level") 136 | .asArrayRef())}; 137 | costs[i] = cost; 138 | } 139 | if (cost < optcost) { 140 | LLVM_DEBUG(llvm::dbgs() 141 | << optcost << " " << cost << " " 142 | << target->getAttrOfType("est_latency") 143 | .getValueAsDouble() 144 | << " " 145 | << target->getAttrOfType("est_error") 146 | .getValueAsDouble() 147 | << "\n"); 148 | optplan = { 149 | SmallVector( 150 | target->getAttrOfType("sm_plan_edge") 151 | .asArrayRef()), 152 | SmallVector( 153 | target->getAttrOfType("sm_plan_scale") 154 | .asArrayRef()), 155 | SmallVector( 156 | target->getAttrOfType("sm_plan_level") 157 | .asArrayRef())}; 158 | optcost = cost; 159 | } 160 | funcs[i].erase(); 161 | } 162 | funcs.clear(); 163 | } 164 | func->setAttr("sm_plan_edge", 165 | builder.getDenseI64ArrayAttr(std::get<0>(optplan))); 166 | func->setAttr("sm_plan_scale", 167 | builder.getDenseI64ArrayAttr(std::get<1>(optplan))); 168 | func->setAttr("sm_plan_level", 169 | builder.getDenseI64ArrayAttr(std::get<2>(optplan))); 170 | func->setAttr("no_mutation", builder.getBoolAttr(true)); 171 | 172 | /* auto p = pm.run(func); */ 173 | } 174 | 175 | void getDependentDialects(DialectRegistry ®istry) const override { 176 | registry.insert(); 177 | } 178 | }; 179 | } // namespace 180 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/EarlyModswitch.cpp: -------------------------------------------------------------------------------- 1 | 2 | //===- Bufferize.cpp - Bufferization for Arith ops ---------*- C++ -*-===// 3 | // 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 | // See https://llvm.org/LICENSE.txt for license information. 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 11 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 12 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 13 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 14 | #include "mlir/Dialect/Func/IR/FuncOps.h" 15 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 16 | #include "llvm/Support/Debug.h" 17 | 18 | namespace hecate { 19 | namespace earth { 20 | #define GEN_PASS_DEF_EARLYMODSWITCH 21 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 22 | } // namespace earth 23 | } // namespace hecate 24 | // 25 | #define DEBUG_TYPE "hecate_em" 26 | 27 | using namespace mlir; 28 | 29 | namespace { 30 | /// Pass to bufferize Arith ops. 31 | struct EarlyModswitchPass 32 | : public hecate::earth::impl::EarlyModswitchBase { 33 | EarlyModswitchPass() {} 34 | 35 | void runOnOperation() override { 36 | auto func = getOperation(); 37 | markAnalysesPreserved(); 38 | 39 | mlir::OpBuilder builder(func); 40 | mlir::IRRewriter rewriter(builder); 41 | 42 | SmallVector inputTypes; 43 | 44 | auto &&bb = func.getBody().getBlocks().front(); 45 | for (auto iter = bb.rbegin(); iter != bb.rend(); ++iter) { 46 | if (auto op = dyn_cast(*iter)) { 47 | // Gather the users and finds the minimum "downFactor" 48 | 49 | uint64_t minModFactor = -1; 50 | for (auto &&oper : op->getResult(0).getUses()) { 51 | if (auto oop = 52 | dyn_cast(oper.getOwner())) { 53 | minModFactor = std::min(minModFactor, oop.getDownFactor()); 54 | } else { 55 | minModFactor = 0; 56 | } 57 | } 58 | 59 | // Check that every user needs the "downFactor"ed level 60 | if (!minModFactor) { 61 | continue; // Go to next operation 62 | } 63 | 64 | // Move the modswitch 65 | if (auto oop = 66 | dyn_cast(op.getOperation())) { 67 | // Modswitch movement can be absorbed into modswitch 68 | oop.setDownFactor(oop.getDownFactor() + minModFactor); 69 | oop.getResult().setType(oop.getScaleType().switchLevel( 70 | oop.getRescaleLevel() + minModFactor)); 71 | } else { 72 | // Modswitch is moved to the opreands 73 | for (int i = 0; i < op->getNumOperands(); i++) { 74 | auto oper = op->getOperand(i); 75 | builder.setInsertionPoint(op); 76 | auto newOper = builder.create( 77 | op->getLoc(), oper, minModFactor); 78 | op->setOperand(i, newOper); 79 | } 80 | op->getResult(0).setType(op.getScaleType().switchLevel( 81 | op.getRescaleLevel() + minModFactor)); 82 | } 83 | 84 | // Change the user modswitch downFactors 85 | for (auto &&oper : op->getResult(0).getUsers()) { 86 | if (auto oop = dyn_cast(oper)) { 87 | oop.setDownFactor(oop.getDownFactor() - minModFactor); 88 | } 89 | } 90 | } 91 | } 92 | LLVM_DEBUG(llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n"); 93 | } 94 | 95 | void getDependentDialects(DialectRegistry ®istry) const override { 96 | registry.insert(); 97 | } 98 | }; 99 | } // namespace 100 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/ElideConstant.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/IR/Builders.h" 7 | #include 8 | 9 | namespace hecate { 10 | namespace earth { 11 | #define GEN_PASS_DEF_ELIDECONSTANT 12 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 13 | } // namespace earth 14 | } // namespace hecate 15 | 16 | using namespace mlir; 17 | 18 | namespace { 19 | /// Pass to bufferize Arith ops. 20 | struct ElideConstantPass 21 | : public hecate::earth::impl::ElideConstantBase { 22 | ElideConstantPass() {} 23 | ElideConstantPass(hecate::earth::ElideConstantOptions ops) { 24 | this->name = ops.name; 25 | } 26 | 27 | void runOnOperation() override { 28 | auto func = getOperation(); 29 | SmallVector, 4> save_data; 30 | 31 | mlir::OpBuilder builder(func.getOperation()); 32 | 33 | func.walk([&](hecate::earth::ConstantOp cop) { 34 | SmallVector datas( 35 | cop.getValue().dyn_cast().getValues()); 36 | save_data.push_back(datas); 37 | cop.setValueAttr(builder.getI64IntegerAttr(save_data.size() - 1)); 38 | }); 39 | 40 | name = name + (func.getName() + ".cst").str(); 41 | llvm::errs() << name << "\n"; 42 | std::ofstream of(name, std::ios::binary); 43 | int64_t a; 44 | a = save_data.size(); 45 | of.write((char *)&a, sizeof(int64_t)); 46 | for (auto k : save_data) { 47 | a = k.size(); 48 | of.write((char *)&a, sizeof(int64_t)); 49 | for (auto d : k) { 50 | of.write((char *)&d, sizeof(double)); 51 | } 52 | } 53 | of.close(); 54 | } 55 | 56 | void getDependentDialects(DialectRegistry ®istry) const override { 57 | registry.insert(); 58 | } 59 | }; 60 | } // namespace 61 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/ErrorEstimator.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/Analysis/AutoDifferentiation.h" 3 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 4 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 5 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 6 | #include "hecate/Support/Support.h" 7 | #include "mlir/Dialect/Func/IR/FuncOps.h" 8 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 9 | 10 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 11 | #include "llvm/Support/Debug.h" 12 | 13 | namespace hecate { 14 | namespace earth { 15 | #define GEN_PASS_DEF_ERRORESTIMATOR 16 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 17 | } // namespace earth 18 | } // namespace hecate 19 | // 20 | #define DEBUG_TYPE "hecate_ub" 21 | 22 | using namespace mlir; 23 | 24 | namespace { 25 | /// Pass to bufferize Arith ops. 26 | struct ErrorEstimatorPass 27 | : public hecate::earth::impl::ErrorEstimatorBase { 28 | ErrorEstimatorPass() {} 29 | 30 | void runOnOperation() override { 31 | auto func = getOperation(); 32 | auto builder = OpBuilder(func); 33 | 34 | auto &&diff = getAnalysis(); 35 | 36 | markAllAnalysesPreserved(); 37 | double error_square = 0; 38 | 39 | // We cannot track both self term (x+x then error should be doubled but 40 | // forward analysis makes the error sqrt(2)-ed) and quadratic term (x1*x2 41 | // adds e1*e2 term) efficiently. 42 | // 43 | // Forward analysis cannot track self term but can track quadratic term 44 | // Backward analysis can track self term but cannot track quadratic term 45 | // 46 | // ELASM paper uses forward analysis but I changed implementation 47 | // Because quadratic term should be smaller enough to guarantee correctness 48 | // 49 | // I think neither quadratic term tracking on backward nor 50 | // self term tracking on forward is practical. 51 | // We can utilize high level information to selectively use analysis 52 | // Additional evaluation is required. 53 | 54 | func.walk([&](hecate::earth::HEProfInterface pop) { 55 | auto df = diff.getBackDiff(pop.getOperation()); 56 | error_square += pop.getNoise() * pop.getNum() * 57 | std::pow(diff.getBackDiff(pop.getOperation()), 2) / 58 | std::exp2(pop.getNoiseScale()); 59 | }); 60 | 61 | func->setAttr("est_error", 62 | builder.getF64FloatAttr(std::sqrt(error_square))); 63 | LLVM_DEBUG(llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n"); 64 | } 65 | 66 | void getDependentDialects(DialectRegistry ®istry) const override { 67 | registry.insert(); 68 | } 69 | }; 70 | } // namespace 71 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/LatencyEstimator.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "hecate/Support/Support.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 8 | 9 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 10 | 11 | namespace hecate { 12 | namespace earth { 13 | #define GEN_PASS_DEF_LATENCYESTIMATOR 14 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 15 | } // namespace earth 16 | } // namespace hecate 17 | 18 | using namespace mlir; 19 | 20 | namespace { 21 | /// Pass to bufferize Arith ops. 22 | struct LatencyEstimatorPass 23 | : public hecate::earth::impl::LatencyEstimatorBase { 24 | LatencyEstimatorPass() {} 25 | 26 | void runOnOperation() override { 27 | auto func = getOperation(); 28 | auto builder = OpBuilder(func); 29 | 30 | markAllAnalysesPreserved(); 31 | double latency = 0; 32 | 33 | func.walk([&](hecate::earth::HEProfInterface pop) { 34 | latency += pop.getLatency() * pop.getNum(); 35 | }); 36 | 37 | func->setAttr("est_latency", builder.getF64FloatAttr(latency)); 38 | } 39 | 40 | void getDependentDialects(DialectRegistry ®istry) const override { 41 | registry.insert(); 42 | } 43 | }; 44 | } // namespace 45 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/PrivatizeConstant.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | 7 | namespace hecate { 8 | namespace earth { 9 | #define GEN_PASS_DEF_PRIVATIZECONSTANT 10 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 11 | } // namespace earth 12 | } // namespace hecate 13 | 14 | using namespace mlir; 15 | 16 | namespace { 17 | /// Pass to bufferize Arith ops. 18 | struct PrivatizeConstantPass 19 | : public hecate::earth::impl::PrivatizeConstantBase { 20 | PrivatizeConstantPass() {} 21 | 22 | void runOnOperation() override { 23 | auto func = getOperation(); 24 | func.walk([&](hecate::earth::HEScaleOpInterface hop) { 25 | OpBuilder builder(hop); 26 | for (size_t i = 0; i < hop.getOperation()->getNumOperands(); i++) { 27 | if (auto &&cop = 28 | hop->getOperand(i).getDefiningOp()) { 29 | hop->setOperand(i, builder.create( 30 | cop.getLoc(), cop.getType(), cop.getValue(), 31 | cop.getRmsVar())); 32 | } 33 | } 34 | }); 35 | } 36 | 37 | void getDependentDialects(DialectRegistry ®istry) const override { 38 | registry.insert(); 39 | } 40 | }; 41 | } // namespace 42 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/ProactiveRescaling.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | 7 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 8 | #include "hecate/Dialect/Earth/Transforms/Common.h" 9 | 10 | namespace hecate { 11 | namespace earth { 12 | #define GEN_PASS_DEF_PROACTIVERESCALING 13 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 14 | } // namespace earth 15 | } // namespace hecate 16 | 17 | using namespace mlir; 18 | 19 | namespace { 20 | /// Pass to bufferize Arith ops. 21 | struct ProactiveRescalingPass 22 | : public hecate::earth::impl::ProactiveRescalingBase< 23 | ProactiveRescalingPass> { 24 | ProactiveRescalingPass() {} 25 | 26 | ProactiveRescalingPass(hecate::earth::ProactiveRescalingOptions ops) { 27 | this->waterline = ops.waterline; 28 | this->output_val = ops.output_val; 29 | } 30 | 31 | void runOnOperation() override { 32 | 33 | auto func = getOperation(); 34 | 35 | markAnalysesPreserved(); 36 | 37 | mlir::OpBuilder builder(func); 38 | mlir::IRRewriter rewriter(builder); 39 | SmallVector inputTypes; 40 | // Set function argument types 41 | for (auto argval : func.getArguments()) { 42 | argval.setType( 43 | argval.getType().dyn_cast().replaceSubElements( 44 | [&](hecate::earth::HEScaleTypeInterface t) { 45 | return t.switchScale(waterline); 46 | })); 47 | inputTypes.push_back(argval.getType()); 48 | } 49 | 50 | // Apply waterline rescaling for the operations 51 | func.walk([&](hecate::earth::ForwardMgmtInterface sop) { 52 | builder.setInsertionPointAfter(sop.getOperation()); 53 | sop.processOperandsPARS(waterline); 54 | inferTypeForward(sop); 55 | sop.processResultsPARS(waterline); 56 | }); 57 | hecate::earth::refineReturnValues(func, builder, inputTypes, waterline, 58 | output_val); 59 | } 60 | 61 | void getDependentDialects(DialectRegistry ®istry) const override { 62 | registry.insert(); 63 | } 64 | }; 65 | } // namespace 66 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/SMUChecker.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 7 | 8 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 9 | 10 | namespace hecate { 11 | namespace earth { 12 | #define GEN_PASS_DEF_SMUCHECKER 13 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 14 | } // namespace earth 15 | } // namespace hecate 16 | 17 | using namespace mlir; 18 | 19 | namespace { 20 | /// Pass to bufferize Arith ops. 21 | struct SMUCheckerPass 22 | : public hecate::earth::impl::SMUCheckerBase { 23 | SMUCheckerPass() {} 24 | 25 | void runOnOperation() override { 26 | auto func = getOperation(); 27 | 28 | auto smu = getAnalysis(); 29 | markAllAnalysesPreserved(); 30 | 31 | llvm::errs() << smu.getNumSMUs() << " " << smu.getNumEdges() << "\n"; 32 | smu.verify(); 33 | for (int i = 0; i < smu.getNumSMUs(); i++) { 34 | llvm::errs() << "$$$$$$ " << i << " $$$$\n"; 35 | int j = 0; 36 | for (auto &&vv : smu.getValueSet(i)) { 37 | if (++j > 10) 38 | break; 39 | llvm::errs() << "## " << vv << "\n"; 40 | } 41 | } 42 | llvm::errs() << "--------------------\n\n\n\n\n\n "; 43 | 44 | std::map> smuLS; 45 | bool success = true; 46 | func.walk([&](hecate::earth::HEScaleOpInterface sop) { 47 | auto ID = smu.getID(sop->getResult(0)); 48 | if (ID == -1) { 49 | return; 50 | } 51 | 52 | auto dat = smuLS.find(ID); 53 | if (dat != smuLS.end()) { 54 | auto &&record = dat->second; 55 | std::pair data = {sop.getRescaleLevel(), 56 | sop.getScale()}; 57 | if (record != data) { 58 | llvm::errs() << record.first << " " << record.second << " " 59 | << data.first << " " << data.second << "\n"; 60 | sop.dump(); 61 | success = false; 62 | } 63 | } else { 64 | smuLS[ID] = {sop.getRescaleLevel(), sop.getScale()}; 65 | } 66 | }); 67 | if (!success) { 68 | func.walk([&](mlir::Block *block) { 69 | for (auto &&arg : block->getArguments()) { 70 | llvm::errs() << smu.getID(arg) << ":"; 71 | llvm::errs() << " => "; 72 | for (auto &&user : arg.getUsers()) { 73 | llvm::errs() << smu.getID(user->getResult(0)) << " "; 74 | } 75 | llvm::errs() << " \n"; 76 | arg.dump(); 77 | } 78 | }); 79 | func.walk([&](hecate::earth::HEScaleOpInterface sop) { 80 | llvm::errs() << smu.getID(sop->getResult(0)) << ":"; 81 | for (auto &&arg : sop->getOperands()) { 82 | llvm::errs() << smu.getID(arg) << " "; 83 | } 84 | llvm::errs() << " => "; 85 | for (auto &&arg : sop->getUsers()) { 86 | llvm::errs() << smu.getID(arg->getResult(0)) << " "; 87 | } 88 | llvm::errs() << " \n"; 89 | sop->getLoc()->dump(); 90 | sop->dump(); 91 | }); 92 | 93 | if (!success) { 94 | assert(0 && "SMU is not correct"); 95 | } 96 | } 97 | } 98 | 99 | void getDependentDialects(DialectRegistry ®istry) const override { 100 | registry.insert(); 101 | } 102 | }; 103 | } // namespace 104 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/SMUEmbedding.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "hecate/Support/Support.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 8 | 9 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 10 | 11 | namespace hecate { 12 | namespace earth { 13 | #define GEN_PASS_DEF_SMUEMBEDDING 14 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 15 | } // namespace earth 16 | } // namespace hecate 17 | 18 | using namespace mlir; 19 | 20 | namespace { 21 | /// Pass to bufferize Arith ops. 22 | struct SMUEmbeddingPass 23 | : public hecate::earth::impl::SMUEmbeddingBase { 24 | SMUEmbeddingPass() {} 25 | 26 | void runOnOperation() override { 27 | auto func = getOperation(); 28 | auto builder = OpBuilder(func); 29 | 30 | auto smu = getAnalysis(); 31 | markAllAnalysesPreserved(); 32 | 33 | smu.attach(); 34 | } 35 | 36 | void getDependentDialects(DialectRegistry ®istry) const override { 37 | registry.insert(); 38 | } 39 | }; 40 | } // namespace 41 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/SNRRescaling.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include 4 | 5 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 6 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 7 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 8 | #include "mlir/Dialect/Func/IR/FuncOps.h" 9 | 10 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 11 | #include "hecate/Dialect/Earth/Transforms/Common.h" 12 | #include "llvm/Support/Debug.h" 13 | 14 | namespace hecate { 15 | namespace earth { 16 | #define GEN_PASS_DEF_SNRRESCALING 17 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 18 | } // namespace earth 19 | } // namespace hecate 20 | // 21 | #define DEBUG_TYPE "hecate_snr" 22 | 23 | using namespace mlir; 24 | 25 | namespace { 26 | /// Pass to bufferize Arith ops. 27 | struct SNRRescalingPass 28 | : public hecate::earth::impl::SNRRescalingBase { 29 | 30 | SNRRescalingPass() {} 31 | 32 | SNRRescalingPass(hecate::earth::SNRRescalingOptions ops) { 33 | this->waterline = ops.waterline; 34 | this->output_val = ops.output_val; 35 | } 36 | 37 | int64_t calcWaterline(hecate::ScaleManagementUnit &smu, Value v) { 38 | while (smu.getID(v) < 0) { 39 | // Backward movement for -1 value because it means scale management 40 | // operations or apply_schedule op. 41 | v = v.getUses().begin().getUser()->getResult(0); 42 | } 43 | return smu.inNoisyGroup(v) ? waterline + 4 : waterline; 44 | } 45 | int64_t calcWaterline(hecate::ScaleManagementUnit &smu, Operation *op) { 46 | return calcWaterline(smu, op->getResult(0)); 47 | } 48 | 49 | void runOnOperation() override { 50 | 51 | auto func = getOperation(); 52 | 53 | hecate::ScaleManagementUnit smu = 54 | getAnalysis(); 55 | 56 | markAnalysesPreserved(); 57 | 58 | mlir::OpBuilder builder(func); 59 | mlir::IRRewriter rewriter(builder); 60 | SmallVector inputTypes; 61 | // Set function argument types 62 | for (auto argval : func.getArguments()) { 63 | argval.setType( 64 | argval.getType().dyn_cast().replaceSubElements( 65 | [&](hecate::earth::HEScaleTypeInterface t) { 66 | return t.switchScale(calcWaterline(smu, argval)); 67 | })); 68 | 69 | inputTypes.push_back(argval.getType()); 70 | } 71 | func.setFunctionType(builder.getFunctionType( 72 | inputTypes, func.getFunctionType().getResults())); 73 | 74 | func.walk([&](hecate::earth::ForwardMgmtInterface sop) { 75 | builder.setInsertionPoint(sop.getOperation()); 76 | sop.processOperandsSNR(calcWaterline(smu, sop.getOperation())); 77 | inferTypeForward(sop); 78 | builder.setInsertionPointAfter(sop.getOperation()); 79 | sop.processResultsSNR(calcWaterline(smu, sop.getOperation())); 80 | }); 81 | hecate::earth::refineReturnValues(func, builder, inputTypes, waterline, 82 | output_val); 83 | 84 | std::error_code EC; 85 | 86 | LLVM_DEBUG(llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n"); 87 | } 88 | 89 | void getDependentDialects(DialectRegistry ®istry) const override { 90 | registry.insert(); 91 | } 92 | }; 93 | } // namespace 94 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/ScaleManagementScheduler.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 4 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 5 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 6 | #include "mlir/Dialect/Func/IR/FuncOps.h" 7 | #include "mlir/IR/Builders.h" 8 | #include "llvm/Support/Debug.h" 9 | #include 10 | #include 11 | 12 | namespace hecate { 13 | namespace earth { 14 | #define GEN_PASS_DEF_SCALEMANAGEMENTSCHEDULER 15 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 16 | } // namespace earth 17 | } // namespace hecate 18 | // 19 | #define DEBUG_TYPE "hecate_sms" 20 | 21 | using namespace mlir; 22 | 23 | namespace { 24 | /// Pass to bufferize Arith ops. 25 | struct ScaleManagementSchedulerPass 26 | : public hecate::earth::impl::ScaleManagementSchedulerBase< 27 | ScaleManagementSchedulerPass> { 28 | ScaleManagementSchedulerPass() {} 29 | 30 | void runOnOperation() override { 31 | markAnalysesPreserved(); 32 | auto func = getOperation(); 33 | mlir::OpBuilder builder(func); 34 | hecate::ScaleManagementUnit smu = 35 | getAnalysis(); 36 | 37 | if (!func->hasAttr("sm_plan_edge")) { 38 | // Add Empty Plan 39 | func->setAttr("sm_plan_edge", builder.getDenseI64ArrayAttr({})); 40 | func->setAttr("sm_plan_scale", builder.getDenseI64ArrayAttr({})); 41 | func->setAttr("sm_plan_level", builder.getDenseI64ArrayAttr({})); 42 | } else if (!func->hasAttr("no_mutation") || 43 | !func->getAttrOfType("no_mutation")) { 44 | // Mutate Plan 45 | std::random_device rd; 46 | std::mt19937 gen(rd()); 47 | std::poisson_distribution planRange( 48 | static_cast(std::sqrt(smu.getNumEdges()))); 49 | /* std::uniform_int_distribution planRange(1, 6); */ 50 | std::uniform_int_distribution smRange(0, smu.getNumEdges() - 1); 51 | std::uniform_int_distribution scaleRange(-15, 15); 52 | std::uniform_int_distribution levelRange(0, 2); 53 | DenseMap> planMap; 54 | auto plan_num = planRange(gen); 55 | for (int64_t i = 0; i < plan_num; i++) { 56 | planMap[smRange(gen)] = {std::max(0L, scaleRange(gen)), 57 | std::max(0L, levelRange(gen))}; 58 | } 59 | SmallVector planEdge; 60 | SmallVector planScale; 61 | SmallVector planLevel; 62 | for (auto &&it : planMap) { 63 | planEdge.push_back(it.first); 64 | planScale.push_back(it.second.first); 65 | planLevel.push_back(it.second.second); 66 | } 67 | func->setAttr("sm_plan_edge", builder.getDenseI64ArrayAttr(planEdge)); 68 | func->setAttr("sm_plan_scale", builder.getDenseI64ArrayAttr(planScale)); 69 | func->setAttr("sm_plan_level", builder.getDenseI64ArrayAttr(planLevel)); 70 | } 71 | 72 | auto &&smEdge = func->getAttrOfType("sm_plan_edge") 73 | .asArrayRef(); 74 | auto &&smScale = 75 | func->getAttrOfType("sm_plan_scale") 76 | .asArrayRef(); 77 | auto &&smLevel = 78 | func->getAttrOfType("sm_plan_level") 79 | .asArrayRef(); 80 | 81 | for (uint64_t i = 0; i < smEdge.size(); i++) { 82 | auto &&edges = smu.getEdgeSet(smEdge[i]); 83 | for (auto &&edge : edges) { 84 | if (edge->get() 85 | .getType() 86 | .dyn_cast() 87 | .isCipher()) { 88 | builder.setInsertionPoint(edge->getOwner()); 89 | edge->set(builder.create( 90 | edge->getOwner()->getLoc(), edge->get(), smScale[i], smLevel[i])); 91 | } 92 | } 93 | } 94 | LLVM_DEBUG(llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n"); 95 | } 96 | void getDependentDialects(DialectRegistry ®istry) const override { 97 | registry.insert(); 98 | } 99 | }; 100 | } // namespace 101 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/UpscaleBubbling.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 7 | #include "llvm/Support/Debug.h" 8 | 9 | namespace hecate { 10 | namespace earth { 11 | #define GEN_PASS_DEF_UPSCALEBUBBLING 12 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 13 | } // namespace earth 14 | } // namespace hecate 15 | // 16 | #define DEBUG_TYPE "hecate_ub" 17 | 18 | using namespace mlir; 19 | 20 | namespace { 21 | /// Pass to bufferize Arith ops. 22 | struct UpscaleBubblingPass 23 | : public hecate::earth::impl::UpscaleBubblingBase { 24 | UpscaleBubblingPass() {} 25 | 26 | void runOnOperation() override { 27 | 28 | auto func = getOperation(); 29 | markAnalysesPreserved(); 30 | 31 | mlir::OpBuilder builder(func); 32 | mlir::IRRewriter rewriter(builder); 33 | 34 | SmallVector inputTypes; 35 | 36 | auto &&bb = func.getBody().getBlocks().front(); 37 | for (auto iter = bb.rbegin(); iter != bb.rend(); ++iter) { 38 | if (auto op = dyn_cast(*iter)) { 39 | // Gather the users and finds the minimum "upFactor" 40 | if (op.isConsume() && !op.isSingle()) { 41 | continue; 42 | } 43 | 44 | uint64_t minUpFactor = -1; 45 | for (auto &&oper : op->getResult(0).getUses()) { 46 | if (auto oop = dyn_cast(oper.getOwner())) { 47 | minUpFactor = std::min(minUpFactor, oop.getUpFactor()); 48 | } else { 49 | minUpFactor = 0; 50 | } 51 | } 52 | 53 | // Check that every user needs the "upFactored"ed scale 54 | if (!minUpFactor) { 55 | continue; // Go to next operation 56 | } 57 | 58 | // Move the modswitch 59 | if (auto oop = dyn_cast(op.getOperation())) { 60 | // Upscale movement can be absorbed into upscale 61 | oop.setUpFactor(oop.getUpFactor() + minUpFactor); 62 | oop.getResult().setType( 63 | oop.getScaleType().switchScale(oop.getScale() + minUpFactor)); 64 | } else if (!op.isConsume()) { 65 | // Upscale is moved to the opreands 66 | for (auto &&i = 0; i < op->getNumOperands(); i++) { 67 | auto oper = op->getOperand(i); 68 | builder.setInsertionPoint(op); 69 | auto newOper = builder.create( 70 | op->getLoc(), oper, minUpFactor); 71 | op->setOperand(i, newOper); 72 | } 73 | op->getResult(0).setType( 74 | op.getScaleType().switchScale(op.getScale() + minUpFactor)); 75 | } else if (op.isConsume() && op.isSingle()) { 76 | // Upscale can be moved to ciphertext operand 77 | for (auto &&i = 0; i < op->getNumOperands(); i++) { 78 | if (op.isOperandCipher(i)) { 79 | auto oper = op->getOperand(i); 80 | builder.setInsertionPoint(op); 81 | auto newOper = builder.create( 82 | op->getLoc(), oper, minUpFactor); 83 | op->setOperand(i, newOper); 84 | } 85 | } 86 | op->getResult(0).setType( 87 | op.getScaleType().switchScale(op.getScale() + minUpFactor)); 88 | } 89 | 90 | // Change the user modswitch downFactors 91 | for (auto &&oper : op->getResult(0).getUsers()) { 92 | if (auto oop = dyn_cast(oper)) { 93 | oop.setUpFactor(oop.getUpFactor() - minUpFactor); 94 | } 95 | } 96 | } 97 | } 98 | LLVM_DEBUG(llvm::dbgs() << __FILE__ << ":" << __LINE__ << "\n"); 99 | } 100 | 101 | void getDependentDialects(DialectRegistry ®istry) const override { 102 | registry.insert(); 103 | } 104 | }; 105 | } // namespace 106 | -------------------------------------------------------------------------------- /lib/Dialect/Earth/Transforms/WaterlineRescaling.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 3 | #include "hecate/Dialect/Earth/IR/HEParameterInterface.h" 4 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 5 | #include "mlir/Dialect/Func/IR/FuncOps.h" 6 | 7 | #include "hecate/Dialect/Earth/Analysis/ScaleManagementUnit.h" 8 | #include "hecate/Dialect/Earth/Transforms/Common.h" 9 | 10 | namespace hecate { 11 | namespace earth { 12 | #define GEN_PASS_DEF_WATERLINERESCALING 13 | #include "hecate/Dialect/Earth/Transforms/Passes.h.inc" 14 | } // namespace earth 15 | } // namespace hecate 16 | 17 | using namespace mlir; 18 | 19 | namespace { 20 | /// Pass to bufferize Arith ops. 21 | struct WaterlineRescalingPass 22 | : public hecate::earth::impl::WaterlineRescalingBase< 23 | WaterlineRescalingPass> { 24 | WaterlineRescalingPass() {} 25 | 26 | WaterlineRescalingPass(hecate::earth::WaterlineRescalingOptions ops) { 27 | this->waterline = ops.waterline; 28 | this->output_val = ops.output_val; 29 | } 30 | 31 | void runOnOperation() override { 32 | 33 | auto func = getOperation(); 34 | 35 | markAnalysesPreserved(); 36 | 37 | mlir::OpBuilder builder(func); 38 | mlir::IRRewriter rewriter(builder); 39 | SmallVector inputTypes; 40 | // Set function argument types 41 | for (auto argval : func.getArguments()) { 42 | argval.setType( 43 | argval.getType().dyn_cast().replaceSubElements( 44 | [&](hecate::earth::HEScaleTypeInterface t) { 45 | return t.switchScale(waterline); 46 | })); 47 | inputTypes.push_back(argval.getType()); 48 | } 49 | 50 | // Apply waterline rescaling for the operations 51 | func.walk([&](hecate::earth::ForwardMgmtInterface sop) { 52 | builder.setInsertionPointAfter(sop.getOperation()); 53 | sop.processOperandsEVA(waterline); 54 | inferTypeForward(sop); 55 | sop.processResultsEVA(waterline); 56 | }); 57 | hecate::earth::refineReturnValues(func, builder, inputTypes, waterline, 58 | output_val); 59 | } 60 | 61 | void getDependentDialects(DialectRegistry ®istry) const override { 62 | registry.insert(); 63 | } 64 | }; 65 | } // namespace 66 | -------------------------------------------------------------------------------- /lib/Runtime/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_library (SEAL_HEVM 3 | SHARED 4 | SEAL_HEVM.cpp) 5 | target_link_libraries(SEAL_HEVM 6 | PUBLIC 7 | SEAL::seal 8 | ) 9 | -------------------------------------------------------------------------------- /python/hecate/hecate/__init__.py: -------------------------------------------------------------------------------- 1 | from .expr import * 2 | from .runner import * 3 | __version__ ="0.0.1" 4 | __all__ = ["setMain","Plain", "Model", "sigmoid", "sqrt", "inverse" , 5 | "sum" , "mean", "variance", "func", "compile", 6 | "dump", "loadModule", "getFunctionInfo","loadContext", 7 | "encrypt", "decrypt", "toggleDebug", "precision_cast", 8 | "PlainMat", "reduce", "BackendType", "setBound", "load_mlir", "pprint", 9 | "hecate_dir", "removeCtxt", "Empty", "bootstrap", "save" 10 | ] 11 | -------------------------------------------------------------------------------- /python/hecate/hecate/expr.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import weakref 3 | import re 4 | import inspect 5 | from subprocess import Popen 6 | from collections.abc import Iterable 7 | # import torch 8 | 9 | import os 10 | import time 11 | import numpy as np 12 | import numpy.ctypeslib as npcl 13 | 14 | hecate_dir = os.environ["HECATE"] 15 | hecateBuild = hecate_dir+"/build" 16 | heaan_keyset = "/heaan_keyset" 17 | libpath = hecateBuild + "/lib/" 18 | lt = ctypes.CDLL(libpath+"libHecateFrontend.so") 19 | os.environ['PATH'] = libpath + os.pathsep + os.environ['PATH'] 20 | 21 | 22 | 23 | 24 | 25 | """Object Creation""" 26 | lt.createConstant.argtypes = [ 27 | ctypes.c_void_p, 28 | ctypes.POINTER(ctypes.c_double), ctypes.c_size_t, ctypes.c_char_p, 29 | ctypes.c_size_t 30 | ] 31 | lt.createFunc.argtypes = [ 32 | ctypes.c_void_p, ctypes.c_char_p, 33 | ctypes.POINTER(ctypes.c_int), ctypes.c_size_t, ctypes.c_char_p, 34 | ctypes.c_size_t 35 | ] 36 | lt.initFunc.argtypes = [ 37 | ctypes.c_void_p, ctypes.c_size_t, 38 | ctypes.POINTER(ctypes.c_size_t), ctypes.c_size_t 39 | ] 40 | lt.createConstant.restype = ctypes.c_size_t 41 | lt.createConstant.restype = ctypes.c_size_t 42 | lt.createFunc.restype = ctypes.c_size_t 43 | 44 | """Immediate arguments""" 45 | lt.createRotation.argtypes = [ 46 | ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int64, ctypes.c_char_p, ctypes.c_size_t 47 | ] 48 | lt.createRotation.restype = ctypes.c_size_t 49 | 50 | 51 | """Unary Operation""" 52 | lt.createUnary.argtypes = [ 53 | ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_char_p, 54 | ctypes.c_size_t 55 | ] 56 | lt.createUnary.restype = ctypes.c_size_t 57 | """Binary Operation""" 58 | lt.createBinary.argtypes = [ 59 | ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t, ctypes.c_size_t, 60 | ctypes.c_char_p, ctypes.c_size_t 61 | ] 62 | lt.createBinary.restype = ctypes.c_size_t 63 | """Return""" 64 | lt.setOutput.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.POINTER(ctypes.c_size_t), 65 | ctypes.c_size_t] 66 | """compile""" 67 | lt.save.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p] 68 | lt.save.restype = ctypes.c_char_p 69 | 70 | lt.init.restype = ctypes.c_void_p 71 | lt.finalize.argtypes = [ctypes.c_void_p] 72 | """Context Generation""" 73 | ctxt = lt.init() 74 | import sys 75 | # weakref.finalize(sys.modules[__name__], lt.finalize, ctxt) 76 | 77 | """OpCode Tables""" 78 | toUnary = { 79 | "bootstrap": 0, # Bootstrap 80 | } 81 | toBinary = { 82 | "add": 6, # Addition 83 | "sub": 7, # Subtraction 84 | "mul": 8, # Multiplication 85 | } 86 | toInnerUnary = { 87 | "neg": 13, # Negation 88 | } 89 | 90 | 91 | 92 | def save(dirs="", cst_dirs=""): 93 | (frame, filename, line_number, function_name, lines, 94 | index) = getProperFrame() 95 | 96 | start = time.perf_counter() 97 | proc_start = time.process_time() 98 | if dirs=="" : 99 | dirs = os.getcwd() 100 | if cst_dirs=="" : 101 | cst_dirs = os.getcwd() 102 | 103 | [func.eval() for func in funcList] 104 | name = filename.split("/")[-1].split(".")[0].encode('utf-8') 105 | name = (dirs +"/"+ filename.split("/")[-1].split(".")[0]+ ".mlir" ).encode('utf-8') 106 | cst_name = (cst_dirs).encode('utf-8') 107 | name = lt.save(ctxt, cst_name, name ).decode('utf-8') 108 | 109 | 110 | return name 111 | 112 | """Set UnaryOperators""" 113 | 114 | 115 | def unaryFactory(name, opcode): 116 | def unaryMethod(self): 117 | (frame, filename, line_number, function_name, lines, 118 | index) = inspect.stack()[1] 119 | return Expr( 120 | lt.createUnary(ctxt, opcode, self.obj, filename.encode('utf-8'), 121 | line_number)) 122 | 123 | globals()[name] = unaryMethod 124 | [unaryFactory(name, opcode) for name, opcode in toUnary.items()] 125 | """Metaclasses""" 126 | 127 | 128 | class hecateMetaBase(type): 129 | def __new__(cls, name, bases, attrs): 130 | newcls = super().__new__(cls, name, bases, attrs) 131 | 132 | def raiser(): 133 | raise Exception("Copying Hecate object is forbidden") 134 | setattr(newcls, "__copy__", raiser) 135 | setattr(newcls, "__deepcopy__", raiser) 136 | return newcls 137 | 138 | 139 | class hecateMetaBinary(hecateMetaBase): 140 | def __new__(cls, name, bases, attrs): 141 | newcls = super().__new__(cls, name, bases, attrs) 142 | 143 | def binaryFactory(cls, name, opcode): 144 | def binaryMethod(self, other): 145 | (frame, filename, line_number, function_name, lines, 146 | index) = inspect.stack()[1] 147 | tmp = resolveType(other) 148 | return Expr( 149 | lt.createBinary(ctxt, opcode, self.obj, tmp.obj, 150 | filename.encode('utf-8'), line_number)) 151 | 152 | def binaryReverseMethod(self, other): 153 | (frame, filename, line_number, function_name, lines, 154 | index) = inspect.stack()[1] 155 | tmp = resolveType(other) 156 | return Expr( 157 | lt.createBinary(ctxt, opcode, tmp.obj, self.obj, 158 | filename.encode('utf-8'), line_number)) 159 | 160 | def binaryInplaceMethod(self, other): 161 | (frame, filename, line_number, function_name, lines, 162 | index) = inspect.stack()[1] 163 | tmp = resolveType(other) 164 | self.obj = lt.createBinary(ctxt, opcode, self.obj, tmp.obj, 165 | filename.encode('utf-8'), 166 | line_number) 167 | 168 | setattr(cls, f"__{name}__", binaryMethod) 169 | setattr(cls, f"__r{name}__", binaryReverseMethod) 170 | setattr(cls, f"__i{name}__", binaryReverseMethod) 171 | 172 | 173 | def innerFactory(cls, name, opcode): 174 | def innerMethod(self): 175 | (frame, filename, line_number, function_name, lines, 176 | index) = inspect.stack()[1] 177 | return Expr( 178 | lt.createUnary(ctxt, opcode, self.obj, 179 | filename.encode('utf-8'), line_number)) 180 | 181 | setattr(cls, f"__{name}__", innerMethod) 182 | 183 | [ 184 | binaryFactory(newcls, name, opcode) 185 | for name, opcode in toBinary.items() 186 | ] 187 | [ 188 | innerFactory(newcls, name, opcode) 189 | for name, opcode in toInnerUnary.items() 190 | ] 191 | 192 | def rotate(self, offset) : 193 | (frame, filename, line_number, function_name, lines, 194 | index) = inspect.stack()[1] 195 | return Expr(lt.createRotation(ctxt, self.obj, offset, 196 | filename.encode('utf-8'), line_number)) 197 | setattr(newcls, "rotate", rotate) 198 | return newcls 199 | 200 | """Helper Functions""" 201 | 202 | def getProperFrame(): 203 | (frame, thisname, line_number, function_name, lines, 204 | index) = inspect.stack()[0] 205 | for (frame, filename, line_number, function_name, lines, 206 | index) in inspect.stack(): 207 | if thisname != filename: 208 | return (frame, filename, line_number, function_name, lines, index) 209 | 210 | def recType(li): 211 | if isinstance(li, int) or isinstance(li, float): 212 | return [] 213 | elif isinstance(li, list): 214 | tys = [recType(x) for x in li] 215 | if all(i == tys[0] for i in tys): 216 | return [len(li)] + tys[0] 217 | else: 218 | raise Exception("Cannot create compatiable type") 219 | else: 220 | raise Exception("Cannot create compatiable type") 221 | 222 | 223 | def flatten(L): 224 | for l in L: 225 | if isinstance(l, list): 226 | yield from flatten(l) 227 | else: 228 | yield l 229 | 230 | 231 | def resolveType(other): 232 | if isinstance(other, Expr): 233 | return other 234 | elif isinstance(other, int): 235 | return Plain(np.array([other], dtype=np.float64)) #Plain([other]) 236 | elif isinstance(other, float): 237 | return Plain(np.array([other], dtype=np.float64)) #Plain([other]) 238 | elif isinstance(other, list): 239 | return Plain(np.array(other, dtype=np.float64)) #Plain(other) 240 | # elif isinstance(other, torch.Tensor) : 241 | # return Plain( torch.flatten(other).tolist() ) 242 | elif isinstance(other, np.ndarray): 243 | return Plain(other) 244 | else: 245 | raise Exception("Cannot create compatiable type") 246 | """Binary-operatable expression""" 247 | 248 | 249 | class Expr(metaclass=hecateMetaBinary): 250 | def __init__(self, obj): 251 | self.obj = obj 252 | 253 | class Plain(Expr): 254 | def __init__(self, data, scale = 40): 255 | # carr = (ctypes.c_double * 256 | # len(data))(*[float(x) for x in flatten(data)]) 257 | if not isinstance(data, np.ndarray) : 258 | data = np.array(data, dtype=np.float64) 259 | 260 | carr = npcl.as_ctypes(data) 261 | (frame, filename, line_number, function_name, lines, 262 | index) = getProperFrame() 263 | super().__init__( 264 | lt.createConstant(ctxt, carr, len(data), filename.encode('utf-8'), 265 | line_number)) 266 | 267 | class Empty : 268 | def __init__ (self) : 269 | pass 270 | def __add__(self, other) : 271 | return other 272 | def __radd__(self, other) : 273 | return other 274 | def __iadd__(self, other) : 275 | return other 276 | def __sub__(self, other) : 277 | return other 278 | def __rsub__(self, other) : 279 | return other 280 | def __isub__(self, other) : 281 | return other 282 | 283 | 284 | """Function Decorator""" 285 | funcList = [] 286 | 287 | 288 | def func(param): 289 | def generateFunc(func): 290 | global funcList 291 | (frame, filename, line_number, function_name, lines, 292 | index) = getProperFrame() 293 | a = Func(func, param, filename, line_number) 294 | funcList.append(a) 295 | return a 296 | 297 | return generateFunc 298 | 299 | 300 | """Function object""" 301 | from collections.abc import Iterable 302 | 303 | class Func(metaclass=hecateMetaBase): 304 | def __init__(self, fun, paramstr, filename, line_number): 305 | name = fun.__name__ 306 | self.fun = fun 307 | 308 | arg = paramstr.split(",") 309 | inputs = [ a== "c" for a in arg] 310 | 311 | self.inputlen = len(inputs) 312 | inputTys = (ctypes.c_int * len(inputs))(*inputs) 313 | self.obj = lt.createFunc(ctxt, name.encode('utf-8'), inputTys, 314 | len(inputs), filename.encode('utf-8'), 315 | line_number) 316 | 317 | def eval(self): 318 | inputarr = (ctypes.c_size_t * self.inputlen)() 319 | lt.initFunc(ctxt, self.obj, inputarr, self.inputlen) 320 | inputs = [Expr(x) for x in inputarr[:self.inputlen]] 321 | returns = self.fun(*inputs) 322 | if not isinstance (returns, Iterable) : 323 | returns = [returns] 324 | outputs = [ x.obj for x in returns] 325 | outputvec = (ctypes.c_size_t * len(outputs))(*outputs) 326 | lt.setOutput(ctxt, self.obj, outputvec, len(outputs)) 327 | 328 | def __call__(self, *args): 329 | (frame, filename, line_number, function_name, lines, 330 | index) = getProperFrame() 331 | tmps = [resolveType(arg) for arg in args] 332 | argarr = (ctypes.c_size_t * len(args))(*[tmp.obj for tmp in tmps]) 333 | 334 | return Expr( 335 | lt.createCall(ctxt, self.obj, argarr, len(args), 336 | filename.encode('utf-8'), line_number)) 337 | 338 | 339 | def removeCtxt() : 340 | lt.removeCtxt() 341 | -------------------------------------------------------------------------------- /python/hecate/hecate/runner.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import ctypes 4 | import weakref 5 | import re 6 | import inspect 7 | from subprocess import Popen 8 | from collections.abc import Iterable 9 | # import torch 10 | 11 | import os 12 | import time 13 | import numpy as np 14 | import numpy.ctypeslib as npcl 15 | from pathlib import Path 16 | 17 | 18 | 19 | hecate_dir = Path(os.environ["HECATE"]) 20 | hecateBuild = hecate_dir / "build" 21 | 22 | 23 | if not hecateBuild.is_dir() : # We expect that this is library path 24 | hecateBuild = hecate_dir 25 | 26 | libpath = hecateBuild / "lib" 27 | lw = ctypes.CDLL(libpath / "libSEAL_HEVM.so") 28 | os.environ['PATH'] = str(libpath) + os.pathsep + os.environ['PATH'] 29 | 30 | 31 | # Init VM functions 32 | lw.initFullVM.argtypes = [ctypes.c_char_p] 33 | lw.initFullVM.restype = ctypes.c_void_p 34 | lw.initClientVM.argtypes = [ctypes.c_char_p] 35 | lw.initClientVM.restype = ctypes.c_void_p 36 | lw.initServerVM.argtypes = [ctypes.c_char_p] 37 | lw.initServerVM.restype = ctypes.c_void_p 38 | 39 | # Init SEAL Contexts 40 | lw.create_context.argtypes = [ctypes.c_char_p] 41 | lw.load.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_char_p] 42 | lw.loadClient.argtypes = [ctypes.c_void_p, ctypes.c_void_p] 43 | lw.getArgLen.argtypes = [ctypes.c_void_p] 44 | lw.getArgLen.restype = ctypes.c_int64 45 | lw.getResLen.argtypes = [ctypes.c_void_p] 46 | lw.getResLen.restype = ctypes.c_int64 47 | 48 | # Encrypt/Decrypt Functions 49 | lw.encrypt.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.POINTER(ctypes.c_double), ctypes.c_int] 50 | lw.decrypt.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.POINTER(ctypes.c_double)] 51 | lw.decrypt_result.argtypes = [ctypes.c_void_p, ctypes.c_int64, ctypes.POINTER(ctypes.c_double)] 52 | 53 | # Helper Functions for ciphertext access 54 | lw.getResIdx.argtypes = [ctypes.c_void_p, ctypes.c_int64] 55 | lw.getResIdx.restype = ctypes.c_int64 56 | lw.getCtxt.argtypes = [ctypes.c_void_p, ctypes.c_int64] 57 | lw.getCtxt.restype = ctypes.c_void_p 58 | 59 | # Runner Functions 60 | lw.preprocess.argtypes = [ctypes.c_void_p] 61 | lw.run.argtypes = [ctypes.c_void_p] 62 | 63 | #Debug Function 64 | lw.setDebug.argtypes = [ctypes.c_void_p, ctypes.c_bool] 65 | 66 | 67 | class HEVM : 68 | def __init__ (self, path = str((Path.home() / ".hevm" / "seal").absolute()) , option= "full") : 69 | self.option = option 70 | if not Path(path).is_dir() : 71 | print ("Press Any key to generate SEAL files (or just kill with ctrl+c)") 72 | input() 73 | Path(path).mkdir(parents=True) 74 | lw.create_context(path.encode('utf-8')) 75 | 76 | if option == "full" : 77 | self.vm = lw.initFullVM(path.encode('utf-8')) 78 | elif option == "client" : 79 | self.vm = lw.initClientVM(path.encode('utf-8')) 80 | elif option == "server" : 81 | self.vm = lw.initServerVM(path.encode('utf-8')) 82 | 83 | # def load (self, func, preprocess=True, const_path =str( (Path(func_dir) / "_hecate_{func}.cst").absoluate() ), hevm_path = str(Path(func_dir) / "_hecate_{func}.hevm"), func_dir = str(Path.cwd()), ) : 84 | def load (self, const_path, hevm_path, preprocess=True) : 85 | if not Path(const_path).is_file() : 86 | raise Exception(f"No file exists in const_path {const_path}") 87 | if not Path(hevm_path).is_file() : 88 | raise Exception(f"No file exists in hevm_path {hevm_path}") 89 | 90 | if self.option == "full" or self.option == "server" : 91 | lw.load(self.vm, const_path.encode('utf-8'), hevm_path.encode('utf-8')) 92 | elif self.option == "client" : 93 | lw.loadClient (self.vm, const_path.encode('utf-8')) 94 | if (preprocess) : 95 | lw.preprocess (self.vm) 96 | else : 97 | raise Exception("Not implemented in SEAL_HEVM") 98 | 99 | self.arglen = lw.getArgLen(self.vm) 100 | self.reslen = lw.getResLen(self.vm) 101 | 102 | def run (self) : 103 | lw.run(self.vm) 104 | 105 | def setInput(self, i, data) : 106 | if not isinstance(data, np.ndarray) : 107 | data = np.array(data, dtype=np.float64) 108 | carr = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 109 | lw.encrypt(self.vm, i, carr, len(data)) 110 | 111 | def setDebug (self, enable) : 112 | lw.setDebug(self.vm, enable) 113 | 114 | 115 | def getOutput (self) : 116 | result = np.zeros( (self.reslen, 1 << 14), dtype=np.float64) 117 | data = np.zeros( 1 << 14, dtype=np.float64) 118 | for i in range(self.reslen) : 119 | # carr = npcl.as_ctypes(data) 120 | carr = data.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) 121 | lw.decrypt_result(self.vm, i, carr) 122 | # result[i] = npcl.as_array(carr, shape= 1<<14) 123 | result[i] = data 124 | 125 | return result 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /python/hecate/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='hecate', 5 | version='0.0.1', 6 | description='PYPI package for hecate binding', 7 | author='yongwoo', 8 | author_email='dragonrain96@gmail.com', 9 | install_requires=['numpy'], 10 | packages=find_packages(exclude=[]), 11 | keywords=['yongwoo', 'homomorphic encryption', 'ckks', 'hecate', 'elasm'], 12 | python_requires='>=3.6', 13 | package_data={}, 14 | zip_safe=False, 15 | classifiers=[ 16 | 'Programming Language :: Python :: 3.6', 17 | 'Programming Language :: Python :: 3.7', 18 | 'Programming Language :: Python :: 3.8', 19 | 'Programming Language :: Python :: 3.9', 20 | 'Programming Language :: Python :: 3.10', 21 | ], 22 | ) 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2023.7.22 2 | charset-normalizer==3.2.0 3 | cmake==3.27.1 4 | filelock==3.12.2 5 | idna==3.4 6 | Jinja2==3.1.2 7 | lit==16.0.6 8 | MarkupSafe==2.1.3 9 | mpmath==1.3.0 10 | networkx==3.1 11 | numpy==1.25.2 12 | nvidia-cublas-cu11==11.10.3.66 13 | nvidia-cuda-cupti-cu11==11.7.101 14 | nvidia-cuda-nvrtc-cu11==11.7.99 15 | nvidia-cuda-runtime-cu11==11.7.99 16 | nvidia-cudnn-cu11==8.5.0.96 17 | nvidia-cufft-cu11==10.9.0.58 18 | nvidia-curand-cu11==10.2.10.91 19 | nvidia-cusolver-cu11==11.4.0.1 20 | nvidia-cusparse-cu11==11.7.4.91 21 | nvidia-nccl-cu11==2.14.3 22 | nvidia-nvtx-cu11==11.7.91 23 | Pillow==10.0.0 24 | requests==2.31.0 25 | sympy==1.12 26 | torch==2.0.1 27 | torchvision==0.15.2 28 | triton==2.0.0 29 | typing_extensions==4.7.1 30 | urllib3==2.0.4 31 | -------------------------------------------------------------------------------- /tools/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_library (HecateFrontend 3 | SHARED 4 | frontend.cpp 5 | ) 6 | 7 | get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) 8 | get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) 9 | 10 | target_link_libraries(HecateFrontend 11 | EarthDialect 12 | EarthTransforms 13 | ${dialect_libs} 14 | ${conversion_libs} 15 | ) 16 | 17 | add_executable (hecate-opt 18 | optimizer.cpp 19 | ) 20 | 21 | target_link_libraries(hecate-opt 22 | MLIREmitCDialect 23 | MLIROptLib 24 | EarthDialect 25 | CKKSDialect 26 | EarthTransforms 27 | CKKSTransforms 28 | HecateCKKSCommonConversion 29 | HecateEarthToCKKS 30 | HecateCKKSToCKKS 31 | ${dialect_libs} 32 | ${conversion_libs} 33 | ) 34 | -------------------------------------------------------------------------------- /tools/frontend.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 9 | #include "mlir/Dialect/Bufferization/Transforms/Passes.h" 10 | #include "mlir/Parser/Parser.h" 11 | #include "mlir/Tools/mlir-opt/MlirOptMain.h" 12 | #include "llvm/Support/SourceMgr.h" 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | #include "hecate/Dialect/Earth/IR/EarthOps.h" 27 | #include "hecate/Dialect/Earth/Transforms/Passes.h" 28 | 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | using namespace mlir; 36 | 37 | void handler(int sig) { 38 | void *array[10]; 39 | size_t size; 40 | 41 | size = backtrace(array, 10); 42 | 43 | fprintf(stderr, "Error: signal %d:\n", sig); 44 | backtrace_symbols_fd(array, size, STDERR_FILENO); 45 | exit(1); 46 | } 47 | 48 | namespace hecate { 49 | 50 | using valueID = size_t; 51 | using funcID = size_t; 52 | 53 | struct Context { 54 | Context(); 55 | mlir::MLIRContext ctxt; 56 | mlir::OwningOpRef mod; 57 | std::unique_ptr builder; 58 | llvm::SmallVector valueMap; 59 | llvm::SmallVector funcMap; 60 | }; 61 | 62 | Context::Context() : ctxt(), mod(), builder() { 63 | 64 | ctxt.getOrLoadDialect(); 65 | auto ed = ctxt.getOrLoadDialect(); 66 | 67 | auto tmp = std::make_unique(&ctxt); 68 | builder.swap(tmp); 69 | 70 | mod = mlir::OwningOpRef( 71 | mlir::ModuleOp::create(builder->getUnknownLoc())); 72 | } 73 | 74 | extern "C" { 75 | 76 | valueID createConstant(Context *ctxt, double *data, int64_t len, char *filename, 77 | size_t line) { 78 | auto &&builder = *ctxt->builder; 79 | auto cons = builder.create( 80 | mlir::FileLineColLoc::get(builder.getStringAttr(filename), line, 0), 81 | llvm::ArrayRef(data, len)); 82 | 83 | ctxt->valueMap.push_back(cons); 84 | return ctxt->valueMap.size() - 1; 85 | } 86 | funcID createFunc(Context *ctxt, char *name, int *inputTys, size_t len, 87 | char *filename, size_t line) { 88 | auto &&builder = *ctxt->builder; 89 | auto &&funcMap = ctxt->funcMap; 90 | llvm::SmallVector arg_types(len); 91 | std::transform(inputTys, inputTys + len, arg_types.begin(), [&](auto a) { 92 | return mlir::RankedTensorType::get( 93 | llvm::SmallVector{1}, 94 | builder.getType(0, 0)); 95 | }); 96 | auto funcType = builder.getFunctionType( 97 | arg_types, mlir::RankedTensorType::get( 98 | llvm::SmallVector{1}, 99 | builder.getType(0, 0))); 100 | auto funcOp = mlir::func::FuncOp::create( 101 | mlir::FileLineColLoc::get(builder.getStringAttr(filename), line, 0), 102 | std::string("_hecate_") + name, funcType); 103 | funcMap.push_back(funcOp); 104 | ctxt->mod->push_back(funcOp); 105 | return funcMap.size() - 1; 106 | } 107 | 108 | void initFunc(Context *ctxt, funcID fun, valueID *args, size_t len) { 109 | auto &&funcOp = ctxt->funcMap[fun]; 110 | auto &&valueMap = ctxt->valueMap; 111 | auto entryBlock = funcOp.addEntryBlock(); 112 | auto funcInput = entryBlock->getArguments(); 113 | ctxt->builder->setInsertionPointToStart(entryBlock); 114 | { 115 | int i = 0; 116 | for (auto a : funcInput) { 117 | valueMap.push_back(a); 118 | args[i++] = valueMap.size() - 1; 119 | } 120 | } 121 | } 122 | 123 | char *save(Context *c, char *const_name, char *mlir_name) { 124 | c->mod->getOperation()->setAttr(mlir::SymbolTable::getSymbolAttrName(), 125 | c->builder->getStringAttr(mlir_name)); 126 | std::string s_const_name(const_name); 127 | mlir::PassManager pm(&c->ctxt); 128 | pm.addPass(createCSEPass()); 129 | pm.addPass(createCanonicalizerPass()); 130 | pm.addNestedPass( 131 | earth::createElideConstant({s_const_name + "/"})); 132 | pm.addNestedPass(earth::createPrivatizeConstant()); 133 | pm.addPass(createCanonicalizerPass()); 134 | 135 | auto ret = pm.run(*c->mod); 136 | 137 | std::error_code EC; 138 | llvm::raw_fd_ostream outputFile(mlir_name, EC); 139 | c->mod->print(outputFile, mlir::OpPrintingFlags() 140 | .printGenericOpForm() 141 | .enableDebugInfo() 142 | .useLocalScope()); 143 | c->valueMap.clear(); 144 | c->funcMap.clear(); 145 | c->mod.release(); 146 | return mlir_name; 147 | } 148 | 149 | /* Unary Operation */ 150 | valueID createUnary(Context *ctxt, size_t opcode, valueID lhs, char *filename, 151 | size_t line) { 152 | auto &&builder = *ctxt->builder; 153 | auto &&valueMap = ctxt->valueMap; 154 | auto location = 155 | mlir::FileLineColLoc::get(builder.getStringAttr(filename), line, 0); 156 | auto &&source = valueMap[lhs]; 157 | switch (opcode) { 158 | case 0: { 159 | break; 160 | } 161 | case 13: { 162 | auto res = builder.create(location, source); 163 | valueMap.push_back(res); 164 | break; 165 | } 166 | 167 | default: 168 | assert(0 && "Unary Operation type is wrong"); 169 | } 170 | return valueMap.size() - 1; 171 | } 172 | 173 | /* Binary Operation */ 174 | valueID createBinary(Context *ctxt, size_t opcode, valueID lhs, valueID rhs, 175 | char *filename, size_t line) { 176 | auto &&builder = *ctxt->builder; 177 | auto &&valueMap = ctxt->valueMap; 178 | auto location = 179 | mlir::FileLineColLoc::get(builder.getStringAttr(filename), line, 0); 180 | 181 | auto &&srcl = valueMap[lhs]; 182 | auto &&srcr = valueMap[rhs]; 183 | 184 | switch (opcode) { 185 | case 6: { 186 | auto res = builder.create(location, srcl, srcr); 187 | valueMap.push_back(res); 188 | break; 189 | } 190 | case 7: { 191 | auto neg = builder.create(location, srcr); 192 | valueMap.push_back(neg); 193 | auto res = builder.create(location, srcl, neg); 194 | valueMap.push_back(res); 195 | break; 196 | } 197 | 198 | case 8: { 199 | auto res = builder.create(location, srcl, srcr); 200 | valueMap.push_back(res); 201 | break; 202 | } 203 | 204 | default: 205 | assert(0 && "Binary Operation type is wrong"); 206 | } 207 | return valueMap.size() - 1; 208 | } 209 | 210 | valueID createRotation(Context *ctxt, size_t valueID, int offset, 211 | char *filename, size_t line) { 212 | auto &&builder = *ctxt->builder; 213 | auto &&srcl = ctxt->valueMap[valueID]; 214 | auto cons = builder.create( 215 | mlir::FileLineColLoc::get(builder.getStringAttr(filename), line, 0), srcl, 216 | offset); 217 | ctxt->valueMap.push_back(cons); 218 | return ctxt->valueMap.size() - 1; 219 | } 220 | 221 | void setOutput(Context *ctxt, funcID fun, valueID *ret, size_t len) { 222 | llvm::SmallVector rets; 223 | llvm::SmallVector types; 224 | for (int i = 0; i < len; i++) { 225 | rets.push_back(ctxt->valueMap[ret[i]]); 226 | types.push_back(ctxt->valueMap[ret[i]].getType()); 227 | } 228 | auto func = ctxt->funcMap[fun]; 229 | ctxt->builder->create(func.getLoc(), rets); 230 | auto retType = func.getFunctionType(); 231 | func.setFunctionType( 232 | ctxt->builder->getFunctionType(retType.getInputs(), types)); 233 | } 234 | 235 | Context *init() { 236 | /* signal(SIGSEGV, handler); */ 237 | return new ::hecate::Context(); 238 | } 239 | void finalize(Context *ctxt) { delete ctxt; } 240 | } // namespace hecate 241 | } // namespace hecate 242 | 243 | /* int main() {} */ 244 | -------------------------------------------------------------------------------- /versions.txt: -------------------------------------------------------------------------------- 1 | python=3.10.12 2 | llvm-project=llvmorg-16.0.0 3 | cmake=3.22.1 4 | SEAL=4.0.0 5 | --------------------------------------------------------------------------------