├── .github └── stip_overview.png ├── LICENSE ├── README.md ├── model.py ├── preprint_Secure_Transformer_Inference.pdf ├── stip_arxiv_update.pdf ├── stip_llama.py └── stip_original.py /.github/stip_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanmu97/secure-transformer-inference/d8e8b876280979efd996ace5d9ed4b3a50bb271f/.github/stip_overview.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mu Yuan (袁牧) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Secure Transformer Inference 2 | 3 | Secure Transformer Inference Protocol, STIP, is a three-party protocol that can protect both Transformer parameters and user data during the inference phase. 4 | For each feedforward inference process, STIP only introduces permutation computation of input and output data on the user side. 5 | STIP can be applied to real-world services like ChatGPT. 6 | 7 | The figure below shows the overview of STIP. 8 | 9 | overview 10 | 11 | We consider three parties: 12 | 13 | * Party-1 ($P_1$): Model developer (e.g., OpenAI) that owns the original Transformer model $f_\theta$. 14 | * Party-2 ($P_2$): Cloud computing platform (e.g., Azure) that owns the computing hardware. 15 | * Party-3 ($P_3$): Users that own private input (e.g., prompt token embedding) and output (e.g., response token logits). 16 | 17 | Initialization phase: 18 | * $P_1$ randomly generate $\pi \in \mathbb{R}^{d\times d}$ 19 | * $P_1$ transform $f_\theta$ to $f_{\theta'}$ using $\pi$ 20 | * $P_1$ send $f_{\theta'}$ to $P_2$ and send $\pi$ to $P_3$ 21 | 22 | Inference phase: 23 | * $P_3$ transform $x$ to $x'=x\pi$ and send $x'$ to $P_2$ 24 | * $P_2$ compute $f_{\theta'}(x')=y'$ and send $y'$ to $P_3$ 25 | * $P_3$ de-transform $y'$ by computing $y'\pi^T$ and get $y\pi\pi^T=y$ 26 | 27 | For detailed transformation of model parameters, please refer to [our paper](./preprint_Secure_Transformer_Inference.pdf). 28 | 29 | ## Test Code 30 | 31 | We tested original Transformer ([Vaswani, Ashish, et al. 2017](https://arxiv.org/abs/1706.03762)) and Llama Transformer ([Touvron, Hugo, et al. 2023](https://arxiv.org/abs/2302.13971)) using PyTorch. 32 | 33 | The test logic is simple: transform the model and re-transform the inference result, then check the absolute difference (Considering the representation error of floating point numbers, not checking for equality) between it and the original result. 34 | 35 | ## Citation 36 | 37 | If you find STIP helpful, please consider citing: 38 | ``` 39 | @misc{cryptoeprint:2023/1763, 40 | author = {Mu Yuan and Lan Zhang and Xiang-Yang Li}, 41 | title = {Secure Transformer Inference}, 42 | howpublished = {Cryptology ePrint Archive, Paper 2023/1763}, 43 | year = {2023}, 44 | note = {\url{https://eprint.iacr.org/2023/1763}}, 45 | url = {https://eprint.iacr.org/2023/1763} 46 | } 47 | ``` 48 | 49 | ## License 50 | 51 | Secure Transformer Inference Protocol (STIP) is licensed under the [MIT License](./LICENSE). 52 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | import math 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple 7 | 8 | import fairscale.nn.model_parallel.initialize as fs_init 9 | import torch 10 | import torch.nn.functional as F 11 | from fairscale.nn.model_parallel.layers import ( 12 | ColumnParallelLinear, 13 | ParallelEmbedding, 14 | RowParallelLinear, 15 | ) 16 | from torch import nn 17 | 18 | 19 | @dataclass 20 | class ModelArgs: 21 | dim: int = 4096 22 | n_layers: int = 32 23 | n_heads: int = 32 24 | n_kv_heads: Optional[int] = None 25 | vocab_size: int = -1 # defined later by tokenizer 26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 27 | ffn_dim_multiplier: Optional[float] = None 28 | norm_eps: float = 1e-5 29 | 30 | max_batch_size: int = 32 31 | max_seq_len: int = 2048 32 | 33 | 34 | class RMSNorm(torch.nn.Module): 35 | def __init__(self, dim: int, eps: float = 1e-6): 36 | """ 37 | Initialize the RMSNorm normalization layer. 38 | 39 | Args: 40 | dim (int): The dimension of the input tensor. 41 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 42 | 43 | Attributes: 44 | eps (float): A small value added to the denominator for numerical stability. 45 | weight (nn.Parameter): Learnable scaling parameter. 46 | 47 | """ 48 | super().__init__() 49 | self.eps = eps 50 | self.weight = nn.Parameter(torch.ones(dim)) 51 | 52 | def _norm(self, x): 53 | """ 54 | Apply the RMSNorm normalization to the input tensor. 55 | 56 | Args: 57 | x (torch.Tensor): The input tensor. 58 | 59 | Returns: 60 | torch.Tensor: The normalized tensor. 61 | 62 | """ 63 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 64 | 65 | def forward(self, x): 66 | """ 67 | Forward pass through the RMSNorm layer. 68 | 69 | Args: 70 | x (torch.Tensor): The input tensor. 71 | 72 | Returns: 73 | torch.Tensor: The output tensor after applying RMSNorm. 74 | 75 | """ 76 | output = self._norm(x.float()).type_as(x) 77 | return output * self.weight 78 | 79 | 80 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 81 | """ 82 | Precompute the frequency tensor for complex exponentials (cis) with given dimensions. 83 | 84 | This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' 85 | and the end index 'end'. The 'theta' parameter scales the frequencies. 86 | The returned tensor contains complex values in complex64 data type. 87 | 88 | Args: 89 | dim (int): Dimension of the frequency tensor. 90 | end (int): End index for precomputing frequencies. 91 | theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. 92 | 93 | Returns: 94 | torch.Tensor: Precomputed frequency tensor with complex exponentials. 95 | 96 | 97 | 98 | 99 | """ 100 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 101 | t = torch.arange(end, device=freqs.device) # type: ignore 102 | freqs = torch.outer(t, freqs).float() # type: ignore 103 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 104 | return freqs_cis 105 | 106 | 107 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 108 | """ 109 | Reshape frequency tensor for broadcasting it with another tensor. 110 | 111 | This function reshapes the frequency tensor to have the same shape as the target tensor 'x' 112 | for the purpose of broadcasting the frequency tensor during element-wise operations. 113 | 114 | Args: 115 | freqs_cis (torch.Tensor): Frequency tensor to be reshaped. 116 | x (torch.Tensor): Target tensor for broadcasting compatibility. 117 | 118 | Returns: 119 | torch.Tensor: Reshaped frequency tensor. 120 | 121 | Raises: 122 | AssertionError: If the frequency tensor doesn't match the expected shape. 123 | AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. 124 | """ 125 | ndim = x.ndim 126 | assert 0 <= 1 < ndim 127 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 128 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 129 | return freqs_cis.view(*shape) 130 | 131 | 132 | def apply_rotary_emb( 133 | xq: torch.Tensor, 134 | xk: torch.Tensor, 135 | freqs_cis: torch.Tensor, 136 | ) -> Tuple[torch.Tensor, torch.Tensor]: 137 | """ 138 | Apply rotary embeddings to input tensors using the given frequency tensor. 139 | 140 | This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided 141 | frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor 142 | is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are 143 | returned as real tensors. 144 | 145 | Args: 146 | xq (torch.Tensor): Query tensor to apply rotary embeddings. 147 | xk (torch.Tensor): Key tensor to apply rotary embeddings. 148 | freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. 149 | 150 | Returns: 151 | Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. 152 | 153 | 154 | 155 | """ 156 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 157 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 158 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 159 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 160 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 161 | return xq_out.type_as(xq), xk_out.type_as(xk) 162 | 163 | 164 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 165 | """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" 166 | bs, slen, n_kv_heads, head_dim = x.shape 167 | if n_rep == 1: 168 | return x 169 | return ( 170 | x[:, :, :, None, :] 171 | .expand(bs, slen, n_kv_heads, n_rep, head_dim) 172 | .reshape(bs, slen, n_kv_heads * n_rep, head_dim) 173 | ) 174 | 175 | 176 | class Attention(nn.Module): 177 | """Multi-head attention module.""" 178 | def __init__(self, args: ModelArgs): 179 | """ 180 | Initialize the Attention module. 181 | 182 | Args: 183 | args (ModelArgs): Model configuration parameters. 184 | 185 | Attributes: 186 | n_kv_heads (int): Number of key and value heads. 187 | n_local_heads (int): Number of local query heads. 188 | n_local_kv_heads (int): Number of local key and value heads. 189 | n_rep (int): Number of repetitions for local heads. 190 | head_dim (int): Dimension size of each attention head. 191 | wq (ColumnParallelLinear): Linear transformation for queries. 192 | wk (ColumnParallelLinear): Linear transformation for keys. 193 | wv (ColumnParallelLinear): Linear transformation for values. 194 | wo (RowParallelLinear): Linear transformation for output. 195 | cache_k (torch.Tensor): Cached keys for attention. 196 | cache_v (torch.Tensor): Cached values for attention. 197 | 198 | """ 199 | super().__init__() 200 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 201 | model_parallel_size = fs_init.get_model_parallel_world_size() 202 | self.n_local_heads = args.n_heads // model_parallel_size 203 | self.n_local_kv_heads = self.n_kv_heads // model_parallel_size 204 | self.n_rep = self.n_local_heads // self.n_local_kv_heads 205 | self.head_dim = args.dim // args.n_heads 206 | 207 | self.wq = ColumnParallelLinear( 208 | args.dim, 209 | args.n_heads * self.head_dim, 210 | bias=False, 211 | gather_output=False, 212 | init_method=lambda x: x, 213 | ) 214 | self.wk = ColumnParallelLinear( 215 | args.dim, 216 | self.n_kv_heads * self.head_dim, 217 | bias=False, 218 | gather_output=False, 219 | init_method=lambda x: x, 220 | ) 221 | self.wv = ColumnParallelLinear( 222 | args.dim, 223 | self.n_kv_heads * self.head_dim, 224 | bias=False, 225 | gather_output=False, 226 | init_method=lambda x: x, 227 | ) 228 | self.wo = RowParallelLinear( 229 | args.n_heads * self.head_dim, 230 | args.dim, 231 | bias=False, 232 | input_is_parallel=True, 233 | init_method=lambda x: x, 234 | ) 235 | 236 | self.cache_k = torch.zeros( 237 | ( 238 | args.max_batch_size, 239 | args.max_seq_len, 240 | self.n_local_kv_heads, 241 | self.head_dim, 242 | ) 243 | ).cuda() 244 | self.cache_v = torch.zeros( 245 | ( 246 | args.max_batch_size, 247 | args.max_seq_len, 248 | self.n_local_kv_heads, 249 | self.head_dim, 250 | ) 251 | ).cuda() 252 | 253 | def forward( 254 | self, 255 | x: torch.Tensor, 256 | start_pos: int, 257 | freqs_cis: torch.Tensor, 258 | mask: Optional[torch.Tensor], 259 | ): 260 | """ 261 | Forward pass of the attention module. 262 | 263 | Args: 264 | x (torch.Tensor): Input tensor. 265 | start_pos (int): Starting position for caching. 266 | freqs_cis (torch.Tensor): Precomputed frequency tensor. 267 | mask (torch.Tensor, optional): Attention mask tensor. 268 | 269 | Returns: 270 | torch.Tensor: Output tensor after attention. 271 | 272 | """ 273 | bsz, seqlen, _ = x.shape 274 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 275 | 276 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 277 | xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 278 | xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) 279 | 280 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 281 | 282 | self.cache_k = self.cache_k.to(xq) 283 | self.cache_v = self.cache_v.to(xq) 284 | 285 | self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk 286 | self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv 287 | 288 | keys = self.cache_k[:bsz, : start_pos + seqlen] 289 | values = self.cache_v[:bsz, : start_pos + seqlen] 290 | 291 | # repeat k/v heads if n_kv_heads < n_heads 292 | keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 293 | values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) 294 | 295 | xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) 296 | keys = keys.transpose(1, 2) 297 | values = values.transpose(1, 2) 298 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 299 | if mask is not None: 300 | scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) 301 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 302 | output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) 303 | output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) 304 | return self.wo(output) 305 | 306 | 307 | class FeedForward(nn.Module): 308 | def __init__( 309 | self, 310 | dim: int, 311 | hidden_dim: int, 312 | multiple_of: int, 313 | ffn_dim_multiplier: Optional[float], 314 | ): 315 | """ 316 | Initialize the FeedForward module. 317 | 318 | Args: 319 | dim (int): Input dimension. 320 | hidden_dim (int): Hidden dimension of the feedforward layer. 321 | multiple_of (int): Value to ensure hidden dimension is a multiple of this value. 322 | ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. 323 | 324 | Attributes: 325 | w1 (ColumnParallelLinear): Linear transformation for the first layer. 326 | w2 (RowParallelLinear): Linear transformation for the second layer. 327 | w3 (ColumnParallelLinear): Linear transformation for the third layer. 328 | 329 | """ 330 | super().__init__() 331 | hidden_dim = int(2 * hidden_dim / 3) 332 | # custom dim factor multiplier 333 | if ffn_dim_multiplier is not None: 334 | hidden_dim = int(ffn_dim_multiplier * hidden_dim) 335 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 336 | 337 | self.w1 = ColumnParallelLinear( 338 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 339 | ) 340 | self.w2 = RowParallelLinear( 341 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x 342 | ) 343 | self.w3 = ColumnParallelLinear( 344 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 345 | ) 346 | 347 | def forward(self, x): 348 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 349 | 350 | 351 | class TransformerBlock(nn.Module): 352 | def __init__(self, layer_id: int, args: ModelArgs): 353 | """ 354 | Initialize a TransformerBlock. 355 | 356 | Args: 357 | layer_id (int): Identifier for the layer. 358 | args (ModelArgs): Model configuration parameters. 359 | 360 | Attributes: 361 | n_heads (int): Number of attention heads. 362 | dim (int): Dimension size of the model. 363 | head_dim (int): Dimension size of each attention head. 364 | attention (Attention): Attention module. 365 | feed_forward (FeedForward): FeedForward module. 366 | layer_id (int): Identifier for the layer. 367 | attention_norm (RMSNorm): Layer normalization for attention output. 368 | ffn_norm (RMSNorm): Layer normalization for feedforward output. 369 | 370 | """ 371 | super().__init__() 372 | self.n_heads = args.n_heads 373 | self.dim = args.dim 374 | self.head_dim = args.dim // args.n_heads 375 | self.attention = Attention(args) 376 | self.feed_forward = FeedForward( 377 | dim=args.dim, 378 | hidden_dim=4 * args.dim, 379 | multiple_of=args.multiple_of, 380 | ffn_dim_multiplier=args.ffn_dim_multiplier, 381 | ) 382 | self.layer_id = layer_id 383 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 384 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 385 | 386 | def forward( 387 | self, 388 | x: torch.Tensor, 389 | start_pos: int, 390 | freqs_cis: torch.Tensor, 391 | mask: Optional[torch.Tensor], 392 | ): 393 | """ 394 | Perform a forward pass through the TransformerBlock. 395 | 396 | Args: 397 | x (torch.Tensor): Input tensor. 398 | start_pos (int): Starting position for attention caching. 399 | freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. 400 | mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. 401 | 402 | Returns: 403 | torch.Tensor: Output tensor after applying attention and feedforward layers. 404 | 405 | """ 406 | h = x + self.attention.forward( 407 | self.attention_norm(x), start_pos, freqs_cis, mask 408 | ) 409 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 410 | return out 411 | 412 | 413 | class Transformer(nn.Module): 414 | def __init__(self, params: ModelArgs): 415 | """ 416 | Initialize a Transformer model. 417 | 418 | Args: 419 | params (ModelArgs): Model configuration parameters. 420 | 421 | Attributes: 422 | params (ModelArgs): Model configuration parameters. 423 | vocab_size (int): Vocabulary size. 424 | n_layers (int): Number of layers in the model. 425 | tok_embeddings (ParallelEmbedding): Token embeddings. 426 | layers (torch.nn.ModuleList): List of Transformer blocks. 427 | norm (RMSNorm): Layer normalization for the model output. 428 | output (ColumnParallelLinear): Linear layer for final output. 429 | freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. 430 | 431 | """ 432 | super().__init__() 433 | self.params = params 434 | self.vocab_size = params.vocab_size 435 | self.n_layers = params.n_layers 436 | 437 | self.tok_embeddings = ParallelEmbedding( 438 | params.vocab_size, params.dim, init_method=lambda x: x 439 | ) 440 | 441 | self.layers = torch.nn.ModuleList() 442 | for layer_id in range(params.n_layers): 443 | self.layers.append(TransformerBlock(layer_id, params)) 444 | 445 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 446 | self.output = ColumnParallelLinear( 447 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 448 | ) 449 | 450 | self.freqs_cis = precompute_freqs_cis( 451 | # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 452 | # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. 453 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 454 | ) 455 | 456 | @torch.inference_mode() 457 | def forward(self, tokens: torch.Tensor, start_pos: int): 458 | """ 459 | Perform a forward pass through the Transformer model. 460 | 461 | Args: 462 | tokens (torch.Tensor): Input token indices. 463 | start_pos (int): Starting position for attention caching. 464 | 465 | Returns: 466 | torch.Tensor: Output logits after applying the Transformer model. 467 | 468 | """ 469 | _bsz, seqlen = tokens.shape 470 | h = self.tok_embeddings(tokens) 471 | self.freqs_cis = self.freqs_cis.to(h.device) 472 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 473 | 474 | mask = None 475 | if seqlen > 1: 476 | mask = torch.full( 477 | (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device 478 | ) 479 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) 480 | 481 | for layer in self.layers: 482 | h = layer(h, start_pos, freqs_cis, mask) 483 | h = self.norm(h) 484 | output = self.output(h).float() 485 | return output 486 | -------------------------------------------------------------------------------- /preprint_Secure_Transformer_Inference.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanmu97/secure-transformer-inference/d8e8b876280979efd996ace5d9ed4b3a50bb271f/preprint_Secure_Transformer_Inference.pdf -------------------------------------------------------------------------------- /stip_arxiv_update.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanmu97/secure-transformer-inference/d8e8b876280979efd996ace5d9ed4b3a50bb271f/stip_arxiv_update.pdf -------------------------------------------------------------------------------- /stip_llama.py: -------------------------------------------------------------------------------- 1 | """ 2 | reference: https://github.com/facebookresearch/codellama 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from stip_original import MultiHeadAttention 10 | import copy 11 | 12 | class FeedForward(nn.Module): 13 | def __init__( 14 | self, 15 | dim: int, 16 | hidden_dim: int, 17 | ): 18 | super(FeedForward, self).__init__() 19 | self.w1 = nn.Linear(dim, hidden_dim, bias=False) 20 | self.w2 = nn.Linear(hidden_dim, dim, bias=False) 21 | self.w3 = nn.Linear(dim, hidden_dim, bias=False) 22 | 23 | def forward(self, x): 24 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 25 | 26 | 27 | class RMSNorm(torch.nn.Module): 28 | def __init__(self, dim: int, eps: float = 1e-6): 29 | """ 30 | Initialize the RMSNorm normalization layer. 31 | 32 | Args: 33 | dim (int): The dimension of the input tensor. 34 | eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. 35 | 36 | Attributes: 37 | eps (float): A small value added to the denominator for numerical stability. 38 | weight (nn.Parameter): Learnable scaling parameter. 39 | 40 | """ 41 | super().__init__() 42 | self.eps = eps 43 | self.weight = nn.Parameter(torch.ones(dim)) 44 | 45 | def _norm(self, x): 46 | """ 47 | Apply the RMSNorm normalization to the input tensor. 48 | 49 | Args: 50 | x (torch.Tensor): The input tensor. 51 | 52 | Returns: 53 | torch.Tensor: The normalized tensor. 54 | 55 | """ 56 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 57 | 58 | def forward(self, x): 59 | """ 60 | Forward pass through the RMSNorm layer. 61 | 62 | Args: 63 | x (torch.Tensor): The input tensor. 64 | 65 | Returns: 66 | torch.Tensor: The output tensor after applying RMSNorm. 67 | 68 | """ 69 | output = self._norm(x.float()).type_as(x) 70 | return output * self.weight 71 | 72 | 73 | class LlamaTransformerBlock(nn.Module): 74 | """ 75 | """ 76 | def __init__(self, d_model, num_heads, d_ff): 77 | super(LlamaTransformerBlock, self).__init__() 78 | self.attention_layer = MultiHeadAttention(d_model, num_heads) 79 | self.ff_layer = FeedForward(d_model, d_ff) 80 | self.rmsn1 = RMSNorm(d_model) 81 | self.rmsn2 = RMSNorm(d_model) 82 | 83 | def forward(self, x, mask): 84 | attention_input = self.rmsn1(x) 85 | attention_output = self.attention_layer(attention_input, attention_input, attention_input, mask) 86 | h = x + attention_output 87 | 88 | ff_input = self.rmsn2(h) 89 | ff_output = self.ff_layer(ff_input) 90 | res = h + ff_output 91 | return res 92 | 93 | 94 | def permute_block(blk, p): 95 | p_blk = copy.deepcopy(blk) 96 | with torch.no_grad(): 97 | for name, para in p_blk.named_parameters(): 98 | # print(name) 99 | if name in ["attention_layer.w_q.weight", 100 | "attention_layer.w_k.weight", 101 | "attention_layer.w_v.weight", 102 | "ff_layer.w1.weight", 103 | "ff_layer.w3.weight"]: 104 | para.data = para.data[:, p] 105 | if name in ["attention_layer.w_o.weight", 106 | "attention_layer.w_o.bias", 107 | "ff_layer.w2.weight"]: 108 | para.data = para.data[p] 109 | return p_blk 110 | 111 | 112 | if __name__ == "__main__": 113 | BS = 2 114 | SEQLEN = 3 115 | DMODEL = 4 116 | DFF = 8 117 | NHEADS = 2 118 | 119 | TEST_BLK = 1 120 | 121 | if TEST_BLK: 122 | x = torch.from_numpy(np.random.rand(BS, SEQLEN, DMODEL)).float() 123 | block = LlamaTransformerBlock(DMODEL, NHEADS, DFF) 124 | p = np.random.permutation(DMODEL) 125 | 126 | print(f"Permutation={p}") 127 | xp = x[:, :, p] 128 | p_block = permute_block(block, p) 129 | mask = (1 - torch.triu(torch.ones(1, SEQLEN, SEQLEN), diagonal=1)).bool() 130 | 131 | with torch.no_grad(): 132 | print("Encoder Block:") 133 | y = block(x, None) 134 | yp = p_block(xp, None) 135 | diff = np.abs(y[:, :, p] - yp).sum() 136 | print("Original reslut:\n", y, "\nNew result:\n", yp) 137 | print("Diff=", diff) 138 | 139 | print("Decoder Block:") 140 | y = block(x, mask) 141 | yp = p_block(xp, mask) 142 | diff = np.abs(y[:, :, p] - yp).sum() 143 | print("Original reslut:\n", y, "\nNew result:\n", yp) 144 | print("Diff=", diff) 145 | -------------------------------------------------------------------------------- /stip_original.py: -------------------------------------------------------------------------------- 1 | """ 2 | reference: https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import numpy as np 8 | import copy 9 | import torch.nn.functional as F 10 | 11 | 12 | class MultiHeadAttention(nn.Module): 13 | """ 14 | """ 15 | def __init__(self, d_model, num_heads): 16 | super(MultiHeadAttention, self).__init__() 17 | assert d_model % num_heads == 0 18 | 19 | self.d_model = d_model 20 | self.num_heads = num_heads 21 | self.d_k = d_model // num_heads 22 | 23 | self.w_q = nn.Linear(d_model, d_model, bias=True) 24 | self.w_k = nn.Linear(d_model, d_model, bias=True) 25 | self.w_v = nn.Linear(d_model, d_model, bias=True) 26 | self.w_o = nn.Linear(d_model, d_model, bias=True) 27 | 28 | 29 | def split_heads(self, x): 30 | batch_size, seq_length, d_model = x.size() 31 | # d_model = num_heads * d_k 32 | return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) 33 | 34 | 35 | def scaled_dot_product_attention(self, q, k, v, mask): 36 | attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 37 | # shape [BS, NUM_HEADS, SEQ_LEN, SEQ_LEN] 38 | 39 | if mask is not None: 40 | attention_scores = attention_scores.masked_fill(mask == 0, -1e9) 41 | 42 | attention_probs = torch.softmax(attention_scores, dim=-1) 43 | 44 | return torch.matmul(attention_probs, v) 45 | 46 | def combine_heads(self, x): 47 | batch_size, num_heads, seq_length, d_k = x.size() 48 | return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model) 49 | 50 | 51 | def forward(self, q, k, v, mask): 52 | q = self.split_heads(self.w_q(q)) 53 | k = self.split_heads(self.w_k(k)) 54 | v = self.split_heads(self.w_v(v)) 55 | 56 | attention_output = self.scaled_dot_product_attention(q, k, v, mask) 57 | attention_output = self.combine_heads(attention_output) 58 | 59 | return self.w_o(attention_output) 60 | 61 | 62 | class FeedForward(nn.Module): 63 | """ 64 | """ 65 | def __init__(self, d_model, d_ff): 66 | super(FeedForward, self).__init__() 67 | self.fc1 = nn.Linear(d_model, d_ff, bias=True) 68 | self.fc2 = nn.Linear(d_ff, d_model, bias=True) 69 | self.relu = nn.ReLU() 70 | 71 | def forward(self, x): 72 | return self.fc2(self.relu(self.fc1(x))) 73 | 74 | 75 | class EncoderBlock(nn.Module): 76 | """ 77 | """ 78 | def __init__(self, d_model, num_heads, d_ff): 79 | super(EncoderBlock, self).__init__() 80 | self.attention_layer = MultiHeadAttention(d_model, num_heads) 81 | self.ff_layer = FeedForward(d_model, d_ff) 82 | self.ln1 = nn.LayerNorm(d_model) 83 | self.ln2 = nn.LayerNorm(d_model) 84 | 85 | def forward(self, x, mask=None): 86 | attention_output = self.attention_layer(x, x, x, mask) 87 | x0 = x + attention_output 88 | x1 = self.ln1(x0) 89 | 90 | ff_output = self.ff_layer(x1) 91 | x2 = self.ln2(x1 + ff_output) 92 | return x2 93 | 94 | 95 | class DecoderBlock(nn.Module): 96 | """ 97 | """ 98 | def __init__(self, d_model, num_heads, d_ff): 99 | super(DecoderBlock, self).__init__() 100 | self.self_attention = MultiHeadAttention(d_model, num_heads) 101 | self.cross_attention = MultiHeadAttention(d_model, num_heads) 102 | self.ff_layer = FeedForward(d_model, d_ff) 103 | self.ln1 = nn.LayerNorm(d_model) 104 | self.ln2 = nn.LayerNorm(d_model) 105 | self.ln3 = nn.LayerNorm(d_model) 106 | 107 | def forward(self, x, enc_output, src_mask, tgt_mask): 108 | self_attention_output = self.self_attention(x, x, x, tgt_mask) 109 | x = self.ln1(x + self_attention_output) 110 | cross_attention_output = self.cross_attention(x, enc_output, enc_output, src_mask) 111 | x = self.ln2(x + cross_attention_output) 112 | ff_output = self.ff_layer(x) 113 | x = self.ln3(x + ff_output) 114 | return x 115 | 116 | 117 | class Transformer(nn.Module): 118 | """ 119 | """ 120 | def __init__(self, d_model, num_heads, d_ff, num_layers, output_dim): 121 | super(Transformer, self).__init__() 122 | self.enc_layers = nn.ModuleList([EncoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)]) 123 | self.dec_layers = nn.ModuleList([DecoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)]) 124 | self.fc = nn.Linear(d_model, output_dim) 125 | 126 | def generate_mask(self, src, tgt): 127 | src_mask = (src != 0).unsqueeze(1).unsqueeze(2) 128 | tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3) 129 | seq_length = tgt.size(1) 130 | nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool() 131 | tgt_mask = tgt_mask & nopeak_mask 132 | return src_mask, tgt_mask 133 | 134 | def forward(self, x, target): 135 | src_mask, tgt_mask = self.generate_mask(x, target) 136 | 137 | enc_output = x 138 | for enc_l in self.enc_layers: 139 | enc_output = enc_l(enc_output, src_mask) 140 | 141 | dec_output = target 142 | for dec_l in self.dec_layers: 143 | dec_output = dec_l(dec_output, enc_output, src_mask, tgt_mask) 144 | 145 | output = self.fc(dec_output) 146 | return output 147 | 148 | 149 | def inv_permutation(p): 150 | inv_p = [0]*len(p) 151 | for old_idx, new_idx in enumerate(p): 152 | inv_p[new_idx] = old_idx 153 | return inv_p 154 | 155 | 156 | def permute_attention(attn_layer, p): 157 | inv_p = inv_permutation(p) 158 | print(inv_p) 159 | p_attn_layer = copy.deepcopy(attn_layer) 160 | with torch.no_grad(): 161 | for name, para in p_attn_layer.named_parameters(): 162 | if name in ["w_q.weight", "w_k.weight", "w_v.weight"]: 163 | para.data = para.data[:, p] 164 | if name in ["w_o.weight", "w_o.bias"]: 165 | para.data = para.data[p] 166 | return p_attn_layer 167 | 168 | 169 | def permute_feedforward(ffn_layer, p): 170 | inv_p = inv_permutation(p) 171 | p_ffn_layer = copy.deepcopy(ffn_layer) 172 | with torch.no_grad(): 173 | for name, para in p_ffn_layer.named_parameters(): 174 | # print(name) 175 | if name in ["fc1.weight"]: 176 | para.data = para.data[:, p] 177 | if name in ["fc2.weight", "fc2.bias"]: 178 | para.data = para.data[p] 179 | return p_ffn_layer 180 | 181 | 182 | def permute_block(blk, p): 183 | inv_p = inv_permutation(p) 184 | p_blk = copy.deepcopy(blk) 185 | with torch.no_grad(): 186 | for name, para in p_blk.named_parameters(): 187 | if name in ["attention_layer.w_q.weight", 188 | "attention_layer.w_k.weight", 189 | "attention_layer.w_v.weight", 190 | "ff_layer.fc1.weight"]: 191 | para.data = para.data[:, p] 192 | if name in ["attention_layer.w_o.weight", 193 | "attention_layer.w_o.bias", 194 | "ff_layer.fc2.weight", 195 | "ff_layer.fc2.bias"]: 196 | para.data = para.data[p] 197 | 198 | return p_blk 199 | 200 | 201 | if __name__ == "__main__": 202 | BS = 2 203 | SEQLEN = 3 204 | DMODEL = 4 205 | DFF = 8 206 | NHEADS = 2 207 | 208 | TEST_ATT = 0 209 | TEST_FFN = 0 210 | TEST_BLK = 1 211 | 212 | if TEST_ATT: 213 | x = torch.from_numpy(np.random.rand(BS, SEQLEN, DMODEL)).float() 214 | 215 | attention_layer = MultiHeadAttention(d_model=DMODEL, num_heads=NHEADS) 216 | mask = (1 - torch.triu(torch.ones(1, SEQLEN, SEQLEN), diagonal=1)).bool() 217 | with torch.no_grad(): 218 | y = attention_layer(x, x, x, None) 219 | y_mask = attention_layer(x, x, x, mask) 220 | 221 | p = np.random.permutation(x.shape[2]) 222 | print(p) 223 | 224 | xp = x[:, :, p] 225 | p_attention_layer = permute_attention(attention_layer, p) 226 | 227 | with torch.no_grad(): 228 | print("Without mask:") 229 | yp = p_attention_layer(xp, xp, xp, None) 230 | diff = np.abs(y[:, :, p] - yp).sum() 231 | print("Original reslut:\n", y, "\nNew result:\n", yp) 232 | print("Diff=", diff) 233 | 234 | print("With mask:") 235 | yp_mask = p_attention_layer(xp, xp, xp, mask) 236 | diff_mask = np.abs(y_mask[:, :, p] - yp_mask).sum() 237 | print("Original reslut:\n", y_mask, "\nNew result:\n", yp_mask) 238 | print("Diff=", diff_mask) 239 | 240 | if TEST_FFN: 241 | x = torch.from_numpy(np.random.rand(BS, SEQLEN, DMODEL)).float() 242 | ffn_layer = FeedForward(DMODEL, DFF) 243 | 244 | p = np.random.permutation(DMODEL) 245 | print(f"Permutation={p}") 246 | 247 | xp = x[:, :, p] 248 | p_ffn_layer = permute_feedforward(ffn_layer, p) 249 | 250 | with torch.no_grad(): 251 | print("Feedforward Layer:") 252 | y = ffn_layer(x) 253 | yp = p_ffn_layer(xp) 254 | diff = np.abs(y[:, :, p] - yp).sum() 255 | print("Original reslut:\n", y, "\nNew result:\n", yp) 256 | print("Diff=", diff) 257 | 258 | 259 | if TEST_BLK: 260 | x = torch.from_numpy(np.random.rand(BS, SEQLEN, DMODEL)).float() 261 | block = EncoderBlock(DMODEL, NHEADS, DFF) 262 | p = np.random.permutation(DMODEL) 263 | 264 | print(f"Permutation={p}") 265 | xp = x[:, :, p] 266 | p_block = permute_block(block, p) 267 | mask = (1 - torch.triu(torch.ones(1, SEQLEN, SEQLEN), diagonal=1)).bool() 268 | 269 | with torch.no_grad(): 270 | print("Encoder Block:") 271 | y = block(x, None) 272 | yp = p_block(xp, None) 273 | diff = np.abs(y[:, :, p] - yp).sum() 274 | print("Original reslut:\n", y, "\nNew result:\n", yp) 275 | print("Diff=", diff) 276 | 277 | print("Decoder Block:") 278 | y = block(x, mask) 279 | yp = p_block(xp, mask) 280 | diff = np.abs(y[:, :, p] - yp).sum() 281 | print("Original reslut:\n", y, "\nNew result:\n", yp) 282 | print("Diff=", diff) 283 | --------------------------------------------------------------------------------