├── README.md ├── doc ├── diffusion.gif └── network.png ├── fewshot ├── ablation.py ├── backbone │ ├── backbone_utils.py │ ├── network │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── wideres.py │ └── train_backbone.py ├── data │ ├── cub │ │ └── split │ │ │ ├── test.csv │ │ │ ├── train.csv │ │ │ └── val.csv │ ├── mini │ │ └── split │ │ │ ├── test.csv │ │ │ ├── train.csv │ │ │ └── val.csv │ └── tiered │ │ └── split │ │ ├── test.csv │ │ ├── train.csv │ │ └── val.csv ├── diffresnet.py ├── saved_models │ └── Put downloaded pretrained models here.txt ├── train.py └── utils.py ├── graph ├── data │ ├── citeseer.npz │ ├── cora.npz │ └── pubmed.npz ├── data_process │ ├── __pycache__ │ │ ├── preprocess.cpython-37.pyc │ │ └── preprocess.cpython-39.pyc │ ├── io.py │ ├── make_dataset.py │ └── preprocess.py ├── model.py ├── train.py └── utils.py └── synthetic ├── two_circle_example.py ├── two_moon_example.py ├── two_spiral_example.py └── xor_example.py /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion Mechanism in Neural Network: Theory and Applications 2 | 3 | This repository contains the code for Diff-ResNet implemented with PyTorch. 4 | 5 | More details in paper: 6 | [**Diffusion Mechanism in Residual Neural Network: Theory and Applications**](https://ieeexplore.ieee.org/document/10114599) 7 | 8 | ## Introduction 9 | Inspired by the diffusive ODEs, we propose a novel diffusion residual network (Diff-ResNet) to strengthen the interactions among data points. Diffusion mechanism can decrease the distance-diameter ratio and improves the separability of data points. Figure below shows the evolution of points with diffusion. 10 |
11 | 12 |
13 | 14 | The figure describes the architecture of our network. 15 |
16 | 17 |
18 | 19 | ## Synthetic Data 20 | We offer several toy examples to test the effect of diffusion mechanism and for users to understand how to use diffusion in a **plug-and-play** manner. 21 | 22 | They can serve as minimal working examples of diffusion mechanism. Simply run each python file. 23 | 24 | ## Graph Learning 25 | Code is adapted from [**Pitfalls of graph neural network evaluation**](https://github.com/shchur/gnn-benchmark/tree/master/gnnbench). Users can test our Diff-ResNet on dataset cora, citeseer and pubmed for 100 random dataset splits and 20 random initializations each. One should provide step_size and layer_num. Specific parameter choice for reproducing results in paper is provided in the appendix. 26 | 27 | ``` 28 | python train.py --dataset cora --step_size 0.25 --layer_num 20 --dropout 0.25 29 | ``` 30 | 31 | ## Few-shot 32 | ### 1. Dataset 33 | Download [miniImageNet](https://mega.nz/file/2ldRWQ7Y#U_zhHOf0mxoZ_WQNdvv4mt1Ke3Ay9YPNmHl5TnOVuAU), [tieredImageNet](https://mega.nz/file/r1kmyAgR#uMx7x38RScStpTZARKL2DwTfkD1eVIgbilL4s20vLhI) and [CUB-100](https://mega.nz/file/axUDACZb#ve0NQdmdj_RhhQttONaZ8Tgaxdh4A__PASs_OCI6cSk). Unpack these dataset in to corresponding dataset name directory in [data/](./fewshot/data/). 34 | 35 | ### 2. Backbone Training 36 | You can download pretrained models on base classes [here](https://mega.nz/file/f5lDUJSY#E6zdNonvpPP5nq7cx_heYgLSU6vxCrsbvy4SNr88MT4), and unpack pretrained models in fewshot/saved_models/. 37 | 38 | Or you can train from scratch by running [train_backbone.py](./fewshot/backbone/train_backbone.py). 39 | 40 | ``` 41 | python train_backbone.py --dataset mini --backbone resnet18 --silent --epochs 100 42 | ``` 43 | 44 | ### 3. Diff-ResNets Classification 45 | Run [train.py](./fewshot/train.py) with specified arguments for few-shot classification. Specific parameter choice for reproducing results in paper is provided in the appendix. See argument description for help. 46 | ``` 47 | python train.py --dataset mini --backbone resnet18 --shot 1 --method diffusion --step_size 0.5 --layer_num 6 48 | ``` 49 | 50 | ## Citation 51 | If you find Diff-ResNets useful in your research, please consider citing: 52 | ``` 53 | @article{wang2024diffusion, 54 | author={Wang, Tangjun and Dou, Zehao and Bao, Chenglong and Shi, Zuoqiang}, 55 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 56 | title={Diffusion Mechanism in Residual Neural Network: Theory and Applications}, 57 | year={2024}, 58 | volume={46}, 59 | number={2}, 60 | pages={667-680}, 61 | doi={10.1109/TPAMI.2023.3272341} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /doc/diffusion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/doc/diffusion.gif -------------------------------------------------------------------------------- /doc/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/doc/network.png -------------------------------------------------------------------------------- /fewshot/ablation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to reproduce Table 2. 3 | """ 4 | import os 5 | import random 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import torch.backends.cudnn as cudnn 12 | from torch.optim.lr_scheduler import MultiStepLR 13 | from utils import get_tqdm, get_configuration, get_dataloader, get_embedded_feature, get_base_mean, calculate_weight 14 | from diffresnet import DiffusionResNet 15 | 16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # specify which GPU(s) to be used 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--seed', default=1, type=int, help='seed for training') 21 | parser.add_argument("--dataset", choices=["mini", "tiered", "cub"], type=str) 22 | parser.add_argument("--backbone", choices=["resnet18", "wideres"], type=str) 23 | parser.add_argument("--query_per_class", default=15, type=int, help="number of unlabeled query sample per class") 24 | parser.add_argument("--way", default=5, type=int, help="5-way-k-shot") 25 | parser.add_argument("--test_iter", default=1000, type=int, help="test on 1000 tasks and output average accuracy") 26 | parser.add_argument("--shot", choices=[1, 5], type=int) 27 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm') 28 | 29 | parser.add_argument('--epochs', default=100, type=int, help='number of training epochs') 30 | parser.add_argument("--step_size", type=float, help='strength of each diffusion layer') 31 | parser.add_argument("--layer_num", type=int, help='number of diffusion layers, 0 means no diffusion') 32 | parser.add_argument("--n_top", type=int) 33 | parser.add_argument("--sigma", type=int) 34 | 35 | parser.add_argument("--lamda", help='parameter in LaplacianShot', default=0.5, type=float) 36 | parser.add_argument("--method", choices=['simple', 'laplacian', 'diffusion'], type=str) 37 | parser.add_argument("--mu", help='parameter for weighted sum of ce loss and laplacian loss', type=float, default=0.0) 38 | 39 | args = parser.parse_args() 40 | 41 | 42 | def main(): 43 | if args.seed is not None: 44 | random.seed(args.seed) 45 | torch.manual_seed(args.seed) 46 | cudnn.deterministic = True 47 | 48 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone) 49 | 50 | # Get the output of embedding function (backbone) 51 | test_loader = get_dataloader(data_path, split_path, 'test') 52 | embedded_feature = get_embedded_feature(test_loader, save_path, args.silent) 53 | 54 | acc_list = [] 55 | tqdm_test_iter = get_tqdm(range(args.test_iter), args.silent) 56 | for _ in tqdm_test_iter: 57 | if args.method == 'simple': 58 | acc = simple_shot(embedded_feature) 59 | elif args.method == 'laplacian': 60 | acc = laplacian_shot(embedded_feature) 61 | elif args.method == 'diffusion': 62 | acc = single_trial(embedded_feature) 63 | else: 64 | raise NotImplementedError 65 | 66 | acc_list.append(acc) 67 | 68 | if not args.silent: 69 | tqdm_test_iter.set_description('Test on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list))) 70 | 71 | if args.silent: 72 | print('Accuracy:{:.2f}'.format(np.mean(acc_list))) 73 | 74 | 75 | def sample_task(embedded_feature): 76 | """ 77 | Sample a single few-shot task from novel classes 78 | """ 79 | sample_class = random.sample(list(embedded_feature.keys()), args.way) 80 | train_data, test_data, test_label, train_label = [], [], [], [] 81 | 82 | for i, each_class in enumerate(sample_class): 83 | samples = random.sample(embedded_feature[each_class], args.shot + args.query_per_class) 84 | 85 | train_label += [i] * args.shot 86 | test_label += [i] * args.query_per_class 87 | train_data += samples[:args.shot] 88 | test_data += samples[args.shot:] 89 | 90 | return np.array(train_data), np.array(test_data), np.array(train_label), np.array(test_label) 91 | 92 | 93 | def single_trial(embedded_feature): 94 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 95 | 96 | train_data, test_data, train_label, test_label = torch.tensor(train_data), torch.tensor( 97 | test_data), torch.tensor(train_label), torch.tensor(test_label) 98 | 99 | inputs = torch.cat([train_data, test_data], dim=0) 100 | weight = calculate_weight(inputs, args.n_top, args.sigma) 101 | inputs, train_label, weight = inputs.cuda(), train_label.cuda(), weight.cuda() 102 | model = DiffusionResNet(n_dim=inputs.shape[1], step_size=args.step_size, layer_num=args.layer_num, 103 | weight=weight).cuda() 104 | 105 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 106 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1) 107 | 108 | for epoch in range(args.epochs): 109 | train(model, inputs, train_label, optimizer, weight) 110 | scheduler.step() 111 | 112 | outputs = model(inputs) 113 | 114 | # get the accuracy only on query data 115 | pred = outputs.argmax(dim=1)[args.way * args.shot:].cpu() 116 | acc = torch.eq(pred, test_label).float().mean().cpu().numpy() * 100 117 | return acc 118 | 119 | 120 | def train(model, inputs, train_label, optimizer, weight): 121 | outputs = model(inputs) 122 | loss1 = nn.CrossEntropyLoss()(outputs[:args.way * args.shot], train_label) 123 | loss2 = torch.sum(weight * torch.linalg.norm(outputs.unsqueeze(0) - outputs.unsqueeze(1), dim=-1) ** 2) 124 | 125 | loss = loss1 + args.mu * loss2 126 | optimizer.zero_grad() 127 | loss.backward() 128 | optimizer.step() 129 | 130 | 131 | def simple_shot(embedded_feature): 132 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 133 | 134 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1) 135 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1) 136 | 137 | idx = np.argmin(distance, axis=1) 138 | pred = np.take(np.unique(train_label), idx) 139 | acc = (pred == test_label).mean() * 100 140 | return acc 141 | 142 | 143 | def laplacian_shot(embedded_feature, knn=3, lamda=args.lamda, max_iter=20): 144 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 145 | 146 | # calculate weight 147 | n = test_data.shape[0] 148 | w = np.zeros((n, n)) 149 | distance = np.linalg.norm(test_data - test_data[:, None], axis=-1) 150 | knn_ind = np.argsort(distance, axis=1)[:, 1:knn] 151 | np.put_along_axis(w, knn_ind, 1.0, axis=1) 152 | 153 | # (8a) 154 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1) 155 | a = np.linalg.norm(prototype - test_data[:, None], axis=-1) 156 | 157 | y = np.exp(-a) / np.sum(np.exp(-a), axis=1, keepdims=True) 158 | energy = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y))) 159 | 160 | for i in range(max_iter): 161 | # (12) update 162 | out = - a + lamda * np.dot(w, y) 163 | y = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) 164 | 165 | # (7) check stopping criterion 166 | energy_new = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y))) 167 | if abs((energy_new - energy) / energy) < 1e-6: 168 | break 169 | energy = energy_new.copy() 170 | 171 | idx = np.argmax(y, axis=1) 172 | pred = np.take(np.unique(train_label), idx) 173 | acc = (pred == test_label).mean() * 100 174 | return acc 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /fewshot/backbone/backbone_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL.Image as Image 4 | import torch.utils.data as data 5 | from torchvision import transforms 6 | from tqdm import tqdm 7 | import torch 8 | import re 9 | 10 | 11 | def get_configuration(dataset, backbone): 12 | """ 13 | Get configuration according to dataset and backbone. 14 | """ 15 | 16 | data_path = '../data/' + dataset + '/images' 17 | split_path = '../data/' + dataset + '/split' 18 | save_path = '../saved_models/' + dataset + '_' + backbone + '.pt' 19 | 20 | if dataset == 'mini': 21 | num_classes = 64 22 | elif dataset == 'tiered': 23 | num_classes = 351 24 | elif dataset == 'cub': 25 | num_classes = 100 26 | else: 27 | raise NotImplementedError 28 | 29 | return data_path, split_path, save_path, num_classes 30 | 31 | 32 | class DatasetFolder(data.Dataset): 33 | def __init__(self, root, split_dir, split_type, transform): 34 | assert split_type in ['train', 'val', 'test'] 35 | split_file = os.path.join(split_dir, split_type + '.csv') 36 | assert os.path.isfile(split_file) 37 | 38 | with open(split_file, 'r') as f: 39 | split = [x.strip().split(',') for x in f.readlines()[1:] if x.strip() != ''] 40 | 41 | data, ori_labels = [x[0] for x in split], [x[1] for x in split] 42 | label_key = sorted(np.unique(np.array(ori_labels))) 43 | label_map = dict(zip(label_key, range(len(label_key)))) 44 | mapped_labels = [label_map[x] for x in ori_labels] 45 | 46 | self.root = root 47 | self.transform = transform 48 | self.data = data 49 | self.labels = mapped_labels 50 | self.length = len(self.data) 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | def __getitem__(self, index): 56 | filename = self.data[index] 57 | path_file = os.path.join(self.root, filename) 58 | assert os.path.isfile(path_file) 59 | img = Image.open(path_file).convert('RGB') 60 | label = self.labels[index] 61 | label = int(label) 62 | if self.transform: 63 | img = self.transform(img) 64 | 65 | return img, label 66 | 67 | 68 | def get_train_dataloader(data_path, split_path, batch_size): 69 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type='train', 70 | transform=transforms.Compose([transforms.RandomResizedCrop(84), 71 | transforms.ColorJitter(brightness=0.4, contrast=0.4, 72 | saturation=0.4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 76 | std=[0.229, 0.224, 0.225])])) 77 | 78 | # Setting appropriate num_workers can significantly increase training speed 79 | loader = data.DataLoader(datasets, batch_size=batch_size, shuffle=True, num_workers=40, pin_memory=True) 80 | 81 | return loader 82 | 83 | 84 | def get_val_dataloader(data_path, split_path): 85 | dataset = re.split('[/_]', data_path)[-2] 86 | if dataset == "cub": 87 | resize = 120 88 | else: 89 | resize = 96 90 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type='val', 91 | transform=transforms.Compose([transforms.Resize(resize), 92 | transforms.CenterCrop(84), 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 95 | std=[0.229, 0.224, 0.225])])) 96 | loader = torch.utils.data.DataLoader(datasets, batch_size=100, shuffle=False, num_workers=40) 97 | return loader 98 | 99 | 100 | def get_tqdm(iters, silent): 101 | """ 102 | Wrap iters with tqdm if not --silent 103 | """ 104 | if silent: 105 | return iters 106 | else: 107 | return tqdm(iters) 108 | -------------------------------------------------------------------------------- /fewshot/backbone/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet18 2 | from .wideres import wideres 3 | -------------------------------------------------------------------------------- /fewshot/backbone/network/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None): 20 | super(BasicBlock, self).__init__() 21 | self.conv1 = conv3x3(inplanes, planes, stride) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.conv2 = conv3x3(planes, planes) 25 | self.bn2 = nn.BatchNorm2d(planes) 26 | self.downsample = downsample 27 | self.stride = stride 28 | 29 | def forward(self, x): 30 | identity = x 31 | 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | if self.downsample is not None: 40 | identity = self.downsample(x) 41 | 42 | out += identity 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | 48 | class Bottleneck(nn.Module): 49 | expansion = 4 50 | 51 | def __init__(self, inplanes, planes, stride=1, downsample=None): 52 | super(Bottleneck, self).__init__() 53 | self.conv1 = conv1x1(inplanes, planes) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = conv3x3(planes, planes, stride) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = conv1x1(planes, planes * self.expansion) 58 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 59 | self.relu = nn.ReLU(inplace=True) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | identity = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv3(out) 75 | out = self.bn3(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class ResNet(nn.Module): 87 | 88 | def __init__(self, block, layers, num_classes=1000): 89 | super(ResNet, self).__init__() 90 | self.inplanes = 64 91 | self.conv1 = conv3x3(3, 64) 92 | self.bn1 = nn.BatchNorm2d(64) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.layer1 = self._make_layer(block, 64, layers[0]) 95 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 98 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 99 | 100 | self.fc = nn.Linear(512 * block.expansion, num_classes) 101 | 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 105 | elif isinstance(m, nn.BatchNorm2d): 106 | nn.init.constant_(m.weight, 1) 107 | nn.init.constant_(m.bias, 0) 108 | 109 | def _make_layer(self, block, planes, blocks, stride=1): 110 | downsample = None 111 | if stride != 1 or self.inplanes != planes * block.expansion: 112 | downsample = nn.Sequential( 113 | conv1x1(self.inplanes, planes * block.expansion, stride), 114 | nn.BatchNorm2d(planes * block.expansion), 115 | ) 116 | 117 | layers = [] 118 | layers.append(block(self.inplanes, planes, stride, downsample)) 119 | self.inplanes = planes * block.expansion 120 | for _ in range(1, blocks): 121 | layers.append(block(self.inplanes, planes)) 122 | 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x, return_feature=False): 126 | x = self.conv1(x) 127 | x = self.bn1(x) 128 | x = self.relu(x) 129 | 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x = self.layer3(x) 133 | x = self.layer4(x) 134 | 135 | x = self.avgpool(x) 136 | feature = x.view(x.size(0), -1) 137 | out = self.fc(feature) 138 | 139 | if return_feature: 140 | return feature, out 141 | else: 142 | return out 143 | 144 | 145 | def resnet10(**kwargs): 146 | """Constructs a ResNet-10 model. 147 | """ 148 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 149 | return model 150 | 151 | 152 | def resnet18(**kwargs): 153 | """Constructs a ResNet-18 model. 154 | """ 155 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 156 | return model 157 | 158 | 159 | def resnet34(**kwargs): 160 | """Constructs a ResNet-34 model. 161 | """ 162 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 163 | return model 164 | 165 | 166 | def resnet50(**kwargs): 167 | """Constructs a ResNet-50 model. 168 | """ 169 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 170 | return model 171 | 172 | 173 | def resnet101(**kwargs): 174 | """Constructs a ResNet-101 model. 175 | """ 176 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 177 | return model 178 | 179 | 180 | def resnet152(**kwargs): 181 | """Constructs a ResNet-152 model. 182 | """ 183 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 184 | return model 185 | -------------------------------------------------------------------------------- /fewshot/backbone/network/wideres.py: -------------------------------------------------------------------------------- 1 | """ 2 | https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py 3 | """ 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['wideres'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1): 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 12 | 13 | 14 | class wide_basic(nn.Module): 15 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 16 | super(wide_basic, self).__init__() 17 | self.bn1 = nn.BatchNorm2d(in_planes) 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 19 | self.dropout = nn.Dropout(p=dropout_rate) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 27 | ) 28 | 29 | def forward(self, x): 30 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 31 | out = self.conv2(F.relu(self.bn2(out))) 32 | out += self.shortcut(x) 33 | 34 | return out 35 | 36 | 37 | class Wide_ResNet(nn.Module): 38 | def __init__(self, depth, widen_factor, dropout_rate, num_classes): 39 | super(Wide_ResNet, self).__init__() 40 | self.in_planes = 16 41 | 42 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4' 43 | n = (depth - 4) // 6 44 | k = widen_factor 45 | 46 | nStages = [16, 16 * k, 32 * k, 64 * k] 47 | 48 | self.conv1 = conv3x3(3, nStages[0]) 49 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 50 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 51 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 52 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 53 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 54 | 55 | self.linear = nn.Linear(nStages[3], num_classes) 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 59 | elif isinstance(m, nn.BatchNorm2d): 60 | nn.init.constant_(m.weight, 1) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 64 | strides = [stride] + [1] * (num_blocks - 1) 65 | layers = [] 66 | 67 | for stride in strides: 68 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 69 | self.in_planes = planes 70 | 71 | return nn.Sequential(*layers) 72 | 73 | def forward(self, x, return_feature=False): 74 | out = self.conv1(x) 75 | out = self.layer1(out) 76 | out = self.layer2(out) 77 | out = self.layer3(out) 78 | out = F.relu(self.bn1(out)) 79 | out = self.avgpool(out) 80 | feature = out.view(out.size(0), -1) 81 | out = self.linear(feature) 82 | 83 | if return_feature: 84 | return feature, out 85 | else: 86 | return out 87 | 88 | 89 | def wideres(num_classes): 90 | """Constructs a wideres-28-10 model without dropout. 91 | """ 92 | return Wide_ResNet(28, 10, 0, num_classes) 93 | -------------------------------------------------------------------------------- /fewshot/backbone/train_backbone.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import argparse 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import MultiStepLR 9 | import network 10 | import numpy as np 11 | import collections 12 | from backbone_utils import get_configuration, get_train_dataloader, get_tqdm, get_val_dataloader 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--seed', default=1, type=int, help='seed for training') 16 | parser.add_argument("--dataset", choices=['mini', 'tiered', 'cub'], type=str) 17 | parser.add_argument("--backbone", choices=['resnet18', 'wideres'], type=str, help='network architecture') 18 | parser.add_argument('--epochs', type=int, help='number of training epochs. 100 for mini and tiered. 400 for cub') 19 | parser.add_argument('--batch_size', default=256, type=int) 20 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm') 21 | 22 | args = parser.parse_args() 23 | 24 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # specify which GPU(s) to be used 26 | 27 | 28 | def main(): 29 | if args.seed is not None: 30 | random.seed(args.seed) 31 | torch.manual_seed(args.seed) 32 | cudnn.deterministic = True 33 | 34 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone) 35 | train_loader = get_train_dataloader(data_path, split_path, args.batch_size) 36 | val_loader = get_val_dataloader(data_path, split_path) 37 | 38 | model = network.__dict__[args.backbone](num_classes=num_classes) 39 | model = torch.nn.DataParallel(model).cuda() 40 | 41 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) 42 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1) 43 | 44 | tqdm_epochs = get_tqdm(range(args.epochs), args.silent) 45 | if not args.silent: 46 | tqdm_epochs.set_description('Total Epochs') 47 | 48 | if not os.path.isdir('../saved_models'): 49 | os.makedirs('../saved_models') 50 | 51 | best_acc = 0 52 | for epoch in tqdm_epochs: 53 | train(train_loader, model, optimizer, epoch) 54 | scheduler.step() 55 | 56 | if epoch >= int(.75 * args.epochs): 57 | val_acc = validate(val_loader, model) 58 | if val_acc > best_acc: 59 | best_acc = val_acc 60 | torch.save(model.state_dict(), save_path) 61 | 62 | 63 | def train(train_loader, model, optimizer, epoch): 64 | model.train() 65 | 66 | correct_count = 0 67 | total_count = 0 68 | acc = 0 69 | tqdm_train_loader = get_tqdm(train_loader, args.silent) 70 | 71 | for batch_idx, (inputs, labels) in enumerate(tqdm_train_loader): 72 | inputs, labels = inputs.cuda(), labels.cuda() 73 | outputs = model(inputs) 74 | loss = nn.CrossEntropyLoss(label_smoothing=0.1)(outputs, labels) 75 | 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | 80 | pred = outputs.argmax(dim=1) 81 | correct_count += pred.eq(labels).sum().item() 82 | total_count += len(inputs) 83 | acc = correct_count / total_count * 100 84 | 85 | if not args.silent: 86 | tqdm_train_loader.set_description('Acc {:.2f}'.format(acc)) 87 | 88 | if args.silent: 89 | print("Epoch={}, Accuracy={:.2f}".format(epoch + 1, acc)) 90 | 91 | 92 | # Below codes only used for validation. We save the models with the highest 1-shot nearest neighbor classification 93 | # accuracy. 94 | def validate(val_loader, model): 95 | input_dict = collections.defaultdict(list) 96 | for i, (inputs, labels) in enumerate(val_loader): 97 | for img, label in zip(inputs, labels): 98 | input_dict[label.item()].append(img) 99 | 100 | acc_list = [] 101 | tqdm_test_iter = get_tqdm(range(1000), args.silent) 102 | for _ in tqdm_test_iter: 103 | acc = nearest_prototype(input_dict, model) 104 | acc_list.append(acc) 105 | 106 | if not args.silent: 107 | tqdm_test_iter.set_description('Validate on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list))) 108 | if args.silent: 109 | print("Validation Accuracy={:.2f}".format(np.mean(acc_list))) 110 | 111 | return np.mean(acc_list) 112 | 113 | 114 | def nearest_prototype(input_dict, model): 115 | sample_class = random.sample(list(input_dict.keys()), 5) 116 | train_img, test_img, test_label, train_label = [], [], [], [] 117 | for i, each_class in enumerate(sample_class): 118 | samples = random.sample(input_dict[each_class], 1 + 15) 119 | 120 | train_label += [i] * 1 # We only validate on 1-shot tasks, for simplicity 121 | test_label += [i] * 15 122 | train_img += samples[:1] 123 | test_img += samples[1:] 124 | 125 | train_img, test_img = torch.stack(train_img).cuda(), torch.stack(test_img).cuda() 126 | train_test_img = torch.cat([train_img, test_img]) 127 | 128 | train_label, test_label = np.array(train_label), np.array(test_label) 129 | 130 | model.eval() 131 | with torch.no_grad(): 132 | train_test_data, _ = model(train_test_img, return_feature=True) 133 | 134 | train_test_data = train_test_data.cpu().data.numpy() 135 | train_data, test_data = train_test_data[:5], train_test_data[5:] 136 | 137 | prototype = train_data.reshape((5, 1, -1)).mean(axis=1) 138 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1) 139 | 140 | idx = np.argmin(distance, axis=1) 141 | pred = np.take(np.unique(train_label), idx) 142 | acc = (pred == test_label).mean() * 100 143 | return acc 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /fewshot/diffresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DiffusionLayer(nn.Module): 7 | def __init__(self, step_size, laplacian): 8 | super(DiffusionLayer, self).__init__() 9 | self.step_size = step_size 10 | self.laplacian = laplacian 11 | 12 | def forward(self, x): 13 | x = x - self.step_size * torch.matmul(self.laplacian, x.flatten(1)).view_as(x) 14 | return x 15 | 16 | 17 | class DiffusionResNet(nn.Module): 18 | def __init__(self, n_dim, step_size, layer_num, weight): 19 | super(DiffusionResNet, self).__init__() 20 | self.layer_num = layer_num 21 | diagonal = torch.diag(weight.sum(dim=1)) 22 | laplacian = diagonal - weight 23 | 24 | self.fc1 = nn.Linear(n_dim, n_dim) 25 | self.fc2 = nn.Linear(n_dim, n_dim) 26 | self.classifier = nn.Linear(n_dim, 5) # 5-way classification 27 | self.diffusion_layer = DiffusionLayer(step_size, laplacian) 28 | 29 | def forward(self, x): 30 | x = self.fc2(F.relu(self.fc1(x))) + x 31 | for _ in range(self.layer_num): 32 | x = self.diffusion_layer(x) 33 | out = self.classifier(x) 34 | return out 35 | -------------------------------------------------------------------------------- /fewshot/saved_models/Put downloaded pretrained models here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/fewshot/saved_models/Put downloaded pretrained models here.txt -------------------------------------------------------------------------------- /fewshot/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to reproduce Table 3. 3 | """ 4 | import os 5 | import random 6 | import argparse 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.backends.cudnn as cudnn 13 | from torch.optim.lr_scheduler import MultiStepLR 14 | from utils import get_tqdm, get_configuration, get_dataloader, get_embedded_feature, get_base_mean 15 | from utils import compute_confidence_interval, calculate_weight 16 | from diffresnet import DiffusionResNet 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--seed', default=1, type=int, help='seed for training') 20 | parser.add_argument("--dataset", choices=["mini", "tiered", "cub"], type=str) 21 | parser.add_argument("--backbone", choices=["resnet18", "wideres"], type=str) 22 | parser.add_argument("--query_per_class", default=15, type=int, help="number of unlabeled query sample per class") 23 | parser.add_argument("--way", default=5, type=int, help="5-way-k-shot") 24 | parser.add_argument("--test_iter", default=10000, type=int, help="test on 10000 tasks and output average accuracy") 25 | parser.add_argument("--shot", choices=[1, 5], type=int) 26 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm') 27 | 28 | parser.add_argument('--epochs', default=100, type=int, help='number of training epochs') 29 | parser.add_argument("--step_size", type=float, help='strength of each diffusion layer', default=0.5) 30 | parser.add_argument("--layer_num", type=int, help='number of diffusion layers, 0 means no diffusion') 31 | parser.add_argument("--n_top", type=int, default=8) 32 | parser.add_argument("--sigma", type=int, default=4) 33 | 34 | parser.add_argument("--lamda", help='parameter in LaplacianShot', default=0.5, type=float) 35 | parser.add_argument("--method", choices=['simple', 'laplacian', 'diffusion'], type=str) 36 | parser.add_argument("--alpha", help='parameter for weighted sum of ce loss and proto loss', type=float, default=0.0) 37 | 38 | args = parser.parse_args() 39 | 40 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 41 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' # specify which GPU(s) to be used 42 | 43 | 44 | def main(): 45 | if args.seed is not None: 46 | random.seed(args.seed) 47 | torch.manual_seed(args.seed) 48 | cudnn.deterministic = True 49 | 50 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone) 51 | 52 | # On novel class: get the output of embedding function (backbone) 53 | # On base class: get the output average of embedding function (backbone), used for centering 54 | train_loader = get_dataloader(data_path, split_path, 'train') 55 | test_loader = get_dataloader(data_path, split_path, 'test') 56 | embedded_feature = get_embedded_feature(test_loader, save_path, args.silent) 57 | base_mean = get_base_mean(train_loader, save_path, args.silent) 58 | 59 | acc_list = [] 60 | tqdm_test_iter = get_tqdm(range(args.test_iter), args.silent) 61 | 62 | for _ in tqdm_test_iter: 63 | if args.method == 'simple': 64 | acc = simple_shot(embedded_feature, base_mean) 65 | elif args.method == 'laplacian': 66 | acc = laplacian_shot(embedded_feature, base_mean) 67 | elif args.method == 'diffusion': 68 | acc = single_trial(embedded_feature, base_mean) 69 | else: 70 | raise NotImplementedError 71 | 72 | acc_list.append(acc) 73 | 74 | if not args.silent: 75 | tqdm_test_iter.set_description('Test on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list))) 76 | 77 | acc_mean, acc_conf = compute_confidence_interval(acc_list) 78 | print('Accuracy:{:.2f}'.format(acc_mean)) 79 | print('Conf:{:.2f}'.format(acc_conf)) 80 | 81 | 82 | def sample_task(embedded_feature): 83 | """ 84 | Sample a single few-shot task from novel classes 85 | """ 86 | sample_class = random.sample(list(embedded_feature.keys()), args.way) 87 | train_data, test_data, test_label, train_label = [], [], [], [] 88 | 89 | for i, each_class in enumerate(sample_class): 90 | samples = random.sample(embedded_feature[each_class], args.shot + args.query_per_class) 91 | 92 | train_label += [i] * args.shot 93 | test_label += [i] * args.query_per_class 94 | train_data += samples[:args.shot] 95 | test_data += samples[args.shot:] 96 | 97 | return np.array(train_data), np.array(test_data), np.array(train_label), np.array(test_label) 98 | 99 | 100 | def single_trial(embedded_feature, base_mean): 101 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 102 | 103 | train_data, test_data, train_label, test_label, base_mean = torch.tensor(train_data), torch.tensor( 104 | test_data), torch.tensor(train_label), torch.tensor(test_label), torch.tensor(base_mean) 105 | 106 | # Centering and Normalization 107 | train_data = train_data - base_mean 108 | train_data = train_data / torch.norm(train_data, dim=1, keepdim=True) 109 | test_data = test_data - base_mean 110 | test_data = test_data / torch.norm(test_data, dim=1, keepdim=True) 111 | 112 | # Cross-Domain Shift 113 | eta = train_data.mean(dim=0, keepdim=True) - test_data.mean(dim=0, keepdim=True) 114 | test_data = test_data + eta 115 | 116 | inputs = torch.cat([train_data, test_data], dim=0) 117 | weight = calculate_weight(inputs, args.n_top, args.sigma) 118 | inputs, train_label, weight = inputs.cuda(), train_label.cuda(), weight.cuda() 119 | model = DiffusionResNet(n_dim=inputs.shape[1], step_size=args.step_size, layer_num=args.layer_num, 120 | weight=weight).cuda() 121 | 122 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 123 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1) 124 | 125 | # Prototype Rectification 126 | whole_data = torch.cat([train_data, test_data], dim=0) 127 | prototype = train_data.reshape(args.way, args.shot, -1).mean(dim=1) 128 | cos_sim = F.cosine_similarity(whole_data[:, None, :], prototype[None, :, :], dim=2) * 10 # 10 is a parameter 129 | pseudo_predict = torch.argmax(cos_sim, dim=1) 130 | cos_weight = F.softmax(cos_sim, dim=1) 131 | rectified_prototype = torch.cat( 132 | [(cos_weight[pseudo_predict == i, i].unsqueeze(1) * whole_data[pseudo_predict == i]).mean(0, keepdim=True) 133 | for i in range(args.way)], dim=0) 134 | rectified_prototype = rectified_prototype.cuda() 135 | 136 | for epoch in range(args.epochs): 137 | train(model, inputs, train_label, optimizer, rectified_prototype) 138 | scheduler.step() 139 | 140 | outputs = model(inputs) 141 | 142 | # get the accuracy only on query data 143 | pred = outputs.argmax(dim=1)[args.way * args.shot:].cpu() 144 | acc = torch.eq(pred, test_label).float().mean().cpu().numpy() * 100 145 | return acc 146 | 147 | 148 | def train(model, inputs, train_label, optimizer, prototype): 149 | outputs = model(inputs) 150 | loss = nn.CrossEntropyLoss()(outputs[:args.way * args.shot], train_label) 151 | 152 | distance = torch.linalg.norm(inputs[args.way * args.shot:].unsqueeze(1) - prototype.unsqueeze(0), dim=2) 153 | proto_loss = (F.softmax(outputs[args.way * args.shot:], dim=1) * distance).sum() 154 | 155 | loss = loss + args.alpha * proto_loss 156 | optimizer.zero_grad() 157 | loss.backward() 158 | optimizer.step() 159 | 160 | 161 | def simple_shot(embedded_feature, base_mean): 162 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 163 | 164 | # Centering and Normalization 165 | train_data = train_data - base_mean 166 | train_data = train_data / np.linalg.norm(train_data, axis=1, keepdims=True) 167 | test_data = test_data - base_mean 168 | test_data = test_data / np.linalg.norm(test_data, axis=1, keepdims=True) 169 | 170 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1) 171 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1) 172 | 173 | idx = np.argmin(distance, axis=1) 174 | pred = np.take(np.unique(train_label), idx) 175 | acc = (pred == test_label).mean() * 100 176 | return acc 177 | 178 | 179 | def laplacian_shot(embedded_feature, base_mean, knn=3, lamda=args.lamda, max_iter=20): 180 | train_data, test_data, train_label, test_label = sample_task(embedded_feature) 181 | 182 | # Centering and Normalization 183 | train_data = train_data - base_mean 184 | train_data = train_data / np.linalg.norm(train_data, axis=1, keepdims=True) 185 | test_data = test_data - base_mean 186 | test_data = test_data / np.linalg.norm(test_data, axis=1, keepdims=True) 187 | 188 | # Cross-Domain Shift 189 | eta = train_data.mean(axis=0, keepdims=True) - test_data.mean(axis=0, keepdims=True) 190 | test_data = test_data + eta 191 | 192 | # Prototype Rectification 193 | train_data, test_data = torch.tensor(train_data), torch.tensor(test_data) 194 | whole_data = torch.cat([train_data, test_data], dim=0) 195 | prototype = train_data.reshape(args.way, args.shot, -1).mean(dim=1) 196 | cos_sim = F.cosine_similarity(whole_data[:, None, :], prototype[None, :, :], dim=2) * 10 # 10 is a parameter 197 | pseudo_predict = torch.argmax(cos_sim, dim=1) 198 | cos_weight = F.softmax(cos_sim, dim=1) 199 | rectified_prototype = torch.cat( 200 | [(cos_weight[pseudo_predict == i, i].unsqueeze(1) * whole_data[pseudo_predict == i]).mean(0, keepdim=True) 201 | for i in range(args.way)], dim=0) 202 | 203 | # calculate weight 204 | n = test_data.shape[0] 205 | w = np.zeros((n, n)) 206 | distance = np.linalg.norm(test_data - test_data[:, None], axis=-1) 207 | knn_ind = np.argsort(distance, axis=1)[:, 1:knn] 208 | np.put_along_axis(w, knn_ind, 1.0, axis=1) 209 | 210 | # (8a) 211 | # prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1) 212 | a = np.linalg.norm(rectified_prototype - test_data[:, None], axis=-1) 213 | 214 | y = np.exp(-a) / np.sum(np.exp(-a), axis=1, keepdims=True) 215 | energy = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y))) 216 | 217 | for i in range(max_iter): 218 | # (12) update 219 | out = - a + lamda * np.dot(w, y) 220 | y = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True) 221 | 222 | # (7) check stopping criterion 223 | energy_new = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y))) 224 | if abs((energy_new - energy) / energy) < 1e-6: 225 | break 226 | energy = energy_new.copy() 227 | 228 | idx = np.argmax(y, axis=1) 229 | pred = np.take(np.unique(train_label), idx) 230 | acc = (pred == test_label).mean() * 100 231 | return acc 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /fewshot/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import pickle 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | from torchvision import transforms 8 | import backbone.network as network 9 | from tqdm import tqdm 10 | import PIL.Image as Image 11 | import re 12 | 13 | 14 | def get_tqdm(iters, silent): 15 | """ 16 | Wrap iters with tqdm if not --silent 17 | """ 18 | if silent: 19 | return iters 20 | else: 21 | return tqdm(iters) 22 | 23 | 24 | def save_pickle(file, data): 25 | with open(file, 'wb') as f: 26 | pickle.dump(data, f) 27 | 28 | 29 | def load_pickle(file): 30 | with open(file, 'rb') as f: 31 | return pickle.load(f) 32 | 33 | 34 | def compute_confidence_interval(data): 35 | """ 36 | Compute 95% confidence interval 37 | """ 38 | return np.mean(data), 1.96 * (np.std(data) / np.sqrt(len(data))) 39 | 40 | 41 | def calculate_weight(inputs, n_top, sigma): 42 | distance = torch.norm(inputs.unsqueeze(0) - inputs.unsqueeze(1), dim=-1) 43 | dist_n_top = torch.kthvalue(distance, n_top, dim=1, keepdim=True)[0] 44 | dist_sigma = torch.kthvalue(distance, sigma, dim=1, keepdim=True)[0] 45 | 46 | distance_truncated = distance.where(distance < dist_n_top, torch.tensor(float("inf"))) 47 | weight = torch.exp(-(distance_truncated / dist_sigma).pow(2)) 48 | 49 | # Symmetrically normalize the weight matrix 50 | d_inv_sqrt = torch.diag(weight.sum(dim=1).pow(-0.5)) 51 | weight = d_inv_sqrt.mm(weight).mm(d_inv_sqrt) 52 | weight = (weight + weight.t()) / 2 53 | weight = weight.detach() 54 | return weight 55 | 56 | 57 | class DatasetFolder(data.Dataset): 58 | def __init__(self, root, split_dir, split_type, transform): 59 | assert split_type in ['train', 'val', 'test'] 60 | split_file = os.path.join(split_dir, split_type + '.csv') 61 | assert os.path.isfile(split_file) 62 | 63 | with open(split_file, 'r') as f: 64 | split = [x.strip().split(',') for x in f.readlines()[1:] if x.strip() != ''] 65 | 66 | data, ori_labels = [x[0] for x in split], [x[1] for x in split] 67 | label_key = sorted(np.unique(np.array(ori_labels))) 68 | label_map = dict(zip(label_key, range(len(label_key)))) 69 | mapped_labels = [label_map[x] for x in ori_labels] 70 | 71 | self.root = root 72 | self.transform = transform 73 | self.data = data 74 | self.labels = mapped_labels 75 | self.length = len(self.data) 76 | 77 | def __len__(self): 78 | return self.length 79 | 80 | def __getitem__(self, index): 81 | filename = self.data[index] 82 | path_file = os.path.join(self.root, filename) 83 | assert os.path.isfile(path_file) 84 | img = Image.open(path_file).convert('RGB') 85 | label = self.labels[index] 86 | label = int(label) 87 | if self.transform: 88 | img = self.transform(img) 89 | return img, label 90 | 91 | 92 | def get_dataloader(data_path, split_path, split_type): 93 | dataset = re.split('[/_]', data_path)[-2] 94 | # First resize larger than 84, then center crop, achieve better result 95 | if dataset == "cub": 96 | resize = 120 97 | else: 98 | resize = 96 99 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type=split_type, 100 | transform=transforms.Compose([transforms.Resize(resize), 101 | transforms.CenterCrop(84), 102 | transforms.ToTensor(), 103 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 104 | std=[0.229, 0.224, 0.225])])) 105 | 106 | loader = torch.utils.data.DataLoader(datasets, batch_size=1000, shuffle=False, num_workers=40) 107 | return loader 108 | 109 | 110 | def get_embedded_feature(test_loader, save_path, silent): 111 | """ 112 | Return embedded features of data from novel classes 113 | """ 114 | # Only compute once for each dataset+backbone 115 | if os.path.isfile(save_path + '_embedded_feature.plk'): 116 | embedded_feature = load_pickle(save_path + '_embedded_feature.plk') 117 | return embedded_feature 118 | 119 | model = load_pretrained_backbone(save_path) 120 | 121 | model.eval() 122 | with torch.no_grad(): 123 | embedded_feature = collections.defaultdict(list) 124 | 125 | tqdm_test_loader = get_tqdm(test_loader, silent) 126 | if not silent: 127 | tqdm_test_loader.set_description('Computing embedded features on test classes') 128 | 129 | for i, (inputs, labels) in enumerate(tqdm_test_loader): 130 | features, _ = model(inputs, return_feature=True) 131 | features = features.cpu().data.numpy() 132 | for feature, label in zip(features, labels): 133 | embedded_feature[label.item()].append(feature) 134 | save_pickle(save_path + '_embedded_feature.plk', embedded_feature) 135 | 136 | return embedded_feature 137 | 138 | 139 | def get_base_mean(train_loader, save_path, silent): 140 | """ 141 | Return average of data from base classes 142 | """ 143 | # Only compute once for each dataset+backbone 144 | if os.path.isfile(save_path + '_base_mean.plk'): 145 | base_mean = load_pickle(save_path + '_base_mean.plk') 146 | return base_mean 147 | 148 | model = load_pretrained_backbone(save_path) 149 | 150 | model.eval() 151 | with torch.no_grad(): 152 | base_mean = [] 153 | 154 | tqdm_train_loader = get_tqdm(train_loader, silent) 155 | if not silent: 156 | tqdm_train_loader.set_description('Computing average on base classes') 157 | 158 | for i, (inputs, _) in enumerate(tqdm_train_loader): 159 | outputs, _ = model(inputs, return_feature=True) 160 | outputs = outputs.cpu().data.numpy() 161 | base_mean.append(outputs) 162 | base_mean = np.concatenate(base_mean, axis=0).mean(axis=0) 163 | save_pickle(save_path + '_base_mean.plk', base_mean) 164 | return base_mean 165 | 166 | 167 | def get_configuration(dataset, backbone): 168 | """ 169 | Get configuration according to dataset and backbone. 170 | """ 171 | 172 | data_path = './data/' + dataset + '/images' 173 | split_path = './data/' + dataset + '/split' 174 | save_path = './saved_models/' + dataset + '_' + backbone 175 | 176 | if dataset == 'mini': 177 | num_classes = 64 178 | elif dataset == 'tiered': 179 | num_classes = 351 180 | elif dataset == 'cub': 181 | num_classes = 100 182 | else: 183 | raise NotImplementedError 184 | 185 | return data_path, split_path, save_path, num_classes 186 | 187 | 188 | def load_pretrained_backbone(save_path): 189 | dataset = re.split('[/_]', save_path)[-2] 190 | backbone = re.split('[/_]', save_path)[-1] 191 | 192 | if dataset == 'mini': 193 | num_classes = 64 194 | elif dataset == 'tiered': 195 | num_classes = 351 196 | elif dataset == 'cub': 197 | num_classes = 100 198 | else: 199 | raise NotImplementedError 200 | 201 | model = network.__dict__[backbone](num_classes=num_classes) 202 | model = torch.nn.DataParallel(model).cuda() 203 | model.load_state_dict(torch.load(save_path + '.pt')) 204 | 205 | return model 206 | -------------------------------------------------------------------------------- /graph/data/citeseer.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/citeseer.npz -------------------------------------------------------------------------------- /graph/data/cora.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/cora.npz -------------------------------------------------------------------------------- /graph/data/pubmed.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/pubmed.npz -------------------------------------------------------------------------------- /graph/data_process/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data_process/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /graph/data_process/__pycache__/preprocess.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data_process/__pycache__/preprocess.cpython-39.pyc -------------------------------------------------------------------------------- /graph/data_process/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.sparse as sp 4 | from .preprocess import eliminate_self_loops as eliminate_self_loops_adj, largest_connected_components 5 | 6 | 7 | class SparseGraph: 8 | """Attributed labeled graph stored in sparse matrix form. 9 | 10 | """ 11 | 12 | def __init__(self, adj_matrix, attr_matrix=None, labels=None, 13 | node_names=None, attr_names=None, class_names=None, metadata=None): 14 | """Create an attributed graph. 15 | 16 | Parameters 17 | ---------- 18 | adj_matrix : sp.csr_matrix, shape [num_nodes, num_nodes] 19 | Adjacency matrix in CSR format. 20 | attr_matrix : sp.csr_matrix or np.ndarray, shape [num_nodes, num_attr], optional 21 | Attribute matrix in CSR or numpy format. 22 | labels : np.ndarray, shape [num_nodes], optional 23 | Array, where each entry represents respective node's label(s). 24 | node_names : np.ndarray, shape [num_nodes], optional 25 | Names of nodes (as strings). 26 | attr_names : np.ndarray, shape [num_attr] 27 | Names of the attributes (as strings). 28 | class_names : np.ndarray, shape [num_classes], optional 29 | Names of the class labels (as strings). 30 | metadata : object 31 | Additional metadata such as text. 32 | 33 | """ 34 | # Make sure that the dimensions of matrices / arrays all agree 35 | if sp.isspmatrix(adj_matrix): 36 | adj_matrix = adj_matrix.tocsr().astype(np.float32) 37 | else: 38 | raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)" 39 | .format(type(adj_matrix))) 40 | 41 | if adj_matrix.shape[0] != adj_matrix.shape[1]: 42 | raise ValueError("Dimensions of the adjacency matrix don't agree") 43 | 44 | if attr_matrix is not None: 45 | if sp.isspmatrix(attr_matrix): 46 | attr_matrix = attr_matrix.tocsr().astype(np.float32) 47 | elif isinstance(attr_matrix, np.ndarray): 48 | attr_matrix = attr_matrix.astype(np.float32) 49 | else: 50 | raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)" 51 | .format(type(attr_matrix))) 52 | 53 | if attr_matrix.shape[0] != adj_matrix.shape[0]: 54 | raise ValueError("Dimensions of the adjacency and attribute matrices don't agree") 55 | 56 | if labels is not None: 57 | if labels.shape[0] != adj_matrix.shape[0]: 58 | raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree") 59 | 60 | if node_names is not None: 61 | if len(node_names) != adj_matrix.shape[0]: 62 | raise ValueError("Dimensions of the adjacency matrix and the node names don't agree") 63 | 64 | if attr_names is not None: 65 | if len(attr_names) != attr_matrix.shape[1]: 66 | raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree") 67 | 68 | self.adj_matrix = adj_matrix 69 | self.attr_matrix = attr_matrix 70 | self.labels = labels 71 | self.node_names = node_names 72 | self.attr_names = attr_names 73 | self.class_names = class_names 74 | self.metadata = metadata 75 | 76 | def num_nodes(self): 77 | """Get the number of nodes in the graph.""" 78 | return self.adj_matrix.shape[0] 79 | 80 | def num_edges(self): 81 | """Get the number of edges in the graph. 82 | 83 | For undirected graphs, (i, j) and (j, i) are counted as single edge. 84 | """ 85 | if self.is_directed(): 86 | return int(self.adj_matrix.nnz) 87 | else: 88 | return int(self.adj_matrix.nnz / 2) 89 | 90 | def get_neighbors(self, idx): 91 | """Get the indices of neighbors of a given node. 92 | 93 | Parameters 94 | ---------- 95 | idx : int 96 | Index of the node whose neighbors are of interest. 97 | 98 | """ 99 | return self.adj_matrix[idx].indices 100 | 101 | def is_directed(self): 102 | """Check if the graph is directed (adjacency matrix is not symmetric).""" 103 | return (self.adj_matrix != self.adj_matrix.T).sum() != 0 104 | 105 | def to_undirected(self): 106 | """Convert to an undirected graph (make adjacency matrix symmetric).""" 107 | if self.is_weighted(): 108 | raise ValueError("Convert to unweighted graph first.") 109 | else: 110 | self.adj_matrix = self.adj_matrix + self.adj_matrix.T 111 | self.adj_matrix[self.adj_matrix != 0] = 1 112 | return self 113 | 114 | def is_weighted(self): 115 | """Check if the graph is weighted (edge weights other than 1).""" 116 | return np.any(np.unique(self.adj_matrix[self.adj_matrix != 0].A1) != 1) 117 | 118 | def to_unweighted(self): 119 | """Convert to an unweighted graph (set all edge weights to 1).""" 120 | self.adj_matrix.data = np.ones_like(self.adj_matrix.data) 121 | return self 122 | 123 | # Quality of life (shortcuts) 124 | def standardize(self): 125 | """Select the LCC of the unweighted/undirected/no-self-loop graph. 126 | 127 | All changes are done inplace. 128 | 129 | """ 130 | G = self.to_unweighted().to_undirected() 131 | G = eliminate_self_loops(G) 132 | G = largest_connected_components(G, 1) 133 | return G 134 | 135 | def unpack(self): 136 | """Return the (A, X, z) triplet.""" 137 | return self.adj_matrix, self.attr_matrix, self.labels 138 | 139 | 140 | def eliminate_self_loops(G): 141 | G.adj_matrix = eliminate_self_loops_adj(G.adj_matrix) 142 | return G 143 | 144 | 145 | def load_dataset(data_path): 146 | """Load a dataset. 147 | 148 | Parameters 149 | ---------- 150 | data_path : str 151 | Name of the dataset to load. 152 | 153 | Returns 154 | ------- 155 | sparse_graph : SparseGraph 156 | The requested dataset in sparse format. 157 | 158 | """ 159 | if not data_path.endswith('.npz'): 160 | data_path += '.npz' 161 | if os.path.isfile(data_path): 162 | return load_npz_to_sparse_graph(data_path) 163 | else: 164 | raise ValueError(f"{data_path} doesn't exist.") 165 | 166 | 167 | def load_npz_to_sparse_graph(file_name): 168 | """Load a SparseGraph from a Numpy binary file. 169 | 170 | Parameters 171 | ---------- 172 | file_name : str 173 | Name of the file to load. 174 | 175 | Returns 176 | ------- 177 | sparse_graph : SparseGraph 178 | Graph in sparse matrix format. 179 | 180 | """ 181 | with np.load(file_name) as loader: 182 | loader = dict(loader) 183 | adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']), 184 | shape=loader['adj_shape']) 185 | 186 | if 'attr_data' in loader: 187 | # Attributes are stored as a sparse CSR matrix 188 | attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']), 189 | shape=loader['attr_shape']) 190 | elif 'attr_matrix' in loader: 191 | # Attributes are stored as a (dense) np.ndarray 192 | attr_matrix = loader['attr_matrix'] 193 | else: 194 | attr_matrix = None 195 | 196 | if 'labels_data' in loader: 197 | # Labels are stored as a CSR matrix 198 | labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']), 199 | shape=loader['labels_shape']) 200 | elif 'labels' in loader: 201 | # Labels are stored as a numpy array 202 | labels = loader['labels'] 203 | else: 204 | labels = None 205 | 206 | node_names = loader.get('node_names') 207 | attr_names = loader.get('attr_names') 208 | class_names = loader.get('class_names') 209 | metadata = loader.get('metadata') 210 | 211 | return SparseGraph(adj_matrix, attr_matrix, labels, node_names, attr_names, class_names, metadata) 212 | 213 | 214 | def save_sparse_graph_to_npz(filepath, sparse_graph): 215 | """Save a SparseGraph to a Numpy binary file. 216 | 217 | Parameters 218 | ---------- 219 | filepath : str 220 | Name of the output file. 221 | sparse_graph : gust.SparseGraph 222 | Graph in sparse matrix format. 223 | 224 | """ 225 | data_dict = { 226 | 'adj_data': sparse_graph.adj_matrix.data, 227 | 'adj_indices': sparse_graph.adj_matrix.indices, 228 | 'adj_indptr': sparse_graph.adj_matrix.indptr, 229 | 'adj_shape': sparse_graph.adj_matrix.shape 230 | } 231 | if sp.isspmatrix(sparse_graph.attr_matrix): 232 | data_dict['attr_data'] = sparse_graph.attr_matrix.data 233 | data_dict['attr_indices'] = sparse_graph.attr_matrix.indices 234 | data_dict['attr_indptr'] = sparse_graph.attr_matrix.indptr 235 | data_dict['attr_shape'] = sparse_graph.attr_matrix.shape 236 | elif isinstance(sparse_graph.attr_matrix, np.ndarray): 237 | data_dict['attr_matrix'] = sparse_graph.attr_matrix 238 | 239 | if sp.isspmatrix(sparse_graph.labels): 240 | data_dict['labels_data'] = sparse_graph.labels.data 241 | data_dict['labels_indices'] = sparse_graph.labels.indices 242 | data_dict['labels_indptr'] = sparse_graph.labels.indptr 243 | data_dict['labels_shape'] = sparse_graph.labels.shape 244 | elif isinstance(sparse_graph.labels, np.ndarray): 245 | data_dict['labels'] = sparse_graph.labels 246 | 247 | if sparse_graph.node_names is not None: 248 | data_dict['node_names'] = sparse_graph.node_names 249 | 250 | if sparse_graph.attr_names is not None: 251 | data_dict['attr_names'] = sparse_graph.attr_names 252 | 253 | if sparse_graph.class_names is not None: 254 | data_dict['class_names'] = sparse_graph.class_names 255 | 256 | if sparse_graph.metadata is not None: 257 | data_dict['metadata'] = sparse_graph.metadata 258 | 259 | if not filepath.endswith('.npz'): 260 | filepath += '.npz' 261 | 262 | np.savez(filepath, **data_dict) 263 | -------------------------------------------------------------------------------- /graph/data_process/make_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .io import load_dataset 3 | from .preprocess import to_binary_bag_of_words, remove_underrepresented_classes, eliminate_self_loops, binarize_labels 4 | 5 | 6 | def get_dataset(name, data_path, standardize, train_examples_per_class=None, val_examples_per_class=None): 7 | dataset_graph = load_dataset(data_path) 8 | 9 | # some standardization preprocessing 10 | if standardize: 11 | dataset_graph = dataset_graph.standardize() 12 | else: 13 | dataset_graph = dataset_graph.to_undirected() 14 | dataset_graph = eliminate_self_loops(dataset_graph) 15 | 16 | if train_examples_per_class is not None and val_examples_per_class is not None: 17 | if name == 'cora_full': 18 | # cora_full has some classes that have very few instances. We have to remove these in order for 19 | # split generation not to fail 20 | dataset_graph = remove_underrepresented_classes(dataset_graph, 21 | train_examples_per_class, val_examples_per_class) 22 | dataset_graph = dataset_graph.standardize() 23 | # To avoid future bugs: the above two lines should be repeated to a fixpoint, otherwise code below might 24 | # fail. However, for cora_full the fixpoint is reached after one iteration, so leave it like this for now. 25 | 26 | graph_adj, node_features, labels = dataset_graph.unpack() 27 | labels = binarize_labels(labels) 28 | 29 | # convert to binary bag-of-words feature representation if necessary 30 | if not is_binary_bag_of_words(node_features): 31 | node_features = to_binary_bag_of_words(node_features) 32 | 33 | # some assertions that need to hold for all datasets 34 | # adj matrix needs to be symmetric 35 | assert (graph_adj != graph_adj.T).nnz == 0 36 | # features need to be binary bag-of-word vectors 37 | assert is_binary_bag_of_words(node_features), f"Non-binary node_features entry!" 38 | 39 | return graph_adj, node_features, labels 40 | 41 | 42 | def get_train_val_test_split(random_state, 43 | labels, 44 | train_examples_per_class=None, val_examples_per_class=None, 45 | test_examples_per_class=None, 46 | train_size=None, val_size=None, test_size=None): 47 | num_samples, num_classes = labels.shape 48 | remaining_indices = list(range(num_samples)) 49 | 50 | if train_examples_per_class is not None: 51 | train_indices = sample_per_class(random_state, labels, train_examples_per_class) 52 | else: 53 | # select train examples with no respect to class distribution 54 | train_indices = random_state.choice(remaining_indices, train_size, replace=False) 55 | 56 | if val_examples_per_class is not None: 57 | val_indices = sample_per_class(random_state, labels, val_examples_per_class, forbidden_indices=train_indices) 58 | else: 59 | remaining_indices = np.setdiff1d(remaining_indices, train_indices) 60 | val_indices = random_state.choice(remaining_indices, val_size, replace=False) 61 | 62 | forbidden_indices = np.concatenate((train_indices, val_indices)) 63 | if test_examples_per_class is not None: 64 | test_indices = sample_per_class(random_state, labels, test_examples_per_class, 65 | forbidden_indices=forbidden_indices) 66 | elif test_size is not None: 67 | remaining_indices = np.setdiff1d(remaining_indices, forbidden_indices) 68 | test_indices = random_state.choice(remaining_indices, test_size, replace=False) 69 | else: 70 | test_indices = np.setdiff1d(remaining_indices, forbidden_indices) 71 | 72 | # assert that there are no duplicates in sets 73 | assert len(set(train_indices)) == len(train_indices) 74 | assert len(set(val_indices)) == len(val_indices) 75 | assert len(set(test_indices)) == len(test_indices) 76 | # assert sets are mutually exclusive 77 | assert len(set(train_indices) - set(val_indices)) == len(set(train_indices)) 78 | assert len(set(train_indices) - set(test_indices)) == len(set(train_indices)) 79 | assert len(set(val_indices) - set(test_indices)) == len(set(val_indices)) 80 | if test_size is None and test_examples_per_class is None: 81 | # all indices must be part of the split 82 | assert len(np.concatenate((train_indices, val_indices, test_indices))) == num_samples 83 | 84 | if train_examples_per_class is not None: 85 | train_labels = labels[train_indices, :] 86 | train_sum = np.sum(train_labels, axis=0) 87 | # assert all classes have equal cardinality 88 | assert np.unique(train_sum).size == 1 89 | 90 | if val_examples_per_class is not None: 91 | val_labels = labels[val_indices, :] 92 | val_sum = np.sum(val_labels, axis=0) 93 | # assert all classes have equal cardinality 94 | assert np.unique(val_sum).size == 1 95 | 96 | if test_examples_per_class is not None: 97 | test_labels = labels[test_indices, :] 98 | test_sum = np.sum(test_labels, axis=0) 99 | # assert all classes have equal cardinality 100 | assert np.unique(test_sum).size == 1 101 | 102 | return train_indices, val_indices, test_indices 103 | 104 | 105 | def sample_per_class(random_state, labels, num_examples_per_class, forbidden_indices=None): 106 | num_samples, num_classes = labels.shape 107 | sample_indices_per_class = {index: [] for index in range(num_classes)} 108 | 109 | # get indices sorted by class 110 | for class_index in range(num_classes): 111 | for sample_index in range(num_samples): 112 | if labels[sample_index, class_index] > 0.0: 113 | if forbidden_indices is None or sample_index not in forbidden_indices: 114 | sample_indices_per_class[class_index].append(sample_index) 115 | 116 | # get specified number of indices for each class 117 | return np.concatenate( 118 | [random_state.choice(sample_indices_per_class[class_index], num_examples_per_class, replace=False) 119 | for class_index in range(len(sample_indices_per_class))]) 120 | 121 | 122 | def is_binary_bag_of_words(features): 123 | features_coo = features.tocoo() 124 | return all(single_entry == 1.0 for _, _, single_entry in zip(features_coo.row, features_coo.col, features_coo.data)) 125 | 126 | -------------------------------------------------------------------------------- /graph/data_process/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | from collections import Counter 4 | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer, normalize 5 | 6 | 7 | def to_binary_bag_of_words(features): 8 | """Converts TF/IDF features to binary bag-of-words features.""" 9 | features_copy = features.tocsr() 10 | features_copy.data[:] = 1.0 11 | return features_copy 12 | 13 | 14 | def normalize_adj(A): 15 | """Compute D^-1/2 * A * D^-1/2.""" 16 | # Make sure that there are no self-loops 17 | A = eliminate_self_loops(A) 18 | D = np.ravel(A.sum(1)) 19 | D[D == 0] = 1 # avoid division by 0 error 20 | D_sqrt = np.sqrt(D) 21 | return A / D_sqrt[:, None] / D_sqrt[None, :] 22 | 23 | 24 | def renormalize_adj(A): 25 | """Renormalize the adjacency matrix (as in the GCN paper).""" 26 | A_tilde = A.tolil() 27 | A_tilde.setdiag(1) 28 | A_tilde = A_tilde.tocsr() 29 | A_tilde.eliminate_zeros() 30 | D = np.ravel(A.sum(1)) 31 | D_sqrt = np.sqrt(D) 32 | return A / D_sqrt[:, None] / D_sqrt[None, :] 33 | 34 | 35 | def row_normalize(matrix): 36 | """Normalize the matrix so that the rows sum up to 1.""" 37 | return normalize(matrix, norm='l1', axis=1) 38 | 39 | 40 | def add_self_loops(A, value=1.0): 41 | """Set the diagonal.""" 42 | A = A.tolil() # make sure we work on a copy of the original matrix 43 | A.setdiag(value) 44 | A = A.tocsr() 45 | if value == 0: 46 | A.eliminate_zeros() 47 | return A 48 | 49 | 50 | def eliminate_self_loops(A): 51 | """Remove self-loops from the adjacency matrix.""" 52 | A = A.tolil() 53 | A.setdiag(0) 54 | A = A.tocsr() 55 | A.eliminate_zeros() 56 | return A 57 | 58 | 59 | def largest_connected_components(sparse_graph, n_components=1): 60 | """Select the largest connected components in the graph. 61 | 62 | Parameters 63 | ---------- 64 | sparse_graph : SparseGraph 65 | Input graph. 66 | n_components : int, default 1 67 | Number of largest connected components to keep. 68 | 69 | Returns 70 | ------- 71 | sparse_graph : SparseGraph 72 | Subgraph of the input graph where only the nodes in largest n_components are kept. 73 | 74 | """ 75 | _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix) 76 | component_sizes = np.bincount(component_indices) 77 | components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending 78 | nodes_to_keep = [ 79 | idx for (idx, component) in enumerate(component_indices) if component in components_to_keep 80 | ] 81 | return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep) 82 | 83 | 84 | def create_subgraph(sparse_graph, _sentinel=None, nodes_to_remove=None, nodes_to_keep=None): 85 | """Create a graph with the specified subset of nodes. 86 | 87 | Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None. 88 | Note that to avoid confusion, it is required to pass node indices as named arguments to this function. 89 | 90 | Parameters 91 | ---------- 92 | sparse_graph : SparseGraph 93 | Input graph. 94 | _sentinel : None 95 | Internal, to prevent passing positional arguments. Do not use. 96 | nodes_to_remove : array-like of int 97 | Indices of nodes that have to removed. 98 | nodes_to_keep : array-like of int 99 | Indices of nodes that have to be kept. 100 | 101 | Returns 102 | ------- 103 | sparse_graph : SparseGraph 104 | Graph with specified nodes removed. 105 | 106 | """ 107 | # Check that arguments are passed correctly 108 | if _sentinel is not None: 109 | raise ValueError("Only call `create_subgraph` with named arguments'," 110 | " (nodes_to_remove=...) or (nodes_to_keep=...)") 111 | if nodes_to_remove is None and nodes_to_keep is None: 112 | raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.") 113 | elif nodes_to_remove is not None and nodes_to_keep is not None: 114 | raise ValueError("Only one of nodes_to_remove or nodes_to_keep must be provided.") 115 | elif nodes_to_remove is not None: 116 | nodes_to_keep = [i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove] 117 | elif nodes_to_keep is not None: 118 | nodes_to_keep = sorted(nodes_to_keep) 119 | else: 120 | raise RuntimeError("This should never happen.") 121 | 122 | sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep] 123 | if sparse_graph.attr_matrix is not None: 124 | sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep] 125 | if sparse_graph.labels is not None: 126 | sparse_graph.labels = sparse_graph.labels[nodes_to_keep] 127 | if sparse_graph.node_names is not None: 128 | sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep] 129 | return sparse_graph 130 | 131 | 132 | def binarize_labels(labels, sparse_output=False, return_classes=False): 133 | """Convert labels vector to a binary label matrix. 134 | 135 | In the default single-label case, labels look like 136 | labels = [y1, y2, y3, ...]. 137 | Also supports the multi-label format. 138 | In this case, labels should look something like 139 | labels = [[y11, y12], [y21, y22, y23], [y31], ...]. 140 | 141 | Parameters 142 | ---------- 143 | labels : array-like, shape [num_samples] 144 | Array of node labels in categorical single- or multi-label format. 145 | sparse_output : bool, default False 146 | Whether return the label_matrix in CSR format. 147 | return_classes : bool, default False 148 | Whether return the classes corresponding to the columns of the label matrix. 149 | 150 | Returns 151 | ------- 152 | label_matrix : np.ndarray or sp.csr_matrix, shape [num_samples, num_classes] 153 | Binary matrix of class labels. 154 | num_classes = number of unique values in "labels" array. 155 | label_matrix[i, k] = 1 <=> node i belongs to class k. 156 | classes : np.array, shape [num_classes], optional 157 | Classes that correspond to each column of the label_matrix. 158 | 159 | """ 160 | if hasattr(labels[0], '__iter__'): # labels[0] is iterable <=> multilabel format 161 | binarizer = MultiLabelBinarizer(sparse_output=sparse_output) 162 | else: 163 | binarizer = LabelBinarizer(sparse_output=sparse_output) 164 | label_matrix = binarizer.fit_transform(labels).astype(np.float32) 165 | return (label_matrix, binarizer.classes_) if return_classes else label_matrix 166 | 167 | 168 | def remove_underrepresented_classes(g, train_examples_per_class, val_examples_per_class): 169 | """Remove nodes from graph that correspond to a class of which there are less than 170 | num_classes * train_examples_per_class + num_classes * val_examples_per_class nodes. 171 | 172 | Those classes would otherwise break the training procedure. 173 | """ 174 | min_examples_per_class = train_examples_per_class + val_examples_per_class 175 | examples_counter = Counter(g.labels) 176 | keep_classes = set(class_ for class_, count in examples_counter.items() if count > min_examples_per_class) 177 | keep_indices = [i for i in range(len(g.labels)) if g.labels[i] in keep_classes] 178 | 179 | return create_subgraph(g, nodes_to_keep=keep_indices) 180 | -------------------------------------------------------------------------------- /graph/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DiffusionLayer(nn.Module): 7 | def __init__(self, step): 8 | super(DiffusionLayer, self).__init__() 9 | self.step = step 10 | 11 | def forward(self, x, adj, diagonal): 12 | x = x - self.step * torch.matmul(diagonal - adj, x) 13 | return x 14 | 15 | 16 | class DiffusionNet(nn.Module): 17 | def __init__(self, n_features, num_classes, step, layer_num, dropout, diagonal): 18 | super(DiffusionNet, self).__init__() 19 | 20 | self.linear = nn.Linear(n_features, n_features) 21 | 22 | self.diffusion_layer = DiffusionLayer(step) 23 | self.classifier = nn.Linear(n_features, num_classes) 24 | self.dropout = dropout 25 | self.layer_num = layer_num 26 | self.diagonal = diagonal 27 | 28 | def forward(self, x, adj): 29 | x = x + F.relu(self.linear(x)) 30 | for j in range(self.layer_num): 31 | x = F.dropout(x, self.dropout, training=self.training) 32 | x = self.diffusion_layer(x, adj, self.diagonal) 33 | 34 | out = self.classifier(x) 35 | return out 36 | -------------------------------------------------------------------------------- /graph/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | import torch.nn as nn 7 | from utils import load_data, accuracy 8 | from model import DiffusionNet 9 | import copy 10 | 11 | # Training settings 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--seed', type=int, default=42, help='Random seed.') 14 | parser.add_argument('--num_splits', type=int, default=100, help='Number of different splits.') 15 | parser.add_argument('--num_inits', type=int, default=20, help='Number of different initializations.') 16 | parser.add_argument('--device', type=str, default='0') 17 | 18 | parser.add_argument('--max_epochs', type=int, default=10000, help='Max uumber of epochs to train.') 19 | parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.') 20 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).') 21 | parser.add_argument('--patience', type=int, default=50, help='Early Stop Patience.') 22 | 23 | parser.add_argument('--dataset', type=str, default="cora") 24 | parser.add_argument('--step_size', type=float) 25 | parser.add_argument('--layer_num', type=int) 26 | parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).') 27 | 28 | args = parser.parse_args() 29 | 30 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 31 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # specify which GPU(s) to be used 32 | 33 | 34 | def train(model, optimizer, adj, features, labels, idx_train): 35 | model.train() 36 | optimizer.zero_grad() 37 | 38 | output = model(features, adj) 39 | loss = nn.CrossEntropyLoss()(output[idx_train], labels[idx_train]) 40 | loss.backward() 41 | optimizer.step() 42 | 43 | 44 | def val(model, adj, features, labels, idx_val): 45 | model.eval() 46 | output = model(features, adj) 47 | loss = nn.CrossEntropyLoss()(output[idx_val], labels[idx_val]) 48 | acc = accuracy(output[idx_val], labels[idx_val]) 49 | loss = loss.detach().cpu().numpy() 50 | acc = acc.cpu().numpy() 51 | 52 | return loss, acc 53 | 54 | 55 | def test(model, adj, features, labels, idx_test): 56 | model.eval() 57 | output = model(features, adj) 58 | acc_test = accuracy(output[idx_test], labels[idx_test]) 59 | return acc_test 60 | 61 | 62 | def run_single_trial_of_single_split(adj, features, labels, idx_train, idx_val, idx_test, diagonal, torch_seeds): 63 | torch.manual_seed(torch_seeds) 64 | torch.cuda.manual_seed(torch_seeds) 65 | 66 | model = DiffusionNet(n_features=features.shape[1], num_classes=labels.max().item() + 1, step=args.step_size, 67 | layer_num=args.layer_num, dropout=args.dropout, diagonal=diagonal.cuda()) 68 | 69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 70 | 71 | model = model.cuda() 72 | features = features.cuda() 73 | adj = adj.cuda() 74 | labels = labels.cuda() 75 | idx_train = idx_train.cuda() 76 | idx_val = idx_val.cuda() 77 | idx_test = idx_test.cuda() 78 | 79 | val_loss_min = np.inf 80 | val_acc_max = 0 81 | patience_step = 0 82 | best_state_dict = None 83 | 84 | val_loss_list = [] 85 | val_acc_list = [] 86 | for epoch in range(args.max_epochs): 87 | train(model, optimizer, adj, features, labels, idx_train) 88 | val_loss, val_acc = val(model, adj, features, labels, idx_val) 89 | val_loss_list.append(val_loss) 90 | val_acc_list.append(val_acc) 91 | 92 | if val_loss <= val_loss_min or val_acc >= val_acc_max: 93 | val_loss_min = np.min((val_loss, val_loss_min)) 94 | val_acc_max = np.max((val_acc, val_acc_max)) 95 | patience_step = 0 96 | best_state_dict = copy.deepcopy(model.state_dict()) 97 | else: 98 | patience_step += 1 99 | 100 | if patience_step >= args.patience: 101 | model.load_state_dict(best_state_dict) 102 | break 103 | 104 | acc = test(model, adj, features, labels, idx_test) 105 | acc = acc.cpu().numpy() 106 | return acc 107 | 108 | 109 | def run_single_split(seed): 110 | random_state = np.random.RandomState(seed) 111 | adj, features, labels, idx_train, idx_val, idx_test, diagonal = load_data(args.dataset, random_state) 112 | torch_seeds = random_state.randint(0, 1000000, args.num_inits) # 20 trials for each split 113 | acc_list = [] 114 | for i in range(args.num_inits): 115 | acc = run_single_trial_of_single_split(adj, features, labels, idx_train, idx_val, idx_test, diagonal, 116 | torch_seeds[i]) 117 | acc_list.append(acc) 118 | return np.array(acc_list) 119 | 120 | 121 | def main(): 122 | random_state = np.random.RandomState(args.seed) 123 | single_split_seed = random_state.randint(0, 1000000, args.num_splits) # 100 random splits 124 | 125 | total_acc_list = [] 126 | for i in range(args.num_splits): 127 | acc_of_single_split = run_single_split(single_split_seed[i]) 128 | print(acc_of_single_split) 129 | total_acc_list.append(acc_of_single_split) 130 | 131 | print(np.mean(total_acc_list) * 100) 132 | print(np.std(total_acc_list) * 100) 133 | print(args.dropout) 134 | print(args.step_size) 135 | print(args.layer_num) 136 | 137 | 138 | if __name__ == '__main__': 139 | main() 140 | -------------------------------------------------------------------------------- /graph/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | from data_process.make_dataset import get_dataset, get_train_val_test_split 5 | 6 | 7 | def load_data(dataset_str, random_state): 8 | data_path = "data/" + dataset_str + ".npz" 9 | 10 | adj, features, labels = get_dataset(dataset_str, data_path, standardize=True, train_examples_per_class=20, 11 | val_examples_per_class=30) 12 | idx_train, idx_val, idx_test = get_train_val_test_split(random_state, labels, train_examples_per_class=20, 13 | val_examples_per_class=30, test_size=None) 14 | 15 | features = normalize_features(features) 16 | adj = normalize_adj(adj + sp.eye(adj.shape[0]), normalization="symmetric") 17 | 18 | diagonal = sp.diags(adj.sum(1).A1) 19 | diagonal = sparse_mx_to_torch_sparse_tensor(diagonal) 20 | 21 | adj = sparse_mx_to_torch_sparse_tensor(adj) 22 | features = torch.FloatTensor(features.todense()) 23 | labels = torch.LongTensor(labels.argmax(axis=-1)) 24 | 25 | idx_train = torch.LongTensor(idx_train) 26 | idx_val = torch.LongTensor(idx_val) 27 | idx_test = torch.LongTensor(idx_test) 28 | 29 | return adj, features, labels, idx_train, idx_val, idx_test, diagonal 30 | 31 | 32 | def normalize_features(features): 33 | """Row-normalize feature matrix""" 34 | rowsum = np.array(features.sum(1)) 35 | r_inv = np.power(rowsum, -1).flatten() 36 | r_inv[np.isinf(r_inv)] = 0. 37 | r_mat_inv = sp.diags(r_inv) 38 | features = r_mat_inv.dot(features) 39 | return features 40 | 41 | 42 | def normalize_adj(adj, normalization="symmetric"): 43 | """Symmetrically or row normalize adjacency matrix.""" 44 | if normalization == "symmetric": 45 | rowsum = np.array(adj.sum(1)) 46 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 47 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 48 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 49 | mx = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt) 50 | elif normalization == "row": 51 | rowsum = np.array(adj.sum(1)) 52 | r_inv = np.power(rowsum, -1).flatten() 53 | r_inv[np.isinf(r_inv)] = 0. 54 | r_mat_inv = sp.diags(r_inv) 55 | mx = r_mat_inv.dot(adj) 56 | else: 57 | raise NotImplementedError 58 | return mx 59 | 60 | 61 | def accuracy(output, labels): 62 | preds = output.max(1)[1].type_as(labels) 63 | correct = preds.eq(labels).double() 64 | correct = correct.sum() 65 | return correct / len(labels) 66 | 67 | 68 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 69 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 70 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 71 | indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 72 | values = torch.from_numpy(sparse_mx.data) 73 | shape = torch.Size(sparse_mx.shape) 74 | return torch.sparse.FloatTensor(indices, values, shape) 75 | -------------------------------------------------------------------------------- /synthetic/two_circle_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import random 8 | 9 | random.seed(42) 10 | np.random.seed(42) 11 | torch.manual_seed(42) 12 | 13 | 14 | def make_circles(n_samples=500): 15 | """Make two interleaving half circles. 16 | A simple toy dataset to visualize clustering and classification 17 | algorithms. Read more in the :ref:`User Guide `. 18 | Parameters 19 | """ 20 | inner_circ_x = np.cos(np.linspace(0, 2 * np.pi, n_samples)) 21 | inner_circ_y = np.sin(np.linspace(0, 2 * np.pi, n_samples)) 22 | outer_circ_x = 2.0 * np.cos(np.linspace(0, 2 * np.pi, n_samples)) 23 | outer_circ_y = 2.0 * np.sin(np.linspace(0, 2 * np.pi, n_samples)) 24 | 25 | x = np.append(outer_circ_x, inner_circ_x) 26 | y = np.append(outer_circ_y, inner_circ_y) 27 | 28 | x += np.random.randn(1000) * 0.05 29 | y += np.random.randn(1000) * 0.05 30 | return x, y 31 | 32 | 33 | def calculate_weight(x, y, sigma=0.5, n_top=50): 34 | weight = np.zeros([1000, 1000]) 35 | for i in range(1000): 36 | for j in range(1000): 37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2) 38 | 39 | # Sparse and Normalize 40 | for i in range(1000): 41 | idx = np.argpartition(weight[i], -n_top)[:-n_top] 42 | weight[i, idx] = 0. 43 | weight[i] /= weight[i].sum() 44 | 45 | return weight 46 | 47 | 48 | class DiffusionLayer(nn.Module): 49 | def __init__(self, step): 50 | super(DiffusionLayer, self).__init__() 51 | self.step = step 52 | 53 | def forward(self, x, adj): 54 | identity = torch.eye(x.size(0), device=x.device) 55 | x = x - self.step * torch.matmul(identity - adj, x.flatten(1)).view_as(x) 56 | return x 57 | 58 | 59 | class Net(nn.Module): 60 | def __init__(self): 61 | super(Net, self).__init__() 62 | self.fc1 = nn.Linear(2, 2) 63 | self.fc2 = nn.Linear(2, 2) 64 | self.classifier = nn.Linear(2, 2) 65 | self.layer_num = 200 66 | self.diffusion_layer = DiffusionLayer(step=1.0) 67 | 68 | def forward(self, x, weight): 69 | out = self.fc2(F.relu(self.fc1(x))) + x 70 | 71 | # Uncomment following lines to use diffusion 72 | # for i in range(self.layer_num): 73 | # out = self.diffusion_layer(out, weight) 74 | res = self.classifier(out) 75 | return res, out 76 | 77 | 78 | def train(model, inputs, weight, labels): 79 | optimizer = optim.SGD(model.parameters(), lr=1.0, momentum=0.9, weight_decay=5e-4) 80 | optimizer.zero_grad() 81 | 82 | outputs, features = model(inputs, weight) 83 | loss = nn.CrossEntropyLoss()(outputs, labels) 84 | loss.backward() 85 | optimizer.step() 86 | 87 | 88 | def test(model, inputs, weight, labels): 89 | outputs, features = model(inputs, weight) 90 | 91 | pred = outputs.argmax(1) 92 | acc = torch.eq(pred, labels).sum() 93 | return acc.item() 94 | 95 | 96 | def main(): 97 | x, y = make_circles() 98 | 99 | color = [i for i in ['red', 'blue'] for _ in range(500)] 100 | plt.scatter(x, y, c=color, marker='.') 101 | plt.xticks([]) 102 | plt.yticks([]) 103 | plt.savefig("figures/two_circle/raw.png", bbox_inches='tight') 104 | 105 | weight = calculate_weight(x, y) 106 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight) 107 | inputs = torch.stack([x, y], dim=1).float() 108 | weight = weight.float() 109 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long() 110 | 111 | acc_list = np.zeros(21) 112 | model = Net() 113 | for epoch in range(21): 114 | outputs, features = model(inputs, weight) 115 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy() 116 | acc = test(model, inputs, weight, labels) 117 | print(epoch, acc) 118 | acc_list[epoch] = acc 119 | 120 | plt.cla() 121 | color = [i for i in ['red', 'blue'] for _ in range(500)] 122 | plt.scatter(x, y, c=color, marker='.') 123 | plt.xticks([]) 124 | plt.yticks([]) 125 | plt.savefig("figures/two_circle/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight') 126 | plt.savefig("figures/two_circle/with_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight') 127 | 128 | train(model, inputs, weight, labels) 129 | print(acc_list) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /synthetic/two_moon_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import random 8 | 9 | random.seed(40) 10 | np.random.seed(40) 11 | torch.manual_seed(40) 12 | 13 | 14 | def make_moons(n_samples=500): 15 | """Make two interleaving half circles. 16 | A simple toy dataset to visualize clustering and classification 17 | algorithms. Read more in the :ref:`User Guide `. 18 | Parameters 19 | """ 20 | outer_circ_x = np.cos(np.linspace(0, np.pi, n_samples)) 21 | outer_circ_y = np.sin(np.linspace(0, np.pi, n_samples)) 22 | inner_circ_x = 1 - np.cos(np.linspace(0, np.pi, n_samples)) 23 | inner_circ_y = 0.5 - np.sin(np.linspace(0, np.pi, n_samples)) 24 | 25 | x = np.append(outer_circ_x, inner_circ_x) 26 | y = np.append(outer_circ_y, inner_circ_y) 27 | 28 | x += np.random.randn(1000) * 0.05 29 | y += np.random.randn(1000) * 0.05 30 | return x, y 31 | 32 | 33 | def calculate_weight(x, y, sigma=0.5, n_top=25): 34 | weight = np.zeros([1000, 1000]) 35 | for i in range(1000): 36 | for j in range(1000): 37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2) 38 | 39 | # Sparse and Normalize 40 | for i in range(1000): 41 | idx = np.argpartition(weight[i], -n_top)[:-n_top] 42 | weight[i, idx] = 0. 43 | weight[i] /= weight[i].sum() 44 | 45 | return weight 46 | 47 | 48 | class DiffusionLayer(nn.Module): 49 | def __init__(self, step): 50 | super(DiffusionLayer, self).__init__() 51 | self.step = step 52 | 53 | def forward(self, x, adj): 54 | diagonal = torch.eye(x.size(0), device=x.device) 55 | x = x - self.step * torch.matmul(diagonal - adj, x.flatten(1)).view_as(x) 56 | return x 57 | 58 | 59 | class Net(nn.Module): 60 | def __init__(self): 61 | super(Net, self).__init__() 62 | self.fc1 = nn.Linear(2, 2) 63 | self.fc2 = nn.Linear(2, 2) 64 | self.classifier = nn.Linear(2, 2) 65 | self.layer_num = 60 66 | self.diffusion_layer = DiffusionLayer(step=1.0) 67 | 68 | def forward(self, x, weight): 69 | out = self.fc2(F.relu(self.fc1(x))) + x 70 | 71 | # Uncomment following lines to use diffusion 72 | # for i in range(self.layer_num): 73 | # out = self.diffusion_layer(out, weight) 74 | res = self.classifier(out) 75 | return res, out 76 | 77 | 78 | def train(model, inputs, weight, labels): 79 | optimizer = optim.SGD(model.parameters(), lr=1.0, momentum=0.9, weight_decay=5e-4) 80 | optimizer.zero_grad() 81 | 82 | outputs, features = model(inputs, weight) 83 | loss = nn.CrossEntropyLoss()(outputs, labels) 84 | loss.backward() 85 | optimizer.step() 86 | 87 | 88 | def test(model, inputs, weight, labels): 89 | outputs, features = model(inputs, weight) 90 | 91 | pred = outputs.argmax(1) 92 | acc = torch.eq(pred, labels).sum() 93 | return acc.item() 94 | 95 | 96 | def main(): 97 | x, y = make_moons() 98 | 99 | color = [i for i in ['red', 'blue'] for _ in range(500)] 100 | plt.scatter(x, y, c=color, marker='.') 101 | plt.xticks([]) 102 | plt.yticks([]) 103 | plt.savefig("figures/two_moon/raw.png", bbox_inches='tight') 104 | 105 | weight = calculate_weight(x, y) 106 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight) 107 | inputs = torch.stack([x, y], dim=1).float() 108 | weight = weight.float() 109 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long() 110 | 111 | acc_list = np.zeros(21) 112 | model = Net() 113 | for epoch in range(21): 114 | outputs, features = model(inputs, weight) 115 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy() 116 | acc = test(model, inputs, weight, labels) 117 | print(epoch, acc) 118 | acc_list[epoch] = acc 119 | 120 | plt.cla() 121 | color = [i for i in ['red', 'blue'] for _ in range(500)] 122 | plt.scatter(x, y, c=color, marker='.') 123 | plt.title("accuracy=" + str(round(acc / 1000 * 100, 1)) + "%", fontsize=40) 124 | plt.xticks([]) 125 | plt.yticks([]) 126 | plt.savefig("figures/two_moon/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight') 127 | 128 | train(model, inputs, weight, labels) 129 | print(acc_list) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /synthetic/two_spiral_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import random 8 | 9 | random.seed(42) 10 | np.random.seed(42) 11 | torch.manual_seed(42) 12 | 13 | 14 | def make_spirals(n_samples=500): 15 | """Make two interleaving half circles. 16 | A simple toy dataset to visualize clustering and classification 17 | algorithms. Read more in the :ref:`User Guide `. 18 | Parameters 19 | """ 20 | 21 | a1, b1, a2, b2 = 1.0, 1.0, -1.0, -1.0 22 | theta = np.linspace(0, 2 * np.pi, n_samples) 23 | 24 | x1 = (a1 + b1 * theta) * np.cos(theta) 25 | y1 = (a1 + b1 * theta) * np.sin(theta) 26 | x2 = (a2 + b2 * theta) * np.cos(theta) 27 | y2 = (a2 + b2 * theta) * np.sin(theta) 28 | 29 | x = np.append(x1, x2) 30 | y = np.append(y1, y2) 31 | 32 | x += np.random.randn(1000) * 0.1 33 | y += np.random.randn(1000) * 0.1 34 | 35 | return x, y 36 | 37 | 38 | def calculate_weight(x, y, sigma=0.5, n_top=25): 39 | weight = np.zeros([1000, 1000]) 40 | for i in range(1000): 41 | for j in range(1000): 42 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2) 43 | 44 | # Sparse and Normalize 45 | for i in range(1000): 46 | idx = np.argpartition(weight[i], -n_top)[:-n_top] 47 | weight[i, idx] = 0. 48 | weight[i] /= weight[i].sum() 49 | 50 | return weight 51 | 52 | 53 | class DiffusionLayer(nn.Module): 54 | def __init__(self, step): 55 | super(DiffusionLayer, self).__init__() 56 | self.step = step 57 | 58 | def forward(self, x, adj): 59 | identity = torch.eye(x.size(0), device=x.device) 60 | x = x - self.step * torch.matmul(identity - adj, x.flatten(1)).view_as(x) 61 | return x 62 | 63 | 64 | class Net(nn.Module): 65 | def __init__(self): 66 | super(Net, self).__init__() 67 | self.fc1 = nn.Linear(2, 2) 68 | self.fc2 = nn.Linear(2, 2) 69 | self.classifier = nn.Linear(2, 2) 70 | self.layer_num = 900 71 | self.diffusion_layer = DiffusionLayer(step=1.0) 72 | 73 | def forward(self, x, weight): 74 | out = self.fc2(F.relu(self.fc1(x))) + x 75 | 76 | # Uncomment following lines to use diffusion 77 | # for i in range(self.layer_num): 78 | # out = self.diffusion_layer(out, weight) 79 | res = self.classifier(out) 80 | return res, out 81 | 82 | 83 | def train(model, inputs, weight, labels): 84 | optimizer = optim.SGD(model.parameters(), lr=0.8, momentum=0.9, weight_decay=5e-4) 85 | optimizer.zero_grad() 86 | 87 | outputs, features = model(inputs, weight) 88 | loss = nn.CrossEntropyLoss()(outputs, labels) 89 | loss.backward() 90 | optimizer.step() 91 | 92 | 93 | def test(model, inputs, weight, labels): 94 | outputs, features = model(inputs, weight) 95 | 96 | pred = outputs.argmax(1) 97 | acc = torch.eq(pred, labels).sum() 98 | return acc.item() 99 | 100 | 101 | def main(): 102 | x, y = make_spirals() 103 | 104 | color = [i for i in ['red', 'blue'] for _ in range(500)] 105 | plt.scatter(x, y, c=color, marker='.') 106 | plt.xticks([]) 107 | plt.yticks([]) 108 | plt.savefig("figures/two_spiral/raw.png", bbox_inches='tight') 109 | 110 | weight = calculate_weight(x, y) 111 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight) 112 | inputs = torch.stack([x, y], dim=1).float() 113 | weight = weight.float() 114 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long() 115 | 116 | acc_list = np.zeros(21) 117 | model = Net() 118 | for epoch in range(21): 119 | outputs, features = model(inputs, weight) 120 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy() 121 | acc = test(model, inputs, weight, labels) 122 | print(epoch, acc) 123 | acc_list[epoch] = acc 124 | 125 | plt.cla() 126 | color = [i for i in ['red', 'blue'] for _ in range(500)] 127 | plt.scatter(x, y, c=color, marker='.') 128 | plt.title("accuracy=" + str(round(acc / 1000 * 100, 1)) + "%", fontsize=40) 129 | plt.xticks([]) 130 | plt.yticks([]) 131 | plt.savefig("figures/two_spiral/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight') 132 | 133 | train(model, inputs, weight, labels) 134 | print(acc_list) 135 | 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /synthetic/xor_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | np.random.seed(42) 5 | 6 | 7 | def generate_points(samples_each=100): 8 | """ 9 | Uniformly sample 100 points from 4 circles in R^2. 10 | Circles are centered at (0,0),(2,2),(2,0),(0,2), respectively. Their diameters are all 1. 11 | Points from circles centered at (0,0) and (2,2) belong to class 1. Others belong to class 2. 12 | :return: Two numpy array of size (100,) 13 | """ 14 | diameter = 1.5 15 | radius = diameter / 2 16 | samples_total = 4 * samples_each 17 | # Why np.sqrt()? https://stats.stackexchange.com/questions/120527/simulate-a-uniform-distribution-on-a-disc 18 | r = np.sqrt(np.random.uniform(0, radius ** 2, samples_total)) 19 | theta = np.pi * np.random.uniform(0, 2, samples_total) 20 | x = r * np.cos(theta) 21 | y = r * np.sin(theta) 22 | 23 | for i in range(samples_each, 3 * samples_each): 24 | x[i] += 2. 25 | for i in range(samples_each, 2 * samples_each): 26 | y[i] += 2. 27 | for i in range(3 * samples_each, 4 * samples_each): 28 | y[i] += 2. 29 | 30 | return x, y 31 | 32 | 33 | def calculate_weight(x, y, sigma=0.5, n_top=20, samples_total=400): 34 | weight = np.zeros([samples_total, samples_total]) 35 | for i in range(samples_total): 36 | for j in range(samples_total): 37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2) 38 | 39 | # Sparse and Normalize 40 | for i in range(samples_total): 41 | idx = np.argpartition(weight[i], -n_top)[:-n_top] 42 | weight[i, idx] = 0. 43 | weight[i] /= weight[i].sum() 44 | 45 | return weight 46 | 47 | 48 | # naive un-vectorized implementation of diffusion 49 | def diffusion(x, y, weight, step_size=1.0, samples_total=400): 50 | new_x = np.zeros_like(x) 51 | new_y = np.zeros_like(y) 52 | for i in range(samples_total): 53 | delta_x = 0. 54 | delta_y = 0. 55 | for j in range(samples_total): 56 | delta_x += weight[i, j] * (x[i] - x[j]) 57 | delta_y += weight[i, j] * (y[i] - y[j]) 58 | new_x[i] = x[i] - step_size * delta_x 59 | new_y[i] = y[i] - step_size * delta_y 60 | return new_x, new_y 61 | 62 | 63 | def calculate_l(x, y, samples_total=400): 64 | min_l = 999. 65 | samples_each = samples_total / 4 66 | for i in range(samples_total): 67 | for j in range(samples_total): 68 | if i // samples_each != j // samples_each: 69 | l = np.sqrt((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) 70 | if l < min_l: 71 | min_l = l 72 | return min_l 73 | 74 | 75 | def calculate_d(x, y, samples_total=400): 76 | max_d = 0. 77 | samples_each = samples_total / 4 78 | for i in range(samples_total): 79 | for j in range(samples_total): 80 | if i // samples_each == j // samples_each: 81 | d = np.sqrt((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) 82 | if d > max_d: 83 | max_d = d 84 | return max_d 85 | 86 | 87 | def main(): 88 | x, y = generate_points() 89 | weight = calculate_weight(x, y) 90 | 91 | epochs = 201 92 | for i in range(epochs): 93 | plt.cla() 94 | color = [i for i in ['red', 'blue'] for _ in range(200)] 95 | 96 | plt.xticks([]) 97 | plt.yticks([]) 98 | plt.scatter(x, y, c=color, marker='.', animated=True) 99 | plt.savefig("figures/xor/iter=" + str(i) + ".png", bbox_inches='tight') 100 | 101 | # l = calculate_l(x, y) 102 | # d = calculate_d(x, y) 103 | 104 | x, y = diffusion(x, y, weight) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | --------------------------------------------------------------------------------