├── Branchy_GNN_Framework.png ├── util.py ├── README.md ├── data.py ├── model.py ├── edge_main.py └── branchy_model.py /Branchy_GNN_Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shaojiawei07/Branchy-GNN/HEAD/Branchy_GNN_Framework.png -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/WangYueFt/dgcnn 3 | """ 4 | 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def cal_loss(pred, gold, smoothing=True): 12 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 13 | 14 | gold = gold.contiguous().view(-1) 15 | 16 | if smoothing: 17 | eps = 0.2 18 | n_class = pred.size(1) 19 | 20 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 21 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 22 | log_prb = F.log_softmax(pred, dim=1) 23 | 24 | loss = -(one_hot * log_prb).sum(dim=1).mean() 25 | else: 26 | loss = F.cross_entropy(pred, gold, reduction='mean') 27 | 28 | return loss 29 | 30 | 31 | class IOStream(): 32 | def __init__(self, path): 33 | self.f = open(path, 'a') 34 | 35 | def cprint(self, text): 36 | print(text) 37 | self.f.write(text+'\n') 38 | self.f.flush() 39 | 40 | def close(self): 41 | self.f.close() 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Branchy-GNN 2 | 3 | This code is for the [paper](https://arxiv.org/abs/2011.02422): "Branchy-GNN: a Device-Edge Co-Inference Framework for Efficient Point Cloud Processing", which is submitted to ICASSP2021. 4 | 5 | 6 | 7 | ## Framework 8 | 9 | We propose a branchy structure for GNN-based point cloud classification to speedup edge inference. We adopt branch structures for early exiting the main branch to reduce the on-device computational cost and introduce joint source-channel coding (JSCC) to reduce the communication overhead. 10 | 11 | In the experiment, we have four exit points. 12 | 13 | Note that the main branch in the framework is based on [DGCNN](https://github.com/WangYueFt/dgcnn). 14 | 15 | 16 | 17 | 18 | 19 | 20 | ### Dependency 21 | 22 | ``` 23 | Pytorch 24 | h5py 25 | sklearn 26 | ``` 27 | 28 | 29 | 30 | ### Dataset 31 | 32 | ``` 33 | ModelNet40 34 | ``` 35 | 36 | 37 | 38 | 39 | 40 | ### How to run 41 | 42 | 1. Pretrain a DGCNN model based on the [code](https://github.com/WangYueFt/dgcnn/tree/master/pytorch) or download it from [here](https://github.com/WangYueFt/dgcnn/tree/master/pytorch/pretrained). (``./pretrained/model.1024.t7``) 43 | 2. Train the branch network by ``python edge_main.py --num_p=1024 --use_sgd=True --model EXIT1 --channel_noise 0.1``. 44 | 45 | Note that ``--model`` could be ``EXIT1``, ``EXIT2``, ``EXIT3``, and ``EXIT4``. 46 | 47 | ``--channel_noise`` is the standard deviation in the AWGN channel. The output of the encoder is normalized based on the l-2 norm, and the average signal power is 1. 48 | So, ``channel_noise = 0.1`` means SNR = 20dB. 49 | 50 | ### Test in different channel conditions 51 | ``python edge_main.py --num_points=1024 --k=20 --eval=True --model_path= (saved model path) --channel_noise 0.056`` 52 | 53 | 54 | ### Other 55 | 56 | (2023-May-29) I have corrected some errors and updated Fig. 2 and Fig. 3 in the [arxiv version](https://arxiv.org/abs/2011.02422). 57 | 58 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/WangYueFt/dgcnn 3 | """ 4 | 5 | 6 | import os 7 | import sys 8 | import glob 9 | import h5py 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | 13 | 14 | def download(): 15 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | DATA_DIR = os.path.join(BASE_DIR, 'data') 17 | if not os.path.exists(DATA_DIR): 18 | os.mkdir(DATA_DIR) 19 | if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): 20 | www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' 21 | zipfile = os.path.basename(www) 22 | os.system('wget %s --no-check-certificate; unzip %s' % (www, zipfile) ) 23 | os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) 24 | os.system('rm %s' % (zipfile)) 25 | 26 | 27 | def load_data(partition): 28 | download() 29 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 30 | DATA_DIR = os.path.join(BASE_DIR, 'data') 31 | all_data = [] 32 | all_label = [] 33 | for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)): 34 | f = h5py.File(h5_name) 35 | data = f['data'][:].astype('float32') 36 | label = f['label'][:].astype('int64') 37 | f.close() 38 | all_data.append(data) 39 | all_label.append(label) 40 | all_data = np.concatenate(all_data, axis=0) 41 | all_label = np.concatenate(all_label, axis=0) 42 | return all_data, all_label 43 | 44 | 45 | def translate_pointcloud(pointcloud): 46 | xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3]) 47 | xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3]) 48 | 49 | translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32') 50 | return translated_pointcloud 51 | 52 | 53 | def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02): 54 | N, C = pointcloud.shape 55 | pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip) 56 | return pointcloud 57 | 58 | 59 | class ModelNet40(Dataset): 60 | def __init__(self, num_points, partition='train'): 61 | self.data, self.label = load_data(partition) 62 | self.num_points = num_points 63 | self.partition = partition 64 | 65 | def __getitem__(self, item): 66 | pointcloud = self.data[item][:self.num_points] 67 | label = self.label[item] 68 | if self.partition == 'train': 69 | pointcloud = translate_pointcloud(pointcloud) 70 | np.random.shuffle(pointcloud) 71 | return pointcloud, label 72 | 73 | def __len__(self): 74 | return self.data.shape[0] 75 | 76 | 77 | if __name__ == '__main__': 78 | train = ModelNet40(1024) 79 | test = ModelNet40(1024, 'test') 80 | for data, label in train: 81 | print(data.shape) 82 | print(label.shape) 83 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/WangYueFt/dgcnn 3 | """ 4 | 5 | 6 | import os 7 | import sys 8 | import copy 9 | import math 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | 14 | import torch.nn.functional as F 15 | 16 | 17 | def knn(x, k): 18 | inner = -2*torch.matmul(x.transpose(2, 1), x) 19 | xx = torch.sum(x**2, dim=1, keepdim=True) 20 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 21 | 22 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 23 | return idx 24 | 25 | 26 | def get_graph_feature(x, k=20, idx=None): 27 | batch_size = x.size(0) 28 | num_points = x.size(2) 29 | x = x.view(batch_size, -1, num_points) 30 | if idx is None: 31 | idx = knn(x, k=k) # (batch_size, num_points, k) 32 | device = torch.device('cuda') 33 | 34 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 35 | 36 | idx = idx + idx_base 37 | 38 | idx = idx.view(-1) 39 | 40 | _, num_dims, _ = x.size() 41 | 42 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 43 | feature = x.view(batch_size*num_points, -1)[idx, :] 44 | feature = feature.view(batch_size, num_points, k, num_dims) 45 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 46 | 47 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 48 | 49 | return feature 50 | 51 | 52 | class PointNet(nn.Module): 53 | def __init__(self, args, output_channels=40): 54 | super(PointNet, self).__init__() 55 | self.args = args 56 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 57 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 58 | self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 59 | self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False) 60 | self.conv5 = nn.Conv1d(128, args.emb_dims, kernel_size=1, bias=False) 61 | self.bn1 = nn.BatchNorm1d(64) 62 | self.bn2 = nn.BatchNorm1d(64) 63 | self.bn3 = nn.BatchNorm1d(64) 64 | self.bn4 = nn.BatchNorm1d(128) 65 | self.bn5 = nn.BatchNorm1d(args.emb_dims) 66 | self.linear1 = nn.Linear(args.emb_dims, 512, bias=False) 67 | self.bn6 = nn.BatchNorm1d(512) 68 | self.dp1 = nn.Dropout() 69 | self.linear2 = nn.Linear(512, output_channels) 70 | 71 | def forward(self, x): 72 | x = F.relu(self.bn1(self.conv1(x))) 73 | x = F.relu(self.bn2(self.conv2(x))) 74 | x = F.relu(self.bn3(self.conv3(x))) 75 | x = F.relu(self.bn4(self.conv4(x))) 76 | x = F.relu(self.bn5(self.conv5(x))) 77 | x = F.adaptive_max_pool1d(x, 1).squeeze() 78 | x = F.relu(self.bn6(self.linear1(x))) 79 | x = self.dp1(x) 80 | x = self.linear2(x) 81 | 82 | return x 83 | 84 | 85 | class DGCNN(nn.Module): 86 | def __init__(self, args, output_channels=40): 87 | super(DGCNN, self).__init__() 88 | self.args = args 89 | self.k = args.k 90 | 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.bn2 = nn.BatchNorm2d(64) 93 | self.bn3 = nn.BatchNorm2d(128) 94 | self.bn4 = nn.BatchNorm2d(256) 95 | self.bn5 = nn.BatchNorm1d(args.emb_dims) 96 | 97 | self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False), 98 | self.bn1, 99 | nn.LeakyReLU(negative_slope=0.2)) 100 | self.conv2 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False), 101 | self.bn2, 102 | nn.LeakyReLU(negative_slope=0.2)) 103 | self.conv3 = nn.Sequential(nn.Conv2d(64*2, 128, kernel_size=1, bias=False), 104 | self.bn3, 105 | nn.LeakyReLU(negative_slope=0.2)) 106 | self.conv4 = nn.Sequential(nn.Conv2d(128*2, 256, kernel_size=1, bias=False), 107 | self.bn4, 108 | nn.LeakyReLU(negative_slope=0.2)) 109 | self.conv5 = nn.Sequential(nn.Conv1d(512, args.emb_dims, kernel_size=1, bias=False), 110 | self.bn5, 111 | nn.LeakyReLU(negative_slope=0.2)) 112 | self.linear1 = nn.Linear(args.emb_dims*2, 512, bias=False) 113 | self.bn6 = nn.BatchNorm1d(512) 114 | self.dp1 = nn.Dropout(p=args.dropout) 115 | self.linear2 = nn.Linear(512, 256) 116 | self.bn7 = nn.BatchNorm1d(256) 117 | self.dp2 = nn.Dropout(p=args.dropout) 118 | self.linear3 = nn.Linear(256, output_channels) 119 | 120 | def forward(self, x): 121 | batch_size = x.size(0) 122 | x = get_graph_feature(x, k=self.k) 123 | x = self.conv1(x) 124 | x1 = x.max(dim=-1, keepdim=False)[0] 125 | 126 | x = get_graph_feature(x1, k=self.k) 127 | x = self.conv2(x) 128 | x2 = x.max(dim=-1, keepdim=False)[0] 129 | 130 | x = get_graph_feature(x2, k=self.k) 131 | x = self.conv3(x) 132 | x3 = x.max(dim=-1, keepdim=False)[0] 133 | 134 | x = get_graph_feature(x3, k=self.k) 135 | x = self.conv4(x) 136 | x4 = x.max(dim=-1, keepdim=False)[0] 137 | 138 | x = torch.cat((x1, x2, x3, x4), dim=1) 139 | 140 | x = self.conv5(x) 141 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 142 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) 143 | x = torch.cat((x1, x2), 1) 144 | 145 | x = F.leaky_relu(self.bn6(self.linear1(x)), negative_slope=0.2) 146 | x = self.dp1(x) 147 | x = F.leaky_relu(self.bn7(self.linear2(x)), negative_slope=0.2) 148 | x = self.dp2(x) 149 | x = self.linear3(x) 150 | return x 151 | -------------------------------------------------------------------------------- /edge_main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import CosineAnnealingLR 9 | from data import ModelNet40 10 | from branchy_model import DGCNN_exit4,DGCNN_exit3,DGCNN_exit2,DGCNN_exit1 11 | import numpy as np 12 | from torch.utils.data import DataLoader 13 | from util import cal_loss, IOStream 14 | import sklearn.metrics as metrics 15 | 16 | 17 | def _init_(): 18 | if not os.path.exists('checkpoints'): 19 | os.makedirs('checkpoints') 20 | if not os.path.exists('checkpoints/'+args.exp_name): 21 | os.makedirs('checkpoints/'+args.exp_name) 22 | if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'): 23 | os.makedirs('checkpoints/'+args.exp_name+'/'+'models') 24 | os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup') 25 | os.system('cp model.py checkpoints' + '/' + args.exp_name + '/' + 'model.py.backup') 26 | os.system('cp util.py checkpoints' + '/' + args.exp_name + '/' + 'util.py.backup') 27 | os.system('cp data.py checkpoints' + '/' + args.exp_name + '/' + 'data.py.backup') 28 | 29 | def train(args, io): 30 | train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=8, 31 | batch_size=args.batch_size, shuffle=True, drop_last=True) 32 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=8, 33 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 34 | 35 | device = torch.device("cuda" if args.cuda else "cpu") 36 | 37 | #Try to load models 38 | if args.model == 'EXIT1': 39 | model = DGCNN_exit1(args).to(device) 40 | elif args.model == 'EXIT2': 41 | model = DGCNN_exit2(args).to(device) 42 | elif args.model == 'EXIT3': 43 | model = DGCNN_exit3(args).to(device) 44 | elif args.model == 'EXIT4': 45 | model = DGCNN_exit4(args).to(device) 46 | 47 | 48 | print(str(model)) 49 | 50 | model = nn.DataParallel(model) 51 | 52 | print("Let's use", torch.cuda.device_count(), "GPUs!") 53 | 54 | if args.use_sgd: 55 | print("Use SGD") 56 | opt = optim.SGD(model.parameters(), lr=args.lr*100, momentum=args.momentum, weight_decay=1e-4) 57 | else: 58 | print("Use Adam") 59 | opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-4) 60 | 61 | scheduler = CosineAnnealingLR(opt, args.epochs, eta_min=args.lr) 62 | 63 | criterion = cal_loss 64 | 65 | best_test_acc = 0 66 | for epoch in range(args.epochs): 67 | scheduler.step() 68 | #################### 69 | # Train 70 | #################### 71 | train_loss = 0.0 72 | count = 0.0 73 | model.train() 74 | train_pred = [] 75 | train_true = [] 76 | for data, label in train_loader: 77 | data, label = data.to(device), label.to(device).squeeze() 78 | data = data.permute(0, 2, 1) 79 | batch_size = data.size()[0] 80 | opt.zero_grad() 81 | logits = model(data) 82 | loss = criterion(logits, label) 83 | loss.backward() 84 | opt.step() 85 | preds = logits.max(dim=1)[1] 86 | count += batch_size 87 | train_loss += loss.item() * batch_size 88 | train_true.append(label.cpu().numpy()) 89 | train_pred.append(preds.detach().cpu().numpy()) 90 | train_true = np.concatenate(train_true) 91 | train_pred = np.concatenate(train_pred) 92 | outstr = 'Train %d, loss: %.6f, train acc: %.6f, train avg acc: %.6f' % (epoch, 93 | train_loss*1.0/count, 94 | metrics.accuracy_score( 95 | train_true, train_pred), 96 | metrics.balanced_accuracy_score( 97 | train_true, train_pred)) 98 | io.cprint(outstr) 99 | 100 | #################### 101 | # Test 102 | #################### 103 | test_loss = 0.0 104 | count = 0.0 105 | model.eval() 106 | test_pred = [] 107 | test_true = [] 108 | for data, label in test_loader: 109 | data, label = data.to(device), label.to(device).squeeze() 110 | data = data.permute(0, 2, 1) 111 | batch_size = data.size()[0] 112 | logits = model(data) 113 | loss = criterion(logits, label) 114 | preds = logits.max(dim=1)[1] 115 | count += batch_size 116 | test_loss += loss.item() * batch_size 117 | test_true.append(label.cpu().numpy()) 118 | test_pred.append(preds.detach().cpu().numpy()) 119 | test_true = np.concatenate(test_true) 120 | test_pred = np.concatenate(test_pred) 121 | test_acc = metrics.accuracy_score(test_true, test_pred) 122 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 123 | outstr = 'Test %d, loss: %.6f, test acc: %.6f, test avg acc: %.6f' % (epoch, 124 | test_loss*1.0/count, 125 | test_acc, 126 | avg_per_class_acc) 127 | io.cprint(outstr) 128 | if test_acc >= best_test_acc: 129 | best_test_acc = test_acc 130 | torch.save(model.state_dict(), 'checkpoints/%s/models/EXIT_model.t7' % args.exp_name) 131 | 132 | 133 | def test(args, io): 134 | test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), 135 | batch_size=args.test_batch_size, shuffle=True, drop_last=False) 136 | 137 | device = torch.device("cuda" if args.cuda else "cpu") 138 | 139 | #Try to load models 140 | if args.model == 'EXIT1': 141 | model = DGCNN_exit1(args).to(device) 142 | elif args.model == 'EXIT2': 143 | model = DGCNN_exit2(args).to(device) 144 | elif args.model == 'EXIT3': 145 | model = DGCNN_exit3(args).to(device) 146 | elif args.model == 'EXIT4': 147 | model = DGCNN_exit4(args).to(device) 148 | 149 | model = nn.DataParallel(model) 150 | model.load_state_dict(torch.load(args.model_path)) 151 | model = model.eval() 152 | test_acc = 0.0 153 | count = 0.0 154 | test_true = [] 155 | test_pred = [] 156 | for data, label in test_loader: 157 | 158 | data, label = data.to(device), label.to(device).squeeze() 159 | data = data.permute(0, 2, 1) 160 | batch_size = data.size()[0] 161 | logits = model(data) 162 | preds = logits.max(dim=1)[1] 163 | test_true.append(label.cpu().numpy()) 164 | test_pred.append(preds.detach().cpu().numpy()) 165 | test_true = np.concatenate(test_true) 166 | test_pred = np.concatenate(test_pred) 167 | test_acc = metrics.accuracy_score(test_true, test_pred) 168 | avg_per_class_acc = metrics.balanced_accuracy_score(test_true, test_pred) 169 | outstr = 'Test :: test acc: %.6f, test avg acc: %.6f'%(test_acc, avg_per_class_acc) 170 | io.cprint(outstr) 171 | 172 | 173 | if __name__ == "__main__": 174 | # Training settings 175 | parser = argparse.ArgumentParser(description='Point Cloud Recognition') 176 | parser.add_argument('--exp_name', type=str, default='exp', metavar='N', 177 | help='Name of the experiment') 178 | parser.add_argument('--model', type=str, default='dgcnn', metavar='N', 179 | choices=['EXIT1', 'EXIT2','EXIT3','EXIT4'], 180 | help='Model to use, [pointnet, dgcnn]') 181 | parser.add_argument('--dataset', type=str, default='modelnet40', metavar='N', 182 | choices=['modelnet40']) 183 | parser.add_argument('--batch_size', type=int, default=32, metavar='batch_size', 184 | help='Size of batch)') 185 | parser.add_argument('--test_batch_size', type=int, default=16, metavar='batch_size', 186 | help='Size of batch)') 187 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 188 | help='number of episode to train ') 189 | parser.add_argument('--use_sgd', type=bool, default=True, 190 | help='Use SGD') 191 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 192 | help='learning rate (default: 0.001, 0.1 if using sgd)') 193 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 194 | help='SGD momentum (default: 0.9)') 195 | parser.add_argument('--no_cuda', type=bool, default=False, 196 | help='enables CUDA training') 197 | parser.add_argument('--seed', type=int, default=1, metavar='S', 198 | help='random seed (default: 1)') 199 | parser.add_argument('--eval', type=bool, default=False, 200 | help='evaluate the model') 201 | parser.add_argument('--num_points', type=int, default=1024, 202 | help='num of points to use') 203 | parser.add_argument('--dropout', type=float, default=0.5, 204 | help='dropout rate') 205 | parser.add_argument('--emb_dims', type=int, default=1024, metavar='N', 206 | help='Dimension of embeddings') 207 | parser.add_argument('--k', type=int, default=20, metavar='N', 208 | help='Num of nearest neighbors to use') 209 | parser.add_argument('--model_path', type=str, default='', metavar='N', 210 | help='Pretrained model path') 211 | parser.add_argument('--channel_noise', type=float, default=0.1) 212 | args = parser.parse_args() 213 | 214 | _init_() 215 | 216 | io = IOStream('checkpoints/' + args.exp_name + '/run.log') 217 | io.cprint(str(args)) 218 | 219 | args.cuda = not args.no_cuda and torch.cuda.is_available() 220 | torch.manual_seed(args.seed) 221 | if args.cuda: 222 | io.cprint( 223 | 'Using GPU : ' + str(torch.cuda.current_device()) + ' from ' + str(torch.cuda.device_count()) + ' devices') 224 | torch.cuda.manual_seed(args.seed) 225 | else: 226 | io.cprint('Using CPU') 227 | 228 | if not args.eval: 229 | train(args, io) 230 | else: 231 | test(args, io) 232 | -------------------------------------------------------------------------------- /branchy_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import copy 4 | import math 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from model import DGCNN 10 | from collections import OrderedDict 11 | 12 | 13 | def knn(x, k): 14 | inner = -2*torch.matmul(x.transpose(2, 1), x) 15 | xx = torch.sum(x**2, dim=1, keepdim=True) 16 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 17 | 18 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 19 | return idx 20 | 21 | 22 | def get_graph_feature(x, k=20, idx=None): 23 | batch_size = x.size(0) 24 | num_points = x.size(2) 25 | x = x.view(batch_size, -1, num_points) 26 | if idx is None: 27 | idx = knn(x, k=k) # (batch_size, num_points, k) 28 | device = torch.device('cuda') 29 | 30 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points 31 | 32 | idx = idx + idx_base 33 | 34 | idx = idx.view(-1) 35 | 36 | _, num_dims, _ = x.size() 37 | 38 | x = x.transpose(2, 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 39 | feature = x.view(batch_size*num_points, -1)[idx, :] 40 | feature = feature.view(batch_size, num_points, k, num_dims) 41 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 42 | 43 | feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous() 44 | 45 | return feature 46 | 47 | def awgn_channel(x, noise_factor): 48 | x = F.normalize(x, p=2, dim = 1) 49 | #print(noise_factor)#print(torch.norm(x,dim = 1)) 50 | return x + torch.randn_like(x) * noise_factor 51 | 52 | class DGCNN_exit1(nn.Module): 53 | def __init__(self, args, output_channels=40): 54 | super(DGCNN_exit1, self).__init__() 55 | self.args = args 56 | self.DGCNN = DGCNN(args) 57 | dict_tmp = torch.load('./pretrained/model.1024.t7') 58 | new_state_dict = OrderedDict() 59 | #print(dict_tmp) 60 | for name, tensor in dict_tmp.items(): 61 | #print(name) 62 | name = name[7:] 63 | new_state_dict[name] = tensor 64 | 65 | self.DGCNN.load_state_dict(new_state_dict) 66 | self.k = 20 67 | 68 | for para in self.DGCNN.parameters(): 69 | para.requires_grad = False 70 | 71 | self.exit1_conv = nn.Sequential(nn.Conv1d(64, 256, kernel_size=1, bias=False), 72 | nn.BatchNorm1d(256), 73 | nn.LeakyReLU(negative_slope=0.2), 74 | ) 75 | self.exit1_fc2 = nn.Sequential(nn.Linear(512,1536), 76 | nn.BatchNorm1d(1536), 77 | nn.LeakyReLU(negative_slope=0.2), 78 | ) 79 | self.exit1_predict = nn.Sequential(nn.Linear(1536,512), 80 | nn.BatchNorm1d(512), 81 | nn.LeakyReLU(negative_slope=0.2), 82 | nn.Dropout(0.5), 83 | nn.Linear(512,256), 84 | nn.BatchNorm1d(256), 85 | nn.LeakyReLU(negative_slope=0.2), 86 | nn.Dropout(0.5), 87 | nn.Linear(256,128), 88 | nn.BatchNorm1d(128), 89 | nn.LeakyReLU(negative_slope=0.2), 90 | nn.Dropout(0.5), 91 | nn.Linear(128,40), 92 | nn.BatchNorm1d(40), 93 | nn.LeakyReLU(negative_slope=0.2), 94 | ) 95 | 96 | 97 | def forward(self, x, noise_factor = 0.1): 98 | 99 | batch_size = x.size(0) 100 | x = get_graph_feature(x, k=self.k) # [batch_size, dim=3 * 2, point_num, k] 101 | x = self.DGCNN.conv1(x) 102 | x1 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 103 | x = x1 # do not need to concate 104 | 105 | #exit 1 106 | x = self.exit1_conv(x) 107 | 108 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 109 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 110 | x = torch.cat((x1, x2), 1) 111 | x = self.exit1_fc2(x) 112 | 113 | #awgn channel model 114 | #x = awgn_channel(x,0.1) # 20dB 115 | x = awgn_channel(x,self.args.channel_noise) 116 | 117 | x = self.exit1_predict(x) 118 | return x 119 | 120 | class DGCNN_exit2(nn.Module): 121 | def __init__(self, args, output_channels=40): 122 | super(DGCNN_exit2, self).__init__() 123 | self.args = args 124 | self.DGCNN = DGCNN(args) 125 | dict_tmp = torch.load('./pretrained/model.1024.t7') 126 | new_state_dict = OrderedDict() 127 | #print(dict_tmp) 128 | for name, tensor in dict_tmp.items(): 129 | #print(name) 130 | name = name[7:] 131 | new_state_dict[name] = tensor 132 | 133 | self.DGCNN.load_state_dict(new_state_dict) 134 | self.k = 20 135 | 136 | for para in self.DGCNN.parameters(): 137 | para.requires_grad = False 138 | 139 | self.exit2_conv = nn.Sequential(nn.Conv1d(128, 256, kernel_size=1, bias=False), 140 | nn.BatchNorm1d(256), 141 | nn.LeakyReLU(negative_slope=0.2), 142 | ) 143 | self.exit2_fc2 = nn.Sequential(nn.Linear(512,1024), 144 | nn.BatchNorm1d(1024), 145 | nn.LeakyReLU(negative_slope=0.2), 146 | ) 147 | self.exit2_predict = nn.Sequential(nn.Linear(1024,512), 148 | nn.BatchNorm1d(512), 149 | nn.LeakyReLU(negative_slope=0.2), 150 | nn.Dropout(0.5), 151 | nn.Linear(512,256), 152 | nn.BatchNorm1d(256), 153 | nn.LeakyReLU(negative_slope=0.2), 154 | nn.Dropout(0.5), 155 | nn.Linear(256,128), 156 | nn.BatchNorm1d(128), 157 | nn.LeakyReLU(negative_slope=0.2), 158 | nn.Dropout(0.5), 159 | nn.Linear(128,40), 160 | nn.BatchNorm1d(40), 161 | nn.LeakyReLU(negative_slope=0.2), 162 | ) 163 | 164 | 165 | def forward(self, x, noise_factor = 0.1): 166 | 167 | batch_size = x.size(0) 168 | x = get_graph_feature(x, k=self.k) # [batch_size, dim=3 * 2, point_num, k] 169 | x = self.DGCNN.conv1(x) 170 | x1 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 171 | x = get_graph_feature(x1, k=self.k) 172 | x = self.DGCNN.conv2(x) 173 | x2 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 174 | 175 | #exit 2 176 | x = torch.cat((x1, x2), dim=1) # [batch_size, dim =128, point_num] 177 | x = self.exit2_conv(x) 178 | 179 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 180 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 181 | x = torch.cat((x1, x2), 1) 182 | x = self.exit2_fc2(x) 183 | 184 | #awgn channel model 185 | #x = awgn_channel(x,0.1) # 20dB 186 | x = awgn_channel(x,self.args.channel_noise) 187 | 188 | x = self.exit2_predict(x) 189 | return x 190 | 191 | class DGCNN_exit3(nn.Module): 192 | def __init__(self, args, output_channels=40): 193 | super(DGCNN_exit3, self).__init__() 194 | self.args = args 195 | self.DGCNN = DGCNN(args) 196 | dict_tmp = torch.load('./pretrained/model.1024.t7') 197 | new_state_dict = OrderedDict() 198 | #print(dict_tmp) 199 | for name, tensor in dict_tmp.items(): 200 | #print(name) 201 | name = name[7:] 202 | new_state_dict[name] = tensor 203 | #print(name) 204 | #self.DGCNN.load_state_dict(torch.load('./pretrained/model.1024.t7')) 205 | self.DGCNN.load_state_dict(new_state_dict) 206 | self.k = 20 207 | 208 | for para in self.DGCNN.parameters(): 209 | para.requires_grad = False 210 | 211 | self.exit3_conv = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False), 212 | nn.BatchNorm1d(256), 213 | nn.LeakyReLU(negative_slope=0.2), 214 | ) 215 | self.exit3_fc2 = nn.Sequential(nn.Linear(512,512), 216 | nn.BatchNorm1d(512), 217 | nn.LeakyReLU(negative_slope=0.2), 218 | ) 219 | self.exit3_predict = nn.Sequential(nn.Linear(512,512), 220 | nn.BatchNorm1d(512), 221 | nn.LeakyReLU(negative_slope=0.2), 222 | nn.Dropout(0.5), 223 | nn.Linear(512,256), 224 | nn.BatchNorm1d(256), 225 | nn.LeakyReLU(negative_slope=0.2), 226 | nn.Dropout(0.5), 227 | nn.Linear(256,128), 228 | nn.BatchNorm1d(128), 229 | nn.LeakyReLU(negative_slope=0.2), 230 | nn.Dropout(0.5), 231 | nn.Linear(128,40), 232 | nn.BatchNorm1d(40), 233 | nn.LeakyReLU(negative_slope=0.2), 234 | ) 235 | 236 | 237 | def forward(self, x, noise_factor = 0.1): 238 | 239 | batch_size = x.size(0) 240 | x = get_graph_feature(x, k=self.k) # [batch_size, dim=3 * 2, point_num, k] 241 | x = self.DGCNN.conv1(x) 242 | x1 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 243 | x = get_graph_feature(x1, k=self.k) 244 | x = self.DGCNN.conv2(x) 245 | x2 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 246 | x = get_graph_feature(x2, k=self.k) 247 | x = self.DGCNN.conv3(x) 248 | x3 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 128, point_num] 249 | 250 | #exit 3 251 | x = torch.cat((x1, x2, x3), dim=1) # [batch_size, dim =256, point_num] 252 | x = self.exit3_conv(x) 253 | 254 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 255 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) # (batch_size, dimension) 256 | x = torch.cat((x1, x2), 1) 257 | x = self.exit3_fc2(x) 258 | 259 | #awgn channel model 260 | #x = awgn_channel(x,0.1) # 20dB 261 | x = awgn_channel(x,self.args.channel_noise) 262 | 263 | 264 | 265 | 266 | x = self.exit3_predict(x) 267 | return x 268 | 269 | class DGCNN_exit4(nn.Module): 270 | def __init__(self, args, output_channels=40): 271 | super(DGCNN_exit4, self).__init__() 272 | self.args = args 273 | self.DGCNN = DGCNN(args) 274 | dict_tmp = torch.load('./pretrained/model.1024.t7') 275 | new_state_dict = OrderedDict() 276 | #print(dict_tmp) 277 | for name, tensor in dict_tmp.items(): 278 | #print(name) 279 | name = name[7:] 280 | new_state_dict[name] = tensor 281 | #print(name) 282 | #self.DGCNN.load_state_dict(torch.load('./pretrained/model.1024.t7')) 283 | self.DGCNN.load_state_dict(new_state_dict) 284 | self.k = 20 285 | 286 | for para in self.DGCNN.parameters(): 287 | para.requires_grad = False 288 | 289 | 290 | self.exit4_fc2 = nn.Sequential(nn.Linear(2048,128), 291 | nn.BatchNorm1d(128), 292 | nn.LeakyReLU(negative_slope=0.2), 293 | ) 294 | self.exit4_predict = nn.Sequential(nn.Linear(128,512), 295 | nn.BatchNorm1d(512), 296 | nn.LeakyReLU(negative_slope=0.2), 297 | nn.Dropout(0.5), 298 | nn.Linear(512,256), 299 | nn.BatchNorm1d(256), 300 | nn.LeakyReLU(negative_slope=0.2), 301 | nn.Dropout(0.5), 302 | nn.Linear(256,128), 303 | nn.BatchNorm1d(128), 304 | nn.LeakyReLU(negative_slope=0.2), 305 | nn.Dropout(0.5), 306 | nn.Linear(128,40), 307 | nn.BatchNorm1d(40), 308 | nn.LeakyReLU(negative_slope=0.2), 309 | ) 310 | 311 | 312 | def forward(self, x, noise_factor = 0.1): 313 | 314 | batch_size = x.size(0) 315 | x = get_graph_feature(x, k=self.k) # [batch_size, dim=3 * 2, point_num, k] 316 | x = self.DGCNN.conv1(x) 317 | x1 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 318 | x = get_graph_feature(x1, k=self.k) 319 | x = self.DGCNN.conv2(x) 320 | x2 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 64, point_num] 321 | x = get_graph_feature(x2, k=self.k) 322 | x = self.DGCNN.conv3(x) 323 | x3 = x.max(dim=-1, keepdim=False)[0] # [batch_size, dim = 128, point_num] 324 | x = get_graph_feature(x3, k=self.k) 325 | x = self.DGCNN.conv4(x) 326 | x4 = x.max(dim=-1, keepdim=False)[0] 327 | 328 | x = torch.cat((x1, x2, x3, x4), dim=1) 329 | 330 | x = self.DGCNN.conv5(x) 331 | x1 = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 332 | x2 = F.adaptive_avg_pool1d(x, 1).view(batch_size, -1) 333 | x = torch.cat((x1, x2), 1) 334 | 335 | x = self.exit4_fc2(x) 336 | 337 | 338 | #awgn channel model 339 | #x = awgn_channel(x,0.1) # 20dB 340 | x = awgn_channel(x,self.args.channel_noise) 341 | 342 | x = self.exit4_predict(x) 343 | return x 344 | 345 | 346 | --------------------------------------------------------------------------------