├── README.md └── stnet.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch code for "A Spatially Separable Attention Mechanism For Massive MIMO CSI Feedback" 2 | (c) Sharan Mourya, email: sharanmourya7@gmail.com 3 | ## Introduction 4 | This repository holds the pytorch implementation of the original models described in the paper 5 | 6 | [Sharan Mourya, Sai Dhiraj Amuru, "A Spatially Separable Attention Mechanism For Massive MIMO CSI Feedback"](https://arxiv.org/abs/2208.03369) 7 | 8 | ## Requirements 9 | - Python >= 3.7 10 | - [PyTorch >= 1.2](https://pytorch.org/get-started/locally/) 11 | - [Scipy >= 1.8.0](https://scipy.org/install/) 12 | 13 | 14 | ## Steps to follow 15 | 16 | #### 1) Download Dataset 17 | 18 | For simulation purposes, we generate channel matrices from [COST2100](https://ieeexplore.ieee.org/document/6393523) model. Chao-Kai Wen and Shi Jin group provides a ready-made version of COST2100 dataset in [Dropbox](https://www.dropbox.com/sh/edla5dodnn2ocwi/AADtPCCALXPOsOwYi_rjv3bda?dl=0). 19 | 20 | #### 2) Organize Dataset 21 | Once dataset is downloaded, we recommend to organize the folders as follows 22 | ``` 23 | ├── STNet # The cloned STNet repository 24 | │ ├── stnet.py 25 | ├── data # The data folder 26 | │ ├── DATA_Htestin.mat 27 | │ ├── ... 28 | ``` 29 | #### 3) Training STNet 30 | Firstly, choose the compression ratio 1/4, 1/8, 1/16, 1/32 or 1/64 by populating the variable **encoded_dim** with 512, 256, 128, 64 or 32 respectively. 31 | 32 | Secondly, choose a scenario "indoor" or "outdoor" by assiging the variable **envir** the same. 33 | 34 | Finally run the file **STNet.py** to begin training... 35 | 36 | ## Results 37 | Normalized Mean Square Error (NMSE) and Floating-Point Operations per second (FLOPS) achieved by STNet for different compression ratios under different scenarios are tabulated below. 38 | 39 | S.No | Compression Ratio | indoor | outdoor | Flops 40 | :--: | :--: | :--: | :--: | :--: 41 | 1 | 1/4 | -31.81 | -12.91 | 5.22M 42 | 2 | 1/8 | -21.28 | -8.53 | 4.38M 43 | 3 | 1/16 | -15.43 | -5.72 | 3.96M 44 | 4 | 1/32 | -9.42 | -3.51 | 3.75M 45 | 5 | 1/64 | -7.81 | -2.46 | 3.65M 46 | 47 | -------------------------------------------------------------------------------- /stnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import Adam, SGD 4 | import scipy.io as sio 5 | import numpy as np 6 | import math 7 | import time 8 | import matplotlib.pyplot as plt 9 | from os.path import dirname, join as pjoin 10 | from collections import OrderedDict 11 | 12 | img_height = 32 13 | img_width = 32 14 | img_channels = 2 15 | img_total = img_height*img_width*img_channels 16 | encoded_dim = 512 #compress rate=1/4->dim.=512, compress rate=1/16->dim.=128, compress rate=1/32->dim.=64, compress rate=1/64->dim.=32 17 | img_size = 32 18 | num_heads = 4 # for multi-head attention 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | depth = 1 # No of STB 21 | qkv_bias=True 22 | window = 8 # window size for LSA 23 | envir = 'indoor' 24 | 25 | class GroupAttention(nn.Module): 26 | 27 | def __init__(self, num_heads=4, qkv_bias=False): 28 | super(GroupAttention, self).__init__() 29 | 30 | self.num_heads = num_heads 31 | head_dim = img_size // num_heads 32 | self.scale = head_dim ** -0.5 33 | 34 | self.qkv = nn.Linear(img_size, img_size * 3, bias=qkv_bias) 35 | self.proj = nn.Linear(img_size, img_size) 36 | self.ws = window 37 | 38 | def forward(self, x): 39 | B, C, H, W = x.shape 40 | h_group, w_group = H // self.ws, W // self.ws 41 | 42 | total_groups = h_group * w_group 43 | 44 | x = x.reshape(B, C, h_group, self.ws, W) 45 | qkv = self.qkv(x).reshape(B, C, total_groups, -1, 3, self.num_heads, self.ws // self.num_heads).permute(4, 0, 1, 2, 5, 3, 6) 46 | q, k, v = qkv[0], qkv[1], qkv[2] 47 | attn = (q @ k.transpose(-2, -1)) * self.scale 48 | attn = attn.softmax(dim=-1) 49 | attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, C) 50 | x = attn.transpose(2, 3).reshape(B, C, H, W) 51 | x = self.proj(x) 52 | return x 53 | 54 | class GlobalAttention(nn.Module): 55 | 56 | def __init__(self, num_heads=4, qkv_bias=False): 57 | super().__init__() 58 | 59 | self.dim = img_size 60 | self.num_heads = num_heads 61 | head_dim = self.dim // num_heads 62 | self.scale = head_dim ** -0.5 63 | 64 | self.q = nn.Linear(self.dim, self.dim, bias=qkv_bias) 65 | self.kv = nn.Linear(self.dim//window, self.dim//window * 2, bias=qkv_bias) 66 | self.proj = nn.Linear(self.dim, self.dim) 67 | self.sr = nn.Conv2d(2, 2, kernel_size=window, stride=window) 68 | self.norm = nn.LayerNorm(self.dim//window) 69 | 70 | def forward(self, x): 71 | B, C, H, W = x.shape 72 | q = self.q(x).reshape(B, C, -1, self.dim//window, self.dim//window).permute(0,1,3,2,4) 73 | x_ = self.sr(x).reshape(B, C, -1, self.dim//window, self.dim//window) 74 | x_ = self.norm(x_) 75 | kv = self.kv(x_).reshape(B, C, -1, 2, self.dim//window, self.dim//window).permute(3,0,1,4,2,5) 76 | k, v = kv[0], kv[1] 77 | 78 | attn = (q @ k.transpose(-2, -1)) * self.scale 79 | attn = attn.softmax(dim=-1) 80 | 81 | x = (attn @ v).transpose(1, 2).reshape(B, C, H, W) 82 | x = self.proj(x) 83 | 84 | return x 85 | 86 | class MLP(nn.Module): 87 | 88 | def __init__(self): 89 | super().__init__() 90 | self.cc1 = nn.Linear(img_size, img_size) 91 | self.cc2 = nn.Linear(img_size, img_size) 92 | self.act = nn.GELU() 93 | 94 | def forward(self, x): 95 | 96 | x = self.cc1(x) 97 | x = self.act(x) 98 | x = self.cc2(x) 99 | 100 | return x 101 | 102 | 103 | class WTL(nn.Module): 104 | def __init__(self, num_heads, qkv_bias): 105 | super().__init__() 106 | self.norm1 = nn.LayerNorm(img_size, eps=1e-6) 107 | self.attn1 = GroupAttention( 108 | num_heads=num_heads, 109 | qkv_bias=qkv_bias, 110 | ) 111 | self.attn2 = GlobalAttention( 112 | num_heads=num_heads, 113 | qkv_bias=qkv_bias, 114 | ) 115 | self.norm2 = nn.LayerNorm(img_size, eps=1e-6) 116 | self.norm3 = nn.LayerNorm(img_size, eps=1e-6) 117 | self.norm4 = nn.LayerNorm(img_size, eps=1e-6) 118 | self.mlp1 = MLP() 119 | self.mlp2 = MLP() 120 | 121 | def forward(self, x): 122 | 123 | x = x + self.attn1(self.norm1(x)) 124 | x = x + self.mlp1(self.norm2(x)) 125 | x = x + self.attn2(self.norm3(x)) 126 | x = x + self.mlp2(self.norm4(x)) 127 | 128 | return x 129 | 130 | class ConvBN(nn.Sequential): 131 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1): 132 | if not isinstance(kernel_size, int): 133 | padding = [(i - 1) // 2 for i in kernel_size] 134 | else: 135 | padding = (kernel_size - 1) // 2 136 | super(ConvBN, self).__init__(OrderedDict([ 137 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 138 | padding=padding, groups=groups, bias=False)), 139 | ('bn', nn.BatchNorm2d(out_planes)) 140 | ])) 141 | 142 | class hsigmoid(nn.Module): 143 | def forward(self, x): 144 | out = F.relu6(x + 3, inplace=True) / 6 145 | return out 146 | 147 | class CRBlock(nn.Module): 148 | def __init__(self): 149 | super(CRBlock, self).__init__() 150 | self.path1 = nn.Sequential(OrderedDict([ 151 | ('conv3x3', ConvBN(2, 7, 3)), 152 | ('relu1', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 153 | ('conv1x9', ConvBN(7, 7, [1, 9])), 154 | ('relu2', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 155 | ('conv9x1', ConvBN(7, 7, [9, 1])), 156 | ])) 157 | self.path2 = nn.Sequential(OrderedDict([ 158 | ('conv1x5', ConvBN(2, 7, [1, 5])), 159 | ('relu', nn.LeakyReLU(negative_slope=0.3, inplace=True)), 160 | ('conv5x1', ConvBN(7, 7, [5, 1])), 161 | ])) 162 | self.conv1x1 = ConvBN(7 * 2, 2, 1) 163 | self.identity = nn.Identity() 164 | self.relu = nn.LeakyReLU(negative_slope=0.3, inplace=True) 165 | 166 | def forward(self, x): 167 | identity = self.identity(x) 168 | 169 | out1 = self.path1(x) 170 | out2 = self.path2(x) 171 | out = torch.cat((out1, out2), dim=1) 172 | out = self.relu(out) 173 | out = self.conv1x1(out) 174 | 175 | out = self.relu(out + identity) 176 | return out 177 | 178 | class Encoder(nn.Module): 179 | def __init__( 180 | self, 181 | img_size=img_size, 182 | depth=depth, 183 | num_heads=num_heads, 184 | qkv_bias=qkv_bias, 185 | ): 186 | super().__init__() 187 | 188 | 189 | self.blocks = nn.ModuleList( 190 | [ 191 | WTL( 192 | num_heads=num_heads, 193 | qkv_bias=qkv_bias, 194 | ) 195 | for _ in range(depth) 196 | ] 197 | ) 198 | 199 | self.norm2 = nn.LayerNorm(img_size, eps=1e-6) 200 | self.norm3 = nn.LayerNorm(img_size, eps=1e-6) 201 | self.conv1 = nn.Conv2d(2,16, kernel_size=1, stride=1) 202 | self.conv5 = nn.Conv2d(16,2, kernel_size=5, stride=1, padding=2) 203 | self.conv4 = nn.Conv2d(2,2, kernel_size=4, stride=2, padding=1) 204 | self.convT = nn.ConvTranspose2d(2,2, kernel_size=4, stride=2, padding=1) 205 | self.fc = nn.Linear(2*img_size*img_size, encoded_dim) 206 | 207 | def forward(self, x): 208 | 209 | n_samples = x.shape[0] 210 | x = self.conv1(x) 211 | x = self.conv5(x) 212 | X = x 213 | 214 | for block in self.blocks: 215 | x = block(x) 216 | x = self.norm3(x) 217 | x = self.convT(x) 218 | x = X + self.conv4(x) 219 | x = self.norm2(x) 220 | x = x.reshape(n_samples,2*img_size*img_size) 221 | x = self.fc(x) 222 | return x 223 | 224 | 225 | class Decoder(nn.Module): 226 | def __init__(self): 227 | super(Decoder, self).__init__() 228 | 229 | self.fc = nn.Linear(encoded_dim, img_channels*img_size*img_size) 230 | self.act = nn.Sigmoid() 231 | self.conv5 = nn.Conv2d(2,2, kernel_size=5, stride=1, padding=2) 232 | self.conv4 = nn.Conv2d(2,2, kernel_size=4, stride=2, padding=1) 233 | self.convT = nn.ConvTranspose2d(2,2, kernel_size=4, stride=2, padding=1) 234 | self.blocks = nn.ModuleList( 235 | [ 236 | WTL( 237 | num_heads=num_heads, 238 | qkv_bias=qkv_bias, 239 | ) 240 | for _ in range(depth) 241 | ] 242 | ) 243 | self.norm2 = nn.LayerNorm(img_size, eps=1e-6) 244 | self.norm3 = nn.LayerNorm(img_size, eps=1e-6) 245 | 246 | self.dense_layers = nn.Sequential( 247 | nn.Linear(encoded_dim, img_total) 248 | ) 249 | 250 | decoder = OrderedDict([ 251 | ("conv5x5_bn", ConvBN(2, 2, 5)), 252 | ("relu", nn.LeakyReLU(negative_slope=0.3, inplace=True)), 253 | ("CRBlock1", CRBlock()) 254 | ]) 255 | self.decoder_feature = nn.Sequential(decoder) 256 | 257 | def forward(self, x): 258 | img = self.dense_layers(x) 259 | img = img.view(-1, img_channels, img_height, img_width) 260 | 261 | out = self.decoder_feature(img) 262 | x = self.conv5(img) 263 | 264 | for block in self.blocks: 265 | x = block((x+out)) 266 | 267 | x = self.norm2(x) 268 | x = self.convT(x) 269 | x = self.conv4(x) 270 | 271 | for block in self.blocks: 272 | x = block((x+out)) 273 | 274 | x = self.norm3(x) 275 | 276 | x = self.act(x) 277 | 278 | return x 279 | 280 | encoder = Encoder() 281 | encoder.to(device) 282 | print(encoder) 283 | 284 | decoder = Decoder() 285 | decoder.to(device) 286 | print(decoder) 287 | 288 | print('Data loading begins.....') 289 | 290 | if envir == 'indoor': 291 | mat = sio.loadmat('../data/DATA_Htrainin.mat') 292 | x_train = mat['HT'] 293 | mat = sio.loadmat('../data/DATA_Hvalin.mat') 294 | x_val = mat['HT'] 295 | mat = sio.loadmat('../data/DATA_Htestin.mat') 296 | x_test = mat['HT'] 297 | 298 | elif envir == 'outdoor': 299 | mat = sio.loadmat('../data/DATA_Htrainout.mat') 300 | x_train = mat['HT'] 301 | mat = sio.loadmat('../data/DATA_Hvalout.mat') 302 | x_val = mat['HT'] 303 | mat = sio.loadmat('../data/DATA_Htestout.mat') 304 | x_test = mat['HT'] 305 | 306 | x_train = torch.from_numpy(x_train) 307 | x_test = torch.from_numpy(x_test) 308 | x_train = np.reshape(x_train, (len(x_train), img_channels, img_height, img_width)) 309 | x_test = np.reshape(x_test, (len(x_test), img_channels, img_height, img_width)) 310 | 311 | print('Data loading done!') 312 | 313 | 314 | x_train = x_train.to(device, dtype=torch.float) 315 | 316 | x_test = x_test.to(device, dtype=torch.float) 317 | 318 | def train_autoencoder(uncompressed_images, opt_enc, opt_dec): 319 | opt_enc.zero_grad() 320 | opt_dec.zero_grad() 321 | 322 | compressed_data = encoder.forward(uncompressed_images) 323 | 324 | reconstructed_images = decoder.forward(compressed_data) 325 | 326 | loss = nn.MSELoss() 327 | grads = loss(uncompressed_images, reconstructed_images) 328 | grads.backward() 329 | opt_enc.step() 330 | opt_dec.step() 331 | return grads.item() 332 | 333 | def fit(epochs, lr, start_idx=1): 334 | 335 | losses_dec = [] 336 | losses_auto = [] 337 | 338 | opt_enc = Adam(encoder.parameters(), lr, betas=(0.5, 0.999)) 339 | opt_dec = Adam(decoder.parameters(), lr, betas=(0.5, 0.999)) 340 | 341 | reps = int(len(x_train) / (batch_size)) 342 | 343 | for epoch in range(epochs): 344 | x_train_idx = torch.randperm(x_train.size()[0]) 345 | for i in range(reps): 346 | loss_auto= train_autoencoder(x_train[x_train_idx[i*batch_size:(i+1)*batch_size]], opt_enc, opt_dec) 347 | if i % 600 == 0: 348 | print('epoch',epoch+1,'/',epochs,'batch:',i+1,'/',reps, "loss_auto: {:.12f}".format(loss_auto)) 349 | losses_auto.append(loss_auto) 350 | 351 | return losses_auto 352 | 353 | epochs = 1000 354 | lr = 0.001 355 | batch_size = 200 356 | print('training starts.....') 357 | losses_auto = fit(epochs, lr) 358 | 359 | plt.figure(figsize=(10,5)) 360 | plt.title("autoencoder and aecoder Loss During Training") 361 | plt.plot(losses_auto,label="AE") 362 | plt.xlabel("iterations") 363 | plt.ylabel("Loss") 364 | plt.legend() 365 | plt.show() 366 | 367 | del losses_auto 368 | del x_train 369 | n = 10 370 | with torch.no_grad(): 371 | temp = encoder.forward(x_test[0:1000,:,:,:]) 372 | x_hat = decoder.forward(temp) 373 | 374 | x_test_in = x_test[0:1000,:,:,:].to('cpu') 375 | x_hat_in = x_hat.to('cpu') 376 | 377 | 378 | x_test_in = x_test_in.numpy() 379 | x_hat_in = x_hat_in.numpy() 380 | 381 | 382 | 383 | plt.figure(figsize=(20, 4)) 384 | for i in range(n): 385 | # display origoutal 386 | ax = plt.subplot(2, n, i + 1 ) 387 | x_testplo = abs(x_test_in[i, 0, :, :]-0.5 + 1j*(x_test_in[i, 1, :, :]-0.5)) 388 | plot = np.max(np.max(x_testplo))-x_testplo.T 389 | # print(plot) 390 | # print(np.mean(plot)) 391 | # plot = plot - np.mean(plot) 392 | # print(plot) 393 | plt.imshow(plot) 394 | # plt.imshow(x_testplo) 395 | plt.gray() 396 | ax.get_xaxis().set_visible(False) 397 | ax.get_yaxis().set_visible(False) 398 | ax.invert_yaxis() 399 | ax = plt.subplot(2, n, i + 1 + n) 400 | decoded_imgsplo = abs(x_hat_in[i, 0, :, :]-0.5 401 | + 1j*(x_hat_in[i, 1, :, :]-0.5)) 402 | plt.imshow(np.max(np.max(decoded_imgsplo))-decoded_imgsplo.T) 403 | # plt.imshow(decoded_imgsplo) 404 | plt.gray() 405 | ax.get_xaxis().set_visible(False) 406 | ax.get_yaxis().set_visible(False) 407 | ax.invert_yaxis() 408 | plt.show() 409 | 410 | 411 | x_test_real = np.reshape(x_test_in[:, 0, :, :], (len(x_test_in), -1)) 412 | x_test_imag = np.reshape(x_test_in[:, 1, :, :], (len(x_test_in), -1)) 413 | x_test_C = x_test_real-0.5 + 1j*(x_test_imag-0.5) 414 | x_hat_real = np.reshape(x_hat_in[:, 0, :, :], (len(x_hat_in), -1)) 415 | x_hat_imag = np.reshape(x_hat_in[:, 1, :, :], (len(x_hat_in), -1)) 416 | x_hat_C = x_hat_real-0.5 + 1j*(x_hat_imag-0.5) 417 | 418 | power = np.sum(abs(x_test_C)**2, axis=1) 419 | mse = np.sum(abs(x_test_C-x_hat_C)**2, axis=1) 420 | 421 | 422 | print("NMSE is ", 10*math.log10(np.mean(mse/power))) 423 | 424 | 425 | --------------------------------------------------------------------------------