├── README.md ├── feature_load.py ├── figures ├── consistent_system_model.png ├── inconsistent_system_model.png └── sphere_packing_mcr2.png ├── main_train_phase1.py ├── main_train_phase2.py ├── mcr2_loss.py ├── models ├── model.py └── mvcnn.py ├── precoding_opt_matlab ├── H_slot_rician_channel_gen.m ├── MCR2_ModelNet10_statistics_complex_24.mat ├── MCR2_ModelNet10_test_feature_label_complex_24.mat ├── MCR2_obj_cplx.m ├── bisection_search_bar.m ├── f_bar.m ├── gaussian_pdf.m ├── main_precoding_opt.m └── steering_vec.m └── tools ├── evaluate_func.py ├── img_dataset.py ├── train_func.py ├── trainer.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # TaskCommMCR2 2 | 3 | This repository is the official implementation of the paper: 4 | 5 | - **Multi-Device Task-Oriented Communication via Maximal Coding Rate Reduction** [[IEEE TWC](https://ieeexplore.ieee.org/abstract/document/10689268)] [[arXiv](https://arxiv.org/abs/2309.02888)] 6 | - **Authors:** [Chang Cai](https://chang-cai.github.io/) (The Chinese University of Hong Kong), [Xiaojun Yuan](https://scholar.google.com/citations?user=o6W_m00AAAAJ&hl=en) (University of Electronic Science and Technology of China), and [Ying-Jun Angela Zhang](https://staff.ie.cuhk.edu.hk/~yjzhang/) (The Chinese University of Hong Kong) 7 | 8 | ## Brief Introduction 9 | 10 | ### Existing Studies: Inconsistent Objectives for Learning and Communication 11 | 12 |

13 | 14 |

15 | 16 | ### This Work: Synergistic Alignment of Learning and Communication Objectives 17 | 18 |

19 | 20 |

21 | 22 | 23 | ## Usage 24 | 25 | ### Feature Encoding 26 | - Download images and put it under ```modelnet40_images_new_12x```: [Shaded Images (1.6GB)](http://supermoe.cs.umass.edu/shape_recog/shaded_images.tar.gz). 27 | If the link does not work, you can download it from my [Google Drive](https://drive.google.com/file/d/1tghRek04_pVHCkYOTOQeYtGlENlgQzhM/view?usp=sharing) backup. 28 | 29 | - Set environment: code is tested on ```python 3.7.13``` and ```pytorch 1.12.1```. 30 | 31 | - Run the script ```main_train_phase1.py``` for the first-phase training of feature encoding. 32 | Then, run the script ```main_train_phase2.py``` for the second-phase training of feature encoding. Check [mvcnn_pytorch](https://github.com/jongchyisu/mvcnn_pytorch) for the details of the two training phases. 33 | 34 | - Alternatively, download the pretrained checkpoints at [Google Drive](https://drive.google.com/drive/folders/1bi2kMot2XI3H27MitiCE6ecxATnGIn5r?usp=drive_link). The checkpoints can be used for feature extraction by running the script ```feature_load.py```. 35 | 36 | ### Precoding Optimization and Performance Evaluation 37 | 38 | The code is located at the folder ```precoding_opt_matlab```. 39 | Run the script ```main_precoding_opt.m``` to compare the performance of the proposed MCR2 precoder and the LMMSE precoder. 40 | 41 | ## Citation 42 | If you find our work interesting, please consider citing 43 | 44 | ``` 45 | @ARTICLE{task_comm_mcr2, 46 | author={Cai, Chang and Yuan, Xiaojun and Zhang, Ying-Jun Angela}, 47 | journal={IEEE Transactions on Wireless Communications}, 48 | title={Multi-Device Task-Oriented Communication via Maximal Coding Rate Reduction}, 49 | year={2024}, 50 | volume={23}, 51 | number={12}, 52 | pages={18096-18110} 53 | } 54 | ``` 55 | [Our follow-up work](https://ieeexplore.ieee.org/abstract/document/10845817) provides an information-theoretic interpretation of the learning-communication separation, as well as an end-to-end learning framework: 56 | 57 | ``` 58 | @ARTICLE{info_theoretic_e2e, 59 | author={Cai, Chang and Yuan, Xiaojun and Zhang, Ying-Jun Angela}, 60 | journal={IEEE Journal on Selected Areas in Communications}, 61 | title={End-to-End Learning for Task-Oriented Semantic Communications Over {MIMO} Channels: An Information-Theoretic Framework}, 62 | year={2025}, 63 | volume={}, 64 | number={}, 65 | pages={1-16} 66 | } 67 | ``` -------------------------------------------------------------------------------- /feature_load.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import argparse 4 | 5 | from tools.img_dataset import MultiviewImgDataset 6 | from models.mvcnn import MVCNN, SVCNN 7 | 8 | from tools import train_func as tf 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--bs", "--batch_size", type=int, default=30, 12 | help="batch size") 13 | parser.add_argument("--cnn_name", type=str, default="vgg11", 14 | help="cnn model name") 15 | parser.add_argument("--num_views", type=int, default=3, 16 | help="number of views") 17 | parser.add_argument("--num_classes", type=int, default=10, 18 | help="number of classes") 19 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train") 20 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test") 21 | parser.add_argument("--num_workers", type=int, default=24, 22 | help="number of workers") 23 | parser.add_argument('--svcnn_model_dir', type=str, default='./mvcnn/phase1_classes10_views3_fd1_32_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/svcnn/model-00029.pth', 24 | help='base directory for svcnn model') 25 | parser.add_argument('--mvcnn_model_dir', type=str, default='./mvcnn/phase2_classes10_views3_fd1_32_fd2_8_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/mvcnn/model-00199.pth', 26 | help='base directory for mvcnn model') 27 | parser.add_argument('--pretraining', type=bool, default=True, 28 | help='pretraining') 29 | parser.add_argument('--fd_phase1', type=int, default=32, 30 | help='dimension of feature dimension in phase 1') 31 | parser.add_argument('--fd_phase2', type=int, default=8, 32 | help='dimension of feature dimension per user in phase 2') 33 | args = parser.parse_args() 34 | 35 | ## CUDA 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 37 | 38 | if __name__ == '__main__': 39 | 40 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining, 41 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1) 42 | state_dict_1 = torch.load(args.svcnn_model_dir, map_location=torch.device('cpu')) 43 | cnet.load_state_dict(state_dict_1) 44 | cnet_2 = MVCNN(name='mvcnn', model=cnet, nclasses=args.num_classes, cnn_name=args.cnn_name, 45 | num_views=args.num_views, fd_phase1=args.fd_phase1, fd_per_user=args.fd_phase2) 46 | state_dict_2 = torch.load(args.mvcnn_model_dir, map_location=torch.device('cpu')) 47 | cnet_2.load_state_dict(state_dict_2) 48 | cnet_2.eval() 49 | del cnet 50 | 51 | train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, 52 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999) 53 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, 54 | shuffle=False, num_workers=args.num_workers) 55 | # shuffle needs to be false! it's done within the trainer 56 | 57 | val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True, 58 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999) 59 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, 60 | shuffle=False, num_workers=args.num_workers) 61 | print('num_train_files: '+str(len(train_dataset.filepaths))) 62 | print('num_val_files: '+str(len(val_dataset.filepaths))) 63 | 64 | train_features, train_labels = tf.get_features(cnet_2, train_loader, 'mcr2', 'mvcnn') 65 | feature = np.array(train_features.detach().cpu()) 66 | target = np.array(train_labels.detach().cpu()) 67 | mdic = {"train_feature": feature, "train_label": target} 68 | # savemat(f"MCR2_ModelNet10_train_feature_label_36.mat", mdic) 69 | 70 | test_features, test_labels = tf.get_features(cnet_2, val_loader, 'mcr2', 'mvcnn') 71 | feature = np.array(test_features.detach().cpu()) 72 | target = np.array(test_labels.detach().cpu()) 73 | mdic = {"test_feature": feature, "test_label": target} 74 | # savemat(f"MCR2_ModelNet10_test_feature_label_36.mat", mdic) 75 | 76 | 77 | -------------------------------------------------------------------------------- /figures/consistent_system_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/consistent_system_model.png -------------------------------------------------------------------------------- /figures/inconsistent_system_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/inconsistent_system_model.png -------------------------------------------------------------------------------- /figures/sphere_packing_mcr2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/sphere_packing_mcr2.png -------------------------------------------------------------------------------- /main_train_phase1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import os 4 | import argparse 5 | 6 | from tools.trainer import ModelNetTrainer 7 | from tools.img_dataset import SingleImgDataset 8 | from models.mvcnn import SVCNN 9 | 10 | from mcr2_loss import MaximalCodingRateReduction 11 | from tools import utils 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--bs", "--batch_size", type=int, default=1200, 15 | help="batch size") 16 | parser.add_argument("--lr", type=float, default=1e-4, 17 | help="learning rate") 18 | parser.add_argument("--weight_decay", type=float, help="weight decay", default=0.001) 19 | parser.add_argument("--cnn_name", type=str, default="vgg11", 20 | help="cnn model name") 21 | parser.add_argument("--num_views", type=int, default=3, 22 | help="number of views") 23 | parser.add_argument("--num_classes", type=int, default=10, 24 | help="number of classes") 25 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train") 26 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test") 27 | parser.add_argument("--num_workers", type=int, default=32, 28 | help="number of workers") 29 | parser.add_argument('--save_dir', type=str, default='./mvcnn/', 30 | help='base directory for saving PyTorch model') 31 | parser.add_argument('--epoch', type=int, default=500, 32 | help='number of epochs for training') 33 | parser.add_argument('--eps', type=float, default=0.5, 34 | help='eps squared') 35 | parser.add_argument('--fd_phase1', type=int, default=32, 36 | help='dimension of feature dimension') 37 | parser.add_argument('--tail', type=str, default='', 38 | help='extra information to add to folder name') 39 | parser.add_argument('--pretraining', type=bool, default=True, 40 | help='pretraining') 41 | parser.add_argument('--mom', type=float, default=0.9, 42 | help='momentum') 43 | args = parser.parse_args() 44 | 45 | ## CUDA 46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 47 | 48 | 49 | if __name__ == '__main__': 50 | ## Pipelines Setup 51 | model_dir = os.path.join(args.save_dir, 52 | 'phase1_classes{}_views{}_fd1_{}_bs{}_lr{}_wd{}_eps{}{}'.format( 53 | args.num_classes, args.num_views, args.fd_phase1, args.bs, args.lr, 54 | args.weight_decay, args.eps, args.tail)) 55 | utils.init_pipeline(model_dir) 56 | utils.save_params(model_dir, vars(args)) 57 | 58 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining, 59 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1) 60 | cnet = cnet.to(device) 61 | optimizer = optim.Adam(cnet.parameters(), lr=args.lr, weight_decay=args.weight_decay) 62 | # optimizer = optim.SGD(cnet.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) 63 | 64 | train_dataset = SingleImgDataset(args.train_path, scale_aug=False, rot_aug=False, 65 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999) # 80 66 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, 67 | shuffle=True, num_workers=args.num_workers) 68 | 69 | val_dataset = SingleImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True, 70 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999) # 20 71 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, 72 | shuffle=False, num_workers=args.num_workers) 73 | print('num_train_files: '+str(len(train_dataset.filepaths))) 74 | print('num_val_files: '+str(len(val_dataset.filepaths))) 75 | 76 | loss_fn = MaximalCodingRateReduction(gam1=1, gam2=1, eps=args.eps) 77 | trainer = ModelNetTrainer(cnet, device, train_loader, val_loader, optimizer, loss_fn, 'svcnn', 78 | model_dir, num_classes=args.num_classes, num_views=1) 79 | trainer.train(args.epoch) 80 | 81 | 82 | -------------------------------------------------------------------------------- /main_train_phase2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import os 4 | import argparse 5 | 6 | from tools.trainer import ModelNetTrainer 7 | from tools.img_dataset import MultiviewImgDataset 8 | from models.mvcnn import MVCNN, SVCNN 9 | 10 | from mcr2_loss import MaximalCodingRateReduction 11 | from tools import utils 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--bs", "--batch_size", type=int, default=1200, 15 | help="batch size") 16 | parser.add_argument("--lr", type=float, default=1e-4, 17 | help="learning rate") 18 | parser.add_argument("--weight_decay", type=float, help="weight decay", default=0.001) 19 | parser.add_argument("--cnn_name", type=str, default="vgg11", 20 | help="cnn model name") 21 | parser.add_argument("--num_views", type=int, default=3, 22 | help="number of views") 23 | parser.add_argument("--num_classes", type=int, default=10, 24 | help="number of classes") 25 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train") 26 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test") 27 | parser.add_argument("--num_workers", type=int, default=32, 28 | help="number of workers") 29 | parser.add_argument('--save_dir', type=str, default='./mvcnn/', 30 | help='base directory for saving PyTorch model') 31 | parser.add_argument('--svcnn_model_dir', type=str, default='./mvcnn/phase1_classes10_views3_fd1_32_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/svcnn/model-00029.pth', 32 | help='base directory for svcnn model') 33 | parser.add_argument('--epoch', type=int, default=500, 34 | help='number of epochs for training') 35 | parser.add_argument('--eps', type=float, default=0.5, 36 | help='eps squared') 37 | parser.add_argument('--fd_phase1', type=int, default=32, 38 | help='dimension of feature dimension in phase 1') 39 | parser.add_argument('--fd_phase2', type=int, default=8, 40 | help='dimension of feature dimension per user in phase 2') 41 | parser.add_argument('--tail', type=str, default='', 42 | help='extra information to add to folder name') 43 | parser.add_argument('--pretraining', type=bool, default=True, 44 | help='pretraining') 45 | parser.add_argument('--mom', type=float, default=0.9, 46 | help='momentum') 47 | args = parser.parse_args() 48 | 49 | ## CUDA 50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 51 | 52 | 53 | if __name__ == '__main__': 54 | ## Pipelines Setup 55 | model_dir = os.path.join(args.save_dir, 56 | 'phase2_classes{}_views{}_fd1_{}_fd2_{}_bs{}_lr{}_wd{}_eps{}{}'.format( 57 | args.num_classes, args.num_views, args.fd_phase1, args.fd_phase2, args.bs, args.lr, 58 | args.weight_decay, args.eps, args.tail)) 59 | utils.init_pipeline(model_dir) 60 | utils.save_params(model_dir, vars(args)) 61 | 62 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining, 63 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1) 64 | state_dict = torch.load(args.svcnn_model_dir, map_location=torch.device('cpu')) 65 | cnet.load_state_dict(state_dict) 66 | cnet_2 = MVCNN(name='mvcnn', model=cnet, nclasses=args.num_classes, cnn_name=args.cnn_name, 67 | num_views=args.num_views, fd_phase1=args.fd_phase1, fd_per_user=args.fd_phase2) 68 | del cnet 69 | 70 | cnet_2 = cnet_2.to(device) 71 | optimizer = optim.Adam(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay) 72 | 73 | train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False, 74 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999) # 80 75 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs, 76 | shuffle=False, num_workers=args.num_workers) 77 | # shuffle needs to be false! it's done within the trainer 78 | 79 | val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True, 80 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999) # 20 81 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs, 82 | shuffle=False, num_workers=args.num_workers) 83 | print('num_train_files: '+str(len(train_dataset.filepaths))) 84 | print('num_val_files: '+str(len(val_dataset.filepaths))) 85 | 86 | loss_fn = MaximalCodingRateReduction(gam1=1, gam2=1, eps=args.eps) 87 | trainer = ModelNetTrainer(cnet_2, device, train_loader, val_loader, optimizer, loss_fn, 'mvcnn', 88 | model_dir, num_classes=args.num_classes, num_views=args.num_views) 89 | trainer.train(args.epoch) 90 | 91 | 92 | -------------------------------------------------------------------------------- /mcr2_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from tools import train_func as tf 4 | 5 | 6 | # import utils 7 | 8 | 9 | # def one_hot(labels_int, n_classes): 10 | # """Turn labels into one hot vector of K classes. """ 11 | # labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float() 12 | # for i, y in enumerate(labels_int): 13 | # labels_onehot[i, y] = 1. 14 | # return labels_onehot 15 | # 16 | # 17 | # def label_to_membership(targets, num_classes=None): 18 | # """Generate a true membership matrix, and assign value to current Pi. 19 | # 20 | # Parameters: 21 | # targets (np.ndarray): matrix with one hot labels 22 | # 23 | # Return: 24 | # Pi: membership matirx, shape (num_classes, num_samples, num_samples) 25 | # 26 | # """ 27 | # targets = one_hot(targets, num_classes) 28 | # num_samples, num_classes = targets.shape 29 | # Pi = np.zeros(shape=(num_classes, num_samples, num_samples)) 30 | # for j in range(len(targets)): 31 | # k = np.argmax(targets[j]) 32 | # Pi[k, j, j] = 1. 33 | # return Pi 34 | 35 | 36 | class MaximalCodingRateReduction(torch.nn.Module): 37 | def __init__(self, gam1=1.0, gam2=1.0, eps=0.01): 38 | super(MaximalCodingRateReduction, self).__init__() 39 | self.gam1 = gam1 40 | self.gam2 = gam2 41 | self.eps = eps 42 | 43 | def compute_discrimn_loss_empirical(self, W): 44 | """Empirical Discriminative Loss.""" 45 | p, m = W.shape 46 | I = torch.eye(p).to(W.device) 47 | scalar = p / (m * self.eps) 48 | logdet = self.logdet(I + self.gam1 * scalar * W.matmul(W.T)) 49 | return logdet / 2. 50 | 51 | def compute_compress_loss_empirical(self, W, Pi): 52 | """Empirical Compressive Loss.""" 53 | p, m = W.shape 54 | k, _, _ = Pi.shape 55 | I = torch.eye(p).to(W.device) 56 | compress_loss = 0. 57 | for j in range(k): 58 | trPi = torch.trace(Pi[j]) + 1e-8 59 | scalar = p / (trPi * self.eps) 60 | log_det = self.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T)) 61 | compress_loss += log_det * trPi / m 62 | return compress_loss / 2. 63 | 64 | def logdet(self, X): 65 | sgn, logdet = torch.linalg.slogdet(X) 66 | return sgn * logdet 67 | 68 | def forward(self, X, Y, num_classes=None): 69 | if num_classes is None: 70 | num_classes = Y.max() + 1 71 | W = X.T 72 | Pi = tf.label_to_membership(Y.numpy(), num_classes) 73 | Pi = torch.tensor(Pi, dtype=torch.float32).to(X.device) 74 | 75 | discrimn_loss_empi = self.compute_discrimn_loss_empirical(W) 76 | compress_loss_empi = self.compute_compress_loss_empirical(W, Pi) 77 | 78 | total_loss_empi = self.gam2 * -discrimn_loss_empi + compress_loss_empi 79 | return (total_loss_empi, 80 | [discrimn_loss_empi.item(), compress_loss_empi.item()]) 81 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import glob 5 | 6 | 7 | class Model(nn.Module): 8 | 9 | def __init__(self, name): 10 | super(Model, self).__init__() 11 | self.name = name 12 | 13 | 14 | def save(self, path, epoch=0): 15 | complete_path = os.path.join(path, self.name) 16 | if not os.path.exists(complete_path): 17 | os.makedirs(complete_path) 18 | torch.save(self.state_dict(), 19 | os.path.join(complete_path, 20 | "model-{}.pth".format(str(epoch).zfill(5)))) 21 | 22 | 23 | def save_results(self, path, data): 24 | raise NotImplementedError("Model subclass must implement this method.") 25 | 26 | 27 | def load(self, path, modelfile=None): 28 | complete_path = os.path.join(path, self.name) 29 | if not os.path.exists(complete_path): 30 | raise IOError("{} directory does not exist in {}".format(self.name, path)) 31 | 32 | if modelfile is None: 33 | model_files = glob.glob(complete_path+"/*") 34 | mf = max(model_files) 35 | else: 36 | mf = os.path.join(complete_path, modelfile) 37 | 38 | self.load_state_dict(torch.load(mf)) 39 | 40 | 41 | -------------------------------------------------------------------------------- /models/mvcnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | from .model import Model 9 | import copy 10 | 11 | 12 | # mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 13 | # std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 14 | 15 | def flip(x, dim): 16 | xsize = x.size() 17 | dim = x.dim() + dim if dim < 0 else dim 18 | x = x.view(-1, *xsize[dim:]) 19 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1, 20 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :] 21 | return x.view(xsize) 22 | 23 | 24 | class SVCNN(Model): 25 | 26 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='vgg11', fd_phase1=32): 27 | super(SVCNN, self).__init__(name) 28 | 29 | if nclasses == 10: 30 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 31 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] 32 | elif nclasses == 40: 33 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 34 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 35 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 36 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 37 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 38 | 39 | self.nclasses = nclasses 40 | self.pretraining = pretraining 41 | self.cnn_name = cnn_name 42 | self.use_resnet = cnn_name.startswith('resnet') 43 | # self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 44 | # self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 45 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False) 46 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False) 47 | self.fd_phase1 = fd_phase1 48 | 49 | if self.use_resnet: 50 | if self.cnn_name == 'resnet18': 51 | self.net = models.resnet18(pretrained=self.pretraining) 52 | self.net.fc = nn.Linear(512, 40) 53 | elif self.cnn_name == 'resnet34': 54 | self.net = models.resnet34(pretrained=self.pretraining) 55 | self.net.fc = nn.Linear(512, 40) 56 | elif self.cnn_name == 'resnet50': 57 | self.net = models.resnet50(pretrained=self.pretraining) 58 | self.net.fc = nn.Linear(2048, 40) 59 | else: 60 | if self.cnn_name == 'alexnet': 61 | self.net_1 = models.alexnet(pretrained=self.pretraining).features 62 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier 63 | elif self.cnn_name == 'vgg11': 64 | self.net_1 = models.vgg11(pretrained=self.pretraining).features 65 | # self.net_2 = models.vgg11(pretrained=self.pretraining).classifier 66 | self.net_2 = torch.nn.Sequential( 67 | nn.Linear(25088, 4096, bias=True), # False 68 | nn.BatchNorm1d(4096), 69 | nn.ReLU(inplace=True), 70 | nn.Linear(4096, self.fd_phase1, bias=True) 71 | ) 72 | elif self.cnn_name == 'vgg16': 73 | self.net_1 = models.vgg16(pretrained=self.pretraining).features 74 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier 75 | 76 | # self.net_2._modules['6'] = nn.Linear(4096, self.fd) 77 | 78 | def forward(self, x): 79 | if self.use_resnet: 80 | return self.net(x) 81 | else: 82 | y = self.net_1(x) 83 | feature = F.normalize(self.net_2(y.view(y.shape[0], -1))) 84 | return feature 85 | 86 | 87 | class MVCNN(Model): 88 | 89 | def __init__(self, name, model, nclasses=40, cnn_name='vgg11', num_views=12, fd_phase1=32, fd_per_user=6): 90 | super(MVCNN, self).__init__(name) 91 | 92 | if nclasses == 10: 93 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 94 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] 95 | elif nclasses == 40: 96 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 97 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 98 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 99 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 100 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 101 | 102 | self.nclasses = nclasses 103 | self.num_views = num_views 104 | self.fd_per_user = fd_per_user 105 | self.fd_phase1 = fd_phase1 106 | # self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda() 107 | # self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda() 108 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False) 109 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False) 110 | 111 | # self.use_resnet = cnn_name.startswith('resnet') 112 | # 113 | # if self.use_resnet: 114 | # self.net_1 = nn.Sequential(*list(model.net.children())[:-1]) 115 | # self.net_2 = model.net.fc 116 | # else: 117 | # self.net_1 = model.net_1 118 | # self.net_2 = model.net_2 119 | 120 | self.net_1 = model.net_1 121 | 122 | net_2_list = [] 123 | for _ in range(self.num_views): 124 | net_2 = copy.deepcopy(model.net_2) 125 | net_2._modules['4'] = nn.BatchNorm1d(self.fd_phase1) 126 | net_2._modules['5'] = nn.ReLU(inplace=True) 127 | net_2._modules['6'] = nn.Linear(self.fd_phase1, 32, bias=True) # False 128 | net_2._modules['7'] = nn.BatchNorm1d(32) 129 | net_2._modules['8'] = nn.ReLU(inplace=True) 130 | net_2._modules['9'] = nn.Linear(32, self.fd_per_user, bias=True) 131 | 132 | net_2_list.append(net_2) 133 | self.net_2 = nn.ModuleList(net_2_list) 134 | 135 | def forward(self, x): 136 | y = self.net_1(x) # (bs*views,512,7,7) 137 | y = y.view((int(x.shape[0] / self.num_views), self.num_views, -1)) # (bs,views,25088) 138 | 139 | feature_list = [] 140 | for i in range(self.num_views): 141 | feature_i = self.net_2[i](y[:, i]) # (bs, fd_i) 142 | feature_list.append(feature_i) 143 | 144 | feature = torch.cat(feature_list, dim=1) # (bs, fd_i*views) 145 | return F.normalize(feature) 146 | -------------------------------------------------------------------------------- /precoding_opt_matlab/H_slot_rician_channel_gen.m: -------------------------------------------------------------------------------- 1 | function H_slot = H_slot_rician_channel_gen(N_t_k, N_r, kappa) 2 | 3 | AoA = deg2rad(360*rand(1)); 4 | AoD = deg2rad(360*rand(1)); 5 | H_LoS = steering_vec(sin(AoA), N_r) * steering_vec(sin(AoD), N_t_k)'; 6 | H_NLoS = 1/sqrt(2*N_r*N_t_k) * (randn(N_r, N_t_k) + 1j*randn(N_r, N_t_k)); 7 | % H_NLoS = 1/sqrt(2) * (randn(N_r, N_t_k) + 1j*randn(N_r, N_t_k)); 8 | H_slot = sqrt(kappa/(1+kappa))*H_LoS + sqrt(1/(1+kappa))*H_NLoS; 9 | 10 | end -------------------------------------------------------------------------------- /precoding_opt_matlab/MCR2_ModelNet10_statistics_complex_24.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/precoding_opt_matlab/MCR2_ModelNet10_statistics_complex_24.mat -------------------------------------------------------------------------------- /precoding_opt_matlab/MCR2_ModelNet10_test_feature_label_complex_24.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/precoding_opt_matlab/MCR2_ModelNet10_test_feature_label_complex_24.mat -------------------------------------------------------------------------------- /precoding_opt_matlab/MCR2_obj_cplx.m: -------------------------------------------------------------------------------- 1 | function [mcr2, mcr2_1, mcr2_2] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p) 2 | 3 | Dim = size(H, 1); 4 | beta = 1 + alpha*delta_0^2; 5 | 6 | mcr2_1 = log( det( beta*eye(Dim) + alpha*H*V*C*V'*H' ) ); 7 | 8 | mcr2_2 = 0; 9 | class = size(p, 2); 10 | for j = 1:class 11 | C_j = CLS_cov(:, :, j); 12 | mcr2_2 = mcr2_2 + p(j) * ... 13 | log( det( beta*eye(Dim) + alpha*H*V*C_j*V'*H' ) ); 14 | end 15 | % mcr2 = 0.5 * real(mcr2_1 - mcr2_2); 16 | % 17 | % mcr2_1 = 0.5 * real(mcr2_1); 18 | % mcr2_2 = 0.5 * real(mcr2_2); 19 | mcr2 = real(mcr2_1 - mcr2_2); 20 | 21 | mcr2_1 = real(mcr2_1); 22 | mcr2_2 = real(mcr2_2); 23 | end -------------------------------------------------------------------------------- /precoding_opt_matlab/bisection_search_bar.m: -------------------------------------------------------------------------------- 1 | function lambda = bisection_search_bar(lambda_lb, lambda_ub, P, epsilon, g, Gamma, matrix_form) 2 | % Detailed explanation goes here 3 | 4 | lb = lambda_lb; 5 | ub = lambda_ub; 6 | lambda = lambda_ub; 7 | 8 | while ub - lb >= epsilon 9 | lambda = (lb + ub)/2; 10 | if f_bar(lambda, g, Gamma, matrix_form) <= P 11 | ub = lambda; 12 | else 13 | lb = lambda; 14 | end 15 | end 16 | 17 | end -------------------------------------------------------------------------------- /precoding_opt_matlab/f_bar.m: -------------------------------------------------------------------------------- 1 | function fcn_bar = f_bar(lambda, g, Gamma, matrix_form) 2 | % Detailed explanation goes here 3 | 4 | fcn_bar = 0; 5 | if matrix_form == false 6 | for i = 1:size(Gamma, 1) 7 | fcn_i = norm(g(i))^2/( (Gamma(i, i) + lambda)^2 ); 8 | fcn_bar = fcn_bar + fcn_i; 9 | end 10 | else 11 | for i = 1:size(Gamma, 1) 12 | fcn_i = abs(g(i))/( (Gamma(i, i) + lambda)^2 ); 13 | fcn_bar = fcn_bar + fcn_i; 14 | end 15 | end 16 | end -------------------------------------------------------------------------------- /precoding_opt_matlab/gaussian_pdf.m: -------------------------------------------------------------------------------- 1 | function [f, log_f] = gaussian_pdf(x, mean, covariance, epsilon) 2 | % multivariate gaussian pdf 3 | 4 | n = length(x); 5 | coef_1 = 1/sqrt((2*pi)^n); 6 | % det(covariance+epsilon*eye(n)) 7 | coef_2 = 1/sqrt(det(covariance+epsilon*eye(n))); 8 | coef = coef_1 * coef_2; 9 | coef_3 = -0.5*(x-mean)'*inv(covariance+epsilon*eye(n))*(x-mean); 10 | expo = exp(coef_3); 11 | f = coef*expo; 12 | 13 | %% 14 | log_f = log(coef)+coef_3; 15 | 16 | end -------------------------------------------------------------------------------- /precoding_opt_matlab/main_precoding_opt.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | 4 | load("MCR2_ModelNet10_statistics_complex_24.mat", "feature_cplx", "label",... 5 | "C_xx_real", "mean_real", "C_xx", "CLS_mean", "CLS_cov", "CLS_rlt"); 6 | load("MCR2_ModelNet10_test_feature_label_complex_24.mat", "p", "test_feature_cplx", "test_label"); 7 | test_feature = test_feature_cplx.'; 8 | 9 | 10 | B = 10*1e3; % 10 kHz 11 | K = 3; % num. of devices 12 | N_t_k = 4; % num. of transmit antennas 13 | N_t = N_t_k * K; 14 | N_r = 8; % num. of receive antennas 15 | T = 1; % num. of time slots 16 | 17 | Dataset = size(test_feature, 2); 18 | % NMSE_factor = norm(test_feature, "fro")^2/Dataset; 19 | C = C_xx; % covariance matrix 20 | L = 10; % num. of classes 21 | D = 2*size(test_feature, 1); 22 | 23 | delta_0_square = 1e-20*B; % noise power (W) -170 dBm/Hz density 24 | delta_0 = sqrt(delta_0_square); % noise 25 | P_k_dBm = 0; % dBm 26 | P_k = 1e-3*10.^(P_k_dBm./10); % W 27 | P = P_k * K; 28 | 29 | d = 240; % m 30 | PL = 32.6 + 36.7*log10(d); % dB 31 | PL = 10^(-PL/10); 32 | delta_0 = delta_0/sqrt(PL); % normalize path loss into the noise power 33 | 34 | eps = 1e-3; 35 | 36 | % channel generation 37 | kappa = 1; % Rician factor 38 | H = []; 39 | H_k_all = zeros(T*N_r, T*N_t_k, K); 40 | H_k_t_all = zeros(N_r, N_t_k, K); 41 | for k = 1:K 42 | H_k_t = H_slot_rician_channel_gen(N_t_k, N_r, kappa); 43 | H_k_t_all(:, :, k) = H_k_t; 44 | % svd(H_k_t) 45 | H_k = kron(eye(T), H_k_t); 46 | H_k_all(:, :, k) = H_k; 47 | H = [H, H_k]; 48 | end 49 | 50 | %% MCR2 Precoder 51 | alpha = T*N_r / eps^2; 52 | C_sr = sqrtm(C); 53 | 54 | % initialization of V 55 | V_init_k = rand(T*N_t_k, D/(2*K)) + 1j*rand(T*N_t_k, D/(2*K)); 56 | V_init = kron(eye(K), V_init_k); 57 | V_init = sqrt(P) * V_init ./ sqrt(trace(V_init*C*V_init')); 58 | V = V_init; 59 | 60 | Ite = 1000; 61 | mcr2_obj_last = 500; 62 | mcr2_all = zeros(1, Ite); 63 | for ii = 1:Ite 64 | ii; 65 | % U update 66 | U = (H*V*C*V'*H'+(1/alpha + delta_0^2)*eye(N_r*T))\(H*V*C_sr); 67 | 68 | % W update 69 | E_0 = eye(D/2)-U'*H*V*C_sr; 70 | E_0 = (E_0*E_0') + (1/alpha + delta_0^2)*(U'*U); 71 | W_0 = inv(E_0); 72 | W_j_all = zeros(N_r*T, N_r*T, L); 73 | for j = 1:L 74 | C_j = CLS_cov(:, :, j); 75 | W_j_all(:, :, j) = inv((1+alpha*delta_0^2)*eye(N_r*T) + alpha*H*V*C_j*V'*H'); 76 | end 77 | 78 | % V update 79 | for k = 1:K 80 | H_k = H_k_all(:, :, k); 81 | C_sr_k = C_sr(((k-1)*D/(2*K)+1):k*D/(2*K), :); 82 | C_k = C(((k-1)*D/(2*K)+1):k*D/(2*K), :); 83 | C_kk = C(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K)); 84 | V_k = V(((k-1)*T*N_t_k+1):k*T*N_t_k, ... 85 | ((k-1)*D/(2*K)+1):k*D/(2*K)); 86 | 87 | sum_term1_T_k = H*V*C_k' - H_k*V_k*C_kk; 88 | sum_term2_T_k = 0; % zeros(2*T*N_t_k, D/K); 89 | sum_term_M_k = 0; 90 | % I_kron_H_bar_k = kron(eye(D/(2*K)), H_k); 91 | for j = 1:L 92 | W_j = W_j_all(:, :, j); 93 | 94 | C_j_k = CLS_cov(((k-1)*D/(2*K)+1):k*D/(2*K), :, j); 95 | C_j_kk = CLS_cov(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K), j); 96 | sum_term2_j_T_k = H*V*C_j_k' - H_k*V_k*C_j_kk; 97 | sum_term2_T_k = sum_term2_T_k + ... 98 | alpha*p(j)*H_k'*W_j*sum_term2_j_T_k; 99 | 100 | % sum_term_M_k = sum_term_M_k + ... 101 | % alpha*p(j)*I_kron_H_bar_k'*kron(C_j_kk, W_j)*I_kron_H_bar_k; 102 | sum_term_M_k = sum_term_M_k + ... 103 | alpha*p(j)*kron(C_j_kk.', H_k'*W_j*H_k); 104 | end 105 | T_k = H_k'*U*W_0*C_sr_k' - ... 106 | H_k'*U*W_0*U'*sum_term1_T_k - sum_term2_T_k; 107 | % I_kron_U_H_bar_k = kron(eye(D/K), U'*H_k); 108 | M_k = kron(C_kk.', H_k'*U*W_0*U'*H_k) + ... 109 | sum_term_M_k; 110 | 111 | [Q_kk, D_kk] = svd(C_kk.'); 112 | Q_kk_kron_I = kron(Q_kk, eye(N_t_k*T)); 113 | D_kk_kron_I = kron(D_kk, eye(N_t_k*T)); 114 | D_kk_kron_I_msr = diag( 1./sqrt(diag(D_kk_kron_I)) ); 115 | C_kk_kron_I_msr = Q_kk_kron_I * D_kk_kron_I_msr * Q_kk_kron_I'; 116 | 117 | t_k = C_kk_kron_I_msr*vec(T_k); 118 | M_k = C_kk_kron_I_msr*M_k*C_kk_kron_I_msr; 119 | 120 | 121 | [U_M_k, Gamma_k] = svd(M_k); 122 | g_k = U_M_k' * t_k; 123 | 124 | lambda_lb = 0; 125 | lambda_ub = sqrt(real(g_k'*g_k)/(P_k*T)); 126 | 127 | fcn_bar_0 = f_bar(0, g_k, Gamma_k, false); 128 | epsilon_bi_search = 1e-6; 129 | if fcn_bar_0 <= P_k 130 | lambda = 0; 131 | else 132 | lambda = bisection_search_bar(lambda_lb, lambda_ub, T*P_k, epsilon_bi_search, g_k, Gamma_k, false); 133 | end 134 | 135 | if lambda < 1e-20 136 | v_k = pinv( M_k + lambda*eye(N_t_k*T*D/(2*K)) ) * t_k; 137 | else 138 | v_k = ( M_k + lambda*eye(N_t_k*T*D/(2*K)) ) \ t_k; 139 | end 140 | 141 | v_k = C_kk_kron_I_msr * v_k; 142 | 143 | % retrive V_k from v_k 144 | V_k = zeros(T*N_t_k, D/(2*K)); 145 | for i = 1:D/(2*K) 146 | V_k(:, i) = v_k(T*N_t_k*(i-1)+1: T*N_t_k*i); 147 | end 148 | V((T*N_t_k*(k-1)+1):T*N_t_k*k, ((k-1)*D/(2*K)+1):k*D/(2*K)) = V_k; 149 | end 150 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p); 151 | mcr2_all(ii) = mcr2; 152 | 153 | % terminate criterion 154 | if abs((mcr2 - mcr2_obj_last)/mcr2_obj_last) < 1e-5 155 | break; 156 | end 157 | mcr2_obj_last = mcr2; 158 | end 159 | figure; 160 | plot(mcr2_all(1:ii), 'r-', LineWidth=1.6); 161 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p); 162 | 163 | %% MAP receiver and LMMSE estimator 164 | epsilon = 1e-6; % trick 165 | acc_mcr2 = 0; 166 | mse_by_mcr2 = 0; 167 | parfor j = 1:Dataset 168 | z_cplx = test_feature(:, j); 169 | % add noise 170 | n_cplx = delta_0 * sqrt(1/2) * (randn(N_r*T, 1) + 1j*randn(N_r*T, 1)); 171 | C_n_cplx = delta_0^2 * eye(N_r*T); 172 | y_cplx = H*V*z_cplx + n_cplx; 173 | 174 | %% Recover signals from complex to real 175 | z_real = [real(z_cplx); imag(z_cplx)]; 176 | y_real = [real(y_cplx); imag(y_cplx)]; 177 | HV_real = [real(H*V), -imag(H*V);... 178 | imag(H*V), real(H*V)]; 179 | C_n_real = 1/2 * [real(C_n_cplx), zeros(N_r*T, N_r*T); 180 | zeros(N_r*T, N_r*T), real(C_n_cplx)]; 181 | 182 | %% 183 | f = zeros(1, L); 184 | log_f = zeros(1, L); % numerical issue 185 | weighted_sum_element = zeros(1, L); 186 | for i = 1:L 187 | mu_i_cplx = CLS_mean(:, i); 188 | C_i = CLS_cov(:, :, i); 189 | R_i = CLS_rlt(:, :, i); 190 | 191 | mu_i = [real(mu_i_cplx); imag(mu_i_cplx)]; 192 | Cov_i = 1/2 * [real(C_i + R_i), imag(-C_i + R_i);... 193 | imag(C_i + R_i), real(C_i - R_i)]; 194 | 195 | [f(i), log_f(i)] = gaussian_pdf(y_real, HV_real*mu_i,... 196 | HV_real*Cov_i*HV_real' + C_n_real, epsilon); 197 | weighted_sum_element(i) = p(i) * f(i); 198 | end 199 | weighted_sum = sum(weighted_sum_element); 200 | 201 | % GM receiver 202 | % [~, pos] = max(log_f); 203 | [~, pos] = max(weighted_sum_element); 204 | if (pos-1)==test_label(j) 205 | acc_mcr2 = acc_mcr2 +1; 206 | end 207 | 208 | % LMMSE estimate 209 | z_lmmse_estimate = mean_real + C_xx_real*HV_real'*... 210 | inv(HV_real*C_xx_real*HV_real'+C_n_real+epsilon*eye(2*N_r*T))*... 211 | (y_real-HV_real*mean_real); 212 | mse_by_mcr2 = mse_by_mcr2 + norm(z_lmmse_estimate-z_real, 2)^2; 213 | 214 | end 215 | acc_mcr2 = acc_mcr2/Dataset 216 | mse_by_mcr2 = mse_by_mcr2/Dataset 217 | 218 | 219 | %% LMMSE Precoder 220 | alpha = T*N_r / eps^2; 221 | 222 | % initialization of V 223 | V_init_k = rand(T*N_t_k, D/(2*K)) + 1j*rand(T*N_t_k, D/(2*K)); 224 | V_init = kron(eye(K), V_init_k); 225 | V_init = sqrt(P) * V_init ./ sqrt(trace(V_init*C*V_init')); 226 | V = V_init; 227 | 228 | Ite = 1000; 229 | mse_obj_last = 500; 230 | mse_all = zeros(1, Ite); 231 | for ii = 1:Ite 232 | ii; 233 | % R update 234 | R = (H*V*C*V'*H'+(delta_0^2)*eye(N_r*T))\(H*V*C); 235 | 236 | % V update 237 | for k = 1:K 238 | H_k = H_k_all(:, :, k); 239 | C_k = C(((k-1)*D/(2*K)+1):k*D/(2*K), :); 240 | C_kk = C(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K)); 241 | V_k = V(((k-1)*T*N_t_k+1):k*T*N_t_k, ... 242 | ((k-1)*D/(2*K)+1):k*D/(2*K)); 243 | 244 | sum_term_T_k = H*V*C_k' - H_k*V_k*C_kk; 245 | J_k = C_k - sum_term_T_k'*R; 246 | 247 | Q_k = H_k'*R*J_k'/C_kk; 248 | M_k = H_k'*R*R'*H_k; 249 | 250 | [U_M_k, Gamma_k] = svd(M_k); 251 | G_k = U_M_k'*Q_k*C_kk*Q_k'*U_M_k; 252 | 253 | lambda_lb = 0; 254 | lambda_ub = sqrt(real(sum(diag(G_k)))/(T*P_k)); 255 | 256 | fcn_bar_0 = f_bar(0, diag(G_k), Gamma_k, true); 257 | epsilon_bi_search = 1e-6; 258 | if fcn_bar_0 <= T*P_k 259 | lambda = 0; 260 | else 261 | lambda = bisection_search_bar(lambda_lb, lambda_ub, T*P_k, epsilon_bi_search, diag(G_k), Gamma_k, true); 262 | end 263 | V_k = (M_k + lambda*eye(size(M_k)))\Q_k; 264 | 265 | V((T*N_t_k*(k-1)+1):T*N_t_k*k, ((k-1)*D/(2*K)+1):k*D/(2*K)) = V_k; 266 | end 267 | matrix_inv = inv(H*V*C*V'*H'+delta_0^2*eye(N_r*T)); 268 | mse_obj = real( trace(C) - trace(C*V'*H'*matrix_inv*H*V*C) ); 269 | mse_all(ii) = mse_obj; 270 | 271 | % terminate criterion 272 | if abs((mse_obj - mse_obj_last)/mse_obj_last) < 1e-5 273 | break; 274 | end 275 | mse_obj_last = mse_obj; 276 | end 277 | figure; 278 | plot(mse_all(1:ii), 'r-', LineWidth=1.6) 279 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p); 280 | 281 | %% MAP receiver and LMMSE estimator 282 | epsilon = 1e-6; % trick 283 | acc_lmmse = 0; 284 | mse_by_lmmse = 0; 285 | parfor j = 1:Dataset 286 | z_cplx = test_feature(:, j); 287 | % add noise 288 | n_cplx = delta_0 * sqrt(1/2) * (randn(N_r*T, 1) + 1j*randn(N_r*T, 1)); 289 | C_n_cplx = delta_0^2 * eye(N_r*T); 290 | y_cplx = H*V*z_cplx + n_cplx; 291 | 292 | %% Recover signals from complex to real 293 | z_real = [real(z_cplx); imag(z_cplx)]; 294 | y_real = [real(y_cplx); imag(y_cplx)]; 295 | HV_real = [real(H*V), -imag(H*V);... 296 | imag(H*V), real(H*V)]; 297 | C_n_real = 1/2 * [real(C_n_cplx), zeros(N_r*T, N_r*T); 298 | zeros(N_r*T, N_r*T), real(C_n_cplx)]; 299 | 300 | %% 301 | f = zeros(1, L); 302 | log_f = zeros(1, L); % numerical issue 303 | weighted_sum_element = zeros(1, L); 304 | for i = 1:L 305 | mu_i_cplx = CLS_mean(:, i); 306 | C_i = CLS_cov(:, :, i); 307 | R_i = CLS_rlt(:, :, i); 308 | 309 | mu_i = [real(mu_i_cplx); imag(mu_i_cplx)]; 310 | Cov_i = 1/2 * [real(C_i + R_i), imag(-C_i + R_i);... 311 | imag(C_i + R_i), real(C_i - R_i)]; 312 | 313 | [f(i), log_f(i)] = gaussian_pdf(y_real, HV_real*mu_i,... 314 | HV_real*Cov_i*HV_real' + C_n_real, epsilon); 315 | weighted_sum_element(i) = p(i) * f(i); 316 | end 317 | weighted_sum = sum(weighted_sum_element); 318 | 319 | % GM receiver 320 | % [~, pos] = max(log_f); 321 | [~, pos] = max(weighted_sum_element); 322 | if (pos-1)==test_label(j) 323 | acc_lmmse = acc_lmmse +1; 324 | end 325 | 326 | % LMMSE estimate 327 | z_lmmse_estimate = mean_real + C_xx_real*HV_real'*... 328 | inv(HV_real*C_xx_real*HV_real'+C_n_real+epsilon*eye(2*N_r*T))*... 329 | (y_real-HV_real*mean_real); 330 | mse_by_lmmse = mse_by_lmmse + norm(z_lmmse_estimate-z_real, 2)^2; 331 | 332 | end 333 | acc_lmmse = acc_lmmse/Dataset 334 | mse_by_lmmse = mse_by_lmmse/Dataset 335 | -------------------------------------------------------------------------------- /precoding_opt_matlab/steering_vec.m: -------------------------------------------------------------------------------- 1 | function [e] = steering_vec(x,M) 2 | % steering vector function 3 | m = (1:M)'; 4 | e = 1/sqrt(M) * exp(-1j*pi*(m-1)*x); 5 | % e = exp(-1j*pi*(m-1)*x); 6 | end 7 | 8 | -------------------------------------------------------------------------------- /tools/evaluate_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.svm import LinearSVC 3 | from sklearn.decomposition import PCA 4 | from sklearn.decomposition import TruncatedSVD 5 | 6 | import utils 7 | 8 | 9 | def svm(train_features, train_labels, test_features, test_labels): 10 | svm = LinearSVC(verbose=0, random_state=10) 11 | svm.fit(train_features, train_labels) 12 | acc_train = svm.score(train_features, train_labels) 13 | acc_test = svm.score(test_features, test_labels) 14 | print("acc train SVM: {}".format(acc_train)) 15 | print("acc test SVM: {}".format(acc_test)) 16 | return acc_train, acc_test 17 | 18 | 19 | def knn(k, train_features, train_labels, test_features, test_labels): 20 | """Perform k-Nearest Neighbor classification using cosine similaristy as metric. 21 | 22 | Options: 23 | k (int): top k features for kNN 24 | 25 | """ 26 | sim_mat = train_features @ test_features.T 27 | topk = sim_mat.topk(k=k, dim=0) 28 | topk_pred = train_labels[topk.indices] 29 | test_pred = topk_pred.mode(0).values.detach() 30 | acc = utils.compute_accuracy(test_pred.numpy(), test_labels.numpy()) 31 | print("kNN: {}".format(acc)) 32 | return acc 33 | 34 | 35 | def nearsub(n_comp, train_features, train_labels, test_features, test_labels): 36 | """Perform nearest subspace classification. 37 | 38 | Options: 39 | n_comp (int): number of components for PCA or SVD 40 | 41 | """ 42 | scores_pca = [] 43 | scores_svd = [] 44 | num_classes = train_labels.numpy().max() + 1 # should be correct most of the time 45 | features_sort, _ = utils.sort_dataset(train_features.numpy(), train_labels.numpy(), 46 | num_classes=num_classes, stack=False) 47 | fd = features_sort[0].shape[1] 48 | for j in range(num_classes): 49 | pca = PCA(n_components=n_comp).fit(features_sort[j]) 50 | pca_subspace = pca.components_.T 51 | mean = np.mean(features_sort[j], axis=0) 52 | pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \ 53 | @ (test_features.numpy() - mean).T 54 | score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0) 55 | 56 | svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j]) 57 | svd_subspace = svd.components_.T 58 | svd_j = (np.eye(fd) - svd_subspace @ svd_subspace.T) \ 59 | @ (test_features.numpy()).T 60 | score_svd_j = np.linalg.norm(svd_j, ord=2, axis=0) 61 | 62 | scores_pca.append(score_pca_j) 63 | scores_svd.append(score_svd_j) 64 | test_predict_pca = np.argmin(scores_pca, axis=0) 65 | test_predict_svd = np.argmin(scores_svd, axis=0) 66 | acc_pca = utils.compute_accuracy(test_predict_pca, test_labels.numpy()) 67 | acc_svd = utils.compute_accuracy(test_predict_svd, test_labels.numpy()) 68 | print('PCA: {}'.format(acc_pca)) 69 | print('SVD: {}'.format(acc_svd)) 70 | return acc_svd -------------------------------------------------------------------------------- /tools/img_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import torch.utils.data 4 | import os 5 | import math 6 | from skimage import io, transform 7 | from PIL import Image 8 | import torch 9 | import torchvision as vision 10 | from torchvision import transforms, datasets 11 | import random 12 | 13 | 14 | class MultiviewImgDataset(torch.utils.data.Dataset): 15 | 16 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, \ 17 | num_classes=40, num_views=12, train_objects=80, test_objects=20, shuffle=True): 18 | 19 | if num_classes == 10: 20 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 21 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] 22 | elif num_classes == 40: 23 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 24 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 25 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 26 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 27 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 28 | 29 | self.num_classes = num_classes 30 | self.root_dir = root_dir 31 | self.scale_aug = scale_aug 32 | self.rot_aug = rot_aug 33 | self.test_mode = test_mode 34 | self.num_views = num_views 35 | self.train_objects = train_objects 36 | self.test_objects = test_objects 37 | self.train_file_number = self.train_objects * self.num_views 38 | self.test_file_number = self.test_objects * self.num_views 39 | 40 | set_ = root_dir.split('/')[-1] 41 | parent_dir = root_dir.rsplit('/', 2)[0] 42 | self.filepaths = [] 43 | for i in range(len(self.classnames)): 44 | all_files = sorted(glob.glob(parent_dir + '/' + self.classnames[i] + '/' + set_ + '/*.png')) 45 | ## Select subset for different number of views 46 | stride = int(12 / self.num_views) # 12 6 4 3 2 1 47 | all_files = all_files[::stride] 48 | 49 | if self.test_mode: 50 | self.filepaths.extend(all_files[:min(self.test_file_number, len(all_files))]) 51 | else: 52 | self.filepaths.extend(all_files[:min(self.train_file_number, len(all_files))]) 53 | 54 | if shuffle == True: 55 | # permute 56 | rand_idx = np.random.permutation(int(len(self.filepaths) / num_views)) 57 | filepaths_new = [] 58 | for i in range(len(rand_idx)): 59 | filepaths_new.extend(self.filepaths[rand_idx[i] * num_views:(rand_idx[i] + 1) * num_views]) 60 | self.filepaths = filepaths_new 61 | 62 | if self.test_mode: 63 | self.transform = transforms.Compose([ 64 | transforms.ToTensor(), 65 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 66 | std=[0.229, 0.224, 0.225]) 67 | ]) 68 | else: 69 | self.transform = transforms.Compose([ 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 73 | std=[0.229, 0.224, 0.225]) 74 | ]) 75 | 76 | def __len__(self): 77 | return int(len(self.filepaths) / self.num_views) 78 | 79 | def __getitem__(self, idx): 80 | path = self.filepaths[idx * self.num_views] 81 | class_name = path.split('/')[-3] # Linux 82 | # class_name = path.split('/')[-2] # Win 83 | class_id = self.classnames.index(class_name) 84 | # Use PIL instead 85 | imgs = [] 86 | for i in range(self.num_views): 87 | im = Image.open(self.filepaths[idx * self.num_views + i]).convert('RGB') 88 | if self.transform: 89 | im = self.transform(im) 90 | imgs.append(im) 91 | 92 | return (class_id, torch.stack(imgs), self.filepaths[idx * self.num_views:(idx + 1) * self.num_views]) 93 | 94 | 95 | class SingleImgDataset(torch.utils.data.Dataset): 96 | 97 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, 98 | num_classes=40, num_views=12, train_objects=80, test_objects=20): 99 | 100 | if num_classes == 10: 101 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser', 102 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet'] 103 | elif num_classes == 40: 104 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair', 105 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box', 106 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand', 107 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 108 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox'] 109 | 110 | self.num_classes = num_classes 111 | self.root_dir = root_dir 112 | self.scale_aug = scale_aug 113 | self.rot_aug = rot_aug 114 | self.test_mode = test_mode 115 | self.num_views = num_views 116 | self.train_objects = train_objects 117 | self.test_objects = test_objects 118 | self.train_file_number = self.train_objects * self.num_views 119 | self.test_file_number = self.test_objects * self.num_views 120 | 121 | set_ = root_dir.split('/')[-1] 122 | parent_dir = root_dir.rsplit('/', 2)[0] 123 | self.filepaths = [] 124 | for i in range(len(self.classnames)): 125 | all_files = sorted(glob.glob(parent_dir + '/' + self.classnames[i] + '/' + set_ + '/*shaded*.png')) 126 | if self.test_mode: 127 | self.filepaths.extend(all_files[:min(self.test_file_number, len(all_files))]) 128 | else: 129 | self.filepaths.extend(all_files[:min(self.train_file_number, len(all_files))]) 130 | 131 | self.transform = transforms.Compose([ 132 | transforms.RandomHorizontalFlip(), 133 | transforms.ToTensor(), 134 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 135 | std=[0.229, 0.224, 0.225]) 136 | ]) 137 | 138 | def __len__(self): 139 | return len(self.filepaths) 140 | 141 | def __getitem__(self, idx): 142 | path = self.filepaths[idx] 143 | class_name = path.split('/')[-3] # Linux 144 | # class_name = path.split('/')[-2] # Win 145 | class_id = self.classnames.index(class_name) 146 | 147 | # Use PIL instead 148 | im = Image.open(self.filepaths[idx]).convert('RGB') 149 | if self.transform: 150 | im = self.transform(im) 151 | 152 | return (class_id, im, path) 153 | -------------------------------------------------------------------------------- /tools/train_func.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | # import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import DataLoader 11 | 12 | # from cluster import ElasticNetSubspaceClustering, clustering_accuracy 13 | import utils 14 | 15 | 16 | def get_features(net, trainloader, loss_name, model_name, verbose=True): 17 | '''Extract all features out into one single batch. 18 | 19 | Parameters: 20 | net (torch.nn.Module): get features using this model 21 | trainloader (torchvision.dataloader): dataloader for loading data 22 | verbose (bool): shows loading staus bar 23 | 24 | Returns: 25 | features (torch.tensor): with dimension (num_samples, feature_dimension) 26 | labels (torch.tensor): with dimension (num_samples, ) 27 | ''' 28 | features = [] 29 | labels = [] 30 | if verbose: 31 | train_bar = tqdm(trainloader, desc="extracting all features from dataset") 32 | else: 33 | train_bar = trainloader 34 | for step, data in enumerate(train_bar): 35 | if model_name == 'mvcnn': 36 | N, V, C, H, W = data[1].size() 37 | batch_imgs = data[1].view(-1, C, H, W) 38 | else: 39 | batch_imgs = data[1] 40 | batch_lbls = data[0].long() 41 | 42 | if loss_name == 'mcr2': 43 | batch_features = net(batch_imgs.cpu()) 44 | else: # 'ce' 45 | _, batch_features = net(batch_imgs.cpu()) 46 | 47 | features.append(batch_features.cpu().detach()) 48 | # features.append(batch_features) 49 | labels.append(batch_lbls) 50 | return torch.cat(features), torch.cat(labels) 51 | 52 | 53 | def label_to_membership(targets, num_classes=None): 54 | """Generate a true membership matrix, and assign value to current Pi. 55 | 56 | Parameters: 57 | targets (np.ndarray): matrix with one hot labels 58 | 59 | Return: 60 | Pi: membership matirx, shape (num_classes, num_samples, num_samples) 61 | 62 | """ 63 | targets = one_hot(targets, num_classes) 64 | num_samples, num_classes = targets.shape 65 | Pi = np.zeros(shape=(num_classes, num_samples, num_samples)) 66 | for j in range(len(targets)): 67 | k = np.argmax(targets[j]) 68 | Pi[k, j, j] = 1. 69 | return Pi 70 | 71 | 72 | def membership_to_label(membership): 73 | """Turn a membership matrix into a list of labels.""" 74 | _, num_classes, num_samples, _ = membership.shape 75 | labels = np.zeros(num_samples) 76 | for i in range(num_samples): 77 | labels[i] = np.argmax(membership[:, i, i]) 78 | return labels 79 | 80 | def one_hot(labels_int, n_classes): 81 | """Turn labels into one hot vector of K classes. """ 82 | labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float() 83 | for i, y in enumerate(labels_int): 84 | labels_onehot[i, y] = 1. 85 | return labels_onehot 86 | -------------------------------------------------------------------------------- /tools/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import os 5 | from tensorboardX import SummaryWriter 6 | 7 | from tools import train_func as tf 8 | from tools.evaluate_func import svm 9 | 10 | 11 | class ModelNetTrainer(object): 12 | 13 | def __init__(self, model, device, train_loader, val_loader, optimizer, loss_fn, \ 14 | model_name, model_dir, num_classes, num_views=12): 15 | 16 | self.optimizer = optimizer 17 | self.model = model 18 | self.device = device 19 | self.train_loader = train_loader 20 | self.val_loader = val_loader 21 | self.loss_fn = loss_fn 22 | self.model_name = model_name 23 | self.num_views = num_views 24 | self.num_classes = num_classes 25 | 26 | self.model_dir = model_dir 27 | self.checkpoints_dir = os.path.join(model_dir, 'checkpoints') 28 | self.tensorboard_dir = os.path.join(model_dir, 'tensorboard') 29 | 30 | # self.model.cuda() 31 | self.model.to(self.device) 32 | 33 | self.writer = SummaryWriter(self.tensorboard_dir) 34 | 35 | def train(self, n_epochs): 36 | 37 | if self.model_name == 'mvcnn': 38 | for param in self.model.net_1.parameters(): 39 | param.requires_grad = False 40 | 41 | i_acc = 0 42 | self.model.train() 43 | for epoch in range(n_epochs): 44 | # permute data 45 | rand_idx = np.random.permutation(int(len(self.train_loader.dataset.filepaths)/self.num_views)) 46 | filepaths_new = [] 47 | for i in range(len(rand_idx)): 48 | filepaths_new.extend(self.train_loader.dataset.filepaths[rand_idx[i]*self.num_views:(rand_idx[i]+1)*self.num_views]) 49 | self.train_loader.dataset.filepaths = filepaths_new 50 | 51 | # plot learning rate 52 | lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 53 | self.writer.add_scalar('params/lr', lr, epoch) 54 | 55 | for i, data in enumerate(self.train_loader): 56 | 57 | if self.model_name == 'mvcnn': 58 | N,V,C,H,W = data[1].size() 59 | in_data = Variable(data[1]).view(-1, C, H, W).to(self.device) 60 | else: 61 | in_data = Variable(data[1]).to(self.device) 62 | target = Variable(data[0]).to(self.device).long() 63 | 64 | self.optimizer.zero_grad() 65 | 66 | out_data = self.model(in_data) 67 | loss, [discrimn_loss, compress_loss] = self.loss_fn(out_data, target, num_classes=self.num_classes) 68 | 69 | self.writer.add_scalars('train_loss', 70 | {'loss': loss, 'discrimn': discrimn_loss, 'compress': compress_loss}, 71 | i_acc+i+1) 72 | 73 | loss.backward() 74 | self.optimizer.step() 75 | 76 | # for name in self.model.state_dict(): 77 | # print(name) 78 | # print(self.model.state_dict()['net_1.6.bias']) 79 | # print(self.model.state_dict()['net_2.2.9.weight']) 80 | 81 | log_str = 'epoch %d, step %d: train_loss %.3f' % (epoch+1, i+1, loss) 82 | if (i+1) % 1 == 0: 83 | print(log_str) 84 | i_acc += i+1 85 | 86 | # evaluation 87 | if (epoch+1) % 10 == 0: 88 | with torch.no_grad(): 89 | acc_train_svm, acc_test_svm = self.update_validation_accuracy() 90 | self.writer.add_scalars('acc', 91 | {'acc_train_svm': acc_train_svm, 'acc_test_svm': acc_test_svm}, 92 | epoch+1) 93 | 94 | # save the model 95 | if (epoch + 1) % 10 == 0: 96 | self.model.save(self.checkpoints_dir, epoch) 97 | 98 | # adjust learning rate manually 99 | if epoch > 0 and (epoch+1) % 40 == 0: 100 | for param_group in self.optimizer.param_groups: 101 | param_group['lr'] = param_group['lr']*0.5 102 | 103 | # export scalar data to JSON for external processing 104 | self.writer.export_scalars_to_json(self.checkpoints_dir+"/all_scalars.json") 105 | self.writer.close() 106 | 107 | def update_validation_accuracy(self): 108 | 109 | self.model.eval() 110 | train_features, train_labels = tf.get_features(self.model, self.train_loader, 'mcr2', self.model_name) 111 | test_features, test_labels = tf.get_features(self.model, self.val_loader, 'mcr2', self.model_name) 112 | 113 | # train & test accuracy 114 | acc_train_svm, acc_test_svm = svm(train_features, train_labels, test_features, test_labels) 115 | 116 | self.model.train() 117 | 118 | return acc_train_svm, acc_test_svm 119 | 120 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def sort_dataset(data, labels, num_classes=10, stack=False): 9 | """Sort dataset based on classes. 10 | 11 | Parameters: 12 | data (np.ndarray): data array 13 | labels (np.ndarray): one dimensional array of class labels 14 | num_classes (int): number of classes 15 | stack (bol): combine sorted data into one numpy array 16 | 17 | Return: 18 | sorted data (np.ndarray), sorted_labels (np.ndarray) 19 | 20 | """ 21 | sorted_data = [[] for _ in range(num_classes)] 22 | for i, lbl in enumerate(labels): 23 | sorted_data[lbl].append(data[i]) 24 | sorted_data = [np.stack(class_data) for class_data in sorted_data] 25 | sorted_labels = [np.repeat(i, (len(sorted_data[i]))) for i in range(num_classes)] 26 | if stack: 27 | sorted_data = np.vstack(sorted_data) 28 | sorted_labels = np.hstack(sorted_labels) 29 | return sorted_data, sorted_labels 30 | 31 | def init_pipeline(model_dir, headers=None): 32 | """Initialize folder and .csv logger.""" 33 | # project folder 34 | os.makedirs(model_dir) 35 | os.makedirs(os.path.join(model_dir, 'checkpoints')) 36 | os.makedirs(os.path.join(model_dir, 'tensorboard')) 37 | # os.makedirs(os.path.join(model_dir, 'figures')) 38 | # os.makedirs(os.path.join(model_dir, 'plabels')) 39 | # os.makedirs(name=model_dir, exist_ok=True) 40 | # os.makedirs(os.path.join(model_dir, 'checkpoints'), exist_ok=True) 41 | # os.makedirs(os.path.join(model_dir, 'figures'), exist_ok=True) 42 | # os.makedirs(os.path.join(model_dir, 'plabels'), exist_ok=True) 43 | if headers is None: 44 | headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e", 45 | "discrimn_loss_t", "compress_loss_t"] 46 | create_csv(model_dir, 'losses.csv', headers) 47 | print("project dir: {}".format(model_dir)) 48 | 49 | def create_csv(model_dir, filename, headers): 50 | """Create .csv file with filename in model_dir, with headers as the first line 51 | of the csv. """ 52 | csv_path = os.path.join(model_dir, filename) 53 | if os.path.exists(csv_path): 54 | os.remove(csv_path) 55 | with open(csv_path, 'w+') as f: 56 | f.write(','.join(map(str, headers))) 57 | return csv_path 58 | 59 | def save_params(model_dir, params): 60 | """Save params to a .json file. Params is a dictionary of parameters.""" 61 | path = os.path.join(model_dir, 'params.json') 62 | with open(path, 'w') as f: 63 | json.dump(params, f, indent=2, sort_keys=True) 64 | 65 | def update_params(model_dir, pretrain_dir): 66 | """Updates architecture and feature dimension from pretrain directory 67 | to new directoy. """ 68 | params = load_params(model_dir) 69 | old_params = load_params(pretrain_dir) 70 | params['arch'] = old_params["arch"] 71 | params['fd'] = old_params['fd'] 72 | save_params(model_dir, params) 73 | 74 | def load_params(model_dir): 75 | """Load params.json file in model directory and return dictionary.""" 76 | _path = os.path.join(model_dir, "params.json") 77 | with open(_path, 'r') as f: 78 | _dict = json.load(f) 79 | return _dict 80 | 81 | def save_state(model_dir, *entries, filename='losses.csv'): 82 | """Save entries to csv. Entries is list of numbers. """ 83 | csv_path = os.path.join(model_dir, filename) 84 | assert os.path.exists(csv_path), 'CSV file is missing in project directory.' 85 | with open(csv_path, 'a') as f: 86 | f.write('\n'+','.join(map(str, entries))) 87 | 88 | def save_ckpt(model_dir, net, epoch): 89 | """Save PyTorch checkpoint to ./checkpoints/ directory in model directory. """ 90 | torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints', 91 | 'model-epoch{}.pt'.format(epoch))) 92 | 93 | def save_labels(model_dir, labels, epoch): 94 | """Save labels of a certain epoch to directory. """ 95 | path = os.path.join(model_dir, 'plabels', f'epoch{epoch}.npy') 96 | np.save(path, labels) 97 | 98 | def compute_accuracy(y_pred, y_true): 99 | """Compute accuracy by counting correct classification. """ 100 | assert y_pred.shape == y_true.shape 101 | return 1 - np.count_nonzero(y_pred - y_true) / y_true.size 102 | 103 | def clustering_accuracy(labels_true, labels_pred): 104 | """Compute clustering accuracy.""" 105 | from sklearn.metrics.cluster import supervised 106 | from scipy.optimize import linear_sum_assignment 107 | labels_true, labels_pred = supervised.check_clusterings(labels_true, labels_pred) 108 | value = supervised.contingency_matrix(labels_true, labels_pred) 109 | [r, c] = linear_sum_assignment(-value) 110 | return value[r, c].sum() / len(labels_true) --------------------------------------------------------------------------------