├── README.md
├── feature_load.py
├── figures
├── consistent_system_model.png
├── inconsistent_system_model.png
└── sphere_packing_mcr2.png
├── main_train_phase1.py
├── main_train_phase2.py
├── mcr2_loss.py
├── models
├── model.py
└── mvcnn.py
├── precoding_opt_matlab
├── H_slot_rician_channel_gen.m
├── MCR2_ModelNet10_statistics_complex_24.mat
├── MCR2_ModelNet10_test_feature_label_complex_24.mat
├── MCR2_obj_cplx.m
├── bisection_search_bar.m
├── f_bar.m
├── gaussian_pdf.m
├── main_precoding_opt.m
└── steering_vec.m
└── tools
├── evaluate_func.py
├── img_dataset.py
├── train_func.py
├── trainer.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # TaskCommMCR2
2 |
3 | This repository is the official implementation of the paper:
4 |
5 | - **Multi-Device Task-Oriented Communication via Maximal Coding Rate Reduction** [[IEEE TWC](https://ieeexplore.ieee.org/abstract/document/10689268)] [[arXiv](https://arxiv.org/abs/2309.02888)]
6 | - **Authors:** [Chang Cai](https://chang-cai.github.io/) (The Chinese University of Hong Kong), [Xiaojun Yuan](https://scholar.google.com/citations?user=o6W_m00AAAAJ&hl=en) (University of Electronic Science and Technology of China), and [Ying-Jun Angela Zhang](https://staff.ie.cuhk.edu.hk/~yjzhang/) (The Chinese University of Hong Kong)
7 |
8 | ## Brief Introduction
9 |
10 | ### Existing Studies: Inconsistent Objectives for Learning and Communication
11 |
12 |
13 |
14 |
15 |
16 | ### This Work: Synergistic Alignment of Learning and Communication Objectives
17 |
18 |
19 |
20 |
21 |
22 |
23 | ## Usage
24 |
25 | ### Feature Encoding
26 | - Download images and put it under ```modelnet40_images_new_12x```: [Shaded Images (1.6GB)](http://supermoe.cs.umass.edu/shape_recog/shaded_images.tar.gz).
27 | If the link does not work, you can download it from my [Google Drive](https://drive.google.com/file/d/1tghRek04_pVHCkYOTOQeYtGlENlgQzhM/view?usp=sharing) backup.
28 |
29 | - Set environment: code is tested on ```python 3.7.13``` and ```pytorch 1.12.1```.
30 |
31 | - Run the script ```main_train_phase1.py``` for the first-phase training of feature encoding.
32 | Then, run the script ```main_train_phase2.py``` for the second-phase training of feature encoding. Check [mvcnn_pytorch](https://github.com/jongchyisu/mvcnn_pytorch) for the details of the two training phases.
33 |
34 | - Alternatively, download the pretrained checkpoints at [Google Drive](https://drive.google.com/drive/folders/1bi2kMot2XI3H27MitiCE6ecxATnGIn5r?usp=drive_link). The checkpoints can be used for feature extraction by running the script ```feature_load.py```.
35 |
36 | ### Precoding Optimization and Performance Evaluation
37 |
38 | The code is located at the folder ```precoding_opt_matlab```.
39 | Run the script ```main_precoding_opt.m``` to compare the performance of the proposed MCR2 precoder and the LMMSE precoder.
40 |
41 | ## Citation
42 | If you find our work interesting, please consider citing
43 |
44 | ```
45 | @ARTICLE{task_comm_mcr2,
46 | author={Cai, Chang and Yuan, Xiaojun and Zhang, Ying-Jun Angela},
47 | journal={IEEE Transactions on Wireless Communications},
48 | title={Multi-Device Task-Oriented Communication via Maximal Coding Rate Reduction},
49 | year={2024},
50 | volume={23},
51 | number={12},
52 | pages={18096-18110}
53 | }
54 | ```
55 | [Our follow-up work](https://ieeexplore.ieee.org/abstract/document/10845817) provides an information-theoretic interpretation of the learning-communication separation, as well as an end-to-end learning framework:
56 |
57 | ```
58 | @ARTICLE{info_theoretic_e2e,
59 | author={Cai, Chang and Yuan, Xiaojun and Zhang, Ying-Jun Angela},
60 | journal={IEEE Journal on Selected Areas in Communications},
61 | title={End-to-End Learning for Task-Oriented Semantic Communications Over {MIMO} Channels: An Information-Theoretic Framework},
62 | year={2025},
63 | volume={},
64 | number={},
65 | pages={1-16}
66 | }
67 | ```
--------------------------------------------------------------------------------
/feature_load.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import argparse
4 |
5 | from tools.img_dataset import MultiviewImgDataset
6 | from models.mvcnn import MVCNN, SVCNN
7 |
8 | from tools import train_func as tf
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--bs", "--batch_size", type=int, default=30,
12 | help="batch size")
13 | parser.add_argument("--cnn_name", type=str, default="vgg11",
14 | help="cnn model name")
15 | parser.add_argument("--num_views", type=int, default=3,
16 | help="number of views")
17 | parser.add_argument("--num_classes", type=int, default=10,
18 | help="number of classes")
19 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train")
20 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test")
21 | parser.add_argument("--num_workers", type=int, default=24,
22 | help="number of workers")
23 | parser.add_argument('--svcnn_model_dir', type=str, default='./mvcnn/phase1_classes10_views3_fd1_32_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/svcnn/model-00029.pth',
24 | help='base directory for svcnn model')
25 | parser.add_argument('--mvcnn_model_dir', type=str, default='./mvcnn/phase2_classes10_views3_fd1_32_fd2_8_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/mvcnn/model-00199.pth',
26 | help='base directory for mvcnn model')
27 | parser.add_argument('--pretraining', type=bool, default=True,
28 | help='pretraining')
29 | parser.add_argument('--fd_phase1', type=int, default=32,
30 | help='dimension of feature dimension in phase 1')
31 | parser.add_argument('--fd_phase2', type=int, default=8,
32 | help='dimension of feature dimension per user in phase 2')
33 | args = parser.parse_args()
34 |
35 | ## CUDA
36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37 |
38 | if __name__ == '__main__':
39 |
40 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining,
41 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1)
42 | state_dict_1 = torch.load(args.svcnn_model_dir, map_location=torch.device('cpu'))
43 | cnet.load_state_dict(state_dict_1)
44 | cnet_2 = MVCNN(name='mvcnn', model=cnet, nclasses=args.num_classes, cnn_name=args.cnn_name,
45 | num_views=args.num_views, fd_phase1=args.fd_phase1, fd_per_user=args.fd_phase2)
46 | state_dict_2 = torch.load(args.mvcnn_model_dir, map_location=torch.device('cpu'))
47 | cnet_2.load_state_dict(state_dict_2)
48 | cnet_2.eval()
49 | del cnet
50 |
51 | train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False,
52 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999)
53 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs,
54 | shuffle=False, num_workers=args.num_workers)
55 | # shuffle needs to be false! it's done within the trainer
56 |
57 | val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True,
58 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999)
59 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs,
60 | shuffle=False, num_workers=args.num_workers)
61 | print('num_train_files: '+str(len(train_dataset.filepaths)))
62 | print('num_val_files: '+str(len(val_dataset.filepaths)))
63 |
64 | train_features, train_labels = tf.get_features(cnet_2, train_loader, 'mcr2', 'mvcnn')
65 | feature = np.array(train_features.detach().cpu())
66 | target = np.array(train_labels.detach().cpu())
67 | mdic = {"train_feature": feature, "train_label": target}
68 | # savemat(f"MCR2_ModelNet10_train_feature_label_36.mat", mdic)
69 |
70 | test_features, test_labels = tf.get_features(cnet_2, val_loader, 'mcr2', 'mvcnn')
71 | feature = np.array(test_features.detach().cpu())
72 | target = np.array(test_labels.detach().cpu())
73 | mdic = {"test_feature": feature, "test_label": target}
74 | # savemat(f"MCR2_ModelNet10_test_feature_label_36.mat", mdic)
75 |
76 |
77 |
--------------------------------------------------------------------------------
/figures/consistent_system_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/consistent_system_model.png
--------------------------------------------------------------------------------
/figures/inconsistent_system_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/inconsistent_system_model.png
--------------------------------------------------------------------------------
/figures/sphere_packing_mcr2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/figures/sphere_packing_mcr2.png
--------------------------------------------------------------------------------
/main_train_phase1.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import os
4 | import argparse
5 |
6 | from tools.trainer import ModelNetTrainer
7 | from tools.img_dataset import SingleImgDataset
8 | from models.mvcnn import SVCNN
9 |
10 | from mcr2_loss import MaximalCodingRateReduction
11 | from tools import utils
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--bs", "--batch_size", type=int, default=1200,
15 | help="batch size")
16 | parser.add_argument("--lr", type=float, default=1e-4,
17 | help="learning rate")
18 | parser.add_argument("--weight_decay", type=float, help="weight decay", default=0.001)
19 | parser.add_argument("--cnn_name", type=str, default="vgg11",
20 | help="cnn model name")
21 | parser.add_argument("--num_views", type=int, default=3,
22 | help="number of views")
23 | parser.add_argument("--num_classes", type=int, default=10,
24 | help="number of classes")
25 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train")
26 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test")
27 | parser.add_argument("--num_workers", type=int, default=32,
28 | help="number of workers")
29 | parser.add_argument('--save_dir', type=str, default='./mvcnn/',
30 | help='base directory for saving PyTorch model')
31 | parser.add_argument('--epoch', type=int, default=500,
32 | help='number of epochs for training')
33 | parser.add_argument('--eps', type=float, default=0.5,
34 | help='eps squared')
35 | parser.add_argument('--fd_phase1', type=int, default=32,
36 | help='dimension of feature dimension')
37 | parser.add_argument('--tail', type=str, default='',
38 | help='extra information to add to folder name')
39 | parser.add_argument('--pretraining', type=bool, default=True,
40 | help='pretraining')
41 | parser.add_argument('--mom', type=float, default=0.9,
42 | help='momentum')
43 | args = parser.parse_args()
44 |
45 | ## CUDA
46 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47 |
48 |
49 | if __name__ == '__main__':
50 | ## Pipelines Setup
51 | model_dir = os.path.join(args.save_dir,
52 | 'phase1_classes{}_views{}_fd1_{}_bs{}_lr{}_wd{}_eps{}{}'.format(
53 | args.num_classes, args.num_views, args.fd_phase1, args.bs, args.lr,
54 | args.weight_decay, args.eps, args.tail))
55 | utils.init_pipeline(model_dir)
56 | utils.save_params(model_dir, vars(args))
57 |
58 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining,
59 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1)
60 | cnet = cnet.to(device)
61 | optimizer = optim.Adam(cnet.parameters(), lr=args.lr, weight_decay=args.weight_decay)
62 | # optimizer = optim.SGD(cnet.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay)
63 |
64 | train_dataset = SingleImgDataset(args.train_path, scale_aug=False, rot_aug=False,
65 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999) # 80
66 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs,
67 | shuffle=True, num_workers=args.num_workers)
68 |
69 | val_dataset = SingleImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True,
70 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999) # 20
71 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs,
72 | shuffle=False, num_workers=args.num_workers)
73 | print('num_train_files: '+str(len(train_dataset.filepaths)))
74 | print('num_val_files: '+str(len(val_dataset.filepaths)))
75 |
76 | loss_fn = MaximalCodingRateReduction(gam1=1, gam2=1, eps=args.eps)
77 | trainer = ModelNetTrainer(cnet, device, train_loader, val_loader, optimizer, loss_fn, 'svcnn',
78 | model_dir, num_classes=args.num_classes, num_views=1)
79 | trainer.train(args.epoch)
80 |
81 |
82 |
--------------------------------------------------------------------------------
/main_train_phase2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import os
4 | import argparse
5 |
6 | from tools.trainer import ModelNetTrainer
7 | from tools.img_dataset import MultiviewImgDataset
8 | from models.mvcnn import MVCNN, SVCNN
9 |
10 | from mcr2_loss import MaximalCodingRateReduction
11 | from tools import utils
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--bs", "--batch_size", type=int, default=1200,
15 | help="batch size")
16 | parser.add_argument("--lr", type=float, default=1e-4,
17 | help="learning rate")
18 | parser.add_argument("--weight_decay", type=float, help="weight decay", default=0.001)
19 | parser.add_argument("--cnn_name", type=str, default="vgg11",
20 | help="cnn model name")
21 | parser.add_argument("--num_views", type=int, default=3,
22 | help="number of views")
23 | parser.add_argument("--num_classes", type=int, default=10,
24 | help="number of classes")
25 | parser.add_argument("--train_path", type=str, default="modelnet40_images_new_12x/*/train")
26 | parser.add_argument("--val_path", type=str, default="modelnet40_images_new_12x/*/test")
27 | parser.add_argument("--num_workers", type=int, default=32,
28 | help="number of workers")
29 | parser.add_argument('--save_dir', type=str, default='./mvcnn/',
30 | help='base directory for saving PyTorch model')
31 | parser.add_argument('--svcnn_model_dir', type=str, default='./mvcnn/phase1_classes10_views3_fd1_32_bs1200_lr0.0001_wd0.001_eps0.5/checkpoints/svcnn/model-00029.pth',
32 | help='base directory for svcnn model')
33 | parser.add_argument('--epoch', type=int, default=500,
34 | help='number of epochs for training')
35 | parser.add_argument('--eps', type=float, default=0.5,
36 | help='eps squared')
37 | parser.add_argument('--fd_phase1', type=int, default=32,
38 | help='dimension of feature dimension in phase 1')
39 | parser.add_argument('--fd_phase2', type=int, default=8,
40 | help='dimension of feature dimension per user in phase 2')
41 | parser.add_argument('--tail', type=str, default='',
42 | help='extra information to add to folder name')
43 | parser.add_argument('--pretraining', type=bool, default=True,
44 | help='pretraining')
45 | parser.add_argument('--mom', type=float, default=0.9,
46 | help='momentum')
47 | args = parser.parse_args()
48 |
49 | ## CUDA
50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51 |
52 |
53 | if __name__ == '__main__':
54 | ## Pipelines Setup
55 | model_dir = os.path.join(args.save_dir,
56 | 'phase2_classes{}_views{}_fd1_{}_fd2_{}_bs{}_lr{}_wd{}_eps{}{}'.format(
57 | args.num_classes, args.num_views, args.fd_phase1, args.fd_phase2, args.bs, args.lr,
58 | args.weight_decay, args.eps, args.tail))
59 | utils.init_pipeline(model_dir)
60 | utils.save_params(model_dir, vars(args))
61 |
62 | cnet = SVCNN(name='svcnn', nclasses=args.num_classes, pretraining=args.pretraining,
63 | cnn_name=args.cnn_name, fd_phase1=args.fd_phase1)
64 | state_dict = torch.load(args.svcnn_model_dir, map_location=torch.device('cpu'))
65 | cnet.load_state_dict(state_dict)
66 | cnet_2 = MVCNN(name='mvcnn', model=cnet, nclasses=args.num_classes, cnn_name=args.cnn_name,
67 | num_views=args.num_views, fd_phase1=args.fd_phase1, fd_per_user=args.fd_phase2)
68 | del cnet
69 |
70 | cnet_2 = cnet_2.to(device)
71 | optimizer = optim.Adam(cnet_2.parameters(), lr=args.lr, weight_decay=args.weight_decay)
72 |
73 | train_dataset = MultiviewImgDataset(args.train_path, scale_aug=False, rot_aug=False,
74 | num_classes=args.num_classes, num_views=args.num_views, train_objects=9999) # 80
75 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.bs,
76 | shuffle=False, num_workers=args.num_workers)
77 | # shuffle needs to be false! it's done within the trainer
78 |
79 | val_dataset = MultiviewImgDataset(args.val_path, scale_aug=False, rot_aug=False, test_mode=True,
80 | num_classes=args.num_classes, num_views=args.num_views, test_objects=9999) # 20
81 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.bs,
82 | shuffle=False, num_workers=args.num_workers)
83 | print('num_train_files: '+str(len(train_dataset.filepaths)))
84 | print('num_val_files: '+str(len(val_dataset.filepaths)))
85 |
86 | loss_fn = MaximalCodingRateReduction(gam1=1, gam2=1, eps=args.eps)
87 | trainer = ModelNetTrainer(cnet_2, device, train_loader, val_loader, optimizer, loss_fn, 'mvcnn',
88 | model_dir, num_classes=args.num_classes, num_views=args.num_views)
89 | trainer.train(args.epoch)
90 |
91 |
92 |
--------------------------------------------------------------------------------
/mcr2_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from tools import train_func as tf
4 |
5 |
6 | # import utils
7 |
8 |
9 | # def one_hot(labels_int, n_classes):
10 | # """Turn labels into one hot vector of K classes. """
11 | # labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float()
12 | # for i, y in enumerate(labels_int):
13 | # labels_onehot[i, y] = 1.
14 | # return labels_onehot
15 | #
16 | #
17 | # def label_to_membership(targets, num_classes=None):
18 | # """Generate a true membership matrix, and assign value to current Pi.
19 | #
20 | # Parameters:
21 | # targets (np.ndarray): matrix with one hot labels
22 | #
23 | # Return:
24 | # Pi: membership matirx, shape (num_classes, num_samples, num_samples)
25 | #
26 | # """
27 | # targets = one_hot(targets, num_classes)
28 | # num_samples, num_classes = targets.shape
29 | # Pi = np.zeros(shape=(num_classes, num_samples, num_samples))
30 | # for j in range(len(targets)):
31 | # k = np.argmax(targets[j])
32 | # Pi[k, j, j] = 1.
33 | # return Pi
34 |
35 |
36 | class MaximalCodingRateReduction(torch.nn.Module):
37 | def __init__(self, gam1=1.0, gam2=1.0, eps=0.01):
38 | super(MaximalCodingRateReduction, self).__init__()
39 | self.gam1 = gam1
40 | self.gam2 = gam2
41 | self.eps = eps
42 |
43 | def compute_discrimn_loss_empirical(self, W):
44 | """Empirical Discriminative Loss."""
45 | p, m = W.shape
46 | I = torch.eye(p).to(W.device)
47 | scalar = p / (m * self.eps)
48 | logdet = self.logdet(I + self.gam1 * scalar * W.matmul(W.T))
49 | return logdet / 2.
50 |
51 | def compute_compress_loss_empirical(self, W, Pi):
52 | """Empirical Compressive Loss."""
53 | p, m = W.shape
54 | k, _, _ = Pi.shape
55 | I = torch.eye(p).to(W.device)
56 | compress_loss = 0.
57 | for j in range(k):
58 | trPi = torch.trace(Pi[j]) + 1e-8
59 | scalar = p / (trPi * self.eps)
60 | log_det = self.logdet(I + scalar * W.matmul(Pi[j]).matmul(W.T))
61 | compress_loss += log_det * trPi / m
62 | return compress_loss / 2.
63 |
64 | def logdet(self, X):
65 | sgn, logdet = torch.linalg.slogdet(X)
66 | return sgn * logdet
67 |
68 | def forward(self, X, Y, num_classes=None):
69 | if num_classes is None:
70 | num_classes = Y.max() + 1
71 | W = X.T
72 | Pi = tf.label_to_membership(Y.numpy(), num_classes)
73 | Pi = torch.tensor(Pi, dtype=torch.float32).to(X.device)
74 |
75 | discrimn_loss_empi = self.compute_discrimn_loss_empirical(W)
76 | compress_loss_empi = self.compute_compress_loss_empirical(W, Pi)
77 |
78 | total_loss_empi = self.gam2 * -discrimn_loss_empi + compress_loss_empi
79 | return (total_loss_empi,
80 | [discrimn_loss_empi.item(), compress_loss_empi.item()])
81 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import glob
5 |
6 |
7 | class Model(nn.Module):
8 |
9 | def __init__(self, name):
10 | super(Model, self).__init__()
11 | self.name = name
12 |
13 |
14 | def save(self, path, epoch=0):
15 | complete_path = os.path.join(path, self.name)
16 | if not os.path.exists(complete_path):
17 | os.makedirs(complete_path)
18 | torch.save(self.state_dict(),
19 | os.path.join(complete_path,
20 | "model-{}.pth".format(str(epoch).zfill(5))))
21 |
22 |
23 | def save_results(self, path, data):
24 | raise NotImplementedError("Model subclass must implement this method.")
25 |
26 |
27 | def load(self, path, modelfile=None):
28 | complete_path = os.path.join(path, self.name)
29 | if not os.path.exists(complete_path):
30 | raise IOError("{} directory does not exist in {}".format(self.name, path))
31 |
32 | if modelfile is None:
33 | model_files = glob.glob(complete_path+"/*")
34 | mf = max(model_files)
35 | else:
36 | mf = os.path.join(complete_path, modelfile)
37 |
38 | self.load_state_dict(torch.load(mf))
39 |
40 |
41 |
--------------------------------------------------------------------------------
/models/mvcnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | import torchvision.models as models
8 | from .model import Model
9 | import copy
10 |
11 |
12 | # mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda()
13 | # std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda()
14 |
15 | def flip(x, dim):
16 | xsize = x.size()
17 | dim = x.dim() + dim if dim < 0 else dim
18 | x = x.view(-1, *xsize[dim:])
19 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1) - 1,
20 | -1, -1), ('cpu', 'cuda')[x.is_cuda])().long(), :]
21 | return x.view(xsize)
22 |
23 |
24 | class SVCNN(Model):
25 |
26 | def __init__(self, name, nclasses=40, pretraining=True, cnn_name='vgg11', fd_phase1=32):
27 | super(SVCNN, self).__init__(name)
28 |
29 | if nclasses == 10:
30 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser',
31 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
32 | elif nclasses == 40:
33 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
34 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
35 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
36 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
37 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
38 |
39 | self.nclasses = nclasses
40 | self.pretraining = pretraining
41 | self.cnn_name = cnn_name
42 | self.use_resnet = cnn_name.startswith('resnet')
43 | # self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda()
44 | # self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda()
45 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False)
46 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False)
47 | self.fd_phase1 = fd_phase1
48 |
49 | if self.use_resnet:
50 | if self.cnn_name == 'resnet18':
51 | self.net = models.resnet18(pretrained=self.pretraining)
52 | self.net.fc = nn.Linear(512, 40)
53 | elif self.cnn_name == 'resnet34':
54 | self.net = models.resnet34(pretrained=self.pretraining)
55 | self.net.fc = nn.Linear(512, 40)
56 | elif self.cnn_name == 'resnet50':
57 | self.net = models.resnet50(pretrained=self.pretraining)
58 | self.net.fc = nn.Linear(2048, 40)
59 | else:
60 | if self.cnn_name == 'alexnet':
61 | self.net_1 = models.alexnet(pretrained=self.pretraining).features
62 | self.net_2 = models.alexnet(pretrained=self.pretraining).classifier
63 | elif self.cnn_name == 'vgg11':
64 | self.net_1 = models.vgg11(pretrained=self.pretraining).features
65 | # self.net_2 = models.vgg11(pretrained=self.pretraining).classifier
66 | self.net_2 = torch.nn.Sequential(
67 | nn.Linear(25088, 4096, bias=True), # False
68 | nn.BatchNorm1d(4096),
69 | nn.ReLU(inplace=True),
70 | nn.Linear(4096, self.fd_phase1, bias=True)
71 | )
72 | elif self.cnn_name == 'vgg16':
73 | self.net_1 = models.vgg16(pretrained=self.pretraining).features
74 | self.net_2 = models.vgg16(pretrained=self.pretraining).classifier
75 |
76 | # self.net_2._modules['6'] = nn.Linear(4096, self.fd)
77 |
78 | def forward(self, x):
79 | if self.use_resnet:
80 | return self.net(x)
81 | else:
82 | y = self.net_1(x)
83 | feature = F.normalize(self.net_2(y.view(y.shape[0], -1)))
84 | return feature
85 |
86 |
87 | class MVCNN(Model):
88 |
89 | def __init__(self, name, model, nclasses=40, cnn_name='vgg11', num_views=12, fd_phase1=32, fd_per_user=6):
90 | super(MVCNN, self).__init__(name)
91 |
92 | if nclasses == 10:
93 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser',
94 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
95 | elif nclasses == 40:
96 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
97 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
98 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
99 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
100 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
101 |
102 | self.nclasses = nclasses
103 | self.num_views = num_views
104 | self.fd_per_user = fd_per_user
105 | self.fd_phase1 = fd_phase1
106 | # self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).cuda()
107 | # self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).cuda()
108 | self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False)
109 | self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False)
110 |
111 | # self.use_resnet = cnn_name.startswith('resnet')
112 | #
113 | # if self.use_resnet:
114 | # self.net_1 = nn.Sequential(*list(model.net.children())[:-1])
115 | # self.net_2 = model.net.fc
116 | # else:
117 | # self.net_1 = model.net_1
118 | # self.net_2 = model.net_2
119 |
120 | self.net_1 = model.net_1
121 |
122 | net_2_list = []
123 | for _ in range(self.num_views):
124 | net_2 = copy.deepcopy(model.net_2)
125 | net_2._modules['4'] = nn.BatchNorm1d(self.fd_phase1)
126 | net_2._modules['5'] = nn.ReLU(inplace=True)
127 | net_2._modules['6'] = nn.Linear(self.fd_phase1, 32, bias=True) # False
128 | net_2._modules['7'] = nn.BatchNorm1d(32)
129 | net_2._modules['8'] = nn.ReLU(inplace=True)
130 | net_2._modules['9'] = nn.Linear(32, self.fd_per_user, bias=True)
131 |
132 | net_2_list.append(net_2)
133 | self.net_2 = nn.ModuleList(net_2_list)
134 |
135 | def forward(self, x):
136 | y = self.net_1(x) # (bs*views,512,7,7)
137 | y = y.view((int(x.shape[0] / self.num_views), self.num_views, -1)) # (bs,views,25088)
138 |
139 | feature_list = []
140 | for i in range(self.num_views):
141 | feature_i = self.net_2[i](y[:, i]) # (bs, fd_i)
142 | feature_list.append(feature_i)
143 |
144 | feature = torch.cat(feature_list, dim=1) # (bs, fd_i*views)
145 | return F.normalize(feature)
146 |
--------------------------------------------------------------------------------
/precoding_opt_matlab/H_slot_rician_channel_gen.m:
--------------------------------------------------------------------------------
1 | function H_slot = H_slot_rician_channel_gen(N_t_k, N_r, kappa)
2 |
3 | AoA = deg2rad(360*rand(1));
4 | AoD = deg2rad(360*rand(1));
5 | H_LoS = steering_vec(sin(AoA), N_r) * steering_vec(sin(AoD), N_t_k)';
6 | H_NLoS = 1/sqrt(2*N_r*N_t_k) * (randn(N_r, N_t_k) + 1j*randn(N_r, N_t_k));
7 | % H_NLoS = 1/sqrt(2) * (randn(N_r, N_t_k) + 1j*randn(N_r, N_t_k));
8 | H_slot = sqrt(kappa/(1+kappa))*H_LoS + sqrt(1/(1+kappa))*H_NLoS;
9 |
10 | end
--------------------------------------------------------------------------------
/precoding_opt_matlab/MCR2_ModelNet10_statistics_complex_24.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/precoding_opt_matlab/MCR2_ModelNet10_statistics_complex_24.mat
--------------------------------------------------------------------------------
/precoding_opt_matlab/MCR2_ModelNet10_test_feature_label_complex_24.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chang-cai/TaskCommMCR2/dfbba3b636d3aba5c8844e2393ead6657acd3518/precoding_opt_matlab/MCR2_ModelNet10_test_feature_label_complex_24.mat
--------------------------------------------------------------------------------
/precoding_opt_matlab/MCR2_obj_cplx.m:
--------------------------------------------------------------------------------
1 | function [mcr2, mcr2_1, mcr2_2] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p)
2 |
3 | Dim = size(H, 1);
4 | beta = 1 + alpha*delta_0^2;
5 |
6 | mcr2_1 = log( det( beta*eye(Dim) + alpha*H*V*C*V'*H' ) );
7 |
8 | mcr2_2 = 0;
9 | class = size(p, 2);
10 | for j = 1:class
11 | C_j = CLS_cov(:, :, j);
12 | mcr2_2 = mcr2_2 + p(j) * ...
13 | log( det( beta*eye(Dim) + alpha*H*V*C_j*V'*H' ) );
14 | end
15 | % mcr2 = 0.5 * real(mcr2_1 - mcr2_2);
16 | %
17 | % mcr2_1 = 0.5 * real(mcr2_1);
18 | % mcr2_2 = 0.5 * real(mcr2_2);
19 | mcr2 = real(mcr2_1 - mcr2_2);
20 |
21 | mcr2_1 = real(mcr2_1);
22 | mcr2_2 = real(mcr2_2);
23 | end
--------------------------------------------------------------------------------
/precoding_opt_matlab/bisection_search_bar.m:
--------------------------------------------------------------------------------
1 | function lambda = bisection_search_bar(lambda_lb, lambda_ub, P, epsilon, g, Gamma, matrix_form)
2 | % Detailed explanation goes here
3 |
4 | lb = lambda_lb;
5 | ub = lambda_ub;
6 | lambda = lambda_ub;
7 |
8 | while ub - lb >= epsilon
9 | lambda = (lb + ub)/2;
10 | if f_bar(lambda, g, Gamma, matrix_form) <= P
11 | ub = lambda;
12 | else
13 | lb = lambda;
14 | end
15 | end
16 |
17 | end
--------------------------------------------------------------------------------
/precoding_opt_matlab/f_bar.m:
--------------------------------------------------------------------------------
1 | function fcn_bar = f_bar(lambda, g, Gamma, matrix_form)
2 | % Detailed explanation goes here
3 |
4 | fcn_bar = 0;
5 | if matrix_form == false
6 | for i = 1:size(Gamma, 1)
7 | fcn_i = norm(g(i))^2/( (Gamma(i, i) + lambda)^2 );
8 | fcn_bar = fcn_bar + fcn_i;
9 | end
10 | else
11 | for i = 1:size(Gamma, 1)
12 | fcn_i = abs(g(i))/( (Gamma(i, i) + lambda)^2 );
13 | fcn_bar = fcn_bar + fcn_i;
14 | end
15 | end
16 | end
--------------------------------------------------------------------------------
/precoding_opt_matlab/gaussian_pdf.m:
--------------------------------------------------------------------------------
1 | function [f, log_f] = gaussian_pdf(x, mean, covariance, epsilon)
2 | % multivariate gaussian pdf
3 |
4 | n = length(x);
5 | coef_1 = 1/sqrt((2*pi)^n);
6 | % det(covariance+epsilon*eye(n))
7 | coef_2 = 1/sqrt(det(covariance+epsilon*eye(n)));
8 | coef = coef_1 * coef_2;
9 | coef_3 = -0.5*(x-mean)'*inv(covariance+epsilon*eye(n))*(x-mean);
10 | expo = exp(coef_3);
11 | f = coef*expo;
12 |
13 | %%
14 | log_f = log(coef)+coef_3;
15 |
16 | end
--------------------------------------------------------------------------------
/precoding_opt_matlab/main_precoding_opt.m:
--------------------------------------------------------------------------------
1 | clc
2 | clear
3 |
4 | load("MCR2_ModelNet10_statistics_complex_24.mat", "feature_cplx", "label",...
5 | "C_xx_real", "mean_real", "C_xx", "CLS_mean", "CLS_cov", "CLS_rlt");
6 | load("MCR2_ModelNet10_test_feature_label_complex_24.mat", "p", "test_feature_cplx", "test_label");
7 | test_feature = test_feature_cplx.';
8 |
9 |
10 | B = 10*1e3; % 10 kHz
11 | K = 3; % num. of devices
12 | N_t_k = 4; % num. of transmit antennas
13 | N_t = N_t_k * K;
14 | N_r = 8; % num. of receive antennas
15 | T = 1; % num. of time slots
16 |
17 | Dataset = size(test_feature, 2);
18 | % NMSE_factor = norm(test_feature, "fro")^2/Dataset;
19 | C = C_xx; % covariance matrix
20 | L = 10; % num. of classes
21 | D = 2*size(test_feature, 1);
22 |
23 | delta_0_square = 1e-20*B; % noise power (W) -170 dBm/Hz density
24 | delta_0 = sqrt(delta_0_square); % noise
25 | P_k_dBm = 0; % dBm
26 | P_k = 1e-3*10.^(P_k_dBm./10); % W
27 | P = P_k * K;
28 |
29 | d = 240; % m
30 | PL = 32.6 + 36.7*log10(d); % dB
31 | PL = 10^(-PL/10);
32 | delta_0 = delta_0/sqrt(PL); % normalize path loss into the noise power
33 |
34 | eps = 1e-3;
35 |
36 | % channel generation
37 | kappa = 1; % Rician factor
38 | H = [];
39 | H_k_all = zeros(T*N_r, T*N_t_k, K);
40 | H_k_t_all = zeros(N_r, N_t_k, K);
41 | for k = 1:K
42 | H_k_t = H_slot_rician_channel_gen(N_t_k, N_r, kappa);
43 | H_k_t_all(:, :, k) = H_k_t;
44 | % svd(H_k_t)
45 | H_k = kron(eye(T), H_k_t);
46 | H_k_all(:, :, k) = H_k;
47 | H = [H, H_k];
48 | end
49 |
50 | %% MCR2 Precoder
51 | alpha = T*N_r / eps^2;
52 | C_sr = sqrtm(C);
53 |
54 | % initialization of V
55 | V_init_k = rand(T*N_t_k, D/(2*K)) + 1j*rand(T*N_t_k, D/(2*K));
56 | V_init = kron(eye(K), V_init_k);
57 | V_init = sqrt(P) * V_init ./ sqrt(trace(V_init*C*V_init'));
58 | V = V_init;
59 |
60 | Ite = 1000;
61 | mcr2_obj_last = 500;
62 | mcr2_all = zeros(1, Ite);
63 | for ii = 1:Ite
64 | ii;
65 | % U update
66 | U = (H*V*C*V'*H'+(1/alpha + delta_0^2)*eye(N_r*T))\(H*V*C_sr);
67 |
68 | % W update
69 | E_0 = eye(D/2)-U'*H*V*C_sr;
70 | E_0 = (E_0*E_0') + (1/alpha + delta_0^2)*(U'*U);
71 | W_0 = inv(E_0);
72 | W_j_all = zeros(N_r*T, N_r*T, L);
73 | for j = 1:L
74 | C_j = CLS_cov(:, :, j);
75 | W_j_all(:, :, j) = inv((1+alpha*delta_0^2)*eye(N_r*T) + alpha*H*V*C_j*V'*H');
76 | end
77 |
78 | % V update
79 | for k = 1:K
80 | H_k = H_k_all(:, :, k);
81 | C_sr_k = C_sr(((k-1)*D/(2*K)+1):k*D/(2*K), :);
82 | C_k = C(((k-1)*D/(2*K)+1):k*D/(2*K), :);
83 | C_kk = C(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K));
84 | V_k = V(((k-1)*T*N_t_k+1):k*T*N_t_k, ...
85 | ((k-1)*D/(2*K)+1):k*D/(2*K));
86 |
87 | sum_term1_T_k = H*V*C_k' - H_k*V_k*C_kk;
88 | sum_term2_T_k = 0; % zeros(2*T*N_t_k, D/K);
89 | sum_term_M_k = 0;
90 | % I_kron_H_bar_k = kron(eye(D/(2*K)), H_k);
91 | for j = 1:L
92 | W_j = W_j_all(:, :, j);
93 |
94 | C_j_k = CLS_cov(((k-1)*D/(2*K)+1):k*D/(2*K), :, j);
95 | C_j_kk = CLS_cov(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K), j);
96 | sum_term2_j_T_k = H*V*C_j_k' - H_k*V_k*C_j_kk;
97 | sum_term2_T_k = sum_term2_T_k + ...
98 | alpha*p(j)*H_k'*W_j*sum_term2_j_T_k;
99 |
100 | % sum_term_M_k = sum_term_M_k + ...
101 | % alpha*p(j)*I_kron_H_bar_k'*kron(C_j_kk, W_j)*I_kron_H_bar_k;
102 | sum_term_M_k = sum_term_M_k + ...
103 | alpha*p(j)*kron(C_j_kk.', H_k'*W_j*H_k);
104 | end
105 | T_k = H_k'*U*W_0*C_sr_k' - ...
106 | H_k'*U*W_0*U'*sum_term1_T_k - sum_term2_T_k;
107 | % I_kron_U_H_bar_k = kron(eye(D/K), U'*H_k);
108 | M_k = kron(C_kk.', H_k'*U*W_0*U'*H_k) + ...
109 | sum_term_M_k;
110 |
111 | [Q_kk, D_kk] = svd(C_kk.');
112 | Q_kk_kron_I = kron(Q_kk, eye(N_t_k*T));
113 | D_kk_kron_I = kron(D_kk, eye(N_t_k*T));
114 | D_kk_kron_I_msr = diag( 1./sqrt(diag(D_kk_kron_I)) );
115 | C_kk_kron_I_msr = Q_kk_kron_I * D_kk_kron_I_msr * Q_kk_kron_I';
116 |
117 | t_k = C_kk_kron_I_msr*vec(T_k);
118 | M_k = C_kk_kron_I_msr*M_k*C_kk_kron_I_msr;
119 |
120 |
121 | [U_M_k, Gamma_k] = svd(M_k);
122 | g_k = U_M_k' * t_k;
123 |
124 | lambda_lb = 0;
125 | lambda_ub = sqrt(real(g_k'*g_k)/(P_k*T));
126 |
127 | fcn_bar_0 = f_bar(0, g_k, Gamma_k, false);
128 | epsilon_bi_search = 1e-6;
129 | if fcn_bar_0 <= P_k
130 | lambda = 0;
131 | else
132 | lambda = bisection_search_bar(lambda_lb, lambda_ub, T*P_k, epsilon_bi_search, g_k, Gamma_k, false);
133 | end
134 |
135 | if lambda < 1e-20
136 | v_k = pinv( M_k + lambda*eye(N_t_k*T*D/(2*K)) ) * t_k;
137 | else
138 | v_k = ( M_k + lambda*eye(N_t_k*T*D/(2*K)) ) \ t_k;
139 | end
140 |
141 | v_k = C_kk_kron_I_msr * v_k;
142 |
143 | % retrive V_k from v_k
144 | V_k = zeros(T*N_t_k, D/(2*K));
145 | for i = 1:D/(2*K)
146 | V_k(:, i) = v_k(T*N_t_k*(i-1)+1: T*N_t_k*i);
147 | end
148 | V((T*N_t_k*(k-1)+1):T*N_t_k*k, ((k-1)*D/(2*K)+1):k*D/(2*K)) = V_k;
149 | end
150 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p);
151 | mcr2_all(ii) = mcr2;
152 |
153 | % terminate criterion
154 | if abs((mcr2 - mcr2_obj_last)/mcr2_obj_last) < 1e-5
155 | break;
156 | end
157 | mcr2_obj_last = mcr2;
158 | end
159 | figure;
160 | plot(mcr2_all(1:ii), 'r-', LineWidth=1.6);
161 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p);
162 |
163 | %% MAP receiver and LMMSE estimator
164 | epsilon = 1e-6; % trick
165 | acc_mcr2 = 0;
166 | mse_by_mcr2 = 0;
167 | parfor j = 1:Dataset
168 | z_cplx = test_feature(:, j);
169 | % add noise
170 | n_cplx = delta_0 * sqrt(1/2) * (randn(N_r*T, 1) + 1j*randn(N_r*T, 1));
171 | C_n_cplx = delta_0^2 * eye(N_r*T);
172 | y_cplx = H*V*z_cplx + n_cplx;
173 |
174 | %% Recover signals from complex to real
175 | z_real = [real(z_cplx); imag(z_cplx)];
176 | y_real = [real(y_cplx); imag(y_cplx)];
177 | HV_real = [real(H*V), -imag(H*V);...
178 | imag(H*V), real(H*V)];
179 | C_n_real = 1/2 * [real(C_n_cplx), zeros(N_r*T, N_r*T);
180 | zeros(N_r*T, N_r*T), real(C_n_cplx)];
181 |
182 | %%
183 | f = zeros(1, L);
184 | log_f = zeros(1, L); % numerical issue
185 | weighted_sum_element = zeros(1, L);
186 | for i = 1:L
187 | mu_i_cplx = CLS_mean(:, i);
188 | C_i = CLS_cov(:, :, i);
189 | R_i = CLS_rlt(:, :, i);
190 |
191 | mu_i = [real(mu_i_cplx); imag(mu_i_cplx)];
192 | Cov_i = 1/2 * [real(C_i + R_i), imag(-C_i + R_i);...
193 | imag(C_i + R_i), real(C_i - R_i)];
194 |
195 | [f(i), log_f(i)] = gaussian_pdf(y_real, HV_real*mu_i,...
196 | HV_real*Cov_i*HV_real' + C_n_real, epsilon);
197 | weighted_sum_element(i) = p(i) * f(i);
198 | end
199 | weighted_sum = sum(weighted_sum_element);
200 |
201 | % GM receiver
202 | % [~, pos] = max(log_f);
203 | [~, pos] = max(weighted_sum_element);
204 | if (pos-1)==test_label(j)
205 | acc_mcr2 = acc_mcr2 +1;
206 | end
207 |
208 | % LMMSE estimate
209 | z_lmmse_estimate = mean_real + C_xx_real*HV_real'*...
210 | inv(HV_real*C_xx_real*HV_real'+C_n_real+epsilon*eye(2*N_r*T))*...
211 | (y_real-HV_real*mean_real);
212 | mse_by_mcr2 = mse_by_mcr2 + norm(z_lmmse_estimate-z_real, 2)^2;
213 |
214 | end
215 | acc_mcr2 = acc_mcr2/Dataset
216 | mse_by_mcr2 = mse_by_mcr2/Dataset
217 |
218 |
219 | %% LMMSE Precoder
220 | alpha = T*N_r / eps^2;
221 |
222 | % initialization of V
223 | V_init_k = rand(T*N_t_k, D/(2*K)) + 1j*rand(T*N_t_k, D/(2*K));
224 | V_init = kron(eye(K), V_init_k);
225 | V_init = sqrt(P) * V_init ./ sqrt(trace(V_init*C*V_init'));
226 | V = V_init;
227 |
228 | Ite = 1000;
229 | mse_obj_last = 500;
230 | mse_all = zeros(1, Ite);
231 | for ii = 1:Ite
232 | ii;
233 | % R update
234 | R = (H*V*C*V'*H'+(delta_0^2)*eye(N_r*T))\(H*V*C);
235 |
236 | % V update
237 | for k = 1:K
238 | H_k = H_k_all(:, :, k);
239 | C_k = C(((k-1)*D/(2*K)+1):k*D/(2*K), :);
240 | C_kk = C(((k-1)*D/(2*K)+1):k*D/(2*K), ((k-1)*D/(2*K)+1):k*D/(2*K));
241 | V_k = V(((k-1)*T*N_t_k+1):k*T*N_t_k, ...
242 | ((k-1)*D/(2*K)+1):k*D/(2*K));
243 |
244 | sum_term_T_k = H*V*C_k' - H_k*V_k*C_kk;
245 | J_k = C_k - sum_term_T_k'*R;
246 |
247 | Q_k = H_k'*R*J_k'/C_kk;
248 | M_k = H_k'*R*R'*H_k;
249 |
250 | [U_M_k, Gamma_k] = svd(M_k);
251 | G_k = U_M_k'*Q_k*C_kk*Q_k'*U_M_k;
252 |
253 | lambda_lb = 0;
254 | lambda_ub = sqrt(real(sum(diag(G_k)))/(T*P_k));
255 |
256 | fcn_bar_0 = f_bar(0, diag(G_k), Gamma_k, true);
257 | epsilon_bi_search = 1e-6;
258 | if fcn_bar_0 <= T*P_k
259 | lambda = 0;
260 | else
261 | lambda = bisection_search_bar(lambda_lb, lambda_ub, T*P_k, epsilon_bi_search, diag(G_k), Gamma_k, true);
262 | end
263 | V_k = (M_k + lambda*eye(size(M_k)))\Q_k;
264 |
265 | V((T*N_t_k*(k-1)+1):T*N_t_k*k, ((k-1)*D/(2*K)+1):k*D/(2*K)) = V_k;
266 | end
267 | matrix_inv = inv(H*V*C*V'*H'+delta_0^2*eye(N_r*T));
268 | mse_obj = real( trace(C) - trace(C*V'*H'*matrix_inv*H*V*C) );
269 | mse_all(ii) = mse_obj;
270 |
271 | % terminate criterion
272 | if abs((mse_obj - mse_obj_last)/mse_obj_last) < 1e-5
273 | break;
274 | end
275 | mse_obj_last = mse_obj;
276 | end
277 | figure;
278 | plot(mse_all(1:ii), 'r-', LineWidth=1.6)
279 | [mcr2, ~, ~] = MCR2_obj_cplx(V, H, C, CLS_cov, alpha, delta_0, p);
280 |
281 | %% MAP receiver and LMMSE estimator
282 | epsilon = 1e-6; % trick
283 | acc_lmmse = 0;
284 | mse_by_lmmse = 0;
285 | parfor j = 1:Dataset
286 | z_cplx = test_feature(:, j);
287 | % add noise
288 | n_cplx = delta_0 * sqrt(1/2) * (randn(N_r*T, 1) + 1j*randn(N_r*T, 1));
289 | C_n_cplx = delta_0^2 * eye(N_r*T);
290 | y_cplx = H*V*z_cplx + n_cplx;
291 |
292 | %% Recover signals from complex to real
293 | z_real = [real(z_cplx); imag(z_cplx)];
294 | y_real = [real(y_cplx); imag(y_cplx)];
295 | HV_real = [real(H*V), -imag(H*V);...
296 | imag(H*V), real(H*V)];
297 | C_n_real = 1/2 * [real(C_n_cplx), zeros(N_r*T, N_r*T);
298 | zeros(N_r*T, N_r*T), real(C_n_cplx)];
299 |
300 | %%
301 | f = zeros(1, L);
302 | log_f = zeros(1, L); % numerical issue
303 | weighted_sum_element = zeros(1, L);
304 | for i = 1:L
305 | mu_i_cplx = CLS_mean(:, i);
306 | C_i = CLS_cov(:, :, i);
307 | R_i = CLS_rlt(:, :, i);
308 |
309 | mu_i = [real(mu_i_cplx); imag(mu_i_cplx)];
310 | Cov_i = 1/2 * [real(C_i + R_i), imag(-C_i + R_i);...
311 | imag(C_i + R_i), real(C_i - R_i)];
312 |
313 | [f(i), log_f(i)] = gaussian_pdf(y_real, HV_real*mu_i,...
314 | HV_real*Cov_i*HV_real' + C_n_real, epsilon);
315 | weighted_sum_element(i) = p(i) * f(i);
316 | end
317 | weighted_sum = sum(weighted_sum_element);
318 |
319 | % GM receiver
320 | % [~, pos] = max(log_f);
321 | [~, pos] = max(weighted_sum_element);
322 | if (pos-1)==test_label(j)
323 | acc_lmmse = acc_lmmse +1;
324 | end
325 |
326 | % LMMSE estimate
327 | z_lmmse_estimate = mean_real + C_xx_real*HV_real'*...
328 | inv(HV_real*C_xx_real*HV_real'+C_n_real+epsilon*eye(2*N_r*T))*...
329 | (y_real-HV_real*mean_real);
330 | mse_by_lmmse = mse_by_lmmse + norm(z_lmmse_estimate-z_real, 2)^2;
331 |
332 | end
333 | acc_lmmse = acc_lmmse/Dataset
334 | mse_by_lmmse = mse_by_lmmse/Dataset
335 |
--------------------------------------------------------------------------------
/precoding_opt_matlab/steering_vec.m:
--------------------------------------------------------------------------------
1 | function [e] = steering_vec(x,M)
2 | % steering vector function
3 | m = (1:M)';
4 | e = 1/sqrt(M) * exp(-1j*pi*(m-1)*x);
5 | % e = exp(-1j*pi*(m-1)*x);
6 | end
7 |
8 |
--------------------------------------------------------------------------------
/tools/evaluate_func.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.svm import LinearSVC
3 | from sklearn.decomposition import PCA
4 | from sklearn.decomposition import TruncatedSVD
5 |
6 | import utils
7 |
8 |
9 | def svm(train_features, train_labels, test_features, test_labels):
10 | svm = LinearSVC(verbose=0, random_state=10)
11 | svm.fit(train_features, train_labels)
12 | acc_train = svm.score(train_features, train_labels)
13 | acc_test = svm.score(test_features, test_labels)
14 | print("acc train SVM: {}".format(acc_train))
15 | print("acc test SVM: {}".format(acc_test))
16 | return acc_train, acc_test
17 |
18 |
19 | def knn(k, train_features, train_labels, test_features, test_labels):
20 | """Perform k-Nearest Neighbor classification using cosine similaristy as metric.
21 |
22 | Options:
23 | k (int): top k features for kNN
24 |
25 | """
26 | sim_mat = train_features @ test_features.T
27 | topk = sim_mat.topk(k=k, dim=0)
28 | topk_pred = train_labels[topk.indices]
29 | test_pred = topk_pred.mode(0).values.detach()
30 | acc = utils.compute_accuracy(test_pred.numpy(), test_labels.numpy())
31 | print("kNN: {}".format(acc))
32 | return acc
33 |
34 |
35 | def nearsub(n_comp, train_features, train_labels, test_features, test_labels):
36 | """Perform nearest subspace classification.
37 |
38 | Options:
39 | n_comp (int): number of components for PCA or SVD
40 |
41 | """
42 | scores_pca = []
43 | scores_svd = []
44 | num_classes = train_labels.numpy().max() + 1 # should be correct most of the time
45 | features_sort, _ = utils.sort_dataset(train_features.numpy(), train_labels.numpy(),
46 | num_classes=num_classes, stack=False)
47 | fd = features_sort[0].shape[1]
48 | for j in range(num_classes):
49 | pca = PCA(n_components=n_comp).fit(features_sort[j])
50 | pca_subspace = pca.components_.T
51 | mean = np.mean(features_sort[j], axis=0)
52 | pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \
53 | @ (test_features.numpy() - mean).T
54 | score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0)
55 |
56 | svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j])
57 | svd_subspace = svd.components_.T
58 | svd_j = (np.eye(fd) - svd_subspace @ svd_subspace.T) \
59 | @ (test_features.numpy()).T
60 | score_svd_j = np.linalg.norm(svd_j, ord=2, axis=0)
61 |
62 | scores_pca.append(score_pca_j)
63 | scores_svd.append(score_svd_j)
64 | test_predict_pca = np.argmin(scores_pca, axis=0)
65 | test_predict_svd = np.argmin(scores_svd, axis=0)
66 | acc_pca = utils.compute_accuracy(test_predict_pca, test_labels.numpy())
67 | acc_svd = utils.compute_accuracy(test_predict_svd, test_labels.numpy())
68 | print('PCA: {}'.format(acc_pca))
69 | print('SVD: {}'.format(acc_svd))
70 | return acc_svd
--------------------------------------------------------------------------------
/tools/img_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import torch.utils.data
4 | import os
5 | import math
6 | from skimage import io, transform
7 | from PIL import Image
8 | import torch
9 | import torchvision as vision
10 | from torchvision import transforms, datasets
11 | import random
12 |
13 |
14 | class MultiviewImgDataset(torch.utils.data.Dataset):
15 |
16 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, \
17 | num_classes=40, num_views=12, train_objects=80, test_objects=20, shuffle=True):
18 |
19 | if num_classes == 10:
20 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser',
21 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
22 | elif num_classes == 40:
23 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
24 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
25 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
26 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
27 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
28 |
29 | self.num_classes = num_classes
30 | self.root_dir = root_dir
31 | self.scale_aug = scale_aug
32 | self.rot_aug = rot_aug
33 | self.test_mode = test_mode
34 | self.num_views = num_views
35 | self.train_objects = train_objects
36 | self.test_objects = test_objects
37 | self.train_file_number = self.train_objects * self.num_views
38 | self.test_file_number = self.test_objects * self.num_views
39 |
40 | set_ = root_dir.split('/')[-1]
41 | parent_dir = root_dir.rsplit('/', 2)[0]
42 | self.filepaths = []
43 | for i in range(len(self.classnames)):
44 | all_files = sorted(glob.glob(parent_dir + '/' + self.classnames[i] + '/' + set_ + '/*.png'))
45 | ## Select subset for different number of views
46 | stride = int(12 / self.num_views) # 12 6 4 3 2 1
47 | all_files = all_files[::stride]
48 |
49 | if self.test_mode:
50 | self.filepaths.extend(all_files[:min(self.test_file_number, len(all_files))])
51 | else:
52 | self.filepaths.extend(all_files[:min(self.train_file_number, len(all_files))])
53 |
54 | if shuffle == True:
55 | # permute
56 | rand_idx = np.random.permutation(int(len(self.filepaths) / num_views))
57 | filepaths_new = []
58 | for i in range(len(rand_idx)):
59 | filepaths_new.extend(self.filepaths[rand_idx[i] * num_views:(rand_idx[i] + 1) * num_views])
60 | self.filepaths = filepaths_new
61 |
62 | if self.test_mode:
63 | self.transform = transforms.Compose([
64 | transforms.ToTensor(),
65 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
66 | std=[0.229, 0.224, 0.225])
67 | ])
68 | else:
69 | self.transform = transforms.Compose([
70 | transforms.RandomHorizontalFlip(),
71 | transforms.ToTensor(),
72 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
73 | std=[0.229, 0.224, 0.225])
74 | ])
75 |
76 | def __len__(self):
77 | return int(len(self.filepaths) / self.num_views)
78 |
79 | def __getitem__(self, idx):
80 | path = self.filepaths[idx * self.num_views]
81 | class_name = path.split('/')[-3] # Linux
82 | # class_name = path.split('/')[-2] # Win
83 | class_id = self.classnames.index(class_name)
84 | # Use PIL instead
85 | imgs = []
86 | for i in range(self.num_views):
87 | im = Image.open(self.filepaths[idx * self.num_views + i]).convert('RGB')
88 | if self.transform:
89 | im = self.transform(im)
90 | imgs.append(im)
91 |
92 | return (class_id, torch.stack(imgs), self.filepaths[idx * self.num_views:(idx + 1) * self.num_views])
93 |
94 |
95 | class SingleImgDataset(torch.utils.data.Dataset):
96 |
97 | def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False,
98 | num_classes=40, num_views=12, train_objects=80, test_objects=20):
99 |
100 | if num_classes == 10:
101 | self.classnames = ['bathtub', 'bed', 'chair', 'desk', 'dresser',
102 | 'monitor', 'night_stand', 'sofa', 'table', 'toilet']
103 | elif num_classes == 40:
104 | self.classnames = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl', 'car', 'chair',
105 | 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser', 'flower_pot', 'glass_box',
106 | 'guitar', 'keyboard', 'lamp', 'laptop', 'mantel', 'monitor', 'night_stand',
107 | 'person', 'piano', 'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs',
108 | 'stool', 'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
109 |
110 | self.num_classes = num_classes
111 | self.root_dir = root_dir
112 | self.scale_aug = scale_aug
113 | self.rot_aug = rot_aug
114 | self.test_mode = test_mode
115 | self.num_views = num_views
116 | self.train_objects = train_objects
117 | self.test_objects = test_objects
118 | self.train_file_number = self.train_objects * self.num_views
119 | self.test_file_number = self.test_objects * self.num_views
120 |
121 | set_ = root_dir.split('/')[-1]
122 | parent_dir = root_dir.rsplit('/', 2)[0]
123 | self.filepaths = []
124 | for i in range(len(self.classnames)):
125 | all_files = sorted(glob.glob(parent_dir + '/' + self.classnames[i] + '/' + set_ + '/*shaded*.png'))
126 | if self.test_mode:
127 | self.filepaths.extend(all_files[:min(self.test_file_number, len(all_files))])
128 | else:
129 | self.filepaths.extend(all_files[:min(self.train_file_number, len(all_files))])
130 |
131 | self.transform = transforms.Compose([
132 | transforms.RandomHorizontalFlip(),
133 | transforms.ToTensor(),
134 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
135 | std=[0.229, 0.224, 0.225])
136 | ])
137 |
138 | def __len__(self):
139 | return len(self.filepaths)
140 |
141 | def __getitem__(self, idx):
142 | path = self.filepaths[idx]
143 | class_name = path.split('/')[-3] # Linux
144 | # class_name = path.split('/')[-2] # Win
145 | class_id = self.classnames.index(class_name)
146 |
147 | # Use PIL instead
148 | im = Image.open(self.filepaths[idx]).convert('RGB')
149 | if self.transform:
150 | im = self.transform(im)
151 |
152 | return (class_id, im, path)
153 |
--------------------------------------------------------------------------------
/tools/train_func.py:
--------------------------------------------------------------------------------
1 | import os
2 | from tqdm import tqdm
3 |
4 | # import cv2
5 | import numpy as np
6 | import torch
7 | import torch.nn
8 | import torchvision
9 | import torchvision.transforms as transforms
10 | from torch.utils.data import DataLoader
11 |
12 | # from cluster import ElasticNetSubspaceClustering, clustering_accuracy
13 | import utils
14 |
15 |
16 | def get_features(net, trainloader, loss_name, model_name, verbose=True):
17 | '''Extract all features out into one single batch.
18 |
19 | Parameters:
20 | net (torch.nn.Module): get features using this model
21 | trainloader (torchvision.dataloader): dataloader for loading data
22 | verbose (bool): shows loading staus bar
23 |
24 | Returns:
25 | features (torch.tensor): with dimension (num_samples, feature_dimension)
26 | labels (torch.tensor): with dimension (num_samples, )
27 | '''
28 | features = []
29 | labels = []
30 | if verbose:
31 | train_bar = tqdm(trainloader, desc="extracting all features from dataset")
32 | else:
33 | train_bar = trainloader
34 | for step, data in enumerate(train_bar):
35 | if model_name == 'mvcnn':
36 | N, V, C, H, W = data[1].size()
37 | batch_imgs = data[1].view(-1, C, H, W)
38 | else:
39 | batch_imgs = data[1]
40 | batch_lbls = data[0].long()
41 |
42 | if loss_name == 'mcr2':
43 | batch_features = net(batch_imgs.cpu())
44 | else: # 'ce'
45 | _, batch_features = net(batch_imgs.cpu())
46 |
47 | features.append(batch_features.cpu().detach())
48 | # features.append(batch_features)
49 | labels.append(batch_lbls)
50 | return torch.cat(features), torch.cat(labels)
51 |
52 |
53 | def label_to_membership(targets, num_classes=None):
54 | """Generate a true membership matrix, and assign value to current Pi.
55 |
56 | Parameters:
57 | targets (np.ndarray): matrix with one hot labels
58 |
59 | Return:
60 | Pi: membership matirx, shape (num_classes, num_samples, num_samples)
61 |
62 | """
63 | targets = one_hot(targets, num_classes)
64 | num_samples, num_classes = targets.shape
65 | Pi = np.zeros(shape=(num_classes, num_samples, num_samples))
66 | for j in range(len(targets)):
67 | k = np.argmax(targets[j])
68 | Pi[k, j, j] = 1.
69 | return Pi
70 |
71 |
72 | def membership_to_label(membership):
73 | """Turn a membership matrix into a list of labels."""
74 | _, num_classes, num_samples, _ = membership.shape
75 | labels = np.zeros(num_samples)
76 | for i in range(num_samples):
77 | labels[i] = np.argmax(membership[:, i, i])
78 | return labels
79 |
80 | def one_hot(labels_int, n_classes):
81 | """Turn labels into one hot vector of K classes. """
82 | labels_onehot = torch.zeros(size=(len(labels_int), n_classes)).float()
83 | for i, y in enumerate(labels_int):
84 | labels_onehot[i, y] = 1.
85 | return labels_onehot
86 |
--------------------------------------------------------------------------------
/tools/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import numpy as np
4 | import os
5 | from tensorboardX import SummaryWriter
6 |
7 | from tools import train_func as tf
8 | from tools.evaluate_func import svm
9 |
10 |
11 | class ModelNetTrainer(object):
12 |
13 | def __init__(self, model, device, train_loader, val_loader, optimizer, loss_fn, \
14 | model_name, model_dir, num_classes, num_views=12):
15 |
16 | self.optimizer = optimizer
17 | self.model = model
18 | self.device = device
19 | self.train_loader = train_loader
20 | self.val_loader = val_loader
21 | self.loss_fn = loss_fn
22 | self.model_name = model_name
23 | self.num_views = num_views
24 | self.num_classes = num_classes
25 |
26 | self.model_dir = model_dir
27 | self.checkpoints_dir = os.path.join(model_dir, 'checkpoints')
28 | self.tensorboard_dir = os.path.join(model_dir, 'tensorboard')
29 |
30 | # self.model.cuda()
31 | self.model.to(self.device)
32 |
33 | self.writer = SummaryWriter(self.tensorboard_dir)
34 |
35 | def train(self, n_epochs):
36 |
37 | if self.model_name == 'mvcnn':
38 | for param in self.model.net_1.parameters():
39 | param.requires_grad = False
40 |
41 | i_acc = 0
42 | self.model.train()
43 | for epoch in range(n_epochs):
44 | # permute data
45 | rand_idx = np.random.permutation(int(len(self.train_loader.dataset.filepaths)/self.num_views))
46 | filepaths_new = []
47 | for i in range(len(rand_idx)):
48 | filepaths_new.extend(self.train_loader.dataset.filepaths[rand_idx[i]*self.num_views:(rand_idx[i]+1)*self.num_views])
49 | self.train_loader.dataset.filepaths = filepaths_new
50 |
51 | # plot learning rate
52 | lr = self.optimizer.state_dict()['param_groups'][0]['lr']
53 | self.writer.add_scalar('params/lr', lr, epoch)
54 |
55 | for i, data in enumerate(self.train_loader):
56 |
57 | if self.model_name == 'mvcnn':
58 | N,V,C,H,W = data[1].size()
59 | in_data = Variable(data[1]).view(-1, C, H, W).to(self.device)
60 | else:
61 | in_data = Variable(data[1]).to(self.device)
62 | target = Variable(data[0]).to(self.device).long()
63 |
64 | self.optimizer.zero_grad()
65 |
66 | out_data = self.model(in_data)
67 | loss, [discrimn_loss, compress_loss] = self.loss_fn(out_data, target, num_classes=self.num_classes)
68 |
69 | self.writer.add_scalars('train_loss',
70 | {'loss': loss, 'discrimn': discrimn_loss, 'compress': compress_loss},
71 | i_acc+i+1)
72 |
73 | loss.backward()
74 | self.optimizer.step()
75 |
76 | # for name in self.model.state_dict():
77 | # print(name)
78 | # print(self.model.state_dict()['net_1.6.bias'])
79 | # print(self.model.state_dict()['net_2.2.9.weight'])
80 |
81 | log_str = 'epoch %d, step %d: train_loss %.3f' % (epoch+1, i+1, loss)
82 | if (i+1) % 1 == 0:
83 | print(log_str)
84 | i_acc += i+1
85 |
86 | # evaluation
87 | if (epoch+1) % 10 == 0:
88 | with torch.no_grad():
89 | acc_train_svm, acc_test_svm = self.update_validation_accuracy()
90 | self.writer.add_scalars('acc',
91 | {'acc_train_svm': acc_train_svm, 'acc_test_svm': acc_test_svm},
92 | epoch+1)
93 |
94 | # save the model
95 | if (epoch + 1) % 10 == 0:
96 | self.model.save(self.checkpoints_dir, epoch)
97 |
98 | # adjust learning rate manually
99 | if epoch > 0 and (epoch+1) % 40 == 0:
100 | for param_group in self.optimizer.param_groups:
101 | param_group['lr'] = param_group['lr']*0.5
102 |
103 | # export scalar data to JSON for external processing
104 | self.writer.export_scalars_to_json(self.checkpoints_dir+"/all_scalars.json")
105 | self.writer.close()
106 |
107 | def update_validation_accuracy(self):
108 |
109 | self.model.eval()
110 | train_features, train_labels = tf.get_features(self.model, self.train_loader, 'mcr2', self.model_name)
111 | test_features, test_labels = tf.get_features(self.model, self.val_loader, 'mcr2', self.model_name)
112 |
113 | # train & test accuracy
114 | acc_train_svm, acc_test_svm = svm(train_features, train_labels, test_features, test_labels)
115 |
116 | self.model.train()
117 |
118 | return acc_train_svm, acc_test_svm
119 |
120 |
--------------------------------------------------------------------------------
/tools/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import json
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def sort_dataset(data, labels, num_classes=10, stack=False):
9 | """Sort dataset based on classes.
10 |
11 | Parameters:
12 | data (np.ndarray): data array
13 | labels (np.ndarray): one dimensional array of class labels
14 | num_classes (int): number of classes
15 | stack (bol): combine sorted data into one numpy array
16 |
17 | Return:
18 | sorted data (np.ndarray), sorted_labels (np.ndarray)
19 |
20 | """
21 | sorted_data = [[] for _ in range(num_classes)]
22 | for i, lbl in enumerate(labels):
23 | sorted_data[lbl].append(data[i])
24 | sorted_data = [np.stack(class_data) for class_data in sorted_data]
25 | sorted_labels = [np.repeat(i, (len(sorted_data[i]))) for i in range(num_classes)]
26 | if stack:
27 | sorted_data = np.vstack(sorted_data)
28 | sorted_labels = np.hstack(sorted_labels)
29 | return sorted_data, sorted_labels
30 |
31 | def init_pipeline(model_dir, headers=None):
32 | """Initialize folder and .csv logger."""
33 | # project folder
34 | os.makedirs(model_dir)
35 | os.makedirs(os.path.join(model_dir, 'checkpoints'))
36 | os.makedirs(os.path.join(model_dir, 'tensorboard'))
37 | # os.makedirs(os.path.join(model_dir, 'figures'))
38 | # os.makedirs(os.path.join(model_dir, 'plabels'))
39 | # os.makedirs(name=model_dir, exist_ok=True)
40 | # os.makedirs(os.path.join(model_dir, 'checkpoints'), exist_ok=True)
41 | # os.makedirs(os.path.join(model_dir, 'figures'), exist_ok=True)
42 | # os.makedirs(os.path.join(model_dir, 'plabels'), exist_ok=True)
43 | if headers is None:
44 | headers = ["epoch", "step", "loss", "discrimn_loss_e", "compress_loss_e",
45 | "discrimn_loss_t", "compress_loss_t"]
46 | create_csv(model_dir, 'losses.csv', headers)
47 | print("project dir: {}".format(model_dir))
48 |
49 | def create_csv(model_dir, filename, headers):
50 | """Create .csv file with filename in model_dir, with headers as the first line
51 | of the csv. """
52 | csv_path = os.path.join(model_dir, filename)
53 | if os.path.exists(csv_path):
54 | os.remove(csv_path)
55 | with open(csv_path, 'w+') as f:
56 | f.write(','.join(map(str, headers)))
57 | return csv_path
58 |
59 | def save_params(model_dir, params):
60 | """Save params to a .json file. Params is a dictionary of parameters."""
61 | path = os.path.join(model_dir, 'params.json')
62 | with open(path, 'w') as f:
63 | json.dump(params, f, indent=2, sort_keys=True)
64 |
65 | def update_params(model_dir, pretrain_dir):
66 | """Updates architecture and feature dimension from pretrain directory
67 | to new directoy. """
68 | params = load_params(model_dir)
69 | old_params = load_params(pretrain_dir)
70 | params['arch'] = old_params["arch"]
71 | params['fd'] = old_params['fd']
72 | save_params(model_dir, params)
73 |
74 | def load_params(model_dir):
75 | """Load params.json file in model directory and return dictionary."""
76 | _path = os.path.join(model_dir, "params.json")
77 | with open(_path, 'r') as f:
78 | _dict = json.load(f)
79 | return _dict
80 |
81 | def save_state(model_dir, *entries, filename='losses.csv'):
82 | """Save entries to csv. Entries is list of numbers. """
83 | csv_path = os.path.join(model_dir, filename)
84 | assert os.path.exists(csv_path), 'CSV file is missing in project directory.'
85 | with open(csv_path, 'a') as f:
86 | f.write('\n'+','.join(map(str, entries)))
87 |
88 | def save_ckpt(model_dir, net, epoch):
89 | """Save PyTorch checkpoint to ./checkpoints/ directory in model directory. """
90 | torch.save(net.state_dict(), os.path.join(model_dir, 'checkpoints',
91 | 'model-epoch{}.pt'.format(epoch)))
92 |
93 | def save_labels(model_dir, labels, epoch):
94 | """Save labels of a certain epoch to directory. """
95 | path = os.path.join(model_dir, 'plabels', f'epoch{epoch}.npy')
96 | np.save(path, labels)
97 |
98 | def compute_accuracy(y_pred, y_true):
99 | """Compute accuracy by counting correct classification. """
100 | assert y_pred.shape == y_true.shape
101 | return 1 - np.count_nonzero(y_pred - y_true) / y_true.size
102 |
103 | def clustering_accuracy(labels_true, labels_pred):
104 | """Compute clustering accuracy."""
105 | from sklearn.metrics.cluster import supervised
106 | from scipy.optimize import linear_sum_assignment
107 | labels_true, labels_pred = supervised.check_clusterings(labels_true, labels_pred)
108 | value = supervised.contingency_matrix(labels_true, labels_pred)
109 | [r, c] = linear_sum_assignment(-value)
110 | return value[r, c].sum() / len(labels_true)
--------------------------------------------------------------------------------