├── LICENSE ├── README.md ├── hashedEmbeddingBag.py ├── hashedEmbeddingCPU.py ├── hashed_embedding_bag_kernel.cu └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aditya Desai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Description 2 | Code for Embeddings with Random Offset Block Embedding Array (ROBE) 3 | - The code was modified from the original pytorch Embedding bag code 4 | - the code is only tested on embedding style scenarios. So beware while using embedding bag 5 | 6 | # How to run 7 | after checkout run the following to install the ROBE/UMA 8 | ``` 9 | python3 setup.py install 10 | ``` 11 | 12 | Usage Multiple Embedding Tables sharing the same underlying set of parameters: 13 | ``` 14 | import hashedEmbeddingBag 15 | import torch 16 | import torch.nn as nn 17 | import numpy as np 18 | 19 | 20 | robe_size = 1000 21 | _weight = nn.Parameter( torch.from_numpy( np.random.uniform(low = -0.001, high=0.001, size=((robe_size, ))).astype(np.float32))) 22 | 23 | n1 = 100000 24 | m1 = 16 25 | E1 = hashedEmbeddingBag.HashedEmbeddingBag(n1, m1, _weight=_weight, val_offset=0).cuda(0) 26 | 27 | n2 = 200000 28 | m2 = 32 29 | E2 = hashedEmbeddingBag.HashedEmbeddingBag(n2, m2, _weight=_weight, val_offset=n1).cuda(0) # note the offset 30 | 31 | indices = torch.arange(5).cuda(0) 32 | embeddings1 = E1(indices) 33 | embeddings2 = E2(indices) 34 | ``` 35 | 36 | #While running with CPU 37 | 38 | hashedEmbeddingBag is written for GPU. 39 | In order to use for CPU/GPU , you can directly use 40 | hashedEmbeddingCPU module. Everything remains the same 41 | 42 | 43 | -------------------------------------------------------------------------------- /hashedEmbeddingBag.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import numpy as np 7 | from torch.nn.parameter import Parameter 8 | import math 9 | 10 | import hashed_embedding_bag 11 | import pdb 12 | 13 | class HashedEmbeddingBagFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, hashed_weights, indices, offsets, mode, embedding_dim, signature, random_numbers, hmode, keymode, val_offset, norm, key_bits, keys_to_use, uma_chunk_size): 16 | if indices.dim() == 2: 17 | if offsets is not None: 18 | raise ValueError("if indices is 2D, then offsets has to be None" 19 | ", as indices is treated is a mini-batch of" 20 | " fixed length sequences. However, found " 21 | "offsets of type {}".format(type(offsets))) 22 | offsets = torch.arange(0, indices.numel(), indices.size(1), 23 | dtype=torch.long, device=indices.device) 24 | indices = indices.reshape(-1) 25 | elif indices.dim() == 1: 26 | if offsets is None: 27 | raise ValueError("offsets has to be a 1D Tensor but got None") 28 | if offsets.dim() != 1: 29 | raise ValueError("offsets has to be a 1D Tensor") 30 | else: 31 | raise ValueError("indices has to be 1D or 2D Tensor," 32 | " but got Tensor of dimension {}".format(indices.dim())) 33 | 34 | if mode == 'sum': 35 | mode_enum = 0 36 | elif mode == 'mean': 37 | mode_enum = 1 38 | raise ValueError("mean mode not supported") 39 | elif mode == 'max': 40 | mode_enum = 2 41 | raise ValueError("max mode not supported") 42 | 43 | if hmode == "rand_hash": 44 | hmode_enum = 0 45 | elif hmode == "lma_hash": 46 | hmode_enum = 1 47 | else: 48 | raise ValueError("hmode not defined") 49 | 50 | 51 | if keymode == "keymode_hashweight": 52 | keymode_enum = 0; 53 | elif keymode == "keymode_static_pm": 54 | keymode_enum = 1; 55 | else: 56 | raise ValueError("keymode not defined") 57 | 58 | if val_offset is not None: 59 | indices = indices + val_offset 60 | 61 | 62 | hashed_weights_size = hashed_weights.size(0) 63 | output, offset2bag, bag_size, max_indices, hashed_idx = \ 64 | hashed_embedding_bag.forward(hashed_weights, indices, offsets, mode_enum, embedding_dim, signature, random_numbers, hmode_enum, keymode_enum, key_bits, keys_to_use, uma_chunk_size) 65 | if norm is not None: 66 | #assert(keymode_enum == 1) 67 | output = output/norm 68 | ctx.save_for_backward(indices, offsets, offset2bag, bag_size, max_indices, hashed_idx) 69 | ctx.mode_enum = mode_enum 70 | ctx.hashed_weights_size = hashed_weights_size 71 | ctx.keymode_enum = keymode_enum 72 | return output 73 | 74 | @staticmethod 75 | def backward(ctx, grad): 76 | indices, offsets, offset2bag, bag_size, max_indices, hashed_idx = ctx.saved_variables 77 | hashed_weights_size = ctx.hashed_weights_size 78 | mode_enum = ctx.mode_enum 79 | keymode_enum = ctx.keymode_enum 80 | embedding_dim = grad.size(1) 81 | if keymode_enum == 0: 82 | weight_grad = hashed_embedding_bag.backward( 83 | grad, indices, offsets, offset2bag, bag_size, max_indices, hashed_idx, hashed_weights_size, False, mode_enum, embedding_dim) 84 | elif keymode_enum == 1: 85 | weight_grad = None 86 | return weight_grad, None, None, None, None, None, None, None,None,None,None,None,None, None 87 | 88 | # use this when we just want the embedding and not the bag 89 | ''' 90 | @staticmethod 91 | def backward(ctx, grad): 92 | keymode_enum = ctx.keymode_enum 93 | if keymode_enum == 0: 94 | indices, offsets, offset2bag, bag_size, max_indices, hashed_idx = ctx.saved_variables 95 | hashed_weights_size = ctx.hashed_weights_size 96 | if hashed_idx.is_contiguous(): 97 | hashed_idx1 = hashed_idx.view(-1) 98 | else: 99 | hashed_idx1 = hashed_idx.reshape(-1) 100 | if grad.is_contiguous(): 101 | grad1 = grad.view(-1) 102 | else: 103 | grad1 = grad.reshape(-1) 104 | hashed_weight_grad = torch.zeros((hashed_weights_size,),dtype=torch.float32, device=indices.device) 105 | hashed_weight_grad.scatter_add_(0, hashed_idx1, grad1) 106 | elif keymode_enum == 1: 107 | weight_grad = None 108 | return hashed_weight_grad, None, None, None, None, None, None, None,None,None,None,None,None, None 109 | ''' 110 | class HashedEmbeddingBag(nn.Module): 111 | def __init__( 112 | self, 113 | num_embeddings: int, 114 | embedding_dim: int, 115 | compression:float = 1. / 64., 116 | mode:str = "sum", 117 | _weight: Optional[torch.Tensor] = None, 118 | signature: Optional[torch.Tensor] = None, 119 | key_bits=4, 120 | keys_to_use=8, 121 | hmode = "rand_hash", 122 | keymode = "keymode_hashweight", 123 | val_offset = None, 124 | seed = 1024, 125 | uma_chunk_size = 1, 126 | padding_idx = None)->None: 127 | super(HashedEmbeddingBag, self).__init__() 128 | self.num_embeddings = num_embeddings 129 | self.embedding_dim = embedding_dim 130 | memory = int(num_embeddings * embedding_dim * compression + 1) 131 | #memory = int(np.exp2(int(np.log2(memory)))) # make sure it is power of 2 132 | self.weight_size = memory 133 | if keymode != "keymode_hashweight": 134 | assert(_weight is None) 135 | self.val_offset = val_offset 136 | self.mode = mode 137 | self.hmode = hmode 138 | self.keymode = keymode 139 | self.signature = signature 140 | self.norm = None 141 | self.key_bits = key_bits 142 | self.keys_to_use = keys_to_use 143 | self.uma_chunk_size = uma_chunk_size 144 | self.padding_idx = padding_idx 145 | r = np.random.RandomState(seed) 146 | random_numbers = np.concatenate([np.array([2038074743]), r.randint(0, 2038074743, (50,))]) # set of 50 random numbers to use 147 | self.random_numbers = Parameter(torch.from_numpy(random_numbers.astype(np.int64)), requires_grad=False) 148 | print("RandomNumbers: ", self.random_numbers[:5]) 149 | 150 | if self.signature is None: 151 | val = np.zeros(shape=(2,)) 152 | self.signature = Parameter(torch.from_numpy(val.astype(np.int64)), requires_grad=False) 153 | if _weight is None : 154 | if keymode == "keymode_hashweight": 155 | low = -math.sqrt(1 / self.num_embeddings) 156 | high = math.sqrt(1 / self.num_embeddings) 157 | self.hashed_weight = Parameter(torch.rand(self.weight_size) * (high - low) + low) 158 | else: 159 | self.weight_size = 2 160 | val = np.random.uniform(low = -1, high = 1, size=(self.weight_size,)) 161 | self.hashed_weight = Parameter(torch.from_numpy(val.astype(np.float32)), requires_grad=False) 162 | self.hashed_weight.requires_grad = False 163 | self.norm = (self.embedding_dim / 32) 164 | #self.norm = np.sqrt(self.embedding_dim) 165 | 166 | self.central = False 167 | #self.reset_parameters() 168 | print("Inside HashedEmbeddingBag (after reset): ", num_embeddings, embedding_dim, compression, self.weight_size, self.hashed_weight.shape) 169 | else: 170 | #assert len(_weight.shape) == 1 and _weight.shape[0] == weight_size, \ 171 | # 'Shape of weight does not match num_embeddings and embedding_dim' 172 | print("Central weight", hmode, "val_offset", self.val_offset) 173 | self.hashed_weight = _weight 174 | self.weight_size = self.hashed_weight.numel() 175 | self.central = True 176 | assert(self.val_offset is not None) 177 | self.weight = self.hashed_weight 178 | print("HashedEmbeddingBag: ", num_embeddings, embedding_dim, "mode", mode, 179 | "hmode", hmode, "kmode", keymode, "central", self.central, "key_bits", self.key_bits, 180 | "keys_to_use", self.keys_to_use, 181 | "weight_size", self.weight_size, 182 | "uma_chunk_size", self.uma_chunk_size, 183 | "seed", seed) 184 | """ 185 | def reset_parameters(self) -> None: 186 | # init.normal_(self.weight) 187 | W = np.random.uniform( 188 | low=-np.sqrt(1 / self.num_embeddings), high=np.sqrt(1 / self.num_embeddings), size=(self.hashed_weight.shape[0], ) 189 | ).astype(np.float32) 190 | self.hashed_weight.data = torch.tensor(W, requires_grad=True) 191 | """ 192 | def forward(self, indices: torch.Tensor, offsets: Optional[torch.Tensor] = None, per_sample_weights=None) -> torch.Tensor: 193 | i_shape = indices.shape 194 | indices = indices.view(-1) 195 | if self.padding_idx is not None: 196 | original_count = indices.shape[0] 197 | indx_mask = (indices != self.padding_idx) 198 | indx_padd_mask = (indices == self.padding_idx) 199 | indices = indices[indx_mask] 200 | 201 | if offsets is None: 202 | offsets = torch.arange(len(indices)).to(indices.device) 203 | 204 | assert(per_sample_weights is None) 205 | embeddings = HashedEmbeddingBagFunction.apply( 206 | self.hashed_weight, 207 | indices, 208 | offsets, 209 | self.mode, 210 | self.embedding_dim, 211 | self.signature, 212 | self.random_numbers, 213 | self.hmode, 214 | self.keymode, 215 | self.val_offset, 216 | self.norm, 217 | self.key_bits, 218 | self.keys_to_use, 219 | self.uma_chunk_size 220 | ) 221 | if self.padding_idx is not None: 222 | Aembeddings = torch.zeros(original_count, self.embedding_dim, device=indices.device) 223 | Aembeddings[indx_mask,:] = embeddings[:,:] 224 | embeddings = Aembeddings 225 | embeddings = embeddings.view(*i_shape, embeddings.shape[-1]) 226 | return embeddings 227 | 228 | class SecondaryLearnedEmbedding(nn.Module): 229 | def __init__(self, underlying_embedding, learn_model): 230 | super(SecondaryLearnedEmbedding, self).__init__() 231 | self.underlying_embedding = underlying_embedding 232 | self.learn_model = learn_model 233 | self.weight = underlying_embedding.weight 234 | 235 | def forward(self, indices: torch.Tensor, offsets: Optional[torch.Tensor] = None) -> torch.Tensor: 236 | i_shape = indices.shape 237 | primary_embedding = self.underlying_embedding(indices, offsets) 238 | e_shape = primary_embedding.shape 239 | primary_embedding = primary_embedding.view(-1, e_shape[-1]) 240 | secondary_embedding = self.learn_model(primary_embedding) 241 | secondary_embedding = secondary_embedding.view(*i_shape, secondary_embedding.shape[-1]) 242 | return secondary_embedding 243 | 244 | 245 | def get_mlplearned_embedding(underlying_embedding, str_mlp, dev="cuda:0"): 246 | ls = [ int(x) for x in str_mlp.split('-')] 247 | mlp_model = nn.ModuleList() 248 | for i in range(0, len(ls) - 2): 249 | mlp_model.append(nn.Linear(ls[i], ls[i+1])) 250 | mlp_model.append(nn.ReLU()) 251 | mlp_model.append(nn.Linear(ls[len(ls)-2], ls[len(ls) - 1])) 252 | mlp_model = torch.nn.Sequential(*mlp_model).to(dev) 253 | return SecondaryLearnedEmbedding(underlying_embedding, mlp_model).to(dev) 254 | 255 | 256 | class FunctionalEmbedding(nn.Module): 257 | def __init__(self, embedding_dim, learn_model, val_offset): 258 | super(FunctionalEmbedding, self).__init__() 259 | self.embedding_dim = embedding_dim 260 | self.learn_model = learn_model 261 | self.val_offset = val_offset 262 | 263 | self.num_hashes = int((embedding_dim + 31) / 32) 264 | r = np.random.RandomState(1234) 265 | A = r.randint(0, 2**32-1, size=(1, self.num_hashes))*2-1 # odd 266 | B = r.randint(0, 2**32-1, size=(1, self.num_hashes))*2-1 # odd 267 | self.A = torch.from_numpy(A).to("cuda:0") 268 | self.B = torch.from_numpy(B).to("cuda:0") 269 | self.bits = 32 270 | mask = 2**torch.arange(32) 271 | self.mask = mask.to("cuda:0") 272 | 273 | def forward(self, indices: torch.Tensor) -> torch.Tensor: 274 | # indices are N x 1 275 | hashes = (indices * self.A + self.B) # no mod because a and b are odd like taking mod 2^32 276 | bithashes = hashes.unsqueeze(-1).bitwise_and(self.mask).ne(0) * 2.0 - 1.0 277 | bithashes = bithashes.view(hashes.shape[0], -1) 278 | input_mlp = bithashes[:,:self.embedding_dim] 279 | if self.learn_model is not None: 280 | return self.learn_model(input_mlp) 281 | else: 282 | return input_mlp 283 | 284 | 285 | def get_functional_embedding(embedding_dim, str_mlp, dev="cuda:0"): 286 | if str_mlp is None: 287 | return FunctionalEmbedding(embedding_dim, None, 0).to(dev) 288 | ls = [ int(x) for x in str_mlp.split('-')] 289 | mlp_model = nn.ModuleList() 290 | for i in range(0, len(ls) - 2): 291 | mlp_model.append(nn.Linear(ls[i], ls[i+1])) 292 | mlp_model.append(nn.ReLU()) 293 | mlp_model.append(nn.Linear(ls[len(ls)-2], ls[len(ls) - 1])) 294 | mlp_model = torch.nn.Sequential(*mlp_model).to(dev) 295 | return FunctionalEmbedding(embedding_dim, mlp_model, 0).to(dev) 296 | -------------------------------------------------------------------------------- /hashedEmbeddingCPU.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import numpy as np 7 | from torch.nn.parameter import Parameter 8 | import math 9 | import pdb 10 | 11 | 12 | class UMAEmbeddingFunc(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, hashed_weights, indices, embedding_dim, val_offset, P, A, B, C, hashed_weights_size, helper_E1sR, helper_Eidx_base, helper_Eidx_offset, uma_chunk_size): 15 | assert(indices.dim() == 1) # indices has tobe a one dimensional array of integers. 16 | 17 | #hashed_idx = (((((indices.view(-1,1) + val_offset) * helper_E1sR) % P + helper_Eidx * B)%P + A) % P) % hashed_weights_size 18 | hashed_idx = ((((((indices.view(-1,1) + val_offset) * helper_E1sR) + helper_Eidx_base * B) + A) % P) % (hashed_weights_size -uma_chunk_size +1) + helper_Eidx_offset) 19 | output = hashed_weights[hashed_idx] 20 | #output, hashed_idx = \ 21 | # rma.forward(hashed_weights, indices, embedding_dim, random_numbers, val_offset) 22 | ctx.save_for_backward(indices, hashed_idx) 23 | ctx.hashed_weights_size = hashed_weights_size 24 | return output 25 | 26 | 27 | @staticmethod 28 | def backward(ctx, grad): 29 | indices, hashed_idx = ctx.saved_variables 30 | hashed_weights_size = ctx.hashed_weights_size 31 | if hashed_idx.is_contiguous(): 32 | hashed_idx1 = hashed_idx.view(-1) 33 | else: 34 | hashed_idx1 = hashed_idx.reshape(-1) 35 | if grad.is_contiguous(): 36 | grad1 = grad.view(-1) 37 | else: 38 | grad1 = grad.reshape(-1) 39 | weight_grad = torch.zeros((hashed_weights_size,),dtype=torch.float32, device=indices.device) 40 | weight_grad.scatter_add_(0, hashed_idx1, grad1) 41 | #weight_grad = rma.backward( 42 | # grad, indices, hashed_idx, hashed_weights_size, embedding_dim) 43 | return weight_grad, None, None, None, None, None, None, None, None, None, None, None, None 44 | 45 | class HashedEmbeddingCPU(nn.Module): 46 | def __init__( 47 | self, 48 | num_embeddings: int, 49 | embedding_dim: int, 50 | _weight: torch.Tensor, 51 | val_offset: int, 52 | uma_chunk_size = 1, 53 | seed = 1024)->None: 54 | super(HashedEmbeddingCPU, self).__init__() 55 | self.num_embeddings = num_embeddings 56 | self.embedding_dim = embedding_dim 57 | self.val_offset = val_offset 58 | self.seed = seed 59 | self.weight = nn.Parameter(_weight, requires_grad = True) # add to parameter 60 | self.weights_size = self.weight.numel() 61 | self.uma_chunk_size = uma_chunk_size 62 | 63 | 64 | r = np.random.RandomState(seed) 65 | random_numbers = np.concatenate([np.array([2038074743]), r.randint(0, 2038074743, (10,))]) # 10 random numbers 66 | random_numbers = torch.from_numpy(random_numbers.astype(np.int64)) 67 | print("[Seed]", seed, "First 5 random numbers: ", random_numbers[:5]) 68 | print("UMA Embedding Object: num_embeddings:{} dim:{} val_offset:{} seed:{} weights_size:{} uma_chunk_size:{}".format(self.num_embeddings, self.embedding_dim, 69 | self.val_offset, self.seed, self.weights_size, self.uma_chunk_size), flush=True) 70 | 71 | # helpers to compute 72 | helper_Eidx_base = torch.LongTensor(np.arange(self.embedding_dim) / self.uma_chunk_size) 73 | helper_Eidx_offset = torch.LongTensor(np.arange(self.embedding_dim) % self.uma_chunk_size) 74 | helper_E1sR = torch.LongTensor(np.ones(self.embedding_dim) * int(random_numbers[3])) # A 75 | 76 | # adding to parameters 77 | self.random_numbers = nn.Parameter(random_numbers, requires_grad=False) 78 | self.helper_Eidx_base = nn.Parameter(helper_Eidx_base, requires_grad=False) 79 | self.helper_Eidx_offset = nn.Parameter(helper_Eidx_offset, requires_grad=False) 80 | self.helper_E1sR = nn.Parameter(helper_E1sR, requires_grad=False) 81 | 82 | 83 | def forward(self, indices: torch.Tensor, offsets=None, per_sample_weights=None) -> torch.Tensor: 84 | 85 | #def forward(ctx, hashed_weights, indices, embedding_dim, val_offset, P, A, B, hashed_weights_size, helper_E1sR, helper_Eidx): 86 | embeddings = UMAEmbeddingFunc.apply( 87 | self.weight, 88 | indices, 89 | self.embedding_dim, 90 | self.val_offset, 91 | self.random_numbers[0], 92 | self.random_numbers[1], 93 | self.random_numbers[2], 94 | self.random_numbers[3], 95 | self.weights_size, 96 | self.helper_E1sR, 97 | self.helper_Eidx_base, 98 | self.helper_Eidx_offset, 99 | self.uma_chunk_size 100 | ) 101 | return embeddings 102 | -------------------------------------------------------------------------------- /hashed_embedding_bag_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | 20 | 21 | constexpr int MODE_SUM = 0; 22 | constexpr int MODE_MEAN = 1; 23 | constexpr int MODE_MAX = 2; 24 | 25 | constexpr int NWEIGHT_PER_THREAD = 128; 26 | constexpr int BIT4MASK = 15; 27 | constexpr int64_t BIT32MASK = ((1u <<31u) - 1u); 28 | 29 | constexpr int HMODE_RANDOMHASH = 0; 30 | constexpr int HMODE_LMAHASH = 1; 31 | 32 | constexpr int KEYMODE_HASHWEIGHT = 0; 33 | constexpr int KEYMODE_STATIC_PM = 1; 34 | 35 | // Fast ceil division (no overflow checking) 36 | __host__ __device__ __forceinline__ 37 | int64_t ceil_div(int64_t x, int64_t y) { 38 | return (x + y - 1) / y; 39 | } 40 | 41 | __global__ 42 | void krn_partials_per_segment(int64_t *ret, const int64_t *segment_offsets, 43 | int64_t num_segments, int64_t numel) { 44 | const int id = blockIdx.x * blockDim.x + threadIdx.x; 45 | if(id < num_segments) { 46 | const int64_t idx_start = segment_offsets[id]; 47 | const int64_t idx_end = (id == num_segments-1)?numel:segment_offsets[id+1]; 48 | const int64_t size = idx_end - idx_start; 49 | ret[id] = ceil_div(size, NWEIGHT_PER_THREAD); 50 | } 51 | } 52 | 53 | __global__ 54 | void krn_partial_segment_offset( 55 | int64_t *ret, 56 | const int64_t *partials_per_segment, 57 | const int64_t *partials_per_segment_offset, 58 | const int64_t *segment_offsets, 59 | int64_t num_segments) { 60 | const int id = blockIdx.x * blockDim.x + threadIdx.x; 61 | if(id < num_segments) { 62 | int64_t idx = partials_per_segment_offset[id]; 63 | const int64_t num_partials = partials_per_segment[id]; 64 | const int64_t segment_offset = segment_offsets[id]; 65 | for (int64_t i=0; i 87 | __device__ __host__ scalar_t keymode_static_pm(int64_t a) { 88 | int64_t val = (a * 71371560971u + 46023704752u) % 100000004987u % 2u; 89 | if (val == 0) { 90 | return -1.0; 91 | } else { 92 | return 1.0; 93 | } 94 | 95 | //return a * 16 + b; 96 | } 97 | 98 | /* fast way to map to +1.-1 */ 99 | template 100 | __device__ __host__ scalar_t keymode_static_pm_parity(int64_t a) { 101 | int64_t val = (a * 71371560971u + 46023704752u) % 100000004987u; 102 | int64_t val1 = val ^ (val>>1); // get parity to decide between +1/-1 103 | val1 = val1 ^ (val1 >> 2); 104 | val1 = val1 ^ (val1 >> 4); 105 | val1 = val1 ^ (val1 >> 8); 106 | val1 = val1 ^ (val1 >> 16); 107 | val1 = val1 ^ (val1 >> 32); 108 | if (val1 & 1) { 109 | return 1.0; 110 | }else{ 111 | return -1.0; 112 | } 113 | //return a * 16 + b; 114 | } 115 | 116 | 117 | __device__ __host__ int64_t lma_hash_func(int64_t v, int64_t i, int64_t signature) { 118 | // input is value, embedding_location, signature 4x16 bit representation which is 119 | // drawn from signature array[value] 120 | // a and b for hashing % 17 % 15 to choose from 16 minhashes 121 | // 9, 1, 14, 2, 10, 10, 2, 8, 122 | // 14, 6, 10, 4, 1, 14, 12, 12]) 123 | // P = 100000004987 124 | // 27099547127, 2699391407, 46970219979, 16806825237, 74212261504, 93432047494, 16220329892, 82313724554, 125 | // 50469911173, 52271898367, 98939193954, 94293094042, 96314459732, 2349378832, 1727459397, 48438134705 126 | 127 | int64_t extracted = ((signature >> (4*((82313724554*i+48438134705)% 100000004987 %15))) & BIT4MASK) // 4 bit number 128 | ^ (((signature >> (4*((27099547127*i+50469911173 )% 100000004987%15))) & BIT4MASK) << 4) 129 | ^ (((signature >> (4*((2699391407*i+52271898367)% 100000004987 %15))) & BIT4MASK) << 8) 130 | ^ (((signature >> (4*((46970219979*i+98939193954)% 100000004987 %15))) & BIT4MASK) << 12) 131 | ^ (((signature >> (4*((16806825237*i+94293094042)% 100000004987 %15))) & BIT4MASK) << 16) 132 | ^ (((signature >> (4*((74212261504*i+96314459732)% 100000004987 %15))) & BIT4MASK) << 20) 133 | ^ (((signature >> (4*((93432047494*i+2349378832)% 100000004987 %15))) & BIT4MASK) << 24) 134 | ^ (((signature >> (4*((16220329892*i+1727459397)% 100000004987 %15))) & BIT4MASK) << 28); 135 | return (int64_t) extracted; // extracted is a 32 bit number 136 | } 137 | 138 | 139 | __device__ __host__ int64_t lma_hash_func_e1(int64_t v, int64_t i, int64_t signature, // still assuming signature is 64 bit 140 | int64_t key_bits, int64_t keys_to_use, int64_t * random_numbers) { 141 | /* 142 | Memory based optimizations: 143 | put random_numbers into __constant__ memory 144 | 145 | code based 146 | make keys_to_use into template parameter and foward declare it with all different values 147 | 148 | 149 | 150 | */ 151 | CUDA_KERNEL_ASSERT(keys_to_use == 1 or keys_to_use == 2 or keys_to_use == 4 or keys_to_use == 6 or keys_to_use == 8 or keys_to_use == 12 or keys_to_use == 16); 152 | int64_t total_bits = key_bits * keys_to_use; 153 | CUDA_KERNEL_ASSERT(total_bits < 60); 154 | int64_t bitmask = (1 << key_bits) - 1; 155 | int64_t numkeys = 64/key_bits -1; 156 | int64_t extracted = ((signature >> (key_bits*((random_numbers[11]*i+random_numbers[12])% random_numbers[0] %numkeys))) & bitmask) ;// key_bits bit number 157 | if (keys_to_use >= 2) 158 | extracted ^= (((signature >> (key_bits*((random_numbers[13]*i+random_numbers[14] )% random_numbers[0]%numkeys))) & bitmask) << key_bits); 159 | if (keys_to_use >=4) { 160 | extracted ^= (((signature >> (key_bits*((random_numbers[15]*i+random_numbers[16])% random_numbers[0] %numkeys))) & bitmask) << 2*key_bits); 161 | extracted ^= (((signature >> (key_bits*((random_numbers[17]*i+random_numbers[18])% random_numbers[0] %numkeys))) & bitmask) << 3*key_bits); 162 | } 163 | if (keys_to_use >= 6) { 164 | extracted ^= (((signature >> (key_bits*((random_numbers[19]*i+random_numbers[20])% random_numbers[0] %numkeys))) & bitmask) << 4*key_bits); 165 | extracted ^= (((signature >> (key_bits*((random_numbers[21]*i+random_numbers[22])% random_numbers[0] %numkeys))) & bitmask) << 5*key_bits); 166 | } 167 | if (keys_to_use >= 8) { 168 | extracted ^= (((signature >> (key_bits*((random_numbers[23]*i+random_numbers[24])% random_numbers[0] %numkeys))) & bitmask) << 6*key_bits); 169 | extracted ^= (((signature >> (key_bits*((random_numbers[25]*i+random_numbers[26])% random_numbers[0] %numkeys))) & bitmask) << 7*key_bits); 170 | } 171 | /* 172 | array([[22406334177, 63792722443], 173 | [75791256117, 15202366190], 174 | [40623773873, 8640139384], 175 | [13655260797, 99959231757], 176 | [21577857905, 50989087799], 177 | [ 8043429682, 29709184765], 178 | [95200260355, 49014991094], 179 | [36941582829, 21960689983]]) 180 | */ 181 | if (keys_to_use >= 12) { 182 | extracted ^= (((signature >> (key_bits*((random_numbers[27]*i+random_numbers[28])% random_numbers[0] %numkeys))) & bitmask) << 8*key_bits); 183 | extracted ^= (((signature >> (key_bits*((random_numbers[29]*i+random_numbers[30])% random_numbers[0] %numkeys))) & bitmask) << 9*key_bits); 184 | extracted ^= (((signature >> (key_bits*((random_numbers[31]*i+random_numbers[32])% random_numbers[0] %numkeys))) & bitmask) << 10*key_bits); 185 | extracted ^= (((signature >> (key_bits*((random_numbers[33]*i+random_numbers[34])% random_numbers[0] %numkeys))) & bitmask) << 11*key_bits); 186 | } 187 | if (keys_to_use >= 16) { 188 | extracted ^= (((signature >> (key_bits*((random_numbers[35]*i+random_numbers[36])% random_numbers[0] %numkeys))) & bitmask) << 12*key_bits); 189 | extracted ^= (((signature >> (key_bits*((random_numbers[37]*i+random_numbers[38])% random_numbers[0] %numkeys))) & bitmask) << 13*key_bits); 190 | extracted ^= (((signature >> (key_bits*((random_numbers[39]*i+random_numbers[40])% random_numbers[0] %numkeys))) & bitmask) << 14*key_bits); 191 | extracted ^= (((signature >> (key_bits*((random_numbers[41]*i+random_numbers[42])% random_numbers[0] %numkeys))) & bitmask) << 15*key_bits); 192 | } 193 | 194 | // extracted uses i for consistent usage from small storage offered by signature 195 | // return value has to be a hash of (extracted, i) 196 | int64_t hash = hash_func(extracted, i, random_numbers); 197 | //int64_t hash = (1<> (key_bits*((random_numbers[11]*i+random_numbers[12])% random_numbers[0] %numkeys))) & bitmask) ;// key_bits bit number 208 | for (int k=1; k < keys_to_use;k++) { 209 | extracted ^= (((signature >> (key_bits*((random_numbers[10 + 2*k+1]*i+random_numbers[10+2*k+2] )% random_numbers[0]%numkeys))) & bitmask) << k*key_bits); 210 | } 211 | // extracted uses i for consistent usage from small storage offered by signature 212 | // return value has to be a hash of (extracted, i) 213 | int64_t hash = hash_func(extracted, i, random_numbers); 214 | //int64_t hash = (1< 220 | __global__ void hashed_embedding_bag_update_output_kernel( 221 | const torch::PackedTensorAccessor32 input, 222 | const torch::PackedTensorAccessor32 offsets, 223 | const torch::PackedTensorAccessor32 hashed_weights, 224 | torch::PackedTensorAccessor32 output, 225 | torch::PackedTensorAccessor32 offset2bag, 226 | int64_t numIndices, 227 | int64_t numBags, 228 | int64_t embedding_dim, 229 | int64_t hashedWeightSize, 230 | int mode, 231 | torch::PackedTensorAccessor32 hashed_index, 232 | torch::PackedTensorAccessor32 bag_size, 233 | torch::PackedTensorAccessor32 max_indices, 234 | torch::PackedTensorAccessor32 signature, 235 | int64_t * random_numbers, 236 | int hmode, 237 | int keymode, 238 | int key_bits, 239 | int keys_to_use, 240 | int uma_chunk_size) 241 | { 242 | /* 243 | optimizations. modes into template paramters 244 | accessor to pointers? 245 | 246 | */ 247 | // the strategy here is that each bag x feature is handled by a single thread 248 | 249 | int64_t chunksPerBag = (embedding_dim + (int64_t)blockDim.x - 1) / (int64_t)blockDim.x; 250 | int64_t numChunks = numBags * chunksPerBag; 251 | int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; 252 | int64_t chunkStride = gridDim.x * blockDim.y; 253 | 254 | for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { 255 | int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x; 256 | if (featureDim < embedding_dim) { 257 | int64_t bag = chunk / chunksPerBag; 258 | int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it 259 | int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices; 260 | CUDA_KERNEL_ASSERT(end >= begin); 261 | 262 | scalar_t weightFeatSum = 0; 263 | scalar_t weightFeatMax; 264 | 265 | int64_t bag_size_ = 0; 266 | int64_t maxWord = -1; 267 | // from start of bag to end of bag. 268 | int64_t hfd = featureDim / uma_chunk_size; 269 | int64_t hfd_shift = featureDim % uma_chunk_size; 270 | for (int64_t emb = begin; emb < end; emb++) { 271 | const int64_t weightRow = input[emb]; 272 | 273 | int64_t hashKey = 0; 274 | int64_t hashedWeightIdx = 0; 275 | scalar_t weightValue = 0; 276 | 277 | switch (hmode) { 278 | case HMODE_LMAHASH: 279 | hashKey = lma_hash_func_e2(weightRow, hfd, signature[weightRow], key_bits, keys_to_use, random_numbers); // expects a val_offset + value 280 | break; 281 | default: // HMODE_RANDOMHASH 282 | // this will be recomputed within uma_chunk_size. But i think if we want to not do that we need a better grid layout 283 | hashKey = hash_func(weightRow, hfd, random_numbers); // expects a val_offset + value if central 284 | break; 285 | } 286 | 287 | switch (keymode) { 288 | case KEYMODE_STATIC_PM: 289 | weightValue = keymode_static_pm_parity(hashKey); 290 | break; 291 | default: // KEYMODE_HASHWEIGHT 292 | hashedWeightIdx = hashKey % (hashedWeightSize - uma_chunk_size + 1)+ hfd_shift; 293 | hashed_index[emb][featureDim] = hashedWeightIdx; 294 | weightValue = hashed_weights[hashedWeightIdx]; 295 | break; 296 | } 297 | 298 | 299 | if (mode == MODE_MAX) { 300 | if (emb == begin || weightValue > weightFeatMax) { 301 | weightFeatMax = weightValue; 302 | maxWord = input[emb]; 303 | } 304 | } else { 305 | weightFeatSum += static_cast(weightValue); 306 | } 307 | 308 | bag_size_++; 309 | if (featureDim == 0) { 310 | offset2bag[emb] = bag; 311 | } 312 | } 313 | if (mode == MODE_MEAN) { 314 | if (end == begin) { 315 | bag_size[bag] = 0; 316 | } else { 317 | weightFeatSum = weightFeatSum / static_cast(bag_size_); 318 | bag_size[bag] = bag_size_; 319 | } 320 | } 321 | 322 | if (mode == MODE_MEAN || mode == MODE_SUM) { 323 | output[bag][featureDim] = static_cast(weightFeatSum); 324 | } 325 | else if (mode == MODE_MAX) { 326 | if (end == begin) { 327 | // If bag is empty, set output to 0. 328 | weightFeatMax = 0; 329 | } 330 | max_indices[bag][featureDim] = maxWord; 331 | output[bag][featureDim] = weightFeatMax; 332 | } 333 | } 334 | } 335 | } 336 | 337 | template 338 | __global__ void compute_grad_weight_bags( 339 | torch::PackedTensorAccessor32 orig_hash_idx_idx, 340 | torch::PackedTensorAccessor32 output_grad, 341 | torch::PackedTensorAccessor32 offset2bag, 342 | int64_t embedding_dim, 343 | int64_t numel, 344 | torch::PackedTensorAccessor32 partial_segment_offset, 345 | int64_t num_of_partial_segments, 346 | torch::PackedTensorAccessor32 grad_weight_per_partial 347 | ) 348 | { 349 | const int partial_id = blockIdx.x * blockDim.x + threadIdx.x; 350 | if (partial_id >= num_of_partial_segments) { 351 | return; 352 | } 353 | const int idx_begin = partial_segment_offset[partial_id]; 354 | const int idx_end = (partial_id == num_of_partial_segments - 1) ? numel : partial_segment_offset[partial_id + 1]; 355 | 356 | scalar_t grad_acc = 0; 357 | for (int idx = idx_begin; idx < idx_end; ++idx) { 358 | const int orig_hash_idx = orig_hash_idx_idx[idx]; // orig_idx in range [0, |indices| x embedding_dim) 359 | const int orig_cat_idx = orig_hash_idx / embedding_dim; // in range [0, |indices|) 360 | const int feature_idx = orig_hash_idx % embedding_dim; // in range [0, embedding_dim) 361 | const int bag_idx = offset2bag[orig_cat_idx]; 362 | grad_acc += output_grad[bag_idx][feature_idx]; 363 | } 364 | grad_weight_per_partial[partial_id] = grad_acc; 365 | 366 | } 367 | 368 | template 369 | __global__ void sum_and_scatter( 370 | torch::PackedTensorAccessor32 sorted_unique_weight_idx, 371 | torch::PackedTensorAccessor32 grad_weight_per_segment, 372 | torch::PackedTensorAccessor32 partical_per_segment_offset, 373 | int64_t num_segments, 374 | int64_t num_of_partial_segments, 375 | torch::PackedTensorAccessor32 weight_grad 376 | ) 377 | { 378 | const int gid = blockIdx.x * blockDim.x + threadIdx.x; 379 | if (gid >= num_segments) { 380 | return; 381 | } 382 | const int weight_idx = sorted_unique_weight_idx[gid]; 383 | 384 | const int idx_begin = partical_per_segment_offset[gid]; 385 | const int idx_end = (gid == num_segments - 1) ? num_of_partial_segments : partical_per_segment_offset[gid + 1]; 386 | scalar_t grad_acc = 0; 387 | for (int idx = idx_begin; idx < idx_end; ++idx) { 388 | grad_acc += grad_weight_per_segment[idx]; 389 | } 390 | weight_grad[weight_idx] = grad_acc; 391 | } 392 | 393 | std::tuple hashed_embedding_bag_cuda_forward( 394 | const torch::Tensor& hashed_weights, 395 | const torch::Tensor& indices, 396 | const torch::Tensor& offsets, 397 | const int64_t mode, 398 | const int64_t embedding_dim, 399 | const torch::Tensor& signature, 400 | const torch::Tensor& random_numbers, 401 | const int64_t hmode, 402 | const int64_t keymode, 403 | const int64_t key_bits, 404 | const int64_t keys_to_use, 405 | const int64_t uma_chunk_size) 406 | { 407 | int64_t numIndices = indices.size(0); 408 | int64_t numBags = offsets.size(0); 409 | 410 | int64_t hashedWeightSize = 0; 411 | if (keymode == KEYMODE_HASHWEIGHT) { 412 | hashedWeightSize = hashed_weights.size(0); 413 | } 414 | auto bag_size = at::empty(offsets.sizes(), indices.options()); 415 | auto offset2bag = 416 | at::empty({indices.size(0)}, indices.options()); 417 | auto hashed_index = at::empty({indices.size(0), embedding_dim}, indices.options()); 418 | auto output = at::empty({numBags, embedding_dim}, hashed_weights.options()); // this gets initialized on CUDA:0 even if hashed_weights is on CUDA:1 why?? 419 | torch::Tensor max_indices; 420 | if (mode == MODE_MAX) { 421 | max_indices = at::empty({numBags, embedding_dim}, indices.options()); 422 | } else { 423 | max_indices = at::empty({0, 0}, indices.options()); 424 | } 425 | 426 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(indices.device().index()); 427 | 428 | #ifdef __HIP_PLATFORM_HCC__ 429 | dim3 block = dim3(64, 4); 430 | #else 431 | dim3 block = dim3(32, 8); 432 | #endif 433 | int grid = 1024; // TODO 2: fix grid size as per size of the index. maybe have a max cap. But having 1024 direclty will be sub-optimial 434 | 435 | AT_DISPATCH_FLOATING_TYPES(hashed_weights.type(), "hashed_embedding_bag_cuda", ([&] { 436 | hashed_embedding_bag_update_output_kernel<<>>( 437 | indices.packed_accessor32(), 438 | offsets.packed_accessor32(), 439 | hashed_weights.packed_accessor32(), 440 | output.packed_accessor32(), 441 | offset2bag.packed_accessor32(), 442 | numIndices, 443 | numBags, 444 | embedding_dim, 445 | hashedWeightSize, 446 | mode, 447 | hashed_index.packed_accessor32(), 448 | bag_size.packed_accessor32(), 449 | max_indices.packed_accessor32(), 450 | signature.packed_accessor32(), 451 | random_numbers.data_ptr(), 452 | hmode, 453 | keymode, 454 | key_bits, 455 | keys_to_use, 456 | uma_chunk_size); 457 | })); 458 | //cudaDeviceSynchronize(); // TODO 1: remove this. this will wait for all sreams to synchronize. we dont want that. 459 | // instead use cudaStreamSynchronize 460 | 461 | return std::tuple( 462 | output, offset2bag, bag_size, max_indices, hashed_index); 463 | } 464 | 465 | torch::Tensor hashed_embedding_bag_sum_backward( 466 | const torch::Tensor& output_grad, 467 | const torch::Tensor& indices, 468 | const torch::Tensor& offsets, 469 | const torch::Tensor& offset2bag, 470 | const torch::Tensor& hash_index, 471 | 472 | int64_t num_weights, 473 | int64_t embedding_dim) 474 | { 475 | int64_t numIndices = indices.size(0); 476 | int64_t numBags = offsets.size(0); 477 | torch::Tensor weight_grad = torch::zeros({num_weights}, output_grad.options()); 478 | 479 | if (numIndices == 0) { 480 | return weight_grad; 481 | } 482 | 483 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(indices.device().index()); 484 | torch::Tensor flattened_hash_index = hash_index.flatten(); 485 | int64_t numel = flattened_hash_index.size(0); 486 | 487 | // hash_index is a |indices| x embedding_dim Tensor, contains the index in hashed weight for each input indices x embedding dim. 488 | // the hash_index is flattened, and then we want to sort it, we use orig_hash_idx_idx to keep track of its orignal indices. 489 | auto sorted_hash_idx = at::empty_like(flattened_hash_index, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 490 | auto orig_hash_idx_idx = at::empty_like(flattened_hash_index, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 491 | using device_ptr = thrust::device_ptr; 492 | { 493 | sorted_hash_idx.copy_(flattened_hash_index); 494 | 495 | 496 | auto count_iter = thrust::counting_iterator(0); 497 | auto orig_hash_idx_idx_data = device_ptr(orig_hash_idx_idx.data_ptr()); 498 | thrust::copy(count_iter, count_iter + numel, orig_hash_idx_idx_data); 499 | 500 | auto sorted_hash_idx_data = device_ptr(sorted_hash_idx.data_ptr()); 501 | thrust::sort_by_key( 502 | sorted_hash_idx_data, 503 | sorted_hash_idx_data + numel, 504 | orig_hash_idx_idx_data); 505 | } 506 | 507 | // There may be many duplicates in the hash_index, now it's sorted, we find the start index for each hash_index value. 508 | // then we can get the count for each hash_index_value. 509 | auto segment_offsets = at::empty({numel}, orig_hash_idx_idx.options()); 510 | auto sorted_unique_weight_idx = at::empty_like(sorted_hash_idx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 511 | int64_t num_segments; 512 | { 513 | auto sorted_hash_idx_data = device_ptr(sorted_hash_idx.data_ptr()); 514 | auto sorted_unique_weight_idx_data = device_ptr(sorted_unique_weight_idx.data_ptr()); 515 | auto iter_end_pair = thrust::unique_by_key_copy( 516 | sorted_hash_idx_data, 517 | sorted_hash_idx_data + numel, 518 | thrust::make_counting_iterator(0), 519 | sorted_unique_weight_idx_data, 520 | thrust::device_ptr(segment_offsets.data_ptr()) 521 | ); 522 | num_segments = thrust::get<0>(iter_end_pair) - sorted_unique_weight_idx_data; 523 | } 524 | 525 | // We split the segments up into sizes of `NROWS_PER_THREAD` 526 | // Compute the number partial-segments per segment (some partial-segments 527 | // may not be the full `NROWS_PER_THREAD` number of rows) 528 | auto partials_per_segment = at::empty({num_segments}, orig_hash_idx_idx.options()); 529 | { 530 | krn_partials_per_segment<<>> ( 531 | partials_per_segment.data_ptr(), 532 | segment_offsets.data_ptr(), 533 | num_segments, 534 | numel); 535 | } 536 | 537 | 538 | // In order to compute `partial_segment_offset`, which is the start index 539 | // of each partial-segment in `sorted_indices`, we need to compute the 540 | // start position of each _segment_ in `partial_segment_offset`. 541 | // Unit: index in `partial_segment_offset` 542 | auto partials_per_segment_offset = at::empty({num_segments}, orig_hash_idx_idx.options()); 543 | thrust::exclusive_scan( 544 | device_ptr(partials_per_segment.data_ptr()), 545 | device_ptr(partials_per_segment.data_ptr() + num_segments), 546 | device_ptr(partials_per_segment_offset.data_ptr()) 547 | ); 548 | 549 | // The total number of partial-segments is the sum of `partials_per_segment_offset` 550 | const int num_of_partial_segments = partials_per_segment[num_segments - 1].item() + 551 | partials_per_segment_offset[num_segments - 1].item(); 552 | 553 | // Now we can compute the start position of each partial-segment 554 | // Unit: index in `sorted_indices` and `orig_indices` 555 | auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_hash_idx_idx.options()); 556 | { 557 | krn_partial_segment_offset<<>> ( 558 | partial_segment_offset.data_ptr(), 559 | partials_per_segment.data_ptr(), 560 | partials_per_segment_offset.data_ptr(), 561 | segment_offsets.data_ptr(), 562 | num_segments); 563 | } 564 | auto grad_weight_per_segment = at::empty({num_of_partial_segments}, weight_grad.options()); 565 | 566 | const int block = NWEIGHT_PER_THREAD; 567 | const int grid = ceil_div(num_of_partial_segments, block); 568 | AT_DISPATCH_ALL_TYPES(weight_grad.scalar_type(), "hashed_embedding_bag_backward_cuda", ([&] { 569 | compute_grad_weight_bags<<>>( 570 | orig_hash_idx_idx.packed_accessor32(), 571 | output_grad.packed_accessor32(), 572 | offset2bag.packed_accessor32(), 573 | embedding_dim, 574 | numel, 575 | partial_segment_offset.packed_accessor32(), 576 | num_of_partial_segments, 577 | grad_weight_per_segment.packed_accessor32() 578 | ); 579 | const int grid2 = ceil_div(num_segments, block); 580 | sum_and_scatter<<>>( 581 | sorted_unique_weight_idx.packed_accessor32(), 582 | grad_weight_per_segment.packed_accessor32(), 583 | partials_per_segment_offset.packed_accessor32(), 584 | num_segments, 585 | num_of_partial_segments, 586 | weight_grad.packed_accessor32() 587 | ); 588 | })); 589 | 590 | 591 | return weight_grad; 592 | } 593 | 594 | torch::Tensor hashed_embedding_bag_cuda_backward( 595 | const torch::Tensor& grad_, 596 | const torch::Tensor& indices, 597 | const torch::Tensor& offsets, 598 | const torch::Tensor& offset2bag, 599 | const torch::Tensor& bag_size_, 600 | const torch::Tensor& max_indices_, 601 | const torch::Tensor& hashed_index, 602 | int64_t num_weights, 603 | bool scale_grad_by_freq, 604 | int64_t mode, 605 | int64_t embedding_dim) 606 | { 607 | torch::Tensor grad = grad_.contiguous(); 608 | switch (mode) { 609 | case MODE_SUM: 610 | return hashed_embedding_bag_sum_backward( 611 | grad_, 612 | indices, 613 | offsets, 614 | offset2bag, 615 | hashed_index, 616 | num_weights, 617 | embedding_dim); 618 | case MODE_MEAN: 619 | case MODE_MAX: 620 | //return hashed_embedding_bag_cuda_max() 621 | default: 622 | return torch::Tensor(); 623 | } 624 | } 625 | 626 | // C++ interface 627 | 628 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 629 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 630 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 631 | 632 | std::tuple hashed_embedding_bag_forward( 633 | const torch::Tensor& hashed_weights, 634 | const torch::Tensor& indices, 635 | const torch::Tensor& offsets, 636 | //const bool scale_grad_by_freq, 637 | const int64_t mode, 638 | const int64_t embedding_dim, 639 | const torch::Tensor& signature, 640 | const torch::Tensor& random_numbers, 641 | const int64_t hmode, 642 | const int64_t keymode, 643 | const int64_t key_bits, 644 | const int64_t keys_to_use, 645 | const int64_t uma_chunk_size) 646 | { 647 | 648 | if (keymode == KEYMODE_HASHWEIGHT) { 649 | CHECK_INPUT(hashed_weights); 650 | } 651 | CHECK_INPUT(indices); 652 | CHECK_INPUT(offsets); 653 | if(hmode == HMODE_LMAHASH) { 654 | CHECK_INPUT(signature); 655 | } 656 | 657 | return hashed_embedding_bag_cuda_forward(hashed_weights, indices, offsets, mode, embedding_dim, signature, random_numbers, hmode, keymode, key_bits, keys_to_use, uma_chunk_size); 658 | } 659 | 660 | 661 | torch::Tensor hashed_embedding_bag_backward( 662 | const torch::Tensor& grad, 663 | const torch::Tensor& indices, 664 | const torch::Tensor& offsets, 665 | const torch::Tensor& offset2bag, 666 | const torch::Tensor& bag_size_, 667 | const torch::Tensor& max_indices_, 668 | const torch::Tensor& hashed_index_, 669 | int64_t num_weights, 670 | bool scale_grad_by_freq, 671 | int64_t mode, 672 | int64_t embedding_dim) 673 | { 674 | CHECK_CUDA(grad); 675 | CHECK_INPUT(indices); 676 | CHECK_INPUT(offsets); 677 | CHECK_INPUT(offset2bag); 678 | CHECK_INPUT(bag_size_); 679 | CHECK_INPUT(max_indices_); 680 | return hashed_embedding_bag_cuda_backward( 681 | grad, 682 | indices, 683 | offsets, 684 | offset2bag, 685 | bag_size_, 686 | max_indices_, 687 | hashed_index_, 688 | num_weights, 689 | scale_grad_by_freq, 690 | mode, 691 | embedding_dim 692 | ); 693 | } 694 | 695 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 696 | m.def("forward", &hashed_embedding_bag_forward, "hash embedding forward (CUDA)"); 697 | m.def("backward", &hashed_embedding_bag_backward, "hash embedding backward (CUDA)"); 698 | m.def("hash", &hash_func, "hash function"); 699 | m.def("lma_hash", &lma_hash_func, "lma hash function"); 700 | } 701 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='hashed_embedding_bag', 6 | ext_modules=[CUDAExtension( 7 | 'hashed_embedding_bag', 8 | [#'hashed_embedding_bag1.cpp', 9 | 'hashed_embedding_bag_kernel.cu'])], 10 | py_modules=['hashedEmbeddingBag', 'hashedEmbeddingCPU'], 11 | cmdclass={'build_ext': BuildExtension} 12 | ) 13 | --------------------------------------------------------------------------------