├── .gitignore ├── LICENSE ├── README.md ├── dataloader ├── __init__.py └── cost2100.py ├── env.yaml ├── main.py ├── models ├── TransNet.py └── __init__.py ├── scripts.sh └── utils ├── __init__.py ├── init.py ├── logger.py ├── parser.py ├── scheduler.py ├── solver.py └── statics.py /.gitignore: -------------------------------------------------------------------------------- 1 | ## General 2 | 3 | # Compiled Object files 4 | models/__pycahche__/ 5 | __pycache__/ 6 | *.pyc 7 | 8 | # Compiled Dynamic libraries 9 | *.so 10 | *.dylib 11 | 12 | # Compiled Static libraries 13 | *.lai 14 | *.la 15 | *.a 16 | 17 | # IPython notebook checkpoints 18 | .ipynb_checkpoints 19 | 20 | # Editor temporaries 21 | *.swp 22 | *~ 23 | 24 | # Sublime Text settings 25 | *.sublime-workspace 26 | *.sublime-project 27 | 28 | # Eclipse Project settings 29 | *.*project 30 | .settings 31 | 32 | # QtCreator files 33 | *.user 34 | 35 | # PyCharm files 36 | .idea 37 | 38 | # OSX dir files 39 | .DS_Store 40 | 41 | # Data and models are either 42 | data/* 43 | snapshots/ 44 | checkpoints/ 45 | *.pth 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kylin Lu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Overview 2 | 3 | This is the PyTorch implementation of paper ["TransNet: Full Attention Network for CSI Feedback in FDD Massive MIMO System"](https://ieeexplore.ieee.org/document/9705497). You can cite our paper by: 4 | 5 | ``` 6 | @ARTICLE{9705497, 7 | author={Cui, Yaodong and Guo, Aihuang and Song, Chunlin}, 8 | journal={IEEE Wireless Communications Letters}, 9 | title={TransNet: Full Attention Network for CSI Feedback in FDD Massive MIMO System}, 10 | year={2022}, 11 | volume={11}, 12 | number={5}, 13 | pages={903-907}, 14 | doi={10.1109/LWC.2022.3149416}} 15 | ``` 16 | or 17 | ``` 18 | Y. Cui, A. Guo and C. Song, "TransNet: Full Attention Network for CSI Feedback in FDD Massive MIMO System," in IEEE Wireless Communications Letters, vol. 11, no. 5, pp. 903-907, May 2022, doi: 10.1109/LWC.2022.3149416. 19 | ``` 20 | ## Requirements 21 | 22 | We support a env.yaml in our project, so you can simply run 23 | ``` 24 | conda env create -f environment.yaml 25 | ``` 26 | to get a useable environment. Or manually install and build your own environment, to use this project, you need to ensure the following main requirements are installed. 27 | 28 | 29 | - Python >= 3.7 30 | - scipy 31 | - [1.2 =< PyTorch <= 1.6](https://pytorch.org/get-started/locally/) 32 | - [thop==0.0.31-2005241907](https://github.com/Lyken17/pytorch-OpCounter) Note that the latest version leads to bug. 33 | - [torchviz](https://github.com/szagoruyko/pytorchviz) 34 | - [tensorboardX](https://github.com/lanpa/tensorboardX) 35 | 36 | ## Project Preparation 37 | 38 | #### A. Data Preparation 39 | 40 | The channel state information (CSI) matrix is generated from [COST2100](https://ieeexplore.ieee.org/document/6393523) model. Chao-Kai Wen and Shi Jin group provides a pre-processed version of COST2100 dataset in [Google Drive](https://drive.google.com/drive/folders/1_lAMLk_5k1Z8zJQlTr5NRnSD6ACaNRtj?usp=sharing), which is easier to use for the CSI feedback task; You can also download it from [Baidu Netdisk](https://pan.baidu.com/s/1Ggr6gnsXNwzD4ULbwqCmjA). 41 | 42 | You can generate your own dataset according to the [open source library of COST2100](https://github.com/cost2100/cost2100) as well. The details of data pre-processing can be found in our paper. 43 | 44 | #### B. Checkpoints Downloading 45 | 46 | You can check the performance of indoor and outdoor scenarios by downloading checkpoints in [Google Drive](https://drive.google.com/drive/folders/1eoxryQfrMOPVtbiMRdxXtp5KsBt13-hI?usp=sharing). We support more detail checpoints in [Google Drive](https://drive.google.com/drive/folders/10AxRFCE1Nbiqc0JgcFdQZ8mxQV8YbR8F?usp=sharing). You can also check the authenticity of our results by training a new TransNet yourself and see its performance, the test NMSE and training MSE loss will be printed during your training. A 400 epochs training dosen't take very long (about 3 and half hours on a single RTX 2060), and you are able to reproduce TransNet-400ep results in Table 1 of our paper. 47 | 48 | 49 | 50 | #### C. Project Tree Arrangement 51 | 52 | We recommend you to arrange the project tree as follows. 53 | 54 | ``` 55 | home 56 | ├── TransNet # The cloned TransNet repository 57 | │ ├── dataset 58 | │ ├── models 59 | │ ├── utils 60 | │ ├── main.py 61 | ├── COST2100 # The data folder 62 | │ ├── DATA_Htestin.mat 63 | │ ├── ... 64 | ├── Experiments 65 | │ ├── checkpoints # The checkpoints folder 66 | │ │ ├── 4_in.pth 67 | │ │ ├── ... 68 | │ ├── run.sh # The bash script 69 | ... 70 | ``` 71 | 72 | ## Train TransNet from Scratch 73 | 74 | An example of run.sh is listed below. Simply use it with `sh run.sh`. It will start TransNet training from scratch. Change scenario by using `--scenario` . Change training epochs with '--epochs' and compression ratio with `--cr`. 75 | 76 | ``` bash 77 | python /home/TransNet/main.py \ 78 | --data-dir '/home/COST2100' \ 79 | --scenario 'in' \ 80 | --epochs 400 \ 81 | --batch-size 200 \ 82 | --workers 0 \ 83 | --cr 4 \ 84 | --scheduler const \ 85 | --gpu 0 \ 86 | 2>&1 | tee log.out 87 | ``` 88 | 89 | ## Results and Reproduction 90 | 91 | The main results reported in our paper are presented as follows. All the listed results can be found in Table1 of our paper. They are achieved from training TransNet with our 2 kind of training scheme (constant learning rate at 1e-4 for 400/1000 epochs). 92 | 93 | Results of 400 epochs 94 | Scenario | Compression Ratio | NMSE | Flops 95 | :--: | :--: | :--: | :--: 96 | indoor | 1/4 | -29.22 | 35.72M 97 | indoor | 1/8 | -21.62 | 34.70M 98 | indoor | 1/16 | -14.98 | 34.14M 99 | indoor | 1/32 | -9.83 | 33.88M 100 | indoor | 1/64 | -5.77 | 33.75M 101 | outdoor | 1/4 | -13.99 | 35.72M 102 | outdoor | 1/8 | -9.57 | 34.70M 103 | outdoor | 1/16 | -6.90 | 34.14M 104 | outdoor | 1/32 | -3.77 | 33.88M 105 | outdoor | 1/64 | -2.20 | 33.75M 106 | 107 | Results of 1000 epochs 108 | Scenario | Compression Ratio | NMSE | Flops 109 | :--: | :--: | :--: | :--: 110 | indoor | 1/4 | -32.38 | 35.72M 111 | indoor | 1/8 | -22.91 | 34.70M 112 | indoor | 1/16 | -15.00 | 34.14M 113 | indoor | 1/32 | -10.49 | 33.88M 114 | indoor | 1/64 | -6.08 | 33.75M 115 | outdoor | 1/4 | -14.86 | 35.72M 116 | outdoor | 1/8 | -9.99 | 34.70M 117 | outdoor | 1/16 | -7.82 | 34.14M 118 | outdoor | 1/32 | -4.13 | 33.88M 119 | outdoor | 1/64 | -2.62 | 33.75M 120 | 121 | **To reproduce all these results, simplely add `--evaluate` to `scripts.sh` and pick the corresponding pre-trained model with `--pretrained`.** An example is shown as follows. 122 | 123 | ``` bash 124 | python /home/TransNet/main.py \ 125 | --data-dir '/home/COST2100' \ 126 | --scenario 'in' \ 127 | --pretrained './checkpoints/4_in.pth' \ 128 | --evaluate \ 129 | --batch-size 200 \ 130 | --workers 0 \ 131 | --cr 4\ # Note that cr should be same as checkpoints 132 | --cpu \ 133 | 2>&1 | tee test_log.out 134 | 135 | ``` 136 | 137 | 138 | 139 | 140 | ## Acknowledgment 141 | 142 | Thank Chao-Kai Wen and Shi Jin group again for providing the pre-processed COST2100 dataset, you can find their related work named CsiNet in [Github-Python_CsiNet](https://github.com/sydney222/Python_CsiNet) 143 | 144 | 145 | Thanks two open source works, CRNet and CLNet, that build on work above and advance the CSI feedback problem in DL, you can find their related work in [Github-Python-PyTorch CRNet](https://github.com/Kylin9511/CRNet) and [Github-Python-PyTorch CLNet](https://github.com/SIJIEJI/CLNet) 146 | 147 | Thanks the Github project members for the open source [Transformer tutorial](https://github.com/datawhalechina/Learn-NLP-with-Transformers), our base model for TransNet is based on their work. 148 | 149 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .cost2100 import Cost2100DataLoader 2 | -------------------------------------------------------------------------------- /dataloader/cost2100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.io as sio 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, TensorDataset 7 | 8 | __all__ = ['Cost2100DataLoader', 'PreFetcher'] 9 | 10 | 11 | class PreFetcher: 12 | r""" Data pre-fetcher to accelerate the data loading 13 | """ 14 | 15 | def __init__(self, loader): 16 | self.ori_loader = loader 17 | self.len = len(loader) 18 | self.stream = torch.cuda.Stream() 19 | self.next_input = None 20 | 21 | def preload(self): 22 | try: 23 | self.next_input = next(self.loader) 24 | except StopIteration: 25 | self.next_input = None 26 | return 27 | 28 | with torch.cuda.stream(self.stream): 29 | for idx, tensor in enumerate(self.next_input): 30 | self.next_input[idx] = tensor.cuda(non_blocking=True) 31 | 32 | def __len__(self): 33 | return self.len 34 | 35 | def __iter__(self): 36 | self.loader = iter(self.ori_loader) 37 | self.preload() 38 | return self 39 | 40 | def __next__(self): 41 | torch.cuda.current_stream().wait_stream(self.stream) 42 | input = self.next_input 43 | if input is None: 44 | raise StopIteration 45 | for tensor in input: 46 | tensor.record_stream(torch.cuda.current_stream()) 47 | self.preload() 48 | return input 49 | 50 | 51 | class Cost2100DataLoader(object): 52 | r""" PyTorch DataLoader for COST2100 dataset. 53 | """ 54 | 55 | def __init__(self, root, batch_size, num_workers, pin_memory, scenario): 56 | assert os.path.isdir(root) 57 | assert scenario in {"in", "out"} 58 | self.batch_size = batch_size 59 | self.num_workers = num_workers 60 | self.pin_memory = pin_memory 61 | 62 | dir_train = os.path.join(root, f"DATA_Htrain{scenario}.mat") 63 | dir_val = os.path.join(root, f"DATA_Hval{scenario}.mat") 64 | dir_test = os.path.join(root, f"DATA_Htest{scenario}.mat") 65 | dir_raw = os.path.join(root, f"DATA_HtestF{scenario}_all.mat") 66 | channel, nt, nc, nc_expand = 2, 32, 32, 125 67 | 68 | # Training data loading 69 | data_train = sio.loadmat(dir_train)['HT'] 70 | data_train = torch.tensor(data_train, dtype=torch.float32).view( 71 | data_train.shape[0], channel, nt, nc) 72 | self.train_dataset = TensorDataset(data_train) 73 | 74 | # Validation data loading 75 | data_val = sio.loadmat(dir_val)['HT'] 76 | data_val = torch.tensor(data_val, dtype=torch.float32).view( 77 | data_val.shape[0], channel, nt, nc) 78 | self.val_dataset = TensorDataset(data_val) 79 | 80 | # Test data loading, including the sparse data and the raw data 81 | data_test = sio.loadmat(dir_test)['HT'] 82 | data_test = torch.tensor(data_test, dtype=torch.float32).view( 83 | data_test.shape[0], channel, nt, nc) 84 | 85 | raw_test = sio.loadmat(dir_raw)['HF_all'] 86 | real = torch.tensor(np.real(raw_test), dtype=torch.float32) 87 | imag = torch.tensor(np.imag(raw_test), dtype=torch.float32) 88 | raw_test = torch.cat((real.view(raw_test.shape[0], nt, nc_expand, 1), 89 | imag.view(raw_test.shape[0], nt, nc_expand, 1)), dim=3) 90 | self.test_dataset = TensorDataset(data_test, raw_test) 91 | 92 | def __call__(self): 93 | train_loader = DataLoader(self.train_dataset, 94 | batch_size=self.batch_size, 95 | num_workers=self.num_workers, 96 | pin_memory=self.pin_memory, 97 | shuffle=True) 98 | val_loader = DataLoader(self.val_dataset, 99 | batch_size=self.batch_size, 100 | num_workers=self.num_workers, 101 | pin_memory=self.pin_memory, 102 | shuffle=False) 103 | test_loader = DataLoader(self.test_dataset, 104 | batch_size=self.batch_size, 105 | num_workers=self.num_workers, 106 | pin_memory=self.pin_memory, 107 | shuffle=False) 108 | 109 | # Accelerate CUDA data loading with pre-fetcher if GPU is used. 110 | if self.pin_memory is True: 111 | train_loader = PreFetcher(train_loader) 112 | val_loader = PreFetcher(val_loader) 113 | test_loader = PreFetcher(test_loader) 114 | 115 | return train_loader, val_loader, test_loader 116 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: transnet 2 | channels: 3 | - pytorch 4 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/main 5 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/main/ 6 | - https://mirrors.ustc.edu.cn/anaconda/pkgs/free/ 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - blas=1.0=mkl 12 | - ca-certificates=2023.01.10=h06a4308_0 13 | - certifi=2022.12.7=py38h06a4308_0 14 | - cudatoolkit=10.1.243=h6bb024c_0 15 | - freetype=2.12.1=h4a9f257_0 16 | - giflib=5.2.1=h5eee18b_1 17 | - intel-openmp=2021.4.0=h06a4308_3561 18 | - jpeg=9e=h7f8727e_0 19 | - lcms2=2.12=h3be6417_0 20 | - ld_impl_linux-64=2.38=h1181459_1 21 | - lerc=3.0=h295c915_0 22 | - libdeflate=1.8=h7f8727e_5 23 | - libffi=3.4.2=h6a678d5_6 24 | - libgcc-ng=11.2.0=h1234567_1 25 | - libgomp=11.2.0=h1234567_1 26 | - libpng=1.6.37=hbc83047_0 27 | - libstdcxx-ng=11.2.0=h1234567_1 28 | - libtiff=4.5.0=h6a678d5_1 29 | - libwebp=1.2.4=h11a3e52_0 30 | - libwebp-base=1.2.4=h5eee18b_0 31 | - lz4-c=1.9.4=h6a678d5_0 32 | - mkl=2021.4.0=h06a4308_640 33 | - mkl-service=2.4.0=py38h7f8727e_0 34 | - mkl_fft=1.3.1=py38hd3c417c_0 35 | - mkl_random=1.2.2=py38h51133e4_0 36 | - ncurses=6.4=h6a678d5_0 37 | - ninja=1.10.2=h06a4308_5 38 | - ninja-base=1.10.2=hd09550d_5 39 | - numpy=1.23.5=py38h14f4228_0 40 | - numpy-base=1.23.5=py38h31eccc5_0 41 | - openssl=1.1.1s=h7f8727e_0 42 | - pillow=9.3.0=py38h6a678d5_2 43 | - pip=22.3.1=py38h06a4308_0 44 | - python=3.8.16=h7a1cb2a_2 45 | - pytorch=1.6.0=py3.8_cuda10.1.243_cudnn7.6.3_0 46 | - readline=8.2=h5eee18b_0 47 | - setuptools=65.6.3=py38h06a4308_0 48 | - six=1.16.0=pyhd3eb1b0_1 49 | - sqlite=3.40.1=h5082296_0 50 | - tk=8.6.12=h1ccaba5_0 51 | - torchvision=0.7.0=py38_cu101 52 | - wheel=0.37.1=pyhd3eb1b0_0 53 | - xz=5.2.10=h5eee18b_1 54 | - zlib=1.2.13=h5eee18b_0 55 | - zstd=1.5.2=ha4553b6_0 56 | - pip: 57 | - future==0.18.2 58 | - protobuf==3.20.1 59 | - python-graphviz==0.20.1 60 | - scipy==1.9.3 61 | - tensorboardx==2.5.1 62 | - thop==0.0.31-2005241907 63 | - torchviz==0.0.2 64 | prefix: /root/miniconda3/envs/transnet 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils.parser import args 4 | from utils import logger, Trainer, Tester 5 | from utils import init_device, init_model, FakeLR, WarmUpCosineAnnealingLR 6 | from dataloader import Cost2100DataLoader 7 | from tensorboardX import SummaryWriter 8 | from torchviz import make_dot 9 | 10 | def main(): 11 | logger.info('=> PyTorch Version: {}'.format(torch.__version__)) 12 | # Environment initialization 13 | device, pin_memory = init_device(args.seed, args.cpu, args.gpu, args.cpu_affinity) 14 | 15 | # Create the data loader 16 | 17 | train_loader, val_loader, test_loader = Cost2100DataLoader( 18 | root=args.data_dir, 19 | batch_size=args.batch_size, 20 | num_workers=args.workers, 21 | pin_memory=pin_memory, 22 | scenario=args.scenario)() 23 | 24 | # Define model 25 | 26 | model = init_model(args) 27 | model.to(device) 28 | 29 | # Define loss function 30 | criterion = nn.MSELoss().to(device) 31 | 32 | # Inference mode 33 | if args.evaluate: 34 | Tester(model, device, criterion)(test_loader) 35 | return 36 | 37 | # Define optimizer and scheduler 38 | 39 | lr_init = 1e-4 if args.scheduler == 'const' else 2e-4 40 | optimizer = torch.optim.Adam(model.parameters(), lr_init) 41 | 42 | if args.scheduler == 'const': 43 | scheduler = FakeLR(optimizer=optimizer) 44 | 45 | else: 46 | scheduler = WarmUpCosineAnnealingLR(optimizer=optimizer, 47 | T_max=args.epochs * len(train_loader), 48 | T_warmup=30 * len(train_loader), 49 | eta_min=5e-5) 50 | 51 | # Define the training pipeline 52 | 53 | trainer = Trainer(model=model, 54 | device=device, 55 | optimizer=optimizer, 56 | criterion=criterion, 57 | scheduler=scheduler, 58 | resume=args.resume) 59 | 60 | # Start training 61 | trainer.loop(args.epochs, train_loader, val_loader, test_loader) 62 | 63 | # Final testing 64 | loss, rho, nmse = Tester(model, device, criterion)(test_loader) 65 | print(f"\n=! Final test loss: {loss:.3e}" 66 | f"\n test rho: {rho:.3e}" 67 | f"\n test NMSE: {nmse:.3e}\n") 68 | 69 | # Create images for loss, rho ,nmse 70 | 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /models/TransNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.init import xavier_uniform_ 5 | from torch.nn.init import constant_ 6 | from torch.nn.init import xavier_normal_ 7 | import torch.nn.functional as F 8 | from typing import Optional, Tuple, Any 9 | from typing import List, Optional, Tuple 10 | from utils import logger 11 | import math 12 | import warnings 13 | 14 | __all__ = ["transnet"] 15 | 16 | Tensor = torch.Tensor 17 | 18 | 19 | def scale_dot_attention( 20 | q:Tensor, 21 | k:Tensor, 22 | v:Tensor, 23 | dropout_p:float = 0.0, 24 | attn_mask : Optional[Tensor] = None, 25 | )-> Tuple[Tensor,Tensor]: 26 | 27 | _,_,E = q.shape 28 | q = q / math.sqrt(E) 29 | attn = torch.bmm(q,k.transpose(-2,-1)) 30 | if attn_mask is not None: 31 | attn = attn + attn_mask 32 | attn = F.softmax(attn,dim =-1) 33 | if dropout_p: 34 | attn = F.dropout(attn,p = dropout_p) 35 | out = torch.bmm(attn,v) 36 | 37 | return out,attn 38 | 39 | 40 | def multi_head_attention_forward( 41 | query: Tensor, 42 | key: Tensor, 43 | value: Tensor, 44 | num_heads: int, 45 | in_proj_weight: Tensor, 46 | in_proj_bias: Optional[Tensor], 47 | dropout_p: float, 48 | out_proj_weight: Tensor, 49 | out_proj_bias: Optional[Tensor], 50 | training: bool = True, 51 | key_padding_mask: Optional[Tensor] = None, 52 | need_weights: bool = True, 53 | attn_mask: Optional[Tensor] = None, 54 | use_separate_proj_weight=None, 55 | q_proj_weight: Optional[Tensor] = None, 56 | k_proj_weight: Optional[Tensor] = None, 57 | v_proj_weight: Optional[Tensor] = None, 58 | ) -> Tuple[Tensor, Optional[Tensor]]: 59 | 60 | tgt_len, bsz, embed_dim = query.shape 61 | src_len, _, _ = key.shape 62 | head_dim = embed_dim // num_heads 63 | q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) 64 | 65 | if attn_mask is not None: 66 | if attn_mask.dtype == torch.uint8: 67 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 68 | attn_mask = attn_mask.to(torch.bool) 69 | else: 70 | assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ 71 | f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" 72 | 73 | if attn_mask.dim() == 2: 74 | correct_2d_size = (tgt_len, src_len) 75 | if attn_mask.shape != correct_2d_size: 76 | raise RuntimeError( 77 | f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") 78 | attn_mask = attn_mask.unsqueeze(0) 79 | elif attn_mask.dim() == 3: 80 | correct_3d_size = (bsz * num_heads, tgt_len, src_len) 81 | if attn_mask.shape != correct_3d_size: 82 | raise RuntimeError( 83 | f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") 84 | else: 85 | raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") 86 | 87 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 88 | warnings.warn( 89 | "Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 90 | key_padding_mask = key_padding_mask.to(torch.bool) 91 | 92 | 93 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 94 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 95 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 96 | if key_padding_mask is not None: 97 | assert key_padding_mask.shape == (bsz, src_len), \ 98 | f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 99 | key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ 100 | expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) 101 | if attn_mask is None: 102 | attn_mask = key_padding_mask 103 | elif attn_mask.dtype == torch.bool: 104 | attn_mask = attn_mask.logical_or(key_padding_mask) 105 | else: 106 | attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) 107 | 108 | if attn_mask is not None and attn_mask.dtype == torch.bool: 109 | new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) 110 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 111 | attn_mask = new_attn_mask 112 | 113 | 114 | if not training: 115 | dropout_p = 0.0 116 | attn_output, attn_output_weights = scale_dot_attention(q, k, v, attn_mask, dropout_p) 117 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 118 | attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) 119 | if need_weights: 120 | # average attention weights over heads 121 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 122 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 123 | else: 124 | return attn_output, None 125 | 126 | 127 | def _in_projection_packed( 128 | q: Tensor, 129 | k: Tensor, 130 | v: Tensor, 131 | w: Tensor, 132 | b: Optional[Tensor] = None, 133 | ) -> List[Tensor]: 134 | E = q.size(-1) 135 | if k is v: 136 | if q is k: 137 | return F.linear(q, w, b).chunk(3, dim=-1) 138 | else: 139 | w_q, w_kv = w.split([E, E * 2]) 140 | if b is None: 141 | b_q = b_kv = None 142 | else: 143 | b_q, b_kv = b.split([E, E * 2]) 144 | return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) 145 | else: 146 | w_q, w_k, w_v = w.chunk(3) 147 | if b is None: 148 | b_q = b_k = b_v = None 149 | else: 150 | b_q, b_k, b_v = b.chunk(3) 151 | return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) 152 | 153 | 154 | 155 | class MultiheadAttention(nn.Module): 156 | 157 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, 158 | kdim=None, vdim=None, batch_first=False) -> None: 159 | # factory_kwargs = {'device': device, 'dtype': dtype} 160 | super(MultiheadAttention, self).__init__() 161 | self.embed_dim = embed_dim 162 | self.kdim = kdim if kdim is not None else embed_dim 163 | self.vdim = vdim if vdim is not None else embed_dim 164 | self._qkv_same_embed_dim = self.kdim == self.embed_dim and self.vdim == self.embed_dim 165 | 166 | self.num_heads = num_heads 167 | self.dropout = dropout 168 | self.batch_first = batch_first 169 | self.head_dim = embed_dim // num_heads 170 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 171 | 172 | if self._qkv_same_embed_dim is False: 173 | self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim))) 174 | self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim))) 175 | self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim))) 176 | self.register_parameter('in_proj_weight', None) 177 | else: 178 | self.in_proj_weight = Parameter(torch.empty((3 * embed_dim,embed_dim))) 179 | self.register_parameter('q_proj_weight', None) 180 | self.register_parameter('k_proj_weight', None) 181 | self.register_parameter('v_proj_weight', None) 182 | 183 | if bias: 184 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 185 | else: 186 | self.register_parameter('in_proj_bias', None) 187 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 188 | 189 | self._reset_parameters() 190 | 191 | def _reset_parameters(self): 192 | if self._qkv_same_embed_dim: 193 | xavier_uniform_(self.in_proj_weight) 194 | else: 195 | xavier_uniform_(self.q_proj_weight) 196 | xavier_uniform_(self.k_proj_weight) 197 | xavier_uniform_(self.v_proj_weight) 198 | 199 | if self.in_proj_bias is not None: 200 | constant_(self.in_proj_bias, 0.) 201 | constant_(self.out_proj.bias, 0.) 202 | 203 | 204 | 205 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, 206 | need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: 207 | if self.batch_first: 208 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 209 | 210 | if not self._qkv_same_embed_dim: 211 | attn_output, attn_output_weights = multi_head_attention_forward( 212 | query, key, value, self.num_heads, 213 | self.in_proj_weight, self.in_proj_bias, 214 | self.dropout, self.out_proj.weight, self.out_proj.bias, 215 | training=self.training, 216 | key_padding_mask=key_padding_mask, need_weights=need_weights, 217 | attn_mask=attn_mask, use_separate_proj_weight=True, 218 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 219 | v_proj_weight=self.v_proj_weight) 220 | else: 221 | attn_output, attn_output_weights = multi_head_attention_forward( 222 | query, key, value, self.num_heads, 223 | self.in_proj_weight, self.in_proj_bias, 224 | self.dropout, self.out_proj.weight, self.out_proj.bias, 225 | training=self.training, 226 | key_padding_mask=key_padding_mask, need_weights=need_weights, 227 | attn_mask=attn_mask) 228 | if self.batch_first: 229 | return attn_output.transpose(1, 0), attn_output_weights 230 | else: 231 | return attn_output, attn_output_weights 232 | 233 | class TransformerEncoderLayer(nn.Module): 234 | 235 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, 236 | layer_norm_eps=1e-5, batch_first=False) -> None: 237 | super(TransformerEncoderLayer, self).__init__() 238 | self.self_attn = MultiheadAttention(d_model,nhead, 239 | dropout=dropout, batch_first=batch_first) 240 | 241 | self.linear1 = nn.Linear(d_model, dim_feedforward) 242 | self.dropout = nn.Dropout(dropout) 243 | self.linear2 = nn.Linear(dim_feedforward, d_model) 244 | 245 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 246 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 247 | self.dropout1 = nn.Dropout(dropout) 248 | self.dropout2 = nn.Dropout(dropout) 249 | self.activation = activation 250 | 251 | def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, 252 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 253 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 254 | key_padding_mask=src_key_padding_mask)[0] 255 | src = src + self.dropout1(src2) 256 | src = self.norm1(src) 257 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 258 | src = src + self.dropout(src2) 259 | src = self.norm2(src) 260 | return src 261 | 262 | 263 | class TransformerEncoder(nn.Module): 264 | 265 | def __init__(self, encoder_layer, num_layers, norm=None): 266 | super(TransformerEncoder, self).__init__() 267 | self.layer = encoder_layer 268 | self.num_layers = num_layers 269 | self.norm = norm 270 | 271 | def forward(self, src: Tensor, mask: Optional[Tensor] = None, 272 | src_key_padding_mask: Optional[Tensor] = None) -> Tensor: 273 | output = src 274 | for _ in range(self.num_layers): 275 | output = self.layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 276 | 277 | if self.norm is not None: 278 | output = self.norm(output) 279 | 280 | return output 281 | 282 | 283 | #Decoder Layer: 284 | class TransformerDecoderLayer(nn.Module): 285 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, 286 | layer_norm_eps=1e-5, batch_first=False) -> None: 287 | super(TransformerDecoderLayer, self).__init__() 288 | self.self_attn = MultiheadAttention(d_model,nhead, 289 | dropout=dropout, batch_first=batch_first) 290 | self.multihead_attn = MultiheadAttention(d_model,nhead,dropout=dropout, batch_first=batch_first) 291 | 292 | self.linear1 = nn.Linear(d_model, dim_feedforward) 293 | self.dropout = nn.Dropout(dropout) 294 | self.linear2 = nn.Linear(dim_feedforward, d_model) 295 | 296 | self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) 297 | self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) 298 | self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) 299 | self.dropout1 = nn.Dropout(dropout) 300 | self.dropout2 = nn.Dropout(dropout) 301 | self.dropout3 = nn.Dropout(dropout) 302 | 303 | self.activation = activation 304 | 305 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 306 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 307 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 308 | tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, 309 | key_padding_mask=tgt_key_padding_mask)[0] 310 | tgt = tgt + self.dropout1(tgt2) 311 | tgt = self.norm1(tgt) 312 | tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, 313 | key_padding_mask=memory_key_padding_mask)[0] 314 | tgt = tgt + self.dropout2(tgt2) 315 | tgt = self.norm2(tgt) 316 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 317 | tgt = tgt + self.dropout3(tgt2) 318 | tgt = self.norm3(tgt) 319 | return tgt 320 | 321 | #Decoder 322 | class TransformerDecoder(nn.Module): 323 | 324 | def __init__(self, decoder_layer, num_layers, norm=None): 325 | super(TransformerDecoder, self).__init__() 326 | self.layer = decoder_layer 327 | self.num_layers = num_layers 328 | self.norm = norm 329 | 330 | def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, 331 | memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, 332 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 333 | output = tgt 334 | for _ in range(self.num_layers): 335 | output = self.layer(output, memory, tgt_mask=tgt_mask, 336 | memory_mask=memory_mask, 337 | tgt_key_padding_mask=tgt_key_padding_mask, 338 | memory_key_padding_mask=memory_key_padding_mask) 339 | if self.norm is not None: 340 | output = self.norm(output) 341 | 342 | return output 343 | 344 | class Transformer(nn.Module): 345 | 346 | def __init__(self, d_model: int = 64, nhead: int = 8, num_encoder_layers: int = 6, 347 | num_decoder_layers: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1, 348 | activation = F.relu, custom_encoder: Optional[Any] = None, custom_decoder: Optional[Any] = None, 349 | layer_norm_eps: float = 1e-5, batch_first: bool = False, reduction=64) -> None: 350 | super(Transformer, self).__init__() 351 | if custom_encoder is not None: 352 | self.encoder = custom_encoder 353 | else: 354 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, 355 | activation, layer_norm_eps, batch_first) 356 | encoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) 357 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers) 358 | 359 | if custom_decoder is not None: 360 | self.decoder = custom_decoder 361 | else: 362 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, 363 | activation, layer_norm_eps, batch_first) 364 | decoder_norm = nn.LayerNorm(d_model, eps=layer_norm_eps) 365 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm) 366 | 367 | self.d_model = d_model 368 | 369 | assert not (2048 % self.d_model), 'd_model needs to be divisible by the size of the entire csi matrix (2048)' 370 | self.feature_shape = (2048//self.d_model, self.d_model) 371 | 372 | self.nhead = nhead 373 | 374 | self.batch_first = batch_first 375 | self.fc_encoder = nn.Linear(2048,2048//reduction) 376 | self.fc_decoder = nn.Linear(2048//reduction,2048) 377 | self._reset_parameters() 378 | 379 | 380 | def forward(self, src: Tensor, tgt: Optional[Tensor]=None, src_mask: Optional[Tensor] = None, 381 | tgt_mask: Optional[Tensor] = None, 382 | memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, 383 | tgt_key_padding_mask: Optional[Tensor] = None, 384 | memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: 385 | memory = self.encoder(src.view(-1, self.feature_shape[0], self.feature_shape[1]), mask=src_mask, src_key_padding_mask=src_key_padding_mask) 386 | memory_encoder = self.fc_encoder(memory.view(memory.shape[0],-1)) 387 | memory_decoder = self.fc_decoder(memory_encoder).view(-1, self.feature_shape[0], self.feature_shape[1]) 388 | output = self.decoder(memory_decoder, memory_decoder, tgt_mask=tgt_mask, memory_mask=memory_mask, 389 | tgt_key_padding_mask=tgt_key_padding_mask, 390 | memory_key_padding_mask=memory_key_padding_mask) 391 | output = output.view(-1,2,32,32) 392 | return output 393 | 394 | def generate_square_subsequent_mask(self, sz: int) -> Tensor: 395 | 396 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 397 | 398 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 399 | return mask 400 | 401 | def _reset_parameters(self): 402 | 403 | for p in self.parameters(): 404 | if p.dim() > 1: 405 | xavier_uniform_(p) 406 | 407 | def transnet(reduction=64, d_model=64): 408 | 409 | r""" Create a proposed TransNet. 410 | 411 | :param reduction: the reciprocal of compression ratio 412 | :return: an instance of TransNet 413 | """ 414 | model = Transformer(d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, nhead=2, reduction =reduction, dropout= 0.) 415 | return model 416 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .TransNet import * 2 | -------------------------------------------------------------------------------- /scripts.sh: -------------------------------------------------------------------------------- 1 | python ./main.py \ 2 | --data-dir './COST2100' \ # root dir path for cost2100 3 | --scenario 'in' \ # in or out 4 | --epochs 1000 \ 5 | --d_model 64 \ # dimension of feature in transformer 6 | --batch-size 200 \ 7 | --workers 3 \ 8 | --cr 4 \ 9 | --scheduler const \ const or cosine 10 | --gpu 0 \ 11 | 2>&1 | tee log.out -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import logger 2 | from .logger import log_level, line_seg 3 | from .init import * 4 | from .scheduler import * 5 | from .solver import * 6 | 7 | -------------------------------------------------------------------------------- /utils/init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import thop 4 | import torch 5 | 6 | from models import transnet 7 | from utils import logger, line_seg 8 | 9 | __all__ = ["init_device", "init_model"] 10 | 11 | 12 | def init_device(seed=None, cpu=None, gpu=None, affinity=None): 13 | # set the CPU affinity 14 | if affinity is not None: 15 | os.system(f'taskset -p {affinity} {os.getpid()}') 16 | 17 | # Set the random seed 18 | if seed is not None: 19 | random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | # Set the GPU id you choose 24 | if gpu is not None: 25 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 26 | 27 | # Env setup 28 | if not cpu and torch.cuda.is_available(): 29 | device = torch.device('cuda') 30 | torch.backends.cudnn.benchmark = True 31 | if seed is not None: 32 | torch.cuda.manual_seed(seed) 33 | pin_memory = True 34 | logger.info("Running on GPU%d" % (gpu if gpu else 0)) 35 | else: 36 | pin_memory = False 37 | device = torch.device('cpu') 38 | logger.info("Running on CPU") 39 | 40 | return device, pin_memory 41 | 42 | 43 | def init_model(args): 44 | # Model loading 45 | model = transnet(reduction=args.cr, d_model=args.d_model) 46 | 47 | if args.pretrained is not None: 48 | assert os.path.isfile(args.pretrained) 49 | state_dict = torch.load(args.pretrained, 50 | map_location=torch.device('cpu'))['state_dict'] 51 | model.load_state_dict(state_dict,strict=False) 52 | logger.info("pretrained model loaded from {}".format(args.pretrained)) 53 | 54 | # Model flops and params counting 55 | H_a = torch.randn([1,2,32,32]) 56 | flops, params = thop.profile(model, inputs=(H_a,), verbose=False) 57 | flops, params = thop.clever_format([flops, params], "%.3f") 58 | 59 | # Model info logging 60 | logger.info(f'=> Model Name: TransNet [pretrained: {args.pretrained}]') 61 | logger.info(f'=> Model Config: compression ratio=1/{args.cr}') 62 | logger.info(f'=> Model Flops: {flops}') 63 | logger.info(f'=> Model Params Num: {params}\n') 64 | logger.info(f'{line_seg}\n{model}\n{line_seg}\n') 65 | 66 | return model 67 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import sys 3 | import traceback 4 | 5 | DEBUG = -1 6 | INFO = 0 7 | EMPH = 1 8 | WARNING = 2 9 | ERROR = 3 10 | FATAL = 4 11 | 12 | log_level = INFO 13 | line_seg = ''.join(['*'] * 65) 14 | 15 | 16 | class LoggerFatalError(SystemExit): 17 | pass 18 | 19 | 20 | def _format(level, messages): 21 | timestr = datetime.strftime(datetime.now(), '%m.%d/%H:%M') 22 | father = traceback.extract_stack()[-4] 23 | func_info = f'{father[0].split("/")[-1]}:{str(father[1]).ljust(4, " ")}' 24 | m = ' '.join(map(str, messages)) 25 | msg = f'{level} {timestr} {func_info}] {m}' 26 | return msg 27 | 28 | 29 | _log_file = None 30 | _log_buffer = [] 31 | _RED = '\033[0;31m' 32 | _GREEN = '\033[1;32m' 33 | _LIGHT_RED = '\033[1;31m' 34 | _ORANGE = '\033[0;33m' 35 | _YELLOW = '\033[1;33m' 36 | _NC = '\033[0m' # No Color 37 | 38 | 39 | def set_file(fname): 40 | global _log_file 41 | global _log_buffer 42 | if _log_file is not None: 43 | warning("Change log file to %s" % fname) 44 | _log_file.close() 45 | _log_file = open(fname, 'w') 46 | if len(_log_buffer): 47 | for s in _log_buffer: 48 | _log_file.write(s) 49 | _log_file.flush() 50 | 51 | 52 | def debug(*messages, file=None): 53 | if log_level > DEBUG: 54 | return 55 | msg = _format('D', messages) 56 | 57 | if file is None: 58 | sys.stdout.write(_YELLOW + msg + _NC + '\n') 59 | sys.stdout.flush() 60 | else: 61 | with open(file, 'a+') as f: 62 | print(msg, file=f) 63 | 64 | 65 | def info(*messages, file=None): 66 | if log_level > INFO: 67 | return 68 | msg = _format('I', messages) 69 | if file is None: 70 | sys.stdout.write(msg + '\n') 71 | sys.stdout.flush() 72 | else: 73 | with open(file, 'a+') as f: 74 | print(msg, file=f) 75 | 76 | 77 | def emph(*messages, file=None): 78 | if log_level > EMPH: 79 | return 80 | msg = _format('EM', messages) 81 | if file is None: 82 | sys.stdout.write(_GREEN + msg + _NC + '\n') 83 | sys.stdout.flush() 84 | else: 85 | with open(file, 'a+') as f: 86 | print(msg, file=f) 87 | 88 | 89 | def warning(*messages, file=None): 90 | if log_level > WARNING: 91 | return 92 | msg = _format('W', messages) 93 | if file is None: 94 | sys.stderr.write(_ORANGE + msg + _NC + '\n') 95 | sys.stderr.flush() 96 | else: 97 | with open(file, 'a+') as f: 98 | print(msg, file=f) 99 | 100 | 101 | def error(*messages, file=None): 102 | if log_level > ERROR: 103 | return 104 | msg = _format('E', messages) 105 | if file is None: 106 | sys.stderr.write(_RED + msg + _NC + '\n') 107 | sys.stderr.flush() 108 | else: 109 | with open(file, 'a+') as f: 110 | print(msg, file=f) 111 | 112 | 113 | def fatal(*messages, file=None): 114 | if log_level > FATAL: 115 | return 116 | msg = _format('F', messages) 117 | if file is None: 118 | sys.stderr.write(_LIGHT_RED + msg + _NC + '\n') 119 | sys.stderr.flush() 120 | else: 121 | with open(file, 'a+') as f: 122 | print(msg, file=f) 123 | 124 | raise LoggerFatalError(-1) 125 | -------------------------------------------------------------------------------- /utils/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='CRNet PyTorch Training') 4 | 5 | 6 | # ========================== Indispensable arguments ========================== 7 | 8 | parser.add_argument('--data-dir', type=str, required=True, 9 | help='the path of dataset.') 10 | parser.add_argument('--scenario', type=str, required=True, choices=["in", "out"], 11 | help="the channel scenario") 12 | parser.add_argument('-b', '--batch-size', type=int, required=True, metavar='N', 13 | help='mini-batch size') 14 | parser.add_argument('-j', '--workers', type=int, metavar='N', required=True, 15 | help='number of data loading workers') 16 | 17 | 18 | # ============================= Optical arguments ============================= 19 | 20 | # Working mode arguments 21 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 22 | help='evaluate model on validation set') 23 | parser.add_argument('--pretrained', type=str, default=None, 24 | help='using locally pre-trained model. The path of pre-trained model should be given') 25 | parser.add_argument('--resume', type=str, metavar='PATH', default=None, 26 | help='path to latest checkpoint (default: none)') 27 | parser.add_argument('--seed', default=None, type=int, 28 | help='seed for initializing training. ') 29 | parser.add_argument('--gpu', default=None, type=int, 30 | help='GPU id to use.') 31 | parser.add_argument('--cpu', action='store_true', 32 | help='disable GPU training (default: False)') 33 | parser.add_argument('--cpu-affinity', default=None, type=str, 34 | help='CPU affinity, like "0xffff"') 35 | 36 | # Other arguments 37 | parser.add_argument('--epochs', type=int, metavar='N', 38 | help='number of total epochs to run') 39 | parser.add_argument('--cr', metavar='N', type=int, default=4, 40 | help='compression ratio') 41 | parser.add_argument('-d', '--d_model', type=int, default=64, metavar= 'N', help= 'number of Transformer feature dimension.' ) 42 | parser.add_argument('--scheduler', type=str, default='const', choices=['const', 'cosine'], 43 | help='learning rate scheduler') 44 | 45 | args = parser.parse_args() 46 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | __all__ = ['WarmUpCosineAnnealingLR', 'FakeLR'] 5 | 6 | 7 | class WarmUpCosineAnnealingLR(_LRScheduler): 8 | def __init__(self, optimizer, T_max, T_warmup, eta_min=0, last_epoch=-1): 9 | self.T_max = T_max 10 | self.T_warmup = T_warmup 11 | self.eta_min = eta_min 12 | super(WarmUpCosineAnnealingLR, self).__init__(optimizer, last_epoch) 13 | 14 | def get_lr(self): 15 | if self.last_epoch < self.T_warmup: 16 | return [base_lr * self.last_epoch / self.T_warmup for base_lr in self.base_lrs] 17 | else: 18 | k = 1 + math.cos(math.pi * (self.last_epoch - self.T_warmup) / (self.T_max - self.T_warmup)) 19 | return [self.eta_min + (base_lr - self.eta_min) * k / 2 for base_lr in self.base_lrs] 20 | 21 | 22 | class FakeLR(_LRScheduler): 23 | def __init__(self, optimizer): 24 | super(FakeLR, self).__init__(optimizer=optimizer) 25 | 26 | def get_lr(self): 27 | return self.base_lrs 28 | -------------------------------------------------------------------------------- /utils/solver.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import torch 4 | from collections import namedtuple 5 | from tensorboardX import SummaryWriter 6 | from utils import logger 7 | from utils.statics import AverageMeter, evaluator 8 | 9 | __all__ = ['Trainer', 'Tester'] 10 | 11 | 12 | field = ('nmse', 'rho', 'epoch') 13 | Result = namedtuple('Result', field, defaults=(None,) * len(field)) 14 | vision_test = SummaryWriter(log_dir="data_vision/test") 15 | vision_best = SummaryWriter(log_dir="data_vision/best") 16 | vision_every = SummaryWriter(log_dir="data_vision/every") 17 | 18 | class Trainer: 19 | r""" The training pipeline for encoder-decoder architecture 20 | """ 21 | 22 | def __init__(self, model, device, optimizer, criterion, scheduler, resume=None, 23 | save_path='./checkpoints', print_freq=20, val_freq=10, test_freq=10): 24 | 25 | # Basic arguments 26 | self.model = model 27 | self.optimizer = optimizer 28 | self.criterion = criterion 29 | self.scheduler = scheduler 30 | self.device = device 31 | 32 | # Verbose arguments 33 | self.resume_file = resume 34 | self.save_path = save_path 35 | self.print_freq = print_freq 36 | self.val_freq = val_freq 37 | self.test_freq = test_freq 38 | 39 | # Pipeline arguments 40 | self.cur_epoch = 1 41 | self.all_epoch = None 42 | self.train_loss = None 43 | self.val_loss = None 44 | self.test_loss = None 45 | self.best_rho = Result() 46 | self.best_nmse = Result() 47 | 48 | self.tester = Tester(model, device, criterion, print_freq) 49 | self.test_loader = None 50 | 51 | def loop(self, epochs, train_loader, val_loader, test_loader): 52 | r""" The main loop function which runs training and validation iteratively. 53 | 54 | Args: 55 | epochs (int): The total epoch for training 56 | train_loader (DataLoader): Data loader for training data. 57 | val_loader (DataLoader): Data loader for validation data. 58 | test_loader (DataLoader): Data loader for test data. 59 | """ 60 | 61 | self.all_epoch = epochs 62 | self._resume() 63 | 64 | for ep in range(self.cur_epoch, epochs + 1): 65 | self.cur_epoch = ep 66 | 67 | # conduct training, validation and test 68 | self.train_loss = self.train(train_loader) 69 | if ep % self.val_freq == 0: 70 | self.val_loss = self.val(val_loader) 71 | 72 | if ep % self.test_freq == 0: 73 | self.test_loss, rho, nmse = self.test(test_loader) 74 | vision_test.add_scalar("test loss", self.test_loss, global_step=ep) 75 | vision_test.add_scalar("test rho", rho, global_step=ep) 76 | vision_test.add_scalar("test nmse", nmse, global_step=ep) 77 | vision_test.add_scalar("train loss", self.train_loss, global_step=ep) 78 | else: 79 | rho, nmse = None, None 80 | 81 | # conduct saving, visualization and log printing 82 | self._loop_postprocessing(rho, nmse) 83 | 84 | def train(self, train_loader): 85 | r""" train the model on the given data loader for one epoch. 86 | 87 | Args: 88 | train_loader (DataLoader): the training data loader 89 | """ 90 | 91 | self.model.train() 92 | with torch.enable_grad(): 93 | return self._iteration(train_loader) 94 | 95 | def val(self, val_loader): 96 | r""" exam the model with validation set. 97 | 98 | Args: 99 | val_loader: (DataLoader): the validation data loader 100 | """ 101 | 102 | self.model.eval() 103 | with torch.no_grad(): 104 | return self._iteration(val_loader) 105 | 106 | def test(self, test_loader): 107 | r""" Truly test the model on the test dataset for one epoch. 108 | 109 | Args: 110 | test_loader (DataLoader): the test data loader 111 | """ 112 | 113 | self.model.eval() 114 | with torch.no_grad(): 115 | return self.tester(test_loader, verbose=False) 116 | 117 | def _iteration(self, data_loader): 118 | iter_loss = AverageMeter('Iter loss') 119 | iter_time = AverageMeter('Iter time') 120 | time_tmp = time.time() 121 | 122 | for batch_idx, (sparse_gt, ) in enumerate(data_loader): 123 | sparse_gt = sparse_gt.to(self.device) 124 | sparse_pred = self.model(sparse_gt) 125 | loss = self.criterion(sparse_pred, sparse_gt) 126 | 127 | # Scheduler update, backward pass and optimization 128 | if self.model.training: 129 | self.optimizer.zero_grad() 130 | loss.backward() 131 | self.optimizer.step() 132 | self.scheduler.step() 133 | 134 | # Log and visdom update 135 | iter_loss.update(loss) 136 | iter_time.update(time.time() - time_tmp) 137 | time_tmp = time.time() 138 | 139 | # plot progress 140 | if (batch_idx + 1) % self.print_freq == 0: 141 | logger.info(f'Epoch: [{self.cur_epoch}/{self.all_epoch}]' 142 | f'[{batch_idx + 1}/{len(data_loader)}] ' 143 | f'lr: {self.scheduler.get_lr()[0]:.2e} | ' 144 | f'MSE loss: {iter_loss.avg:.3e} | ' 145 | f'time: {iter_time.avg:.3f}') 146 | vision_every.add_scalar(" lr ",self.scheduler.get_lr()[0],global_step=self.cur_epoch) 147 | vision_every.add_scalar(" MSE loss",iter_loss.avg , self.cur_epoch) 148 | 149 | mode = 'Train' if self.model.training else 'Val' 150 | logger.info(f'=> {mode} Loss: {iter_loss.avg:.3e}\n') 151 | 152 | return iter_loss.avg 153 | 154 | def _save(self, state, name): 155 | if self.save_path is None: 156 | logger.warning('No path to save checkpoints.') 157 | return 158 | 159 | os.makedirs(self.save_path, exist_ok=True) 160 | torch.save(state, os.path.join(self.save_path, name)) 161 | 162 | def _resume(self): 163 | r""" protected function which resume from checkpoint at the beginning of training. 164 | """ 165 | 166 | if self.resume_file is None: 167 | return None 168 | assert os.path.isfile(self.resume_file) 169 | logger.info(f'=> loading checkpoint {self.resume_file}') 170 | checkpoint = torch.load(self.resume_file) 171 | self.cur_epoch = checkpoint['epoch'] 172 | self.model.load_state_dict(checkpoint['state_dict']) 173 | self.optimizer.load_state_dict(checkpoint['optimizer']) 174 | self.scheduler.load_state_dict(checkpoint['scheduler']) 175 | self.best_rho = checkpoint['best_rho'] 176 | self.best_nmse = checkpoint['best_nmse'] 177 | self.cur_epoch += 1 # start from the next epoch 178 | 179 | logger.info(f'=> successfully loaded checkpoint {self.resume_file} ' 180 | f'from epoch {checkpoint["epoch"]}.\n') 181 | 182 | def _loop_postprocessing(self, rho, nmse): 183 | r""" private function which makes loop() function neater. 184 | """ 185 | 186 | # save state generate 187 | state = { 188 | 'epoch': self.cur_epoch, 189 | 'state_dict': self.model.state_dict(), 190 | 'optimizer': self.optimizer.state_dict(), 191 | 'scheduler': self.scheduler.state_dict(), 192 | 'best_rho': self.best_rho, 193 | 'best_nmse': self.best_nmse 194 | } 195 | 196 | # save model with best rho and nmse 197 | if rho is not None: 198 | if self.best_rho.rho is None or self.best_rho.rho < rho: 199 | self.best_rho = Result(rho=rho, nmse=nmse, epoch=self.cur_epoch) 200 | state['best_rho'] = self.best_rho 201 | self._save(state, name=f"best_rho.pth") 202 | if self.best_nmse.nmse is None or self.best_nmse.nmse > nmse: 203 | self.best_nmse = Result(rho=rho, nmse=nmse, epoch=self.cur_epoch) 204 | state['best_nmse'] = self.best_nmse 205 | self._save(state, name=f"best_nmse.pth") 206 | 207 | self._save(state, name='last.pth') 208 | 209 | # print current best results 210 | if self.best_rho.rho is not None: 211 | print(f'\n=! Best rho: {self.best_rho.rho:.3e} (' 212 | f'Corresponding nmse={self.best_rho.nmse:.3e}; ' 213 | f'epoch={self.best_rho.epoch})' 214 | f'\n Best NMSE: {self.best_nmse.nmse:.3e} (' 215 | f'Corresponding rho={self.best_nmse.rho:.3e}; ' 216 | f'epoch={self.best_nmse.epoch})\n') 217 | vision_best.add_scalar(" best rho ",self.best_rho.rho, global_step=self.best_rho.epoch) 218 | vision_best.add_scalar(" best MSE ", self.best_nmse.nmse, global_step=self.best_nmse.epoch) 219 | 220 | 221 | 222 | class Tester: 223 | r""" The testing interface for classification 224 | """ 225 | 226 | def __init__(self, model, device, criterion, print_freq=20): 227 | self.model = model 228 | self.device = device 229 | self.criterion = criterion 230 | self.print_freq = print_freq 231 | 232 | def __call__(self, test_data, verbose=True): 233 | r""" Runs the testing procedure. 234 | 235 | Args: 236 | test_data (DataLoader): Data loader for validation data. 237 | """ 238 | 239 | self.model.eval() 240 | with torch.no_grad(): 241 | loss, rho, nmse = self._iteration(test_data) 242 | if verbose: 243 | print(f'\n=> Test result: \nloss: {loss:.3e}' 244 | f' rho: {rho:.3e} NMSE: {nmse:.3e}\n') 245 | return loss, rho, nmse 246 | 247 | def _iteration(self, data_loader): 248 | r""" protected function which test the model on given data loader for one epoch. 249 | """ 250 | 251 | iter_rho = AverageMeter('Iter rho') 252 | iter_nmse = AverageMeter('Iter nmse') 253 | iter_loss = AverageMeter('Iter loss') 254 | iter_time = AverageMeter('Iter time') 255 | time_tmp = time.time() 256 | 257 | for batch_idx, (sparse_gt, raw_gt) in enumerate(data_loader): 258 | sparse_gt = sparse_gt.to(self.device) 259 | sparse_pred = self.model(sparse_gt) 260 | loss = self.criterion(sparse_pred, sparse_gt) 261 | rho, nmse = evaluator(sparse_pred, sparse_gt, raw_gt) 262 | 263 | # Log and visdom update 264 | iter_loss.update(loss) 265 | iter_rho.update(rho) 266 | iter_nmse.update(nmse) 267 | iter_time.update(time.time() - time_tmp) 268 | time_tmp = time.time() 269 | 270 | # plot progress 271 | if (batch_idx + 1) % self.print_freq == 0: 272 | logger.info(f'[{batch_idx + 1}/{len(data_loader)}] ' 273 | f'loss: {iter_loss.avg:.3e} | rho: {iter_rho.avg:.3e} | ' 274 | f'NMSE: {iter_nmse.avg:.3e} | time: {iter_time.avg:.3f}') 275 | 276 | logger.info(f'=> Test rho:{iter_rho.avg:.3e} NMSE: {iter_nmse.avg:.3e}\n') 277 | 278 | 279 | return iter_loss.avg, iter_rho.avg, iter_nmse.avg 280 | -------------------------------------------------------------------------------- /utils/statics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | __all__ = ['AverageMeter', 'evaluator'] 4 | 5 | 6 | class AverageMeter(object): 7 | r"""Computes and stores the average and current value 8 | Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self, name): 11 | self.reset() 12 | self.val = 0 13 | self.avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | self.name = name 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | def __repr__(self): 31 | return f"==> For {self.name}: sum={self.sum}; avg={self.avg}" 32 | 33 | 34 | def evaluator(sparse_pred, sparse_gt, raw_gt): 35 | r""" Evaluation of decoding implemented in PyTorch Tensor 36 | Computes normalized mean square error (NMSE) and rho. 37 | """ 38 | 39 | with torch.no_grad(): 40 | # Basic params 41 | nt = 32 42 | nc = 32 43 | nc_expand = 257 44 | 45 | # De-centralize 46 | sparse_gt = sparse_gt - 0.5 47 | sparse_pred = sparse_pred - 0.5 48 | 49 | # Calculate the NMSE 50 | power_gt = sparse_gt[:, 0, :, :] ** 2 + sparse_gt[:, 1, :, :] ** 2 51 | difference = sparse_gt - sparse_pred 52 | mse = difference[:, 0, :, :] ** 2 + difference[:, 1, :, :] ** 2 53 | nmse = 10 * torch.log10((mse.sum(dim=[1, 2]) / power_gt.sum(dim=[1, 2])).mean()) 54 | 55 | # Calculate the Rho 56 | n = sparse_pred.size(0) 57 | sparse_pred = sparse_pred.permute(0, 2, 3, 1) # Move the real/imaginary dim to the last 58 | zeros = sparse_pred.new_zeros((n, nt, nc_expand - nc, 2)) 59 | sparse_pred = torch.cat((sparse_pred, zeros), dim=2) 60 | raw_pred = torch.fft(sparse_pred, signal_ndim=1)[:, :, :125, :] 61 | 62 | norm_pred = raw_pred[..., 0] ** 2 + raw_pred[..., 1] ** 2 63 | norm_pred = torch.sqrt(norm_pred.sum(dim=1)) 64 | 65 | norm_gt = raw_gt[..., 0] ** 2 + raw_gt[..., 1] ** 2 66 | norm_gt = torch.sqrt(norm_gt.sum(dim=1)) 67 | 68 | real_cross = raw_pred[..., 0] * raw_gt[..., 0] + raw_pred[..., 1] * raw_gt[..., 1] 69 | real_cross = real_cross.sum(dim=1) 70 | imag_cross = raw_pred[..., 0] * raw_gt[..., 1] - raw_pred[..., 1] * raw_gt[..., 0] 71 | imag_cross = imag_cross.sum(dim=1) 72 | norm_cross = torch.sqrt(real_cross ** 2 + imag_cross ** 2) 73 | 74 | rho = (norm_cross / (norm_pred * norm_gt)).mean() 75 | 76 | return rho, nmse 77 | --------------------------------------------------------------------------------