├── .github └── workflows │ └── tests.yml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── bench ├── Cargo.toml └── benches │ ├── double.rs │ └── modops.rs └── src ├── barrett.rs ├── bigint.rs ├── double.rs ├── lib.rs ├── mersenne.rs ├── monty.rs ├── preinv.rs ├── prim.rs ├── reduced.rs └── word.rs /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: 4 | - master 5 | pull_request: 6 | branches: 7 | - master 8 | 9 | name: Tests 10 | 11 | jobs: 12 | check: 13 | name: Check 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | rust: [stable, 1.57] 18 | steps: 19 | - uses: actions/checkout@v2 20 | - uses: actions-rs/toolchain@v1 21 | with: 22 | profile: minimal 23 | toolchain: ${{ matrix.rust }} 24 | override: true 25 | - uses: actions-rs/cargo@v1 26 | with: 27 | command: check 28 | args: --all-features 29 | 30 | test: 31 | name: Test 32 | strategy: 33 | matrix: 34 | bits: [16, 32, 64] 35 | runs-on: ubuntu-latest 36 | env: 37 | RUST_BACKTRACE: 1 38 | RUSTFLAGS: -D warnings --cfg force_bits="${{ matrix.bits }}$" 39 | steps: 40 | - uses: actions/checkout@v2 41 | - uses: actions-rs/toolchain@v1 42 | with: 43 | profile: minimal 44 | toolchain: stable 45 | override: true 46 | - uses: actions-rs/cargo@v1 47 | with: 48 | command: test 49 | args: --all-features 50 | 51 | test-x86: 52 | name: Test x86 53 | runs-on: ubuntu-latest 54 | env: 55 | RUST_BACKTRACE: 1 56 | RUSTFLAGS: -D warnings 57 | steps: 58 | - uses: actions/checkout@v2 59 | - uses: actions-rs/toolchain@v1 60 | with: 61 | profile: minimal 62 | toolchain: stable-i686-unknown-linux-gnu 63 | override: true 64 | - run: | 65 | sudo apt update 66 | sudo apt install gcc-multilib 67 | - uses: actions-rs/cargo@v1 68 | with: 69 | command: test 70 | 71 | test-x86_64: 72 | name: Test x86_64 73 | runs-on: ubuntu-latest 74 | env: 75 | RUST_BACKTRACE: 1 76 | RUSTFLAGS: -D warnings 77 | steps: 78 | - uses: actions/checkout@v2 79 | - uses: actions-rs/toolchain@v1 80 | with: 81 | profile: minimal 82 | toolchain: stable-x86_64-unknown-linux-gnu 83 | override: true 84 | - uses: actions-rs/cargo@v1 85 | with: 86 | command: test 87 | 88 | test-no-std: 89 | name: Test no-std 90 | runs-on: ubuntu-latest 91 | env: 92 | RUSTFLAGS: -D warnings 93 | steps: 94 | - uses: actions/checkout@v2 95 | - uses: actions-rs/toolchain@v1 96 | with: 97 | profile: minimal 98 | toolchain: stable 99 | override: true 100 | - uses: actions-rs/cargo@v1 101 | with: 102 | command: test 103 | args: --no-default-features 104 | 105 | build-aarch64: 106 | name: Build aarch64 107 | runs-on: ubuntu-latest 108 | env: 109 | RUSTFLAGS: -D warnings 110 | steps: 111 | - uses: actions/checkout@v2 112 | - uses: actions-rs/toolchain@v1 113 | with: 114 | profile: minimal 115 | toolchain: stable 116 | target: aarch64-unknown-linux-gnu 117 | override: true 118 | - uses: actions-rs/cargo@v1 119 | with: 120 | command: build 121 | args: --target aarch64-unknown-linux-gnu --all-features --workspace --exclude benchmark 122 | 123 | build-benchmark: 124 | name: Build benchmark 125 | runs-on: ubuntu-latest 126 | env: 127 | RUSTFLAGS: -D warnings 128 | steps: 129 | - uses: actions/checkout@v2 130 | - uses: actions-rs/toolchain@v1 131 | with: 132 | profile: minimal 133 | toolchain: stable 134 | override: true 135 | - uses: actions-rs/cargo@v1 136 | with: 137 | command: build 138 | args: -p num-modular-bench 139 | 140 | fmt: 141 | name: Rustfmt 142 | runs-on: ubuntu-latest 143 | steps: 144 | - uses: actions/checkout@v2 145 | - uses: actions-rs/toolchain@v1 146 | with: 147 | profile: minimal 148 | toolchain: stable 149 | override: true 150 | components: rustfmt 151 | - uses: actions-rs/cargo@v1 152 | with: 153 | command: fmt 154 | args: --all -- --check 155 | 156 | clippy: 157 | name: Clippy 158 | runs-on: ubuntu-latest 159 | steps: 160 | - uses: actions/checkout@v2 161 | - uses: actions-rs/toolchain@v1 162 | with: 163 | profile: minimal 164 | toolchain: stable 165 | override: true 166 | components: clippy 167 | - uses: actions-rs/cargo@v1 168 | with: 169 | command: clippy 170 | args: --all-features --all-targets -- -D warnings -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | Cargo.lock 3 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "num-modular" 3 | version = "0.6.1" 4 | edition = "2018" 5 | 6 | repository = "https://github.com/cmpute/num-modular" 7 | keywords = ["mathematics", "numeric", "number-theory", "modular", "montgomery"] 8 | categories = ["mathematics", "algorithms", "no-std"] 9 | documentation = "https://docs.rs/num-modular" 10 | license = "Apache-2.0" 11 | description = """ 12 | Implementation of efficient integer division and modular arithmetic operations with generic number types. 13 | Supports various backends including num-bigint, etc.. 14 | """ 15 | readme = "README.md" 16 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 17 | 18 | [dependencies] 19 | num-integer = { version = "0.1.44", optional = true } 20 | num-traits = { version = "0.2.14", optional = true } 21 | 22 | [dependencies.num-bigint] 23 | optional = true 24 | version = "0.4.3" 25 | default-features = false 26 | 27 | [dev-dependencies] 28 | rand = "0.8.4" 29 | 30 | [workspace] 31 | members = [ 32 | "bench", 33 | ] 34 | 35 | [package.metadata.docs.rs] 36 | all-features = true 37 | 38 | [features] 39 | std = [] 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2022] [Jacob Zhong] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # num-modular 2 | 3 | A generic implementation of integer division and modular arithmetics in Rust. It provide basic operators and an type to represent integers in a modulo ring. Specifically the following features are supported: 4 | 5 | - Common modular arithmetics: `add`, `sub`, `mul`, `div`, `neg`, `double`, `square`, `inv`, `pow` 6 | - Optimized modular arithmetics in **Montgomery form** 7 | - Optimized modular arithmetics with **pseudo Mersenne primes** as moduli 8 | - Fast **integer divisibility** check 9 | - **Legendre**, **Jacobi** and **Kronecker** symbols 10 | 11 | It also support various integer type backends, including primitive integers and `num-bigint`. Note that this crate also supports `[no_std]`. To enable `std` related functionalities, enable the `std` feature of the crate. 12 | 13 | 17 | -------------------------------------------------------------------------------- /bench/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "num-modular-bench" 3 | version = "0.0.0" 4 | publish = false 5 | edition = "2018" 6 | 7 | [[bench]] 8 | name = "modops" 9 | harness = false 10 | 11 | [[bench]] 12 | name = "double" 13 | harness = false 14 | 15 | [dependencies] 16 | num-integer = "0.1.0" 17 | num-traits = "0.2.0" 18 | num-modular = { path = ".." } 19 | rand = "0.8.4" 20 | criterion = "0.3" 21 | 22 | ethnum = { version = "1.1.1", optional = true } 23 | primitive-types = { version = "0.11", optional = true } 24 | crypto-bigint = { version = "0.3.2", optional = true } 25 | uint = { version = "0.9.3", optional = true } 26 | zkp-u256 = { version = "0.2.1", optional = true } 27 | 28 | [features] 29 | default = ['ethnum'] 30 | -------------------------------------------------------------------------------- /bench/benches/double.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | use criterion::Criterion; 4 | use rand::random; 5 | 6 | use num_modular::udouble; 7 | 8 | pub fn bench_mul(c: &mut Criterion) { 9 | let mut group = c.benchmark_group("single mul single"); 10 | 11 | const N: usize = 12; 12 | let mut lhs: [u128; N] = [0; N]; 13 | let mut rhs: [u128; N] = [0; N]; 14 | for i in 0..N { 15 | lhs[i] = random(); 16 | rhs[i] = random(); 17 | } 18 | 19 | group.bench_function("ours", |b| { 20 | b.iter(|| { 21 | lhs.iter() 22 | .zip(rhs.iter()) 23 | .map(|(&a, &b)| udouble::widening_mul(a, b)) 24 | .reduce(|a, b| udouble::from(a.lo.wrapping_add(b.lo))) 25 | }) 26 | }); 27 | 28 | #[cfg(feature = "ethnum")] 29 | { 30 | use ethnum::U256; 31 | group.bench_function("ethnum", |b| { 32 | b.iter(|| { 33 | lhs.iter() 34 | .zip(rhs.iter()) 35 | .map(|(&a, &b)| U256::from(a) * U256::from(b)) 36 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 37 | }) 38 | }); 39 | } 40 | 41 | #[cfg(feature = "primitive-types")] 42 | { 43 | use primitive_types::U256; 44 | group.bench_function("primitive-types", |b| { 45 | b.iter(|| { 46 | lhs.iter() 47 | .zip(rhs.iter()) 48 | .map(|(&a, &b)| U256::from(a) * U256::from(b)) 49 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 50 | }) 51 | }); 52 | } 53 | 54 | #[cfg(feature = "uint")] 55 | { 56 | use uint::construct_uint; 57 | construct_uint! { 58 | pub struct U256(4); 59 | } 60 | group.bench_function("uint", |b| { 61 | b.iter(|| { 62 | lhs.iter() 63 | .zip(rhs.iter()) 64 | .map(|(&a, &b)| U256::from(a) * U256::from(b)) 65 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 66 | }) 67 | }); 68 | } 69 | 70 | #[cfg(feature = "zkp-u256")] 71 | { 72 | use zkp_u256::U256; 73 | group.bench_function("zkp-u256", |b| { 74 | b.iter(|| { 75 | lhs.iter() 76 | .zip(rhs.iter()) 77 | .map(|(&a, &b)| U256::from(a) * U256::from(b)) 78 | .reduce(|a, b| U256::from(a.as_u128().wrapping_add(b.as_u128()))) 79 | }) 80 | }); 81 | } 82 | 83 | #[cfg(feature = "crypto-bigint")] 84 | { 85 | use crypto_bigint::{Split, U128, U256}; 86 | group.bench_function("crypto-bigint", |b| { 87 | b.iter(|| { 88 | lhs.iter() 89 | .zip(rhs.iter()) 90 | .map(|(&a, &b)| U256::from(a).saturating_mul(&U256::from(b))) 91 | .reduce(|a, b| U256::from((U128::ZERO, a.split().0.wrapping_add(&b.split().0)))) 92 | }) 93 | }); 94 | } 95 | 96 | group.finish(); 97 | } 98 | 99 | pub fn bench_div(c: &mut Criterion) { 100 | let mut group = c.benchmark_group("double div single"); 101 | 102 | const N: usize = 12; 103 | let mut lhs: [(u128, u128); N] = [(0, 0); N]; 104 | let mut rhs: [u128; N] = [0; N]; 105 | for i in 0..N { 106 | lhs[i] = (random(), random()); 107 | rhs[i] = random(); 108 | } 109 | 110 | group.bench_function("ours", |b| { 111 | b.iter(|| { 112 | lhs.iter() 113 | .zip(rhs.iter()) 114 | .map(|(&a, &b)| udouble { lo: a.0, hi: a.1 } / b) 115 | .reduce(|a, b| udouble::from(a.lo.wrapping_add(b.lo))) 116 | }) 117 | }); 118 | 119 | #[cfg(feature = "ethnum")] 120 | { 121 | use ethnum::U256; 122 | group.bench_function("ethnum", |b| { 123 | b.iter(|| { 124 | lhs.iter() 125 | .zip(rhs.iter()) 126 | .map(|(&a, &b)| U256([a.0, a.1]) / b) 127 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 128 | }) 129 | }); 130 | } 131 | 132 | #[cfg(feature = "primitive-types")] 133 | { 134 | use primitive_types::U256; 135 | const MASK: u128 = (1 << 64) - 1; 136 | group.bench_function("primitive-types", |b| { 137 | b.iter(|| { 138 | lhs.iter() 139 | .zip(rhs.iter()) 140 | .map(|(&a, &b)| { 141 | U256([ 142 | (a.0 >> 64) as u64, 143 | (a.0 & MASK) as u64, 144 | (a.1 >> 64) as u64, 145 | (a.1 & MASK) as u64, 146 | ]) / b 147 | }) 148 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 149 | }) 150 | }); 151 | } 152 | 153 | #[cfg(feature = "uint")] 154 | { 155 | use uint::construct_uint; 156 | construct_uint! { 157 | pub struct U256(4); 158 | } 159 | const MASK: u128 = (1 << 64) - 1; 160 | group.bench_function("uint", |b| { 161 | b.iter(|| { 162 | lhs.iter() 163 | .zip(rhs.iter()) 164 | .map(|(&a, &b)| { 165 | U256([ 166 | (a.0 >> 64) as u64, 167 | (a.0 & MASK) as u64, 168 | (a.1 >> 64) as u64, 169 | (a.1 & MASK) as u64, 170 | ]) / b 171 | }) 172 | .reduce(|a, b| U256::from(a.0[0].wrapping_add(b.0[0]))) 173 | }) 174 | }); 175 | } 176 | 177 | #[cfg(feature = "zkp-u256")] 178 | { 179 | use zkp_u256::U256; 180 | group.bench_function("zkp-u256", |b| { 181 | b.iter(|| { 182 | lhs.iter() 183 | .zip(rhs.iter()) 184 | .map(|(&a, &b)| ((U256::from(a.0) << 128) + U256::from(a.1)) / U256::from(b)) 185 | .reduce(|a, b| U256::from(a.as_u128().wrapping_add(b.as_u128()))) 186 | }) 187 | }); 188 | } 189 | 190 | #[cfg(feature = "crypto-bigint")] 191 | { 192 | use crypto_bigint::{Split, U128, U256}; 193 | group.bench_function("crypto-bigint", |b| { 194 | b.iter(|| { 195 | lhs.iter() 196 | .zip(rhs.iter()) 197 | .map(|(&a, &b)| { 198 | ((U256::from(a.0) << 128).wrapping_add(&U256::from(a.1))) 199 | .div_rem(&U256::from(b)) 200 | .unwrap() 201 | .0 202 | }) 203 | .reduce(|a, b| U256::from((U128::ZERO, a.split().0.wrapping_add(&b.split().0)))) 204 | }) 205 | }); 206 | } 207 | 208 | group.finish(); 209 | } 210 | 211 | criterion_group!(benches, bench_mul, bench_div); 212 | criterion_main!(benches); 213 | -------------------------------------------------------------------------------- /bench/benches/modops.rs: -------------------------------------------------------------------------------- 1 | #[macro_use] 2 | extern crate criterion; 3 | use criterion::Criterion; 4 | use num_modular::{FixedMersenneInt, ModularCoreOps, ModularPow, ModularUnaryOps}; 5 | use rand::random; 6 | 7 | pub fn bench_u128(c: &mut Criterion) { 8 | const N: usize = 256; 9 | let mut cases: [(u128, u128, u128); N] = [(0, 0, 0); N]; 10 | for i in 0..N { 11 | cases[i] = (random(), random(), random()); 12 | } 13 | 14 | let mut group = c.benchmark_group("u128 modular ops"); 15 | group.bench_function("addm", |b| { 16 | b.iter(|| { 17 | cases 18 | .iter() 19 | .map(|&(a, b, m)| a.addm(b, &m)) 20 | .reduce(|a, b| a.wrapping_add(b)) 21 | }) 22 | }); 23 | group.bench_function("mulm", |b| { 24 | b.iter(|| { 25 | cases 26 | .iter() 27 | .map(|&(a, b, m)| a.mulm(b, &m)) 28 | .reduce(|a, b| a.wrapping_add(b)) 29 | }) 30 | }); 31 | } 32 | 33 | pub fn bench_modinv(c: &mut Criterion) { 34 | const M1: u64 = (1 << 56) - 5; 35 | let mut group = c.benchmark_group("modular inverse (small operands)"); 36 | 37 | group.bench_function("extended gcd", |b| { 38 | b.iter(|| { 39 | (100u64..400u64) 40 | .map(|n| n.invm(&M1).unwrap()) 41 | .reduce(|a, b| a.addm(b, &M1)) 42 | }) 43 | }); 44 | group.bench_function("fermat theorem", |b| { 45 | b.iter(|| { 46 | (100u64..400u64) 47 | .map(|n| n.powm(M1 - 2, &M1)) 48 | .reduce(|a, b| a.addm(b, &M1)) 49 | }) 50 | }); 51 | group.bench_function("mersenne + extended gcd", |b| { 52 | b.iter(|| { 53 | (100u64..400u64) 54 | .map(|n| { 55 | FixedMersenneInt::<56, 5>::new(n as u128, &(M1 as u128)) 56 | .inv() 57 | .unwrap() 58 | }) 59 | .reduce(|a, b| a + b) 60 | }) 61 | }); 62 | group.bench_function("mersenne + fermat theorem", |b| { 63 | b.iter(|| { 64 | (100u64..400u64) 65 | .map(|n| { 66 | FixedMersenneInt::<56, 5>::new(n as u128, &(M1 as u128)).pow(&(M1 as u128 - 2)) 67 | }) 68 | .reduce(|a, b| a + b) 69 | }) 70 | }); 71 | 72 | group.finish(); 73 | 74 | const M2: u128 = (1 << 94) - 3; 75 | let mut group = c.benchmark_group("modular inverse (large operands)"); 76 | 77 | group.bench_function("extended gcd", |b| { 78 | b.iter(|| { 79 | (1_000_000_000u128..1_000_000_300u128) 80 | .map(|n| n.invm(&M2).unwrap()) 81 | .reduce(|a, b| a.addm(b, &M2)) 82 | }) 83 | }); 84 | group.bench_function("fermat theorem", |b| { 85 | b.iter(|| { 86 | (1_000_000_000u128..1_000_000_300u128) 87 | .map(|n| n.powm(M2 - 2, &M2)) 88 | .reduce(|a, b| a.addm(b, &M2)) 89 | }) 90 | }); 91 | group.bench_function("mersenne + extended gcd", |b| { 92 | b.iter(|| { 93 | (1_000_000_000u128..1_000_000_300u128) 94 | .map(|n| { 95 | FixedMersenneInt::<94, 3>::new(n, &(M2 as u128)) 96 | .inv() 97 | .unwrap() 98 | }) 99 | .reduce(|a, b| a + b) 100 | }) 101 | }); 102 | group.bench_function("mersenne + fermat theorem", |b| { 103 | b.iter(|| { 104 | (1_000_000_000u128..1_000_000_300u128) 105 | .map(|n| FixedMersenneInt::<94, 3>::new(n, &(M2 as u128)).pow(&(M2 - 2))) 106 | .reduce(|a, b| a + b) 107 | }) 108 | }); 109 | 110 | group.finish(); 111 | } 112 | 113 | criterion_group!(benches, bench_modinv, bench_u128); 114 | criterion_main!(benches); 115 | -------------------------------------------------------------------------------- /src/barrett.rs: -------------------------------------------------------------------------------- 1 | //! All methods that using pre-computed inverse of the modulus will be contained in this module, 2 | //! as it shares the idea of barrett reduction. 3 | 4 | // Version 1: Vanilla barrett reduction (for x mod n, x < n^2) 5 | // - Choose k = ceil(log2(n)) 6 | // - Precompute r = floor(2^(k+1)/n) 7 | // - t = x - floor(x*r/2^(k+1)) * n 8 | // - if t > n, t -= n 9 | // - return t 10 | // 11 | // Version 2: Full width barrett reduction 12 | // - Similar to version 1 but support n up to full width 13 | // - Ref (u128): 14 | // 15 | // Version 3: Floating point barrett reduction 16 | // - Using floating point to store r 17 | // - Ref: 18 | // 19 | // Version 4: "Improved division by invariant integers" by Granlund 20 | // - Ref: 21 | // 22 | // 23 | // Comparison between vanilla Barrett reduction and Montgomery reduction: 24 | // - Barrett reduction requires one 2k-by-k bits and one k-by-k bits multiplication while Montgomery only involves two k-by-k multiplications 25 | // - Extra conversion step is required for Montgomery form to get a normal integer 26 | // (Referece: ) 27 | // 28 | // The latter two versions are efficient and practical for use. 29 | 30 | use crate::reduced::{impl_reduced_binary_pow, Vanilla}; 31 | use crate::{DivExact, ModularUnaryOps, Reducer}; 32 | 33 | /// Divide a Word by a prearranged divisor. 34 | /// 35 | /// Granlund, Montgomerry "Division by Invariant Integers using Multiplication" 36 | /// Algorithm 4.1. 37 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 38 | pub struct PreMulInv1by1 { 39 | // Let n = ceil(log_2(divisor)) 40 | // 2^(n-1) < divisor <= 2^n 41 | // m = floor(B * 2^n / divisor) + 1 - B, where B = 2^N 42 | m: T, 43 | 44 | // shift = n - 1 45 | shift: u32, 46 | } 47 | 48 | macro_rules! impl_premulinv_1by1_for { 49 | ($T:ty) => { 50 | impl PreMulInv1by1<$T> { 51 | pub const fn new(divisor: $T) -> Self { 52 | debug_assert!(divisor > 1); 53 | 54 | // n = ceil(log2(divisor)) 55 | let n = <$T>::BITS - (divisor - 1).leading_zeros(); 56 | 57 | /* Calculate: 58 | * m = floor(B * 2^n / divisor) + 1 - B 59 | * m >= B + 1 - B >= 1 60 | * m <= B * 2^n / (2^(n-1) + 1) + 1 - B 61 | * = (B * 2^n + 2^(n-1) + 1) / (2^(n-1) + 1) - B 62 | * = B * (2^n + 2^(n-1-N) + 2^-N) / (2^(n-1)+1) - B 63 | * < B * (2^n + 2^1) / (2^(n-1)+1) - B 64 | * = B 65 | * So m fits in a Word. 66 | * 67 | * Note: 68 | * divisor * (B + m) = divisor * floor(B * 2^n / divisor + 1) 69 | * = B * 2^n + k, 1 <= k <= divisor 70 | */ 71 | 72 | // m = floor(B * (2^n-1 - (divisor-1)) / divisor) + 1 73 | let (lo, _hi) = split(merge(0, ones(n) - (divisor - 1)) / extend(divisor)); 74 | debug_assert!(_hi == 0); 75 | Self { 76 | shift: n - 1, 77 | m: lo + 1, 78 | } 79 | } 80 | 81 | /// (a / divisor, a % divisor) 82 | #[inline] 83 | pub const fn div_rem(&self, a: $T, d: $T) -> ($T, $T) { 84 | // q = floor( (B + m) * a / (B * 2^n) ) 85 | /* 86 | * Remember that divisor * (B + m) = B * 2^n + k, 1 <= k <= 2^n 87 | * 88 | * (B + m) * a / (B * 2^n) 89 | * = a / divisor * (B * 2^n + k) / (B * 2^n) 90 | * = a / divisor + k * a / (divisor * B * 2^n) 91 | * On one hand, this is >= a / divisor 92 | * On the other hand, this is: 93 | * <= a / divisor + 2^n * (B-1) / (2^n * B) / divisor 94 | * < (a + 1) / divisor 95 | * 96 | * Therefore the floor is always the exact quotient. 97 | */ 98 | 99 | // t = m * n / B 100 | let (_, t) = split(wmul(self.m, a)); 101 | // q = (t + a) / 2^n = (t + (a - t)/2) / 2^(n-1) 102 | let q = (t + ((a - t) >> 1)) >> self.shift; 103 | let r = a - q * d; 104 | (q, r) 105 | } 106 | } 107 | 108 | impl DivExact<$T, PreMulInv1by1<$T>> for $T { 109 | type Output = $T; 110 | 111 | #[inline] 112 | fn div_exact(self, d: $T, pre: &PreMulInv1by1<$T>) -> Option { 113 | let (q, r) = pre.div_rem(self, d); 114 | if r == 0 { 115 | Some(q) 116 | } else { 117 | None 118 | } 119 | } 120 | } 121 | }; 122 | } 123 | 124 | /// Divide a DoubleWord by a prearranged divisor. 125 | /// 126 | /// Assumes quotient fits in a Word. 127 | /// 128 | /// Möller, Granlund, "Improved division by invariant integers", Algorithm 4. 129 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 130 | pub struct Normalized2by1Divisor { 131 | // Normalized (top bit must be set). 132 | divisor: T, 133 | 134 | // floor((B^2 - 1) / divisor) - B, where B = 2^T::BITS 135 | m: T, 136 | } 137 | 138 | macro_rules! impl_normdiv_2by1_for { 139 | ($T:ty, $D:ty) => { 140 | impl Normalized2by1Divisor<$T> { 141 | /// Calculate the inverse m > 0 of a normalized divisor (fit in a word), such that 142 | /// 143 | /// (m + B) * divisor = B^2 - k for some 1 <= k <= divisor 144 | /// 145 | #[inline] 146 | pub const fn invert_word(divisor: $T) -> $T { 147 | let (m, _hi) = split(<$D>::MAX / extend(divisor)); 148 | debug_assert!(_hi == 1); 149 | m 150 | } 151 | 152 | /// Initialize from a given normalized divisor. 153 | /// 154 | /// The divisor must have top bit of 1 155 | #[inline] 156 | pub const fn new(divisor: $T) -> Self { 157 | assert!(divisor.leading_zeros() == 0); 158 | Self { 159 | divisor, 160 | m: Self::invert_word(divisor), 161 | } 162 | } 163 | 164 | /// Returns (a / divisor, a % divisor) 165 | #[inline] 166 | pub const fn div_rem_1by1(&self, a: $T) -> ($T, $T) { 167 | if a < self.divisor { 168 | (0, a) 169 | } else { 170 | (1, a - self.divisor) // because self.divisor is normalized 171 | } 172 | } 173 | 174 | /// Returns (a / divisor, a % divisor) 175 | /// The result must fit in a single word. 176 | #[inline] 177 | pub const fn div_rem_2by1(&self, a: $D) -> ($T, $T) { 178 | let (a_lo, a_hi) = split(a); 179 | debug_assert!(a_hi < self.divisor); 180 | 181 | // Approximate quotient is (m + B) * a / B^2 ~= (m * a/B + a)/B. 182 | // This is q1 below. 183 | // This doesn't overflow because a_hi < self.divisor <= Word::MAX. 184 | let (q0, q1) = split(wmul(self.m, a_hi) + a); 185 | 186 | // q = q1 + 1 is our first approximation, but calculate mod B. 187 | // r = a - q * d 188 | let q = q1.wrapping_add(1); 189 | let r = a_lo.wrapping_sub(q.wrapping_mul(self.divisor)); 190 | 191 | /* Theorem: max(-d, q0+1-B) <= r < max(B-d, q0) 192 | * Proof: 193 | * r = a - q * d = a - q1 * d - d 194 | * = a - (q1 * B + q0 - q0) * d/B - d 195 | * = a - (m * a_hi + a - q0) * d/B - d 196 | * = a - ((m+B) * a_hi + a_lo - q0) * d/B - d 197 | * = a - ((B^2-k)/d * a_hi + a_lo - q0) * d/B - d 198 | * = a - B * a_hi + (a_hi * k - a_lo * d + q0 * d) / B - d 199 | * = (a_hi * k + a_lo * (B - d) + q0 * d) / B - d 200 | * 201 | * r >= q0 * d / B - d 202 | * r >= -d 203 | * r >= d/B (q0 - B) > q0-B 204 | * r >= max(-d, q0+1-B) 205 | * 206 | * r < (d * d + B * (B-d) + q0 * d) / B - d 207 | * = (B-d)^2 / B + q0 * d / B 208 | * = (1 - d/B) * (B-d) + (d/B) * q0 209 | * <= max(B-d, q0) 210 | * QED 211 | */ 212 | 213 | // if r mod B > q0 { q -= 1; r += d; } 214 | // 215 | // Consider two cases: 216 | // a) r >= 0: 217 | // Then r = r mod B > q0, hence r < B-d. Adding d will not overflow r. 218 | // b) r < 0: 219 | // Then r mod B = r-B > q0, and r >= -d, so adding d will make r non-negative. 220 | // In either case, this will result in 0 <= r < B. 221 | 222 | // In a branch-free way: 223 | // decrease = 0xffff.fff = -1 if r mod B > q0, 0 otherwise. 224 | let (_, decrease) = split(extend(q0).wrapping_sub(extend(r))); 225 | let mut q = q.wrapping_add(decrease); 226 | let mut r = r.wrapping_add(decrease & self.divisor); 227 | 228 | // At this point 0 <= r < B, i.e. 0 <= r < 2d. 229 | // the following fix step is unlikely to happen 230 | if r >= self.divisor { 231 | q += 1; 232 | r -= self.divisor; 233 | } 234 | 235 | (q, r) 236 | } 237 | } 238 | }; 239 | } 240 | 241 | /// A wrapper of [Normalized2by1Divisor] that can be used as a [Reducer] 242 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 243 | pub struct PreMulInv2by1 { 244 | div: Normalized2by1Divisor, 245 | shift: u32, 246 | } 247 | 248 | impl PreMulInv2by1 { 249 | #[inline] 250 | pub const fn divider(&self) -> &Normalized2by1Divisor { 251 | &self.div 252 | } 253 | #[inline] 254 | pub const fn shift(&self) -> u32 { 255 | self.shift 256 | } 257 | } 258 | 259 | macro_rules! impl_premulinv_2by1_reducer_for { 260 | ($T:ty) => { 261 | impl PreMulInv2by1<$T> { 262 | #[inline] 263 | pub const fn new(divisor: $T) -> Self { 264 | let shift = divisor.leading_zeros(); 265 | let div = Normalized2by1Divisor::<$T>::new(divisor << shift); 266 | Self { div, shift } 267 | } 268 | 269 | /// Get the **normalized** divisor. 270 | #[inline] 271 | pub const fn divisor(&self) -> $T { 272 | self.div.divisor 273 | } 274 | } 275 | 276 | impl Reducer<$T> for PreMulInv2by1<$T> { 277 | #[inline] 278 | fn new(m: &$T) -> Self { 279 | PreMulInv2by1::<$T>::new(*m) 280 | } 281 | #[inline] 282 | fn transform(&self, target: $T) -> $T { 283 | if self.shift == 0 { 284 | self.div.div_rem_1by1(target).1 285 | } else { 286 | self.div.div_rem_2by1(extend(target) << self.shift).1 287 | } 288 | } 289 | #[inline] 290 | fn check(&self, target: &$T) -> bool { 291 | *target < self.div.divisor && target & ones(self.shift) == 0 292 | } 293 | #[inline] 294 | fn residue(&self, target: $T) -> $T { 295 | target >> self.shift 296 | } 297 | #[inline] 298 | fn modulus(&self) -> $T { 299 | self.div.divisor >> self.shift 300 | } 301 | #[inline] 302 | fn is_zero(&self, target: &$T) -> bool { 303 | *target == 0 304 | } 305 | 306 | #[inline(always)] 307 | fn add(&self, lhs: &$T, rhs: &$T) -> $T { 308 | Vanilla::<$T>::add(&self.div.divisor, *lhs, *rhs) 309 | } 310 | #[inline(always)] 311 | fn dbl(&self, target: $T) -> $T { 312 | Vanilla::<$T>::dbl(&self.div.divisor, target) 313 | } 314 | #[inline(always)] 315 | fn sub(&self, lhs: &$T, rhs: &$T) -> $T { 316 | Vanilla::<$T>::sub(&self.div.divisor, *lhs, *rhs) 317 | } 318 | #[inline(always)] 319 | fn neg(&self, target: $T) -> $T { 320 | Vanilla::<$T>::neg(&self.div.divisor, target) 321 | } 322 | 323 | #[inline(always)] 324 | fn inv(&self, target: $T) -> Option<$T> { 325 | self.residue(target) 326 | .invm(&self.modulus()) 327 | .map(|v| v << self.shift) 328 | } 329 | #[inline] 330 | fn mul(&self, lhs: &$T, rhs: &$T) -> $T { 331 | self.div.div_rem_2by1(wmul(lhs >> self.shift, *rhs)).1 332 | } 333 | #[inline] 334 | fn sqr(&self, target: $T) -> $T { 335 | self.div.div_rem_2by1(wsqr(target) >> self.shift).1 336 | } 337 | 338 | impl_reduced_binary_pow!($T); 339 | } 340 | }; 341 | } 342 | 343 | /// Divide a 3-Word by a prearranged DoubleWord divisor. 344 | /// 345 | /// Assumes quotient fits in a Word. 346 | /// 347 | /// Möller, Granlund, "Improved division by invariant integers" 348 | /// Algorithm 5. 349 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 350 | pub struct Normalized3by2Divisor { 351 | // Top bit must be 1. 352 | divisor: D, 353 | 354 | // floor ((B^3 - 1) / divisor) - B, where B = 2^WORD_BITS 355 | m: T, 356 | } 357 | 358 | macro_rules! impl_normdiv_3by2_for { 359 | ($T:ty, $D:ty) => { 360 | impl Normalized3by2Divisor<$T, $D> { 361 | /// Calculate the inverse m > 0 of a normalized divisor (fit in a DoubleWord), such that 362 | /// 363 | /// (m + B) * divisor = B^3 - k for some 1 <= k <= divisor 364 | /// 365 | /// Möller, Granlund, "Improved division by invariant integers", Algorithm 6. 366 | #[inline] 367 | pub const fn invert_double_word(divisor: $D) -> $T { 368 | let (d0, d1) = split(divisor); 369 | let mut v = Normalized2by1Divisor::<$T>::invert_word(d1); 370 | // then B^2 - d1 <= (B + v)d1 < B^2 371 | 372 | let (mut p, c) = d1.wrapping_mul(v).overflowing_add(d0); 373 | if c { 374 | v -= 1; 375 | if p >= d1 { 376 | v -= 1; 377 | p -= d1; 378 | } 379 | p = p.wrapping_sub(d1); 380 | } 381 | // then B^2 - d1 <= (B + v)d1 + d0 < B^2 382 | 383 | let (t0, t1) = split(extend(v) * extend(d0)); 384 | let (p, c) = p.overflowing_add(t1); 385 | if c { 386 | v -= 1; 387 | if merge(t0, p) >= divisor { 388 | v -= 1; 389 | } 390 | } 391 | 392 | v 393 | } 394 | 395 | /// Initialize from a given normalized divisor. 396 | /// 397 | /// divisor must have top bit of 1 398 | #[inline] 399 | pub const fn new(divisor: $D) -> Self { 400 | assert!(divisor.leading_zeros() == 0); 401 | Self { 402 | divisor, 403 | m: Self::invert_double_word(divisor), 404 | } 405 | } 406 | 407 | #[inline] 408 | pub const fn div_rem_2by2(&self, a: $D) -> ($D, $D) { 409 | if a < self.divisor { 410 | (0, a) 411 | } else { 412 | (1, a - self.divisor) // because self.divisor is normalized 413 | } 414 | } 415 | 416 | /// The input a is arranged as (lo, mi & hi) 417 | /// The output is (a / divisor, a % divisor) 418 | pub const fn div_rem_3by2(&self, a_lo: $T, a_hi: $D) -> ($T, $D) { 419 | debug_assert!(a_hi < self.divisor); 420 | let (a1, a2) = split(a_hi); 421 | let (d0, d1) = split(self.divisor); 422 | 423 | // This doesn't overflow because a2 <= self.divisor / B <= Word::MAX. 424 | let (q0, q1) = split(wmul(self.m, a2) + a_hi); 425 | let r1 = a1.wrapping_sub(q1.wrapping_mul(d1)); 426 | let t = wmul(d0, q1); 427 | let r = merge(a_lo, r1).wrapping_sub(t).wrapping_sub(self.divisor); 428 | 429 | // The first guess of quotient is q1 + 1 430 | // if r1 >= q0 { r += d; } else { q1 += 1; } 431 | // In a branch-free way: 432 | // decrease = 0 if r1 >= q0, = 0xffff.fff = -1 otherwise 433 | let (_, r1) = split(r); 434 | let (_, decrease) = split(extend(r1).wrapping_sub(extend(q0))); 435 | let mut q1 = q1.wrapping_sub(decrease); 436 | let mut r = r.wrapping_add(merge(!decrease, !decrease) & self.divisor); 437 | 438 | // the following fix step is unlikely to happen 439 | if r >= self.divisor { 440 | q1 += 1; 441 | r -= self.divisor; 442 | } 443 | 444 | (q1, r) 445 | } 446 | 447 | /// Divdide a 4-word number with double word divisor 448 | /// 449 | /// The output is (a / divisor, a % divisor) 450 | pub const fn div_rem_4by2(&self, a_lo: $D, a_hi: $D) -> ($D, $D) { 451 | let (a0, a1) = split(a_lo); 452 | let (q1, r1) = self.div_rem_3by2(a1, a_hi); 453 | let (q0, r0) = self.div_rem_3by2(a0, r1); 454 | (merge(q0, q1), r0) 455 | } 456 | } 457 | }; 458 | } 459 | 460 | /// A wrapper of [Normalized3by2Divisor] that can be used as a [Reducer] 461 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 462 | pub struct PreMulInv3by2 { 463 | div: Normalized3by2Divisor, 464 | shift: u32, 465 | } 466 | 467 | impl PreMulInv3by2 { 468 | #[inline] 469 | pub const fn divider(&self) -> &Normalized3by2Divisor { 470 | &self.div 471 | } 472 | #[inline] 473 | pub const fn shift(&self) -> u32 { 474 | self.shift 475 | } 476 | } 477 | 478 | macro_rules! impl_premulinv_3by2_reducer_for { 479 | ($T:ty, $D:ty) => { 480 | impl PreMulInv3by2<$T, $D> { 481 | #[inline] 482 | pub const fn new(divisor: $D) -> Self { 483 | let shift = divisor.leading_zeros(); 484 | let div = Normalized3by2Divisor::<$T, $D>::new(divisor << shift); 485 | Self { div, shift } 486 | } 487 | 488 | /// Get the **normalized** divisor. 489 | #[inline] 490 | pub const fn divisor(&self) -> $D { 491 | self.div.divisor 492 | } 493 | } 494 | 495 | impl Reducer<$D> for PreMulInv3by2<$T, $D> { 496 | #[inline] 497 | fn new(m: &$D) -> Self { 498 | assert!(*m > <$T>::MAX as $D); 499 | let shift = m.leading_zeros(); 500 | let div = Normalized3by2Divisor::<$T, $D>::new(m << shift); 501 | Self { div, shift } 502 | } 503 | #[inline] 504 | fn transform(&self, target: $D) -> $D { 505 | if self.shift == 0 { 506 | self.div.div_rem_2by2(target).1 507 | } else { 508 | let (lo, hi) = split(target); 509 | let (n0, carry) = split(extend(lo) << self.shift); 510 | let n12 = (extend(hi) << self.shift) | extend(carry); 511 | self.div.div_rem_3by2(n0, n12).1 512 | } 513 | } 514 | #[inline] 515 | fn check(&self, target: &$D) -> bool { 516 | *target < self.div.divisor && split(*target).0 & ones(self.shift) == 0 517 | } 518 | #[inline] 519 | fn residue(&self, target: $D) -> $D { 520 | target >> self.shift 521 | } 522 | #[inline] 523 | fn modulus(&self) -> $D { 524 | self.div.divisor >> self.shift 525 | } 526 | #[inline] 527 | fn is_zero(&self, target: &$D) -> bool { 528 | *target == 0 529 | } 530 | 531 | #[inline(always)] 532 | fn add(&self, lhs: &$D, rhs: &$D) -> $D { 533 | Vanilla::<$D>::add(&self.div.divisor, *lhs, *rhs) 534 | } 535 | #[inline(always)] 536 | fn dbl(&self, target: $D) -> $D { 537 | Vanilla::<$D>::dbl(&self.div.divisor, target) 538 | } 539 | #[inline(always)] 540 | fn sub(&self, lhs: &$D, rhs: &$D) -> $D { 541 | Vanilla::<$D>::sub(&self.div.divisor, *lhs, *rhs) 542 | } 543 | #[inline(always)] 544 | fn neg(&self, target: $D) -> $D { 545 | Vanilla::<$D>::neg(&self.div.divisor, target) 546 | } 547 | 548 | #[inline(always)] 549 | fn inv(&self, target: $D) -> Option<$D> { 550 | self.residue(target) 551 | .invm(&self.modulus()) 552 | .map(|v| v << self.shift) 553 | } 554 | #[inline] 555 | fn mul(&self, lhs: &$D, rhs: &$D) -> $D { 556 | let prod = DoubleWordModule::wmul(lhs >> self.shift, *rhs); 557 | let (lo, hi) = DoubleWordModule::split(prod); 558 | self.div.div_rem_4by2(lo, hi).1 559 | } 560 | #[inline] 561 | fn sqr(&self, target: $D) -> $D { 562 | let prod = DoubleWordModule::wsqr(target) >> self.shift; 563 | let (lo, hi) = DoubleWordModule::split(prod); 564 | self.div.div_rem_4by2(lo, hi).1 565 | } 566 | 567 | impl_reduced_binary_pow!($D); 568 | } 569 | }; 570 | } 571 | 572 | macro_rules! collect_impls { 573 | ($T:ident, $ns:ident) => { 574 | mod $ns { 575 | use super::*; 576 | use crate::word::$T::*; 577 | 578 | impl_premulinv_1by1_for!(Word); 579 | impl_normdiv_2by1_for!(Word, DoubleWord); 580 | impl_premulinv_2by1_reducer_for!(Word); 581 | impl_normdiv_3by2_for!(Word, DoubleWord); 582 | impl_premulinv_3by2_reducer_for!(Word, DoubleWord); 583 | } 584 | }; 585 | } 586 | collect_impls!(u8, u8_impl); 587 | collect_impls!(u16, u16_impl); 588 | collect_impls!(u32, u32_impl); 589 | collect_impls!(u64, u64_impl); 590 | collect_impls!(usize, usize_impl); 591 | 592 | #[cfg(test)] 593 | mod tests { 594 | use super::*; 595 | use crate::reduced::tests::ReducedTester; 596 | use rand::prelude::*; 597 | 598 | #[test] 599 | fn test_mul_inv_1by1() { 600 | type Word = u64; 601 | let mut rng = StdRng::seed_from_u64(1); 602 | for _ in 0..400000 { 603 | let d_bits = rng.gen_range(2..=Word::BITS); 604 | let max_d = Word::MAX >> (Word::BITS - d_bits); 605 | let d = rng.gen_range(max_d / 2 + 1..=max_d); 606 | let fast_div = PreMulInv1by1::::new(d); 607 | let n = rng.gen(); 608 | let (q, r) = fast_div.div_rem(n, d); 609 | assert_eq!(q, n / d); 610 | assert_eq!(r, n % d); 611 | 612 | if r == 0 { 613 | assert_eq!(n.div_exact(d, &fast_div), Some(q)); 614 | } else { 615 | assert_eq!(n.div_exact(d, &fast_div), None); 616 | } 617 | } 618 | } 619 | 620 | #[test] 621 | fn test_mul_inv_2by1() { 622 | type Word = u64; 623 | type Divider = Normalized2by1Divisor; 624 | use crate::word::u64::*; 625 | 626 | let fast_div = Divider::new(Word::MAX); 627 | assert_eq!(fast_div.div_rem_2by1(0), (0, 0)); 628 | 629 | let mut rng = StdRng::seed_from_u64(1); 630 | for _ in 0..200000 { 631 | let d = rng.gen_range(Word::MAX / 2 + 1..=Word::MAX); 632 | let q = rng.gen(); 633 | let r = rng.gen_range(0..d); 634 | let (a0, a1) = split(wmul(q, d) + extend(r)); 635 | let fast_div = Divider::new(d); 636 | assert_eq!(fast_div.div_rem_2by1(merge(a0, a1)), (q, r)); 637 | } 638 | } 639 | 640 | #[test] 641 | fn test_mul_inv_3by2() { 642 | type Word = u64; 643 | type DoubleWord = u128; 644 | type Divider = Normalized3by2Divisor; 645 | use crate::word::u64::*; 646 | 647 | let d = DoubleWord::MAX; 648 | let fast_div = Divider::new(d); 649 | assert_eq!(fast_div.div_rem_3by2(0, 0), (0, 0)); 650 | 651 | let mut rng = StdRng::seed_from_u64(1); 652 | for _ in 0..100000 { 653 | let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX); 654 | let r = rng.gen_range(0..d); 655 | let q = rng.gen(); 656 | 657 | let (d0, d1) = split(d); 658 | let (r0, r1) = split(r); 659 | let (a0, c) = split(wmul(q, d0) + extend(r0)); 660 | let (a1, a2) = split(wmul(q, d1) + extend(r1) + extend(c)); 661 | let a12 = merge(a1, a2); 662 | 663 | let fast_div = Divider::new(d); 664 | assert_eq!( 665 | fast_div.div_rem_3by2(a0, a12), 666 | (q, r), 667 | "failed at {:?} / {}", 668 | (a0, a12), 669 | d 670 | ); 671 | } 672 | } 673 | 674 | #[test] 675 | fn test_mul_inv_4by2() { 676 | type Word = u64; 677 | type DoubleWord = u128; 678 | type Divider = Normalized3by2Divisor; 679 | use crate::word::u128::*; 680 | 681 | let mut rng = StdRng::seed_from_u64(1); 682 | for _ in 0..20000 { 683 | let d = rng.gen_range(DoubleWord::MAX / 2 + 1..=DoubleWord::MAX); 684 | let q = rng.gen(); 685 | let r = rng.gen_range(0..d); 686 | let (a_lo, a_hi) = split(wmul(q, d) + r as DoubleWord); 687 | let fast_div = Divider::new(d); 688 | assert_eq!(fast_div.div_rem_4by2(a_lo, a_hi), (q, r)); 689 | } 690 | } 691 | 692 | #[test] 693 | fn test_2by1_against_modops() { 694 | for _ in 0..10 { 695 | ReducedTester::::test_against_modops::>(0); 696 | ReducedTester::::test_against_modops::>(0); 697 | ReducedTester::::test_against_modops::>(0); 698 | ReducedTester::::test_against_modops::>(0); 699 | // ReducedTester::::test_against_modops::>(); 700 | ReducedTester::::test_against_modops::>(0); 701 | } 702 | } 703 | 704 | #[test] 705 | fn test_3by2_against_modops() { 706 | for _ in 0..10 { 707 | ReducedTester::::test_against_modops::>(2); 708 | ReducedTester::::test_against_modops::>(2); 709 | ReducedTester::::test_against_modops::>(2); 710 | ReducedTester::::test_against_modops::>(2); 711 | } 712 | } 713 | } 714 | -------------------------------------------------------------------------------- /src/bigint.rs: -------------------------------------------------------------------------------- 1 | use crate::{ModularAbs, ModularCoreOps, ModularPow, ModularSymbols, ModularUnaryOps}; 2 | use core::convert::TryInto; 3 | use num_integer::Integer; 4 | use num_traits::{One, ToPrimitive, Zero}; 5 | 6 | // Efficient implementation for bigints can be found in "Handbook of Applied Cryptography" 7 | // Reference: https://cacr.uwaterloo.ca/hac/about/chap14.pdf 8 | 9 | // Forward modular operations to ref by ref 10 | macro_rules! impl_mod_ops_by_ref { 11 | ($T:ty) => { 12 | // core ops 13 | impl ModularCoreOps<$T, &$T> for &$T { 14 | type Output = $T; 15 | #[inline] 16 | fn addm(self, rhs: $T, m: &$T) -> $T { 17 | self.addm(&rhs, &m) 18 | } 19 | #[inline] 20 | fn subm(self, rhs: $T, m: &$T) -> $T { 21 | self.subm(&rhs, &m) 22 | } 23 | #[inline] 24 | fn mulm(self, rhs: $T, m: &$T) -> $T { 25 | self.mulm(&rhs, &m) 26 | } 27 | } 28 | impl ModularCoreOps<&$T, &$T> for $T { 29 | type Output = $T; 30 | #[inline] 31 | fn addm(self, rhs: &$T, m: &$T) -> $T { 32 | (&self).addm(rhs, &m) 33 | } 34 | #[inline] 35 | fn subm(self, rhs: &$T, m: &$T) -> $T { 36 | (&self).subm(rhs, &m) 37 | } 38 | #[inline] 39 | fn mulm(self, rhs: &$T, m: &$T) -> $T { 40 | (&self).mulm(rhs, &m) 41 | } 42 | } 43 | impl ModularCoreOps<$T, &$T> for $T { 44 | type Output = $T; 45 | #[inline] 46 | fn addm(self, rhs: $T, m: &$T) -> $T { 47 | (&self).addm(&rhs, &m) 48 | } 49 | #[inline] 50 | fn subm(self, rhs: $T, m: &$T) -> $T { 51 | (&self).subm(&rhs, &m) 52 | } 53 | #[inline] 54 | fn mulm(self, rhs: $T, m: &$T) -> $T { 55 | (&self).mulm(&rhs, &m) 56 | } 57 | } 58 | 59 | // pow 60 | impl ModularPow<$T, &$T> for &$T { 61 | type Output = $T; 62 | #[inline] 63 | fn powm(self, exp: $T, m: &$T) -> $T { 64 | self.powm(&exp, &m) 65 | } 66 | } 67 | impl ModularPow<&$T, &$T> for $T { 68 | type Output = $T; 69 | #[inline] 70 | fn powm(self, exp: &$T, m: &$T) -> $T { 71 | (&self).powm(exp, &m) 72 | } 73 | } 74 | impl ModularPow<$T, &$T> for $T { 75 | type Output = $T; 76 | #[inline] 77 | fn powm(self, exp: $T, m: &$T) -> $T { 78 | (&self).powm(&exp, &m) 79 | } 80 | } 81 | 82 | // unary ops and symbols 83 | impl ModularUnaryOps<&$T> for $T { 84 | type Output = $T; 85 | #[inline] 86 | fn negm(self, m: &$T) -> $T { 87 | ModularUnaryOps::<&$T>::negm(&self, m) 88 | } 89 | #[inline] 90 | fn invm(self, m: &$T) -> Option<$T> { 91 | ModularUnaryOps::<&$T>::invm(&self, m) 92 | } 93 | #[inline] 94 | fn dblm(self, m: &$T) -> $T { 95 | ModularUnaryOps::<&$T>::dblm(&self, m) 96 | } 97 | #[inline] 98 | fn sqm(self, m: &$T) -> $T { 99 | ModularUnaryOps::<&$T>::sqm(&self, m) 100 | } 101 | } 102 | }; 103 | } 104 | 105 | #[cfg(feature = "num-bigint")] 106 | mod _num_bigint { 107 | use super::*; 108 | use num_bigint::{BigInt, BigUint}; 109 | use num_traits::Signed; 110 | 111 | impl ModularCoreOps<&BigUint, &BigUint> for &BigUint { 112 | type Output = BigUint; 113 | 114 | #[inline] 115 | fn addm(self, rhs: &BigUint, m: &BigUint) -> BigUint { 116 | (self + rhs) % m 117 | } 118 | fn subm(self, rhs: &BigUint, m: &BigUint) -> BigUint { 119 | let (lhs, rhs) = (self % m, rhs % m); 120 | if lhs >= rhs { 121 | lhs - rhs 122 | } else { 123 | m - (rhs - lhs) 124 | } 125 | } 126 | 127 | fn mulm(self, rhs: &BigUint, m: &BigUint) -> BigUint { 128 | let a = self % m; 129 | let b = rhs % m; 130 | 131 | if let Some(sm) = m.to_usize() { 132 | let sself = a.to_usize().unwrap(); 133 | let srhs = b.to_usize().unwrap(); 134 | return BigUint::from(sself.mulm(srhs, &sm)); 135 | } 136 | 137 | (a * b) % m 138 | } 139 | } 140 | 141 | impl ModularUnaryOps<&BigUint> for &BigUint { 142 | type Output = BigUint; 143 | #[inline] 144 | fn negm(self, m: &BigUint) -> BigUint { 145 | let x = self % m; 146 | if x.is_zero() { 147 | BigUint::zero() 148 | } else { 149 | m - x 150 | } 151 | } 152 | 153 | fn invm(self, m: &BigUint) -> Option { 154 | let x = if self >= m { self % m } else { self.clone() }; 155 | 156 | let (mut last_r, mut r) = (m.clone(), x); 157 | let (mut last_t, mut t) = (BigUint::zero(), BigUint::one()); 158 | 159 | while r > BigUint::zero() { 160 | let (quo, rem) = last_r.div_rem(&r); 161 | last_r = r; 162 | r = rem; 163 | 164 | let new_t = last_t.subm(&quo.mulm(&t, m), m); 165 | last_t = t; 166 | t = new_t; 167 | } 168 | 169 | // if r = gcd(self, m) > 1, then inverse doesn't exist 170 | if last_r > BigUint::one() { 171 | None 172 | } else { 173 | Some(last_t) 174 | } 175 | } 176 | 177 | #[inline] 178 | fn dblm(self, m: &BigUint) -> BigUint { 179 | let x = self % m; 180 | let d = x << 1; 181 | if &d > m { 182 | d - m 183 | } else { 184 | d 185 | } 186 | } 187 | 188 | #[inline] 189 | fn sqm(self, m: &BigUint) -> BigUint { 190 | self.modpow(&BigUint::from(2u8), m) 191 | } 192 | } 193 | 194 | impl ModularPow<&BigUint, &BigUint> for &BigUint { 195 | type Output = BigUint; 196 | #[inline] 197 | fn powm(self, exp: &BigUint, m: &BigUint) -> BigUint { 198 | self.modpow(exp, m) 199 | } 200 | } 201 | 202 | impl ModularSymbols<&BigUint> for BigUint { 203 | #[inline] 204 | fn checked_legendre(&self, n: &BigUint) -> Option { 205 | let r = self.powm((n - 1u8) >> 1u8, n); 206 | if r.is_zero() { 207 | Some(0) 208 | } else if r.is_one() { 209 | Some(1) 210 | } else if &(r + 1u8) == n { 211 | Some(-1) 212 | } else { 213 | None 214 | } 215 | } 216 | 217 | fn checked_jacobi(&self, n: &BigUint) -> Option { 218 | if n.is_even() { 219 | return None; 220 | } 221 | if self.is_zero() { 222 | return Some(if n.is_one() { 1 } else { 0 }); 223 | } 224 | if self.is_one() { 225 | return Some(1); 226 | } 227 | 228 | let three = BigUint::from(3u8); 229 | let five = BigUint::from(5u8); 230 | let seven = BigUint::from(7u8); 231 | 232 | let mut a = self % n; 233 | let mut n = n.clone(); 234 | let mut t = 1; 235 | while a > BigUint::zero() { 236 | while a.is_even() { 237 | a >>= 1; 238 | if &n & &seven == three || &n & &seven == five { 239 | t *= -1; 240 | } 241 | } 242 | core::mem::swap(&mut a, &mut n); 243 | if (&a & &three) == three && (&n & &three) == three { 244 | t *= -1; 245 | } 246 | a %= &n; 247 | } 248 | Some(if n.is_one() { t } else { 0 }) 249 | } 250 | 251 | #[inline] 252 | fn kronecker(&self, n: &BigUint) -> i8 { 253 | if n.is_zero() { 254 | return if self.is_one() { 1 } else { 0 }; 255 | } 256 | if n.is_one() { 257 | return 1; 258 | } 259 | if n == &BigUint::from(2u8) { 260 | return if self.is_even() { 261 | 0 262 | } else { 263 | let seven = BigUint::from(7u8); 264 | if (self & &seven).is_one() || self & &seven == seven { 265 | 1 266 | } else { 267 | -1 268 | } 269 | }; 270 | } 271 | 272 | let f = n.trailing_zeros().unwrap_or(0); 273 | let n = n >> f; 274 | let t1 = self.kronecker(&BigUint::from(2u8)); 275 | let t2 = self.jacobi(&n); 276 | t1.pow(f.try_into().unwrap()) * t2 277 | } 278 | } 279 | 280 | impl ModularSymbols<&BigInt> for BigInt { 281 | #[inline] 282 | fn checked_legendre(&self, n: &BigInt) -> Option { 283 | if n < &BigInt::one() { 284 | return None; 285 | } 286 | self.mod_floor(n) 287 | .magnitude() 288 | .checked_legendre(n.magnitude()) 289 | } 290 | 291 | fn checked_jacobi(&self, n: &BigInt) -> Option { 292 | if n < &BigInt::one() { 293 | return None; 294 | } 295 | self.mod_floor(n).magnitude().checked_jacobi(n.magnitude()) 296 | } 297 | 298 | #[inline] 299 | fn kronecker(&self, n: &BigInt) -> i8 { 300 | if n.is_negative() { 301 | if n.magnitude().is_one() { 302 | return if self.is_negative() { -1 } else { 1 }; 303 | } else { 304 | return self.kronecker(&-BigInt::one()) * self.kronecker(&-n); 305 | } 306 | } 307 | 308 | // n is positive from now on 309 | let n = n.magnitude(); 310 | if n.is_zero() { 311 | return if self.is_one() { 1 } else { 0 }; 312 | } 313 | if n.is_one() { 314 | return 1; 315 | } 316 | if n == &BigUint::from(2u8) { 317 | return if self.is_even() { 318 | 0 319 | } else { 320 | let eight = BigInt::from(8u8); 321 | if (self.mod_floor(&eight)).is_one() 322 | || self.mod_floor(&eight) == BigInt::from(7u8) 323 | { 324 | 1 325 | } else { 326 | -1 327 | } 328 | }; 329 | } 330 | 331 | let f = n.trailing_zeros().unwrap_or(0); 332 | let n = n >> f; 333 | let t1 = self.kronecker(&BigInt::from(2u8)); 334 | let t2 = self.jacobi(&n.into()); 335 | t1.pow(f.try_into().unwrap()) * t2 336 | } 337 | } 338 | 339 | impl_mod_ops_by_ref!(BigUint); 340 | 341 | impl ModularAbs for BigInt { 342 | fn absm(self, m: &BigUint) -> BigUint { 343 | if self.is_negative() { 344 | self.magnitude().negm(m) 345 | } else { 346 | self.magnitude() % m 347 | } 348 | } 349 | } 350 | 351 | #[cfg(test)] 352 | mod tests { 353 | use super::*; 354 | use rand::random; 355 | 356 | const NRANDOM: u32 = 10; // number of random tests to run 357 | 358 | #[test] 359 | fn basic_tests() { 360 | for _ in 0..NRANDOM { 361 | let a = random::(); 362 | let ra = &BigUint::from(a); 363 | let b = random::(); 364 | let rb = &BigUint::from(b); 365 | let m = random::() | 1; 366 | let rm = &BigUint::from(m); 367 | assert_eq!(ra.addm(rb, rm), (ra + rb) % rm); 368 | assert_eq!(ra.mulm(rb, rm), (ra * rb) % rm); 369 | 370 | let a = random::(); 371 | let ra = &BigUint::from(a); 372 | let e = random::(); 373 | let re = &BigUint::from(e); 374 | let m = random::() | 1; 375 | let rm = &BigUint::from(m); 376 | assert_eq!(ra.powm(re, rm), ra.pow(e as u32) % rm); 377 | } 378 | } 379 | 380 | #[test] 381 | fn test_against_prim() { 382 | for _ in 0..NRANDOM { 383 | let a = random::(); 384 | let ra = &BigUint::from(a); 385 | let b = random::(); 386 | let rb = &BigUint::from(b); 387 | let m = random::(); 388 | let rm = &BigUint::from(m); 389 | assert_eq!(ra.addm(rb, rm), a.addm(b, &m).into()); 390 | assert_eq!(ra.subm(rb, rm), a.subm(b, &m).into()); 391 | assert_eq!(ra.mulm(rb, rm), a.mulm(b, &m).into()); 392 | assert_eq!(ra.negm(rm), a.negm(&m).into()); 393 | assert_eq!(ra.invm(rm), a.invm(&m).map(|v| v.into())); 394 | assert_eq!(ra.checked_legendre(rm), a.checked_legendre(&m)); 395 | assert_eq!(ra.checked_jacobi(rm), a.checked_jacobi(&m)); 396 | assert_eq!(ra.kronecker(rm), a.kronecker(&m)); 397 | 398 | let e = random::(); 399 | let re = &BigUint::from(e); 400 | assert_eq!(ra.powm(re, rm), a.powm(e as u128, &m).into()); 401 | 402 | // signed integers 403 | let a = random::(); 404 | let ra = &BigInt::from(a); 405 | let m = random::(); 406 | let rm = &BigInt::from(m); 407 | assert_eq!(ra.checked_legendre(rm), a.checked_legendre(&m)); 408 | assert_eq!(ra.checked_jacobi(rm), a.checked_jacobi(&m)); 409 | assert_eq!(ra.kronecker(rm), a.kronecker(&m)); 410 | } 411 | } 412 | } 413 | } 414 | -------------------------------------------------------------------------------- /src/double.rs: -------------------------------------------------------------------------------- 1 | //! This module implements a double width integer type based on the largest built-in integer (u128) 2 | //! Part of the optimization comes from `ethnum` and `zkp-u256` crates. 3 | 4 | use core::ops::*; 5 | 6 | /// Alias of the builtin integer type with max width (currently [u128]) 7 | #[allow(non_camel_case_types)] 8 | pub type umax = u128; 9 | 10 | const HALF_BITS: u32 = umax::BITS / 2; 11 | 12 | // Split umax into hi and lo parts. Tt's critical to use inline here 13 | #[inline(always)] 14 | const fn split(v: umax) -> (umax, umax) { 15 | (v >> HALF_BITS, v & (umax::MAX >> HALF_BITS)) 16 | } 17 | 18 | #[inline(always)] 19 | const fn div_rem(n: umax, d: umax) -> (umax, umax) { 20 | (n / d, n % d) 21 | } 22 | 23 | #[allow(non_camel_case_types)] 24 | #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] 25 | /// A double width integer type based on the largest built-in integer type [umax] (currently [u128]), and 26 | /// to support double-width operations on it is the only goal for this type. 27 | /// 28 | /// Although it can be regarded as u256, it's not as feature-rich as in other crates 29 | /// since it's only designed to support this crate and few other crates (will be noted in comments). 30 | pub struct udouble { 31 | /// Most significant part 32 | pub hi: umax, 33 | /// Least significant part 34 | pub lo: umax, 35 | } 36 | 37 | impl udouble { 38 | pub const MAX: Self = Self { 39 | lo: umax::MAX, 40 | hi: umax::MAX, 41 | }; 42 | 43 | //> (used in u128::addm) 44 | #[inline] 45 | pub const fn widening_add(lhs: umax, rhs: umax) -> Self { 46 | let (sum, carry) = lhs.overflowing_add(rhs); 47 | udouble { 48 | hi: carry as umax, 49 | lo: sum, 50 | } 51 | } 52 | 53 | /// Calculate multiplication of two [umax] integers with result represented in double width integer 54 | // equivalent to umul_ppmm, can be implemented efficiently with carrying_mul and widening_mul implemented (rust#85532) 55 | //> (used in u128::mulm, MersenneInt, Montgomery::::{reduce, mul}, num-order::NumHash) 56 | #[inline] 57 | pub const fn widening_mul(lhs: umax, rhs: umax) -> Self { 58 | let ((x1, x0), (y1, y0)) = (split(lhs), split(rhs)); 59 | 60 | let z2 = x1 * y1; 61 | let (c0, z0) = split(x0 * y0); // c0 <= umax::MAX - 1 62 | let (c1, z1) = split(x1 * y0 + c0); 63 | let z2 = z2 + c1; 64 | let (c1, z1) = split(x0 * y1 + z1); 65 | Self { 66 | hi: z2 + c1, 67 | lo: z0 | z1 << HALF_BITS, 68 | } 69 | } 70 | 71 | /// Optimized squaring function for [umax] integers 72 | //> (used in Montgomery::::{square}) 73 | #[inline] 74 | pub const fn widening_square(x: umax) -> Self { 75 | // the algorithm here is basically the same as widening_mul 76 | let (x1, x0) = split(x); 77 | 78 | let z2 = x1 * x1; 79 | let m = x1 * x0; 80 | let (c0, z0) = split(x0 * x0); 81 | let (c1, z1) = split(m + c0); 82 | let z2 = z2 + c1; 83 | let (c1, z1) = split(m + z1); 84 | Self { 85 | hi: z2 + c1, 86 | lo: z0 | z1 << HALF_BITS, 87 | } 88 | } 89 | 90 | //> (used in Montgomery::::reduce) 91 | #[inline] 92 | pub const fn overflowing_add(&self, rhs: Self) -> (Self, bool) { 93 | let (lo, carry) = self.lo.overflowing_add(rhs.lo); 94 | let (hi, of1) = self.hi.overflowing_add(rhs.hi); 95 | let (hi, of2) = hi.overflowing_add(carry as umax); 96 | (Self { lo, hi }, of1 || of2) 97 | } 98 | 99 | // double by double multiplication, listed here in case of future use 100 | #[allow(dead_code)] 101 | fn overflowing_mul(&self, rhs: Self) -> (Self, bool) { 102 | let c2 = self.hi != 0 && rhs.hi != 0; 103 | let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs.lo); 104 | let (z1x, c1x) = umax::overflowing_mul(self.lo, rhs.hi); 105 | let (z1y, c1y) = umax::overflowing_mul(self.hi, rhs.lo); 106 | let (z1z, c1z) = umax::overflowing_add(z1x, z1y); 107 | let (z1, c1) = z1z.overflowing_add(c0); 108 | (Self { hi: z1, lo: z0 }, c1x | c1y | c1z | c1 | c2) 109 | } 110 | 111 | /// Multiplication of double width and single width 112 | //> (used in num-order:NumHash) 113 | #[inline] 114 | pub const fn overflowing_mul1(&self, rhs: umax) -> (Self, bool) { 115 | let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs); 116 | let (z1, c1) = self.hi.overflowing_mul(rhs); 117 | let (z1, cs1) = z1.overflowing_add(c0); 118 | (Self { hi: z1, lo: z0 }, c1 | cs1) 119 | } 120 | 121 | /// Multiplication of double width and single width 122 | //> (used in Self::mul::) 123 | #[inline] 124 | pub fn checked_mul1(&self, rhs: umax) -> Option { 125 | let Self { lo: z0, hi: c0 } = Self::widening_mul(self.lo, rhs); 126 | let z1 = self.hi.checked_mul(rhs)?.checked_add(c0)?; 127 | Some(Self { hi: z1, lo: z0 }) 128 | } 129 | 130 | //> (used in num-order::NumHash) 131 | #[inline] 132 | pub fn checked_shl(self, rhs: u32) -> Option { 133 | if rhs < umax::BITS * 2 { 134 | Some(self << rhs) 135 | } else { 136 | None 137 | } 138 | } 139 | 140 | //> (not used yet) 141 | #[inline] 142 | pub fn checked_shr(self, rhs: u32) -> Option { 143 | if rhs < umax::BITS * 2 { 144 | Some(self >> rhs) 145 | } else { 146 | None 147 | } 148 | } 149 | } 150 | 151 | impl From for udouble { 152 | #[inline] 153 | fn from(v: umax) -> Self { 154 | Self { lo: v, hi: 0 } 155 | } 156 | } 157 | 158 | impl Add for udouble { 159 | type Output = udouble; 160 | 161 | // equivalent to add_ssaaaa 162 | #[inline] 163 | fn add(self, rhs: Self) -> Self::Output { 164 | let (lo, carry) = self.lo.overflowing_add(rhs.lo); 165 | let hi = self.hi + rhs.hi + carry as umax; 166 | Self { lo, hi } 167 | } 168 | } 169 | //> (used in Self::div_rem) 170 | impl Add for udouble { 171 | type Output = udouble; 172 | #[inline] 173 | fn add(self, rhs: umax) -> Self::Output { 174 | let (lo, carry) = self.lo.overflowing_add(rhs); 175 | let hi = if carry { self.hi + 1 } else { self.hi }; 176 | Self { lo, hi } 177 | } 178 | } 179 | impl AddAssign for udouble { 180 | #[inline] 181 | fn add_assign(&mut self, rhs: Self) { 182 | let (lo, carry) = self.lo.overflowing_add(rhs.lo); 183 | self.lo = lo; 184 | self.hi += rhs.hi + carry as umax; 185 | } 186 | } 187 | impl AddAssign for udouble { 188 | #[inline] 189 | fn add_assign(&mut self, rhs: umax) { 190 | let (lo, carry) = self.lo.overflowing_add(rhs); 191 | self.lo = lo; 192 | if carry { 193 | self.hi += 1 194 | } 195 | } 196 | } 197 | 198 | //> (used in test of Add) 199 | impl Sub for udouble { 200 | type Output = Self; 201 | #[inline] 202 | fn sub(self, rhs: Self) -> Self::Output { 203 | let carry = self.lo < rhs.lo; 204 | let lo = self.lo.wrapping_sub(rhs.lo); 205 | let hi = self.hi - rhs.hi - carry as umax; 206 | Self { lo, hi } 207 | } 208 | } 209 | impl Sub for udouble { 210 | type Output = Self; 211 | #[inline] 212 | fn sub(self, rhs: umax) -> Self::Output { 213 | let carry = self.lo < rhs; 214 | let lo = self.lo.wrapping_sub(rhs); 215 | let hi = if carry { self.hi - 1 } else { self.hi }; 216 | Self { lo, hi } 217 | } 218 | } 219 | //> (used in test of AddAssign) 220 | impl SubAssign for udouble { 221 | #[inline] 222 | fn sub_assign(&mut self, rhs: Self) { 223 | let carry = self.lo < rhs.lo; 224 | self.lo = self.lo.wrapping_sub(rhs.lo); 225 | self.hi -= rhs.hi + carry as umax; 226 | } 227 | } 228 | impl SubAssign for udouble { 229 | #[inline] 230 | fn sub_assign(&mut self, rhs: umax) { 231 | let carry = self.lo < rhs; 232 | self.lo = self.lo.wrapping_sub(rhs); 233 | if carry { 234 | self.hi -= 1; 235 | } 236 | } 237 | } 238 | 239 | macro_rules! impl_sh_ops { 240 | ($t:ty) => { 241 | //> (used in Self::checked_shl) 242 | impl Shl<$t> for udouble { 243 | type Output = Self; 244 | #[inline] 245 | fn shl(self, rhs: $t) -> Self::Output { 246 | match rhs { 247 | 0 => self, // avoid shifting by full bits, which is UB 248 | s if s >= umax::BITS as $t => Self { 249 | hi: self.lo << (s - umax::BITS as $t), 250 | lo: 0, 251 | }, 252 | s => Self { 253 | lo: self.lo << s, 254 | hi: (self.hi << s) | (self.lo >> (umax::BITS as $t - s)), 255 | }, 256 | } 257 | } 258 | } 259 | //> (not used yet) 260 | impl ShlAssign<$t> for udouble { 261 | #[inline] 262 | fn shl_assign(&mut self, rhs: $t) { 263 | match rhs { 264 | 0 => {} 265 | s if s >= umax::BITS as $t => { 266 | self.hi = self.lo << (s - umax::BITS as $t); 267 | self.lo = 0; 268 | } 269 | s => { 270 | self.hi <<= s; 271 | self.hi |= self.lo >> (umax::BITS as $t - s); 272 | self.lo <<= s; 273 | } 274 | } 275 | } 276 | } 277 | //> (used in Self::checked_shr) 278 | impl Shr<$t> for udouble { 279 | type Output = Self; 280 | #[inline] 281 | fn shr(self, rhs: $t) -> Self::Output { 282 | match rhs { 283 | 0 => self, 284 | s if s >= umax::BITS as $t => Self { 285 | lo: self.hi >> (rhs - umax::BITS as $t), 286 | hi: 0, 287 | }, 288 | s => Self { 289 | hi: self.hi >> s, 290 | lo: (self.lo >> s) | (self.hi << (umax::BITS as $t - s)), 291 | }, 292 | } 293 | } 294 | } 295 | //> (not used yet) 296 | impl ShrAssign<$t> for udouble { 297 | #[inline] 298 | fn shr_assign(&mut self, rhs: $t) { 299 | match rhs { 300 | 0 => {} 301 | s if s >= umax::BITS as $t => { 302 | self.lo = self.hi >> (rhs - umax::BITS as $t); 303 | self.hi = 0; 304 | } 305 | s => { 306 | self.lo >>= s; 307 | self.lo |= self.hi << (umax::BITS as $t - s); 308 | self.hi >>= s; 309 | } 310 | } 311 | } 312 | } 313 | }; 314 | } 315 | 316 | // only implement most useful ones, so that we don't need to optimize so many variants 317 | impl_sh_ops!(u8); 318 | impl_sh_ops!(u16); 319 | impl_sh_ops!(u32); 320 | 321 | //> (not used yet) 322 | impl BitAnd for udouble { 323 | type Output = Self; 324 | #[inline] 325 | fn bitand(self, rhs: Self) -> Self::Output { 326 | Self { 327 | lo: self.lo & rhs.lo, 328 | hi: self.hi & rhs.hi, 329 | } 330 | } 331 | } 332 | //> (not used yet) 333 | impl BitAndAssign for udouble { 334 | #[inline] 335 | fn bitand_assign(&mut self, rhs: Self) { 336 | self.lo &= rhs.lo; 337 | self.hi &= rhs.hi; 338 | } 339 | } 340 | //> (not used yet) 341 | impl BitOr for udouble { 342 | type Output = Self; 343 | #[inline] 344 | fn bitor(self, rhs: Self) -> Self::Output { 345 | Self { 346 | lo: self.lo | rhs.lo, 347 | hi: self.hi | rhs.hi, 348 | } 349 | } 350 | } 351 | //> (not used yet) 352 | impl BitOrAssign for udouble { 353 | #[inline] 354 | fn bitor_assign(&mut self, rhs: Self) { 355 | self.lo |= rhs.lo; 356 | self.hi |= rhs.hi; 357 | } 358 | } 359 | //> (not used yet) 360 | impl BitXor for udouble { 361 | type Output = Self; 362 | #[inline] 363 | fn bitxor(self, rhs: Self) -> Self::Output { 364 | Self { 365 | lo: self.lo ^ rhs.lo, 366 | hi: self.hi ^ rhs.hi, 367 | } 368 | } 369 | } 370 | //> (not used yet) 371 | impl BitXorAssign for udouble { 372 | #[inline] 373 | fn bitxor_assign(&mut self, rhs: Self) { 374 | self.lo ^= rhs.lo; 375 | self.hi ^= rhs.hi; 376 | } 377 | } 378 | //> (not used yet) 379 | impl Not for udouble { 380 | type Output = Self; 381 | #[inline] 382 | fn not(self) -> Self::Output { 383 | Self { 384 | lo: !self.lo, 385 | hi: !self.hi, 386 | } 387 | } 388 | } 389 | 390 | impl udouble { 391 | //> (used in Self::div_rem) 392 | #[inline] 393 | pub const fn leading_zeros(self) -> u32 { 394 | if self.hi == 0 { 395 | self.lo.leading_zeros() + umax::BITS 396 | } else { 397 | self.hi.leading_zeros() 398 | } 399 | } 400 | 401 | // double by double division (long division), it's not the most efficient algorithm. 402 | // listed here in case of future use 403 | #[allow(dead_code)] 404 | fn div_rem_2by2(self, other: Self) -> (Self, Self) { 405 | let mut n = self; // numerator 406 | let mut d = other; // denominator 407 | let mut q = Self { lo: 0, hi: 0 }; // quotient 408 | 409 | let nbits = (2 * umax::BITS - n.leading_zeros()) as u16; // assuming umax = u128 410 | let dbits = (2 * umax::BITS - d.leading_zeros()) as u16; 411 | assert!(dbits != 0, "division by zero"); 412 | 413 | // Early return in case we are dividing by a larger number than us 414 | if nbits < dbits { 415 | return (q, n); 416 | } 417 | 418 | // Bitwise long division 419 | let mut shift = nbits - dbits; 420 | d <<= shift; 421 | loop { 422 | if n >= d { 423 | q += 1; 424 | n -= d; 425 | } 426 | if shift == 0 { 427 | break; 428 | } 429 | 430 | d >>= 1u8; 431 | q <<= 1u8; 432 | shift -= 1; 433 | } 434 | (q, n) 435 | } 436 | 437 | // double by single to single division. 438 | // equivalent to `udiv_qrnnd` in C or `divq` in assembly. 439 | //> (used in Self::{div, rem}::) 440 | fn div_rem_2by1(self, other: umax) -> (umax, umax) { 441 | // the following algorithm comes from `ethnum` crate 442 | const B: umax = 1 << HALF_BITS; // number base (64 bits) 443 | 444 | // Normalize the divisor. 445 | let s = other.leading_zeros(); 446 | let (n, d) = (self << s, other << s); // numerator, denominator 447 | let (d1, d0) = split(d); 448 | let (n1, n0) = split(n.lo); // split lower part of dividend 449 | 450 | // Compute the first quotient digit q1. 451 | let (mut q1, mut rhat) = div_rem(n.hi, d1); 452 | 453 | // q1 has at most error 2. No more than 2 iterations. 454 | while q1 >= B || q1 * d0 > B * rhat + n1 { 455 | q1 -= 1; 456 | rhat += d1; 457 | if rhat >= B { 458 | break; 459 | } 460 | } 461 | 462 | let r21 = 463 | n.hi.wrapping_mul(B) 464 | .wrapping_add(n1) 465 | .wrapping_sub(q1.wrapping_mul(d)); 466 | 467 | // Compute the second quotient digit q0. 468 | let (mut q0, mut rhat) = div_rem(r21, d1); 469 | 470 | // q0 has at most error 2. No more than 2 iterations. 471 | while q0 >= B || q0 * d0 > B * rhat + n0 { 472 | q0 -= 1; 473 | rhat += d1; 474 | if rhat >= B { 475 | break; 476 | } 477 | } 478 | 479 | let r = (r21 480 | .wrapping_mul(B) 481 | .wrapping_add(n0) 482 | .wrapping_sub(q0.wrapping_mul(d))) 483 | >> s; 484 | let q = q1 * B + q0; 485 | (q, r) 486 | } 487 | } 488 | 489 | impl Mul for udouble { 490 | type Output = Self; 491 | #[inline] 492 | fn mul(self, rhs: umax) -> Self::Output { 493 | self.checked_mul1(rhs).expect("multiplication overflow!") 494 | } 495 | } 496 | 497 | impl Div for udouble { 498 | type Output = Self; 499 | #[inline] 500 | fn div(self, rhs: umax) -> Self::Output { 501 | // self.div_rem(rhs.into()).0 502 | if self.hi < rhs { 503 | // The result fits in 128 bits. 504 | Self { 505 | lo: self.div_rem_2by1(rhs).0, 506 | hi: 0, 507 | } 508 | } else { 509 | let (q, r) = div_rem(self.hi, rhs); 510 | Self { 511 | lo: Self { lo: self.lo, hi: r }.div_rem_2by1(rhs).0, 512 | hi: q, 513 | } 514 | } 515 | } 516 | } 517 | 518 | //> (used in Montgomery::::transform) 519 | impl Rem for udouble { 520 | type Output = umax; 521 | #[inline] 522 | fn rem(self, rhs: umax) -> Self::Output { 523 | if self.hi < rhs { 524 | // The result fits in 128 bits. 525 | self.div_rem_2by1(rhs).1 526 | } else { 527 | Self { 528 | lo: self.lo, 529 | hi: self.hi % rhs, 530 | } 531 | .div_rem_2by1(rhs) 532 | .1 533 | } 534 | } 535 | } 536 | 537 | #[cfg(test)] 538 | mod tests { 539 | use super::*; 540 | use rand::random; 541 | 542 | #[test] 543 | fn test_construction() { 544 | // from widening operators 545 | assert_eq!(udouble { hi: 0, lo: 2 }, udouble::widening_add(1, 1)); 546 | assert_eq!( 547 | udouble { 548 | hi: 1, 549 | lo: umax::MAX - 1 550 | }, 551 | udouble::widening_add(umax::MAX, umax::MAX) 552 | ); 553 | 554 | assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_mul(1, 1)); 555 | assert_eq!(udouble { hi: 0, lo: 1 }, udouble::widening_square(1)); 556 | assert_eq!( 557 | udouble { hi: 1 << 32, lo: 0 }, 558 | udouble::widening_mul(1 << 80, 1 << 80) 559 | ); 560 | assert_eq!( 561 | udouble { hi: 1 << 32, lo: 0 }, 562 | udouble::widening_square(1 << 80) 563 | ); 564 | assert_eq!( 565 | udouble { 566 | hi: 1 << 32, 567 | lo: 2 << 120 | 1 << 80 568 | }, 569 | udouble::widening_mul(1 << 80 | 1 << 40, 1 << 80 | 1 << 40) 570 | ); 571 | assert_eq!( 572 | udouble { 573 | hi: 1 << 32, 574 | lo: 2 << 120 | 1 << 80 575 | }, 576 | udouble::widening_square(1 << 80 | 1 << 40) 577 | ); 578 | assert_eq!( 579 | udouble { 580 | hi: umax::MAX - 1, 581 | lo: 1 582 | }, 583 | udouble::widening_mul(umax::MAX, umax::MAX) 584 | ); 585 | assert_eq!( 586 | udouble { 587 | hi: umax::MAX - 1, 588 | lo: 1 589 | }, 590 | udouble::widening_square(umax::MAX) 591 | ); 592 | } 593 | 594 | #[test] 595 | fn test_ops() { 596 | const ONE: udouble = udouble { hi: 0, lo: 1 }; 597 | const TWO: udouble = udouble { hi: 0, lo: 2 }; 598 | const MAX: udouble = udouble { 599 | hi: 0, 600 | lo: umax::MAX, 601 | }; 602 | const ONEZERO: udouble = udouble { hi: 1, lo: 0 }; 603 | const ONEMAX: udouble = udouble { 604 | hi: 1, 605 | lo: umax::MAX, 606 | }; 607 | const TWOZERO: udouble = udouble { hi: 2, lo: 0 }; 608 | 609 | assert_eq!(ONE + MAX, ONEZERO); 610 | assert_eq!(ONE + ONEMAX, TWOZERO); 611 | assert_eq!(ONEZERO - ONE, MAX); 612 | assert_eq!(ONEZERO - MAX, ONE); 613 | assert_eq!(TWOZERO - ONE, ONEMAX); 614 | assert_eq!(TWOZERO - ONEMAX, ONE); 615 | 616 | assert_eq!(ONE << umax::BITS, ONEZERO); 617 | assert_eq!((MAX << 1u8) + 1, ONEMAX); 618 | assert_eq!( 619 | ONE << 200u8, 620 | udouble { 621 | lo: 0, 622 | hi: 1 << (200 - umax::BITS) 623 | } 624 | ); 625 | assert_eq!(ONEZERO >> umax::BITS, ONE); 626 | assert_eq!(ONEMAX >> 1u8, MAX); 627 | 628 | assert_eq!(ONE * MAX.lo, MAX); 629 | assert_eq!(ONEMAX * ONE.lo, ONEMAX); 630 | assert_eq!(ONEMAX * TWO.lo, ONEMAX + ONEMAX); 631 | assert_eq!(MAX / ONE.lo, MAX); 632 | assert_eq!(MAX / MAX.lo, ONE); 633 | assert_eq!(ONE / MAX.lo, udouble { lo: 0, hi: 0 }); 634 | assert_eq!(ONEMAX / ONE.lo, ONEMAX); 635 | assert_eq!(ONEMAX / MAX.lo, TWO); 636 | assert_eq!(ONEMAX / TWO.lo, MAX); 637 | assert_eq!(ONE % MAX.lo, 1); 638 | assert_eq!(TWO % MAX.lo, 2); 639 | assert_eq!(ONEMAX % MAX.lo, 1); 640 | assert_eq!(ONEMAX % TWO.lo, 1); 641 | 642 | assert_eq!(ONEMAX.checked_mul1(MAX.lo), None); 643 | assert_eq!(TWOZERO.checked_mul1(MAX.lo), None); 644 | } 645 | 646 | #[test] 647 | fn test_assign_ops() { 648 | for _ in 0..10 { 649 | let x = udouble { 650 | hi: random::() as umax, 651 | lo: random(), 652 | }; 653 | let y = udouble { 654 | hi: random::() as umax, 655 | lo: random(), 656 | }; 657 | let mut z = x; 658 | 659 | z += y; 660 | assert_eq!(z, x + y); 661 | z -= y; 662 | assert_eq!(z, x); 663 | } 664 | } 665 | } 666 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate provides efficient Modular arithmetic operations for various integer types, 2 | //! including primitive integers and `num-bigint`. The latter option is enabled optionally. 3 | //! 4 | //! To achieve fast modular arithmetics, convert integers to any [ModularInteger] implementation 5 | //! use static `new()` or associated [ModularInteger::convert()] functions. Some builtin implementations 6 | //! of [ModularInteger] includes [MontgomeryInt] and [FixedMersenneInt]. 7 | //! 8 | //! Example code: 9 | //! ```rust 10 | //! use num_modular::{ModularCoreOps, ModularInteger, MontgomeryInt}; 11 | //! 12 | //! // directly using methods in ModularCoreOps 13 | //! let (x, y, m) = (12u8, 13u8, 5u8); 14 | //! assert_eq!(x.mulm(y, &m), x * y % m); 15 | //! 16 | //! // convert integers into ModularInteger 17 | //! let mx = MontgomeryInt::new(x, &m); 18 | //! let my = mx.convert(y); // faster than static MontgomeryInt::new(y, m) 19 | //! assert_eq!((mx * my).residue(), x * y % m); 20 | //! ``` 21 | //! 22 | //! # Comparison of fast division / modular arithmetics 23 | //! Several fast division / modulo tricks are provided in these crate, the difference of them are listed below: 24 | //! - [PreModInv]: pre-compute modular inverse of the divisor, only applicable to exact division 25 | //! - Barrett (to be implemented): pre-compute (rational approximation of) the reciprocal of the divisor, 26 | //! applicable to fast division and modulo 27 | //! - [Montgomery]: Convert the dividend into a special form by shifting and pre-compute a modular inverse, 28 | //! only applicable to fast modulo, but faster than Barrett reduction 29 | //! - [FixedMersenne]: Specialization of modulo in form `2^P-K` under 2^127. 30 | //! 31 | 32 | // XXX: Other fast modular arithmetic tricks 33 | // REF: https://github.com/lemire/fastmod & https://arxiv.org/pdf/1902.01961.pdf 34 | // REF: https://eprint.iacr.org/2014/040.pdf 35 | // REF: https://github.com/ridiculousfish/libdivide/ 36 | // REF: Faster Interleaved Modular Multiplication Based on Barrett and Montgomery Reduction Methods (work for modulus in certain form) 37 | 38 | #![no_std] 39 | #[cfg(any(feature = "std", test))] 40 | extern crate std; 41 | 42 | use core::ops::{Add, Mul, Neg, Sub}; 43 | 44 | /// Core modular arithmetic operations. 45 | /// 46 | /// Note that all functions will panic if the modulus is zero. 47 | pub trait ModularCoreOps { 48 | type Output; 49 | 50 | /// Return (self + rhs) % m 51 | fn addm(self, rhs: Rhs, m: Modulus) -> Self::Output; 52 | 53 | /// Return (self - rhs) % m 54 | fn subm(self, rhs: Rhs, m: Modulus) -> Self::Output; 55 | 56 | /// Return (self * rhs) % m 57 | fn mulm(self, rhs: Rhs, m: Modulus) -> Self::Output; 58 | } 59 | 60 | /// Core unary modular arithmetics 61 | /// 62 | /// Note that all functions will panic if the modulus is zero. 63 | pub trait ModularUnaryOps { 64 | type Output; 65 | 66 | /// Return (-self) % m and make sure the result is normalized in range [0,m) 67 | fn negm(self, m: Modulus) -> Self::Output; 68 | 69 | /// Calculate modular inverse (x such that self*x = 1 mod m). 70 | /// 71 | /// This operation is only available for integer that is coprime to `m`. If not, 72 | /// the result will be [None]. 73 | fn invm(self, m: Modulus) -> Option; 74 | 75 | /// Calculate modular double ( x+x mod m) 76 | fn dblm(self, m: Modulus) -> Self::Output; 77 | 78 | /// Calculate modular square ( x*x mod m ) 79 | fn sqm(self, m: Modulus) -> Self::Output; 80 | 81 | // TODO: Modular sqrt aka Quadratic residue, follow the behavior of FLINT `n_sqrtmod` 82 | // fn sqrtm(self, m: Modulus) -> Option; 83 | // REF: https://stackoverflow.com/questions/6752374/cube-root-modulo-p-how-do-i-do-this 84 | } 85 | 86 | /// Modular power functions 87 | pub trait ModularPow { 88 | type Output; 89 | 90 | /// Return (self ^ exp) % m 91 | fn powm(self, exp: Exp, m: Modulus) -> Self::Output; 92 | } 93 | 94 | /// Math symbols related to modular arithmetics 95 | pub trait ModularSymbols { 96 | /// Calculate Legendre Symbol (a|n), where a is `self`. 97 | /// 98 | /// Note that this function doesn't perform a full primality check, since 99 | /// is costly. So if n is not a prime, the result can be not reasonable. 100 | /// 101 | /// # Panics 102 | /// Only if n is not prime 103 | #[inline] 104 | fn legendre(&self, n: Modulus) -> i8 { 105 | self.checked_legendre(n).expect("n shoud be a prime") 106 | } 107 | 108 | /// Calculate Legendre Symbol (a|n), where a is `self`. Returns [None] only if n is 109 | /// not a prime. 110 | /// 111 | /// Note that this function doesn't perform a full primality check, since 112 | /// is costly. So if n is not a prime, the result can be not reasonable. 113 | /// 114 | /// # Panics 115 | /// Only if n is not prime 116 | fn checked_legendre(&self, n: Modulus) -> Option; 117 | 118 | /// Calculate Jacobi Symbol (a|n), where a is `self` 119 | /// 120 | /// # Panics 121 | /// if n is negative or even 122 | #[inline] 123 | fn jacobi(&self, n: Modulus) -> i8 { 124 | self.checked_jacobi(n) 125 | .expect("the Jacobi symbol is only defined for non-negative odd integers") 126 | } 127 | 128 | /// Calculate Jacobi Symbol (a|n), where a is `self`. Returns [None] if n is negative or even. 129 | fn checked_jacobi(&self, n: Modulus) -> Option; 130 | 131 | /// Calculate Kronecker Symbol (a|n), where a is `self` 132 | fn kronecker(&self, n: Modulus) -> i8; 133 | } 134 | 135 | // TODO: Discrete log aka index, follow the behavior of FLINT `n_discrete_log_bsgs` 136 | // REF: https://github.com/vks/discrete-log 137 | // fn logm(self, base: Modulus, m: Modulus); 138 | 139 | /// Collection of common modular arithmetic operations 140 | pub trait ModularOps: 141 | ModularCoreOps 142 | + ModularUnaryOps 143 | + ModularPow 144 | + ModularSymbols 145 | { 146 | } 147 | impl ModularOps for T where 148 | T: ModularCoreOps 149 | + ModularUnaryOps 150 | + ModularPow 151 | + ModularSymbols 152 | { 153 | } 154 | 155 | /// Collection of operations similar to [ModularOps], but takes operands with references 156 | pub trait ModularRefOps: for<'r> ModularOps<&'r Self, &'r Self> + Sized {} 157 | impl ModularRefOps for T where T: for<'r> ModularOps<&'r T, &'r T> {} 158 | 159 | /// Provides a utility function to convert signed integers into unsigned modular form 160 | pub trait ModularAbs { 161 | /// Return self % m, but accepting signed integers 162 | fn absm(self, m: &Modulus) -> Modulus; 163 | } 164 | 165 | /// Represents an number defined in a modulo ring ℤ/nℤ 166 | /// 167 | /// The operators should panic if the modulus of two number 168 | /// are not the same. 169 | pub trait ModularInteger: 170 | Sized 171 | + PartialEq 172 | + Add 173 | + Sub 174 | + Neg 175 | + Mul 176 | { 177 | /// The underlying representation type of the integer 178 | type Base; 179 | 180 | /// Return the modulus of the ring 181 | fn modulus(&self) -> Self::Base; 182 | 183 | /// Return the normalized residue of this integer in the ring 184 | fn residue(&self) -> Self::Base; 185 | 186 | /// Check if the integer is zero 187 | fn is_zero(&self) -> bool; 188 | 189 | /// Convert an normal integer into the same ring. 190 | /// 191 | /// This method should be perferred over the static 192 | /// constructor to prevent unnecessary overhead of pre-computation. 193 | fn convert(&self, n: Self::Base) -> Self; 194 | 195 | /// Calculate the value of self + self 196 | fn double(self) -> Self; 197 | 198 | /// Calculate the value of self * self 199 | fn square(self) -> Self; 200 | } 201 | 202 | // XXX: implement ModularInteger for ff::PrimeField? 203 | // TODO: implement invm_range (Modular inverse in certain range) and crt (Chinese Remainder Theorem), REF: bubblemath crate 204 | 205 | /// Utility function for exact division, with precomputed helper values 206 | /// 207 | /// # Available Pre-computation types: 208 | /// - `()`: No pre-computation, the implementation relies on native integer division 209 | /// - [PreModInv]: With Pre-computed modular inverse 210 | pub trait DivExact: Sized { 211 | type Output; 212 | 213 | /// Check if d divides self with the help of the precomputation. If d divides self, 214 | /// then the quotient is returned. 215 | fn div_exact(self, d: Rhs, pre: &Precompute) -> Option; 216 | } 217 | 218 | /// A modular reducer that can ensure that the operations on integers are all performed 219 | /// in a modular ring. 220 | /// 221 | /// Essential information for performing the modulo operation will be stored in the reducer. 222 | pub trait Reducer { 223 | /// Create a reducer for a modulus m 224 | fn new(m: &T) -> Self; 225 | 226 | /// Transform a normal integer into reduced form 227 | fn transform(&self, target: T) -> T; 228 | 229 | /// Check whether target is a valid reduced form 230 | fn check(&self, target: &T) -> bool; 231 | 232 | /// Get the modulus in original integer type 233 | fn modulus(&self) -> T; 234 | 235 | /// Transform a reduced form back to normal integer 236 | fn residue(&self, target: T) -> T; 237 | 238 | /// Test if the residue() == 0 239 | fn is_zero(&self, target: &T) -> bool; 240 | 241 | /// Calculate (lhs + rhs) mod m in reduced form 242 | fn add(&self, lhs: &T, rhs: &T) -> T; 243 | 244 | #[inline] 245 | fn add_in_place(&self, lhs: &mut T, rhs: &T) { 246 | *lhs = self.add(lhs, rhs) 247 | } 248 | 249 | /// Calculate 2*target mod m 250 | fn dbl(&self, target: T) -> T; 251 | 252 | /// Calculate (lhs - rhs) mod m in reduced form 253 | fn sub(&self, lhs: &T, rhs: &T) -> T; 254 | 255 | #[inline] 256 | fn sub_in_place(&self, lhs: &mut T, rhs: &T) { 257 | *lhs = self.sub(lhs, rhs); 258 | } 259 | 260 | /// Calculate -monty mod m in reduced form 261 | fn neg(&self, target: T) -> T; 262 | 263 | /// Calculate (lhs * rhs) mod m in reduced form 264 | fn mul(&self, lhs: &T, rhs: &T) -> T; 265 | 266 | #[inline] 267 | fn mul_in_place(&self, lhs: &mut T, rhs: &T) { 268 | *lhs = self.mul(lhs, rhs); 269 | } 270 | 271 | /// Calculate target^-1 mod m in reduced form, 272 | /// it may return None when there is no modular inverse. 273 | fn inv(&self, target: T) -> Option; 274 | 275 | /// Calculate target^2 mod m in reduced form 276 | fn sqr(&self, target: T) -> T; 277 | 278 | /// Calculate base ^ exp mod m in reduced form 279 | fn pow(&self, base: T, exp: &T) -> T; 280 | } 281 | 282 | mod barrett; 283 | mod double; 284 | mod mersenne; 285 | mod monty; 286 | mod preinv; 287 | mod prim; 288 | mod reduced; 289 | mod word; 290 | 291 | pub use barrett::{ 292 | Normalized2by1Divisor, Normalized3by2Divisor, PreMulInv1by1, PreMulInv2by1, PreMulInv3by2, 293 | }; 294 | pub use double::{udouble, umax}; 295 | pub use mersenne::FixedMersenne; 296 | pub use monty::Montgomery; 297 | pub use preinv::PreModInv; 298 | pub use reduced::{ReducedInt, Vanilla, VanillaInt}; 299 | 300 | /// An integer in modulo ring based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form) 301 | pub type MontgomeryInt = ReducedInt>; 302 | 303 | /// An integer in modulo ring with a fixed (pseudo) Mersenne number as modulus 304 | pub type FixedMersenneInt = ReducedInt>; 305 | 306 | // pub type BarrettInt = ReducedInt>; 307 | 308 | #[cfg(feature = "num-bigint")] 309 | mod bigint; 310 | -------------------------------------------------------------------------------- /src/mersenne.rs: -------------------------------------------------------------------------------- 1 | use crate::reduced::impl_reduced_binary_pow; 2 | use crate::{udouble, umax, ModularUnaryOps, Reducer}; 3 | 4 | // FIXME: use unchecked operators to speed up calculation (after https://github.com/rust-lang/rust/issues/85122) 5 | /// A modular reducer for (pseudo) Mersenne numbers `2^P - K` as modulus. It supports `P` up to 127 and `K < 2^(P-1)` 6 | /// 7 | /// The `P` is limited to 127 so that it's not necessary to check overflow. This limit won't be a problem for any 8 | /// Mersenne primes within the range of [umax] (i.e. [u128]). 9 | #[derive(Debug, Clone, Copy)] 10 | pub struct FixedMersenne(); 11 | 12 | // XXX: support other primes as modulo, such as solinas prime, proth prime and support multi precision 13 | // REF: Handbook of Cryptography 14.3.4 14 | 15 | impl FixedMersenne { 16 | const BITMASK: umax = (1 << P) - 1; 17 | pub const MODULUS: umax = (1 << P) - K; 18 | 19 | // Calculate v % Self::MODULUS, where v is a umax integer 20 | const fn reduce_single(v: umax) -> umax { 21 | let mut lo = v & Self::BITMASK; 22 | let mut hi = v >> P; 23 | while hi > 0 { 24 | let sum = if K == 1 { hi + lo } else { hi * K + lo }; 25 | lo = sum & Self::BITMASK; 26 | hi = sum >> P; 27 | } 28 | 29 | if lo >= Self::MODULUS { 30 | lo - Self::MODULUS 31 | } else { 32 | lo 33 | } 34 | } 35 | 36 | // Calculate v % Self::MODULUS, where v is a udouble integer 37 | fn reduce_double(v: udouble) -> umax { 38 | // reduce modulo 39 | let mut lo = v.lo & Self::BITMASK; 40 | let mut hi = v >> P; 41 | while hi.hi > 0 { 42 | // first reduce until high bits fit in umax 43 | let sum = if K == 1 { hi + lo } else { hi * K + lo }; 44 | lo = sum.lo & Self::BITMASK; 45 | hi = sum >> P; 46 | } 47 | 48 | let mut hi = hi.lo; 49 | while hi > 0 { 50 | // then reduce the smaller high bits 51 | let sum = if K == 1 { hi + lo } else { hi * K + lo }; 52 | lo = sum & Self::BITMASK; 53 | hi = sum >> P; 54 | } 55 | 56 | if lo >= Self::MODULUS { 57 | lo - Self::MODULUS 58 | } else { 59 | lo 60 | } 61 | } 62 | } 63 | 64 | impl Reducer for FixedMersenne { 65 | #[inline] 66 | fn new(m: &umax) -> Self { 67 | assert!( 68 | *m == Self::MODULUS, 69 | "the given modulus doesn't match with the generic params" 70 | ); 71 | debug_assert!(P <= 127); 72 | debug_assert!(K > 0 && K < (2 as umax).pow(P as u32 - 1) && K % 2 == 1); 73 | debug_assert!( 74 | Self::MODULUS % 3 != 0 75 | && Self::MODULUS % 5 != 0 76 | && Self::MODULUS % 7 != 0 77 | && Self::MODULUS % 11 != 0 78 | && Self::MODULUS % 13 != 0 79 | ); // error on easy composites 80 | Self {} 81 | } 82 | #[inline] 83 | fn transform(&self, target: umax) -> umax { 84 | Self::reduce_single(target) 85 | } 86 | fn check(&self, target: &umax) -> bool { 87 | *target < Self::MODULUS 88 | } 89 | #[inline] 90 | fn residue(&self, target: umax) -> umax { 91 | target 92 | } 93 | #[inline] 94 | fn modulus(&self) -> umax { 95 | Self::MODULUS 96 | } 97 | #[inline] 98 | fn is_zero(&self, target: &umax) -> bool { 99 | target == &0 100 | } 101 | 102 | #[inline] 103 | fn add(&self, lhs: &umax, rhs: &umax) -> umax { 104 | let mut sum = lhs + rhs; 105 | if sum >= Self::MODULUS { 106 | sum -= Self::MODULUS 107 | } 108 | sum 109 | } 110 | #[inline] 111 | fn sub(&self, lhs: &umax, rhs: &umax) -> umax { 112 | if lhs >= rhs { 113 | lhs - rhs 114 | } else { 115 | Self::MODULUS - (rhs - lhs) 116 | } 117 | } 118 | #[inline] 119 | fn dbl(&self, target: umax) -> umax { 120 | self.add(&target, &target) 121 | } 122 | #[inline] 123 | fn neg(&self, target: umax) -> umax { 124 | if target == 0 { 125 | 0 126 | } else { 127 | Self::MODULUS - target 128 | } 129 | } 130 | #[inline] 131 | fn mul(&self, lhs: &umax, rhs: &umax) -> umax { 132 | if (P as u32) < (umax::BITS / 2) { 133 | Self::reduce_single(lhs * rhs) 134 | } else { 135 | Self::reduce_double(udouble::widening_mul(*lhs, *rhs)) 136 | } 137 | } 138 | #[inline] 139 | fn inv(&self, target: umax) -> Option { 140 | if (P as u32) < usize::BITS { 141 | (target as usize) 142 | .invm(&(Self::MODULUS as usize)) 143 | .map(|v| v as umax) 144 | } else { 145 | target.invm(&Self::MODULUS) 146 | } 147 | } 148 | #[inline] 149 | fn sqr(&self, target: umax) -> umax { 150 | if (P as u32) < (umax::BITS / 2) { 151 | Self::reduce_single(target * target) 152 | } else { 153 | Self::reduce_double(udouble::widening_square(target)) 154 | } 155 | } 156 | 157 | impl_reduced_binary_pow!(umax); 158 | } 159 | 160 | #[cfg(test)] 161 | mod tests { 162 | use super::*; 163 | use crate::{ModularCoreOps, ModularPow}; 164 | use rand::random; 165 | 166 | type M = FixedMersenne<31, 1>; 167 | type M1 = FixedMersenne<31, 1>; 168 | type M2 = FixedMersenne<61, 1>; 169 | type M3 = FixedMersenne<127, 1>; 170 | type M4 = FixedMersenne<32, 5>; 171 | type M5 = FixedMersenne<56, 5>; 172 | type M6 = FixedMersenne<122, 3>; 173 | 174 | const NRANDOM: u32 = 10; 175 | 176 | #[test] 177 | fn creation_test() { 178 | const P: umax = (1 << 31) - 1; 179 | let m = M::new(&P); 180 | assert_eq!(m.residue(m.transform(0)), 0); 181 | assert_eq!(m.residue(m.transform(1)), 1); 182 | assert_eq!(m.residue(m.transform(P)), 0); 183 | assert_eq!(m.residue(m.transform(P - 1)), P - 1); 184 | assert_eq!(m.residue(m.transform(P + 1)), 1); 185 | 186 | // random creation test 187 | for _ in 0..NRANDOM { 188 | let a = random::(); 189 | 190 | const P1: umax = (1 << 31) - 1; 191 | let m1 = M1::new(&P1); 192 | assert_eq!(m1.residue(m1.transform(a)), a % P1); 193 | const P2: umax = (1 << 61) - 1; 194 | let m2 = M2::new(&P2); 195 | assert_eq!(m2.residue(m2.transform(a)), a % P2); 196 | const P3: umax = (1 << 127) - 1; 197 | let m3 = M3::new(&P3); 198 | assert_eq!(m3.residue(m3.transform(a)), a % P3); 199 | const P4: umax = (1 << 32) - 5; 200 | let m4 = M4::new(&P4); 201 | assert_eq!(m4.residue(m4.transform(a)), a % P4); 202 | const P5: umax = (1 << 56) - 5; 203 | let m5 = M5::new(&P5); 204 | assert_eq!(m5.residue(m5.transform(a)), a % P5); 205 | const P6: umax = (1 << 122) - 3; 206 | let m6 = M6::new(&P6); 207 | assert_eq!(m6.residue(m6.transform(a)), a % P6); 208 | } 209 | } 210 | 211 | #[test] 212 | fn test_against_modops() { 213 | macro_rules! tests_for { 214 | ($a:tt, $b:tt, $e:tt; $($M:ty)*) => ($({ 215 | const P: umax = <$M>::MODULUS; 216 | let r = <$M>::new(&P); 217 | let am = r.transform($a); 218 | let bm = r.transform($b); 219 | assert_eq!(r.add(&am, &bm), $a.addm($b, &P)); 220 | assert_eq!(r.sub(&am, &bm), $a.subm($b, &P)); 221 | assert_eq!(r.mul(&am, &bm), $a.mulm($b, &P)); 222 | assert_eq!(r.neg(am), $a.negm(&P)); 223 | assert_eq!(r.inv(am), $a.invm(&P)); 224 | assert_eq!(r.dbl(am), $a.dblm(&P)); 225 | assert_eq!(r.sqr(am), $a.sqm(&P)); 226 | assert_eq!(r.pow(am, &$e), $a.powm($e, &P)); 227 | })*); 228 | } 229 | 230 | for _ in 0..NRANDOM { 231 | let (a, b) = (random::(), random::()); 232 | let e = random::() as umax; 233 | tests_for!(a, b, e; M1 M2 M3 M4 M5 M6); 234 | } 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /src/monty.rs: -------------------------------------------------------------------------------- 1 | use crate::reduced::impl_reduced_binary_pow; 2 | use crate::{ModularUnaryOps, Reducer, Vanilla}; 3 | 4 | /// Negated modular inverse on binary bases 5 | /// `neginv` calculates `-(m^-1) mod R`, `R = 2^k. If m is odd, then result of m + 1 will be returned. 6 | mod neg_mod_inv { 7 | // Entry i contains (2i+1)^(-1) mod 256. 8 | #[rustfmt::skip] 9 | const BINV_TABLE: [u8; 128] = [ 10 | 0x01, 0xAB, 0xCD, 0xB7, 0x39, 0xA3, 0xC5, 0xEF, 0xF1, 0x1B, 0x3D, 0xA7, 0x29, 0x13, 0x35, 0xDF, 11 | 0xE1, 0x8B, 0xAD, 0x97, 0x19, 0x83, 0xA5, 0xCF, 0xD1, 0xFB, 0x1D, 0x87, 0x09, 0xF3, 0x15, 0xBF, 12 | 0xC1, 0x6B, 0x8D, 0x77, 0xF9, 0x63, 0x85, 0xAF, 0xB1, 0xDB, 0xFD, 0x67, 0xE9, 0xD3, 0xF5, 0x9F, 13 | 0xA1, 0x4B, 0x6D, 0x57, 0xD9, 0x43, 0x65, 0x8F, 0x91, 0xBB, 0xDD, 0x47, 0xC9, 0xB3, 0xD5, 0x7F, 14 | 0x81, 0x2B, 0x4D, 0x37, 0xB9, 0x23, 0x45, 0x6F, 0x71, 0x9B, 0xBD, 0x27, 0xA9, 0x93, 0xB5, 0x5F, 15 | 0x61, 0x0B, 0x2D, 0x17, 0x99, 0x03, 0x25, 0x4F, 0x51, 0x7B, 0x9D, 0x07, 0x89, 0x73, 0x95, 0x3F, 16 | 0x41, 0xEB, 0x0D, 0xF7, 0x79, 0xE3, 0x05, 0x2F, 0x31, 0x5B, 0x7D, 0xE7, 0x69, 0x53, 0x75, 0x1F, 17 | 0x21, 0xCB, 0xED, 0xD7, 0x59, 0xC3, 0xE5, 0x0F, 0x11, 0x3B, 0x5D, 0xC7, 0x49, 0x33, 0x55, 0xFF, 18 | ]; 19 | 20 | pub mod u8 { 21 | use super::*; 22 | pub const fn neginv(m: u8) -> u8 { 23 | let i = BINV_TABLE[((m >> 1) & 0x7F) as usize]; 24 | i.wrapping_neg() 25 | } 26 | } 27 | 28 | pub mod u16 { 29 | use super::*; 30 | pub const fn neginv(m: u16) -> u16 { 31 | let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u16; 32 | // hensel lifting 33 | i = 2u16.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 34 | i.wrapping_neg() 35 | } 36 | } 37 | 38 | pub mod u32 { 39 | use super::*; 40 | pub const fn neginv(m: u32) -> u32 { 41 | let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u32; 42 | i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 43 | i = 2u32.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 44 | i.wrapping_neg() 45 | } 46 | } 47 | 48 | pub mod u64 { 49 | use super::*; 50 | pub const fn neginv(m: u64) -> u64 { 51 | let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u64; 52 | i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 53 | i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 54 | i = 2u64.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 55 | i.wrapping_neg() 56 | } 57 | } 58 | 59 | pub mod u128 { 60 | use super::*; 61 | pub const fn neginv(m: u128) -> u128 { 62 | let mut i = BINV_TABLE[((m >> 1) & 0x7F) as usize] as u128; 63 | i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 64 | i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 65 | i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 66 | i = 2u128.wrapping_sub(i.wrapping_mul(m)).wrapping_mul(i); 67 | i.wrapping_neg() 68 | } 69 | } 70 | 71 | pub mod usize { 72 | #[inline] 73 | pub const fn neginv(m: usize) -> usize { 74 | #[cfg(target_pointer_width = "16")] 75 | return super::u16::neginv(m as _) as _; 76 | #[cfg(target_pointer_width = "32")] 77 | return super::u32::neginv(m as _) as _; 78 | #[cfg(target_pointer_width = "64")] 79 | return super::u64::neginv(m as _) as _; 80 | } 81 | } 82 | } 83 | 84 | /// A modular reducer based on [Montgomery form](https://en.wikipedia.org/wiki/Montgomery_modular_multiplication#Montgomery_form), only supports odd modulus. 85 | /// 86 | /// The generic type T represents the underlying integer representation for modular inverse `-m^-1 mod R`, 87 | /// and `R=2^B` will be used as the auxiliary modulus, where B is automatically selected 88 | /// based on the size of T. 89 | #[derive(Debug, Clone, Copy)] 90 | pub struct Montgomery { 91 | m: T, // modulus 92 | inv: T, // modular inverse of the modulus 93 | } 94 | 95 | macro_rules! impl_montgomery_for { 96 | ($t:ident, $ns:ident) => { 97 | mod $ns { 98 | use super::*; 99 | use crate::word::$t::*; 100 | use neg_mod_inv::$t::neginv; 101 | 102 | impl Montgomery<$t> { 103 | pub const fn new(m: $t) -> Self { 104 | assert!( 105 | m & 1 != 0, 106 | "Only odd modulus are supported by the Montgomery form" 107 | ); 108 | Self { m, inv: neginv(m) } 109 | } 110 | const fn reduce(&self, monty: DoubleWord) -> $t { 111 | debug_assert!(high(monty) < self.m); 112 | 113 | // REDC algorithm 114 | let tm = low(monty).wrapping_mul(self.inv); 115 | let (t, overflow) = monty.overflowing_add(wmul(tm, self.m)); 116 | let t = high(t); 117 | 118 | if overflow { 119 | t + self.m.wrapping_neg() 120 | } else if t >= self.m { 121 | t - self.m 122 | } else { 123 | t 124 | } 125 | } 126 | } 127 | 128 | impl Reducer<$t> for Montgomery<$t> { 129 | #[inline] 130 | fn new(m: &$t) -> Self { 131 | Self::new(*m) 132 | } 133 | #[inline] 134 | fn transform(&self, target: $t) -> $t { 135 | if target == 0 { 136 | return 0; 137 | } 138 | nrem(merge(0, target), self.m) 139 | } 140 | #[inline] 141 | fn check(&self, target: &$t) -> bool { 142 | *target < self.m 143 | } 144 | 145 | #[inline] 146 | fn residue(&self, target: $t) -> $t { 147 | self.reduce(extend(target)) 148 | } 149 | #[inline(always)] 150 | fn modulus(&self) -> $t { 151 | self.m 152 | } 153 | #[inline(always)] 154 | fn is_zero(&self, target: &$t) -> bool { 155 | *target == 0 156 | } 157 | 158 | #[inline(always)] 159 | fn add(&self, lhs: &$t, rhs: &$t) -> $t { 160 | Vanilla::<$t>::add(&self.m, *lhs, *rhs) 161 | } 162 | 163 | #[inline(always)] 164 | fn dbl(&self, target: $t) -> $t { 165 | Vanilla::<$t>::dbl(&self.m, target) 166 | } 167 | 168 | #[inline(always)] 169 | fn sub(&self, lhs: &$t, rhs: &$t) -> $t { 170 | Vanilla::<$t>::sub(&self.m, *lhs, *rhs) 171 | } 172 | 173 | #[inline(always)] 174 | fn neg(&self, target: $t) -> $t { 175 | Vanilla::<$t>::neg(&self.m, target) 176 | } 177 | 178 | #[inline] 179 | fn mul(&self, lhs: &$t, rhs: &$t) -> $t { 180 | self.reduce(wmul(*lhs, *rhs)) 181 | } 182 | 183 | #[inline] 184 | fn sqr(&self, target: $t) -> $t { 185 | self.reduce(wsqr(target)) 186 | } 187 | 188 | #[inline(always)] 189 | fn inv(&self, target: $t) -> Option<$t> { 190 | // TODO: support direct montgomery inverse 191 | // REF: http://cetinkayakoc.net/docs/j82.pdf 192 | self.residue(target) 193 | .invm(&self.m) 194 | .map(|v| self.transform(v)) 195 | } 196 | 197 | impl_reduced_binary_pow!(Word); 198 | } 199 | } 200 | }; 201 | } 202 | impl_montgomery_for!(u8, u8_impl); 203 | impl_montgomery_for!(u16, u16_impl); 204 | impl_montgomery_for!(u32, u32_impl); 205 | impl_montgomery_for!(u64, u64_impl); 206 | impl_montgomery_for!(u128, u128_impl); 207 | impl_montgomery_for!(usize, usize_impl); 208 | 209 | // TODO(v0.6.x): accept even numbers by removing 2 factors from m and store the exponent 210 | // Requirement: 1. A separate class to perform modular arithmetics with 2^n as modulus 211 | // 2. Algorithm for construct residue from two components (see http://koclab.cs.ucsb.edu/teaching/cs154/docx/Notes7-Montgomery.pdf) 212 | // Or we can just provide crt function, and let the implementation of monty int with full modulus support as an example code. 213 | 214 | #[cfg(test)] 215 | mod tests { 216 | use super::*; 217 | use rand::random; 218 | 219 | const NRANDOM: u32 = 10; 220 | 221 | #[test] 222 | fn creation_test() { 223 | // a deterministic test case for u128 224 | let a = (0x81u128 << 120) - 1; 225 | let m = (0x81u128 << 119) - 1; 226 | let m = m >> m.trailing_zeros(); 227 | let r = Montgomery::::new(m); 228 | assert_eq!(r.residue(r.transform(a)), a % m); 229 | 230 | // is_zero test 231 | let r = Montgomery::::new(11u8); 232 | assert!(r.is_zero(&r.transform(0))); 233 | let five = r.transform(5u8); 234 | let six = r.transform(6u8); 235 | assert!(r.is_zero(&r.add(&five, &six))); 236 | 237 | // random creation test 238 | for _ in 0..NRANDOM { 239 | let a = random::(); 240 | let m = random::() | 1; 241 | let r = Montgomery::::new(m); 242 | assert_eq!(r.residue(r.transform(a)), a % m); 243 | 244 | let a = random::(); 245 | let m = random::() | 1; 246 | let r = Montgomery::::new(m); 247 | assert_eq!(r.residue(r.transform(a)), a % m); 248 | 249 | let a = random::(); 250 | let m = random::() | 1; 251 | let r = Montgomery::::new(m); 252 | assert_eq!(r.residue(r.transform(a)), a % m); 253 | 254 | let a = random::(); 255 | let m = random::() | 1; 256 | let r = Montgomery::::new(m); 257 | assert_eq!(r.residue(r.transform(a)), a % m); 258 | 259 | let a = random::(); 260 | let m = random::() | 1; 261 | let r = Montgomery::::new(m); 262 | assert_eq!(r.residue(r.transform(a)), a % m); 263 | } 264 | } 265 | 266 | #[test] 267 | fn test_against_modops() { 268 | use crate::reduced::tests::ReducedTester; 269 | for _ in 0..NRANDOM { 270 | ReducedTester::::test_against_modops::>(1); 271 | ReducedTester::::test_against_modops::>(1); 272 | ReducedTester::::test_against_modops::>(1); 273 | ReducedTester::::test_against_modops::>(1); 274 | ReducedTester::::test_against_modops::>(1); 275 | ReducedTester::::test_against_modops::>(1); 276 | } 277 | } 278 | } 279 | -------------------------------------------------------------------------------- /src/preinv.rs: -------------------------------------------------------------------------------- 1 | use crate::{DivExact, ModularUnaryOps}; 2 | 3 | /// Pre-computing the modular inverse for fast divisibility check. 4 | /// 5 | /// This struct stores the modular inverse of a divisor, and a limit for divisibility check. 6 | /// See for the explanation of the trick 7 | #[derive(Debug, Clone, Copy)] 8 | pub struct PreModInv { 9 | d_inv: T, // modular inverse of divisor 10 | q_lim: T, // limit of residue 11 | } 12 | 13 | macro_rules! impl_preinv_for_prim_int { 14 | ($t:ident, $ns:ident) => { 15 | mod $ns { 16 | use super::*; 17 | use crate::word::$t::*; 18 | 19 | impl PreModInv<$t> { 20 | /// Construct the preinv instance with raw values. 21 | /// 22 | /// This function can be used to initialize preinv in a constant context, the divisor d 23 | /// is required only for verification of d_inv and q_lim. 24 | #[inline] 25 | pub const fn new(d_inv: $t, q_lim: $t) -> Self { 26 | Self { d_inv, q_lim } 27 | } 28 | 29 | // check if the divisor is consistent in debug mode 30 | #[inline] 31 | fn debug_check(&self, d: $t) { 32 | debug_assert!(d % 2 != 0, "only odd divisors are supported"); 33 | debug_assert!(d.wrapping_mul(self.d_inv) == 1); 34 | debug_assert!(self.q_lim * d > (<$t>::MAX - d)); 35 | } 36 | } 37 | 38 | impl From<$t> for PreModInv<$t> { 39 | #[inline] 40 | fn from(v: $t) -> Self { 41 | use crate::word::$t::*; 42 | 43 | debug_assert!(v % 2 != 0, "only odd divisors are supported"); 44 | let d_inv = extend(v).invm(&merge(0, 1)).unwrap() as $t; 45 | let q_lim = <$t>::MAX / v; 46 | Self { d_inv, q_lim } 47 | } 48 | } 49 | 50 | impl DivExact<$t, PreModInv<$t>> for $t { 51 | type Output = $t; 52 | #[inline] 53 | fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option { 54 | pre.debug_check(d); 55 | let q = self.wrapping_mul(pre.d_inv); 56 | if q <= pre.q_lim { 57 | Some(q) 58 | } else { 59 | None 60 | } 61 | } 62 | } 63 | 64 | impl DivExact<$t, PreModInv<$t>> for DoubleWord { 65 | type Output = DoubleWord; 66 | 67 | #[inline] 68 | fn div_exact(self, d: $t, pre: &PreModInv<$t>) -> Option { 69 | pre.debug_check(d); 70 | 71 | // this implementation comes from GNU factor, 72 | // see https://math.stackexchange.com/q/4436380/815652 for explanation 73 | 74 | let (n0, n1) = split(self); 75 | let q0 = n0.wrapping_mul(pre.d_inv); 76 | let nr0 = wmul(q0, d); 77 | let nr0 = split(nr0).1; 78 | if nr0 > n1 { 79 | return None; 80 | } 81 | let nr1 = n1 - nr0; 82 | let q1 = nr1.wrapping_mul(pre.d_inv); 83 | if q1 > pre.q_lim { 84 | return None; 85 | } 86 | Some(merge(q0, q1)) 87 | } 88 | } 89 | } 90 | }; 91 | } 92 | impl_preinv_for_prim_int!(u8, u8_impl); 93 | impl_preinv_for_prim_int!(u16, u16_impl); 94 | impl_preinv_for_prim_int!(u32, u32_impl); 95 | impl_preinv_for_prim_int!(u64, u64_impl); 96 | impl_preinv_for_prim_int!(usize, usize_impl); 97 | 98 | // XXX: unchecked div_exact can be introduced by not checking the q_lim, 99 | // investigate this after `exact_div` is introduced or removed from core lib 100 | // https://github.com/rust-lang/rust/issues/85122 101 | 102 | #[cfg(test)] 103 | mod tests { 104 | use super::*; 105 | use rand::random; 106 | 107 | #[test] 108 | fn div_exact_test() { 109 | const N: u8 = 100; 110 | for _ in 0..N { 111 | // u8 test 112 | let d = random::() | 1; 113 | let pre: PreModInv<_> = d.into(); 114 | 115 | let n: u8 = random(); 116 | let expect = if n % d == 0 { Some(n / d) } else { None }; 117 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 118 | let n: u16 = random(); 119 | let expect = if n % (d as u16) == 0 { 120 | Some(n / (d as u16)) 121 | } else { 122 | None 123 | }; 124 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 125 | 126 | // u16 test 127 | let d = random::() | 1; 128 | let pre: PreModInv<_> = d.into(); 129 | 130 | let n: u16 = random(); 131 | let expect = if n % d == 0 { Some(n / d) } else { None }; 132 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 133 | let n: u32 = random(); 134 | let expect = if n % (d as u32) == 0 { 135 | Some(n / (d as u32)) 136 | } else { 137 | None 138 | }; 139 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 140 | 141 | // u32 test 142 | let d = random::() | 1; 143 | let pre: PreModInv<_> = d.into(); 144 | 145 | let n: u32 = random(); 146 | let expect = if n % d == 0 { Some(n / d) } else { None }; 147 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 148 | let n: u64 = random(); 149 | let expect = if n % (d as u64) == 0 { 150 | Some(n / (d as u64)) 151 | } else { 152 | None 153 | }; 154 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 155 | 156 | // u64 test 157 | let d = random::() | 1; 158 | let pre: PreModInv<_> = d.into(); 159 | 160 | let n: u64 = random(); 161 | let expect = if n % d == 0 { Some(n / d) } else { None }; 162 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 163 | let n: u128 = random(); 164 | let expect = if n % (d as u128) == 0 { 165 | Some(n / (d as u128)) 166 | } else { 167 | None 168 | }; 169 | assert_eq!(n.div_exact(d, &pre), expect, "{} / {}", n, d); 170 | } 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/prim.rs: -------------------------------------------------------------------------------- 1 | //! Implementations for modular operations on primitive integers 2 | 3 | use crate::{udouble, Reducer, Vanilla}; 4 | use crate::{DivExact, ModularAbs, ModularCoreOps, ModularPow, ModularSymbols, ModularUnaryOps}; 5 | 6 | // FIXME: implement the modular functions as const after https://github.com/rust-lang/rust/pull/68847 7 | 8 | macro_rules! impl_core_ops_uu { 9 | ($($T:ty => $Tdouble:ty;)*) => ($( 10 | impl ModularCoreOps<$T, &$T> for $T { 11 | type Output = $T; 12 | #[inline(always)] 13 | fn addm(self, rhs: $T, m: &$T) -> $T { 14 | (((self as $Tdouble) + (rhs as $Tdouble)) % (*m as $Tdouble)) as $T 15 | } 16 | #[inline] 17 | fn subm(self, rhs: $T, m: &$T) -> $T { 18 | if self >= rhs { 19 | (self - rhs) % m 20 | } else { 21 | ((rhs - self) % m).negm(m) 22 | } 23 | } 24 | #[inline(always)] 25 | fn mulm(self, rhs: $T, m: &$T) -> $T { 26 | (((self as $Tdouble) * (rhs as $Tdouble)) % (*m as $Tdouble)) as $T 27 | } 28 | } 29 | )*); 30 | } 31 | impl_core_ops_uu! { u8 => u16; u16 => u32; u32 => u64; u64 => u128; } 32 | 33 | #[cfg(target_pointer_width = "16")] 34 | impl_core_ops_uu! { usize => u32; } 35 | #[cfg(target_pointer_width = "32")] 36 | impl_core_ops_uu! { usize => u64; } 37 | #[cfg(target_pointer_width = "64")] 38 | impl_core_ops_uu! { usize => u128; } 39 | 40 | impl ModularCoreOps for u128 { 41 | type Output = u128; 42 | 43 | #[inline] 44 | fn addm(self, rhs: u128, m: &u128) -> u128 { 45 | if let Some(ab) = self.checked_add(rhs) { 46 | ab % m 47 | } else { 48 | udouble::widening_add(self, rhs) % *m 49 | } 50 | } 51 | 52 | #[inline] 53 | fn subm(self, rhs: u128, m: &u128) -> u128 { 54 | if self >= rhs { 55 | (self - rhs) % m 56 | } else { 57 | ((rhs - self) % m).negm(m) 58 | } 59 | } 60 | 61 | #[inline] 62 | fn mulm(self, rhs: u128, m: &u128) -> u128 { 63 | if let Some(ab) = self.checked_mul(rhs) { 64 | ab % m 65 | } else { 66 | udouble::widening_mul(self, rhs) % *m 67 | } 68 | } 69 | } 70 | 71 | macro_rules! impl_powm_uprim { 72 | ($($T:ty)*) => ($( 73 | impl ModularPow<$T, &$T> for $T { 74 | type Output = $T; 75 | #[inline(always)] 76 | fn powm(self, exp: $T, m: &$T) -> $T { 77 | Vanilla::<$T>::new(&m).pow(self % m, &exp) 78 | } 79 | } 80 | )*); 81 | } 82 | impl_powm_uprim!(u8 u16 u32 u64 u128 usize); 83 | 84 | macro_rules! impl_symbols_uprim { 85 | ($($T:ty)*) => ($( 86 | impl ModularSymbols<&$T> for $T { 87 | #[inline] 88 | fn checked_legendre(&self, n: &$T) -> Option { 89 | match self.powm((n - 1)/2, &n) { 90 | 0 => Some(0), 91 | 1 => Some(1), 92 | x if x == n - 1 => Some(-1), 93 | _ => None, 94 | } 95 | } 96 | 97 | fn checked_jacobi(&self, n: &$T) -> Option { 98 | if n % 2 == 0 { 99 | return None; 100 | } 101 | if self == &0 { 102 | return Some(if n == &1 { 103 | 1 104 | } else { 105 | 0 106 | }); 107 | } 108 | if self == &1 { 109 | return Some(1); 110 | } 111 | 112 | let mut a = self % n; 113 | let mut n = *n; 114 | let mut t = 1; 115 | while a > 0 { 116 | while a % 2 == 0 { 117 | a /= 2; 118 | if n % 8 == 3 || n % 8 == 5 { 119 | t *= -1; 120 | } 121 | } 122 | core::mem::swap(&mut a, &mut n); 123 | if a % 4 == 3 && n % 4 == 3 { 124 | t *= -1; 125 | } 126 | a %= n; 127 | } 128 | Some(if n == 1 { 129 | t 130 | } else { 131 | 0 132 | }) 133 | } 134 | 135 | fn kronecker(&self, n: &$T) -> i8 { 136 | match n { 137 | 0 => { 138 | if self == &1 { 139 | 1 140 | } else { 141 | 0 142 | } 143 | } 144 | 1 => 1, 145 | 2 => { 146 | if self % 2 == 0 { 147 | 0 148 | } else if self % 8 == 1 || self % 8 == 7 { 149 | 1 150 | } else { 151 | -1 152 | } 153 | } 154 | _ => { 155 | let f = n.trailing_zeros(); 156 | let n = n >> f; 157 | self.kronecker(&2).pow(f) 158 | * self.jacobi(&n) 159 | } 160 | } 161 | } 162 | } 163 | )*); 164 | } 165 | impl_symbols_uprim!(u8 u16 u32 u64 u128 usize); 166 | 167 | macro_rules! impl_symbols_iprim { 168 | ($($T:ty, $U:ty;)*) => ($( 169 | impl ModularSymbols<&$T> for $T { 170 | #[inline] 171 | fn checked_legendre(&self, n: &$T) -> Option { 172 | if n < &1 { 173 | return None; 174 | } 175 | let a = self.rem_euclid(*n) as $U; 176 | a.checked_legendre(&(*n as $U)) 177 | } 178 | 179 | #[inline] 180 | fn checked_jacobi(&self, n: &$T) -> Option { 181 | if n < &1 { 182 | return None; 183 | } 184 | let a = self.rem_euclid(*n) as $U; 185 | a.checked_jacobi(&(*n as $U)) 186 | } 187 | 188 | #[inline] 189 | fn kronecker(&self, n: &$T) -> i8 { 190 | match n { 191 | -1 => { 192 | if self < &0 { 193 | -1 194 | } else { 195 | 1 196 | } 197 | } 198 | 0 => { 199 | if self == &1 { 200 | 1 201 | } else { 202 | 0 203 | } 204 | } 205 | 1 => 1, 206 | 2 => { 207 | if self % 2 == 0 { 208 | 0 209 | } else if self.rem_euclid(8) == 1 || self.rem_euclid(8) == 7 { 210 | 1 211 | } else { 212 | -1 213 | } 214 | }, 215 | i if i < &-1 => { 216 | self.kronecker(&-1) * self.kronecker(&-i) 217 | }, 218 | _ => { 219 | let f = n.trailing_zeros(); 220 | self.kronecker(&2).pow(f) 221 | * self.jacobi(&(n >> f)) 222 | } 223 | } 224 | } 225 | } 226 | )*); 227 | } 228 | 229 | impl_symbols_iprim!(i8, u8; i16, u16; i32, u32; i64, u64; i128, u128; isize, usize;); 230 | 231 | macro_rules! impl_unary_uprim { 232 | ($($T:ty)*) => ($( 233 | impl ModularUnaryOps<&$T> for $T { 234 | type Output = $T; 235 | #[inline] 236 | fn negm(self, m: &$T) -> $T { 237 | let x = self % m; 238 | if x == 0 { 239 | 0 240 | } else { 241 | m - x 242 | } 243 | } 244 | 245 | // inverse mod using extended euclidean algorithm 246 | fn invm(self, m: &$T) -> Option<$T> { 247 | // TODO: optimize using https://eprint.iacr.org/2020/972.pdf 248 | let x = if &self >= m { self % m } else { self.clone() }; 249 | 250 | let (mut last_r, mut r) = (m.clone(), x); 251 | let (mut last_t, mut t) = (0, 1); 252 | 253 | while r > 0 { 254 | let (quo, rem) = (last_r / r, last_r % r); 255 | last_r = r; 256 | r = rem; 257 | 258 | let new_t = last_t.subm(quo.mulm(t, m), m); 259 | last_t = t; 260 | t = new_t; 261 | } 262 | 263 | // if r = gcd(self, m) > 1, then inverse doesn't exist 264 | if last_r > 1 { 265 | None 266 | } else { 267 | Some(last_t) 268 | } 269 | } 270 | 271 | #[inline(always)] 272 | fn dblm(self, m: &$T) -> $T { 273 | self.addm(self, m) 274 | } 275 | #[inline(always)] 276 | fn sqm(self, m: &$T) -> $T { 277 | self.mulm(self, m) 278 | } 279 | } 280 | )*); 281 | } 282 | impl_unary_uprim!(u8 u16 u32 u64 u128 usize); 283 | 284 | // forward modular operations to valye by value 285 | macro_rules! impl_mod_ops_by_deref { 286 | ($($T:ty)*) => {$( 287 | // core ops 288 | impl ModularCoreOps<$T, &$T> for &$T { 289 | type Output = $T; 290 | #[inline] 291 | fn addm(self, rhs: $T, m: &$T) -> $T { 292 | (*self).addm(rhs, &m) 293 | } 294 | #[inline] 295 | fn subm(self, rhs: $T, m: &$T) -> $T { 296 | (*self).subm(rhs, &m) 297 | } 298 | #[inline] 299 | fn mulm(self, rhs: $T, m: &$T) -> $T { 300 | (*self).mulm(rhs, &m) 301 | } 302 | } 303 | impl ModularCoreOps<&$T, &$T> for $T { 304 | type Output = $T; 305 | #[inline] 306 | fn addm(self, rhs: &$T, m: &$T) -> $T { 307 | self.addm(*rhs, &m) 308 | } 309 | #[inline] 310 | fn subm(self, rhs: &$T, m: &$T) -> $T { 311 | self.subm(*rhs, &m) 312 | } 313 | #[inline] 314 | fn mulm(self, rhs: &$T, m: &$T) -> $T { 315 | self.mulm(*rhs, &m) 316 | } 317 | } 318 | impl ModularCoreOps<&$T, &$T> for &$T { 319 | type Output = $T; 320 | #[inline] 321 | fn addm(self, rhs: &$T, m: &$T) -> $T { 322 | (*self).addm(*rhs, &m) 323 | } 324 | #[inline] 325 | fn subm(self, rhs: &$T, m: &$T) -> $T { 326 | (*self).subm(*rhs, &m) 327 | } 328 | #[inline] 329 | fn mulm(self, rhs: &$T, m: &$T) -> $T { 330 | (*self).mulm(*rhs, &m) 331 | } 332 | } 333 | 334 | // pow 335 | impl ModularPow<$T, &$T> for &$T { 336 | type Output = $T; 337 | #[inline] 338 | fn powm(self, exp: $T, m: &$T) -> $T { 339 | (*self).powm(exp, &m) 340 | } 341 | } 342 | impl ModularPow<&$T, &$T> for $T { 343 | type Output = $T; 344 | #[inline] 345 | fn powm(self, exp: &$T, m: &$T) -> $T { 346 | self.powm(*exp, &m) 347 | } 348 | } 349 | impl ModularPow<&$T, &$T> for &$T { 350 | type Output = $T; 351 | #[inline] 352 | fn powm(self, exp: &$T, m: &$T) -> $T { 353 | (*self).powm(*exp, &m) 354 | } 355 | } 356 | 357 | // unary ops 358 | impl ModularUnaryOps<&$T> for &$T { 359 | type Output = $T; 360 | 361 | #[inline] 362 | fn negm(self, m: &$T) -> $T { 363 | ModularUnaryOps::<&$T>::negm(*self, m) 364 | } 365 | #[inline] 366 | fn invm(self, m: &$T) -> Option<$T> { 367 | ModularUnaryOps::<&$T>::invm(*self, m) 368 | } 369 | #[inline] 370 | fn dblm(self, m: &$T) -> $T { 371 | ModularUnaryOps::<&$T>::dblm(*self, m) 372 | } 373 | #[inline] 374 | fn sqm(self, m: &$T) -> $T { 375 | ModularUnaryOps::<&$T>::sqm(*self, m) 376 | } 377 | } 378 | )*}; 379 | } 380 | 381 | impl_mod_ops_by_deref!(u8 u16 u32 u64 u128 usize); 382 | 383 | macro_rules! impl_absm_for_prim { 384 | ($($signed:ty => $unsigned:ty;)*) => {$( 385 | impl ModularAbs<$unsigned> for $signed { 386 | fn absm(self, m: &$unsigned) -> $unsigned { 387 | if self >= 0 { 388 | (self as $unsigned) % m 389 | } else { 390 | (-self as $unsigned).negm(m) 391 | } 392 | } 393 | } 394 | )*}; 395 | } 396 | 397 | impl_absm_for_prim! { 398 | i8 => u8; i16 => u16; i32 => u32; i64 => u64; i128 => u128; isize => usize; 399 | } 400 | 401 | macro_rules! impl_div_exact_for_prim { 402 | ($($t:ty)*) => {$( 403 | impl DivExact<$t, ()> for $t { 404 | type Output = $t; 405 | #[inline] 406 | fn div_exact(self, d: $t, _: &()) -> Option { 407 | let (q, r) = (self / d, self % d); 408 | if r == 0 { 409 | Some(q) 410 | } else { 411 | None 412 | } 413 | } 414 | } 415 | )*}; 416 | } 417 | 418 | impl_div_exact_for_prim!(u8 u16 u32 u64 u128); 419 | 420 | #[cfg(test)] 421 | mod tests { 422 | use super::*; 423 | use core::ops::Neg; 424 | use rand::random; 425 | 426 | const NRANDOM: u32 = 10; // number of random tests to run 427 | 428 | #[test] 429 | fn addm_test() { 430 | // fixed cases 431 | const CASES: [(u8, u8, u8, u8); 10] = [ 432 | // [m, x, y, rem]: x + y = rem (mod m) 433 | (5, 0, 0, 0), 434 | (5, 1, 2, 3), 435 | (5, 2, 1, 3), 436 | (5, 2, 2, 4), 437 | (5, 3, 2, 0), 438 | (5, 2, 3, 0), 439 | (5, 6, 1, 2), 440 | (5, 1, 6, 2), 441 | (5, 11, 7, 3), 442 | (5, 7, 11, 3), 443 | ]; 444 | 445 | for &(m, x, y, r) in CASES.iter() { 446 | assert_eq!(x.addm(y, &m), r); 447 | assert_eq!((x as u16).addm(y as u16, &(m as _)), r as _); 448 | assert_eq!((x as u32).addm(y as u32, &(m as _)), r as _); 449 | assert_eq!((x as u64).addm(y as u64, &(m as _)), r as _); 450 | assert_eq!((x as u128).addm(y as u128, &(m as _)), r as _); 451 | } 452 | 453 | // random cases for u64 and u128 454 | for _ in 0..NRANDOM { 455 | let a = random::() as u64; 456 | let b = random::() as u64; 457 | let m = random::() as u64; 458 | assert_eq!(a.addm(b, &m), (a + b) % m); 459 | assert_eq!( 460 | a.addm(b, &(1u64 << 32)) as u32, 461 | (a as u32).wrapping_add(b as u32) 462 | ); 463 | 464 | let a = random::() as u128; 465 | let b = random::() as u128; 466 | let m = random::() as u128; 467 | assert_eq!(a.addm(b, &m), (a + b) % m); 468 | assert_eq!( 469 | a.addm(b, &(1u128 << 64)) as u64, 470 | (a as u64).wrapping_add(b as u64) 471 | ); 472 | } 473 | } 474 | 475 | #[test] 476 | fn subm_test() { 477 | // fixed cases 478 | const CASES: [(u8, u8, u8, u8); 10] = [ 479 | // [m, x, y, rem]: x - y = rem (mod m) 480 | (7, 0, 0, 0), 481 | (7, 11, 9, 2), 482 | (7, 5, 2, 3), 483 | (7, 2, 5, 4), 484 | (7, 6, 7, 6), 485 | (7, 1, 7, 1), 486 | (7, 7, 1, 6), 487 | (7, 0, 6, 1), 488 | (7, 15, 1, 0), 489 | (7, 1, 15, 0), 490 | ]; 491 | 492 | for &(m, x, y, r) in CASES.iter() { 493 | assert_eq!(x.subm(y, &m), r); 494 | assert_eq!((x as u16).subm(y as u16, &(m as _)), r as _); 495 | assert_eq!((x as u32).subm(y as u32, &(m as _)), r as _); 496 | assert_eq!((x as u64).subm(y as u64, &(m as _)), r as _); 497 | assert_eq!((x as u128).subm(y as u128, &(m as _)), r as _); 498 | } 499 | 500 | // random cases for u64 and u128 501 | for _ in 0..NRANDOM { 502 | let a = random::() as u64; 503 | let b = random::() as u64; 504 | let m = random::() as u64; 505 | assert_eq!( 506 | a.subm(b, &m), 507 | (a as i64 - b as i64).rem_euclid(m as i64) as u64 508 | ); 509 | assert_eq!( 510 | a.subm(b, &(1u64 << 32)) as u32, 511 | (a as u32).wrapping_sub(b as u32) 512 | ); 513 | 514 | let a = random::() as u128; 515 | let b = random::() as u128; 516 | let m = random::() as u128; 517 | assert_eq!( 518 | a.subm(b, &m), 519 | (a as i128 - b as i128).rem_euclid(m as i128) as u128 520 | ); 521 | assert_eq!( 522 | a.subm(b, &(1u128 << 64)) as u64, 523 | (a as u64).wrapping_sub(b as u64) 524 | ); 525 | } 526 | } 527 | 528 | #[test] 529 | fn negm_and_absm_test() { 530 | // fixed cases 531 | const CASES: [(u8, u8, u8); 5] = [ 532 | // [m, x, rem]: -x = rem (mod m) 533 | (5, 0, 0), 534 | (5, 2, 3), 535 | (5, 1, 4), 536 | (5, 5, 0), 537 | (5, 12, 3), 538 | ]; 539 | 540 | for &(m, x, r) in CASES.iter() { 541 | assert_eq!(x.negm(&m), r); 542 | assert_eq!((x as i8).neg().absm(&m), r); 543 | assert_eq!((x as u16).negm(&(m as _)), r as _); 544 | assert_eq!((x as i16).neg().absm(&(m as u16)), r as _); 545 | assert_eq!((x as u32).negm(&(m as _)), r as _); 546 | assert_eq!((x as i32).neg().absm(&(m as u32)), r as _); 547 | assert_eq!((x as u64).negm(&(m as _)), r as _); 548 | assert_eq!((x as i64).neg().absm(&(m as u64)), r as _); 549 | assert_eq!((x as u128).negm(&(m as _)), r as _); 550 | assert_eq!((x as i128).neg().absm(&(m as u128)), r as _); 551 | } 552 | 553 | // random cases for u64 and u128 554 | for _ in 0..NRANDOM { 555 | let a = random::() as u64; 556 | let m = random::() as u64; 557 | assert_eq!(a.negm(&m), (a as i64).neg().rem_euclid(m as i64) as u64); 558 | assert_eq!(a.negm(&(1u64 << 32)) as u32, (a as u32).wrapping_neg()); 559 | 560 | let a = random::() as u128; 561 | let m = random::() as u128; 562 | assert_eq!(a.negm(&m), (a as i128).neg().rem_euclid(m as i128) as u128); 563 | assert_eq!(a.negm(&(1u128 << 64)) as u64, (a as u64).wrapping_neg()); 564 | } 565 | } 566 | 567 | #[test] 568 | fn mulm_test() { 569 | // fixed cases 570 | const CASES: [(u8, u8, u8, u8); 10] = [ 571 | // [m, x, y, rem]: x*y = rem (mod m) 572 | (7, 0, 0, 0), 573 | (7, 11, 9, 1), 574 | (7, 5, 2, 3), 575 | (7, 2, 5, 3), 576 | (7, 6, 7, 0), 577 | (7, 1, 7, 0), 578 | (7, 7, 1, 0), 579 | (7, 0, 6, 0), 580 | (7, 15, 1, 1), 581 | (7, 1, 15, 1), 582 | ]; 583 | 584 | for &(m, x, y, r) in CASES.iter() { 585 | assert_eq!(x.mulm(y, &m), r); 586 | assert_eq!((x as u16).mulm(y as u16, &(m as _)), r as _); 587 | assert_eq!((x as u32).mulm(y as u32, &(m as _)), r as _); 588 | assert_eq!((x as u64).mulm(y as u64, &(m as _)), r as _); 589 | assert_eq!((x as u128).mulm(y as u128, &(m as _)), r as _); 590 | } 591 | 592 | // random cases for u64 and u128 593 | for _ in 0..NRANDOM { 594 | let a = random::() as u64; 595 | let b = random::() as u64; 596 | let m = random::() as u64; 597 | assert_eq!(a.mulm(b, &m), (a * b) % m); 598 | assert_eq!( 599 | a.mulm(b, &(1u64 << 32)) as u32, 600 | (a as u32).wrapping_mul(b as u32) 601 | ); 602 | 603 | let a = random::() as u128; 604 | let b = random::() as u128; 605 | let m = random::() as u128; 606 | assert_eq!(a.mulm(b, &m), (a * b) % m); 607 | assert_eq!( 608 | a.mulm(b, &(1u128 << 32)) as u32, 609 | (a as u32).wrapping_mul(b as u32) 610 | ); 611 | } 612 | } 613 | 614 | #[test] 615 | fn powm_test() { 616 | // fixed cases 617 | const CASES: [(u8, u8, u8, u8); 12] = [ 618 | // [m, x, y, rem]: x^y = rem (mod m) 619 | (7, 0, 0, 1), 620 | (7, 11, 9, 1), 621 | (7, 5, 2, 4), 622 | (7, 2, 5, 4), 623 | (7, 6, 7, 6), 624 | (7, 1, 7, 1), 625 | (7, 7, 1, 0), 626 | (7, 0, 6, 0), 627 | (7, 15, 1, 1), 628 | (7, 1, 15, 1), 629 | (7, 255, 255, 6), 630 | (10, 255, 255, 5), 631 | ]; 632 | 633 | for &(m, x, y, r) in CASES.iter() { 634 | assert_eq!(x.powm(y, &m), r); 635 | assert_eq!((x as u16).powm(y as u16, &(m as _)), r as _); 636 | assert_eq!((x as u32).powm(y as u32, &(m as _)), r as _); 637 | assert_eq!((x as u64).powm(y as u64, &(m as _)), r as _); 638 | assert_eq!((x as u128).powm(y as u128, &(m as _)), r as _); 639 | } 640 | } 641 | 642 | #[test] 643 | fn invm_test() { 644 | // fixed cases 645 | const CASES: [(u64, u64, u64); 8] = [ 646 | // [a, m, x] s.t. a*x = 1 (mod m) is satisfied 647 | (5, 11, 9), 648 | (8, 11, 7), 649 | (10, 11, 10), 650 | (3, 5000, 1667), 651 | (1667, 5000, 3), 652 | (999, 5000, 3999), 653 | (999, 9_223_372_036_854_775_807, 3_619_181_019_466_538_655), 654 | ( 655 | 9_223_372_036_854_775_804, 656 | 9_223_372_036_854_775_807, 657 | 3_074_457_345_618_258_602, 658 | ), 659 | ]; 660 | 661 | for &(a, m, x) in CASES.iter() { 662 | assert_eq!(a.invm(&m).unwrap(), x); 663 | } 664 | 665 | // random cases for u64 and u128 666 | for _ in 0..NRANDOM { 667 | let a = random::() as u64; 668 | let m = random::() as u64; 669 | if let Some(ia) = a.invm(&m) { 670 | assert_eq!(a.mulm(ia, &m), 1); 671 | } 672 | 673 | let a = random::() as u128; 674 | let m = random::() as u128; 675 | if let Some(ia) = a.invm(&m) { 676 | assert_eq!(a.mulm(ia, &m), 1); 677 | } 678 | } 679 | } 680 | 681 | #[test] 682 | fn dblm_and_sqm_test() { 683 | // random cases for u64 and u128 684 | for _ in 0..NRANDOM { 685 | let a = random::(); 686 | let m = random::(); 687 | assert_eq!(a.addm(a, &m), a.dblm(&m)); 688 | assert_eq!(a.mulm(2, &m), a.dblm(&m)); 689 | assert_eq!(a.mulm(a, &m), a.sqm(&m)); 690 | assert_eq!(a.powm(2, &m), a.sqm(&m)); 691 | 692 | let a = random::(); 693 | let m = random::(); 694 | assert_eq!(a.addm(a, &m), a.dblm(&m)); 695 | assert_eq!(a.mulm(2, &m), a.dblm(&m)); 696 | assert_eq!(a.mulm(a, &m), a.sqm(&m)); 697 | assert_eq!(a.powm(2, &m), a.sqm(&m)); 698 | } 699 | } 700 | 701 | #[test] 702 | fn legendre_test() { 703 | const CASES: [(u8, u8, i8); 18] = [ 704 | (0, 11, 0), 705 | (1, 11, 1), 706 | (2, 11, -1), 707 | (4, 11, 1), 708 | (7, 11, -1), 709 | (10, 11, -1), 710 | (0, 17, 0), 711 | (1, 17, 1), 712 | (2, 17, 1), 713 | (4, 17, 1), 714 | (9, 17, 1), 715 | (10, 17, -1), 716 | (0, 101, 0), 717 | (1, 101, 1), 718 | (2, 101, -1), 719 | (4, 101, 1), 720 | (9, 101, 1), 721 | (10, 101, -1), 722 | ]; 723 | 724 | for &(a, n, res) in CASES.iter() { 725 | assert_eq!(a.legendre(&n), res); 726 | assert_eq!((a as u16).legendre(&(n as u16)), res); 727 | assert_eq!((a as u32).legendre(&(n as u32)), res); 728 | assert_eq!((a as u64).legendre(&(n as u64)), res); 729 | assert_eq!((a as u128).legendre(&(n as u128)), res); 730 | } 731 | 732 | const SIGNED_CASES: [(i8, i8, i8); 15] = [ 733 | (-10, 11, 1), 734 | (-7, 11, 1), 735 | (-4, 11, -1), 736 | (-2, 11, 1), 737 | (-1, 11, -1), 738 | (-10, 17, -1), 739 | (-9, 17, 1), 740 | (-4, 17, 1), 741 | (-2, 17, 1), 742 | (-1, 17, 1), 743 | (-10, 101, -1), 744 | (-9, 101, 1), 745 | (-4, 101, 1), 746 | (-2, 101, -1), 747 | (-1, 101, 1), 748 | ]; 749 | 750 | for &(a, n, res) in SIGNED_CASES.iter() { 751 | assert_eq!(a.legendre(&n), res); 752 | assert_eq!((a as i16).legendre(&(n as i16)), res); 753 | assert_eq!((a as i32).legendre(&(n as i32)), res); 754 | assert_eq!((a as i64).legendre(&(n as i64)), res); 755 | assert_eq!((a as i128).legendre(&(n as i128)), res); 756 | } 757 | } 758 | 759 | #[test] 760 | fn jacobi_test() { 761 | const CASES: [(u8, u8, i8); 15] = [ 762 | (1, 1, 1), 763 | (15, 1, 1), 764 | (2, 3, -1), 765 | (29, 9, 1), 766 | (4, 11, 1), 767 | (17, 11, -1), 768 | (19, 29, -1), 769 | (10, 33, -1), 770 | (11, 33, 0), 771 | (12, 33, 0), 772 | (14, 33, -1), 773 | (15, 33, 0), 774 | (15, 37, -1), 775 | (29, 59, 1), 776 | (30, 59, -1), 777 | ]; 778 | 779 | for &(a, n, res) in CASES.iter() { 780 | assert_eq!(a.jacobi(&n), res, "{}, {}", a, n); 781 | assert_eq!((a as u16).jacobi(&(n as u16)), res); 782 | assert_eq!((a as u32).jacobi(&(n as u32)), res); 783 | assert_eq!((a as u64).jacobi(&(n as u64)), res); 784 | assert_eq!((a as u128).jacobi(&(n as u128)), res); 785 | } 786 | 787 | const SIGNED_CASES: [(i8, i8, i8); 15] = [ 788 | (-10, 15, 0), 789 | (-7, 15, 1), 790 | (-4, 15, -1), 791 | (-2, 15, -1), 792 | (-1, 15, -1), 793 | (-10, 13, 1), 794 | (-9, 13, 1), 795 | (-4, 13, 1), 796 | (-2, 13, -1), 797 | (-1, 13, 1), 798 | (-10, 11, 1), 799 | (-9, 11, -1), 800 | (-4, 11, -1), 801 | (-2, 11, 1), 802 | (-1, 11, -1), 803 | ]; 804 | 805 | for &(a, n, res) in SIGNED_CASES.iter() { 806 | assert_eq!(a.jacobi(&n), res); 807 | assert_eq!((a as i16).jacobi(&(n as i16)), res); 808 | assert_eq!((a as i32).jacobi(&(n as i32)), res); 809 | assert_eq!((a as i64).jacobi(&(n as i64)), res); 810 | assert_eq!((a as i128).jacobi(&(n as i128)), res); 811 | } 812 | } 813 | 814 | #[test] 815 | fn kronecker_test() { 816 | const CASES: [(u8, u8, i8); 18] = [ 817 | (0, 15, 0), 818 | (1, 15, 1), 819 | (2, 15, 1), 820 | (4, 15, 1), 821 | (7, 15, -1), 822 | (10, 15, 0), 823 | (0, 14, 0), 824 | (1, 14, 1), 825 | (2, 14, 0), 826 | (4, 14, 0), 827 | (9, 14, 1), 828 | (10, 14, 0), 829 | (0, 11, 0), 830 | (1, 11, 1), 831 | (2, 11, -1), 832 | (4, 11, 1), 833 | (9, 11, 1), 834 | (10, 11, -1), 835 | ]; 836 | 837 | for &(a, n, res) in CASES.iter() { 838 | assert_eq!(a.kronecker(&n), res); 839 | assert_eq!((a as u16).kronecker(&(n as u16)), res); 840 | assert_eq!((a as u32).kronecker(&(n as u32)), res); 841 | assert_eq!((a as u64).kronecker(&(n as u64)), res); 842 | assert_eq!((a as u128).kronecker(&(n as u128)), res); 843 | } 844 | 845 | const SIGNED_CASES: [(i8, i8, i8); 37] = [ 846 | (-10, 15, 0), 847 | (-7, 15, 1), 848 | (-4, 15, -1), 849 | (-2, 15, -1), 850 | (-1, 15, -1), 851 | (-10, 14, 0), 852 | (-9, 14, -1), 853 | (-4, 14, 0), 854 | (-2, 14, 0), 855 | (-1, 14, -1), 856 | (-10, 11, 1), 857 | (-9, 11, -1), 858 | (-4, 11, -1), 859 | (-2, 11, 1), 860 | (-1, 11, -1), 861 | (-10, -11, -1), 862 | (-9, -11, 1), 863 | (-4, -11, 1), 864 | (-2, -11, -1), 865 | (-1, -11, 1), 866 | (0, -11, 0), 867 | (1, -11, 1), 868 | (2, -11, -1), 869 | (4, -11, 1), 870 | (9, -11, 1), 871 | (10, -11, -1), 872 | (-10, 32, 0), 873 | (-9, 32, 1), 874 | (-4, 32, 0), 875 | (-2, 32, 0), 876 | (-1, 32, 1), 877 | (0, 32, 0), 878 | (1, 32, 1), 879 | (2, 32, 0), 880 | (4, 32, 0), 881 | (9, 32, 1), 882 | (10, 32, 0), 883 | ]; 884 | 885 | for &(a, n, res) in SIGNED_CASES.iter() { 886 | assert_eq!(a.kronecker(&n), res, "{}, {}", a, n); 887 | assert_eq!((a as i16).kronecker(&(n as i16)), res); 888 | assert_eq!((a as i32).kronecker(&(n as i32)), res); 889 | assert_eq!((a as i64).kronecker(&(n as i64)), res); 890 | assert_eq!((a as i128).kronecker(&(n as i128)), res); 891 | } 892 | } 893 | } 894 | -------------------------------------------------------------------------------- /src/reduced.rs: -------------------------------------------------------------------------------- 1 | use crate::{udouble, ModularInteger, ModularUnaryOps, Reducer}; 2 | use core::ops::*; 3 | #[cfg(feature = "num-traits")] 4 | use num_traits::{Inv, Pow}; 5 | 6 | /// An integer in a modulo ring 7 | #[derive(Debug, Clone, Copy)] 8 | pub struct ReducedInt> { 9 | /// The reduced representation of the integer in a modulo ring. 10 | a: T, 11 | 12 | /// The reducer for the integer 13 | r: R, 14 | } 15 | 16 | impl> ReducedInt { 17 | /// Convert n into the modulo ring ℤ/mℤ (i.e. `n % m`) 18 | #[inline] 19 | pub fn new(n: T, m: &T) -> Self { 20 | let r = R::new(m); 21 | let a = r.transform(n); 22 | Self { a, r } 23 | } 24 | 25 | #[inline(always)] 26 | fn check_modulus_eq(&self, rhs: &Self) 27 | where 28 | T: PartialEq, 29 | { 30 | // we don't directly compare m because m could be empty in case of Mersenne modular integer 31 | if cfg!(debug_assertions) && self.r.modulus() != rhs.r.modulus() { 32 | panic!("The modulus of two operators should be the same!"); 33 | } 34 | } 35 | 36 | #[inline(always)] 37 | pub fn repr(&self) -> &T { 38 | &self.a 39 | } 40 | 41 | #[inline(always)] 42 | pub fn inv(self) -> Option { 43 | Some(Self { 44 | a: self.r.inv(self.a)?, 45 | r: self.r, 46 | }) 47 | } 48 | 49 | #[inline(always)] 50 | pub fn pow(self, exp: &T) -> Self { 51 | Self { 52 | a: self.r.pow(self.a, exp), 53 | r: self.r, 54 | } 55 | } 56 | } 57 | 58 | impl> PartialEq for ReducedInt { 59 | #[inline] 60 | fn eq(&self, other: &Self) -> bool { 61 | self.check_modulus_eq(other); 62 | self.a == other.a 63 | } 64 | } 65 | 66 | macro_rules! impl_binops { 67 | ($method:ident, impl $op:ident) => { 68 | impl> $op for ReducedInt { 69 | type Output = Self; 70 | fn $method(self, rhs: Self) -> Self::Output { 71 | self.check_modulus_eq(&rhs); 72 | let Self { a, r } = self; 73 | let a = r.$method(&a, &rhs.a); 74 | Self { a, r } 75 | } 76 | } 77 | 78 | impl> $op<&Self> for ReducedInt { 79 | type Output = Self; 80 | #[inline] 81 | fn $method(self, rhs: &Self) -> Self::Output { 82 | self.check_modulus_eq(&rhs); 83 | let Self { a, r } = self; 84 | let a = r.$method(&a, &rhs.a); 85 | Self { a, r } 86 | } 87 | } 88 | 89 | impl> $op> for &ReducedInt { 90 | type Output = ReducedInt; 91 | #[inline] 92 | fn $method(self, rhs: ReducedInt) -> Self::Output { 93 | self.check_modulus_eq(&rhs); 94 | let ReducedInt { a, r } = rhs; 95 | let a = r.$method(&self.a, &a); 96 | ReducedInt { a, r } 97 | } 98 | } 99 | 100 | impl + Clone> $op<&ReducedInt> 101 | for &ReducedInt 102 | { 103 | type Output = ReducedInt; 104 | #[inline] 105 | fn $method(self, rhs: &ReducedInt) -> Self::Output { 106 | self.check_modulus_eq(&rhs); 107 | let a = self.r.$method(&self.a, &rhs.a); 108 | ReducedInt { 109 | a, 110 | r: self.r.clone(), 111 | } 112 | } 113 | } 114 | 115 | impl> $op for ReducedInt { 116 | type Output = Self; 117 | fn $method(self, rhs: T) -> Self::Output { 118 | let Self { a, r } = self; 119 | let rhs = r.transform(rhs); 120 | let a = r.$method(&a, &rhs); 121 | Self { a, r } 122 | } 123 | } 124 | }; 125 | } 126 | impl_binops!(add, impl Add); 127 | impl_binops!(sub, impl Sub); 128 | impl_binops!(mul, impl Mul); 129 | 130 | impl> Neg for ReducedInt { 131 | type Output = Self; 132 | #[inline] 133 | fn neg(self) -> Self::Output { 134 | let Self { a, r } = self; 135 | let a = r.neg(a); 136 | Self { a, r } 137 | } 138 | } 139 | impl + Clone> Neg for &ReducedInt { 140 | type Output = ReducedInt; 141 | #[inline] 142 | fn neg(self) -> Self::Output { 143 | let a = self.r.neg(self.a.clone()); 144 | ReducedInt { 145 | a, 146 | r: self.r.clone(), 147 | } 148 | } 149 | } 150 | 151 | const INV_ERR_MSG: &str = "the modular inverse doesn't exist!"; 152 | 153 | #[cfg(feature = "num-traits")] 154 | impl> Inv for ReducedInt { 155 | type Output = Self; 156 | #[inline] 157 | fn inv(self) -> Self::Output { 158 | self.inv().expect(INV_ERR_MSG) 159 | } 160 | } 161 | #[cfg(feature = "num-traits")] 162 | impl + Clone> Inv for &ReducedInt { 163 | type Output = ReducedInt; 164 | #[inline] 165 | fn inv(self) -> Self::Output { 166 | self.clone().inv().expect(INV_ERR_MSG) 167 | } 168 | } 169 | 170 | impl> Div for ReducedInt { 171 | type Output = Self; 172 | #[inline] 173 | fn div(self, rhs: Self) -> Self::Output { 174 | self.check_modulus_eq(&rhs); 175 | let ReducedInt { a, r } = rhs; 176 | let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG)); 177 | ReducedInt { a, r } 178 | } 179 | } 180 | impl> Div<&ReducedInt> for ReducedInt { 181 | type Output = Self; 182 | #[inline] 183 | fn div(self, rhs: &Self) -> Self::Output { 184 | self.check_modulus_eq(rhs); 185 | let Self { a, r } = self; 186 | let a = r.mul(&a, &r.inv(rhs.a.clone()).expect(INV_ERR_MSG)); 187 | ReducedInt { a, r } 188 | } 189 | } 190 | impl> Div> for &ReducedInt { 191 | type Output = ReducedInt; 192 | #[inline] 193 | fn div(self, rhs: ReducedInt) -> Self::Output { 194 | self.check_modulus_eq(&rhs); 195 | let ReducedInt { a, r } = rhs; 196 | let a = r.mul(&self.a, &r.inv(a).expect(INV_ERR_MSG)); 197 | ReducedInt { a, r } 198 | } 199 | } 200 | impl + Clone> Div<&ReducedInt> for &ReducedInt { 201 | type Output = ReducedInt; 202 | #[inline] 203 | fn div(self, rhs: &ReducedInt) -> Self::Output { 204 | self.check_modulus_eq(rhs); 205 | let a = self 206 | .r 207 | .mul(&self.a, &self.r.inv(rhs.a.clone()).expect(INV_ERR_MSG)); 208 | ReducedInt { 209 | a, 210 | r: self.r.clone(), 211 | } 212 | } 213 | } 214 | 215 | #[cfg(feature = "num-traits")] 216 | impl> Pow for ReducedInt { 217 | type Output = Self; 218 | #[inline] 219 | fn pow(self, rhs: T) -> Self::Output { 220 | ReducedInt::pow(self, &rhs) 221 | } 222 | } 223 | #[cfg(feature = "num-traits")] 224 | impl + Clone> Pow for &ReducedInt { 225 | type Output = ReducedInt; 226 | #[inline] 227 | fn pow(self, rhs: T) -> Self::Output { 228 | let a = self.r.pow(self.a.clone(), &rhs); 229 | ReducedInt { 230 | a, 231 | r: self.r.clone(), 232 | } 233 | } 234 | } 235 | 236 | impl + Clone> ModularInteger for ReducedInt { 237 | type Base = T; 238 | 239 | #[inline] 240 | fn modulus(&self) -> T { 241 | self.r.modulus() 242 | } 243 | 244 | #[inline(always)] 245 | fn residue(&self) -> T { 246 | debug_assert!(self.r.check(&self.a)); 247 | self.r.residue(self.a.clone()) 248 | } 249 | 250 | #[inline(always)] 251 | fn is_zero(&self) -> bool { 252 | self.r.is_zero(&self.a) 253 | } 254 | 255 | #[inline] 256 | fn convert(&self, n: T) -> Self { 257 | Self { 258 | a: self.r.transform(n), 259 | r: self.r.clone(), 260 | } 261 | } 262 | 263 | #[inline] 264 | fn double(self) -> Self { 265 | let Self { a, r } = self; 266 | let a = r.dbl(a); 267 | Self { a, r } 268 | } 269 | 270 | #[inline] 271 | fn square(self) -> Self { 272 | let Self { a, r } = self; 273 | let a = r.sqr(a); 274 | Self { a, r } 275 | } 276 | } 277 | 278 | // An vanilla reducer is also provided here 279 | /// A plain reducer that just use normal [Rem] operators. It will keep the integer 280 | /// in range [0, modulus) after each operation. 281 | #[derive(Debug, Clone, Copy)] 282 | pub struct Vanilla(T); 283 | 284 | macro_rules! impl_uprim_vanilla_core_const { 285 | ($($T:ty)*) => {$( 286 | // These methods are for internal use only, wait for the introduction of const Trait in Rust 287 | impl Vanilla<$T> { 288 | #[inline] 289 | pub(crate) const fn add(m: &$T, lhs: $T, rhs: $T) -> $T { 290 | let (sum, overflow) = lhs.overflowing_add(rhs); 291 | if overflow || sum >= *m { 292 | let (sum2, overflow2) = sum.overflowing_sub(*m); 293 | debug_assert!(overflow == overflow2); 294 | sum2 295 | } else { 296 | sum 297 | } 298 | } 299 | 300 | #[inline] 301 | pub(crate) const fn dbl(m: &$T, target: $T) -> $T { 302 | Self::add(m, target, target) 303 | } 304 | 305 | #[inline] 306 | pub(crate) const fn sub(m: &$T, lhs: $T, rhs: $T) -> $T { 307 | // this implementation should be equivalent to using overflowing_add and _sub after optimization. 308 | if lhs >= rhs { 309 | lhs - rhs 310 | } else { 311 | *m - (rhs - lhs) 312 | } 313 | } 314 | 315 | #[inline] 316 | pub(crate) const fn neg(m: &$T, target: $T) -> $T { 317 | match target { 318 | 0 => 0, 319 | x => *m - x 320 | } 321 | } 322 | } 323 | )*}; 324 | } 325 | impl_uprim_vanilla_core_const!(u8 u16 u32 u64 u128 usize); 326 | 327 | macro_rules! impl_reduced_binary_pow { 328 | ($T:ty) => { 329 | fn pow(&self, base: $T, exp: &$T) -> $T { 330 | match *exp { 331 | 1 => base, 332 | 2 => self.sqr(base), 333 | e => { 334 | let mut multi = base; 335 | let mut exp = e; 336 | let mut result = self.transform(1); 337 | while exp > 0 { 338 | if exp & 1 != 0 { 339 | result = self.mul(&result, &multi); 340 | } 341 | multi = self.sqr(multi); 342 | exp >>= 1; 343 | } 344 | result 345 | } 346 | } 347 | } 348 | }; 349 | } 350 | 351 | pub(crate) use impl_reduced_binary_pow; 352 | 353 | macro_rules! impl_uprim_vanilla_core { 354 | ($single:ty) => { 355 | #[inline(always)] 356 | fn new(m: &$single) -> Self { 357 | assert!(m > &0); 358 | Self(*m) 359 | } 360 | #[inline(always)] 361 | fn transform(&self, target: $single) -> $single { 362 | target % self.0 363 | } 364 | #[inline(always)] 365 | fn check(&self, target: &$single) -> bool { 366 | *target < self.0 367 | } 368 | #[inline(always)] 369 | fn residue(&self, target: $single) -> $single { 370 | target 371 | } 372 | #[inline(always)] 373 | fn modulus(&self) -> $single { 374 | self.0 375 | } 376 | #[inline(always)] 377 | fn is_zero(&self, target: &$single) -> bool { 378 | *target == 0 379 | } 380 | 381 | #[inline(always)] 382 | fn add(&self, lhs: &$single, rhs: &$single) -> $single { 383 | Vanilla::<$single>::add(&self.0, *lhs, *rhs) 384 | } 385 | 386 | #[inline(always)] 387 | fn dbl(&self, target: $single) -> $single { 388 | Vanilla::<$single>::dbl(&self.0, target) 389 | } 390 | 391 | #[inline(always)] 392 | fn sub(&self, lhs: &$single, rhs: &$single) -> $single { 393 | Vanilla::<$single>::sub(&self.0, *lhs, *rhs) 394 | } 395 | 396 | #[inline(always)] 397 | fn neg(&self, target: $single) -> $single { 398 | Vanilla::<$single>::neg(&self.0, target) 399 | } 400 | 401 | #[inline(always)] 402 | fn inv(&self, target: $single) -> Option<$single> { 403 | target.invm(&self.0) 404 | } 405 | 406 | impl_reduced_binary_pow!($single); 407 | }; 408 | } 409 | 410 | macro_rules! impl_uprim_vanilla { 411 | ($t:ident, $ns:ident) => { 412 | mod $ns { 413 | use super::*; 414 | use crate::word::$t::*; 415 | 416 | impl Reducer<$t> for Vanilla<$t> { 417 | impl_uprim_vanilla_core!($t); 418 | 419 | #[inline] 420 | fn mul(&self, lhs: &$t, rhs: &$t) -> $t { 421 | (wmul(*lhs, *rhs) % extend(self.0)) as $t 422 | } 423 | 424 | #[inline] 425 | fn sqr(&self, target: $t) -> $t { 426 | (wsqr(target) % extend(self.0)) as $t 427 | } 428 | } 429 | } 430 | }; 431 | } 432 | 433 | impl_uprim_vanilla!(u8, u8_impl); 434 | impl_uprim_vanilla!(u16, u16_impl); 435 | impl_uprim_vanilla!(u32, u32_impl); 436 | impl_uprim_vanilla!(u64, u64_impl); 437 | impl_uprim_vanilla!(usize, usize_impl); 438 | 439 | impl Reducer for Vanilla { 440 | impl_uprim_vanilla_core!(u128); 441 | 442 | #[inline] 443 | fn mul(&self, lhs: &u128, rhs: &u128) -> u128 { 444 | udouble::widening_mul(*lhs, *rhs) % self.0 445 | } 446 | 447 | #[inline] 448 | fn sqr(&self, target: u128) -> u128 { 449 | udouble::widening_square(target) % self.0 450 | } 451 | } 452 | 453 | /// An integer in modulo ring based on conventional [Rem] operations 454 | pub type VanillaInt = ReducedInt>; 455 | 456 | #[cfg(test)] 457 | pub(crate) mod tests { 458 | use super::*; 459 | use crate::{ModularCoreOps, ModularPow, ModularUnaryOps}; 460 | use core::marker::PhantomData; 461 | use rand::random; 462 | 463 | pub(crate) struct ReducedTester(PhantomData); 464 | 465 | macro_rules! impl_reduced_test_for { 466 | ($($T:ty)*) => {$( 467 | impl ReducedTester<$T> { 468 | /// Range of modulus: 469 | /// - random_mode = 0: [1, $T::MAX] 470 | /// - random_mode = 1: [1, $T::MAX] and odd 471 | /// - random_mode = 2: [$T::MAX >> $T::BITS/2, $T::MAX] 472 | pub fn test_against_modops + Copy>(random_mode: i32) { 473 | let m = match random_mode { 474 | 0 => random::<$T>().saturating_add(1), 475 | 1 => random::<$T>().saturating_add(1) | 1, 476 | 2 => random::<$T>().saturating_add(1 << (<$T>::BITS / 2)), 477 | _ => unreachable!() 478 | }; 479 | 480 | let (a, b) = (random::<$T>(), random::<$T>()); 481 | let am = ReducedInt::<$T, R>::new(a, &m); 482 | let bm = ReducedInt::<$T, R>::new(b, &m); 483 | assert_eq!((am + bm).residue(), a.addm(b, &m), "incorrect add"); 484 | assert_eq!((am - bm).residue(), a.subm(b, &m), "incorrect sub"); 485 | assert_eq!((am * bm).residue(), a.mulm(b, &m), "incorrect mul"); 486 | assert_eq!(am.neg().residue(), a.negm(&m), "incorrect neg"); 487 | assert_eq!(am.double().residue(), a.dblm(&m), "incorrect dbl"); 488 | assert_eq!(am.square().residue(), a.sqm(&m), "incorrect sqr"); 489 | 490 | let e = random::() as $T; 491 | assert_eq!(am.pow(&e).residue(), a.powm(e, &m), "incorrect pow"); 492 | if let Some(v) = a.invm(&m) { 493 | assert_eq!(am.inv().unwrap().residue(), v, "incorrect inv"); 494 | } 495 | } 496 | } 497 | )*}; 498 | } 499 | impl_reduced_test_for!(u8 u16 u32 u64 u128 usize); 500 | 501 | #[test] 502 | fn test_against_modops() { 503 | for _ in 0..10 { 504 | ReducedTester::::test_against_modops::>(0); 505 | ReducedTester::::test_against_modops::>(0); 506 | ReducedTester::::test_against_modops::>(0); 507 | ReducedTester::::test_against_modops::>(0); 508 | ReducedTester::::test_against_modops::>(0); 509 | ReducedTester::::test_against_modops::>(0); 510 | } 511 | } 512 | } 513 | -------------------------------------------------------------------------------- /src/word.rs: -------------------------------------------------------------------------------- 1 | macro_rules! simple_word_impl { 2 | ($S:ty, $D:ident) => { 3 | pub type Word = $S; 4 | pub type DoubleWord = $D; 5 | pub use super::$D as DoubleWordModule; 6 | 7 | #[inline(always)] 8 | pub const fn ones(n: u32) -> Word { 9 | if n == 0 { 10 | 0 11 | } else { 12 | Word::MAX >> (Word::BITS - n) 13 | } 14 | } 15 | 16 | #[inline(always)] 17 | pub const fn extend(word: Word) -> DoubleWord { 18 | word as DoubleWord 19 | } 20 | 21 | #[inline(always)] 22 | pub const fn low(dw: DoubleWord) -> Word { 23 | dw as Word 24 | } 25 | 26 | #[inline(always)] 27 | pub const fn high(dw: DoubleWord) -> Word { 28 | (dw >> Word::BITS) as Word 29 | } 30 | 31 | #[inline(always)] 32 | pub const fn split(dw: DoubleWord) -> (Word, Word) { 33 | (low(dw), high(dw)) 34 | } 35 | 36 | #[inline(always)] 37 | pub const fn merge(low: Word, high: Word) -> DoubleWord { 38 | extend(low) | extend(high) << Word::BITS 39 | } 40 | 41 | /// Widening multiplication 42 | #[inline(always)] 43 | pub const fn wmul(a: Word, b: Word) -> DoubleWord { 44 | extend(a) * extend(b) 45 | } 46 | 47 | /// Widening squaring 48 | #[inline(always)] 49 | pub const fn wsqr(a: Word) -> DoubleWord { 50 | extend(a) * extend(a) 51 | } 52 | 53 | /// Narrowing remainder 54 | pub const fn nrem(n: DoubleWord, d: Word) -> Word { 55 | (n % d as DoubleWord) as _ 56 | } 57 | }; 58 | } 59 | use simple_word_impl; 60 | 61 | pub mod u8 { 62 | super::simple_word_impl!(u8, u16); 63 | } 64 | 65 | pub mod u16 { 66 | super::simple_word_impl!(u16, u32); 67 | } 68 | 69 | pub mod u32 { 70 | super::simple_word_impl!(u32, u64); 71 | } 72 | 73 | pub mod u64 { 74 | super::simple_word_impl!(u64, u128); 75 | } 76 | 77 | pub mod usize { 78 | #[cfg(target_pointer_width = "16")] 79 | super::simple_word_impl!(usize, u32); 80 | #[cfg(target_pointer_width = "32")] 81 | super::simple_word_impl!(usize, u64); 82 | #[cfg(target_pointer_width = "64")] 83 | super::simple_word_impl!(usize, u128); 84 | } 85 | 86 | pub mod u128 { 87 | use crate::double::udouble; 88 | pub type Word = u128; 89 | pub type DoubleWord = udouble; 90 | 91 | #[inline] 92 | pub const fn extend(word: Word) -> DoubleWord { 93 | udouble { lo: word, hi: 0 } 94 | } 95 | 96 | #[inline(always)] 97 | pub const fn low(dw: DoubleWord) -> Word { 98 | dw.lo 99 | } 100 | 101 | #[inline(always)] 102 | pub const fn high(dw: DoubleWord) -> Word { 103 | dw.hi 104 | } 105 | 106 | #[inline] 107 | pub const fn split(dw: DoubleWord) -> (Word, Word) { 108 | (dw.lo, dw.hi) 109 | } 110 | 111 | #[inline] 112 | pub const fn merge(low: Word, high: Word) -> DoubleWord { 113 | udouble { lo: low, hi: high } 114 | } 115 | 116 | #[inline] 117 | pub const fn wmul(a: Word, b: Word) -> DoubleWord { 118 | udouble::widening_mul(a, b) 119 | } 120 | 121 | #[inline] 122 | pub const fn wsqr(a: Word) -> DoubleWord { 123 | udouble::widening_square(a) 124 | } 125 | 126 | #[inline] 127 | pub fn nrem(n: DoubleWord, d: Word) -> Word { 128 | n % d 129 | } 130 | } 131 | --------------------------------------------------------------------------------