├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DPF.md ├── LICENSE ├── README.md ├── benchmark.py ├── dpf.py ├── dpf_base ├── aes_core.h ├── dpf.cc └── dpf.h ├── dpf_gpu ├── dpf │ ├── dpf_breadth_first.cu │ ├── dpf_coop.cu │ ├── dpf_hybrid.cu │ └── dpf_naive.cu ├── dpf_benchmark.cu ├── matmul │ └── matmul.cu ├── matmul_benchmark.cu ├── prf │ ├── prf.cu │ └── prf_algos │ │ └── aes_core.h ├── tests │ └── test_128_bit.cu └── utils.h ├── dpf_wrapper.cu ├── imgs ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png └── dpf.png ├── install.sh ├── paper ├── experimental │ ├── .gitignore │ ├── batch_pir │ │ ├── batch_pir_optimization.py │ │ ├── batch_pir_optimization_old.py │ │ ├── modules │ │ │ ├── language_model │ │ │ │ ├── data.py │ │ │ │ ├── language_model.py │ │ │ │ ├── language_model_dataset.py │ │ │ │ ├── main.py │ │ │ │ └── train_model.sh │ │ │ ├── movielens_rec │ │ │ │ └── movielens_dataset.py │ │ │ └── taobao_rec │ │ │ │ ├── taobao_rec_dataset.py │ │ │ │ └── taobao_rec_dataset_v2.py │ │ ├── setup.sh │ │ └── sweep │ │ │ ├── language_model_plot.py │ │ │ ├── movielens_plot.py │ │ │ ├── sweep.py │ │ │ └── taobao_plot.py │ └── codesign │ │ ├── join_batch_pir_accuracy_with_gpu_dpf.py │ │ ├── plot_lm.py │ │ └── plot_rec.py └── kernel │ ├── cpu │ └── dpf_google │ │ ├── Makefile │ │ ├── benchmark.cu │ │ ├── benchmark_multithread_dpf.sh │ │ ├── dpf_helpers.cc │ │ ├── dpf_helpers.h │ │ └── test_dpf_so.cc │ └── gpu │ ├── Makefile │ ├── dpf_base │ ├── dpf.cc │ └── dpf.h │ ├── dpf_gpu │ ├── dpf │ │ ├── dpf_breadth_first.cu │ │ ├── dpf_coop.cu │ │ ├── dpf_hybrid.cu │ │ └── dpf_naive.cu │ ├── dpf_benchmark.cu │ ├── matmul │ │ └── matmul.cu │ ├── matmul_benchmark.cu │ ├── prf │ │ └── prf.cu │ ├── tests │ │ └── test_128_bit.cu │ └── utils.h │ └── scripts │ ├── scrape.py │ └── sweep.sh ├── sample.py └── setup.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to GPU-DPF 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | This codebase is primarily supported through open-source development. As such, any changes to the codebase should be made through pull requests and integrated after code review. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * Please use the same coding style as in other files. 34 | 35 | ## License 36 | By contributing to GPU-DPF, you agree that your contributions will be licensed 37 | under the LICENSE file in the root directory of this source tree. 38 | -------------------------------------------------------------------------------- /DPF.md: -------------------------------------------------------------------------------- 1 | # Private Information Retrieval with Distributed Point Functions 2 | 3 | This note describes how **distributed point functions (DPF)** can be used to enable private accesses to a table stored on two non-colluding servers. 4 | 5 | ## Private Table Lookups 6 | Suppose Meta servers hold a database (**T**) as shown below. The content of the table is known to Meta, but the table needs to be accessed with private indices. 7 | 8 |

9 | 10 |

11 | 12 | Let’s say that a user wants to access the 4th entry (index 3) of the table, obtaining 4, but without Meta knowing which index he or she accessed. How can we do this? 13 | 14 | A useful framework for understanding table lookups is to view them as a dot product. That is, a table lookup is essentially doing a dot product between the entire table and a one-hot vector containing a ‘1’ at the index of the item the user wants to look up, and ‘0’ everywhere else. 15 | The previous example of looking up the 4th entry of table T is shown below. 16 | 17 |

18 | 19 |

20 | 21 | Performing the dot product and returning the result would give the correct answer, 4 – but it would not hide the user’s query! 22 | 23 | To fix this, we can introduce a second non-colluding server that holds the same table T. 24 | 25 | Then, the user can generate a random vector R, and query one server to return the result of the dot product T\*(R+Q), and the other to return the result of the dot-product T\*R (assuming we are working within a finite field). 26 | 27 | The difference between the results of the two servers gives T\*Q, which is exactly the entry the user wanted! Furthermore, no information is leaked as R and (R+Q), when working over a finite field, individually reveal no information about the query! This technique is known as additive blinding, and R and R+Q are referred to as additive secret shares of Q. 28 | 29 |

30 | 31 |

32 | 33 | This is great, but there’s a new issue: the size of the vectors that the user needs to send is big. 34 | 35 | Concretely, the number of entries in these vectors is proportional to the number of entries in the table. If we were privately looking up an element in a table with 10,000,000,000 entries, we would have to send over 1GB of data to the servers! Clearly, this dot-product method, as it is, is impractical for even moderately sized tables (e.g: tables with thousands or millions of entries would incur kilobytes or megabytes of communication cost). 36 | 37 | This is where distributed point functions (DPF) come into play. 38 | 39 | A DPF is a way to compress the secret-shared one-hot vector significantly. Specifically, a **DPF is a cryptographic primitive that, when evaluated, yields secret shares of a vector that is zero everywhere except at a single location**. 40 | 41 | So, how much can DPFs compress secret sharings of one-hot vectors? Amazingly, if N is the number of entries in the table, a DPF can compress these shares to O(log(N)). This means, privately accessing an element in a 10,000,000,000 entry table no longer needs >1GB of communication, but only around 2-3 KB with a DPF! 42 | 43 | Below we’ll see how the distributed point function achieves this. 44 | 45 | ## How Distributed Point Functions Work 46 | 47 | A **DPF** is a cryptographic primitive that allows us to compactly represent secret shares of a one-hot vector. 48 | 49 | Concretely, DPFs allow us to generate a pair of compact keys K1 and K2, which, when expanded, yield secret shares of a one hot vector with alpha as the target non-zero index (note, any scale beta, can be chosen for the output value for the target index): 50 | 51 |

52 | 53 |

54 | 55 | Given this functionality, to perform a private table access, a user would construct keys K1 and K2, send each to the corresponding server, who would then expand the DPF with their key to obtain secret shares of the one-hot vector, then finally perform a dot product with their table to return the result. The user adds the two results from the servers to obtain the plaintext table entry. 56 | 57 | Amazingly, the size of the keys K1 and K2 can be made to be O(log(N)) the size of the length of the table. The key idea is to view the server’s table as a 2-D grid rather than a 1-D array, where the number of rows and columns of the grid is O(sqrt(N)) the number of table elements. 58 | 59 |

60 | 61 |

62 | 63 | Now, if the user’s target table index were say, 5, what the user can do is, assign server 1’s key K1 to be 3 random “nonces” (one for each column of the 2-D grid T’), and assign server 2’s key K2 to be the same “nonces”, except at the column containing the one-hot index. 64 | The two servers would then use these “nonces” as seeds for a random number generator (RNG) to generate values for each column. 65 | 66 |

67 | 68 |

69 | 70 | As shown, for the green columns, which represent indices that are not the target, the DPF expansion would yield the exact same values, which is a secret share of 0! 71 | 72 | For the red column, which contains the target index, the DPF expansion would evaluate to something different, something random. Unfortunately, this entire column is different (but we only want a single particlar entry to be different) – so we need to do some error correction. 73 | 74 | To account for this difference, the user additionally sends two correction words to both servers, each the size of the length of a column. These two correction words have values that look random, but the difference between them is chosen to be the difference between the two RNGs (red column) plus a one at the row index for the target. 75 | 76 | During DPF expansion, the servers add the correction word indexed by the last bit of the nonce for that particular column. By doing this, we effectively “correct” the red column as the last bits for the differing nonces are chosen to be different. The corresponding “green” columns still evaluate to the same number, since they’ve added the exact same correction words. 77 | 78 |

79 | 80 |

81 | 82 | As seen, these codewords allow us to “correct” the differences of the last column. With this, our expansion method now outputs secret shares of a one-hot vector. 83 | 84 | Security is maintained because all data sent to the servers (K1 individually, K2 individually, correction words), as well as the expanded output, look random. Assuming we use 128-bit nonces, it is computationally hard to brute force the seeds (e.g: try out all possible seeds) to do a pattern matching attack. 85 | 86 | What is the required amount of communication for this scheme? As shown, if we have a table with N elements, then the length of the nonces (K1, K2) and codewords is sqrt(N). This is a large improvement over O(N) communication with the naive method. 87 | 88 | But, as discussed earlier, we can do even better. We can reduce the communication down to log(N). 89 | 90 | The critical observation is that, the nonces (e.g: {3, 5, 9} and {3, 5, 2}) can themselves be represented as a DPF, as they are the same at all places except one. We can recursively construct a DPF to generate them. Concretely, instead of viewing the table as a sqrt(N)-by-sqrt(N) grid, we view it as a [2-by-N/2] grid, and recursively construct DPFs for the N/2 nonces. This allows us to reduce communication to log(N) because the number of times we recurse is log(N), and we have only two correction words per level. In practice, using the log(N) scheme, a private query to a table with ~1,000,000 entries takes around 1 KB of communication. 91 | 92 | ## Conclusion 93 | A DPF is a cryptographic primitive that, when evaluated, yields secret shares of a vector that is zero everywhere except at a single location. 94 | 95 | DPFs can be used to enable private table accesses to two untrusted non-colluding servers that share a table. They derive their security from other cryptographic primitives such as random number generators and PRFs. 96 | 97 | DPFs have applications to a wide range of cryptographic applications and is an important cryptographic tool for privacy preserving computation. 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPU Distributed Point Functions 2 | 3 | This codebase implements a high-performance GPU implementation of a distributed point function, and exposes a simple and easy to use python interface. 4 | 5 | This repository is the source code for the paper [GPU-based Private Information Retrieval for On-Device ML Inference](https://arxiv.org/abs/2301.10904). 6 | 7 | 8 | ## Background 9 | 10 | A **Distributed Point Function (DPF)** is a cryptographic primitive that allows a client to **efficiently** and **privately** access an entry of a table replicated across two non-colluding servers. 11 | 12 |

13 | 14 |

15 | 16 | The workflow for private table accesses using distributed point functions is: 17 | - Client **generates** two compact keys k_a, k_b that represents the secret index they wish to retrieve 18 | - Client **sends** k_a, k_b across the network to two non-colluding servers Server 1, Server 2 respectively 19 | - Server 1 and Server 2 **evaluate** the keys, and return the result 20 | - Client **sums** the returned shares to obtain the table entry 21 | 22 | By using a DPF 23 | 1) **No information about the client's secret index is revealed** 24 | * Assuming no collusion between servers. 25 | 2) **Network communication costs are minimized** 26 | * Key sizes are *compact* on the order of 2KB for tables with up to 2^32 entries. 27 | 28 | ## How do Distributed Point Functions Work? 29 | 30 | We describe how distributed point functions work [here](https://github.com/facebookresearch/GPU-DPF/blob/main/DPF.md). 31 | 32 | ## Accelerating Distributed Point Functions with GPUs 33 | 34 | Evaluating DPFs is computationally intensive, making GPU acceleration _key_ to obtaining high performance. 35 | 36 | This codebase implements a high-performance GPU implementation of a distributed point function, and exposes a simple and easy to use python interface. By leveraging the GPU we are able to speed up DPF evaluation by over an order of magnitude over a multi-core CPU. 37 | 38 | We accelerate the DPF construction described [here](https://www.iacr.org/archive/eurocrypt2014/84410245/84410245.pdf). This DPF construction generates keys that are **log(n)** the size of the number of entries of the table, and require **O(n)** computation to evaluate the keys. 39 | 40 | ## Requirements 41 | 42 | - python, pytorch, numpy 43 | - CUDA GPU (tested on cuda > 11.4) 44 | 45 | ## Installation 46 | 47 | ``` 48 | bash install.sh 49 | ``` 50 | 51 | To check success, run `python dpf.py`. All checks should pass. 52 | 53 | ## Example 54 | 55 | Example usage (from `sample.py`). See `dpf.py` for more. 56 | 57 | ```python 58 | import sys 59 | import dpf 60 | import torch 61 | 62 | # Table parameters 63 | table_size = 16384 64 | entry_size = 1 65 | 66 | # The actual table (replicated on 2 non-colluding servers) 67 | table = torch.randint(2**31, (table_size, entry_size)).int() 68 | table[42,:] = 42 69 | 70 | def server(k): 71 | 72 | # Server initializes DPF w/ table 73 | dpf_ = dpf.DPF() 74 | dpf_.eval_init(table) 75 | 76 | # Server evaluates DPF on table 77 | return dpf_.eval_gpu([k]) 78 | 79 | def client(): 80 | secret_indx = 42 81 | 82 | # Generate two keys that represents the secret indx 83 | dpf_ = dpf.DPF() 84 | k1, k2 = dpf_.gen(secret_indx, table_size) 85 | 86 | # Send one key to each server to evaluate. 87 | # 88 | # Assuming that these two servers do not collude, 89 | # the servers learn _nothing_ about secret_indx. 90 | a = server(k1).item() 91 | b = server(k2).item() 92 | 93 | rec = a-b 94 | 95 | print(a, b, rec) 96 | assert(rec == 42) 97 | 98 | if __name__=="__main__": 99 | client() 100 | ``` 101 | 102 | ## Benchmark 103 | Benchmark with `python benchmark.py`. Sample output on a P100 GPU. 104 | ``` 105 | DPF(entries=16384, entry_size=16, prf_method=AES128) Key Size: 2096 bytes, Perf: 23954 dpfs/sec 106 | DPF(entries=16384, entry_size=16, prf_method=SALSA20) Key Size: 2096 bytes, Perf: 76073 dpfs/sec 107 | DPF(entries=16384, entry_size=16, prf_method=CHACHA20) Key Size: 2096 bytes, Perf: 75679 dpfs/sec 108 | DPF(entries=65536, entry_size=16, prf_method=AES128) Key Size: 2096 bytes, Perf: 6131 dpfs/sec 109 | DPF(entries=65536, entry_size=16, prf_method=SALSA20) Key Size: 2096 bytes, Perf: 23141 dpfs/sec 110 | DPF(entries=65536, entry_size=16, prf_method=CHACHA20) Key Size: 2096 bytes, Perf: 22433 dpfs/sec 111 | DPF(entries=262144, entry_size=16, prf_method=AES128) Key Size: 2096 bytes, Perf: 1443 dpfs/sec 112 | DPF(entries=262144, entry_size=16, prf_method=SALSA20) Key Size: 2096 bytes, Perf: 5849 dpfs/sec 113 | DPF(entries=262144, entry_size=16, prf_method=CHACHA20) Key Size: 2096 bytes, Perf: 5830 dpfs/sec 114 | DPF(entries=1048576, entry_size=16, prf_method=AES128) Key Size: 2096 bytes, Perf: 379 dpfs/sec 115 | DPF(entries=1048576, entry_size=16, prf_method=SALSA20) Key Size: 2096 bytes, Perf: 1447 dpfs/sec 116 | DPF(entries=1048576, entry_size=16, prf_method=CHACHA20) Key Size: 2096 bytes, Perf: 1424 dpfs/sec 117 | ``` 118 | 119 | Our current implementation supports tables of sizes up to 2^32, at a fixed key size of ~2KB, assuming 16 integers per entry. These can be configured by editing the C++ wrapper code. 120 | 121 | **Note:** We also provide a CPU implementation of a DPF, however it is less optimized, and not the one we compare against in our paper. Please see [google_dpf](https://github.com/google/distributed_point_functions) for a more optimized CPU implementation. 122 | 123 | ## Results 124 | 125 | We compare performance between GPU DPF on a V100 GPU vs [CPU](https://github.com/google/distributed_point_functions), using AES-128 for the PRF, with 16 32-bit values per table entry. 126 | 127 | | # Table Entries | PRF | CPU 1-thread DPFs/sec | CPU 32-thread DPFs/sec | V100 GPU DPFs/sec | Speedup vs 1-thread CPU | Speedup vs 32-thread CPU | 128 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 129 | | 16384 | AES128 | 220 | 2,810 | **52,536** | 238x | 18.7x | 130 | | 65536 | AES128 | 50 | 688 | **15,392** | 308x | 22.3x | 131 | | 262144 | AES128 | 13 | 212 | **3,967** | 305x | 18.7x | 132 | | 1048576 | AES128 | 3 | 55 | **923** | 307x | 16.8x | 133 | 134 | Overall, our GPU DPF implementation attains over **200x** speedup over an optimized single-threaded CPU DPF implementation, and over **15x** speedup over an optimized multi-threaded CPU DPF implementation. 135 | 136 | Further performance numbers for GPU DPF on a V100 GPU for Salsa20/Chacha20 PRFs. 137 | | # Table Entries | PRF | V100 GPU DPFs/sec | 138 | | :-: | :-: | :-: | 139 | | 16384 | SALSA20 | 145,646 | 140 | | 65536 | SALSA20 | 54,892 | 141 | | 262144 | SALSA20 | 16,650 | 142 | | 1048576 | SALSA20 | 3,894 | 143 | | 16384 | CHACHA20 | 139,590 | 144 | | 65536 | CHACHA20 | 56,120 | 145 | | 262144 | CHACHA20 | 16,086 | 146 | | 1048576 | CHACHA20 | 4,054 | 147 | 148 | 149 | ## License 150 | 151 | GPU-DPF is released under the [Apache 2.0 license](https://github.com/facebookresearch/GPU-DPF/blob/main/LICENSE). 152 | 153 | ## Citation 154 | 155 | ``` 156 | @article{lam2023gpudpf, 157 | title={GPU-based Private Information Retrieval for On-Device Machine Learning Inference}, 158 | author={Maximilian Lam and Jeff Johnson and Wenjie Xiong and Kiwan Maeng and Udit Gupta and Minsoo Rhu and Hsien-Hsin S. Lee and Vijay Janapa Reddi and Gu-Yeon Wei and David Brooks and Edward Suh}, 159 | year={2023}, 160 | eprint={2301.10904}, 161 | archivePrefix={arXiv}, 162 | primaryClass={cs.CR} 163 | } 164 | ``` 165 | 166 | ## Disclaimer 167 | 168 | **This open source project is for proof of concept purposes only and should not be used in production environments. The code has not been officially vetted for security vulnerabilities and provides no guarantees of correctness or security. Users should carefully review the code and conduct their own security assessments before using the software.** 169 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | import dpf 3 | 4 | for N in [16384, 65536, 262144, 1048576]: 5 | dpf.test_gpu_dpf_perf(N=N, prf=dpf.DPF.PRF_AES128) 6 | dpf.test_gpu_dpf_perf(N=N, prf=dpf.DPF.PRF_SALSA20) 7 | dpf.test_gpu_dpf_perf(N=N, prf=dpf.DPF.PRF_CHACHA20) 8 | -------------------------------------------------------------------------------- /dpf_base/dpf.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | /* 4 | Serial CPU dpf function based on the sqrt(n) grid trick described 5 | - https://www.youtube.com/watch?v=y2aVgxD7DJc 6 | - https://www.iacr.org/archive/eurocrypt2014/84410245/84410245.pdf 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include "dpf.h" 18 | 19 | int main(int argc, char *argv[]) { 20 | test_log_n_method(); 21 | test_sqrt_n_method(); 22 | benchmark_log_n_method_perf(); 23 | test_flat_codewords(); 24 | } 25 | -------------------------------------------------------------------------------- /dpf_gpu/dpf/dpf_breadth_first.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | // This is like batch size: a block expands _multiple_ dpfs 4 | #define DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK 256 5 | #define DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK 1 6 | 7 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_1, *DPF_BREADTH_PARALLEL_KEYS_2; 8 | 9 | void dpf_breadth_first_initialize(int batch_size, int num_entries) { 10 | cudaMalloc(&DPF_BREADTH_PARALLEL_KEYS_1, sizeof(uint128_t_gpu)*batch_size*num_entries); 11 | cudaMalloc(&DPF_BREADTH_PARALLEL_KEYS_2, sizeof(uint128_t_gpu)*batch_size*num_entries); 12 | } 13 | 14 | __global__ void dpf_breadth_first_kernel(SeedsCodewordsFlatGPU *cw, uint128_t_gpu *out, 15 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_1, 16 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_2, 17 | int batch_size, int num_entries) { 18 | 19 | // Computes DPF expansion in a breadth parallel way 20 | int thread_idx = threadIdx.x; 21 | int block_idx = blockIdx.x; 22 | 23 | // This block handles DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK DPF expansions 24 | int cw_start = blockIdx.x*DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 25 | int cw_end = cw_start+DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 26 | 27 | // Load cw to shared memory 28 | __shared__ SeedsCodewordsFlatGPU cw_shared[DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK]; 29 | if (thread_idx < DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK) { 30 | cw_shared[thread_idx] = cw[cw_start+thread_idx]; 31 | } 32 | 33 | __syncthreads(); 34 | 35 | // Simple recurrence relation for expanding binary tree-based DPF. 36 | // Nodes numbered with following format: 37 | // 0 38 | // / \ 39 | // 0 1 40 | // / \ / \ 41 | // 0 1 2 3 42 | // 43 | // Relation: 44 | // k_1 = seed 45 | // k_i = PRF(k_{i//2}, i % 2) + CW_{k_{i//2} & 1}(i % 2) 46 | // 47 | // Output k_{2^{depth-1}} to k_{2^{depth-1}} + N 48 | // 49 | // Do note, we are expanding multiple binary tree DPFs. 50 | // In this algo, each threadblock expands DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK dpfs. 51 | // Following checks ensure blocking params are correct: 52 | // assert(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK/DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK >= 1) 53 | // assert(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK%DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK == 0) 54 | uint128_t_gpu *key_write = DPF_BREADTH_PARALLEL_KEYS_1; 55 | uint128_t_gpu *key_read = DPF_BREADTH_PARALLEL_KEYS_2; 56 | uint128_t_gpu *tmp; 57 | 58 | constexpr int parallel_work_per_threadblock_per_dpf = (DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK/DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK); 59 | 60 | // Init the first seed 61 | int batch_idx = thread_idx % DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 62 | key_write[0 + (block_idx+batch_idx)*num_entries] = cw_shared[batch_idx].last_keys[0]; 63 | 64 | // Outer loop loops from top level of tree down to bottom 65 | for (int i = cw_shared[0].depth-1; i >= 0; i--) { 66 | 67 | // Swap read and write buffers 68 | tmp = key_read; 69 | key_read = key_write; 70 | key_write = tmp; 71 | 72 | // Can parallelize _within_ a level of the tree, but not _across_ levels of the tree 73 | __syncthreads(); 74 | 75 | // Inner loop scans the current level of the tree (in parallel batches) 76 | int start = 0, end = 1<<(cw_shared[0].depth-i); 77 | for (int j = start; j < end; j += parallel_work_per_threadblock_per_dpf) { 78 | int expansion_idx = j + (thread_idx % parallel_work_per_threadblock_per_dpf); 79 | int batch_idx = thread_idx / parallel_work_per_threadblock_per_dpf; 80 | 81 | if (expansion_idx < end) { 82 | int idx_into_codewords = expansion_idx % 2; 83 | uint128_t_gpu key = key_read[(expansion_idx/2) + (block_idx+batch_idx)*num_entries]; 84 | uint128_t_gpu value = PRF(key, idx_into_codewords); 85 | uint128_t_gpu *cw = (key.x & 1) == 0 ? cw_shared[batch_idx].cw_1 : cw_shared[batch_idx].cw_2; 86 | cw = &cw[i*2]; 87 | key_write[expansion_idx + (block_idx+batch_idx)*num_entries] = add_uint128(value, cw[idx_into_codewords]); 88 | } 89 | } 90 | } 91 | 92 | // Postamble, write to output 93 | for (int i = 0; i < num_entries; i+= parallel_work_per_threadblock_per_dpf) { 94 | int expansion_idx = i + (thread_idx % parallel_work_per_threadblock_per_dpf); 95 | int batch_idx = thread_idx / parallel_work_per_threadblock_per_dpf; 96 | int dst_idx = __brev(expansion_idx) >> (32-cw_shared[0].depth); 97 | 98 | // Do note: the best way to write memory is with _coalescing_. 99 | // Without it, huge performance slowdowns (2.5x slowdown!) 100 | // However, this writes to the output buffer in a permutated order. 101 | out[expansion_idx + batch_idx*num_entries + block_idx*DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK*num_entries] = 102 | key_write[expansion_idx + (block_idx+batch_idx)*num_entries]; 103 | } 104 | } 105 | 106 | void dpf_breadth_first(SeedsCodewordsFlatGPU *cw, 107 | uint128_t_gpu *out, 108 | int batch_size, int num_entries, 109 | cudaStream_t s) { 110 | dim3 n_blocks_breadth_parallel(batch_size / DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK); 111 | dim3 n_threads_breadth_parallel(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK); 112 | 113 | dpf_breadth_first_kernel<<>>(cw, out, 114 | DPF_BREADTH_PARALLEL_KEYS_1, 115 | DPF_BREADTH_PARALLEL_KEYS_2, 116 | batch_size, 117 | num_entries); 118 | } 119 | -------------------------------------------------------------------------------- /dpf_gpu/dpf/dpf_coop.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include 4 | 5 | using namespace cooperative_groups; 6 | 7 | #ifndef FUSES_MATMUL 8 | #define FUSES_MATMUL 1 9 | #endif 10 | 11 | //#define DPF_COOP_N_BLOCKS 64 12 | #define DPF_COOP_THREADS_PER_BLOCK 128 13 | 14 | int DPF_COOP_N_BLOCKS = -1; 15 | 16 | uint128_t_gpu *DPF_COOP_KEYS_1, *DPF_COOP_KEYS_2; 17 | uint128_t_gpu *TABLE_REDUCTION; 18 | 19 | __global__ void dpf_coop_kernel(SeedsCodewordsFlatGPU *cw, 20 | uint128_t_gpu *TABLE, 21 | uint128_t_gpu *TABLE_REDUCTION, 22 | uint128_t_gpu *out, 23 | uint128_t_gpu *DPF_COOP_KEYS_1, 24 | uint128_t_gpu *DPF_COOP_KEYS_2, 25 | int batch_size, int num_entries, 26 | int DPF_COOP_N_BLOCKS) { 27 | 28 | // Computes DPF expansion in a breadth parallel way 29 | int thread_idx = threadIdx.x; 30 | int block_idx = blockIdx.x; 31 | 32 | // Load cw to shared memory. Recall only 1 cw as batchsize=1 33 | __shared__ SeedsCodewordsFlatGPU cw_shared[1]; 34 | if (thread_idx == 0) { 35 | cw_shared[thread_idx] = cw[0]; 36 | } 37 | 38 | // Use cooperative groups to sync blocks 39 | this_grid().sync(); 40 | __syncthreads(); 41 | 42 | // Algorithm same as breadth parallel, see breadth parallel method for high level DPF strat 43 | uint128_t_gpu *key_write = DPF_COOP_KEYS_1; 44 | uint128_t_gpu *key_read = DPF_COOP_KEYS_2; 45 | uint128_t_gpu *tmp; 46 | 47 | // Init the first seed 48 | key_write[0] = cw_shared[0].last_keys[0]; 49 | 50 | // Outer loop loops from top level of tree down to bottom 51 | for (int i = cw_shared[0].depth-1; i >= 0; i--) { 52 | 53 | // Swap read and write buffers 54 | tmp = key_read; 55 | key_read = key_write; 56 | key_write = tmp; 57 | 58 | // Can parallelize _within_ a level of the tree, but not _across_ levels of the tree 59 | this_grid().sync(); 60 | __syncthreads(); 61 | 62 | // Inner loop scans the current level of the tree (in parallel batches) 63 | int start = 0, end = 1<<(cw_shared[0].depth-i); 64 | 65 | // Scan through the work. All threads of each block eval a single PRF 66 | for (int j = start; j < end; j += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 67 | int expansion_idx = j + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 68 | 69 | if (expansion_idx < end) { 70 | int idx_into_codewords = expansion_idx % 2; 71 | uint128_t_gpu key = key_read[(expansion_idx/2)]; 72 | uint128_t_gpu value = PRF(key, idx_into_codewords); 73 | uint128_t_gpu *cw = (key.x & 1) == 0 ? cw_shared[0].cw_1 : cw_shared[0].cw_2; 74 | cw = &cw[i*2]; 75 | key_write[expansion_idx] = add_uint128(value, cw[idx_into_codewords]); 76 | } 77 | } 78 | } 79 | 80 | #if(!FUSES_MATMUL) 81 | // Postamble, write to output 82 | for (int i = 0; i < num_entries; i += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 83 | int expansion_idx = i + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 84 | 85 | // Do note: the best way to write memory is with _coalescing_. 86 | // Without it, huge performance slowdowns (2.5x slowdown!) 87 | // However, this writes to the output buffer in a permutated order. 88 | if (expansion_idx < num_entries) { 89 | out[expansion_idx] = key_write[expansion_idx]; 90 | } 91 | } 92 | #else 93 | 94 | // Fused matmul. Recall MM is num_elements_per_entry 95 | uint128_t_gpu per_thread_accumulate[MM] = {0}; 96 | for (int i = 0; i < num_entries; i += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 97 | int expansion_idx = i + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 98 | if (expansion_idx < num_entries) { 99 | for (int z = 0; z < MM; z++) { 100 | per_thread_accumulate[z] = add_uint128(mul_uint128(key_write[expansion_idx], TABLE[expansion_idx]), 101 | per_thread_accumulate[z]); 102 | } 103 | } 104 | } 105 | 106 | // Tree sum reduction on accumulates 107 | int total_threads = DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK; 108 | int glob_thread_idx = block_idx*DPF_COOP_THREADS_PER_BLOCK+thread_idx; 109 | 110 | // Write local accumulates to table 111 | for (int i = 0; i < MM; i++) { 112 | TABLE_REDUCTION[i*total_threads+glob_thread_idx] = per_thread_accumulate[i]; 113 | } 114 | 115 | this_grid().sync(); 116 | __syncthreads(); 117 | 118 | for (int neighbor = 1; neighbor < total_threads; neighbor*=2) { 119 | if (glob_thread_idx % (neighbor*2) == 0 && glob_thread_idx+neighbor < total_threads) { 120 | for (int z = 0; z < MM; z++) { 121 | TABLE_REDUCTION[z*total_threads+glob_thread_idx] = 122 | add_uint128(TABLE_REDUCTION[z*total_threads+glob_thread_idx], 123 | TABLE_REDUCTION[z*total_threads+glob_thread_idx+neighbor]); 124 | } 125 | } 126 | this_grid().sync(); 127 | __syncthreads(); 128 | } 129 | 130 | if (glob_thread_idx == 0) { 131 | for (int z = 0; z < MM; z++) { 132 | out[z] = TABLE_REDUCTION[z*total_threads+0]; 133 | } 134 | } 135 | 136 | #endif 137 | } 138 | 139 | int getMaxInterpreterGrid(int numThreads) { 140 | int maxBlocksPerSM = 0; 141 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxBlocksPerSM, dpf_coop_kernel, numThreads, 0); 142 | 143 | cudaDeviceProp deviceProp; 144 | cudaGetDeviceProperties(&deviceProp, 0); 145 | int numSM = deviceProp.multiProcessorCount; 146 | 147 | return maxBlocksPerSM * numSM; 148 | } 149 | 150 | void dpf_coop_initialize(int batch_size, int num_entries, int entry_size) { 151 | // Same as breadth parallel strategy, except always uses 152 | // batch 1. 153 | if (batch_size != 1) { 154 | printf("Cooperative threads DPF strategy requires batch_size=1\n"); 155 | } 156 | assert(batch_size == 1); 157 | 158 | DPF_COOP_N_BLOCKS = getMaxInterpreterGrid(DPF_COOP_THREADS_PER_BLOCK); 159 | printf("Coooperative threads DPF strategy with grid size %d\n", DPF_COOP_N_BLOCKS); 160 | 161 | cudaMalloc(&DPF_COOP_KEYS_1, sizeof(uint128_t_gpu)*batch_size*num_entries); 162 | cudaMalloc(&DPF_COOP_KEYS_2, sizeof(uint128_t_gpu)*batch_size*num_entries); 163 | 164 | // Given batch size 1, we also initialize a table of size num_entries*entry_size 165 | // for the purpose of reducing the final accumulates 166 | cudaMalloc(&TABLE_REDUCTION, sizeof(uint128_t_gpu)*entry_size*DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK); 167 | } 168 | 169 | void dpf_coop(SeedsCodewordsFlatGPU * cw, 170 | uint128_t_gpu *out, 171 | uint128_t_gpu *TABLE, 172 | int batch_size, int num_entries, 173 | cudaStream_t s) { 174 | dim3 n_blocks(DPF_COOP_N_BLOCKS); 175 | dim3 n_threads(DPF_COOP_THREADS_PER_BLOCK); 176 | 177 | void *kernel_args[] = 178 | { 179 | &cw, &TABLE, &TABLE_REDUCTION, &out, 180 | &DPF_COOP_KEYS_1, 181 | &DPF_COOP_KEYS_2, 182 | &batch_size, 183 | &num_entries, 184 | &DPF_COOP_N_BLOCKS, 185 | }; 186 | cudaLaunchCooperativeKernel((void *)dpf_coop_kernel, 187 | n_blocks, n_threads, kernel_args); 188 | } 189 | 190 | -------------------------------------------------------------------------------- /dpf_gpu/dpf/dpf_naive.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include "../utils.h" 4 | 5 | #define DPF_NAIVE_BLOCK_W (8) 6 | #define DPF_NAIVE_BLOCK_H (128) 7 | 8 | __device__ uint128_t_gpu expand_dpf_naive_kernel(const SeedsCodewordsFlatGPU *s, int indx) { 9 | 10 | int indx_remaining = indx; 11 | uint128_t_gpu key = s->last_keys[0]; 12 | uint128_t_gpu value; 13 | 14 | for (int i = s->depth-1; i >= 0; i--) { 15 | int indx_into_codewords = indx_remaining % 2; 16 | value = PRF(key, indx_into_codewords); 17 | const uint128_t_gpu *cw = (key.x & 1) == 0 ? s->cw_1 : s->cw_2; 18 | cw = &cw[i*2]; 19 | key = add_uint128(value, cw[indx_into_codewords]); 20 | indx_remaining >>= 1; 21 | } 22 | 23 | return key; 24 | } 25 | 26 | __global__ void dpf_naive_kernel(SeedsCodewordsFlatGPU *cw, 27 | uint128_t_gpu *out, 28 | int batch_size) { 29 | 30 | int x_indx = blockIdx.x*DPF_NAIVE_BLOCK_W + threadIdx.x; 31 | int y_indx = blockIdx.y*DPF_NAIVE_BLOCK_H + threadIdx.y; 32 | int out_indx = y_indx*batch_size + x_indx; 33 | 34 | out[out_indx] = expand_dpf_naive_kernel(&cw[x_indx], y_indx); 35 | } 36 | 37 | void dpf_naive(SeedsCodewordsFlatGPU *cw, 38 | uint128_t_gpu *out, 39 | int batch_size, int num_entries, 40 | cudaStream_t s) { 41 | dim3 threads_per_block_naive(DPF_NAIVE_BLOCK_W, DPF_NAIVE_BLOCK_H); 42 | dim3 n_blocks_naive(batch_size/DPF_NAIVE_BLOCK_W, num_entries/DPF_NAIVE_BLOCK_H); 43 | 44 | //printf("%d %d\n", batch_size/DPF_NAIVE_BLOCK_W, num_entries/DPF_NAIVE_BLOCK_H); 45 | //return; 46 | 47 | dpf_naive_kernel<<>>(cw, out, batch_size); 48 | } 49 | -------------------------------------------------------------------------------- /dpf_gpu/matmul/matmul.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include "../utils.h" 4 | 5 | // Define stride K to iterate over 6 | #define BLOCK_TILE_K 16 7 | 8 | // Define sizes of blocks of output C to operate over in parallel 9 | #define BLOCK_H 4 10 | #define BLOCK_W 4 11 | 12 | // If K is really large, might exceed launch config size restrictions 13 | // Going to hack this to set the right size (TODO: fix) 14 | #define MAX(a,b) \ 15 | ({ __typeof__ (a) _a = (a); \ 16 | __typeof__ (b) _b = (b); \ 17 | _a > _b ? _a : _b; }) 18 | #define BLOCK_K (MAX(128, K/32768)) 19 | 20 | // Tile inner loop, by outer products with dimension 21 | // K reduction dimension is iterated 1 by 1 22 | #define THREAD_BLOCK_H 1 23 | #define THREAD_BLOCK_W 1 24 | 25 | //// Reduction params 26 | #define REDUCTION_THREADS_PER_BLOCK 128 27 | 28 | uint128_t_gpu *MATMUL_TABLE_REDUCTION; 29 | 30 | void initialize_matmul(int M, int K, int N) { 31 | // We further initialize global memory for reducing across the K dimension 32 | assert((K&(K-1)) == 0); 33 | cudaMalloc(&MATMUL_TABLE_REDUCTION, sizeof(uint128_t_gpu)*M*N*K/BLOCK_K); 34 | cudaMemset(MATMUL_TABLE_REDUCTION, 0, sizeof(uint128_t_gpu)*M*N*K/BLOCK_K); 35 | } 36 | 37 | // Matmul of shape: MxK * KxN -> MxN 38 | __global__ void GEMM128_kernel(uint128_t_gpu *A, 39 | uint128_t_gpu *C, 40 | uint128_t_gpu *B, 41 | uint128_t_gpu *MATMUL_TABLE_REDUCTION, 42 | int M, int K, int N) { 43 | 44 | int block_indx_x = blockIdx.x; 45 | int block_indx_y = blockIdx.y; 46 | int block_indx_k = blockIdx.z; 47 | 48 | int thread_indx_x = threadIdx.x; 49 | int thread_indx_y = threadIdx.y; 50 | 51 | int thread_id_within_block = thread_indx_y*BLOCK_W + thread_indx_x; 52 | 53 | // Threads in a block handle block starting from 54 | int block_C_indx_start = block_indx_y*N*BLOCK_H + block_indx_x*BLOCK_W; 55 | 56 | int threads_per_block = (BLOCK_H/THREAD_BLOCK_H)*(BLOCK_W/THREAD_BLOCK_W); 57 | int thread_id = thread_indx_y*(BLOCK_W/THREAD_BLOCK_W)+thread_indx_x; 58 | 59 | __shared__ uint128_t_gpu A_block_local[BLOCK_H][BLOCK_TILE_K+1]; 60 | __shared__ uint128_t_gpu B_block_local[BLOCK_TILE_K][BLOCK_W+1]; 61 | uint128_t_gpu C_frag_local[THREAD_BLOCK_H][THREAD_BLOCK_W] = {0}; 62 | 63 | // This is the same as the nvidia post, loop over entire K dimension 64 | for (int k = block_indx_k*BLOCK_K; k < block_indx_k*BLOCK_K + BLOCK_K; k += BLOCK_TILE_K) { 65 | 66 | // Load blocks of A,B into shared memory in parallel 67 | int block_A_indx_start = block_indx_y*K*BLOCK_H; 68 | int block_B_indx_start = block_indx_x*BLOCK_W; 69 | 70 | for (int i = 0; i < BLOCK_H*BLOCK_TILE_K; i+= threads_per_block) { 71 | int ii = (i+thread_id) / BLOCK_TILE_K; 72 | int jj = (i+thread_id) % BLOCK_TILE_K; 73 | A_block_local[ii][jj] = A[k+block_A_indx_start + ii*K + jj]; 74 | } 75 | 76 | for (int i = 0; i < BLOCK_TILE_K*BLOCK_W; i+= threads_per_block) { 77 | int ii = (i+thread_id) / BLOCK_W; 78 | int jj = (i+thread_id) % BLOCK_W; 79 | //B_block_local[ii][jj] = B[block_B_indx_start + k*N + ii*N + jj]; 80 | B_block_local[ii][jj] = B[(block_B_indx_start+jj)*K + (k+ii)]; 81 | } 82 | 83 | __syncthreads(); 84 | 85 | // Compute over thread block tiles 86 | for (int i = 0; i < BLOCK_TILE_K; i++) { 87 | 88 | // More efficient method should be outer product 89 | // Load fragments into registers 90 | uint128_t_gpu A_frag_local[THREAD_BLOCK_H]; 91 | uint128_t_gpu B_frag_local[THREAD_BLOCK_W]; 92 | 93 | for (int j = 0; j < THREAD_BLOCK_H; j++) { 94 | A_frag_local[j] = A_block_local[j+thread_indx_y*THREAD_BLOCK_H][i]; 95 | } 96 | for (int j = 0; j < THREAD_BLOCK_W; j++) { 97 | B_frag_local[j] = B_block_local[i][j+thread_indx_x*THREAD_BLOCK_W]; 98 | } 99 | 100 | // Outer product into per-thread mem 101 | for (int jj = 0; jj < THREAD_BLOCK_H; jj++) { 102 | for (int kk = 0; kk < THREAD_BLOCK_W; kk++) { 103 | C_frag_local[jj][kk] = add_uint128(C_frag_local[jj][kk], 104 | mul_uint128(A_frag_local[jj], B_frag_local[kk])); 105 | } 106 | } 107 | } 108 | } 109 | 110 | ////////////////////////////////////////////////// 111 | // Reduction across threads in the K dimension // 112 | ///////////////////////////////////////////////// 113 | 114 | // Write C frag locals to intermediate output 115 | int k_stride = M*N; 116 | for (int j = 0; j < THREAD_BLOCK_W; j++) { 117 | for (int i = 0; i < THREAD_BLOCK_H; i++) { 118 | MATMUL_TABLE_REDUCTION[block_indx_k*k_stride + block_C_indx_start + thread_indx_y*THREAD_BLOCK_H*N + thread_indx_x*THREAD_BLOCK_W + i*N + j] = C_frag_local[i][j]; 119 | } 120 | } 121 | } 122 | 123 | __global__ void GEMM128_reduction_kernel(uint128_t_gpu *MATMUL_TABLE_REDUCTION, 124 | uint128_t_gpu *out, 125 | int M, int K, int N) { 126 | int block_indx = blockIdx.x; 127 | int thread_idx = threadIdx.x; 128 | int work_per_block = REDUCTION_THREADS_PER_BLOCK; 129 | int work_indx = block_indx*work_per_block + thread_idx; 130 | 131 | if (work_indx >= M*N) return; 132 | 133 | int k_stride = M*N; 134 | uint128_t_gpu accum[1] = {0}; 135 | for (int k = 0; k < K/BLOCK_K; k++) { 136 | uint128_t_gpu op2 = MATMUL_TABLE_REDUCTION[k*k_stride + work_indx]; 137 | accum[0] = add_uint128(accum[0], op2); 138 | } 139 | 140 | out[work_indx] = accum[0]; 141 | } 142 | 143 | void GEMM128(uint128_t_gpu *A, 144 | uint128_t_gpu *C, 145 | uint128_t_gpu *B, 146 | int M, int K, int N, 147 | cudaStream_t s) { 148 | 149 | assert(BLOCK_W%THREAD_BLOCK_W == 0); 150 | assert(BLOCK_H%THREAD_BLOCK_H == 0); 151 | assert(N%BLOCK_W == 0); 152 | assert(M%BLOCK_H == 0); 153 | 154 | dim3 threads_per_block(BLOCK_W/THREAD_BLOCK_W, BLOCK_H/THREAD_BLOCK_H); 155 | dim3 n_blocks(N/BLOCK_W, M/BLOCK_H, K/BLOCK_K); 156 | 157 | GEMM128_kernel<<>>(A, C, B, MATMUL_TABLE_REDUCTION, M, K, N); 158 | 159 | dim3 threads_per_block_reduce(REDUCTION_THREADS_PER_BLOCK); 160 | dim3 n_blocks_reduce((M*N)/REDUCTION_THREADS_PER_BLOCK+1); 161 | GEMM128_reduction_kernel<<>>(MATMUL_TABLE_REDUCTION, C, M, K, N); 162 | } 163 | -------------------------------------------------------------------------------- /dpf_gpu/matmul_benchmark.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | // Benchmark and test 128-bit matmul for DPF 4 | 5 | #include "utils.h" 6 | #include "matmul/matmul.cu" 7 | 8 | #ifndef REPS 9 | #define REPS 10 10 | #endif 11 | 12 | void print_params() { 13 | printf("------------------------------------------------------\n"); 14 | printf("matmul_benchmark.cu:\n"); 15 | printf("------------------------------------------------------\n"); 16 | printf("- Entries in table (K): %d\n", KK); 17 | printf("- Batch size (N): %d\n", NN); 18 | printf("- Entry size (M): %d\n", MM); 19 | printf("------------------------------------------------------\n"); 20 | } 21 | 22 | void alloc_test_matrix(uint128_t_gpu **A_gpu, 23 | uint128_t_gpu **A_cpu, 24 | int M, int N) { 25 | *A_cpu = new uint128_t_gpu[M*N]; 26 | for (int i = 0; i < M; i++) { 27 | for (int j = 0; j < N; j++) { 28 | (*A_cpu)[i*N+j] = uint128_gpu_from((uint128_t)i*N+j); 29 | } 30 | } 31 | 32 | cudaMalloc(A_gpu, sizeof(uint128_t_gpu)*M*N); 33 | cudaMemcpy(*A_gpu, *A_cpu, sizeof(uint128_t_gpu)*M*N, cudaMemcpyHostToDevice); 34 | } 35 | 36 | void check_correct(uint128_t_gpu *A, 37 | uint128_t_gpu *B, 38 | uint128_t_gpu *C, 39 | int M, int K, int N) { 40 | uint128_t_gpu *C_ref = new uint128_t_gpu[M*N]; 41 | memset(C_ref, 0, sizeof(uint128_t_gpu)*M*N); 42 | 43 | // Compute ref solution 44 | for (int i = 0; i < M; i++) { 45 | for (int j = 0; j < K; j++) { 46 | for (int k = 0; k < N; k++) { 47 | uint128_t c = uint128_from_gpu(C_ref[i*N+k]); 48 | uint128_t a = uint128_from_gpu(A[i*K+j]); 49 | uint128_t b = uint128_from_gpu(B[j+k*K]); 50 | uint128_t accum = c+a*b; 51 | C_ref[i*N+k] = uint128_gpu_from(accum); 52 | } 53 | } 54 | } 55 | 56 | // Assert same 57 | for (int i = 0; i < M; i++) { 58 | for (int j = 0; j < N; j++) { 59 | uint128_t_gpu got = C[i*N+j]; 60 | uint128_t_gpu expected = C_ref[i*N+j]; 61 | 62 | assert(got.x == expected.x && 63 | got.y == expected.y && 64 | got.z == expected.z && 65 | got.w == expected.w); 66 | } 67 | } 68 | 69 | printf("PASS CHECKS\n"); 70 | } 71 | 72 | int main(void) { 73 | print_params(); 74 | 75 | // Alloc & Init buffers 76 | uint128_t_gpu *A_gpu, *B_gpu, *C_gpu; 77 | uint128_t_gpu *A_cpu, *B_cpu, *C_cpu; 78 | 79 | alloc_test_matrix(&A_gpu, &A_cpu, MM, KK); 80 | alloc_test_matrix(&B_gpu, &B_cpu, KK, NN); 81 | alloc_test_matrix(&C_gpu, &C_cpu, MM, NN); 82 | 83 | cudaMemset(C_gpu, 0, sizeof(uint128_t_gpu)*MM*NN); 84 | 85 | // Init 86 | initialize_matmul(MM, KK, NN); 87 | 88 | // Kernel benchmark 89 | cudaStream_t s1; 90 | cudaStreamCreate(&s1); 91 | cudaEvent_t start, stop; 92 | cudaEventCreate(&start); 93 | cudaEventCreate(&stop); 94 | 95 | // Run throughput benchmark 96 | cudaEventRecord(start); 97 | for (int i = 0; i < REPS; i++) { 98 | GEMM128(A_gpu, C_gpu, B_gpu, MM, KK, NN, s1); 99 | } 100 | cudaEventRecord(stop); 101 | cudaEventSynchronize(stop); 102 | 103 | // Run latency benchmark 104 | cudaEvent_t start_latency, stop_latency; 105 | cudaEventCreate(&start_latency); 106 | cudaEventCreate(&stop_latency); 107 | cudaEventRecord(start_latency); 108 | 109 | GEMM128(A_gpu, C_gpu, B_gpu, MM, KK, NN, s1); 110 | 111 | cudaEventRecord(stop_latency); 112 | cudaEventSynchronize(stop_latency); 113 | CUDA_CHECK(cudaGetLastError()); 114 | 115 | // Correctness checks 116 | cudaMemcpy(C_cpu, C_gpu, sizeof(uint128_t_gpu)*MM*NN, cudaMemcpyDeviceToHost); 117 | // check_correct(A_cpu, B_cpu, C_cpu, MM, KK, NN); 118 | 119 | // Stats 120 | float ms = 0; 121 | cudaEventElapsedTime(&ms, start, stop); 122 | float throughput_per_query = NN*REPS/ms; 123 | 124 | float ms_latency = 0; 125 | cudaEventElapsedTime(&ms_latency, start_latency, stop_latency); 126 | 127 | // Final logging output 128 | printf("{'entries (K)': %d, 'entry_size_ints (M)': %d, 'batch_size (N)': %d," 129 | "'latency_ms' : %f, 'throughput_queries_per_ms' : %f'}\n", 130 | KK, MM, NN, 131 | ms_latency, throughput_per_query); 132 | 133 | cudaFree(A_gpu); 134 | cudaFree(B_gpu); 135 | cudaFree(C_gpu); 136 | 137 | delete A_cpu; 138 | delete B_cpu; 139 | delete C_cpu; 140 | } 141 | -------------------------------------------------------------------------------- /dpf_gpu/prf/prf.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include "../utils.h" 4 | #include "prf_algos/aes_core.h" 5 | 6 | #define DUMMY 0 7 | #define SALSA20 1 8 | #define CHACHA20 2 9 | #define AES128 3 10 | 11 | #define PRF_METHOD DUMMY 12 | 13 | std::string get_PRF_method() { 14 | if (PRF_METHOD == DUMMY) return "DUMMY"; 15 | if (PRF_METHOD == SALSA20) return "SALSA20"; 16 | if (PRF_METHOD == CHACHA20) return "CHACHA20"; 17 | if (PRF_METHOD == AES128) return "AES128"; 18 | assert(0); 19 | } 20 | 21 | // Ignore warnings since there are unused variables due to 22 | // swapping out included files 23 | #pragma push 24 | #pragma diag_suppress = 253-D 25 | #pragma diag_suppress = 549-D 26 | #pragma diag_suppress = 550-D 27 | #pragma diag_suppress = code_is_unreachable 28 | #pragma diag_suppress = declared_but_not_referenced 29 | 30 | __device__ uint128_t_gpu PRF_DUMMY(uint128_t_gpu seed, uint32_t i) { 31 | uint128_t_gpu val_4242 = uint128_from(0, 4242); 32 | uint128_t_gpu val_i = uint128_from(0, i); 33 | return add_uint128(mul_uint128(seed, add_uint128(val_4242, val_i)), 34 | add_uint128(val_4242, val_i)); 35 | } 36 | 37 | 38 | // Salsa20 Source: https://en.wikipedia.org/wiki/Salsa20 39 | #define ROTL(a,b) (((a) << (b)) | ((a) >> (32 - (b)))) 40 | #define QR(a, b, c, d)( \ 41 | b ^= ROTL(a + d, 7), \ 42 | c ^= ROTL(b + a, 9), \ 43 | d ^= ROTL(c + b,13), \ 44 | a ^= ROTL(d + c,18)) 45 | 46 | __device__ uint128_t_gpu salsa20_12_gpu(uint128_t_gpu seed, uint32_t pos) { 47 | 48 | // Set up initial state 49 | uint32_t in[16] = {0}; 50 | uint32_t out[16] = {0}; 51 | 52 | // Only use the upper half of 256-bit key 53 | in[1] = (seed.w) & 0xFFFFFFFF; 54 | in[2] = (seed.z) & 0xFFFFFFFF; 55 | in[3] = (seed.y) & 0xFFFFFFFF; 56 | in[4] = (seed.x) & 0xFFFFFFFF; 57 | 58 | // Set position in stream (pos actual value is 32-bit) 59 | in[8] = (pos >> 32) & 0xFFFFFFFF; 60 | in[9] = (pos >> 0) & 0xFFFFFFFF; 61 | 62 | // Rest 63 | in[0] = 0x65787061; 64 | in[5] = 0x6e642033; 65 | in[10] = 0x322d6279; 66 | in[15] = 0x7465206b; 67 | 68 | int i; 69 | uint32_t x[16]; 70 | 71 | for (i = 0; i < 16; ++i) 72 | x[i] = in[i]; 73 | // 10 loops × 2 rounds/loop = 20 rounds 74 | for (i = 0; i < 12; i += 2) { 75 | // Odd round 76 | QR(x[ 0], x[ 4], x[ 8], x[12]); // column 1 77 | QR(x[ 5], x[ 9], x[13], x[ 1]); // column 2 78 | QR(x[10], x[14], x[ 2], x[ 6]); // column 3 79 | QR(x[15], x[ 3], x[ 7], x[11]); // column 4 80 | // Even round 81 | QR(x[ 0], x[ 1], x[ 2], x[ 3]); // row 1 82 | QR(x[ 5], x[ 6], x[ 7], x[ 4]); // row 2 83 | QR(x[10], x[11], x[ 8], x[ 9]); // row 3 84 | QR(x[15], x[12], x[13], x[14]); // row 4 85 | } 86 | for (i = 0; i < 16; ++i) 87 | out[i] = x[i] + in[i]; 88 | 89 | // Use upper half as result of PRF 90 | uint128_t_gpu result; 91 | result.x = out[4]; 92 | result.y = out[3]; 93 | result.z = out[2]; 94 | result.w = out[1]; 95 | return result; 96 | } 97 | 98 | // ChaCha20 Source: https://en.wikipedia.org/wiki/Salsa20 99 | #define ROTL_CHA(a,b) (((a) << (b)) | ((a) >> (32 - (b)))) 100 | #define QR_CHA(a, b, c, d) ( \ 101 | a += b, d ^= a, d = ROTL_CHA(d,16), \ 102 | c += d, b ^= c, b = ROTL_CHA(b,12), \ 103 | a += b, d ^= a, d = ROTL_CHA(d, 8), \ 104 | c += d, b ^= c, b = ROTL_CHA(b, 7)) 105 | 106 | __device__ uint128_t_gpu chacha20_12_gpu(uint128_t_gpu seed, uint32_t pos) 107 | { 108 | 109 | // Set up initial state 110 | uint32_t in[16] = {0}; 111 | uint32_t out[16] = {0}; 112 | 113 | // Only use the upper half of 256-bit key 114 | in[4] = (seed.w) & 0xFFFFFFFF; 115 | in[5] = (seed.z) & 0xFFFFFFFF; 116 | in[6] = (seed.y) & 0xFFFFFFFF; 117 | in[7] = (seed.x) & 0xFFFFFFFF; 118 | 119 | // Set position in stream (pos actual value is 32-bit) 120 | in[12] = (pos >> 32) & 0xFFFFFFFF; 121 | in[13] = (pos >> 0) & 0xFFFFFFFF; 122 | 123 | // Rest 124 | in[0] = 0x65787061; 125 | in[1] = 0x6e642033; 126 | in[2] = 0x322d6279; 127 | in[3] = 0x7465206b; 128 | 129 | int i; 130 | uint32_t x[16]; 131 | 132 | for (i = 0; i < 16; ++i) 133 | x[i] = in[i]; 134 | // 10 loops × 2 rounds/loop = 20 rounds 135 | for (i = 0; i < 12; i += 2) { 136 | // Odd round 137 | QR_CHA(x[0], x[4], x[ 8], x[12]); // column 0 138 | QR_CHA(x[1], x[5], x[ 9], x[13]); // column 1 139 | QR_CHA(x[2], x[6], x[10], x[14]); // column 2 140 | QR_CHA(x[3], x[7], x[11], x[15]); // column 3 141 | // Even round 142 | QR_CHA(x[0], x[5], x[10], x[15]); // diagonal 1 (main diagonal) 143 | QR_CHA(x[1], x[6], x[11], x[12]); // diagonal 2 144 | QR_CHA(x[2], x[7], x[ 8], x[13]); // diagonal 3 145 | QR_CHA(x[3], x[4], x[ 9], x[14]); // diagonal 4 146 | } 147 | for (i = 0; i < 16; ++i) 148 | out[i] = x[i] + in[i]; 149 | 150 | // Use upper half as result of PRF 151 | uint128_t_gpu result; 152 | result.x = out[7]; 153 | result.y = out[6]; 154 | result.z = out[5]; 155 | result.w = out[4]; 156 | return result; 157 | } 158 | 159 | __device__ uint128_t_gpu aes128_gpu(uint128_t_gpu seed, uint32_t pos) { 160 | unsigned char in[16] = {0}; 161 | unsigned char out[16] = {0}; 162 | const int nr = 10; 163 | 164 | // Input to AES is just the counter (no nonce) 165 | uint128_t_gpu pos_128 = {0}; 166 | pos_128.x = pos; 167 | memcpy(in, &pos_128, 16); 168 | 169 | // Key expansion 170 | AES_KEY k; 171 | AES_set_encrypt_key((const unsigned char *)&seed, 172 | 128, &k); 173 | 174 | 175 | // AES 128 176 | AES_encrypt(in, out, &k); 177 | 178 | // Return output 179 | uint128_t_gpu r = {0}; 180 | memcpy(&r, out, 16); 181 | 182 | return r; 183 | 184 | } 185 | 186 | template 187 | __device__ uint128_t_gpu PRF(uint128_t_gpu seed, uint32_t i) { 188 | if (prf_method == DUMMY) { 189 | return PRF_DUMMY(seed, i); 190 | } 191 | if (prf_method == SALSA20) { 192 | return salsa20_12_gpu(seed, i); 193 | } 194 | if (prf_method == CHACHA20) { 195 | return chacha20_12_gpu(seed, i); 196 | } 197 | if (prf_method == AES128) { 198 | return aes128_gpu(seed, i); 199 | } 200 | assert(0); 201 | } 202 | 203 | #pragma pop 204 | 205 | 206 | -------------------------------------------------------------------------------- /dpf_gpu/tests/test_128_bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include "../utils.h" 4 | 5 | void test_uint128_gpu_from() { 6 | uint128_t val = 0x1122334422334455; 7 | val <<= 64; 8 | val |= 0x3344556644556677; 9 | 10 | uint128_t_gpu g1 = uint128_gpu_from(val); 11 | assert(g1.w == 0x11223344); 12 | assert(g1.z == 0x22334455); 13 | assert(g1.y == 0x33445566); 14 | assert(g1.x == 0x44556677); 15 | } 16 | 17 | __global__ void test_uint128_from_kernel(uint128_t_gpu *r) { 18 | *r = uint128_from(0x1234567823456789, 19 | 0x2345678934567890); 20 | } 21 | 22 | void test_uint128_from() { 23 | uint128_t_gpu *r; 24 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 25 | test_uint128_from_kernel<<<1, 1>>>(r); 26 | uint128_t_gpu r_cpu; 27 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 28 | 29 | assert(r_cpu.w == 0x12345678); 30 | assert(r_cpu.z == 0x23456789); 31 | assert(r_cpu.y == 0x23456789); 32 | assert(r_cpu.x == 0x34567890); 33 | 34 | cudaFree(r); 35 | } 36 | 37 | __global__ void test_add_uint128_kernel(uint128_t_gpu *a, 38 | uint128_t_gpu *b, 39 | uint128_t_gpu *r) { 40 | *r = add_uint128(*a, *b); 41 | } 42 | 43 | void test_add_uint128() { 44 | 45 | // Init v1 and v2 for mult 46 | uint128_t v1 = 0x12345678; 47 | v1 <<= 64; 48 | v1 |= 0x23456789; 49 | 50 | uint128_t v2 = 0x34567890; 51 | v2 <<= 64; 52 | v2 |= 0x45678901; 53 | 54 | uint128_t_gpu a = uint128_gpu_from(v1); 55 | uint128_t_gpu b = uint128_gpu_from(v2); 56 | 57 | // Alloc gpu mem 58 | uint128_t_gpu *r; 59 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 60 | 61 | uint128_t_gpu *a_gpu, *b_gpu; 62 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 63 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 64 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 65 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 66 | 67 | test_add_uint128_kernel<<<1, 1>>>(a_gpu, b_gpu, r); 68 | uint128_t_gpu r_cpu; 69 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 70 | 71 | uint128_t truth = v1+v2; 72 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 73 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 74 | assert(r_cpu.w == truth >> 96); 75 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 76 | 77 | cudaFree(r); 78 | cudaFree(a_gpu); 79 | cudaFree(b_gpu); 80 | } 81 | 82 | __global__ void test_mul_uint128_kernel(uint128_t_gpu *a, 83 | uint128_t_gpu *b, 84 | uint128_t_gpu *r) { 85 | *r = mul_uint128(*a, *b); 86 | } 87 | 88 | void test_mul_uint128() { 89 | 90 | // Init v1 and v2 for mult 91 | uint128_t v1 = 0x12345678; 92 | v1 <<= 64; 93 | v1 |= 0x23456789; 94 | 95 | uint128_t v2 = 0x34567890; 96 | v2 <<= 64; 97 | v2 |= 0x45678901; 98 | 99 | uint128_t_gpu a = uint128_gpu_from(v1); 100 | uint128_t_gpu b = uint128_gpu_from(v2); 101 | 102 | // Alloc gpu mem 103 | uint128_t_gpu *r; 104 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 105 | 106 | uint128_t_gpu *a_gpu, *b_gpu; 107 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 108 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 109 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 110 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 111 | 112 | test_mul_uint128_kernel<<<1, 1>>>(a_gpu, b_gpu, r); 113 | uint128_t_gpu r_cpu; 114 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 115 | 116 | uint128_t truth = v1*v2; 117 | 118 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 119 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 120 | assert(r_cpu.w == truth >> 96); 121 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 122 | 123 | cudaFree(r); 124 | cudaFree(a_gpu); 125 | cudaFree(b_gpu); 126 | } 127 | 128 | __global__ void test_mul_uint128_kernel_twice(uint128_t_gpu *a, 129 | uint128_t_gpu *b, 130 | uint128_t_gpu *c, 131 | uint128_t_gpu *r) { 132 | *r = mul_uint128(mul_uint128(*a, *b), *c); 133 | } 134 | 135 | void test_mul_uint128_twice() { 136 | 137 | // Init v1 and v2 for mult 138 | uint128_t v1 = 0x12345678; 139 | v1 <<= 64; 140 | v1 |= 0x23456789; 141 | 142 | uint128_t v2 = 0x34567890; 143 | v2 <<= 64; 144 | v2 |= 0x45678901; 145 | 146 | uint128_t v3 = 0x123; 147 | v3 <<= 64; 148 | v3 |= 0x456; 149 | 150 | uint128_t_gpu a = uint128_gpu_from(v1); 151 | uint128_t_gpu b = uint128_gpu_from(v2); 152 | uint128_t_gpu c = uint128_gpu_from(v3); 153 | 154 | // Alloc gpu mem 155 | uint128_t_gpu *r; 156 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 157 | 158 | uint128_t_gpu *a_gpu, *b_gpu, *c_gpu; 159 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 160 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 161 | cudaMalloc((void **)&c_gpu, sizeof(uint128_t_gpu)); 162 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 163 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); cudaMemcpy(c_gpu, &c, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 164 | 165 | test_mul_uint128_kernel_twice<<<1, 1>>>(a_gpu, b_gpu, c_gpu, r); 166 | uint128_t_gpu r_cpu; 167 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 168 | 169 | uint128_t truth = (v1*v2)*v3; 170 | 171 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 172 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 173 | assert(r_cpu.w == truth >> 96); 174 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 175 | 176 | cudaFree(r); 177 | cudaFree(a_gpu); 178 | cudaFree(b_gpu); 179 | } 180 | 181 | void test_uint128_gpu_conversion() { 182 | for (int i = 0; i < 1000; i++) { 183 | uint128_t k = i * 0x12345; 184 | 185 | uint128_t_gpu v = uint128_gpu_from(k); 186 | uint128_t v_back = uint128_from_gpu(v); 187 | 188 | assert(v_back == k); 189 | } 190 | } 191 | 192 | int main(void) { 193 | test_uint128_gpu_from(); 194 | test_uint128_from(); 195 | test_add_uint128(); 196 | test_mul_uint128(); 197 | test_mul_uint128_twice(); 198 | test_uint128_gpu_conversion(); 199 | printf("PASS\n"); 200 | } 201 | -------------------------------------------------------------------------------- /dpf_gpu/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #ifndef UTILS 4 | #define UTILS 5 | 6 | #include "../dpf_base/dpf.h" 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | //////////////////////////////////////////////////////////////////////////////// 13 | // 128-bit functionalities // 14 | // from: https://stackoverflow.com/questions/6162140/128-bit-integer-on-cuda // 15 | //////////////////////////////////////////////////////////////////////////////// 16 | typedef uint4 uint128_t_gpu; 17 | 18 | uint128_t_gpu uint128_gpu_from(uint128_t val) { 19 | uint128_t_gpu res; 20 | res.w = (val >> 96) & 0xFFFFFFFF; 21 | res.z = (val >> 64) & 0xFFFFFFFF; 22 | res.y = (val >> 32) & 0xFFFFFFFF; 23 | res.x = (val >> 0) & 0xFFFFFFFF; 24 | return res; 25 | } 26 | 27 | uint128_t uint128_from_gpu(uint128_t_gpu val) { 28 | uint128_t res = 0; 29 | return val.x + 30 | ((uint128_t)val.y << 32) + 31 | ((uint128_t)val.z << 64) + 32 | ((uint128_t)val.w << 96); 33 | } 34 | 35 | __device__ uint128_t_gpu uint128_from(uint64_t hi, 36 | uint64_t lo) { 37 | uint128_t_gpu res; 38 | res.w = (hi >> 32); 39 | res.z = hi & 0x00000000FFFFFFFF; 40 | res.y = (lo >> 32); 41 | res.x = lo & 0x00000000FFFFFFFF; 42 | return res; 43 | } 44 | 45 | __device__ uint128_t_gpu add_uint128(uint128_t_gpu addend, uint128_t_gpu augend) 46 | { 47 | uint128_t_gpu res; 48 | asm ("add.cc.u32 %0, %4, %8;\n\t" 49 | "addc.cc.u32 %1, %5, %9;\n\t" 50 | "addc.cc.u32 %2, %6, %10;\n\t" 51 | "addc.u32 %3, %7, %11;\n\t" 52 | : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) 53 | : "r"(addend.x), "r"(addend.y), "r"(addend.z), "r"(addend.w), 54 | "r"(augend.x), "r"(augend.y), "r"(augend.z), "r"(augend.w)); 55 | return res; 56 | } 57 | 58 | __device__ uint128_t_gpu mul_uint128(uint128_t_gpu a, uint128_t_gpu b) 59 | { 60 | uint128_t_gpu res; 61 | asm ("{\n\t" 62 | "mul.lo.u32 %0, %4, %8; \n\t" 63 | "mul.hi.u32 %1, %4, %8; \n\t" 64 | "mad.lo.cc.u32 %1, %4, %9, %1;\n\t" 65 | "madc.hi.u32 %2, %4, %9, 0;\n\t" 66 | "mad.lo.cc.u32 %1, %5, %8, %1;\n\t" 67 | "madc.hi.cc.u32 %2, %5, %8, %2;\n\t" 68 | "madc.hi.u32 %3, %4,%10, 0;\n\t" 69 | "mad.lo.cc.u32 %2, %4,%10, %2;\n\t" 70 | "madc.hi.u32 %3, %5, %9, %3;\n\t" 71 | "mad.lo.cc.u32 %2, %5, %9, %2;\n\t" 72 | "madc.hi.u32 %3, %6, %8, %3;\n\t" 73 | "mad.lo.cc.u32 %2, %6, %8, %2;\n\t" 74 | "madc.lo.u32 %3, %4,%11, %3;\n\t" 75 | "mad.lo.u32 %3, %5,%10, %3;\n\t" 76 | "mad.lo.u32 %3, %6, %9, %3;\n\t" 77 | "mad.lo.u32 %3, %7, %8, %3;\n\t" 78 | "}" 79 | : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) 80 | : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), 81 | "r"(b.x), "r"(b.y), "r"(b.z), "r"(b.w)); 82 | return res; 83 | } 84 | 85 | // Error check functionality 86 | inline void error_check(cudaError_t err, const char* file, int line) { 87 | if(err != cudaSuccess) { 88 | ::fprintf(stderr, "CUDA ERROR at %s[%d] : %s\n", file, line, cudaGetErrorString(err)); 89 | abort(); 90 | } 91 | } 92 | #define CUDA_CHECK(err) do { error_check(err, __FILE__, __LINE__); } while(0) 93 | 94 | // SeedsCodewordsFlat for GPU (replaces uint128_t with vector version) 95 | struct SeedsCodewordsFlatGPU { 96 | int depth; 97 | uint128_t_gpu cw_1[64], cw_2[64]; 98 | uint128_t_gpu last_keys[1]; 99 | }; 100 | 101 | SeedsCodewordsFlatGPU SeedsCodewordsFlatGPUFromCPU(SeedsCodewordsFlat &f) { 102 | SeedsCodewordsFlatGPU g; 103 | g.depth = f.depth; 104 | for (int i = 0; i < 64; i++) { 105 | g.cw_1[i] = uint128_gpu_from(f.cw_1[i]); 106 | g.cw_2[i] = uint128_gpu_from(f.cw_2[i]); 107 | } 108 | g.last_keys[0] = uint128_gpu_from(f.last_keys[0]); 109 | return g; 110 | } 111 | 112 | /*// Generates dummy codewords for testint 113 | std::vector GenCodewords(int k, int n, 114 | SeedsCodewordsFlatGPU **cw_gpu) { 115 | 116 | //auto cw_cpu = std::vector(n); 117 | for (int i = 0; i < n; i++) { 118 | 119 | std::mt19937 g_gen(i); 120 | int alpha = (100+i) % k; 121 | int beta = 4242+i; 122 | 123 | SeedsCodewords *s = GenerateSeedsAndCodewordsLog(alpha, beta, k, g_gen); 124 | FlattenCodewords(s, 0, &cw_cpu[i]); 125 | FreeSeedsCodewords(s); 126 | } 127 | 128 | // Convert codewords to gpu rep 129 | SeedsCodewordsFlatGPU *cw_intermediate = (SeedsCodewordsFlatGPU *)malloc(sizeof(SeedsCodewordsFlatGPU)*n); 130 | for (int i = 0; i < n; i++) { 131 | cw_intermediate[i] = SeedsCodewordsFlatGPUFromCPU(cw_cpu[i]); 132 | } 133 | 134 | cudaMalloc((void **)cw_gpu, sizeof(SeedsCodewordsFlatGPU)*n); 135 | cudaMemcpy(*cw_gpu, cw_intermediate, sizeof(SeedsCodewordsFlatGPU)*(n), cudaMemcpyHostToDevice); 136 | free(cw_intermediate); 137 | 138 | return cw_cpu; 139 | }*/ 140 | 141 | // https://stackoverflow.com/questions/9144800/c-reverse-bits-in-unsigned-integer 142 | uint32_t brev_cpu(uint32_t x) { 143 | x = ((x >> 1) & 0x55555555u) | ((x & 0x55555555u) << 1); 144 | x = ((x >> 2) & 0x33333333u) | ((x & 0x33333333u) << 2); 145 | x = ((x >> 4) & 0x0f0f0f0fu) | ((x & 0x0f0f0f0fu) << 4); 146 | x = ((x >> 8) & 0x00ff00ffu) | ((x & 0x00ff00ffu) << 8); 147 | x = ((x >> 16) & 0xffffu) | ((x & 0xffffu) << 16); 148 | return x; 149 | } 150 | 151 | // Correctness checks the output of GPU kernel code 152 | void check_correct(SeedsCodewordsFlat *cw, uint128_t_gpu *target, 153 | int batch_size, int num_entries, 154 | int permutated_ordering) { 155 | int zz = 0; 156 | for (int i = 0; i < batch_size; i++) { 157 | for (int j = 0; j < num_entries; j++) { 158 | 159 | uint128_t truth = EvaluateFlat(&cw[i], j, 0); 160 | uint128_t_gpu truth_128_t_gpu = uint128_gpu_from(truth); 161 | 162 | // This is the "standard" ordering 163 | uint128_t_gpu got; 164 | if (!permutated_ordering) { 165 | got = target[j*batch_size+i]; 166 | } 167 | else { 168 | // This is the "permutated" ordering 169 | //int tgt_indx = brev_cpu(j) >> 32 - cw[0].depth; 170 | int tgt_indx = brev_cpu(j) >> 32 - (int)log2(num_entries); 171 | got = target[tgt_indx + i*num_entries]; 172 | } 173 | 174 | // For debugging 175 | //printf("Got : %d %d %d %d\n", got.x, got.y, got.z, got.w); 176 | //printf("Expect: %d %d %d %d\n", truth_128_t_gpu.x, truth_128_t_gpu.y, truth_128_t_gpu.z, truth_128_t_gpu.w); 177 | //zz += 1; 178 | //if (zz >= 100) return; 179 | 180 | assert(got.x == truth_128_t_gpu.x && 181 | got.y == truth_128_t_gpu.y && 182 | got.z == truth_128_t_gpu.z && 183 | got.w == truth_128_t_gpu.w); 184 | } 185 | } 186 | printf("PASS\n"); 187 | } 188 | 189 | void check_correct_fused(SeedsCodewordsFlat *cw, uint128_t_gpu *target, uint128_t_gpu *table, 190 | int entry_size, int batch_size, int num_entries) { 191 | for (int i = 0; i < batch_size; i++) { 192 | for (int k = 0; k < entry_size; k++) { 193 | uint128_t accum = 0; 194 | for (int j = 0; j < num_entries; j++) { 195 | uint128_t truth = EvaluateFlat(&cw[i], j, 0); 196 | accum += truth * uint128_from_gpu(table[j+k*num_entries]); 197 | } 198 | 199 | uint128_t_gpu cmp = uint128_gpu_from(accum); 200 | uint128_t_gpu got = target[i+k*batch_size]; 201 | 202 | assert(got.x == cmp.x && 203 | got.y == cmp.y && 204 | got.z == cmp.z && 205 | got.w == cmp.w); 206 | } 207 | } 208 | printf("PASS MATMUL CHECK\n"); 209 | } 210 | 211 | #endif 212 | -------------------------------------------------------------------------------- /dpf_wrapper.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #define gpuErrchk(ans) { gpuAssert((ans), __FILE__, __LINE__); } 8 | inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) 9 | { 10 | if (code != cudaSuccess) 11 | { 12 | fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line); 13 | if (abort) exit(code); 14 | } 15 | } 16 | 17 | // Number of uint128_ts per entry 18 | #define MM 16 19 | 20 | // Batch size 21 | #define BATCH_SIZE 512 22 | 23 | #include "dpf_base/dpf.h" 24 | #include "dpf_gpu/dpf/dpf_hybrid.cu" 25 | 26 | at::Tensor key_from_codewords(SeedsCodewordsFlat *k, int n) { 27 | at::Tensor key = torch::zeros({524}, at::kInt); 28 | uint128_t *key_ptr = (uint128_t *)key.data_ptr(); 29 | key_ptr[0] = k->depth; 30 | memcpy(&key_ptr[1], k->cw_1, sizeof(uint128_t)*64); 31 | memcpy(&key_ptr[65], k->cw_2, sizeof(uint128_t)*64); 32 | key_ptr[129] = k->last_keys[0]; 33 | key_ptr[130] = n; 34 | return key; 35 | } 36 | 37 | SeedsCodewordsFlat *codewords_from_key(at::Tensor &key, int *n) { 38 | SeedsCodewordsFlat *k = new SeedsCodewordsFlat; 39 | uint128_t *key_ptr = (uint128_t *)key.data_ptr(); 40 | k->depth = key_ptr[0]; 41 | memcpy(k->cw_1, &key_ptr[1], sizeof(uint128_t)*64); 42 | memcpy(k->cw_2, &key_ptr[65], sizeof(uint128_t)*64); 43 | k->last_keys[0] = key_ptr[129]; 44 | *n = key_ptr[130]; 45 | return k; 46 | } 47 | 48 | // Note: s1, and s2 are seeds that concatenate to form a 128-bit seed 49 | std::vector gen(int k, int n, char *seed, int prf_method) { 50 | 51 | // Generate DPF codewords and flatten them 52 | std::mt19937 g_gen(*(uint128_t *)seed); 53 | SeedsCodewords *s = GenerateSeedsAndCodewordsLog(k, 1, n, g_gen, prf_method); 54 | SeedsCodewordsFlat *k_1 = new SeedsCodewordsFlat; 55 | SeedsCodewordsFlat *k_2 = new SeedsCodewordsFlat; 56 | FlattenCodewords(s, 0, k_1); 57 | FlattenCodewords(s, 1, k_2); 58 | 59 | // Copy over to tensor 60 | at::Tensor key_1 = key_from_codewords(k_1, n); 61 | at::Tensor key_2 = key_from_codewords(k_2, n); 62 | 63 | FreeSeedsCodewords(s); 64 | free(k_1); 65 | free(k_2); 66 | 67 | return {key_1, key_2}; 68 | } 69 | 70 | at::Tensor eval_dpf_cpu(at::Tensor key, int prf_method) { 71 | 72 | int n; 73 | SeedsCodewordsFlat *k = codewords_from_key(key, &n); 74 | 75 | // Expand codewords 76 | at::Tensor result = torch::ones({n}, at::kInt); 77 | 78 | // CPU expansion 79 | for (int i = 0; i < n; i++) { 80 | result[i] = (int)EvaluateFlat(k, i, prf_method); 81 | } 82 | 83 | return result; 84 | } 85 | 86 | void eval_free(std::vector buffers) { 87 | cudaFree(buffers[0]); 88 | cudaFree(buffers[1]); 89 | cudaFree(buffers[2]); 90 | dpf_hybrid_deinitialize(); 91 | } 92 | 93 | std::vector eval_init(at::Tensor table) { 94 | 95 | 96 | int num_entries = table.size(0); 97 | int entry_size = table.size(1); 98 | 99 | assert((num_entries & (num_entries-1)) == 0); 100 | assert(entry_size == MM); 101 | 102 | // Initialize the table on GPU memory 103 | uint128_t_gpu *table_reordered_cvted = new uint128_t_gpu[num_entries*entry_size]; 104 | for (int j = 0; j < entry_size; j++) { 105 | for (int i = 0; i < num_entries; i++) { 106 | int reordered_indx = brev_cpu(i) >> 32 - (int)log2(num_entries); 107 | table_reordered_cvted[i+j*num_entries] = uint128_gpu_from((uint128_t)table[reordered_indx][j].item()); 108 | } 109 | } 110 | 111 | uint128_t_gpu *TABLE; 112 | 113 | // Alloc and cpy to uint128_t_gpu array 114 | gpuErrchk(cudaMalloc(&TABLE, sizeof(uint128_t_gpu)*num_entries*entry_size)); 115 | cudaMemcpy(TABLE, table_reordered_cvted, sizeof(uint128_t_gpu)*num_entries*entry_size, cudaMemcpyHostToDevice); 116 | 117 | delete table_reordered_cvted; 118 | 119 | // Allocate gpu buffer for the input keys 120 | SeedsCodewordsFlatGPU *CW_GPU; 121 | gpuErrchk(cudaMalloc((void **)&CW_GPU, sizeof(SeedsCodewordsFlatGPU)*BATCH_SIZE)); 122 | 123 | // Allocate gpu buffer for the output 124 | uint128_t_gpu *OUT; 125 | gpuErrchk(cudaMalloc((void **)&OUT, sizeof(uint128_t_gpu)*BATCH_SIZE*MM)); 126 | cudaMemset(OUT, sizeof(uint128_t_gpu)*BATCH_SIZE*MM, 0); 127 | 128 | // Initialize hybrid strat 129 | dpf_hybrid_initialize(BATCH_SIZE, num_entries); 130 | 131 | return {TABLE, CW_GPU, OUT}; 132 | } 133 | 134 | at::Tensor eval_gpu(std::vector keys, std::vector buffers, int n, int prf_method) { 135 | assert(keys.size() == BATCH_SIZE); 136 | 137 | SeedsCodewordsFlatGPU *cw_intermediate = (SeedsCodewordsFlatGPU *)malloc(sizeof(SeedsCodewordsFlatGPU)*BATCH_SIZE); 138 | 139 | // Convert seeds/codewords to CW_GPU 140 | for (int i = 0; i < keys.size(); i++) { 141 | int num_entries = 0; 142 | SeedsCodewordsFlat *k = codewords_from_key(keys[i], &num_entries); 143 | assert(num_entries == n); 144 | cw_intermediate[i] = SeedsCodewordsFlatGPUFromCPU(*k); 145 | free(k); 146 | } 147 | 148 | // Copy to codewords to GPU buffer 149 | SeedsCodewordsFlatGPU *CW_GPU = (SeedsCodewordsFlatGPU *)buffers[1]; 150 | cudaMemcpy(CW_GPU, cw_intermediate, sizeof(SeedsCodewordsFlatGPU)*(keys.size()), cudaMemcpyHostToDevice); 151 | uint128_t_gpu *TABLE = (uint128_t_gpu *)buffers[0]; 152 | uint128_t_gpu *OUT = (uint128_t_gpu *)buffers[2]; 153 | 154 | // Perform batched dpf lookup 155 | cudaStream_t s; 156 | cudaStreamCreate(&s); 157 | if (prf_method == DUMMY) { 158 | dpf_hybrid(CW_GPU, OUT, TABLE, BATCH_SIZE, n, s); 159 | } 160 | else if (prf_method == SALSA20) { 161 | dpf_hybrid(CW_GPU, OUT, TABLE, BATCH_SIZE, n, s); 162 | } 163 | else if (prf_method == CHACHA20) { 164 | dpf_hybrid(CW_GPU, OUT, TABLE, BATCH_SIZE, n, s); 165 | } 166 | else if (prf_method == AES128) { 167 | dpf_hybrid(CW_GPU, OUT, TABLE, BATCH_SIZE, n, s); 168 | } 169 | 170 | else { 171 | assert(0); 172 | } 173 | 174 | // Cvt GPU output to CPU output 175 | uint128_t_gpu out_cpu[BATCH_SIZE*MM]; 176 | cudaMemcpy(out_cpu, OUT, sizeof(uint128_t_gpu)*BATCH_SIZE*MM, cudaMemcpyDeviceToHost); 177 | 178 | at::Tensor result = torch::zeros({BATCH_SIZE, MM}, at::kInt); 179 | uint32_t *r = (uint32_t *)result.data_ptr(); 180 | for (int i = 0; i < BATCH_SIZE; i++) { 181 | for (int j = 0; j < MM; j++) { 182 | r[i*MM+j] = (uint32_t)uint128_from_gpu(out_cpu[i+j*BATCH_SIZE]); 183 | } 184 | } 185 | return result; 186 | } 187 | 188 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 189 | // Funcs 190 | m.def("gen", &gen, "dpf gen"); 191 | m.def("eval_cpu", &eval_dpf_cpu, "dpf eval cpu"); 192 | m.def("eval_gpu", &eval_gpu, "dpf eval gpu"); 193 | m.def("eval_init", &eval_init, "dpf eval init"); 194 | m.def("eval_free", &eval_free, "dpf eval free"); 195 | 196 | // Consts 197 | m.attr("ENTRY_SIZE") = py::int_(MM); 198 | m.attr("BATCH_SIZE") = py::int_(BATCH_SIZE); 199 | 200 | m.attr("PRF_DUMMY") = py::int_(DUMMY); 201 | m.attr("PRF_SALSA20") = py::int_(SALSA20); 202 | m.attr("PRF_CHACHA20") = py::int_(CHACHA20); 203 | m.attr("PRF_AES128") = py::int_(AES128); 204 | } 205 | -------------------------------------------------------------------------------- /imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/1.png -------------------------------------------------------------------------------- /imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/2.png -------------------------------------------------------------------------------- /imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/3.png -------------------------------------------------------------------------------- /imgs/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/4.png -------------------------------------------------------------------------------- /imgs/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/5.png -------------------------------------------------------------------------------- /imgs/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/6.png -------------------------------------------------------------------------------- /imgs/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/7.png -------------------------------------------------------------------------------- /imgs/dpf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GPU-DPF/ce23a06af884ee54300b5bc5fd5350e445f10b0b/imgs/dpf.png -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | CC=g++ python setup.py install 3 | -------------------------------------------------------------------------------- /paper/experimental/.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *# 3 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/modules/language_model/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import os 4 | from io import open 5 | import torch 6 | 7 | class Dictionary(object): 8 | def __init__(self): 9 | self.word2idx = {} 10 | self.idx2word = [] 11 | 12 | def add_word(self, word): 13 | if word not in self.word2idx: 14 | self.idx2word.append(word) 15 | self.word2idx[word] = len(self.idx2word) - 1 16 | return self.word2idx[word] 17 | 18 | def __len__(self): 19 | return len(self.idx2word) 20 | 21 | 22 | class Corpus(object): 23 | def __init__(self, path): 24 | self.dictionary = Dictionary() 25 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 26 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 27 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 28 | 29 | def tokenize(self, path): 30 | """Tokenizes a text file.""" 31 | assert os.path.exists(path) 32 | # Add words to the dictionary 33 | with open(path, 'r', encoding="utf8") as f: 34 | for line in f: 35 | words = line.split() + [''] 36 | for word in words: 37 | self.dictionary.add_word(word) 38 | 39 | # Tokenize file content 40 | with open(path, 'r', encoding="utf8") as f: 41 | idss = [] 42 | for line in f: 43 | words = line.split() + [''] 44 | ids = [] 45 | for word in words: 46 | ids.append(self.dictionary.word2idx[word]) 47 | idss.append(torch.tensor(ids).type(torch.int64)) 48 | ids = torch.cat(idss) 49 | 50 | return ids 51 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/modules/language_model/language_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class RNNModel(nn.Module): 9 | """Container module with an encoder, a recurrent module, and a decoder.""" 10 | 11 | def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, tie_weights=False): 12 | super(RNNModel, self).__init__() 13 | self.ntoken = ntoken 14 | self.drop = nn.Dropout(dropout) 15 | self.encoder = nn.Embedding(ntoken, ninp) 16 | if rnn_type in ['LSTM', 'GRU']: 17 | self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout) 18 | else: 19 | try: 20 | nonlinearity = {'RNN_TANH': 'tanh', 'RNN_RELU': 'relu'}[rnn_type] 21 | except KeyError: 22 | raise ValueError( """An invalid option for `--model` was supplied, 23 | options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']""") 24 | self.rnn = nn.RNN(ninp, nhid, nlayers, nonlinearity=nonlinearity, dropout=dropout) 25 | self.decoder = nn.Linear(nhid, ntoken) 26 | 27 | # Optionally tie weights as in: 28 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 29 | # https://arxiv.org/abs/1608.05859 30 | # and 31 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 32 | # https://arxiv.org/abs/1611.01462 33 | if tie_weights: 34 | if nhid != ninp: 35 | raise ValueError('When using the tied flag, nhid must be equal to emsize') 36 | self.decoder.weight = self.encoder.weight 37 | 38 | self.init_weights() 39 | 40 | self.rnn_type = rnn_type 41 | self.nhid = nhid 42 | self.nlayers = nlayers 43 | 44 | def init_weights(self): 45 | initrange = 0.1 46 | nn.init.uniform_(self.encoder.weight, -initrange, initrange) 47 | nn.init.zeros_(self.decoder.bias) 48 | nn.init.uniform_(self.decoder.weight, -initrange, initrange) 49 | 50 | def forward(self, input, hidden): 51 | emb = self.drop(self.encoder(input)) 52 | output, hidden = self.rnn(emb, hidden) 53 | output = self.drop(output) 54 | decoded = self.decoder(output) 55 | decoded = decoded.view(-1, self.ntoken) 56 | return F.log_softmax(decoded, dim=1), hidden 57 | 58 | def init_hidden(self, bsz): 59 | weight = next(self.parameters()) 60 | if self.rnn_type == 'LSTM': 61 | return (weight.new_zeros(self.nlayers, bsz, self.nhid), 62 | weight.new_zeros(self.nlayers, bsz, self.nhid)) 63 | else: 64 | return weight.new_zeros(self.nlayers, bsz, self.nhid) 65 | 66 | # Temporarily leave PositionalEncoding module here. Will be moved somewhere else. 67 | class PositionalEncoding(nn.Module): 68 | r"""Inject some information about the relative or absolute position of the tokens in the sequence. 69 | The positional encodings have the same dimension as the embeddings, so that the two can be summed. 70 | Here, we use sine and cosine functions of different frequencies. 71 | .. math: 72 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 73 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 74 | \text{where pos is the word position and i is the embed idx) 75 | Args: 76 | d_model: the embed dim (required). 77 | dropout: the dropout value (default=0.1). 78 | max_len: the max. length of the incoming sequence (default=5000). 79 | Examples: 80 | >>> pos_encoder = PositionalEncoding(d_model) 81 | """ 82 | 83 | def __init__(self, d_model, dropout=0.1, max_len=5000): 84 | super(PositionalEncoding, self).__init__() 85 | self.dropout = nn.Dropout(p=dropout) 86 | 87 | pe = torch.zeros(max_len, d_model) 88 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 89 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 90 | pe[:, 0::2] = torch.sin(position * div_term) 91 | pe[:, 1::2] = torch.cos(position * div_term) 92 | pe = pe.unsqueeze(0).transpose(0, 1) 93 | self.register_buffer('pe', pe) 94 | 95 | def forward(self, x): 96 | r"""Inputs of forward function 97 | Args: 98 | x: the sequence fed to the positional encoder model (required). 99 | Shape: 100 | x: [sequence length, batch size, embed dim] 101 | output: [sequence length, batch size, embed dim] 102 | Examples: 103 | >>> output = pos_encoder(x) 104 | """ 105 | 106 | x = x + self.pe[:x.size(0), :] 107 | return self.dropout(x) 108 | 109 | class TransformerModel(nn.Module): 110 | """Container module with an encoder, a recurrent or transformer module, and a decoder.""" 111 | 112 | def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): 113 | super(TransformerModel, self).__init__() 114 | try: 115 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 116 | except: 117 | raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or lower.') 118 | self.model_type = 'Transformer' 119 | self.src_mask = None 120 | self.pos_encoder = PositionalEncoding(ninp, dropout) 121 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) 122 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 123 | self.encoder = nn.Embedding(ntoken, ninp) 124 | self.ninp = ninp 125 | self.decoder = nn.Linear(ninp, ntoken) 126 | 127 | self.init_weights() 128 | 129 | def _generate_square_subsequent_mask(self, sz): 130 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 131 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 132 | return mask 133 | 134 | def init_weights(self): 135 | initrange = 0.1 136 | nn.init.uniform_(self.encoder.weight, -initrange, initrange) 137 | nn.init.zeros_(self.decoder.bias) 138 | nn.init.uniform_(self.decoder.weight, -initrange, initrange) 139 | 140 | def forward(self, src, has_mask=True): 141 | if has_mask: 142 | device = src.device 143 | if self.src_mask is None or self.src_mask.size(0) != len(src): 144 | mask = self._generate_square_subsequent_mask(len(src)).to(device) 145 | self.src_mask = mask 146 | else: 147 | self.src_mask = None 148 | 149 | src = self.encoder(src) * math.sqrt(self.ninp) 150 | src = self.pos_encoder(src) 151 | output = self.transformer_encoder(src, self.src_mask) 152 | output = self.decoder(output) 153 | return F.log_softmax(output, dim=-1) 154 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/modules/language_model/language_model_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | # coding: utf-8 4 | import argparse 5 | import time 6 | import math 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.onnx 11 | 12 | import data 13 | 14 | import os 15 | import sys 16 | from io import open 17 | import torch 18 | import language_model as model_module 19 | import numpy as np 20 | 21 | class Dictionary(object): 22 | def __init__(self): 23 | self.word2idx = {} 24 | self.idx2word = [] 25 | 26 | def add_word(self, word): 27 | if word not in self.word2idx: 28 | self.idx2word.append(word) 29 | self.word2idx[word] = len(self.idx2word) - 1 30 | return self.word2idx[word] 31 | 32 | def __len__(self): 33 | return len(self.idx2word) 34 | 35 | 36 | class Corpus(object): 37 | def __init__(self, path): 38 | self.dictionary = Dictionary() 39 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 40 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 41 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 42 | 43 | def tokenize(self, path): 44 | """Tokenizes a text file.""" 45 | assert os.path.exists(path) 46 | # Add words to the dictionary 47 | with open(path, 'r', encoding="utf8") as f: 48 | for line in f: 49 | words = line.split() + [''] 50 | for word in words: 51 | self.dictionary.add_word(word) 52 | 53 | # Tokenize file content 54 | with open(path, 'r', encoding="utf8") as f: 55 | idss = [] 56 | for line in f: 57 | words = line.split() + [''] 58 | ids = [] 59 | for word in words: 60 | ids.append(self.dictionary.word2idx[word]) 61 | idss.append(torch.tensor(ids).type(torch.int64)) 62 | ids = torch.cat(idss) 63 | 64 | return ids 65 | 66 | 67 | corpus = Corpus("./data/wikitext-2/") 68 | 69 | # Starting from sequential data, batchify arranges the dataset into columns. 70 | # For instance, with the alphabet as the sequence and batch size 4, we'd get 71 | # ┌ a g m s ┐ 72 | # │ b h n t │ 73 | # │ c i o u │ 74 | # │ d j p v │ 75 | # │ e k q w │ 76 | # └ f l r x ┘. 77 | # These columns are treated as independent by the model, which means that the 78 | # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient 79 | # batch processing. 80 | 81 | def batchify(data, bsz): 82 | # Work out how cleanly we can divide the dataset into bsz parts. 83 | nbatch = data.size(0) // bsz 84 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 85 | data = data.narrow(0, 0, nbatch * bsz) 86 | # Evenly divide the data across the bsz batches. 87 | data = data.view(bsz, -1).t().contiguous() 88 | data = data.numpy().tolist() 89 | return data 90 | 91 | def batchify_pytorch(data, bsz): 92 | # Work out how cleanly we can divide the dataset into bsz parts. 93 | nbatch = data.size(0) // bsz 94 | # Trim off any extra elements that wouldn't cleanly fit (remainders). 95 | data = data.narrow(0, 0, nbatch * bsz) 96 | # Evenly divide the data across the bsz batches. 97 | data = data.view(bsz, -1).t().contiguous() 98 | return data.to("cpu") 99 | 100 | def repackage_hidden(h): 101 | """Wraps hidden states in new Tensors, to detach them from their history.""" 102 | 103 | if isinstance(h, torch.Tensor): 104 | return h.detach() 105 | else: 106 | return tuple(repackage_hidden(v) for v in h) 107 | 108 | def get_batch(source, i): 109 | seq_len = min(bptt, len(source) - 1 - i) 110 | data = [x[0] for x in source[i:i+seq_len]] 111 | target = source[i+1:i+1+seq_len] 112 | return data, target 113 | 114 | def get_batch_pytorch(source, i): 115 | seq_len = min(bptt, len(source) - 1 - i) 116 | data = source[i:i+seq_len] 117 | target = source[i+1:i+1+seq_len].view(-1) 118 | return data, target 119 | 120 | def get_access_pattern(data_source): 121 | access_pattern = [] 122 | 123 | # 35 is the defaul bptt for pytorch lang model 124 | for i in range(0, len(data_source), bptt): 125 | data, targets = get_batch(data_source, i) 126 | access_pattern.append(data) 127 | return access_pattern 128 | 129 | train_data = batchify(corpus.train, 1) 130 | val_data = batchify(corpus.valid, 1) 131 | test_data = batchify(corpus.test, 1) 132 | bptt = 35 133 | 134 | train_access_pattern = None 135 | test_access_pattern = None 136 | val_access_pattern = None 137 | num_embeddings = None 138 | test_words = None 139 | 140 | def wordify(source): 141 | sentences = [] 142 | for b in source: 143 | bb = [corpus.dictionary.idx2word[i] for i in b] 144 | sentences.append(" ".join(bb)) 145 | return sentences 146 | 147 | 148 | def initialize(): 149 | global train_access_pattern 150 | global test_access_pattern 151 | global val_access_pattern 152 | global num_embeddings 153 | 154 | train_access_pattern = get_access_pattern(train_data) 155 | test_access_pattern = get_access_pattern(test_data) 156 | val_access_pattern = get_access_pattern(val_data) 157 | 158 | #test_words = wordify(test_access_pattern) 159 | #print(wordify([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])) 160 | #sys.exit(0) 161 | num_embeddings = len(corpus.dictionary.idx2word) 162 | 163 | def evaluate(pir_optimize): 164 | print("Language model evaluating...") 165 | ntokens = len(corpus.dictionary) 166 | emsize = 650 167 | nhid = 650 168 | nlayers = 2 169 | dropout = .5 170 | tied = True 171 | eval_batch_size = 64 172 | bptt = 35 173 | 174 | model = model_module.RNNModel("LSTM", ntokens, emsize, nhid, nlayers, 175 | dropout, tied).to("cpu") 176 | 177 | # Load the best saved model. 178 | dir_to_use = os.path.dirname(__file__) 179 | with open(f"{dir_to_use}/model.pt", 'rb') as f: 180 | model = torch.load(f) 181 | model.to("cpu") 182 | # after load the rnn params are not a continuous chunk of memory 183 | # this makes them a continuous chunk, and will speed up forward pass 184 | # Currently, only rnn model supports flatten_parameters function. 185 | model.rnn.flatten_parameters() 186 | 187 | criterion = nn.NLLLoss() 188 | 189 | data_source = batchify_pytorch(corpus.valid, eval_batch_size) 190 | 191 | # Turn on evaluation mode which disables dropout. 192 | model.eval() 193 | total_loss = 0. 194 | hidden = model.init_hidden(eval_batch_size) 195 | with torch.no_grad(): 196 | for i in range(0, data_source.size(0) - 1, bptt): 197 | data, targets = get_batch_pytorch(data_source, i) 198 | 199 | ##################################################### 200 | # PIR optimization: drop according to pir_optimize 201 | data_pir = [] 202 | for batch in data: 203 | b = batch.detach().numpy().tolist() 204 | recovered, _ = pir_optimize.fetch(b) 205 | # 9 is 206 | new_b = [x if x in recovered else 9 for x in b] 207 | data_pir.append(new_b) 208 | data_pir = np.array(data_pir) 209 | data_pir = torch.from_numpy(data_pir) 210 | 211 | assert(data_pir.shape == data.shape) 212 | 213 | data = data_pir 214 | 215 | ##################################################### 216 | 217 | output, hidden = model(data, hidden) 218 | hidden = repackage_hidden(hidden) 219 | total_loss += len(data) * criterion(output, targets).item() 220 | ppl = total_loss / (len(data_source) - 1) 221 | print("Language model ppl: %f" % ppl) 222 | return {"ppl" : math.exp(ppl)} 223 | 224 | if __name__=="__main__": 225 | initialize() 226 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/modules/language_model/train_model.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | python main.py --cuda --emsize 650 --nhid 650 --dropout 0.5 --epochs 40 --tied --cuda 3 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/modules/movielens_rec/movielens_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import sys 4 | import tqdm 5 | import scipy 6 | import numpy as np 7 | import os 8 | import torch 9 | import torch.nn.functional as F 10 | from sklearn.metrics import roc_auc_score 11 | 12 | train_access_pattern = [] 13 | val_access_pattern = [] 14 | 15 | train_dataset = [] 16 | val_dataset = [] 17 | 18 | num_movies = 0 19 | num_users = 0 20 | 21 | LIM = 10000000000 22 | #LIM = 10000 23 | split = .8 24 | 25 | class RecModel(torch.nn.Module): 26 | 27 | def __init__(self, n_movies, em_size=32): 28 | super().__init__() 29 | 30 | # First embedding table index is the ads embedding 31 | self.table = torch.nn.EmbeddingBag(n_movies+1, em_size, mode="sum") 32 | self.em_size = em_size 33 | 34 | self.fc1 = torch.nn.Linear(64, 200) 35 | self.fc2 = torch.nn.Linear(200, 80) 36 | self.fc3 = torch.nn.Linear(80, 2) 37 | 38 | self.d = torch.nn.Dropout(.5) 39 | 40 | def forward(self, movie_history, target_movie): 41 | target_movie = target_movie.reshape((-1, 1)) 42 | 43 | self.table.weight.data[0,:] = 0 44 | 45 | target_embedding = self.table(target_movie+1) 46 | movie_history_embedding = self.table(movie_history+1) 47 | 48 | x = torch.cat([target_embedding, movie_history_embedding], dim=1) 49 | 50 | x = self.fc1(x) 51 | x = F.relu(x) 52 | x = self.d(x) 53 | x = self.fc2(x) 54 | x = F.relu(x) 55 | x = self.fc3(x) 56 | 57 | return x 58 | 59 | def initialize(): 60 | 61 | user_ratings = {} 62 | 63 | with open("data/ml-20m/ratings.csv", "r") as f: 64 | lines = f.readlines()[1:] 65 | # userId,movieId,rating,timestamp 66 | for i,line in enumerate(lines): 67 | if i >= LIM: 68 | break 69 | line = line.split(",") 70 | user_id, movie_id, rating, timestamp = line[0], line[1], line[2], line[3] 71 | user_id, movie_id, rating, timestamp = int(user_id), int(movie_id), float(rating), int(timestamp) 72 | 73 | click = rating >= 4 74 | 75 | if user_id not in user_ratings: 76 | user_ratings[user_id] = [] 77 | 78 | user_ratings[user_id].append((movie_id, click, timestamp)) 79 | 80 | 81 | global train_dataset 82 | global val_dataset 83 | global num_users 84 | global num_movies 85 | train_dataset = [] 86 | val_dataset = [] 87 | num_users = 0 88 | num_movies = 0 89 | for i, (user_id, d) in enumerate(user_ratings.items()): 90 | test = i >= int(split*len(user_ratings)) 91 | user_click_history = [(x[0], x[2]) for x in d if x[1]] 92 | num_users = max(num_users, user_id) 93 | for movie_id, click, timestamp in d: 94 | if test: 95 | val_dataset.append((user_click_history, movie_id, timestamp, click)) 96 | else: 97 | train_dataset.append((user_click_history, movie_id, timestamp, click)) 98 | num_movies = max(num_movies, movie_id) 99 | 100 | num_users += 1 101 | num_movies += 1 102 | 103 | print("movies: ", num_movies) 104 | 105 | # Extract train and val access pattern 106 | print("Extracting access pattern...") 107 | for i, (user_id, d) in enumerate(user_ratings.items()): 108 | test = i >= int(split*len(user_ratings)) 109 | user_click_history = [x[0] for x in d if x[1]] 110 | if test: 111 | val_access_pattern.append(user_click_history) 112 | else: 113 | train_access_pattern.append(user_click_history) 114 | 115 | def obtain_click_history(point, timestamp): 116 | L = 5000 117 | click_history = point[0] 118 | click_history = [x[0] for x in click_history if x[1] < timestamp] 119 | if len(click_history) < L: 120 | click_history += [-1]*(L-len(click_history)) 121 | if len(click_history) != L: 122 | print(len(click_history)) 123 | assert(len(click_history) == L) 124 | return click_history 125 | 126 | def evaluate(pir_optimize): 127 | dir_to_use = os.path.dirname(__file__) 128 | model = RecModel(num_movies) 129 | with open(f"{dir_to_use}/recmodel_epoch=1.pt", 'rb') as f: 130 | model = torch.load(f) 131 | pass 132 | model.to("cpu") 133 | 134 | auc = evaluate_model(model, val_dataset, pir_optimize=pir_optimize) 135 | print(f"AUC: {auc}") 136 | return {"auc" : auc} 137 | 138 | def evaluate_model(model, dataset, batch=256, pir_optimize=None): 139 | model.eval() 140 | groundtruths, preds = [], [] 141 | 142 | indices = list(range(len(dataset)//10)) 143 | for b in range(0, len(indices), batch): 144 | print(f"evaluate_model {b}/{len(indices)}") 145 | 146 | # Get user "clicks" 147 | points = [train_dataset[x] for x in indices[b:b+batch]] 148 | timestamps = [x[2] for x in points] 149 | click_history = [obtain_click_history(x, timestamps[i]) for i,x in enumerate(points)] 150 | 151 | ############################ 152 | # PIR 153 | data_pir = [] 154 | for bbatch in click_history: 155 | n_fillers = bbatch.count(-1) 156 | bb = [x for x in bbatch if x != -1] 157 | if pir_optimize is not None: 158 | recovered, _ = pir_optimize.fetch(bb) 159 | else: 160 | recovered = bb 161 | # 9 is 162 | new_b = [x if x in recovered else -1 for x in bb] 163 | data_pir.append(new_b + [-1]*n_fillers) 164 | #data_pir = np.array(data_pir) 165 | #data_pir = torch.from_numpy(data_pir) 166 | #data_pir = data_pir.to(next(model.parameters()).device) 167 | 168 | #assert(data_pir.shape == click_history.shape) 169 | 170 | click_history= data_pir 171 | 172 | ############################ """ 173 | 174 | target_movie = [x[1] for x in points] 175 | targets = [x[-1] for x in points] 176 | 177 | click_history = torch.from_numpy(np.array(click_history)).long() 178 | target_movie = torch.from_numpy(np.array(target_movie)).long() 179 | targets = torch.from_numpy(np.array(targets)).long() 180 | 181 | click_history = click_history.to(next(model.parameters()).device) 182 | target_movie = target_movie.to(next(model.parameters()).device) 183 | targets = targets.to(next(model.parameters()).device) 184 | 185 | pred = model(click_history, target_movie) 186 | 187 | prob_click = F.softmax(pred, dim=1)[:,1] 188 | prob_click = prob_click.detach().cpu().numpy().flatten().tolist() 189 | 190 | preds += prob_click 191 | groundtruths += targets.detach().cpu().numpy().flatten().tolist() 192 | 193 | score = roc_auc_score(groundtruths, preds) 194 | model.train() 195 | return score 196 | 197 | def train_movielens(epochs=100, batch=64): 198 | print("Training...") 199 | 200 | model = RecModel(num_movies) 201 | model.to("cuda") 202 | loss = torch.nn.CrossEntropyLoss() 203 | optim = torch.optim.Adam(model.parameters()) 204 | 205 | # Train on train users 206 | for epoch in range(epochs): 207 | print(f"Epoch {epoch}/{epochs}") 208 | indices = list(range(len(train_dataset))) 209 | np.random.shuffle(indices) 210 | 211 | train_loss = 0 212 | for b in tqdm.tqdm(range(0, len(indices), batch)): 213 | # Get user "clicks" 214 | points = [train_dataset[x] for x in indices[b:b+batch]] 215 | timestamps = [x[2] for x in points] 216 | click_history = [obtain_click_history(x, timestamps[i]) for i,x in enumerate(points)] 217 | target_movie = [x[1] for x in points] 218 | targets = [x[-1] for x in points] 219 | 220 | click_history = torch.from_numpy(np.array(click_history)).long() 221 | target_movie = torch.from_numpy(np.array(target_movie)).long() 222 | targets = torch.from_numpy(np.array(targets)).long() 223 | 224 | click_history = click_history.to("cuda") 225 | target_movie = target_movie.to("cuda") 226 | targets = targets.to("cuda") 227 | 228 | model.zero_grad() 229 | pred = model(click_history, target_movie) 230 | output = loss(pred, targets) 231 | 232 | output.backward() 233 | optim.step() 234 | 235 | train_loss += output.detach().cpu().item() 236 | 237 | score = evaluate_model(model, val_dataset) 238 | #score = evaluate_model(model, train_dataset) 239 | print("Eval score", score) 240 | 241 | torch.save(model, f"recmodel_epoch={epoch}.pt") 242 | 243 | if __name__=="__main__": 244 | initialize() 245 | #train_movielens() 246 | evaluate(None) 247 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/setup.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | export PYTHONPATH=$PYTHONPATH:`pwd`/modules/language_model/ 4 | export PYTHONPATH=$PYTHONPATH:`pwd`/modules/taobao_rec/ 5 | export PYTHONPATH=$PYTHONPATH:`pwd`/modules/movielens_rec/ 6 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/sweep/language_model_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | import json 7 | import glob 8 | import sys 9 | 10 | dir_out = "language_model_sweep_out" 11 | files = glob.glob(dir_out + "/*") 12 | 13 | data = [] 14 | 15 | for f in files: 16 | with open(f, "r") as ff: 17 | d = json.load(ff) 18 | data.append(d) 19 | 20 | # Fairly fast for many datapoints, less fast for many costs, somewhat readable 21 | def is_pareto_efficient_simple(costs): 22 | """ 23 | Find the pareto-efficient points 24 | :param costs: An (n_points, n_costs) array 25 | :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient 26 | """ 27 | is_efficient = np.ones(costs.shape[0], dtype = bool) 28 | for i, c in enumerate(costs): 29 | if is_efficient[i]: 30 | is_efficient[is_efficient] = np.any(costs[is_efficient]c, axis=1) # Keep any point with a lower cost 34 | is_efficient[i] = True # And keep self 35 | return is_efficient 36 | 37 | def get_pareto_points(xs, ys, reverse=False): 38 | points = list(zip(xs, ys)) 39 | is_efficient = is_pareto_efficient_simple(np.array(points), reverse=reverse) 40 | selected = [x for i,x in enumerate(points) if is_efficient[i]] 41 | selected.sort(key=lambda x:x[0]) 42 | return [x[0] for x in selected], [x[1] for x in selected] 43 | 44 | def plot_computation_vs_accuracy(data): 45 | accuracy_baseline = max([x["accuracy_stats"]["auc"] for x in data]) 46 | 47 | data = [x for x in data if x["cost"]["upload_communication"] + x["cost"]["download_communication"] <= 300000] 48 | print(len(data)) 49 | 50 | plt.cla() 51 | plt.clf() 52 | 53 | plt.axhline(accuracy_baseline) 54 | 55 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 56 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 57 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 58 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 59 | 60 | print("plain", len(plain)) 61 | print("coll", len(collocate_only)) 62 | print("hotc", len(hotcold_only)) 63 | print("basic", len(basic)) 64 | 65 | def plot_single(data, label, marker, color): 66 | accuracy = [x["accuracy_stats"]["auc"] for x in data] 67 | computation = [x["cost"]["computation"]/1000 for x in data] 68 | 69 | 70 | computation, accuracy = get_pareto_points(computation, [-x for x in accuracy], reverse=False) 71 | accuracy = [-x for x in accuracy] 72 | 73 | print(label, list(zip(computation, accuracy))) 74 | 75 | #plt.scatter(computation, accuracy, label=label, alpha=.3) 76 | plt.plot(computation, accuracy, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 77 | 78 | plot_single(plain, "batch-pir", "o", "black") 79 | plot_single(collocate_only, "batch-pir +c", "x", "red") 80 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 81 | plot_single(data, "batch-pir +c +h", "v", "blue") 82 | 83 | plt.xlabel("Computation (kPRFs)", fontsize=28) 84 | plt.ylabel("Accuracy (auc)", fontsize=28) 85 | 86 | plt.xticks(fontsize=22) 87 | plt.yticks(fontsize=22) 88 | 89 | #plt.yscale("log") 90 | 91 | plt.legend(loc="best", fontsize=14) 92 | plt.tight_layout() 93 | plt.savefig(f"movielens_computation_vs_auc.pdf", tight_layout=True) 94 | 95 | def plot_communication_vs_accuracy(data): 96 | accuracy_baseline = max([x["accuracy_stats"]["auc"] for x in data]) 97 | 98 | data = [x for x in data if x["cost"]["computation"] <= 100000] 99 | data = [x for x in data if (x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 < 1000] 100 | data = [x for x in data if x["accuracy_stats"]["auc"] >= .75] 101 | 102 | print(len(data)) 103 | 104 | plt.cla() 105 | plt.clf() 106 | 107 | plt.axhline(accuracy_baseline) 108 | 109 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 110 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 111 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 112 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 113 | 114 | print("plain", len(plain)) 115 | print("coll", len(collocate_only)) 116 | print("hotc", len(hotcold_only)) 117 | print("basic", len(basic)) 118 | 119 | def plot_single(data, label, marker, color): 120 | accuracy = [x["accuracy_stats"]["auc"] for x in data] 121 | computation = [(x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 for x in data] 122 | 123 | 124 | computation, accuracy = get_pareto_points(computation, [-x for x in accuracy], reverse=False) 125 | accuracy = [-x for x in accuracy] 126 | 127 | print(label, list(zip(computation, accuracy))) 128 | 129 | #plt.scatter(computation, accuracy, label=label, alpha=.3) 130 | plt.plot(computation, accuracy, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 131 | 132 | plot_single(plain, "batch-pir", "o", "black") 133 | plot_single(collocate_only, "batch-pir +c", "x", "red") 134 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 135 | plot_single(data, "batch-pir +c +h", "v", "blue") 136 | 137 | plt.xlabel("Communication (kBytes)", fontsize=28) 138 | plt.ylabel("Accuracy (auc)", fontsize=28) 139 | 140 | plt.xticks(fontsize=22) 141 | plt.yticks(fontsize=22) 142 | 143 | #plt.yscale("log") 144 | 145 | plt.legend(loc="best", fontsize=14) 146 | plt.tight_layout() 147 | plt.savefig(f"movielens_communication_vs_auc.pdf", tight_layout=True) 148 | 149 | def plot_communication_vs_computation_cost(data): 150 | 151 | data = [x for x in data if x["accuracy_stats"]["auc"] >= .77] 152 | print(len(data)) 153 | 154 | plt.cla() 155 | plt.clf() 156 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 157 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 158 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 159 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 160 | 161 | print("plain", len(plain)) 162 | print("coll", len(collocate_only)) 163 | print("hotc", len(hotcold_only)) 164 | print("basic", len(basic)) 165 | 166 | def plot_single(data, label, marker, color): 167 | communication = [(x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 for x in data] 168 | computation = [x["cost"]["computation"]/1000 for x in data] 169 | communication, computation = get_pareto_points(communication, computation) 170 | #plt.scatter(communication, computation, label=label) 171 | plt.plot(communication, computation, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 172 | 173 | plot_single(plain, "batch-pir", "o", "black") 174 | plot_single(collocate_only, "batch-pir +c", "x", "red") 175 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 176 | plot_single(data, "batch-pir +c +h", "v", "blue") 177 | 178 | plt.xlabel("Communication (kBytes)", fontsize=28) 179 | plt.ylabel("Computation (kPRFs)", fontsize=28) 180 | 181 | plt.xticks(fontsize=18) 182 | plt.yticks(fontsize=18) 183 | 184 | #plt.yscale("log") 185 | 186 | plt.legend(loc="best", fontsize=14) 187 | plt.tight_layout() 188 | plt.savefig(f"movielens_communication_vs_computation.pdf", tight_layout=True) 189 | 190 | #plot_computation_cost(data) 191 | #plot_communication_cost(data) 192 | plot_communication_vs_computation_cost(data) 193 | plot_computation_vs_accuracy(data) 194 | plot_communication_vs_accuracy(data) 195 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/sweep/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import sys 4 | import os 5 | import json 6 | import pprint 7 | from collections import namedtuple 8 | import numpy as np 9 | import random 10 | import pprint 11 | import matplotlib.pyplot as plt 12 | import batch_pir_optimization 13 | 14 | mode = sys.argv[1] 15 | 16 | if mode == "lm": 17 | import language_model_dataset as dataset 18 | from language_model_dataset import * 19 | dir_out = "language_model_sweep_out" 20 | 21 | if mode == "movielens": 22 | import movielens_dataset as dataset 23 | from movielens_dataset import * 24 | dir_out = "movielens_sweep_out" 25 | initialize_collocate_load_from="initialize_collocate_movielens.json" 26 | 27 | if mode == "taobao": 28 | import taobao_rec_dataset_v2 as dataset 29 | from taobao_rec_dataset_v2 import * 30 | 31 | dir_out = "taobao_sweep_out" 32 | initialize_collocate_load_from="initialize_collocate_taobao.json" 33 | 34 | from torch.multiprocessing import Pool, Process, set_start_method 35 | try: 36 | set_start_method('spawn') 37 | except RuntimeError: 38 | pass 39 | 40 | dataset.initialize() 41 | 42 | # Config params (try to be independent of dataset) 43 | assert(dataset.num_embeddings > 0) 44 | 45 | # Get average number of accesses 46 | access_lengths = [] 47 | for d in dataset.val_access_pattern: 48 | access_lengths.append(len(d)) 49 | avg_access_length = np.mean(access_lengths) 50 | max_access_length = np.max(access_lengths) 51 | access_length_90 = np.percentile(access_lengths, 90) 52 | 53 | hot_cold_ratios = [1, .05, .1, .15, .2, .25, .3] 54 | num_collocation = [0, 1, 2, 3, 4, 5] 55 | 56 | #pir_num_bins = [int(x) for x in list(np.arange(1, dataset.num_embeddings, dataset.num_embeddings//10))] 57 | 58 | # We are going change num_bins -> bin_fraction (i.e: fraction of total dataset) 59 | pir_num_bins = [.1, .2, .3, .4, .5, .6, .7, .8, .9, 1] 60 | 61 | # We got by powers of 2 but not as much since generally most users don't click on many ads 62 | pir_hot_queries = [int(x) for x in list(np.arange(1, access_length_90, max(1, access_length_90//10)))] 63 | pir_cold_queries = [int(x) for x in list(np.arange(1, access_length_90, max(1,access_length_90//10)))] 64 | 65 | with open(initialize_collocate_load_from, "r") as f: 66 | initialize_collocate_load_from = json.load(f) 67 | 68 | def run(hot_cold_ratio, num_collocation, pir_num_bins, num_hot_queries, num_cold_queries): 69 | hotcold_config = batch_pir_optimization.HotColdConfig(hot_cold_ratio) 70 | collocate_config = batch_pir_optimization.CollocateConfig(num_collocation) 71 | pir_config = batch_pir_optimization.PIRConfig(pir_num_bins, 200, num_hot_queries, num_cold_queries) 72 | 73 | b = batch_pir_optimization.BatchPIROptimize(dataset.train_access_pattern, 74 | dataset.val_access_pattern, 75 | hotcold_config, collocate_config, pir_config, 76 | initialize_collocate_load_from=initialize_collocate_load_from) 77 | b.evaluate_real() 78 | results = b.summarize_evaluation() 79 | 80 | f_out = f"{dir_out}/{hot_cold_ratio}_{num_collocation}_{ pir_num_bins}_{num_hot_queries}_{num_cold_queries}" 81 | with open(f_out, "w") as f: 82 | json.dump(results, f) 83 | 84 | if __name__=="__main__": 85 | 86 | if not os.path.exists(dir_out): 87 | os.makedirs(dir_out) 88 | 89 | args = [] 90 | for r in hot_cold_ratios: 91 | for c in num_collocation: 92 | for b in pir_num_bins: 93 | for q in pir_hot_queries: 94 | for x in pir_cold_queries: 95 | args_set = (r, c, b, q, x) 96 | args.append(args_set) 97 | 98 | np.random.shuffle(args) 99 | 100 | batch_pir_only = [x for x in args if x[0] == 1 and x[1] == 0] 101 | batch_pir_with_hot_cold = [x for x in args if x[1] == 0 and x[0] < 1] 102 | batch_pir_with_hot_cold_and_coll = [x for x in args if x[1] > 1 and x[0] < 1] 103 | batch_pir_with_coll = [x for x in args if x[1] > 0 and x[0] == 1] 104 | 105 | 106 | max_len = max([len(x) for x in [batch_pir_only, batch_pir_with_hot_cold, batch_pir_with_hot_cold_and_coll, batch_pir_with_coll]]) 107 | 108 | batch_pir_only = (batch_pir_only*max_len)[:max_len] 109 | batch_pir_with_hot_cold = (batch_pir_with_hot_cold*max_len)[:max_len] 110 | batch_pir_with_hot_cold_and_coll = (batch_pir_with_hot_cold_and_coll*max_len)[:max_len] 111 | batch_pir_with_coll = (batch_pir_with_coll*max_len)[:max_len] 112 | 113 | args = batch_pir_only + batch_pir_with_hot_cold + batch_pir_with_hot_cold_and_coll + batch_pir_with_coll 114 | np.random.shuffle(args) 115 | 116 | #args.sort(key=lambda x: 0 if (x[0] == 1 and 117 | # x[1] == 0 and 118 | # (x[2] == 1 or x[2] == 4 or x[2] == 256) and 119 | # (x[3] == 1 or x[3] == 4 or x[3] == 16) and 120 | # (x[4] == 1 or x[4] == 4 or x[4] == 16)) else 1) 121 | 122 | #print(args[0:3]) 123 | #args = [x for x in args if x[2] == .1] 124 | #run(*args[0]) 125 | #run(*args[1]) 126 | #run(*args[2]) 127 | 128 | with Pool(processes=8) as pool: 129 | pool.starmap(run, args) 130 | 131 | 132 | -------------------------------------------------------------------------------- /paper/experimental/batch_pir/sweep/taobao_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import os 6 | import json 7 | import glob 8 | import sys 9 | 10 | dir_out = "taobao_sweep_out" 11 | files = glob.glob(dir_out + "/*") 12 | 13 | data = [] 14 | 15 | for f in files: 16 | with open(f, "r") as ff: 17 | d = json.load(ff) 18 | data.append(d) 19 | 20 | # Fairly fast for many datapoints, less fast for many costs, somewhat readable 21 | def is_pareto_efficient_simple(costs, reverse=False): 22 | """ 23 | Find the pareto-efficient points 24 | :param costs: An (n_points, n_costs) array 25 | :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient 26 | """ 27 | is_efficient = np.ones(costs.shape[0], dtype = bool) 28 | for i, c in enumerate(costs): 29 | if is_efficient[i]: 30 | if not reverse: 31 | is_efficient[is_efficient] = np.any(costs[is_efficient]c, axis=1) # Keep any point with a lower cost 34 | is_efficient[i] = True # And keep self 35 | return is_efficient 36 | 37 | def get_pareto_points(xs, ys, reverse=False): 38 | points = list(zip(xs, ys)) 39 | is_efficient = is_pareto_efficient_simple(np.array(points), reverse=reverse) 40 | selected = [x for i,x in enumerate(points) if is_efficient[i]] 41 | selected.sort(key=lambda x:x[0]) 42 | return [x[0] for x in selected], [x[1] for x in selected] 43 | 44 | def plot_computation_vs_accuracy(data): 45 | accuracy_baseline = max([x["accuracy_stats"]["auc"] for x in data]) 46 | 47 | data = [x for x in data if x["cost"]["upload_communication"] + x["cost"]["download_communication"] <= 300000] 48 | print(len(data)) 49 | 50 | plt.cla() 51 | plt.clf() 52 | 53 | plt.axhline(accuracy_baseline) 54 | 55 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 56 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 57 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 58 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 59 | 60 | print("plain", len(plain)) 61 | print("coll", len(collocate_only)) 62 | print("hotc", len(hotcold_only)) 63 | print("basic", len(basic)) 64 | 65 | 66 | def plot_single(data, label, marker, color): 67 | accuracy = [x["accuracy_stats"]["auc"] for x in data] 68 | computation = [x["cost"]["computation"]/1000 for x in data] 69 | 70 | 71 | computation, accuracy = get_pareto_points(computation, [-x for x in accuracy], reverse=False) 72 | accuracy = [-x for x in accuracy] 73 | 74 | print(label, list(zip(computation, accuracy))) 75 | 76 | #plt.scatter(computation, accuracy, label=label, alpha=.3) 77 | plt.plot(computation, accuracy, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 78 | 79 | plot_single(plain, "batch-pir", "o", "black") 80 | plot_single(collocate_only, "batch-pir +c", "x", "red") 81 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 82 | plot_single(data, "batch-pir +c +h", "v", "blue") 83 | 84 | plt.xlabel("Computation (kPRFs)", fontsize=28) 85 | plt.ylabel("Accuracy (auc)", fontsize=28) 86 | 87 | plt.xticks(fontsize=22) 88 | plt.yticks(fontsize=22) 89 | 90 | #plt.yscale("log") 91 | 92 | plt.legend(loc="best", fontsize=14) 93 | plt.tight_layout() 94 | plt.savefig(f"taobao_computation_vs_auc.pdf", tight_layout=True) 95 | 96 | def plot_communication_vs_accuracy(data): 97 | accuracy_baseline = max([x["accuracy_stats"]["auc"] for x in data]) 98 | 99 | data = [x for x in data if x["cost"]["computation"] <= 100000000000] 100 | #data = [x for x in data if (x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 < 150] 101 | data = [x for x in data if x["accuracy_stats"]["auc"] >= .58] 102 | 103 | print(len(data)) 104 | 105 | plt.cla() 106 | plt.clf() 107 | 108 | plt.axhline(accuracy_baseline) 109 | 110 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 111 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 112 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 113 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 114 | 115 | print("plain", len(plain)) 116 | print("coll", len(collocate_only)) 117 | print("hotc", len(hotcold_only)) 118 | print("basic", len(basic)) 119 | 120 | def plot_single(data, label, marker, color): 121 | accuracy = [x["accuracy_stats"]["auc"] for x in data] 122 | computation = [(x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 for x in data] 123 | 124 | 125 | computation, accuracy = get_pareto_points(computation, [-x for x in accuracy], reverse=False) 126 | accuracy = [-x for x in accuracy] 127 | 128 | print(label, list(zip(computation, accuracy))) 129 | 130 | #plt.scatter(computation, accuracy, label=label, alpha=.3) 131 | plt.plot(computation, accuracy, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 132 | 133 | plot_single(plain, "batch-pir", "o", "black") 134 | plot_single(collocate_only, "batch-pir +c", "x", "red") 135 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 136 | plot_single(data, "batch-pir +c +h", "v", "blue") 137 | 138 | plt.xlabel("Communication (kBytes)", fontsize=28) 139 | plt.ylabel("Accuracy (auc)", fontsize=28) 140 | 141 | plt.xticks(fontsize=22) 142 | plt.yticks(fontsize=22) 143 | 144 | #plt.yscale("log") 145 | 146 | plt.legend(loc="best", fontsize=14) 147 | plt.tight_layout() 148 | plt.savefig(f"taobao_communication_vs_auc.pdf", tight_layout=True) 149 | 150 | def plot_communication_vs_computation_cost(data): 151 | 152 | data = [x for x in data if x["accuracy_stats"]["auc"] >= .58] 153 | print(len(data)) 154 | 155 | plt.cla() 156 | plt.clf() 157 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 158 | collocate_only = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1] 159 | hotcold_only = [x for x in data if x["collocate_config"]["num_collocate"] == 0] 160 | basic = [x for x in data if x["pir_config"]["num_bins"] == 1 and x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 161 | 162 | print("plain", len(plain)) 163 | print("coll", len(collocate_only)) 164 | print("hotc", len(hotcold_only)) 165 | print("basic", len(basic)) 166 | 167 | def plot_single(data, label, marker, color): 168 | communication = [(x["cost"]["upload_communication"]+x["cost"]["download_communication"])/1000 for x in data] 169 | computation = [x["cost"]["computation"]/1000 for x in data] 170 | communication, computation = get_pareto_points(communication, computation) 171 | #plt.scatter(communication, computation, label=label) 172 | plt.plot(communication, computation, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 173 | 174 | plot_single(plain, "batch-pir", "o", "black") 175 | plot_single(collocate_only, "batch-pir +c", "x", "red") 176 | plot_single(hotcold_only, "batch_pir +h", "^", "green") 177 | plot_single(data, "batch-pir +c +h", "v", "blue") 178 | 179 | plt.xlabel("Communication (kBytes)", fontsize=28) 180 | plt.ylabel("Computation (kPRFs)", fontsize=28) 181 | 182 | plt.xticks(fontsize=18) 183 | plt.yticks(fontsize=18) 184 | 185 | #plt.yscale("log") 186 | 187 | plt.legend(loc="best", fontsize=14) 188 | plt.tight_layout() 189 | plt.savefig(f"taobao_communication_vs_computation.pdf", tight_layout=True) 190 | 191 | #plot_computation_cost(data) 192 | #plot_communication_cost(data) 193 | plot_communication_vs_computation_cost(data) 194 | plot_computation_vs_accuracy(data) 195 | plot_communication_vs_accuracy(data) 196 | -------------------------------------------------------------------------------- /paper/experimental/codesign/join_batch_pir_accuracy_with_gpu_dpf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import sys 4 | import pprint 5 | import os 6 | import glob 7 | import numpy as np 8 | import json 9 | 10 | d_in = sys.argv[1] 11 | app = sys.argv[2] 12 | dpf_in = sys.argv[3] 13 | d_out = sys.argv[4] 14 | 15 | def application_config(app): 16 | if app == "lm": 17 | return 33278 18 | if app == "movielens": 19 | return 131263 20 | if app == "taobao": 21 | return 846811 22 | assert(0) 23 | 24 | def load_acc_data(d_in): 25 | files = glob.glob(f"{d_in}/*") 26 | print(len(files)) 27 | 28 | d = [] 29 | for f in files: 30 | with open(f, "r") as ff: 31 | d.append((f, json.load(ff))) 32 | 33 | return d 34 | 35 | def load_dpf_perf_data(dpf_in): 36 | files = glob.glob(f"{dpf_in}/*") 37 | d = [] 38 | 39 | for f in files: 40 | with open(f, "r") as ff: 41 | last_line = ff.readlines()[-1] 42 | try: 43 | data = eval(last_line) 44 | d.append(data) 45 | 46 | except: 47 | pass 48 | return d 49 | 50 | def compute_joined_data(num_embeddings, pir_stat, dpf_perf_numbers): 51 | 52 | hot_table_embeddings = int(pir_stat["hotcold_config"]["cache_size_fraction"]*num_embeddings) 53 | cold_table_embeddings = num_embeddings - hot_table_embeddings 54 | 55 | num_collocate = pir_stat["collocate_config"]["num_collocate"] 56 | 57 | queries_to_hot = pir_stat["pir_config"]["queries_to_hot"] 58 | queries_to_cold = pir_stat["pir_config"]["queries_to_cold"] 59 | 60 | num_bins = pir_stat["pir_config"]["num_bins"] 61 | 62 | if num_bins == 1: 63 | return None 64 | 65 | if num_bins > 1: 66 | bin_size_hot = num_bins 67 | bin_size_cold = num_bins 68 | elif num_bins < 1: 69 | bin_size_hot = int(num_bins * hot_table_embeddings) 70 | bin_size_cold = int(num_bins * cold_table_embeddings) 71 | 72 | num_bins_hot = hot_table_embeddings // bin_size_hot + 1 73 | num_bins_cold = 0 if cold_table_embeddings == 0 else cold_table_embeddings // bin_size_cold + 1 74 | 75 | hot_numbers = [] 76 | cold_numbers = [] 77 | 78 | for perf_number in dpf_perf_numbers: 79 | throughput = perf_number["throughput_queries_per_ms"] 80 | latency = perf_number["latency_ms"] 81 | batchsize = perf_number["batch_size"] 82 | 83 | if perf_number["entries"] >= bin_size_hot and perf_number["entry_size_ints"]*128/8 >= num_collocate*16: 84 | n_queries_for_hot = num_bins_hot * queries_to_hot 85 | #print(n_queries_for_hot, bin_size_hot) 86 | #if n_queries_for_hot > batchsize: 87 | # continue 88 | hot_throughput = batchsize / n_queries_for_hot 89 | hot_latency = np.ceil(n_queries_for_hot / batchsize)*latency 90 | hot_numbers.append((hot_throughput, hot_latency)) 91 | 92 | if perf_number["entries"] >= bin_size_cold and perf_number["entry_size_ints"]*128/8 >= num_collocate*16: 93 | n_queries_for_cold = num_bins_cold * queries_to_cold 94 | #if n_queries_for_cold > batchsize or n_queries_for_cold == 0: 95 | # continue 96 | if n_queries_for_cold == 0: 97 | cold_throughput = float("inf") 98 | cold_latency = float("0") 99 | else: 100 | cold_throughput = batchsize / n_queries_for_cold 101 | cold_latency = np.ceil(n_queries_for_cold / batchsize)*latency 102 | cold_numbers.append((cold_throughput, cold_latency)) 103 | 104 | 105 | # Since we need to compute both] hot and cold numbers, assuming we have 2 GV100 gpus 106 | # we take the maxs of latency and mins of throughputs 107 | latency_throughputs = [] 108 | for h in hot_numbers: 109 | for c in cold_numbers: 110 | #latency = max(h[1],c[1]) 111 | latency = h[1]+c[1] 112 | #throughput = min(h[0],c[0]) 113 | throughput = min(h[0],c[0])/2 114 | latency_throughputs.append((latency, throughput)) 115 | latency_throughputs = list(set(latency_throughputs)) 116 | 117 | #print(len(latency_throughputs)) 118 | 119 | latency_throughputs_final = [] 120 | for latency, throughput in latency_throughputs: 121 | best_throughput = max([x[1] for x in latency_throughputs if x[0] <= latency]) 122 | #print(best_throughput) 123 | latency_throughputs_final.append((latency, best_throughput)) 124 | 125 | if len(latency_throughputs_final) <= 0: 126 | return None 127 | 128 | assert(len(latency_throughputs_final) > 0) 129 | #print(sorted(latency_throughputs_final)) 130 | #sys.exit(0) 131 | 132 | return {**pir_stat, **{"latency_throughputs" : latency_throughputs_final}} 133 | 134 | def join_data(num_embeddings, data, dpf_perf_numbers): 135 | final = [] 136 | for fname, pir_stats in data: 137 | joined_data = compute_joined_data(num_embeddings, pir_stats, dpf_perf_numbers) 138 | final.append((fname, joined_data)) 139 | 140 | for fname, final_d in final: 141 | bname = os.path.basename(fname) 142 | fpath = f"{d_out}/{bname}" 143 | 144 | if not os.path.exists(d_out): 145 | os.makedirs(d_out) 146 | 147 | #print(fpath) 148 | if final_d is not None: 149 | with open(fpath, "w") as ff: 150 | json.dump(final_d, ff) 151 | 152 | if __name__ == "__main__": 153 | num_embeddings = application_config(app) 154 | data = load_acc_data(d_in) 155 | dpf_perf_numbers = load_dpf_perf_data(dpf_in) 156 | 157 | join_data(num_embeddings, data, dpf_perf_numbers) 158 | -------------------------------------------------------------------------------- /paper/experimental/codesign/plot_lm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import sys 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import glob 7 | import json 8 | 9 | app_name = "lm" 10 | 11 | def load_data(files): 12 | data = [] 13 | for fname in files: 14 | with open(fname, "r") as f: 15 | data.append(json.load(f)) 16 | return data 17 | 18 | # Fairly fast for many datapoints, less fast for many costs, somewhat readable 19 | def is_pareto_efficient_simple(costs): 20 | """ 21 | Find the pareto-efficient points 22 | :param costs: An (n_points, n_costs) array 23 | :return: A (n_points, ) boolean array, indicating whether each point is Pareto efficient 24 | """ 25 | is_efficient = np.ones(costs.shape[0], dtype = bool) 26 | for i, c in enumerate(costs): 27 | if is_efficient[i]: 28 | is_efficient[is_efficient] = np.any(costs[is_efficient]= .50] 49 | accuracy_baseline = max([x["accuracy_stats"]["auc"] for x in data]) 50 | 51 | plain = [x for x in data if x["hotcold_config"]["cache_size_fraction"] == 1 and x["collocate_config"]["num_collocate"] == 0] 52 | 53 | for d in data: 54 | continue 55 | pprint.pprint(d) 56 | sys.stdin.readline() 57 | 58 | plt.axhline(accuracy_baseline) 59 | 60 | def plot_single(data, label, marker, color): 61 | 62 | all_accuracies = [] 63 | all_throughputs = [] 64 | 65 | for d in data: 66 | for latency, throughput in d["latency_throughputs"]: 67 | if latency < 100: 68 | all_accuracies.append(d["accuracy_stats"]["auc"]) 69 | all_throughputs.append(throughput) 70 | #all_throughputs.append(d["cost"]["computation"]) 71 | 72 | 73 | #print(list(set(all_accuracies))) 74 | #sys.exit(0) 75 | 76 | #print(list(set(all_throughputs))) 77 | 78 | #zipped = list(zip(all_throughputs, all_accuracies)) 79 | #zipped.sort(key=lambda x:x[0]) 80 | #print(zipped[:10]) 81 | #sys.exit(0) 82 | 83 | """ 84 | all_throughputs_accuracies = list(zip(all_throughputs, all_accuracies)) 85 | np.random.shuffle(all_throughputs_accuracies) 86 | all_throughputs = [x[0] for x in all_throughputs_accuracies] 87 | all_accuracies = [x[1] for x in all_throughputs_accuracies] 88 | all_throughputs = all_throughputs[:1000] 89 | all_accuracies = all_accuracies[:1000] 90 | print(sorted(list(zip(all_throughputs, all_accuracies)))) 91 | """ 92 | 93 | #""" 94 | all_accuracies = [-x for x in all_accuracies] 95 | all_throughputs = [-x for x in all_throughputs] 96 | print(max(all_throughputs)) 97 | all_throughputs, all_accuracies = get_pareto_points(all_throughputs, all_accuracies) 98 | all_throughputs = [-x for x in all_throughputs] 99 | all_accuracies = [-x for x in all_accuracies] 100 | #all_throughputs = [-x for x in all_throughputs] 101 | plt.plot(all_throughputs, all_accuracies, label=label, markersize=15, marker=marker, alpha=1, color=color, linewidth=5) 102 | #""" 103 | 104 | #print(list(zip(all_throughputs, all_accuracies))) 105 | 106 | print(len(all_throughputs)) 107 | #plt.scatter(all_throughputs, all_accuracies, label=label, marker=marker, alpha=1, color=color) 108 | 109 | plot_single(plain, "batch-pir", "o", "black") 110 | plot_single(data, "batch-pir w/ co-design", "x", "blue") 111 | 112 | plt.xlabel("Throughput (q/ms)", fontsize=28) 113 | plt.ylabel("Accuracy (auc)", fontsize=28) 114 | 115 | plt.xticks(fontsize=22) 116 | plt.yticks(fontsize=22) 117 | 118 | plt.legend(loc="best", fontsize=14); 119 | 120 | _, max_number = plt.xlim() 121 | #plt.xticks(np.arange(0, max_number, 5)) 122 | 123 | plt.tight_layout() 124 | 125 | plt.savefig(f"{app_name}_throughpt_auc.pdf", tight_layout=True) 126 | 127 | 128 | d_in = sys.argv[1] 129 | data = load_data(glob.glob(f"{d_in}/*")) 130 | 131 | plot_accuracy_vs_throughput(data) 132 | 133 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | # NOTE: REMOVED GOOGLE DPF CODEBASE, DOWNLOAD IT AGAIN HERE: https://github.com/google/distributed_point_functions 3 | # 4 | # Really long compile command for linking with Google's DPF library 5 | # To get this working I 6 | # - Add build target to DPF bazel build file and build using bazel 7 | # - Find all the header dependencies that bazel uses + external ones (e.g: boringssl) 8 | # - Run bazel test //... and copy over generated protobuf files for distributed_point_function.proto 9 | # 10 | # Make sure when running to add path to dpf.so 11 | # - e.g: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH://private/home/maxlam/pir/dpf/distributed_point_functions/bazel-bin/dpf 12 | 13 | CC=g++ 14 | CFLAGS=-std=c++17 -O3 -fopenmp 15 | NVCC=nvcc -Xcompiler -fopenmp 16 | CUDA_FLAGS=-lcuda -lcudart -lcublas 17 | DPF_INCS=-L ./distributed_point_functions/bazel-bin/dpf/ -I /private/home/maxlam/pir/dpf/distributed_point_functions/dpf -I/private/home/maxlam/.conda/envs/pir/share/bazel/dacb9117b7bba7784d38f8b74d338ba9/external/boringssl/src/include/ -I /private/home/maxlam/.conda/envs/pir/share/bazel/dacb9117b7bba7784d38f8b74d338ba9/external/com_google_absl -I /private/home/maxlam/pir/dpf/distributed_point_functions -I /private/home/maxlam/.conda/envs/pir/share/bazel/dacb9117b7bba7784d38f8b74d338ba9/external/com_google_protobuf_protoc_linux_x86_64/include/ -I /private/home/maxlam/.conda/envs/pir/share/bazel/dacb9117b7bba7784d38f8b74d338ba9/external/com_github_protocolbuffers_protobuf/src/ -I /private/home/maxlam/.conda/envs/pir/share/bazel/dacb9117b7bba7784d38f8b74d338ba9/external/com_github_google_highway/ -ldpf 18 | 19 | sanity: 20 | $(CC) $(CFLAGS) test_dpf_so.cc $(DPF_INCS) -o sanity 21 | 22 | benchmark: 23 | $(NVCC) -c benchmark.cu -o benchmark.o $(CUDA_FLAGS) 24 | $(CC) $(CFLAGS) -c dpf_helpers.cc $(DPF_INCS) -o dpf_helpers.o 25 | $(NVCC) -o benchmark benchmark.o dpf_helpers.o $(DPF_INCS) $(CUDA_FLAGS) 26 | 27 | # Make sure when running to add path to dpf.so 28 | # - e.g: export LD_LIBRARY_PATH=$$LD_LIBRARY_PATH://private/home/maxlam/pir/dpf/distributed_point_functions/bazel-bin/dpf 29 | 30 | # benchmark [n_embedding_entries] [embedding_length] [use_dpf] [batch] [reps] [use_gemm] [dpf_threads] 31 | #./benchmark 10000000 256 0 32 100 1 1 32 | clean: 33 | rm -f benchmark 34 | rm -f sanity 35 | rm -f *.o 36 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/benchmark.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | /* 5 | Simulates the performance of 2 server PIR 6 | - Does not measure communication latency 7 | - Performs DPF on the CPU _only_ 8 | - Measures runtime for a single server to expand DPF vector + CUDA Matvecmul 9 | (Note: since the other server does the same thing, runtime will be same) 10 | 11 | Current workflow 12 | - Initialize embedding tables (w/ random numbers) assumed to be a secret 13 | share of the true table. Table size is parameterized. Put on GPU 14 | - Initialize DPF key using google's DPF CPU library 15 | - Start timer 16 | - Expand DPF on vector 17 | - Transfer vector GPU memory, perform call to Cublas GEMV 18 | - Stop timer 19 | */ 20 | 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include "dpf_helpers.h" 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | 34 | void *AllocEmbeddingTable(int n, int l) { 35 | // n - number of table entries 36 | // l - number of elements per table entry (vector dimension) 37 | // Assum each entry is uint32_t 38 | 39 | // Allocate on CPU, populate with some values, cpy to GPU 40 | uint32_t *cpu_table = (uint32_t *)malloc(n*l*sizeof(uint32_t)); 41 | if (cpu_table == NULL) assert(0); 42 | for (int i = 0; i < n; i++) { 43 | for (int j = 0; j < l; j++) { 44 | cpu_table[i*l+j] = i; 45 | } 46 | } 47 | 48 | void *gpu_table; 49 | cudaMalloc((void**)&gpu_table, sizeof(uint32_t)*n*l); 50 | cudaMemcpy(gpu_table, cpu_table, sizeof(uint32_t)*n*l, cudaMemcpyHostToDevice); 51 | 52 | free(cpu_table); 53 | 54 | return gpu_table; 55 | } 56 | 57 | 58 | int main(int argc, char *argv[]) { 59 | int N_EMBEDDING_ENTRIES = atoi(argv[1]); 60 | int EMBEDDING_LENGTH = atoi(argv[2]); 61 | int USE_DPF = atoi(argv[3]); 62 | int BATCH_SIZE = atoi(argv[4]); 63 | int REPS = atoi(argv[5]); 64 | int USE_GEMM = atoi(argv[6]); 65 | int DPF_THREADS = atoi(argv[7]); 66 | 67 | printf("Params: n_embedding_entries=%d, embedding_length=%d, use_dpf=%d, batch=%d reps=%d\n", N_EMBEDDING_ENTRIES, EMBEDDING_LENGTH, USE_DPF, BATCH_SIZE, REPS); 68 | 69 | std::cout << "Init CUDA" << std::endl; 70 | cudaStream_t cudaStream; 71 | cublasHandle_t handle; 72 | cublasCreate(&handle); 73 | if (CUBLAS_STATUS_SUCCESS != cublasSetStream(handle, cudaStream)) 74 | { 75 | printf("Cublas set stream failed\n"); 76 | exit(-1); 77 | } 78 | 79 | std::cout << "Malloc embedding table onto GPU" << std::endl; 80 | void *embedding_table = AllocEmbeddingTable(N_EMBEDDING_ENTRIES, 81 | EMBEDDING_LENGTH); 82 | 83 | std::cout << "Initializing DPF" << std::endl; 84 | void *dpf = DPFInitialize(32, 32); 85 | void *k1, *k2; 86 | DPFGetKey(dpf, 42, 21, &k1, &k2); 87 | 88 | std::cout << "Mallocing indicator vector" << std::endl; 89 | uint32_t *indicator_vector; 90 | cudaMallocManaged(&indicator_vector, sizeof(uint32_t)*N_EMBEDDING_ENTRIES*BATCH_SIZE); 91 | 92 | std::cout << "Mallocing output vector" << std::endl; 93 | uint32_t *o; 94 | cudaMallocManaged(&o, sizeof(uint32_t)*EMBEDDING_LENGTH); 95 | cudaMemset(o, 0, sizeof(uint32_t)*EMBEDDING_LENGTH); 96 | 97 | //////////////////////////////////////////////////////////////////////// 98 | // Past this point the benchmark times 99 | std::cout << "Benchmarking..." << std::endl; 100 | 101 | cudaEvent_t start, stop; 102 | cudaEventCreate(&start); 103 | cudaEventCreate(&stop); 104 | 105 | float gemm_ms_time_cumulative = 0; 106 | float dpf_ms_time_cumulative = 0; 107 | 108 | omp_set_num_threads(DPF_THREADS); 109 | 110 | for (int i = 0; i < REPS; i++) { 111 | std::cout << i << std::endl; 112 | 113 | // DPF initialize batch of secret shared indicator vectors 114 | if (USE_DPF) { 115 | auto t0 = std::chrono::high_resolution_clock::now(); 116 | 117 | #pragma omp parallel for 118 | for (int j = 0; j < BATCH_SIZE; j++) { 119 | DPFExpand(dpf, k1, N_EMBEDDING_ENTRIES, indicator_vector + N_EMBEDDING_ENTRIES*j); 120 | } 121 | 122 | auto t1 = std::chrono::high_resolution_clock::now(); 123 | std::chrono::duration< double > fs = t1 - t0; 124 | std::chrono::milliseconds d = std::chrono::duration_cast< std::chrono::milliseconds >(fs); 125 | 126 | dpf_ms_time_cumulative += d.count(); 127 | } 128 | 129 | if (USE_GEMM) { 130 | // Batch matmul 131 | int alpha = 1; 132 | int beta = 0; 133 | cudaEventRecord(start); 134 | auto status = cublasGemmEx(handle, 135 | CUBLAS_OP_T, // Embed Table is row major 136 | CUBLAS_OP_N, // indicator vec is col major 137 | 138 | EMBEDDING_LENGTH, //m, where mxkxn, n=1 139 | BATCH_SIZE, //n (batch) 140 | N_EMBEDDING_ENTRIES, //k 141 | 142 | &alpha, //alpha 143 | embedding_table, //A 144 | CUDA_R_32F, //dtype of A 145 | N_EMBEDDING_ENTRIES, //lda 146 | indicator_vector, //B 147 | CUDA_R_32F, //dtype of B 148 | N_EMBEDDING_ENTRIES, //ldb 149 | &beta, //beta 150 | o, //C 151 | CUDA_R_32F, //dtype of C. idea is values wrap around and hardware overflow implicitly does mod math 152 | EMBEDDING_LENGTH, //ldc 153 | CUDA_R_32F, 154 | CUBLAS_GEMM_DEFAULT_TENSOR_OP); 155 | cudaEventRecord(stop); 156 | if (status != CUBLAS_STATUS_SUCCESS) { 157 | std::cout << "GemmEx Failed " << status << " " << CUBLAS_STATUS_NOT_SUPPORTED << std::endl; 158 | } 159 | 160 | cudaEventSynchronize(stop); 161 | float milliseconds = 0; 162 | cudaEventElapsedTime(&milliseconds, start, stop); 163 | gemm_ms_time_cumulative += milliseconds; 164 | } 165 | } 166 | 167 | float total_time_ms = gemm_ms_time_cumulative + dpf_ms_time_cumulative; 168 | float total_throughput = (REPS*BATCH_SIZE)/total_time_ms; 169 | float dpf_throughput = (REPS*BATCH_SIZE)/dpf_ms_time_cumulative; 170 | float gemm_throughput = (REPS*BATCH_SIZE)/gemm_ms_time_cumulative; 171 | 172 | printf("{'total_time_ms' : %f, 'total_throughput': %f," 173 | "'dpf_throughput': %f, 'gemm_throughput': %f," 174 | "'gemm_ms_time_cumulative':%f," 175 | "'dpf_ms_time_cumulative': %f," 176 | "'N_EMBEDDING_ENTRIES': %d," 177 | "'EMBEDDING_LENGTH': %d," 178 | "'USE_DPF': %d," 179 | "'BATCH_SIZE': %d," 180 | "'REPS': %d," 181 | "'USE_GEMM': %d," 182 | "'DPF_THREADS': %d}", 183 | total_time_ms, total_throughput, dpf_throughput, 184 | gemm_throughput, gemm_ms_time_cumulative, dpf_ms_time_cumulative, 185 | N_EMBEDDING_ENTRIES, EMBEDDING_LENGTH, USE_DPF, BATCH_SIZE, REPS, USE_GEMM, DPF_THREADS); 186 | } 187 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/benchmark_multithread_dpf.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | batch_sizes=( 1 32 128 512 ) 4 | lengths=( 16 128 1024 16384 1048576 ) 5 | emb_sizs=( 32 64 128 256 512 ) 6 | threads=( 100 56 28 16 8 4 2 1 ) 7 | 8 | for b in "${batch_sizes[@]}"; do 9 | for length in "${lengths[@]}"; do 10 | for emb_siz in "${emb_sizs[@]}"; do 11 | for thread in "${threads[@]}"; do 12 | ./benchmark $length $emb_siz 1 $b 100 0 $thread > multithread_dpf_benchmark_n=${length}_d=${emb_siz}_b=${b}_threads=${thread} 13 | done 14 | done 15 | done 16 | done 17 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/dpf_helpers.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include "dpf_helpers.h" 4 | #include "distributed_point_function.h" 5 | #include "dpf/distributed_point_function.pb.h" 6 | 7 | void *DPFInitialize(int log_domain_size, int bitsize) { 8 | distributed_point_functions::DpfParameters parameters; 9 | parameters.set_log_domain_size(log_domain_size); 10 | parameters.mutable_value_type()->mutable_integer()->set_bitsize(bitsize); 11 | std::unique_ptr dpf = 12 | distributed_point_functions::DistributedPointFunction::Create(parameters).value(); 13 | if (dpf == NULL) { 14 | std::cout << "Error dpf NULL... Exiting" << std::endl; 15 | exit(0); 16 | } 17 | 18 | return (void *)dpf.release(); 19 | } 20 | 21 | void DPFGetKey(void *dpf_ptr, int alpha, int beta, void **k1, void **k2) { 22 | distributed_point_functions::DistributedPointFunction *dpf = (distributed_point_functions::DistributedPointFunction *)dpf_ptr; 23 | auto keypair = dpf->GenerateKeys((uint32_t)alpha, (uint32_t)beta).value(); 24 | *k1 = keypair.first.New(); 25 | *k2 = keypair.second.New(); 26 | keypair.first.Swap((distributed_point_functions::DpfKey *)*k1); 27 | keypair.second.Swap((distributed_point_functions::DpfKey *)*k2); 28 | } 29 | 30 | void DPFExpand(void *dpf_ptr, void *key_data, int N, uint32_t *out) { 31 | distributed_point_functions::DistributedPointFunction *dpf = (distributed_point_functions::DistributedPointFunction *)dpf_ptr; 32 | distributed_point_functions::DpfKey key = *((distributed_point_functions::DpfKey *)key_data); 33 | std::vector evaluation_points(N); 34 | for (int i = 0; i < N; i++) evaluation_points[i] = i; 35 | std::vector expanded = dpf->EvaluateAt(key, 0, evaluation_points).value(); 36 | //memcpy(out, expanded.data(), sizeof(uint32_t)*N); 37 | } 38 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/dpf_helpers.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | void *DPFInitialize(int log_domain_size, int bitsize); 9 | void DPFGetKey(void *dpf, int alpha, int beta, void **k1, void **k2); 10 | void DPFExpand(void *dpf_ptr, void *key_data, int N, uint32_t *out); 11 | -------------------------------------------------------------------------------- /paper/kernel/cpu/dpf_google/test_dpf_so.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include 5 | #include "distributed_point_function.h" 6 | #include "dpf/distributed_point_function.pb.h" 7 | 8 | using namespace std; 9 | 10 | int main(int argc, char *argv[]) { 11 | 12 | // Init 13 | distributed_point_functions::DpfParameters parameters; 14 | parameters.set_log_domain_size(64); 15 | parameters.mutable_value_type()->mutable_integer()->set_bitsize(64); 16 | std::unique_ptr dpf = 17 | distributed_point_functions::DistributedPointFunction::Create(parameters).value(); 18 | 19 | if (dpf == NULL) { 20 | cout << "Error dpf NULL... Exiting" << endl; 21 | exit(0); 22 | } 23 | 24 | // Actual key generation 25 | // "Generates a pair of keys for a DPF that evaluates to `beta` when evaluated 26 | // `alpha`" 27 | //absl::uint128 alpha = 42; 28 | //absl::uint128 beta = 21; 29 | uint64_t alpha = 42; 30 | uint64_t beta = 21; 31 | auto keypair = dpf->GenerateKeys(alpha, beta).value(); 32 | 33 | // Test 34 | int num_evaluation_points = 100; 35 | std::vector evaluation_points(num_evaluation_points); 36 | for (int i = 0; i < num_evaluation_points; ++i) { 37 | evaluation_points[i] = i; 38 | } 39 | 40 | auto r1 = dpf->EvaluateAt(keypair.first, 0, evaluation_points).value(); 41 | auto r2 = dpf->EvaluateAt(keypair.second, 0, evaluation_points).value(); 42 | 43 | int failed = 0; 44 | for (int i = 0; i < num_evaluation_points; i++) { 45 | auto sum = r1[i] + r2[i]; 46 | auto truth = i == alpha ? beta : 0; 47 | if (sum != truth) { 48 | failed = 1; 49 | } 50 | cout << "Index " << i << " w/ sum: " << sum << " (expect: " << truth << ")" << endl; 51 | } 52 | 53 | if (failed) cout << "FAIL" << endl; 54 | else cout << "SUCCESS" << endl; 55 | } 56 | -------------------------------------------------------------------------------- /paper/kernel/gpu/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | DPF_STRATEGY="DPF_BREADTH_FIRST" 4 | PRF_METHOD="SALSA20_12_CIPHER" 5 | #DPF_STRATEGY="DPF_HYBRID" 6 | #PRF_METHOD="DUMMY" 7 | NUM_ENTRIES=16384 8 | BATCH_SIZE=32 9 | ENTRY_SIZE=16 10 | FUSES_MATMUL=0 11 | PERFORM_MATMUL=0 12 | 13 | PRF_CONFIG_FLAGS=-DPRF_METHOD=$(PRF_METHOD) -DDPF_STRATEGY=$(DPF_STRATEGY) -DMM=$(ENTRY_SIZE) -DKK=$(NUM_ENTRIES) -DNN=$(BATCH_SIZE) -DFUSES_MATMUL=$(FUSES_MATMUL) -DPERFORM_MATMUL=$(PERFORM_MATMUL) 14 | 15 | CC=g++ 16 | NVCC=nvcc 17 | NVCC_FLAGS=-lineinfo --no-host-device-initializer-list -Xcudafe --display_error_number 18 | 19 | benchmark: 20 | $(NVCC) $(PRF_CONFIG_FLAGS) dpf_gpu/dpf_benchmark.cu $(NVCC_FLAGS) -o dpf_benchmark 21 | 22 | profile: 23 | $(NVCC) $(PRF_CONFIG_FLAGS) -DREPS=1 dpf_gpu/dpf_benchmark.cu $(NVCC_FLAGS) -o dpf_benchmark 24 | ncu -f --import-source yes --set full --replay-mode application -o $(DPF_STRATEGY),num_entries=$(NUM_ENTRIES),batch_size=$(BATCH_SIZE),entry_size=$(ENTRY_SIZE),fuses_matmul=$(FUSES_MATMUL),perform_matmul=$(PERFORM_MATMUL),prf_method=$(PRF_METHOD) --target-processes all ./dpf_benchmark 25 | 26 | benchmark_matmul: 27 | $(NVCC) $(PRF_CONFIG_FLAGS) dpf_gpu/matmul_benchmark.cu $(NVCC_FLAGS) -o matmul_benchmark 28 | 29 | profile_matmul: 30 | $(NVCC) $(PRF_CONFIG_FLAGS) -DREPS=1 dpf_gpu/matmul_benchmark.cu $(NVCC_FLAGS) -o matmul_benchmark 31 | ncu -f --import-source yes --set full -o matmul,num_entries=$(NUM_ENTRIES),batch_size=$(BATCH_SIZE),entry_size=$(ENTRY_SIZE) --target-processes all ./matmul_benchmark 32 | 33 | test_128_bit_functionality: 34 | $(NVCC) dpf_gpu/tests/test_128_bit.cu $(NVCC_FLAGS) -o test_128_bit 35 | ./test_128_bit 36 | rm -f *.o 37 | rm test_128_bit 38 | 39 | test_dpf_base: 40 | $(CC) -Ofast dpf_base/dpf.cc -o dpf_cpu_base 41 | ./dpf_cpu_base 42 | rm -f *.o 43 | rm dpf_cpu_base 44 | 45 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_base/dpf.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | /* 4 | Serial CPU dpf function based on the sqrt(n) grid trick described 5 | - https://www.youtube.com/watch?v=y2aVgxD7DJc 6 | - https://www.iacr.org/archive/eurocrypt2014/84410245/84410245.pdf 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include "dpf.h" 18 | 19 | int main(int argc, char *argv[]) { 20 | test_log_n_method(); 21 | test_sqrt_n_method(); 22 | benchmark_log_n_method_perf(); 23 | test_flat_codewords(); 24 | } 25 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/dpf/dpf_breadth_first.cu: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 3 | 4 | // This is like batch size: a block expands _multiple_ dpfs 5 | #define DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK 1024 6 | #define DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK 1 7 | 8 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_1, *DPF_BREADTH_PARALLEL_KEYS_2; 9 | 10 | void dpf_breadth_first_initialize(int batch_size, int num_entries) { 11 | cudaMalloc(&DPF_BREADTH_PARALLEL_KEYS_1, sizeof(uint128_t_gpu)*batch_size*num_entries); 12 | cudaMalloc(&DPF_BREADTH_PARALLEL_KEYS_2, sizeof(uint128_t_gpu)*batch_size*num_entries); 13 | } 14 | 15 | __global__ void dpf_breadth_first_kernel(SeedsCodewordsFlatGPU *cw, uint128_t_gpu *out, 16 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_1, 17 | uint128_t_gpu *DPF_BREADTH_PARALLEL_KEYS_2, 18 | int batch_size, int num_entries) { 19 | 20 | // Computes DPF expansion in a breadth parallel way 21 | int thread_idx = threadIdx.x; 22 | int block_idx = blockIdx.x; 23 | 24 | // This block handles DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK DPF expansions 25 | int cw_start = blockIdx.x*DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 26 | int cw_end = cw_start+DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 27 | 28 | // Load cw to shared memory 29 | __shared__ SeedsCodewordsFlatGPU cw_shared[DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK]; 30 | if (thread_idx < DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK) { 31 | cw_shared[thread_idx] = cw[cw_start+thread_idx]; 32 | } 33 | 34 | __syncthreads(); 35 | 36 | // Simple recurrence relation for expanding binary tree-based DPF. 37 | // Nodes numbered with following format: 38 | // 0 39 | // / \ 40 | // 0 1 41 | // / \ / \ 42 | // 0 1 2 3 43 | // 44 | // Relation: 45 | // k_1 = seed 46 | // k_i = PRF(k_{i//2}, i % 2) + CW_{k_{i//2} & 1}(i % 2) 47 | // 48 | // Output k_{2^{depth-1}} to k_{2^{depth-1}} + N 49 | // 50 | // Do note, we are expanding multiple binary tree DPFs. 51 | // In this algo, each threadblock expands DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK dpfs. 52 | // Following checks ensure blocking params are correct: 53 | // assert(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK/DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK >= 1) 54 | // assert(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK%DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK == 0) 55 | uint128_t_gpu *key_write = DPF_BREADTH_PARALLEL_KEYS_1; 56 | uint128_t_gpu *key_read = DPF_BREADTH_PARALLEL_KEYS_2; 57 | uint128_t_gpu *tmp; 58 | 59 | constexpr int parallel_work_per_threadblock_per_dpf = (DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK/DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK); 60 | 61 | // Init the first seed 62 | int batch_idx = thread_idx % DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK; 63 | key_write[0 + (block_idx+batch_idx)*num_entries] = cw_shared[batch_idx].last_keys[0]; 64 | 65 | // Outer loop loops from top level of tree down to bottom 66 | for (int i = cw_shared[0].depth-1; i >= 0; i--) { 67 | 68 | // Swap read and write buffers 69 | tmp = key_read; 70 | key_read = key_write; 71 | key_write = tmp; 72 | 73 | // Can parallelize _within_ a level of the tree, but not _across_ levels of the tree 74 | __syncthreads(); 75 | 76 | // Inner loop scans the current level of the tree (in parallel batches) 77 | int start = 0, end = 1<<(cw_shared[0].depth-i); 78 | for (int j = start; j < end; j += parallel_work_per_threadblock_per_dpf) { 79 | int expansion_idx = j + (thread_idx % parallel_work_per_threadblock_per_dpf); 80 | int batch_idx = thread_idx / parallel_work_per_threadblock_per_dpf; 81 | 82 | if (expansion_idx < end) { 83 | int idx_into_codewords = expansion_idx % 2; 84 | uint128_t_gpu key = key_read[(expansion_idx/2) + (block_idx+batch_idx)*num_entries]; 85 | uint128_t_gpu value = PRF(key, idx_into_codewords); 86 | uint128_t_gpu *cw = (key.x & 1) == 0 ? cw_shared[batch_idx].cw_1 : cw_shared[batch_idx].cw_2; 87 | cw = &cw[i*2]; 88 | key_write[expansion_idx + (block_idx+batch_idx)*num_entries] = add_uint128(value, cw[idx_into_codewords]); 89 | } 90 | } 91 | } 92 | 93 | // Postamble, write to output 94 | for (int i = 0; i < num_entries; i+= parallel_work_per_threadblock_per_dpf) { 95 | int expansion_idx = i + (thread_idx % parallel_work_per_threadblock_per_dpf); 96 | int batch_idx = thread_idx / parallel_work_per_threadblock_per_dpf; 97 | int dst_idx = __brev(expansion_idx) >> (32-cw_shared[0].depth); 98 | 99 | // Do note: the best way to write memory is with _coalescing_. 100 | // Without it, huge performance slowdowns (2.5x slowdown!) 101 | // However, this writes to the output buffer in a permutated order. 102 | out[expansion_idx + batch_idx*num_entries + block_idx*DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK*num_entries] = 103 | key_write[expansion_idx + (block_idx+batch_idx)*num_entries]; 104 | } 105 | } 106 | 107 | void dpf_breadth_first(SeedsCodewordsFlatGPU *cw, 108 | uint128_t_gpu *out, 109 | int batch_size, int num_entries, 110 | cudaStream_t s) { 111 | dim3 n_blocks_breadth_parallel(batch_size / DPF_BREADTH_PARALLEL_DPFS_PER_BLOCK); 112 | dim3 n_threads_breadth_parallel(DPF_BREADTH_PARALLEL_THREADS_PER_BLOCK); 113 | 114 | dpf_breadth_first_kernel<<>>(cw, out, 115 | DPF_BREADTH_PARALLEL_KEYS_1, 116 | DPF_BREADTH_PARALLEL_KEYS_2, 117 | batch_size, 118 | num_entries); 119 | } 120 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/dpf/dpf_coop.cu: -------------------------------------------------------------------------------- 1 | 2 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 3 | 4 | #include 5 | 6 | using namespace cooperative_groups; 7 | 8 | #ifndef FUSES_MATMUL 9 | #define FUSES_MATMUL 1 10 | #endif 11 | 12 | //#define DPF_COOP_N_BLOCKS 64 13 | #define DPF_COOP_THREADS_PER_BLOCK 128 14 | 15 | int DPF_COOP_N_BLOCKS = -1; 16 | 17 | uint128_t_gpu *DPF_COOP_KEYS_1, *DPF_COOP_KEYS_2; 18 | uint128_t_gpu *TABLE_REDUCTION; 19 | 20 | __global__ void dpf_coop_kernel(SeedsCodewordsFlatGPU *cw, 21 | uint128_t_gpu *TABLE, 22 | uint128_t_gpu *TABLE_REDUCTION, 23 | uint128_t_gpu *out, 24 | uint128_t_gpu *DPF_COOP_KEYS_1, 25 | uint128_t_gpu *DPF_COOP_KEYS_2, 26 | int batch_size, int num_entries, 27 | int DPF_COOP_N_BLOCKS) { 28 | 29 | // Computes DPF expansion in a breadth parallel way 30 | int thread_idx = threadIdx.x; 31 | int block_idx = blockIdx.x; 32 | 33 | // Load cw to shared memory. Recall only 1 cw as batchsize=1 34 | __shared__ SeedsCodewordsFlatGPU cw_shared[1]; 35 | if (thread_idx == 0) { 36 | cw_shared[thread_idx] = cw[0]; 37 | } 38 | 39 | // Use cooperative groups to sync blocks 40 | this_grid().sync(); 41 | __syncthreads(); 42 | 43 | // Algorithm same as breadth parallel, see breadth parallel method for high level DPF strat 44 | uint128_t_gpu *key_write = DPF_COOP_KEYS_1; 45 | uint128_t_gpu *key_read = DPF_COOP_KEYS_2; 46 | uint128_t_gpu *tmp; 47 | 48 | // Init the first seed 49 | key_write[0] = cw_shared[0].last_keys[0]; 50 | 51 | // Outer loop loops from top level of tree down to bottom 52 | for (int i = cw_shared[0].depth-1; i >= 0; i--) { 53 | 54 | // Swap read and write buffers 55 | tmp = key_read; 56 | key_read = key_write; 57 | key_write = tmp; 58 | 59 | // Can parallelize _within_ a level of the tree, but not _across_ levels of the tree 60 | this_grid().sync(); 61 | __syncthreads(); 62 | 63 | // Inner loop scans the current level of the tree (in parallel batches) 64 | int start = 0, end = 1<<(cw_shared[0].depth-i); 65 | 66 | // Scan through the work. All threads of each block eval a single PRF 67 | for (int j = start; j < end; j += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 68 | int expansion_idx = j + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 69 | 70 | if (expansion_idx < end) { 71 | int idx_into_codewords = expansion_idx % 2; 72 | uint128_t_gpu key = key_read[(expansion_idx/2)]; 73 | uint128_t_gpu value = PRF(key, idx_into_codewords); 74 | uint128_t_gpu *cw = (key.x & 1) == 0 ? cw_shared[0].cw_1 : cw_shared[0].cw_2; 75 | cw = &cw[i*2]; 76 | key_write[expansion_idx] = add_uint128(value, cw[idx_into_codewords]); 77 | } 78 | } 79 | } 80 | 81 | #if(!FUSES_MATMUL) 82 | // Postamble, write to output 83 | for (int i = 0; i < num_entries; i += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 84 | int expansion_idx = i + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 85 | 86 | // Do note: the best way to write memory is with _coalescing_. 87 | // Without it, huge performance slowdowns (2.5x slowdown!) 88 | // However, this writes to the output buffer in a permutated order. 89 | if (expansion_idx < num_entries) { 90 | out[expansion_idx] = key_write[expansion_idx]; 91 | } 92 | } 93 | #else 94 | 95 | // Fused matmul. Recall MM is num_elements_per_entry 96 | uint128_t_gpu per_thread_accumulate[MM] = {0}; 97 | for (int i = 0; i < num_entries; i += DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK) { 98 | int expansion_idx = i + (block_idx*DPF_COOP_THREADS_PER_BLOCK + thread_idx); 99 | if (expansion_idx < num_entries) { 100 | for (int z = 0; z < MM; z++) { 101 | per_thread_accumulate[z] = add_uint128(mul_uint128(key_write[expansion_idx], TABLE[expansion_idx]), 102 | per_thread_accumulate[z]); 103 | } 104 | } 105 | } 106 | 107 | // Tree sum reduction on accumulates 108 | int total_threads = DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK; 109 | int glob_thread_idx = block_idx*DPF_COOP_THREADS_PER_BLOCK+thread_idx; 110 | 111 | // Write local accumulates to table 112 | for (int i = 0; i < MM; i++) { 113 | TABLE_REDUCTION[i*total_threads+glob_thread_idx] = per_thread_accumulate[i]; 114 | } 115 | 116 | this_grid().sync(); 117 | __syncthreads(); 118 | 119 | for (int neighbor = 1; neighbor < total_threads; neighbor*=2) { 120 | if (glob_thread_idx % (neighbor*2) == 0 && glob_thread_idx+neighbor < total_threads) { 121 | for (int z = 0; z < MM; z++) { 122 | TABLE_REDUCTION[z*total_threads+glob_thread_idx] = 123 | add_uint128(TABLE_REDUCTION[z*total_threads+glob_thread_idx], 124 | TABLE_REDUCTION[z*total_threads+glob_thread_idx+neighbor]); 125 | } 126 | } 127 | this_grid().sync(); 128 | __syncthreads(); 129 | } 130 | 131 | if (glob_thread_idx == 0) { 132 | for (int z = 0; z < MM; z++) { 133 | out[z] = TABLE_REDUCTION[z*total_threads+0]; 134 | } 135 | } 136 | 137 | #endif 138 | } 139 | 140 | int getMaxInterpreterGrid(int numThreads) { 141 | int maxBlocksPerSM = 0; 142 | cudaOccupancyMaxActiveBlocksPerMultiprocessor(&maxBlocksPerSM, dpf_coop_kernel, numThreads, 0); 143 | 144 | cudaDeviceProp deviceProp; 145 | cudaGetDeviceProperties(&deviceProp, 0); 146 | int numSM = deviceProp.multiProcessorCount; 147 | 148 | return maxBlocksPerSM * numSM; 149 | } 150 | 151 | void dpf_coop_initialize(int batch_size, int num_entries, int entry_size) { 152 | // Same as breadth parallel strategy, except always uses 153 | // batch 1. 154 | if (batch_size != 1) { 155 | printf("Cooperative threads DPF strategy requires batch_size=1\n"); 156 | } 157 | assert(batch_size == 1); 158 | 159 | DPF_COOP_N_BLOCKS = getMaxInterpreterGrid(DPF_COOP_THREADS_PER_BLOCK); 160 | printf("Coooperative threads DPF strategy with grid size %d\n", DPF_COOP_N_BLOCKS); 161 | 162 | cudaMalloc(&DPF_COOP_KEYS_1, sizeof(uint128_t_gpu)*batch_size*num_entries); 163 | cudaMalloc(&DPF_COOP_KEYS_2, sizeof(uint128_t_gpu)*batch_size*num_entries); 164 | 165 | // Given batch size 1, we also initialize a table of size num_entries*entry_size 166 | // for the purpose of reducing the final accumulates 167 | cudaMalloc(&TABLE_REDUCTION, sizeof(uint128_t_gpu)*entry_size*DPF_COOP_N_BLOCKS*DPF_COOP_THREADS_PER_BLOCK); 168 | } 169 | 170 | void dpf_coop(SeedsCodewordsFlatGPU * cw, 171 | uint128_t_gpu *out, 172 | uint128_t_gpu *TABLE, 173 | int batch_size, int num_entries, 174 | cudaStream_t s) { 175 | dim3 n_blocks(DPF_COOP_N_BLOCKS); 176 | dim3 n_threads(DPF_COOP_THREADS_PER_BLOCK); 177 | 178 | void *kernel_args[] = 179 | { 180 | &cw, &TABLE, &TABLE_REDUCTION, &out, 181 | &DPF_COOP_KEYS_1, 182 | &DPF_COOP_KEYS_2, 183 | &batch_size, 184 | &num_entries, 185 | &DPF_COOP_N_BLOCKS, 186 | }; 187 | cudaLaunchCooperativeKernel((void *)dpf_coop_kernel, 188 | n_blocks, n_threads, kernel_args); 189 | } 190 | 191 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/dpf/dpf_naive.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include "../utils.h" 5 | 6 | #define DPF_NAIVE_BLOCK_W 8 7 | #define DPF_NAIVE_BLOCK_H 16 8 | 9 | __device__ uint128_t_gpu expand_dpf_naive_kernel(const SeedsCodewordsFlatGPU *s, int indx) { 10 | 11 | int indx_remaining = indx; 12 | uint128_t_gpu key = s->last_keys[0]; 13 | uint128_t_gpu value; 14 | 15 | for (int i = s->depth-1; i >= 0; i--) { 16 | int indx_into_codewords = indx_remaining % 2; 17 | value = PRF(key, indx_into_codewords); 18 | const uint128_t_gpu *cw = (key.x & 1) == 0 ? s->cw_1 : s->cw_2; 19 | cw = &cw[i*2]; 20 | key = add_uint128(value, cw[indx_into_codewords]); 21 | indx_remaining >>= 1; 22 | } 23 | 24 | return key; 25 | } 26 | 27 | __global__ void dpf_naive_kernel(SeedsCodewordsFlatGPU *cw, 28 | uint128_t_gpu *out, 29 | int batch_size) { 30 | 31 | int x_indx = blockIdx.x*DPF_NAIVE_BLOCK_W + threadIdx.x; 32 | int y_indx = blockIdx.y*DPF_NAIVE_BLOCK_H + threadIdx.y; 33 | int out_indx = y_indx*batch_size + x_indx; 34 | 35 | out[out_indx] = expand_dpf_naive_kernel(&cw[x_indx], y_indx); 36 | } 37 | 38 | void dpf_naive(SeedsCodewordsFlatGPU *cw, 39 | uint128_t_gpu *out, 40 | int batch_size, int num_entries, 41 | cudaStream_t s) { 42 | dim3 threads_per_block_naive(DPF_NAIVE_BLOCK_W, DPF_NAIVE_BLOCK_H); 43 | dim3 n_blocks_naive(batch_size/DPF_NAIVE_BLOCK_W, num_entries/DPF_NAIVE_BLOCK_H); 44 | 45 | dpf_naive_kernel<<>>(cw, out, batch_size); 46 | } 47 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/dpf_benchmark.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include "prf/prf.cu" 5 | #include "matmul/matmul.cu" 6 | #include "utils.h" 7 | 8 | // By default expand dpf in full 9 | #ifndef FUSES_MATMUL 10 | #define FUSES_MATMUL 0 11 | #endif 12 | 13 | /**************************** 14 | * DPF Strategies * 15 | ****************************/ 16 | 17 | #define DPF_NAIVE 0 18 | #define DPF_BREADTH_FIRST 1 19 | #define DPF_HYBRID 2 20 | #define DPF_COOP 3 21 | 22 | #ifndef DPF_STRATEGY 23 | #define DPF_STRATEGY DPF_HYBRID 24 | #endif 25 | 26 | // Include different DPF methods depending on set strategy. 27 | // We do this as some DPF methods may pre-initialize or use 28 | // global memory. 29 | #if(DPF_STRATEGY == DPF_NAIVE) 30 | #include "dpf/dpf_naive.cu" 31 | #endif 32 | #if(DPF_STRATEGY == DPF_BREADTH_FIRST) 33 | #include "dpf/dpf_breadth_first.cu" 34 | #endif 35 | #if(DPF_STRATEGY == DPF_HYBRID) 36 | #include "dpf/dpf_hybrid.cu" 37 | #endif 38 | #if(DPF_STRATEGY == DPF_COOP) 39 | #include "dpf/dpf_coop.cu" 40 | #endif 41 | 42 | 43 | // Flag if we perform matmul at end 44 | #ifndef PERFORM_MATMUL 45 | #define PERFORM_MATMUL 0 46 | #endif 47 | 48 | #if(FUSES_MATMUL == 1) 49 | #undef PERFORM_MATMUL 50 | #define PERFORM_MATMUL 1 51 | #endif 52 | 53 | /********************************* 54 | * Compile time table constants * 55 | *********************************/ 56 | // MM - entry size (note this does not affect DPF expansion) 57 | // KK - number of entries 58 | // NN - batch size 59 | #ifndef MM 60 | #define MM 1 61 | #endif 62 | 63 | #ifndef KK 64 | #define KK 1048576 65 | #endif 66 | 67 | #ifndef NN 68 | #define NN 64 69 | #endif 70 | 71 | // Benchmarking constants 72 | #ifndef REPS 73 | #define REPS 10 74 | #endif 75 | 76 | // Table on GPU 77 | uint128_t_gpu *TABLE; 78 | 79 | void initialize_table(uint128_t_gpu *table, int num_entries, int entry_size) { 80 | 81 | // Make sure num entries is pow of 2 82 | assert((num_entries & (num_entries-1)) == 0); 83 | 84 | // First, re-order the table according to DPF scattered output and cvt to uint128_t_gpu 85 | uint128_t_gpu *table_reordered_cvted = new uint128_t_gpu[num_entries*entry_size]; 86 | for (int j = 0; j < entry_size; j++) { 87 | for (int i = 0; i < num_entries; i++) { 88 | int reordered_indx = brev_cpu(i) >> 32 - (int)log2(num_entries); 89 | table_reordered_cvted[i+j*num_entries] = table[reordered_indx+j*num_entries]; 90 | } 91 | } 92 | 93 | // Alloc and cpy to uint128_t_gpu array 94 | cudaMalloc(&TABLE, sizeof(uint128_t_gpu)*num_entries*entry_size); 95 | cudaMemcpy(TABLE, table_reordered_cvted, sizeof(uint128_t_gpu)*num_entries*entry_size, cudaMemcpyHostToDevice); 96 | 97 | delete table_reordered_cvted; 98 | } 99 | 100 | std::string get_DPF_strategy() { 101 | if (DPF_STRATEGY == DPF_NAIVE) return "Naive"; 102 | if (DPF_STRATEGY == DPF_BREADTH_FIRST) return "Breadth-first"; 103 | if (DPF_STRATEGY == DPF_HYBRID) return "Memory-efficient"; 104 | if (DPF_STRATEGY == DPF_COOP) return "Cooperative threads"; 105 | } 106 | 107 | void print_params() { 108 | printf("------------------------------------------------------\n"); 109 | printf("dpf_benchmark.cu:\n"); 110 | printf("------------------------------------------------------\n"); 111 | printf("- Entries in table: %d\n", KK); 112 | printf("- Batch size: %d\n", NN); 113 | printf("- Entry size: %d\n", MM); 114 | printf("- PRF method: %s\n", get_PRF_method().c_str()); 115 | printf("- DPF Strategy: %s\n", get_DPF_strategy().c_str()); 116 | printf("- Matmul fusion: %d\n", FUSES_MATMUL); 117 | printf("- Perform final matmul: %d\n", PERFORM_MATMUL); 118 | printf("------------------------------------------------------\n"); 119 | } 120 | 121 | int main() { 122 | 123 | print_params(); 124 | 125 | // Allocate codewords 126 | SeedsCodewordsFlatGPU *cw_gpu; 127 | auto cw_cpu = GenCodewords(KK, NN, &cw_gpu); 128 | 129 | 130 | #if(PERFORM_MATMUL) 131 | // Allocate & init CPU table 132 | uint128_t_gpu *table = new uint128_t_gpu[KK*MM]; 133 | for (int j = 0; j < MM; j++) { 134 | for (int i = 0; i < KK; i++) table[i+KK*j] = uint128_gpu_from((uint128_t)i); 135 | } 136 | 137 | // Initialize gpu table 138 | initialize_table(table, KK, MM); 139 | 140 | // If perform final matmul, create output buffer for it 141 | uint128_t_gpu *final_output_gpu; 142 | cudaMalloc((void **)&final_output_gpu, sizeof(uint128_t_gpu)*NN*MM); 143 | cudaMemset(final_output_gpu, 0, sizeof(uint128_t_gpu)*NN*MM); 144 | #endif 145 | 146 | // Allocate the B vector which holds either 147 | // - the DPF expanded one-hot secret share (C=A*B) 148 | // - batch size uint128_t_gpu for output 149 | uint128_t_gpu *B_gpu; 150 | #if(FUSES_MATMUL) 151 | B_gpu = final_output_gpu; 152 | #else 153 | cudaMalloc((void **)&B_gpu, sizeof(uint128_t_gpu)*KK*NN); 154 | cudaMemset(B_gpu, 0, sizeof(uint128_t_gpu)*KK*NN); 155 | #endif 156 | 157 | // Timer event trackers 158 | cudaEvent_t start, stop; 159 | cudaEventCreate(&start); 160 | cudaEventCreate(&stop); 161 | 162 | cudaEventRecord(start); 163 | 164 | // Do any initialization if needed 165 | #if(DPF_STRATEGY == DPF_BREADTH_FIRST) 166 | dpf_breadth_first_initialize(NN, KK); 167 | #endif 168 | #if(DPF_STRATEGY == DPF_HYBRID) 169 | dpf_hybrid_initialize(NN, KK); 170 | #endif 171 | #if(DPF_STRATEGY == DPF_COOP) 172 | dpf_coop_initialize(NN, KK, MM); 173 | #endif 174 | 175 | #if(PERFORM_MATMUL && !FUSES_MATMUL) 176 | initialize_matmul(MM, KK, NN); 177 | #endif 178 | 179 | 180 | // 181 | // Throughput benchmarks 182 | // 183 | 184 | // Use 2 streams. 185 | // This interleaves matmul and dpf expansion across multiple runs 186 | cudaStream_t s1, s2; 187 | cudaStreamCreate(&s1); 188 | cudaStreamCreate(&s2); 189 | 190 | for (int i = 0; i < REPS; i++) { 191 | #if(DPF_STRATEGY == DPF_NAIVE) 192 | dpf_naive(cw_gpu, B_gpu, NN, KK, s1); 193 | #endif 194 | #if(DPF_STRATEGY == DPF_BREADTH_FIRST) 195 | dpf_breadth_first(cw_gpu, B_gpu, NN, KK, s1); 196 | #endif 197 | #if(DPF_STRATEGY == DPF_HYBRID) 198 | dpf_hybrid(cw_gpu, B_gpu, TABLE, NN, KK, s1); 199 | #endif 200 | #if(DPF_STRATEGY == DPF_COOP) 201 | dpf_coop(cw_gpu, B_gpu, TABLE, NN, KK, s1); 202 | #endif 203 | 204 | 205 | #if(!FUSES_MATMUL && PERFORM_MATMUL) 206 | // Final matmul for obtaining results 207 | GEMM128(TABLE, final_output_gpu, B_gpu, MM, KK, NN, s1); 208 | #endif 209 | 210 | #if(DPF_STRATEGY == DPF_NAIVE) 211 | dpf_naive(cw_gpu, B_gpu, NN, KK, s2); 212 | #endif 213 | #if(DPF_STRATEGY == DPF_BREADTH_FIRST) 214 | dpf_breadth_first(cw_gpu, B_gpu, NN, KK, s2); 215 | #endif 216 | #if(DPF_STRATEGY == DPF_HYBRID) 217 | dpf_hybrid(cw_gpu, B_gpu, TABLE, NN, KK, s2); 218 | #endif 219 | 220 | #if(!FUSES_MATMUL && PERFORM_MATMUL) 221 | // Final matmul for obtaining results 222 | GEMM128(TABLE, final_output_gpu, B_gpu, MM, KK, NN, s2); 223 | #endif 224 | } 225 | 226 | cudaEventRecord(stop); 227 | cudaEventSynchronize(stop); 228 | 229 | CUDA_CHECK(cudaGetLastError()); 230 | 231 | // 232 | // End throughput benchmark 233 | // 234 | 235 | // 236 | // Latency benchmark 237 | // 238 | cudaEvent_t start_latency, stop_latency; 239 | cudaEventCreate(&start_latency); 240 | cudaEventCreate(&stop_latency); 241 | cudaEventRecord(start_latency); 242 | 243 | #if(DPF_STRATEGY == DPF_NAIVE) 244 | dpf_naive(cw_gpu, B_gpu, NN, KK, s1); 245 | #endif 246 | #if(DPF_STRATEGY == DPF_BREADTH_FIRST) 247 | dpf_breadth_first(cw_gpu, B_gpu, NN, KK, s1); 248 | #endif 249 | #if(DPF_STRATEGY == DPF_HYBRID) 250 | dpf_hybrid(cw_gpu, B_gpu, TABLE, NN, KK, s1); 251 | #endif 252 | 253 | #if(!FUSES_MATMUL && PERFORM_MATMUL) 254 | // Final matmul for obtaining results 255 | GEMM128(TABLE, final_output_gpu, B_gpu, MM, KK, NN, s1); 256 | #endif 257 | 258 | #if(DPF_STRATEGY == DPF_COOP) 259 | dpf_coop(cw_gpu, B_gpu, TABLE, NN, KK, s1); 260 | #endif 261 | 262 | 263 | cudaEventRecord(stop_latency); 264 | cudaEventSynchronize(stop_latency); 265 | CUDA_CHECK(cudaGetLastError()); 266 | 267 | // 268 | // End latency benchmark 269 | // 270 | 271 | // 272 | // Check correctness if PRF is dummy method 273 | // 274 | if (PRF_METHOD == DUMMY) { 275 | 276 | #if(PERFORM_MATMUL) 277 | // If fuses matmul, check correctness of the dot product 278 | auto B_cpu = std::vector(MM * NN); 279 | cudaMemcpy(B_cpu.data(), final_output_gpu, sizeof(uint128_t_gpu)*MM*NN, cudaMemcpyDeviceToHost); 280 | check_correct_fused(cw_cpu.data(), B_cpu.data(), table, MM, NN, KK); 281 | #else 282 | // If expanding the full DPF, check the correctness of the expanded DPF only 283 | auto B_cpu = std::vector(KK * NN); 284 | cudaMemcpy(B_cpu.data(), B_gpu, sizeof(uint128_t_gpu)*KK*NN, cudaMemcpyDeviceToHost); 285 | check_correct(cw_cpu.data(), B_cpu.data(), NN, KK, DPF_STRATEGY==DPF_BREADTH_FIRST || DPF_STRATEGY==DPF_HYBRID || DPF_STRATEGY==DPF_COOP); 286 | #endif 287 | } 288 | 289 | // 290 | // Log benchmark output to dict 291 | // 292 | float ms = 0; 293 | cudaEventElapsedTime(&ms, start, stop); 294 | float throughput_per_query = NN*REPS*2/ms; 295 | 296 | float ms_latency = 0; 297 | cudaEventElapsedTime(&ms_latency, start_latency, stop_latency); 298 | 299 | // Final logging output 300 | printf("{'entries': %d, 'entry_size_ints': %d, 'batch_size': %d," 301 | "'prf_method': '%s', 'dpf_strategy': '%s', " 302 | "'latency_ms' : %f, 'throughput_queries_per_ms' : %f," 303 | "'fuses_matmul' : %d, 'performs_matmul' : %d}\n", 304 | KK, MM, NN, 305 | get_PRF_method().c_str(), get_DPF_strategy().c_str(), 306 | ms_latency, throughput_per_query, FUSES_MATMUL, 307 | PERFORM_MATMUL); 308 | 309 | cudaFree(B_gpu); 310 | cudaFree(cw_gpu); 311 | 312 | #if(PERFORM_MATMUL) 313 | cudaFree(final_output_gpu); 314 | delete table; 315 | #endif 316 | } 317 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/matmul/matmul.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include "../utils.h" 5 | 6 | // Define stride K to iterate over 7 | #define BLOCK_TILE_K 16 8 | 9 | // Define sizes of blocks of output C to operate over in parallel 10 | #define BLOCK_H 4 11 | #define BLOCK_W 4 12 | 13 | // If K is really large, might exceed launch config size restrictions 14 | // Going to hack this to set the right size (TODO: fix) 15 | #define MAX(a,b) \ 16 | ({ __typeof__ (a) _a = (a); \ 17 | __typeof__ (b) _b = (b); \ 18 | _a > _b ? _a : _b; }) 19 | #define BLOCK_K (MAX(128, K/32768)) 20 | 21 | // Tile inner loop, by outer products with dimension 22 | // K reduction dimension is iterated 1 by 1 23 | #define THREAD_BLOCK_H 1 24 | #define THREAD_BLOCK_W 1 25 | 26 | //// Reduction params 27 | #define REDUCTION_THREADS_PER_BLOCK 128 28 | 29 | uint128_t_gpu *MATMUL_TABLE_REDUCTION; 30 | 31 | void initialize_matmul(int M, int K, int N) { 32 | // We further initialize global memory for reducing across the K dimension 33 | assert((K&(K-1)) == 0); 34 | cudaMalloc(&MATMUL_TABLE_REDUCTION, sizeof(uint128_t_gpu)*M*N*K/BLOCK_K); 35 | cudaMemset(MATMUL_TABLE_REDUCTION, 0, sizeof(uint128_t_gpu)*M*N*K/BLOCK_K); 36 | } 37 | 38 | // Matmul of shape: MxK * KxN -> MxN 39 | __global__ void GEMM128_kernel(uint128_t_gpu *A, 40 | uint128_t_gpu *C, 41 | uint128_t_gpu *B, 42 | uint128_t_gpu *MATMUL_TABLE_REDUCTION, 43 | int M, int K, int N) { 44 | 45 | int block_indx_x = blockIdx.x; 46 | int block_indx_y = blockIdx.y; 47 | int block_indx_k = blockIdx.z; 48 | 49 | int thread_indx_x = threadIdx.x; 50 | int thread_indx_y = threadIdx.y; 51 | 52 | int thread_id_within_block = thread_indx_y*BLOCK_W + thread_indx_x; 53 | 54 | // Threads in a block handle block starting from 55 | int block_C_indx_start = block_indx_y*N*BLOCK_H + block_indx_x*BLOCK_W; 56 | 57 | int threads_per_block = (BLOCK_H/THREAD_BLOCK_H)*(BLOCK_W/THREAD_BLOCK_W); 58 | int thread_id = thread_indx_y*(BLOCK_W/THREAD_BLOCK_W)+thread_indx_x; 59 | 60 | __shared__ uint128_t_gpu A_block_local[BLOCK_H][BLOCK_TILE_K+1]; 61 | __shared__ uint128_t_gpu B_block_local[BLOCK_TILE_K][BLOCK_W+1]; 62 | uint128_t_gpu C_frag_local[THREAD_BLOCK_H][THREAD_BLOCK_W] = {0}; 63 | 64 | // This is the same as the nvidia post, loop over entire K dimension 65 | for (int k = block_indx_k*BLOCK_K; k < block_indx_k*BLOCK_K + BLOCK_K; k += BLOCK_TILE_K) { 66 | 67 | // Load blocks of A,B into shared memory in parallel 68 | int block_A_indx_start = block_indx_y*K*BLOCK_H; 69 | int block_B_indx_start = block_indx_x*BLOCK_W; 70 | 71 | for (int i = 0; i < BLOCK_H*BLOCK_TILE_K; i+= threads_per_block) { 72 | int ii = (i+thread_id) / BLOCK_TILE_K; 73 | int jj = (i+thread_id) % BLOCK_TILE_K; 74 | A_block_local[ii][jj] = A[k+block_A_indx_start + ii*K + jj]; 75 | } 76 | 77 | for (int i = 0; i < BLOCK_TILE_K*BLOCK_W; i+= threads_per_block) { 78 | int ii = (i+thread_id) / BLOCK_W; 79 | int jj = (i+thread_id) % BLOCK_W; 80 | //B_block_local[ii][jj] = B[block_B_indx_start + k*N + ii*N + jj]; 81 | B_block_local[ii][jj] = B[(block_B_indx_start+jj)*K + (k+ii)]; 82 | } 83 | 84 | __syncthreads(); 85 | 86 | // Compute over thread block tiles 87 | for (int i = 0; i < BLOCK_TILE_K; i++) { 88 | 89 | // More efficient method should be outer product 90 | // Load fragments into registers 91 | uint128_t_gpu A_frag_local[THREAD_BLOCK_H]; 92 | uint128_t_gpu B_frag_local[THREAD_BLOCK_W]; 93 | 94 | for (int j = 0; j < THREAD_BLOCK_H; j++) { 95 | A_frag_local[j] = A_block_local[j+thread_indx_y*THREAD_BLOCK_H][i]; 96 | } 97 | for (int j = 0; j < THREAD_BLOCK_W; j++) { 98 | B_frag_local[j] = B_block_local[i][j+thread_indx_x*THREAD_BLOCK_W]; 99 | } 100 | 101 | // Outer product into per-thread mem 102 | for (int jj = 0; jj < THREAD_BLOCK_H; jj++) { 103 | for (int kk = 0; kk < THREAD_BLOCK_W; kk++) { 104 | C_frag_local[jj][kk] = add_uint128(C_frag_local[jj][kk], 105 | mul_uint128(A_frag_local[jj], B_frag_local[kk])); 106 | } 107 | } 108 | } 109 | } 110 | 111 | ////////////////////////////////////////////////// 112 | // Reduction across threads in the K dimension // 113 | ///////////////////////////////////////////////// 114 | 115 | // Write C frag locals to intermediate output 116 | int k_stride = M*N; 117 | for (int j = 0; j < THREAD_BLOCK_W; j++) { 118 | for (int i = 0; i < THREAD_BLOCK_H; i++) { 119 | MATMUL_TABLE_REDUCTION[block_indx_k*k_stride + block_C_indx_start + thread_indx_y*THREAD_BLOCK_H*N + thread_indx_x*THREAD_BLOCK_W + i*N + j] = C_frag_local[i][j]; 120 | } 121 | } 122 | } 123 | 124 | __global__ void GEMM128_reduction_kernel(uint128_t_gpu *MATMUL_TABLE_REDUCTION, 125 | uint128_t_gpu *out, 126 | int M, int K, int N) { 127 | int block_indx = blockIdx.x; 128 | int thread_idx = threadIdx.x; 129 | int work_per_block = REDUCTION_THREADS_PER_BLOCK; 130 | int work_indx = block_indx*work_per_block + thread_idx; 131 | 132 | if (work_indx >= M*N) return; 133 | 134 | int k_stride = M*N; 135 | uint128_t_gpu accum[1] = {0}; 136 | for (int k = 0; k < K/BLOCK_K; k++) { 137 | uint128_t_gpu op2 = MATMUL_TABLE_REDUCTION[k*k_stride + work_indx]; 138 | accum[0] = add_uint128(accum[0], op2); 139 | } 140 | 141 | out[work_indx] = accum[0]; 142 | } 143 | 144 | void GEMM128(uint128_t_gpu *A, 145 | uint128_t_gpu *C, 146 | uint128_t_gpu *B, 147 | int M, int K, int N, 148 | cudaStream_t s) { 149 | 150 | assert(BLOCK_W%THREAD_BLOCK_W == 0); 151 | assert(BLOCK_H%THREAD_BLOCK_H == 0); 152 | assert(N%BLOCK_W == 0); 153 | assert(M%BLOCK_H == 0); 154 | 155 | dim3 threads_per_block(BLOCK_W/THREAD_BLOCK_W, BLOCK_H/THREAD_BLOCK_H); 156 | dim3 n_blocks(N/BLOCK_W, M/BLOCK_H, K/BLOCK_K); 157 | 158 | GEMM128_kernel<<>>(A, C, B, MATMUL_TABLE_REDUCTION, M, K, N); 159 | 160 | dim3 threads_per_block_reduce(REDUCTION_THREADS_PER_BLOCK); 161 | dim3 n_blocks_reduce((M*N)/REDUCTION_THREADS_PER_BLOCK+1); 162 | GEMM128_reduction_kernel<<>>(MATMUL_TABLE_REDUCTION, C, M, K, N); 163 | } 164 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/matmul_benchmark.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | // Benchmark and test 128-bit matmul for DPF 5 | 6 | #include "utils.h" 7 | #include "matmul/matmul.cu" 8 | 9 | #ifndef REPS 10 | #define REPS 10 11 | #endif 12 | 13 | void print_params() { 14 | printf("------------------------------------------------------\n"); 15 | printf("matmul_benchmark.cu:\n"); 16 | printf("------------------------------------------------------\n"); 17 | printf("- Entries in table (K): %d\n", KK); 18 | printf("- Batch size (N): %d\n", NN); 19 | printf("- Entry size (M): %d\n", MM); 20 | printf("------------------------------------------------------\n"); 21 | } 22 | 23 | void alloc_test_matrix(uint128_t_gpu **A_gpu, 24 | uint128_t_gpu **A_cpu, 25 | int M, int N) { 26 | *A_cpu = new uint128_t_gpu[M*N]; 27 | for (int i = 0; i < M; i++) { 28 | for (int j = 0; j < N; j++) { 29 | (*A_cpu)[i*N+j] = uint128_gpu_from((uint128_t)i*N+j); 30 | } 31 | } 32 | 33 | cudaMalloc(A_gpu, sizeof(uint128_t_gpu)*M*N); 34 | cudaMemcpy(*A_gpu, *A_cpu, sizeof(uint128_t_gpu)*M*N, cudaMemcpyHostToDevice); 35 | } 36 | 37 | void check_correct(uint128_t_gpu *A, 38 | uint128_t_gpu *B, 39 | uint128_t_gpu *C, 40 | int M, int K, int N) { 41 | uint128_t_gpu *C_ref = new uint128_t_gpu[M*N]; 42 | memset(C_ref, 0, sizeof(uint128_t_gpu)*M*N); 43 | 44 | // Compute ref solution 45 | for (int i = 0; i < M; i++) { 46 | for (int j = 0; j < K; j++) { 47 | for (int k = 0; k < N; k++) { 48 | uint128_t c = uint128_from_gpu(C_ref[i*N+k]); 49 | uint128_t a = uint128_from_gpu(A[i*K+j]); 50 | uint128_t b = uint128_from_gpu(B[j+k*K]); 51 | uint128_t accum = c+a*b; 52 | C_ref[i*N+k] = uint128_gpu_from(accum); 53 | } 54 | } 55 | } 56 | 57 | // Assert same 58 | for (int i = 0; i < M; i++) { 59 | for (int j = 0; j < N; j++) { 60 | uint128_t_gpu got = C[i*N+j]; 61 | uint128_t_gpu expected = C_ref[i*N+j]; 62 | 63 | assert(got.x == expected.x && 64 | got.y == expected.y && 65 | got.z == expected.z && 66 | got.w == expected.w); 67 | } 68 | } 69 | 70 | printf("PASS CHECKS\n"); 71 | } 72 | 73 | int main(void) { 74 | print_params(); 75 | 76 | // Alloc & Init buffers 77 | uint128_t_gpu *A_gpu, *B_gpu, *C_gpu; 78 | uint128_t_gpu *A_cpu, *B_cpu, *C_cpu; 79 | 80 | alloc_test_matrix(&A_gpu, &A_cpu, MM, KK); 81 | alloc_test_matrix(&B_gpu, &B_cpu, KK, NN); 82 | alloc_test_matrix(&C_gpu, &C_cpu, MM, NN); 83 | 84 | cudaMemset(C_gpu, 0, sizeof(uint128_t_gpu)*MM*NN); 85 | 86 | // Init 87 | initialize_matmul(MM, KK, NN); 88 | 89 | // Kernel benchmark 90 | cudaStream_t s1; 91 | cudaStreamCreate(&s1); 92 | cudaEvent_t start, stop; 93 | cudaEventCreate(&start); 94 | cudaEventCreate(&stop); 95 | 96 | // Run throughput benchmark 97 | cudaEventRecord(start); 98 | for (int i = 0; i < REPS; i++) { 99 | GEMM128(A_gpu, C_gpu, B_gpu, MM, KK, NN, s1); 100 | } 101 | cudaEventRecord(stop); 102 | cudaEventSynchronize(stop); 103 | 104 | // Run latency benchmark 105 | cudaEvent_t start_latency, stop_latency; 106 | cudaEventCreate(&start_latency); 107 | cudaEventCreate(&stop_latency); 108 | cudaEventRecord(start_latency); 109 | 110 | GEMM128(A_gpu, C_gpu, B_gpu, MM, KK, NN, s1); 111 | 112 | cudaEventRecord(stop_latency); 113 | cudaEventSynchronize(stop_latency); 114 | CUDA_CHECK(cudaGetLastError()); 115 | 116 | // Correctness checks 117 | cudaMemcpy(C_cpu, C_gpu, sizeof(uint128_t_gpu)*MM*NN, cudaMemcpyDeviceToHost); 118 | check_correct(A_cpu, B_cpu, C_cpu, MM, KK, NN); 119 | 120 | // Stats 121 | float ms = 0; 122 | cudaEventElapsedTime(&ms, start, stop); 123 | float throughput_per_query = NN*REPS/ms; 124 | 125 | float ms_latency = 0; 126 | cudaEventElapsedTime(&ms_latency, start_latency, stop_latency); 127 | 128 | // Final logging output 129 | printf("{'entries (K)': %d, 'entry_size_ints (M)': %d, 'batch_size (N)': %d," 130 | "'latency_ms' : %f, 'throughput_queries_per_ms' : %f'}\n", 131 | KK, MM, NN, 132 | ms_latency, throughput_per_query); 133 | 134 | cudaFree(A_gpu); 135 | cudaFree(B_gpu); 136 | cudaFree(C_gpu); 137 | 138 | delete A_cpu; 139 | delete B_cpu; 140 | delete C_cpu; 141 | } 142 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/prf/prf.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include "../utils.h" 5 | 6 | #define DUMMY 0 7 | 8 | // These hash methods are from: https://github.com/mochimodev/cuda-hashing-algos 9 | // 10 | // One thing to note is that these operate over "batches" of inputs. 11 | // But as it is, we're calling it one function at a time on DPF codewords (non-batch) 12 | // which could be why it's showing extremely poor performance (besides MD5) 13 | // 14 | // These are don't work well. 15 | #define MD5_HASH_SLOW 1 16 | #define BLAKE2B_HASH_SLOW 2 17 | #define KECCAK_HASH_SLOW 3 18 | #define MD2_HASH_SLOW 4 19 | #define SHA256_HASH_SLOW 5 20 | /////////////////////// 21 | 22 | // Hash functions 23 | #define MD5_FASTER 6 24 | #define SHA256_FASTER 11 25 | 26 | // Pure PRFs 27 | #define SIPHASH 12 28 | #define HIGHWAYHASH 13 29 | 30 | // Ciphers (stream + block) 31 | #define SALSA20_CIPHER 7 32 | #define SALSA20_12_CIPHER 8 33 | #define SALSA20_8_CIPHER 9 34 | #define AES_128_CTR_CIPHER 10 35 | 36 | #ifndef PRF_METHOD 37 | #define PRF_METHOD DUMMY 38 | //#define PRF_METHOD HIGHWAYHASH 39 | //#define PRF_METHOD SIPHASH 40 | //#define PRF_METHOD SHA256_FASTER 41 | //#define PRF_METHOD AES_128_CTR_CIPHER 42 | //#define PRF_METHOD SALSA20_CIPHER 43 | //#define PRF_METHOD SALSA20_8_CIPHER 44 | //#define PRF_METHOD SALSA20_12_CIPHER 45 | //#define PRF_METHOD MD5_FASTER 46 | //#define PRF_METHOD MD2_HASH_SLOW 47 | //#define PRF_METHOD KECCAK_HASH_SLOW 48 | //#define PRF_METHOD BLAKE2B_HASH_SLOW 49 | //#define PRF_METHOD MD5_HASH_SLOW 50 | //#define PRF_METHOD SHA256_HASH_SLOW 51 | #endif 52 | 53 | #if(PRF_METHOD == MD5_HASH_SLOW) 54 | #include "prf_algos/md5_mochi.cu" 55 | #endif 56 | 57 | #if(PRF_METHOD == BLAKE2B_HASH_SLOW) 58 | #include "prf_algos/blake2b_mochi.cu" 59 | #endif 60 | 61 | #if(PRF_METHOD == KECCAK_HASH_SLOW) 62 | #include "prf_algos/keccak_mochi.cu" 63 | #endif 64 | 65 | #if(PRF_METHOD == MD2_HASH_SLOW) 66 | #include "prf_algos/md2_mochi.cu" 67 | #endif 68 | 69 | #if(PRF_METHOD == SHA256_HASH_SLOW) 70 | #include "prf_algos/sha256_mochi.cu" 71 | #endif 72 | 73 | #if(PRF_METHOD == MD5_FASTER) 74 | #include "prf_algos/md5.cu" 75 | #endif 76 | 77 | #if(PRF_METHOD == SALSA20_CIPHER || PRF_METHOD == SALSA20_8_CIPHER || PRF_METHOD == SALSA20_12_CIPHER) 78 | #include "prf_algos/salsa20.cu" 79 | #endif 80 | 81 | #if(PRF_METHOD == AES_128_CTR_CIPHER) 82 | #include "prf_algos/aes_cuda.cu" 83 | #endif 84 | 85 | #if(PRF_METHOD == SHA256_FASTER) 86 | #include "prf_algos/sha256.cu" 87 | #endif 88 | 89 | #if(PRF_METHOD == SIPHASH) 90 | #include "prf_algos/siphash.cu" 91 | #endif 92 | 93 | #if(PRF_METHOD == HIGHWAYHASH) 94 | #include "prf_algos/highwayhash.cu" 95 | #endif 96 | 97 | std::string get_PRF_method() { 98 | if (PRF_METHOD == DUMMY) return "DUMMY"; 99 | if (PRF_METHOD == HIGHWAYHASH) return "HIGHWAYHASH"; 100 | if (PRF_METHOD == SIPHASH) return "SIPHASH"; 101 | if (PRF_METHOD == MD5_FASTER) return "MD5"; 102 | if (PRF_METHOD == SHA256_FASTER) return "SHA256"; 103 | if (PRF_METHOD == SALSA20_8_CIPHER) return "SALSA20_8"; 104 | if (PRF_METHOD == SALSA20_12_CIPHER) return "SALSA20_12"; 105 | if (PRF_METHOD == SALSA20_CIPHER) return "SHA25620_20"; 106 | if (PRF_METHOD == AES_128_CTR_CIPHER) return "AES128"; 107 | } 108 | 109 | // Ignore warnings since there are unused variables due to 110 | // swapping out included files 111 | #pragma push 112 | #pragma diag_suppress = 253-D 113 | #pragma diag_suppress = 549-D 114 | #pragma diag_suppress = 550-D 115 | #pragma diag_suppress = code_is_unreachable 116 | #pragma diag_suppress = declared_but_not_referenced 117 | 118 | __device__ void hash(uint64_t *out_data, uint64_t *in_data) { 119 | 120 | unsigned char *in = (unsigned char *)in_data; 121 | unsigned char *out = (unsigned char *)out_data; 122 | 123 | #if(PRF_METHOD == MD5_HASH_SLOW) 124 | CUDA_MD5_CTX ctx; 125 | cuda_md5_init(&ctx); 126 | cuda_md5_update(&ctx, in, 16); 127 | cuda_md5_final(&ctx, out); 128 | #endif 129 | 130 | #if(PRF_METHOD == BLAKE2B_HASH_SLOW) 131 | CUDA_BLAKE2B_CTX ctx = c_CTX; 132 | // IMPORTANT: c_CTX is not initialized, but couldn't get it to work 133 | // with initialization (seg fault) 134 | cuda_blake2b_update(&ctx, in, 16); 135 | cuda_blake2b_final(&ctx, out); 136 | #endif 137 | 138 | #if(PRF_METHOD == KECCAK_HASH_SLOW) 139 | CUDA_KECCAK_CTX ctx; 140 | cuda_keccak_init(&ctx, 128); 141 | cuda_keccak_update(&ctx, in, 16); 142 | cuda_keccak_final(&ctx, out); 143 | #endif 144 | 145 | #if(PRF_METHOD == MD2_HASH_SLOW) 146 | CUDA_MD2_CTX ctx; 147 | cuda_md2_init(&ctx); 148 | cuda_md2_update(&ctx, in, 16); 149 | cuda_md2_final(&ctx, out); 150 | #endif 151 | 152 | #if(PRF_METHOD == SHA256_HASH_SLOW) 153 | CUDA_SHA256_CTX ctx; 154 | cuda_sha256_init(&ctx); 155 | cuda_sha256_update(&ctx, in, 16); 156 | cuda_sha256_final(&ctx, out); 157 | #endif 158 | 159 | #if(PRF_METHOD == MD5_FASTER) 160 | md5Hash(in, 16, ((uint32_t *)out) + 0, ((uint32_t *)out) + 1, ((uint32_t *)out) + 2, ((uint32_t *)out) + 3); 161 | #endif 162 | 163 | #if(PRF_METHOD == SHA256_FASTER) 164 | uint32_t out_larger[8]; 165 | sha256(out_larger, 166 | ((uint32_t *)in)[0], 167 | ((uint32_t *)in)[1], 168 | ((uint32_t *)in)[2], 169 | ((uint32_t *)in)[3], 170 | 0, 0, 0, 0, 171 | 0, 0, 0, 0, 172 | 0, 0, 0, 0); 173 | ((uint32_t *)out)[0] = out_larger[0]; 174 | ((uint32_t *)out)[1] = out_larger[1]; 175 | ((uint32_t *)out)[2] = out_larger[2]; 176 | ((uint32_t *)out)[3] = out_larger[3]; 177 | #endif 178 | } 179 | 180 | __device__ uint128_t_gpu HMAC(uint128_t_gpu seed, int i) { 181 | // PERFORMS HMAC 182 | uint64_t in[2]; 183 | uint64_t out[2]; 184 | 185 | uint64_t seed_first = (((uint64_t)seed.x) << 32) | seed.y; 186 | uint64_t seed_second = (((uint64_t)seed.z) << 32) | seed.w; 187 | 188 | in[0] = (seed_first ^ 0x3636363636363636) | i; 189 | in[1] = (seed_second ^ 0x3636363636363636) | i; 190 | 191 | hash(out, in); 192 | 193 | in[0] = out[0] | (seed_first ^ 0x5c5c5c5c5c5c5c5c); 194 | in[1] = out[1] | (seed_second ^ 0x5c5c5c5c5c5c5c5c); 195 | 196 | hash(out, in); 197 | 198 | uint128_t_gpu r; 199 | r.x = out[0] >> 32; 200 | r.y = out[0] & 0xFFFFFFFF; 201 | r.z = out[1] >> 32; 202 | r.w = out[1] & 0xFFFFFFFF; 203 | return r; 204 | } 205 | 206 | __device__ uint128_t_gpu PRF(uint128_t_gpu seed, uint32_t i) { 207 | 208 | if (PRF_METHOD == DUMMY) { 209 | uint128_t_gpu val_4242 = uint128_from(0, 4242); 210 | uint128_t_gpu val_i = uint128_from(0, i); 211 | return add_uint128(mul_uint128(seed, add_uint128(val_4242, val_i)), 212 | add_uint128(val_4242, val_i)); 213 | } 214 | 215 | // Check if stream cipher (otherwise uses hash + HMAC) 216 | if (PRF_METHOD == SALSA20_CIPHER || 217 | PRF_METHOD == SALSA20_8_CIPHER || 218 | PRF_METHOD == SALSA20_12_CIPHER) { 219 | 220 | // Set initial state 221 | uint32_t in[16]; 222 | uint32_t out[16]; 223 | 224 | // Fixed words 225 | in[0] = 0x65787061; 226 | in[5] = 0x6e642033; 227 | in[10] = 0x322d6279; 228 | in[15] = 0x7465206b; 229 | 230 | // Nonce (this is fixed as well) 231 | in[6] = 0; 232 | in[7] = 0; 233 | 234 | // CTR 235 | in[8] = i; 236 | in[9] = 0; 237 | 238 | // Keys 239 | in[1] = seed.x; 240 | in[2] = seed.y; 241 | in[3] = seed.z; 242 | in[4] = seed.w; 243 | 244 | in[11] = in[12] = in[13] = in[14] = 0; 245 | 246 | #if(PRF_METHOD == SALSA20_CIPHER) 247 | SALSA20(out, in); 248 | #endif 249 | #if(PRF_METHOD == SALSA20_8_CIPHER) 250 | SALSA20_8(out, in); 251 | #endif 252 | #if(PRF_METHOD == SALSA20_12_CIPHER) 253 | SALSA20_12(out, in); 254 | #endif 255 | uint128_t_gpu r; 256 | r.x = out[0]; 257 | r.y = out[1]; 258 | r.w = out[2]; 259 | r.z = out[3]; 260 | return r; 261 | } 262 | else if (PRF_METHOD == AES_128_CTR_CIPHER) { 263 | // Note that, this method does not do key expansion 264 | // and assumes it is already done. Performance is 265 | // poor enough and key expansion makes it worse 266 | 267 | uint32_t block[4]; 268 | for (int ii = 0; ii < 44; ii++) block[i] = i; 269 | uint32_t key[44]; 270 | for (int ii = 0; ii < 44; ii++) key[i] = seed.x; 271 | 272 | #if(PRF_METHOD == AES_128_CTR_CIPHER) 273 | encrypt((uint8_t *)block, (uint8_t *)key, 0); 274 | #endif 275 | 276 | uint128_t_gpu r; 277 | r.x = block[0]; 278 | r.y = block[1]; 279 | r.w = block[2]; 280 | r.z = block[3]; 281 | return r; 282 | } 283 | else if (PRF_METHOD == SIPHASH) { 284 | // Note that siphash isn't really a hash function but instead 285 | // is directly a PRF, so we can call it directly 286 | // TODO: wikipedia says 64-bit output. CONVERT TO 128 bit! 287 | uint8_t out[16]; 288 | uint32_t in[4] = {i, 0, 0, 0}; 289 | uint32_t k[4] = {seed.x, seed.y, seed.z, seed.w}; 290 | 291 | #if(PRF_METHOD == SIPHASH) 292 | siphash(out, (const uint8_t *)in, 16, (const uint8_t *)k); 293 | #endif 294 | 295 | uint128_t_gpu r; 296 | r.x = out[0]; 297 | r.y = out[1]; 298 | 299 | // TODO: Note these two should be out[2], out[3] for 128-bit security 300 | r.w = out[0]; 301 | r.y = out[1]; 302 | return r; 303 | } 304 | else if (PRF_METHOD == HIGHWAYHASH) { 305 | // Note there are criticisms of highway hash as cryptographically 306 | // secure hash. But authors claim it is a PRF. 307 | uint32_t data[4] = {i, 0, 0, 0}; 308 | size_t size = 16; 309 | uint64_t key[4] = {seed.x, seed.y, seed.z, seed.w}; 310 | uint64_t hash[2]; 311 | 312 | #if(PRF_METHOD == HIGHWAYHASH) 313 | HighwayHash128((uint8_t *)data, size, key, hash); 314 | #endif 315 | 316 | uint128_t_gpu r; 317 | r.x = hash[0] >> 32; 318 | r.y = hash[0] & 0xFFFFFFFF; 319 | r.w = hash[1] >> 32; 320 | r.y = hash[1] & 0xFFFFFFFF; 321 | 322 | return r; 323 | } 324 | else { 325 | return HMAC(seed, i); 326 | } 327 | } 328 | #pragma pop 329 | 330 | 331 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/tests/test_128_bit.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #include "../utils.h" 5 | 6 | void test_uint128_gpu_from() { 7 | uint128_t val = 0x1122334422334455; 8 | val <<= 64; 9 | val |= 0x3344556644556677; 10 | 11 | uint128_t_gpu g1 = uint128_gpu_from(val); 12 | assert(g1.w == 0x11223344); 13 | assert(g1.z == 0x22334455); 14 | assert(g1.y == 0x33445566); 15 | assert(g1.x == 0x44556677); 16 | } 17 | 18 | __global__ void test_uint128_from_kernel(uint128_t_gpu *r) { 19 | *r = uint128_from(0x1234567823456789, 20 | 0x2345678934567890); 21 | } 22 | 23 | void test_uint128_from() { 24 | uint128_t_gpu *r; 25 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 26 | test_uint128_from_kernel<<<1, 1>>>(r); 27 | uint128_t_gpu r_cpu; 28 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 29 | 30 | assert(r_cpu.w == 0x12345678); 31 | assert(r_cpu.z == 0x23456789); 32 | assert(r_cpu.y == 0x23456789); 33 | assert(r_cpu.x == 0x34567890); 34 | 35 | cudaFree(r); 36 | } 37 | 38 | __global__ void test_add_uint128_kernel(uint128_t_gpu *a, 39 | uint128_t_gpu *b, 40 | uint128_t_gpu *r) { 41 | *r = add_uint128(*a, *b); 42 | } 43 | 44 | void test_add_uint128() { 45 | 46 | // Init v1 and v2 for mult 47 | uint128_t v1 = 0x12345678; 48 | v1 <<= 64; 49 | v1 |= 0x23456789; 50 | 51 | uint128_t v2 = 0x34567890; 52 | v2 <<= 64; 53 | v2 |= 0x45678901; 54 | 55 | uint128_t_gpu a = uint128_gpu_from(v1); 56 | uint128_t_gpu b = uint128_gpu_from(v2); 57 | 58 | // Alloc gpu mem 59 | uint128_t_gpu *r; 60 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 61 | 62 | uint128_t_gpu *a_gpu, *b_gpu; 63 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 64 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 65 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 66 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 67 | 68 | test_add_uint128_kernel<<<1, 1>>>(a_gpu, b_gpu, r); 69 | uint128_t_gpu r_cpu; 70 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 71 | 72 | uint128_t truth = v1+v2; 73 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 74 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 75 | assert(r_cpu.w == truth >> 96); 76 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 77 | 78 | cudaFree(r); 79 | cudaFree(a_gpu); 80 | cudaFree(b_gpu); 81 | } 82 | 83 | __global__ void test_mul_uint128_kernel(uint128_t_gpu *a, 84 | uint128_t_gpu *b, 85 | uint128_t_gpu *r) { 86 | *r = mul_uint128(*a, *b); 87 | } 88 | 89 | void test_mul_uint128() { 90 | 91 | // Init v1 and v2 for mult 92 | uint128_t v1 = 0x12345678; 93 | v1 <<= 64; 94 | v1 |= 0x23456789; 95 | 96 | uint128_t v2 = 0x34567890; 97 | v2 <<= 64; 98 | v2 |= 0x45678901; 99 | 100 | uint128_t_gpu a = uint128_gpu_from(v1); 101 | uint128_t_gpu b = uint128_gpu_from(v2); 102 | 103 | // Alloc gpu mem 104 | uint128_t_gpu *r; 105 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 106 | 107 | uint128_t_gpu *a_gpu, *b_gpu; 108 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 109 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 110 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 111 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 112 | 113 | test_mul_uint128_kernel<<<1, 1>>>(a_gpu, b_gpu, r); 114 | uint128_t_gpu r_cpu; 115 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 116 | 117 | uint128_t truth = v1*v2; 118 | 119 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 120 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 121 | assert(r_cpu.w == truth >> 96); 122 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 123 | 124 | cudaFree(r); 125 | cudaFree(a_gpu); 126 | cudaFree(b_gpu); 127 | } 128 | 129 | __global__ void test_mul_uint128_kernel_twice(uint128_t_gpu *a, 130 | uint128_t_gpu *b, 131 | uint128_t_gpu *c, 132 | uint128_t_gpu *r) { 133 | *r = mul_uint128(mul_uint128(*a, *b), *c); 134 | } 135 | 136 | void test_mul_uint128_twice() { 137 | 138 | // Init v1 and v2 for mult 139 | uint128_t v1 = 0x12345678; 140 | v1 <<= 64; 141 | v1 |= 0x23456789; 142 | 143 | uint128_t v2 = 0x34567890; 144 | v2 <<= 64; 145 | v2 |= 0x45678901; 146 | 147 | uint128_t v3 = 0x123; 148 | v3 <<= 64; 149 | v3 |= 0x456; 150 | 151 | uint128_t_gpu a = uint128_gpu_from(v1); 152 | uint128_t_gpu b = uint128_gpu_from(v2); 153 | uint128_t_gpu c = uint128_gpu_from(v3); 154 | 155 | // Alloc gpu mem 156 | uint128_t_gpu *r; 157 | cudaMalloc((void **)&r, sizeof(uint128_t_gpu)); 158 | 159 | uint128_t_gpu *a_gpu, *b_gpu, *c_gpu; 160 | cudaMalloc((void **)&a_gpu, sizeof(uint128_t_gpu)); 161 | cudaMalloc((void **)&b_gpu, sizeof(uint128_t_gpu)); 162 | cudaMalloc((void **)&c_gpu, sizeof(uint128_t_gpu)); 163 | cudaMemcpy(a_gpu, &a, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 164 | cudaMemcpy(b_gpu, &b, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); cudaMemcpy(c_gpu, &c, sizeof(uint128_t_gpu), cudaMemcpyHostToDevice); 165 | 166 | test_mul_uint128_kernel_twice<<<1, 1>>>(a_gpu, b_gpu, c_gpu, r); 167 | uint128_t_gpu r_cpu; 168 | cudaMemcpy(&r_cpu, r, sizeof(uint128_t_gpu), cudaMemcpyDeviceToHost); 169 | 170 | uint128_t truth = (v1*v2)*v3; 171 | 172 | assert(r_cpu.x == (truth & 0xFFFFFFFF)); 173 | assert(r_cpu.y == ((truth & 0xFFFFFFFF00000000) >> 32)); 174 | assert(r_cpu.w == truth >> 96); 175 | assert(r_cpu.z == ((truth >> 64) & 0xFFFFFFFF)); 176 | 177 | cudaFree(r); 178 | cudaFree(a_gpu); 179 | cudaFree(b_gpu); 180 | } 181 | 182 | void test_uint128_gpu_conversion() { 183 | for (int i = 0; i < 1000; i++) { 184 | uint128_t k = i * 0x12345; 185 | 186 | uint128_t_gpu v = uint128_gpu_from(k); 187 | uint128_t v_back = uint128_from_gpu(v); 188 | 189 | assert(v_back == k); 190 | } 191 | } 192 | 193 | int main(void) { 194 | test_uint128_gpu_from(); 195 | test_uint128_from(); 196 | test_add_uint128(); 197 | test_mul_uint128(); 198 | test_mul_uint128_twice(); 199 | test_uint128_gpu_conversion(); 200 | printf("PASS\n"); 201 | } 202 | -------------------------------------------------------------------------------- /paper/kernel/gpu/dpf_gpu/utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #ifndef UTILS 5 | #define UTILS 6 | 7 | #include "../dpf_base/dpf.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | //////////////////////////////////////////////////////////////////////////////// 14 | // 128-bit functionalities // 15 | // from: https://stackoverflow.com/questions/6162140/128-bit-integer-on-cuda // 16 | //////////////////////////////////////////////////////////////////////////////// 17 | typedef uint4 uint128_t_gpu; 18 | 19 | uint128_t_gpu uint128_gpu_from(uint128_t val) { 20 | uint128_t_gpu res; 21 | res.w = (val >> 96) & 0xFFFFFFFF; 22 | res.z = (val >> 64) & 0xFFFFFFFF; 23 | res.y = (val >> 32) & 0xFFFFFFFF; 24 | res.x = (val >> 0) & 0xFFFFFFFF; 25 | return res; 26 | } 27 | 28 | uint128_t uint128_from_gpu(uint128_t_gpu val) { 29 | uint128_t res = 0; 30 | return val.x + 31 | ((uint128_t)val.y << 32) + 32 | ((uint128_t)val.z << 64) + 33 | ((uint128_t)val.w << 96); 34 | } 35 | 36 | __device__ uint128_t_gpu uint128_from(uint64_t hi, 37 | uint64_t lo) { 38 | uint128_t_gpu res; 39 | res.w = (hi >> 32); 40 | res.z = hi & 0x00000000FFFFFFFF; 41 | res.y = (lo >> 32); 42 | res.x = lo & 0x00000000FFFFFFFF; 43 | return res; 44 | } 45 | 46 | __device__ uint128_t_gpu add_uint128(uint128_t_gpu addend, uint128_t_gpu augend) 47 | { 48 | uint128_t_gpu res; 49 | asm ("add.cc.u32 %0, %4, %8;\n\t" 50 | "addc.cc.u32 %1, %5, %9;\n\t" 51 | "addc.cc.u32 %2, %6, %10;\n\t" 52 | "addc.u32 %3, %7, %11;\n\t" 53 | : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) 54 | : "r"(addend.x), "r"(addend.y), "r"(addend.z), "r"(addend.w), 55 | "r"(augend.x), "r"(augend.y), "r"(augend.z), "r"(augend.w)); 56 | return res; 57 | } 58 | 59 | __device__ uint128_t_gpu mul_uint128(uint128_t_gpu a, uint128_t_gpu b) 60 | { 61 | uint128_t_gpu res; 62 | asm ("{\n\t" 63 | "mul.lo.u32 %0, %4, %8; \n\t" 64 | "mul.hi.u32 %1, %4, %8; \n\t" 65 | "mad.lo.cc.u32 %1, %4, %9, %1;\n\t" 66 | "madc.hi.u32 %2, %4, %9, 0;\n\t" 67 | "mad.lo.cc.u32 %1, %5, %8, %1;\n\t" 68 | "madc.hi.cc.u32 %2, %5, %8, %2;\n\t" 69 | "madc.hi.u32 %3, %4,%10, 0;\n\t" 70 | "mad.lo.cc.u32 %2, %4,%10, %2;\n\t" 71 | "madc.hi.u32 %3, %5, %9, %3;\n\t" 72 | "mad.lo.cc.u32 %2, %5, %9, %2;\n\t" 73 | "madc.hi.u32 %3, %6, %8, %3;\n\t" 74 | "mad.lo.cc.u32 %2, %6, %8, %2;\n\t" 75 | "madc.lo.u32 %3, %4,%11, %3;\n\t" 76 | "mad.lo.u32 %3, %5,%10, %3;\n\t" 77 | "mad.lo.u32 %3, %6, %9, %3;\n\t" 78 | "mad.lo.u32 %3, %7, %8, %3;\n\t" 79 | "}" 80 | : "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w) 81 | : "r"(a.x), "r"(a.y), "r"(a.z), "r"(a.w), 82 | "r"(b.x), "r"(b.y), "r"(b.z), "r"(b.w)); 83 | return res; 84 | } 85 | 86 | // Error check functionality 87 | inline void error_check(cudaError_t err, const char* file, int line) { 88 | if(err != cudaSuccess) { 89 | ::fprintf(stderr, "CUDA ERROR at %s[%d] : %s\n", file, line, cudaGetErrorString(err)); 90 | abort(); 91 | } 92 | } 93 | #define CUDA_CHECK(err) do { error_check(err, __FILE__, __LINE__); } while(0) 94 | 95 | // SeedsCodewordsFlat for GPU (replaces uint128_t with vector version) 96 | struct SeedsCodewordsFlatGPU { 97 | int depth; 98 | uint128_t_gpu cw_1[64], cw_2[64]; 99 | uint128_t_gpu last_keys[1]; 100 | }; 101 | 102 | SeedsCodewordsFlatGPU SeedsCodewordsFlatGPUFromCPU(SeedsCodewordsFlat f) { 103 | SeedsCodewordsFlatGPU g; 104 | g.depth = f.depth; 105 | for (int i = 0; i < 64; i++) { 106 | g.cw_1[i] = uint128_gpu_from(f.cw_1[i]); 107 | g.cw_2[i] = uint128_gpu_from(f.cw_2[i]); 108 | } 109 | g.last_keys[0] = uint128_gpu_from(f.last_keys[0]); 110 | return g; 111 | } 112 | 113 | // Generates dummy codewords for testint 114 | std::vector GenCodewords(int k, int n, 115 | SeedsCodewordsFlatGPU **cw_gpu) { 116 | 117 | auto cw_cpu = std::vector(n); 118 | for (int i = 0; i < n; i++) { 119 | 120 | std::mt19937 g_gen(i); 121 | int alpha = (100+i) % k; 122 | int beta = 4242+i; 123 | 124 | SeedsCodewords *s = GenerateSeedsAndCodewordsLog(alpha, beta, k, g_gen); 125 | FlattenCodewords(s, 0, &cw_cpu[i]); 126 | FreeSeedsCodewords(s); 127 | } 128 | 129 | // Convert codewords to gpu rep 130 | SeedsCodewordsFlatGPU *cw_intermediate = (SeedsCodewordsFlatGPU *)malloc(sizeof(SeedsCodewordsFlatGPU)*n); 131 | for (int i = 0; i < n; i++) { 132 | cw_intermediate[i] = SeedsCodewordsFlatGPUFromCPU(cw_cpu[i]); 133 | } 134 | 135 | cudaMalloc((void **)cw_gpu, sizeof(SeedsCodewordsFlatGPU)*n); 136 | cudaMemcpy(*cw_gpu, cw_intermediate, sizeof(SeedsCodewordsFlatGPU)*(n), cudaMemcpyHostToDevice); 137 | free(cw_intermediate); 138 | 139 | return cw_cpu; 140 | } 141 | 142 | // https://stackoverflow.com/questions/9144800/c-reverse-bits-in-unsigned-integer 143 | uint32_t brev_cpu(uint32_t x) { 144 | x = ((x >> 1) & 0x55555555u) | ((x & 0x55555555u) << 1); 145 | x = ((x >> 2) & 0x33333333u) | ((x & 0x33333333u) << 2); 146 | x = ((x >> 4) & 0x0f0f0f0fu) | ((x & 0x0f0f0f0fu) << 4); 147 | x = ((x >> 8) & 0x00ff00ffu) | ((x & 0x00ff00ffu) << 8); 148 | x = ((x >> 16) & 0xffffu) | ((x & 0xffffu) << 16); 149 | return x; 150 | } 151 | 152 | // Correctness checks the output of GPU kernel code 153 | void check_correct(SeedsCodewordsFlat *cw, uint128_t_gpu *target, 154 | int batch_size, int num_entries, 155 | int permutated_ordering) { 156 | int zz = 0; 157 | for (int i = 0; i < batch_size; i++) { 158 | for (int j = 0; j < num_entries; j++) { 159 | 160 | uint128_t truth = EvaluateFlat(&cw[i], j); 161 | uint128_t_gpu truth_128_t_gpu = uint128_gpu_from(truth); 162 | 163 | // This is the "standard" ordering 164 | uint128_t_gpu got; 165 | if (!permutated_ordering) { 166 | got = target[j*batch_size+i]; 167 | } 168 | else { 169 | // This is the "permutated" ordering 170 | //int tgt_indx = brev_cpu(j) >> 32 - cw[0].depth; 171 | int tgt_indx = brev_cpu(j) >> 32 - (int)log2(num_entries); 172 | got = target[tgt_indx + i*num_entries]; 173 | } 174 | 175 | // For debugging 176 | //printf("Got : %d %d %d %d\n", got.x, got.y, got.z, got.w); 177 | //printf("Expect: %d %d %d %d\n", truth_128_t_gpu.x, truth_128_t_gpu.y, truth_128_t_gpu.z, truth_128_t_gpu.w); 178 | //zz += 1; 179 | //if (zz >= 100) return; 180 | 181 | assert(got.x == truth_128_t_gpu.x && 182 | got.y == truth_128_t_gpu.y && 183 | got.z == truth_128_t_gpu.z && 184 | got.w == truth_128_t_gpu.w); 185 | } 186 | } 187 | printf("PASS\n"); 188 | } 189 | 190 | void check_correct_fused(SeedsCodewordsFlat *cw, uint128_t_gpu *target, uint128_t_gpu *table, 191 | int entry_size, int batch_size, int num_entries) { 192 | for (int i = 0; i < batch_size; i++) { 193 | for (int k = 0; k < entry_size; k++) { 194 | uint128_t accum = 0; 195 | for (int j = 0; j < num_entries; j++) { 196 | uint128_t truth = EvaluateFlat(&cw[i], j); 197 | accum += truth * uint128_from_gpu(table[j+k*num_entries]); 198 | } 199 | 200 | uint128_t_gpu cmp = uint128_gpu_from(accum); 201 | uint128_t_gpu got = target[i+k*batch_size]; 202 | 203 | assert(got.x == cmp.x && 204 | got.y == cmp.y && 205 | got.z == cmp.z && 206 | got.w == cmp.w); 207 | } 208 | } 209 | printf("PASS MATMUL CHECK\n"); 210 | } 211 | 212 | #endif 213 | -------------------------------------------------------------------------------- /paper/kernel/gpu/scripts/scrape.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | import glob 4 | import sys 5 | 6 | def extract(fname): 7 | with open(fname) as f: 8 | lines = f.readlines() 9 | lines = [x for x in lines if x.strip() != ""] 10 | try: 11 | lines[-1] = lines[-1].replace("inf", "-1") 12 | z = eval(lines[-1]) 13 | if type(z) == dict: 14 | return z 15 | return None 16 | except: 17 | return None 18 | 19 | fs = glob.glob(sys.argv[1]+"/*") 20 | ds = [extract(x) for x in fs] 21 | ds = [x for x in ds if x is not None] 22 | 23 | ks = [str(k) for k,v in sorted(ds[0].items(), key=lambda x:x[0])] 24 | print(",".join(ks)) 25 | for d in ds: 26 | if d is None: 27 | continue 28 | kvs = [] 29 | for k,v in sorted(d.items(), key=lambda x: x[0]): 30 | kvs.append(str(v)) 31 | print(",".join(kvs)) 32 | -------------------------------------------------------------------------------- /paper/kernel/gpu/scripts/sweep.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | NUM_ENTRIES=( 8192 16384 32768 65536 131072 262144 524288 1048576 2097152 4194304 8388608 16777216 33554432 67108864 ) 4 | BATCH_SIZE=( 8 16 32 64 128 256 512 1024 2048 4096 ) 5 | 6 | mkdir -p sweep/sweep_entry_size=1 7 | 8 | for num_entries in "${NUM_ENTRIES[@]}"; do 9 | for batch_size in "${BATCH_SIZE[@]}"; do 10 | echo $num_entries $batch_size 11 | make ENTRY_SIZE=1 NUM_ENTRIES=$num_entries FUSES_MATMUL=1 DPF_STRATEGY="DPF_HYBRID" BATCH_SIZE=$batch_size benchmark 12 | ./dpf_benchmark > sweep/sweep_entry_size=1/entries=${num_entries},batch_size=${batch_size} 2>&1 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | # sample.py 4 | # ------------------------------------ 5 | # Example usage of DPF interface. 6 | # 7 | # Problem setting: 8 | # - Client wishes to retrieve an entry from a table held on two non-colluding servers 9 | # - Client does not wish to leak any information about the index they are retrieving 10 | # 11 | # Solution: 12 | # - Client constructs a DPF representing their secret index 13 | # - Client generates two keys k1, k2 from the DPF 14 | # - Client sends k1, k2 to non-colluding servers 1 and 2 respectively 15 | # - Servers 1 and 2 evaluate k1 and k2 returning the result 16 | # - Client adds the shares together to obtain the table entry 17 | 18 | import sys 19 | import dpf 20 | import torch 21 | 22 | # Table parameters 23 | table_size = 16384 24 | entry_size = 1 25 | 26 | # The actual table (replicated on 2 non-colluding servers) 27 | table = torch.randint(2**31, (table_size, entry_size)).int() 28 | table[42,:] = 42 29 | 30 | def server(k): 31 | 32 | # Server initializes DPF w/ table 33 | dpf_ = dpf.DPF() 34 | dpf_.eval_init(table) 35 | 36 | # Server evaluates DPF on table 37 | return dpf_.eval_gpu([k]) 38 | 39 | def client(): 40 | secret_indx = 42 41 | 42 | # Generate two keys that represents the secret indx 43 | dpf_ = dpf.DPF() 44 | k1, k2 = dpf_.gen(secret_indx, table_size) 45 | 46 | # Send one key to each server to evaluate. 47 | # 48 | # Assuming that these two servers do not collude, 49 | # the servers learn _nothing_ about secret_indx. 50 | a = server(k1).item() 51 | b = server(k2).item() 52 | 53 | rec = a-b 54 | 55 | print(a, b, rec) 56 | assert(rec == 42) 57 | 58 | if __name__=="__main__": 59 | client() 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | from setuptools import setup, Extension 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | from torch.utils import cpp_extension 6 | 7 | setup(name='dpf_cpp', 8 | ext_modules=[ 9 | CUDAExtension('dpf_cpp', sources=[ 10 | 'dpf_wrapper.cu', 11 | ], extra_compile_args=['-std=c++17'], 12 | ) 13 | ], 14 | cmdclass={'build_ext': cpp_extension.BuildExtension}, 15 | ) 16 | 17 | 18 | --------------------------------------------------------------------------------