├── README.md
├── doc
├── diffusion.gif
└── network.png
├── fewshot
├── ablation.py
├── backbone
│ ├── backbone_utils.py
│ ├── network
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── wideres.py
│ └── train_backbone.py
├── data
│ ├── cub
│ │ └── split
│ │ │ ├── test.csv
│ │ │ ├── train.csv
│ │ │ └── val.csv
│ ├── mini
│ │ └── split
│ │ │ ├── test.csv
│ │ │ ├── train.csv
│ │ │ └── val.csv
│ └── tiered
│ │ └── split
│ │ ├── test.csv
│ │ ├── train.csv
│ │ └── val.csv
├── diffresnet.py
├── saved_models
│ └── Put downloaded pretrained models here.txt
├── train.py
└── utils.py
├── graph
├── data
│ ├── citeseer.npz
│ ├── cora.npz
│ └── pubmed.npz
├── data_process
│ ├── __pycache__
│ │ ├── preprocess.cpython-37.pyc
│ │ └── preprocess.cpython-39.pyc
│ ├── io.py
│ ├── make_dataset.py
│ └── preprocess.py
├── model.py
├── train.py
└── utils.py
└── synthetic
├── two_circle_example.py
├── two_moon_example.py
├── two_spiral_example.py
└── xor_example.py
/README.md:
--------------------------------------------------------------------------------
1 | # Diffusion Mechanism in Neural Network: Theory and Applications
2 |
3 | This repository contains the code for Diff-ResNet implemented with PyTorch.
4 |
5 | More details in paper:
6 | [**Diffusion Mechanism in Residual Neural Network: Theory and Applications**](https://ieeexplore.ieee.org/document/10114599)
7 |
8 | ## Introduction
9 | Inspired by the diffusive ODEs, we propose a novel diffusion residual network (Diff-ResNet) to strengthen the interactions among data points. Diffusion mechanism can decrease the distance-diameter ratio and improves the separability of data points. Figure below shows the evolution of points with diffusion.
10 |
11 |

12 |
13 |
14 | The figure describes the architecture of our network.
15 |
16 |

17 |
18 |
19 | ## Synthetic Data
20 | We offer several toy examples to test the effect of diffusion mechanism and for users to understand how to use diffusion in a **plug-and-play** manner.
21 |
22 | They can serve as minimal working examples of diffusion mechanism. Simply run each python file.
23 |
24 | ## Graph Learning
25 | Code is adapted from [**Pitfalls of graph neural network evaluation**](https://github.com/shchur/gnn-benchmark/tree/master/gnnbench). Users can test our Diff-ResNet on dataset cora, citeseer and pubmed for 100 random dataset splits and 20 random initializations each. One should provide step_size and layer_num. Specific parameter choice for reproducing results in paper is provided in the appendix.
26 |
27 | ```
28 | python train.py --dataset cora --step_size 0.25 --layer_num 20 --dropout 0.25
29 | ```
30 |
31 | ## Few-shot
32 | ### 1. Dataset
33 | Download [miniImageNet](https://mega.nz/file/2ldRWQ7Y#U_zhHOf0mxoZ_WQNdvv4mt1Ke3Ay9YPNmHl5TnOVuAU), [tieredImageNet](https://mega.nz/file/r1kmyAgR#uMx7x38RScStpTZARKL2DwTfkD1eVIgbilL4s20vLhI) and [CUB-100](https://mega.nz/file/axUDACZb#ve0NQdmdj_RhhQttONaZ8Tgaxdh4A__PASs_OCI6cSk). Unpack these dataset in to corresponding dataset name directory in [data/](./fewshot/data/).
34 |
35 | ### 2. Backbone Training
36 | You can download pretrained models on base classes [here](https://mega.nz/file/f5lDUJSY#E6zdNonvpPP5nq7cx_heYgLSU6vxCrsbvy4SNr88MT4), and unpack pretrained models in fewshot/saved_models/.
37 |
38 | Or you can train from scratch by running [train_backbone.py](./fewshot/backbone/train_backbone.py).
39 |
40 | ```
41 | python train_backbone.py --dataset mini --backbone resnet18 --silent --epochs 100
42 | ```
43 |
44 | ### 3. Diff-ResNets Classification
45 | Run [train.py](./fewshot/train.py) with specified arguments for few-shot classification. Specific parameter choice for reproducing results in paper is provided in the appendix. See argument description for help.
46 | ```
47 | python train.py --dataset mini --backbone resnet18 --shot 1 --method diffusion --step_size 0.5 --layer_num 6
48 | ```
49 |
50 | ## Citation
51 | If you find Diff-ResNets useful in your research, please consider citing:
52 | ```
53 | @article{wang2024diffusion,
54 | author={Wang, Tangjun and Dou, Zehao and Bao, Chenglong and Shi, Zuoqiang},
55 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
56 | title={Diffusion Mechanism in Residual Neural Network: Theory and Applications},
57 | year={2024},
58 | volume={46},
59 | number={2},
60 | pages={667-680},
61 | doi={10.1109/TPAMI.2023.3272341}
62 | }
63 | ```
64 |
--------------------------------------------------------------------------------
/doc/diffusion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/doc/diffusion.gif
--------------------------------------------------------------------------------
/doc/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/doc/network.png
--------------------------------------------------------------------------------
/fewshot/ablation.py:
--------------------------------------------------------------------------------
1 | """
2 | Code to reproduce Table 2.
3 | """
4 | import os
5 | import random
6 | import argparse
7 | import numpy as np
8 | import torch
9 | import torch.optim as optim
10 | import torch.nn as nn
11 | import torch.backends.cudnn as cudnn
12 | from torch.optim.lr_scheduler import MultiStepLR
13 | from utils import get_tqdm, get_configuration, get_dataloader, get_embedded_feature, get_base_mean, calculate_weight
14 | from diffresnet import DiffusionResNet
15 |
16 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # specify which GPU(s) to be used
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument('--seed', default=1, type=int, help='seed for training')
21 | parser.add_argument("--dataset", choices=["mini", "tiered", "cub"], type=str)
22 | parser.add_argument("--backbone", choices=["resnet18", "wideres"], type=str)
23 | parser.add_argument("--query_per_class", default=15, type=int, help="number of unlabeled query sample per class")
24 | parser.add_argument("--way", default=5, type=int, help="5-way-k-shot")
25 | parser.add_argument("--test_iter", default=1000, type=int, help="test on 1000 tasks and output average accuracy")
26 | parser.add_argument("--shot", choices=[1, 5], type=int)
27 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm')
28 |
29 | parser.add_argument('--epochs', default=100, type=int, help='number of training epochs')
30 | parser.add_argument("--step_size", type=float, help='strength of each diffusion layer')
31 | parser.add_argument("--layer_num", type=int, help='number of diffusion layers, 0 means no diffusion')
32 | parser.add_argument("--n_top", type=int)
33 | parser.add_argument("--sigma", type=int)
34 |
35 | parser.add_argument("--lamda", help='parameter in LaplacianShot', default=0.5, type=float)
36 | parser.add_argument("--method", choices=['simple', 'laplacian', 'diffusion'], type=str)
37 | parser.add_argument("--mu", help='parameter for weighted sum of ce loss and laplacian loss', type=float, default=0.0)
38 |
39 | args = parser.parse_args()
40 |
41 |
42 | def main():
43 | if args.seed is not None:
44 | random.seed(args.seed)
45 | torch.manual_seed(args.seed)
46 | cudnn.deterministic = True
47 |
48 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone)
49 |
50 | # Get the output of embedding function (backbone)
51 | test_loader = get_dataloader(data_path, split_path, 'test')
52 | embedded_feature = get_embedded_feature(test_loader, save_path, args.silent)
53 |
54 | acc_list = []
55 | tqdm_test_iter = get_tqdm(range(args.test_iter), args.silent)
56 | for _ in tqdm_test_iter:
57 | if args.method == 'simple':
58 | acc = simple_shot(embedded_feature)
59 | elif args.method == 'laplacian':
60 | acc = laplacian_shot(embedded_feature)
61 | elif args.method == 'diffusion':
62 | acc = single_trial(embedded_feature)
63 | else:
64 | raise NotImplementedError
65 |
66 | acc_list.append(acc)
67 |
68 | if not args.silent:
69 | tqdm_test_iter.set_description('Test on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list)))
70 |
71 | if args.silent:
72 | print('Accuracy:{:.2f}'.format(np.mean(acc_list)))
73 |
74 |
75 | def sample_task(embedded_feature):
76 | """
77 | Sample a single few-shot task from novel classes
78 | """
79 | sample_class = random.sample(list(embedded_feature.keys()), args.way)
80 | train_data, test_data, test_label, train_label = [], [], [], []
81 |
82 | for i, each_class in enumerate(sample_class):
83 | samples = random.sample(embedded_feature[each_class], args.shot + args.query_per_class)
84 |
85 | train_label += [i] * args.shot
86 | test_label += [i] * args.query_per_class
87 | train_data += samples[:args.shot]
88 | test_data += samples[args.shot:]
89 |
90 | return np.array(train_data), np.array(test_data), np.array(train_label), np.array(test_label)
91 |
92 |
93 | def single_trial(embedded_feature):
94 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
95 |
96 | train_data, test_data, train_label, test_label = torch.tensor(train_data), torch.tensor(
97 | test_data), torch.tensor(train_label), torch.tensor(test_label)
98 |
99 | inputs = torch.cat([train_data, test_data], dim=0)
100 | weight = calculate_weight(inputs, args.n_top, args.sigma)
101 | inputs, train_label, weight = inputs.cuda(), train_label.cuda(), weight.cuda()
102 | model = DiffusionResNet(n_dim=inputs.shape[1], step_size=args.step_size, layer_num=args.layer_num,
103 | weight=weight).cuda()
104 |
105 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
106 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1)
107 |
108 | for epoch in range(args.epochs):
109 | train(model, inputs, train_label, optimizer, weight)
110 | scheduler.step()
111 |
112 | outputs = model(inputs)
113 |
114 | # get the accuracy only on query data
115 | pred = outputs.argmax(dim=1)[args.way * args.shot:].cpu()
116 | acc = torch.eq(pred, test_label).float().mean().cpu().numpy() * 100
117 | return acc
118 |
119 |
120 | def train(model, inputs, train_label, optimizer, weight):
121 | outputs = model(inputs)
122 | loss1 = nn.CrossEntropyLoss()(outputs[:args.way * args.shot], train_label)
123 | loss2 = torch.sum(weight * torch.linalg.norm(outputs.unsqueeze(0) - outputs.unsqueeze(1), dim=-1) ** 2)
124 |
125 | loss = loss1 + args.mu * loss2
126 | optimizer.zero_grad()
127 | loss.backward()
128 | optimizer.step()
129 |
130 |
131 | def simple_shot(embedded_feature):
132 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
133 |
134 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1)
135 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1)
136 |
137 | idx = np.argmin(distance, axis=1)
138 | pred = np.take(np.unique(train_label), idx)
139 | acc = (pred == test_label).mean() * 100
140 | return acc
141 |
142 |
143 | def laplacian_shot(embedded_feature, knn=3, lamda=args.lamda, max_iter=20):
144 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
145 |
146 | # calculate weight
147 | n = test_data.shape[0]
148 | w = np.zeros((n, n))
149 | distance = np.linalg.norm(test_data - test_data[:, None], axis=-1)
150 | knn_ind = np.argsort(distance, axis=1)[:, 1:knn]
151 | np.put_along_axis(w, knn_ind, 1.0, axis=1)
152 |
153 | # (8a)
154 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1)
155 | a = np.linalg.norm(prototype - test_data[:, None], axis=-1)
156 |
157 | y = np.exp(-a) / np.sum(np.exp(-a), axis=1, keepdims=True)
158 | energy = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y)))
159 |
160 | for i in range(max_iter):
161 | # (12) update
162 | out = - a + lamda * np.dot(w, y)
163 | y = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True)
164 |
165 | # (7) check stopping criterion
166 | energy_new = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y)))
167 | if abs((energy_new - energy) / energy) < 1e-6:
168 | break
169 | energy = energy_new.copy()
170 |
171 | idx = np.argmax(y, axis=1)
172 | pred = np.take(np.unique(train_label), idx)
173 | acc = (pred == test_label).mean() * 100
174 | return acc
175 |
176 |
177 | if __name__ == '__main__':
178 | main()
179 |
--------------------------------------------------------------------------------
/fewshot/backbone/backbone_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL.Image as Image
4 | import torch.utils.data as data
5 | from torchvision import transforms
6 | from tqdm import tqdm
7 | import torch
8 | import re
9 |
10 |
11 | def get_configuration(dataset, backbone):
12 | """
13 | Get configuration according to dataset and backbone.
14 | """
15 |
16 | data_path = '../data/' + dataset + '/images'
17 | split_path = '../data/' + dataset + '/split'
18 | save_path = '../saved_models/' + dataset + '_' + backbone + '.pt'
19 |
20 | if dataset == 'mini':
21 | num_classes = 64
22 | elif dataset == 'tiered':
23 | num_classes = 351
24 | elif dataset == 'cub':
25 | num_classes = 100
26 | else:
27 | raise NotImplementedError
28 |
29 | return data_path, split_path, save_path, num_classes
30 |
31 |
32 | class DatasetFolder(data.Dataset):
33 | def __init__(self, root, split_dir, split_type, transform):
34 | assert split_type in ['train', 'val', 'test']
35 | split_file = os.path.join(split_dir, split_type + '.csv')
36 | assert os.path.isfile(split_file)
37 |
38 | with open(split_file, 'r') as f:
39 | split = [x.strip().split(',') for x in f.readlines()[1:] if x.strip() != '']
40 |
41 | data, ori_labels = [x[0] for x in split], [x[1] for x in split]
42 | label_key = sorted(np.unique(np.array(ori_labels)))
43 | label_map = dict(zip(label_key, range(len(label_key))))
44 | mapped_labels = [label_map[x] for x in ori_labels]
45 |
46 | self.root = root
47 | self.transform = transform
48 | self.data = data
49 | self.labels = mapped_labels
50 | self.length = len(self.data)
51 |
52 | def __len__(self):
53 | return self.length
54 |
55 | def __getitem__(self, index):
56 | filename = self.data[index]
57 | path_file = os.path.join(self.root, filename)
58 | assert os.path.isfile(path_file)
59 | img = Image.open(path_file).convert('RGB')
60 | label = self.labels[index]
61 | label = int(label)
62 | if self.transform:
63 | img = self.transform(img)
64 |
65 | return img, label
66 |
67 |
68 | def get_train_dataloader(data_path, split_path, batch_size):
69 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type='train',
70 | transform=transforms.Compose([transforms.RandomResizedCrop(84),
71 | transforms.ColorJitter(brightness=0.4, contrast=0.4,
72 | saturation=0.4),
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
76 | std=[0.229, 0.224, 0.225])]))
77 |
78 | # Setting appropriate num_workers can significantly increase training speed
79 | loader = data.DataLoader(datasets, batch_size=batch_size, shuffle=True, num_workers=40, pin_memory=True)
80 |
81 | return loader
82 |
83 |
84 | def get_val_dataloader(data_path, split_path):
85 | dataset = re.split('[/_]', data_path)[-2]
86 | if dataset == "cub":
87 | resize = 120
88 | else:
89 | resize = 96
90 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type='val',
91 | transform=transforms.Compose([transforms.Resize(resize),
92 | transforms.CenterCrop(84),
93 | transforms.ToTensor(),
94 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
95 | std=[0.229, 0.224, 0.225])]))
96 | loader = torch.utils.data.DataLoader(datasets, batch_size=100, shuffle=False, num_workers=40)
97 | return loader
98 |
99 |
100 | def get_tqdm(iters, silent):
101 | """
102 | Wrap iters with tqdm if not --silent
103 | """
104 | if silent:
105 | return iters
106 | else:
107 | return tqdm(iters)
108 |
--------------------------------------------------------------------------------
/fewshot/backbone/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet18
2 | from .wideres import wideres
3 |
--------------------------------------------------------------------------------
/fewshot/backbone/network/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | __all__ = ['resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 |
11 | def conv1x1(in_planes, out_planes, stride=1):
12 | """1x1 convolution"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 |
16 | class BasicBlock(nn.Module):
17 | expansion = 1
18 |
19 | def __init__(self, inplanes, planes, stride=1, downsample=None):
20 | super(BasicBlock, self).__init__()
21 | self.conv1 = conv3x3(inplanes, planes, stride)
22 | self.bn1 = nn.BatchNorm2d(planes)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.conv2 = conv3x3(planes, planes)
25 | self.bn2 = nn.BatchNorm2d(planes)
26 | self.downsample = downsample
27 | self.stride = stride
28 |
29 | def forward(self, x):
30 | identity = x
31 |
32 | out = self.conv1(x)
33 | out = self.bn1(out)
34 | out = self.relu(out)
35 |
36 | out = self.conv2(out)
37 | out = self.bn2(out)
38 |
39 | if self.downsample is not None:
40 | identity = self.downsample(x)
41 |
42 | out += identity
43 | out = self.relu(out)
44 |
45 | return out
46 |
47 |
48 | class Bottleneck(nn.Module):
49 | expansion = 4
50 |
51 | def __init__(self, inplanes, planes, stride=1, downsample=None):
52 | super(Bottleneck, self).__init__()
53 | self.conv1 = conv1x1(inplanes, planes)
54 | self.bn1 = nn.BatchNorm2d(planes)
55 | self.conv2 = conv3x3(planes, planes, stride)
56 | self.bn2 = nn.BatchNorm2d(planes)
57 | self.conv3 = conv1x1(planes, planes * self.expansion)
58 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
59 | self.relu = nn.ReLU(inplace=True)
60 | self.downsample = downsample
61 | self.stride = stride
62 |
63 | def forward(self, x):
64 | identity = x
65 |
66 | out = self.conv1(x)
67 | out = self.bn1(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv2(out)
71 | out = self.bn2(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv3(out)
75 | out = self.bn3(out)
76 |
77 | if self.downsample is not None:
78 | identity = self.downsample(x)
79 |
80 | out += identity
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 |
86 | class ResNet(nn.Module):
87 |
88 | def __init__(self, block, layers, num_classes=1000):
89 | super(ResNet, self).__init__()
90 | self.inplanes = 64
91 | self.conv1 = conv3x3(3, 64)
92 | self.bn1 = nn.BatchNorm2d(64)
93 | self.relu = nn.ReLU(inplace=True)
94 | self.layer1 = self._make_layer(block, 64, layers[0])
95 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
96 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
97 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
98 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
99 |
100 | self.fc = nn.Linear(512 * block.expansion, num_classes)
101 |
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
105 | elif isinstance(m, nn.BatchNorm2d):
106 | nn.init.constant_(m.weight, 1)
107 | nn.init.constant_(m.bias, 0)
108 |
109 | def _make_layer(self, block, planes, blocks, stride=1):
110 | downsample = None
111 | if stride != 1 or self.inplanes != planes * block.expansion:
112 | downsample = nn.Sequential(
113 | conv1x1(self.inplanes, planes * block.expansion, stride),
114 | nn.BatchNorm2d(planes * block.expansion),
115 | )
116 |
117 | layers = []
118 | layers.append(block(self.inplanes, planes, stride, downsample))
119 | self.inplanes = planes * block.expansion
120 | for _ in range(1, blocks):
121 | layers.append(block(self.inplanes, planes))
122 |
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, x, return_feature=False):
126 | x = self.conv1(x)
127 | x = self.bn1(x)
128 | x = self.relu(x)
129 |
130 | x = self.layer1(x)
131 | x = self.layer2(x)
132 | x = self.layer3(x)
133 | x = self.layer4(x)
134 |
135 | x = self.avgpool(x)
136 | feature = x.view(x.size(0), -1)
137 | out = self.fc(feature)
138 |
139 | if return_feature:
140 | return feature, out
141 | else:
142 | return out
143 |
144 |
145 | def resnet10(**kwargs):
146 | """Constructs a ResNet-10 model.
147 | """
148 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs)
149 | return model
150 |
151 |
152 | def resnet18(**kwargs):
153 | """Constructs a ResNet-18 model.
154 | """
155 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
156 | return model
157 |
158 |
159 | def resnet34(**kwargs):
160 | """Constructs a ResNet-34 model.
161 | """
162 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
163 | return model
164 |
165 |
166 | def resnet50(**kwargs):
167 | """Constructs a ResNet-50 model.
168 | """
169 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
170 | return model
171 |
172 |
173 | def resnet101(**kwargs):
174 | """Constructs a ResNet-101 model.
175 | """
176 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
177 | return model
178 |
179 |
180 | def resnet152(**kwargs):
181 | """Constructs a ResNet-152 model.
182 | """
183 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
184 | return model
185 |
--------------------------------------------------------------------------------
/fewshot/backbone/network/wideres.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/meliketoy/wide-resnet.pytorch/blob/master/networks/wide_resnet.py
3 | """
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | __all__ = ['wideres']
8 |
9 |
10 | def conv3x3(in_planes, out_planes, stride=1):
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)
12 |
13 |
14 | class wide_basic(nn.Module):
15 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
16 | super(wide_basic, self).__init__()
17 | self.bn1 = nn.BatchNorm2d(in_planes)
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
19 | self.dropout = nn.Dropout(p=dropout_rate)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
27 | )
28 |
29 | def forward(self, x):
30 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
31 | out = self.conv2(F.relu(self.bn2(out)))
32 | out += self.shortcut(x)
33 |
34 | return out
35 |
36 |
37 | class Wide_ResNet(nn.Module):
38 | def __init__(self, depth, widen_factor, dropout_rate, num_classes):
39 | super(Wide_ResNet, self).__init__()
40 | self.in_planes = 16
41 |
42 | assert ((depth - 4) % 6 == 0), 'Wide-resnet depth should be 6n+4'
43 | n = (depth - 4) // 6
44 | k = widen_factor
45 |
46 | nStages = [16, 16 * k, 32 * k, 64 * k]
47 |
48 | self.conv1 = conv3x3(3, nStages[0])
49 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
50 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
51 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
52 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
53 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
54 |
55 | self.linear = nn.Linear(nStages[3], num_classes)
56 | for m in self.modules():
57 | if isinstance(m, nn.Conv2d):
58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
59 | elif isinstance(m, nn.BatchNorm2d):
60 | nn.init.constant_(m.weight, 1)
61 | nn.init.constant_(m.bias, 0)
62 |
63 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
64 | strides = [stride] + [1] * (num_blocks - 1)
65 | layers = []
66 |
67 | for stride in strides:
68 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
69 | self.in_planes = planes
70 |
71 | return nn.Sequential(*layers)
72 |
73 | def forward(self, x, return_feature=False):
74 | out = self.conv1(x)
75 | out = self.layer1(out)
76 | out = self.layer2(out)
77 | out = self.layer3(out)
78 | out = F.relu(self.bn1(out))
79 | out = self.avgpool(out)
80 | feature = out.view(out.size(0), -1)
81 | out = self.linear(feature)
82 |
83 | if return_feature:
84 | return feature, out
85 | else:
86 | return out
87 |
88 |
89 | def wideres(num_classes):
90 | """Constructs a wideres-28-10 model without dropout.
91 | """
92 | return Wide_ResNet(28, 10, 0, num_classes)
93 |
--------------------------------------------------------------------------------
/fewshot/backbone/train_backbone.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import argparse
4 | import torch
5 | import torch.backends.cudnn as cudnn
6 | import torch.nn as nn
7 | import torch.optim as optim
8 | from torch.optim.lr_scheduler import MultiStepLR
9 | import network
10 | import numpy as np
11 | import collections
12 | from backbone_utils import get_configuration, get_train_dataloader, get_tqdm, get_val_dataloader
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--seed', default=1, type=int, help='seed for training')
16 | parser.add_argument("--dataset", choices=['mini', 'tiered', 'cub'], type=str)
17 | parser.add_argument("--backbone", choices=['resnet18', 'wideres'], type=str, help='network architecture')
18 | parser.add_argument('--epochs', type=int, help='number of training epochs. 100 for mini and tiered. 400 for cub')
19 | parser.add_argument('--batch_size', default=256, type=int)
20 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm')
21 |
22 | args = parser.parse_args()
23 |
24 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
25 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # specify which GPU(s) to be used
26 |
27 |
28 | def main():
29 | if args.seed is not None:
30 | random.seed(args.seed)
31 | torch.manual_seed(args.seed)
32 | cudnn.deterministic = True
33 |
34 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone)
35 | train_loader = get_train_dataloader(data_path, split_path, args.batch_size)
36 | val_loader = get_val_dataloader(data_path, split_path)
37 |
38 | model = network.__dict__[args.backbone](num_classes=num_classes)
39 | model = torch.nn.DataParallel(model).cuda()
40 |
41 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
42 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1)
43 |
44 | tqdm_epochs = get_tqdm(range(args.epochs), args.silent)
45 | if not args.silent:
46 | tqdm_epochs.set_description('Total Epochs')
47 |
48 | if not os.path.isdir('../saved_models'):
49 | os.makedirs('../saved_models')
50 |
51 | best_acc = 0
52 | for epoch in tqdm_epochs:
53 | train(train_loader, model, optimizer, epoch)
54 | scheduler.step()
55 |
56 | if epoch >= int(.75 * args.epochs):
57 | val_acc = validate(val_loader, model)
58 | if val_acc > best_acc:
59 | best_acc = val_acc
60 | torch.save(model.state_dict(), save_path)
61 |
62 |
63 | def train(train_loader, model, optimizer, epoch):
64 | model.train()
65 |
66 | correct_count = 0
67 | total_count = 0
68 | acc = 0
69 | tqdm_train_loader = get_tqdm(train_loader, args.silent)
70 |
71 | for batch_idx, (inputs, labels) in enumerate(tqdm_train_loader):
72 | inputs, labels = inputs.cuda(), labels.cuda()
73 | outputs = model(inputs)
74 | loss = nn.CrossEntropyLoss(label_smoothing=0.1)(outputs, labels)
75 |
76 | optimizer.zero_grad()
77 | loss.backward()
78 | optimizer.step()
79 |
80 | pred = outputs.argmax(dim=1)
81 | correct_count += pred.eq(labels).sum().item()
82 | total_count += len(inputs)
83 | acc = correct_count / total_count * 100
84 |
85 | if not args.silent:
86 | tqdm_train_loader.set_description('Acc {:.2f}'.format(acc))
87 |
88 | if args.silent:
89 | print("Epoch={}, Accuracy={:.2f}".format(epoch + 1, acc))
90 |
91 |
92 | # Below codes only used for validation. We save the models with the highest 1-shot nearest neighbor classification
93 | # accuracy.
94 | def validate(val_loader, model):
95 | input_dict = collections.defaultdict(list)
96 | for i, (inputs, labels) in enumerate(val_loader):
97 | for img, label in zip(inputs, labels):
98 | input_dict[label.item()].append(img)
99 |
100 | acc_list = []
101 | tqdm_test_iter = get_tqdm(range(1000), args.silent)
102 | for _ in tqdm_test_iter:
103 | acc = nearest_prototype(input_dict, model)
104 | acc_list.append(acc)
105 |
106 | if not args.silent:
107 | tqdm_test_iter.set_description('Validate on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list)))
108 | if args.silent:
109 | print("Validation Accuracy={:.2f}".format(np.mean(acc_list)))
110 |
111 | return np.mean(acc_list)
112 |
113 |
114 | def nearest_prototype(input_dict, model):
115 | sample_class = random.sample(list(input_dict.keys()), 5)
116 | train_img, test_img, test_label, train_label = [], [], [], []
117 | for i, each_class in enumerate(sample_class):
118 | samples = random.sample(input_dict[each_class], 1 + 15)
119 |
120 | train_label += [i] * 1 # We only validate on 1-shot tasks, for simplicity
121 | test_label += [i] * 15
122 | train_img += samples[:1]
123 | test_img += samples[1:]
124 |
125 | train_img, test_img = torch.stack(train_img).cuda(), torch.stack(test_img).cuda()
126 | train_test_img = torch.cat([train_img, test_img])
127 |
128 | train_label, test_label = np.array(train_label), np.array(test_label)
129 |
130 | model.eval()
131 | with torch.no_grad():
132 | train_test_data, _ = model(train_test_img, return_feature=True)
133 |
134 | train_test_data = train_test_data.cpu().data.numpy()
135 | train_data, test_data = train_test_data[:5], train_test_data[5:]
136 |
137 | prototype = train_data.reshape((5, 1, -1)).mean(axis=1)
138 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1)
139 |
140 | idx = np.argmin(distance, axis=1)
141 | pred = np.take(np.unique(train_label), idx)
142 | acc = (pred == test_label).mean() * 100
143 | return acc
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/fewshot/diffresnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DiffusionLayer(nn.Module):
7 | def __init__(self, step_size, laplacian):
8 | super(DiffusionLayer, self).__init__()
9 | self.step_size = step_size
10 | self.laplacian = laplacian
11 |
12 | def forward(self, x):
13 | x = x - self.step_size * torch.matmul(self.laplacian, x.flatten(1)).view_as(x)
14 | return x
15 |
16 |
17 | class DiffusionResNet(nn.Module):
18 | def __init__(self, n_dim, step_size, layer_num, weight):
19 | super(DiffusionResNet, self).__init__()
20 | self.layer_num = layer_num
21 | diagonal = torch.diag(weight.sum(dim=1))
22 | laplacian = diagonal - weight
23 |
24 | self.fc1 = nn.Linear(n_dim, n_dim)
25 | self.fc2 = nn.Linear(n_dim, n_dim)
26 | self.classifier = nn.Linear(n_dim, 5) # 5-way classification
27 | self.diffusion_layer = DiffusionLayer(step_size, laplacian)
28 |
29 | def forward(self, x):
30 | x = self.fc2(F.relu(self.fc1(x))) + x
31 | for _ in range(self.layer_num):
32 | x = self.diffusion_layer(x)
33 | out = self.classifier(x)
34 | return out
35 |
--------------------------------------------------------------------------------
/fewshot/saved_models/Put downloaded pretrained models here.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/fewshot/saved_models/Put downloaded pretrained models here.txt
--------------------------------------------------------------------------------
/fewshot/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Code to reproduce Table 3.
3 | """
4 | import os
5 | import random
6 | import argparse
7 | import numpy as np
8 | import torch
9 | import torch.optim as optim
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import torch.backends.cudnn as cudnn
13 | from torch.optim.lr_scheduler import MultiStepLR
14 | from utils import get_tqdm, get_configuration, get_dataloader, get_embedded_feature, get_base_mean
15 | from utils import compute_confidence_interval, calculate_weight
16 | from diffresnet import DiffusionResNet
17 |
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--seed', default=1, type=int, help='seed for training')
20 | parser.add_argument("--dataset", choices=["mini", "tiered", "cub"], type=str)
21 | parser.add_argument("--backbone", choices=["resnet18", "wideres"], type=str)
22 | parser.add_argument("--query_per_class", default=15, type=int, help="number of unlabeled query sample per class")
23 | parser.add_argument("--way", default=5, type=int, help="5-way-k-shot")
24 | parser.add_argument("--test_iter", default=10000, type=int, help="test on 10000 tasks and output average accuracy")
25 | parser.add_argument("--shot", choices=[1, 5], type=int)
26 | parser.add_argument('--silent', action='store_true', help='call --silent to disable tqdm')
27 |
28 | parser.add_argument('--epochs', default=100, type=int, help='number of training epochs')
29 | parser.add_argument("--step_size", type=float, help='strength of each diffusion layer', default=0.5)
30 | parser.add_argument("--layer_num", type=int, help='number of diffusion layers, 0 means no diffusion')
31 | parser.add_argument("--n_top", type=int, default=8)
32 | parser.add_argument("--sigma", type=int, default=4)
33 |
34 | parser.add_argument("--lamda", help='parameter in LaplacianShot', default=0.5, type=float)
35 | parser.add_argument("--method", choices=['simple', 'laplacian', 'diffusion'], type=str)
36 | parser.add_argument("--alpha", help='parameter for weighted sum of ce loss and proto loss', type=float, default=0.0)
37 |
38 | args = parser.parse_args()
39 |
40 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
41 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' # specify which GPU(s) to be used
42 |
43 |
44 | def main():
45 | if args.seed is not None:
46 | random.seed(args.seed)
47 | torch.manual_seed(args.seed)
48 | cudnn.deterministic = True
49 |
50 | data_path, split_path, save_path, num_classes = get_configuration(args.dataset, args.backbone)
51 |
52 | # On novel class: get the output of embedding function (backbone)
53 | # On base class: get the output average of embedding function (backbone), used for centering
54 | train_loader = get_dataloader(data_path, split_path, 'train')
55 | test_loader = get_dataloader(data_path, split_path, 'test')
56 | embedded_feature = get_embedded_feature(test_loader, save_path, args.silent)
57 | base_mean = get_base_mean(train_loader, save_path, args.silent)
58 |
59 | acc_list = []
60 | tqdm_test_iter = get_tqdm(range(args.test_iter), args.silent)
61 |
62 | for _ in tqdm_test_iter:
63 | if args.method == 'simple':
64 | acc = simple_shot(embedded_feature, base_mean)
65 | elif args.method == 'laplacian':
66 | acc = laplacian_shot(embedded_feature, base_mean)
67 | elif args.method == 'diffusion':
68 | acc = single_trial(embedded_feature, base_mean)
69 | else:
70 | raise NotImplementedError
71 |
72 | acc_list.append(acc)
73 |
74 | if not args.silent:
75 | tqdm_test_iter.set_description('Test on few-shot tasks. Accuracy:{:.2f}'.format(np.mean(acc_list)))
76 |
77 | acc_mean, acc_conf = compute_confidence_interval(acc_list)
78 | print('Accuracy:{:.2f}'.format(acc_mean))
79 | print('Conf:{:.2f}'.format(acc_conf))
80 |
81 |
82 | def sample_task(embedded_feature):
83 | """
84 | Sample a single few-shot task from novel classes
85 | """
86 | sample_class = random.sample(list(embedded_feature.keys()), args.way)
87 | train_data, test_data, test_label, train_label = [], [], [], []
88 |
89 | for i, each_class in enumerate(sample_class):
90 | samples = random.sample(embedded_feature[each_class], args.shot + args.query_per_class)
91 |
92 | train_label += [i] * args.shot
93 | test_label += [i] * args.query_per_class
94 | train_data += samples[:args.shot]
95 | test_data += samples[args.shot:]
96 |
97 | return np.array(train_data), np.array(test_data), np.array(train_label), np.array(test_label)
98 |
99 |
100 | def single_trial(embedded_feature, base_mean):
101 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
102 |
103 | train_data, test_data, train_label, test_label, base_mean = torch.tensor(train_data), torch.tensor(
104 | test_data), torch.tensor(train_label), torch.tensor(test_label), torch.tensor(base_mean)
105 |
106 | # Centering and Normalization
107 | train_data = train_data - base_mean
108 | train_data = train_data / torch.norm(train_data, dim=1, keepdim=True)
109 | test_data = test_data - base_mean
110 | test_data = test_data / torch.norm(test_data, dim=1, keepdim=True)
111 |
112 | # Cross-Domain Shift
113 | eta = train_data.mean(dim=0, keepdim=True) - test_data.mean(dim=0, keepdim=True)
114 | test_data = test_data + eta
115 |
116 | inputs = torch.cat([train_data, test_data], dim=0)
117 | weight = calculate_weight(inputs, args.n_top, args.sigma)
118 | inputs, train_label, weight = inputs.cuda(), train_label.cuda(), weight.cuda()
119 | model = DiffusionResNet(n_dim=inputs.shape[1], step_size=args.step_size, layer_num=args.layer_num,
120 | weight=weight).cuda()
121 |
122 | optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
123 | scheduler = MultiStepLR(optimizer, milestones=[int(.5 * args.epochs), int(.75 * args.epochs)], gamma=0.1)
124 |
125 | # Prototype Rectification
126 | whole_data = torch.cat([train_data, test_data], dim=0)
127 | prototype = train_data.reshape(args.way, args.shot, -1).mean(dim=1)
128 | cos_sim = F.cosine_similarity(whole_data[:, None, :], prototype[None, :, :], dim=2) * 10 # 10 is a parameter
129 | pseudo_predict = torch.argmax(cos_sim, dim=1)
130 | cos_weight = F.softmax(cos_sim, dim=1)
131 | rectified_prototype = torch.cat(
132 | [(cos_weight[pseudo_predict == i, i].unsqueeze(1) * whole_data[pseudo_predict == i]).mean(0, keepdim=True)
133 | for i in range(args.way)], dim=0)
134 | rectified_prototype = rectified_prototype.cuda()
135 |
136 | for epoch in range(args.epochs):
137 | train(model, inputs, train_label, optimizer, rectified_prototype)
138 | scheduler.step()
139 |
140 | outputs = model(inputs)
141 |
142 | # get the accuracy only on query data
143 | pred = outputs.argmax(dim=1)[args.way * args.shot:].cpu()
144 | acc = torch.eq(pred, test_label).float().mean().cpu().numpy() * 100
145 | return acc
146 |
147 |
148 | def train(model, inputs, train_label, optimizer, prototype):
149 | outputs = model(inputs)
150 | loss = nn.CrossEntropyLoss()(outputs[:args.way * args.shot], train_label)
151 |
152 | distance = torch.linalg.norm(inputs[args.way * args.shot:].unsqueeze(1) - prototype.unsqueeze(0), dim=2)
153 | proto_loss = (F.softmax(outputs[args.way * args.shot:], dim=1) * distance).sum()
154 |
155 | loss = loss + args.alpha * proto_loss
156 | optimizer.zero_grad()
157 | loss.backward()
158 | optimizer.step()
159 |
160 |
161 | def simple_shot(embedded_feature, base_mean):
162 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
163 |
164 | # Centering and Normalization
165 | train_data = train_data - base_mean
166 | train_data = train_data / np.linalg.norm(train_data, axis=1, keepdims=True)
167 | test_data = test_data - base_mean
168 | test_data = test_data / np.linalg.norm(test_data, axis=1, keepdims=True)
169 |
170 | prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1)
171 | distance = np.linalg.norm(prototype - test_data[:, None], axis=-1)
172 |
173 | idx = np.argmin(distance, axis=1)
174 | pred = np.take(np.unique(train_label), idx)
175 | acc = (pred == test_label).mean() * 100
176 | return acc
177 |
178 |
179 | def laplacian_shot(embedded_feature, base_mean, knn=3, lamda=args.lamda, max_iter=20):
180 | train_data, test_data, train_label, test_label = sample_task(embedded_feature)
181 |
182 | # Centering and Normalization
183 | train_data = train_data - base_mean
184 | train_data = train_data / np.linalg.norm(train_data, axis=1, keepdims=True)
185 | test_data = test_data - base_mean
186 | test_data = test_data / np.linalg.norm(test_data, axis=1, keepdims=True)
187 |
188 | # Cross-Domain Shift
189 | eta = train_data.mean(axis=0, keepdims=True) - test_data.mean(axis=0, keepdims=True)
190 | test_data = test_data + eta
191 |
192 | # Prototype Rectification
193 | train_data, test_data = torch.tensor(train_data), torch.tensor(test_data)
194 | whole_data = torch.cat([train_data, test_data], dim=0)
195 | prototype = train_data.reshape(args.way, args.shot, -1).mean(dim=1)
196 | cos_sim = F.cosine_similarity(whole_data[:, None, :], prototype[None, :, :], dim=2) * 10 # 10 is a parameter
197 | pseudo_predict = torch.argmax(cos_sim, dim=1)
198 | cos_weight = F.softmax(cos_sim, dim=1)
199 | rectified_prototype = torch.cat(
200 | [(cos_weight[pseudo_predict == i, i].unsqueeze(1) * whole_data[pseudo_predict == i]).mean(0, keepdim=True)
201 | for i in range(args.way)], dim=0)
202 |
203 | # calculate weight
204 | n = test_data.shape[0]
205 | w = np.zeros((n, n))
206 | distance = np.linalg.norm(test_data - test_data[:, None], axis=-1)
207 | knn_ind = np.argsort(distance, axis=1)[:, 1:knn]
208 | np.put_along_axis(w, knn_ind, 1.0, axis=1)
209 |
210 | # (8a)
211 | # prototype = train_data.reshape((args.way, args.shot, -1)).mean(axis=1)
212 | a = np.linalg.norm(rectified_prototype - test_data[:, None], axis=-1)
213 |
214 | y = np.exp(-a) / np.sum(np.exp(-a), axis=1, keepdims=True)
215 | energy = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y)))
216 |
217 | for i in range(max_iter):
218 | # (12) update
219 | out = - a + lamda * np.dot(w, y)
220 | y = np.exp(out) / np.sum(np.exp(out), axis=1, keepdims=True)
221 |
222 | # (7) check stopping criterion
223 | energy_new = np.sum(y * (np.log(y) + a - lamda * np.dot(w, y)))
224 | if abs((energy_new - energy) / energy) < 1e-6:
225 | break
226 | energy = energy_new.copy()
227 |
228 | idx = np.argmax(y, axis=1)
229 | pred = np.take(np.unique(train_label), idx)
230 | acc = (pred == test_label).mean() * 100
231 | return acc
232 |
233 |
234 | if __name__ == '__main__':
235 | main()
236 |
--------------------------------------------------------------------------------
/fewshot/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import pickle
4 | import numpy as np
5 | import torch
6 | import torch.utils.data as data
7 | from torchvision import transforms
8 | import backbone.network as network
9 | from tqdm import tqdm
10 | import PIL.Image as Image
11 | import re
12 |
13 |
14 | def get_tqdm(iters, silent):
15 | """
16 | Wrap iters with tqdm if not --silent
17 | """
18 | if silent:
19 | return iters
20 | else:
21 | return tqdm(iters)
22 |
23 |
24 | def save_pickle(file, data):
25 | with open(file, 'wb') as f:
26 | pickle.dump(data, f)
27 |
28 |
29 | def load_pickle(file):
30 | with open(file, 'rb') as f:
31 | return pickle.load(f)
32 |
33 |
34 | def compute_confidence_interval(data):
35 | """
36 | Compute 95% confidence interval
37 | """
38 | return np.mean(data), 1.96 * (np.std(data) / np.sqrt(len(data)))
39 |
40 |
41 | def calculate_weight(inputs, n_top, sigma):
42 | distance = torch.norm(inputs.unsqueeze(0) - inputs.unsqueeze(1), dim=-1)
43 | dist_n_top = torch.kthvalue(distance, n_top, dim=1, keepdim=True)[0]
44 | dist_sigma = torch.kthvalue(distance, sigma, dim=1, keepdim=True)[0]
45 |
46 | distance_truncated = distance.where(distance < dist_n_top, torch.tensor(float("inf")))
47 | weight = torch.exp(-(distance_truncated / dist_sigma).pow(2))
48 |
49 | # Symmetrically normalize the weight matrix
50 | d_inv_sqrt = torch.diag(weight.sum(dim=1).pow(-0.5))
51 | weight = d_inv_sqrt.mm(weight).mm(d_inv_sqrt)
52 | weight = (weight + weight.t()) / 2
53 | weight = weight.detach()
54 | return weight
55 |
56 |
57 | class DatasetFolder(data.Dataset):
58 | def __init__(self, root, split_dir, split_type, transform):
59 | assert split_type in ['train', 'val', 'test']
60 | split_file = os.path.join(split_dir, split_type + '.csv')
61 | assert os.path.isfile(split_file)
62 |
63 | with open(split_file, 'r') as f:
64 | split = [x.strip().split(',') for x in f.readlines()[1:] if x.strip() != '']
65 |
66 | data, ori_labels = [x[0] for x in split], [x[1] for x in split]
67 | label_key = sorted(np.unique(np.array(ori_labels)))
68 | label_map = dict(zip(label_key, range(len(label_key))))
69 | mapped_labels = [label_map[x] for x in ori_labels]
70 |
71 | self.root = root
72 | self.transform = transform
73 | self.data = data
74 | self.labels = mapped_labels
75 | self.length = len(self.data)
76 |
77 | def __len__(self):
78 | return self.length
79 |
80 | def __getitem__(self, index):
81 | filename = self.data[index]
82 | path_file = os.path.join(self.root, filename)
83 | assert os.path.isfile(path_file)
84 | img = Image.open(path_file).convert('RGB')
85 | label = self.labels[index]
86 | label = int(label)
87 | if self.transform:
88 | img = self.transform(img)
89 | return img, label
90 |
91 |
92 | def get_dataloader(data_path, split_path, split_type):
93 | dataset = re.split('[/_]', data_path)[-2]
94 | # First resize larger than 84, then center crop, achieve better result
95 | if dataset == "cub":
96 | resize = 120
97 | else:
98 | resize = 96
99 | datasets = DatasetFolder(root=data_path, split_dir=split_path, split_type=split_type,
100 | transform=transforms.Compose([transforms.Resize(resize),
101 | transforms.CenterCrop(84),
102 | transforms.ToTensor(),
103 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
104 | std=[0.229, 0.224, 0.225])]))
105 |
106 | loader = torch.utils.data.DataLoader(datasets, batch_size=1000, shuffle=False, num_workers=40)
107 | return loader
108 |
109 |
110 | def get_embedded_feature(test_loader, save_path, silent):
111 | """
112 | Return embedded features of data from novel classes
113 | """
114 | # Only compute once for each dataset+backbone
115 | if os.path.isfile(save_path + '_embedded_feature.plk'):
116 | embedded_feature = load_pickle(save_path + '_embedded_feature.plk')
117 | return embedded_feature
118 |
119 | model = load_pretrained_backbone(save_path)
120 |
121 | model.eval()
122 | with torch.no_grad():
123 | embedded_feature = collections.defaultdict(list)
124 |
125 | tqdm_test_loader = get_tqdm(test_loader, silent)
126 | if not silent:
127 | tqdm_test_loader.set_description('Computing embedded features on test classes')
128 |
129 | for i, (inputs, labels) in enumerate(tqdm_test_loader):
130 | features, _ = model(inputs, return_feature=True)
131 | features = features.cpu().data.numpy()
132 | for feature, label in zip(features, labels):
133 | embedded_feature[label.item()].append(feature)
134 | save_pickle(save_path + '_embedded_feature.plk', embedded_feature)
135 |
136 | return embedded_feature
137 |
138 |
139 | def get_base_mean(train_loader, save_path, silent):
140 | """
141 | Return average of data from base classes
142 | """
143 | # Only compute once for each dataset+backbone
144 | if os.path.isfile(save_path + '_base_mean.plk'):
145 | base_mean = load_pickle(save_path + '_base_mean.plk')
146 | return base_mean
147 |
148 | model = load_pretrained_backbone(save_path)
149 |
150 | model.eval()
151 | with torch.no_grad():
152 | base_mean = []
153 |
154 | tqdm_train_loader = get_tqdm(train_loader, silent)
155 | if not silent:
156 | tqdm_train_loader.set_description('Computing average on base classes')
157 |
158 | for i, (inputs, _) in enumerate(tqdm_train_loader):
159 | outputs, _ = model(inputs, return_feature=True)
160 | outputs = outputs.cpu().data.numpy()
161 | base_mean.append(outputs)
162 | base_mean = np.concatenate(base_mean, axis=0).mean(axis=0)
163 | save_pickle(save_path + '_base_mean.plk', base_mean)
164 | return base_mean
165 |
166 |
167 | def get_configuration(dataset, backbone):
168 | """
169 | Get configuration according to dataset and backbone.
170 | """
171 |
172 | data_path = './data/' + dataset + '/images'
173 | split_path = './data/' + dataset + '/split'
174 | save_path = './saved_models/' + dataset + '_' + backbone
175 |
176 | if dataset == 'mini':
177 | num_classes = 64
178 | elif dataset == 'tiered':
179 | num_classes = 351
180 | elif dataset == 'cub':
181 | num_classes = 100
182 | else:
183 | raise NotImplementedError
184 |
185 | return data_path, split_path, save_path, num_classes
186 |
187 |
188 | def load_pretrained_backbone(save_path):
189 | dataset = re.split('[/_]', save_path)[-2]
190 | backbone = re.split('[/_]', save_path)[-1]
191 |
192 | if dataset == 'mini':
193 | num_classes = 64
194 | elif dataset == 'tiered':
195 | num_classes = 351
196 | elif dataset == 'cub':
197 | num_classes = 100
198 | else:
199 | raise NotImplementedError
200 |
201 | model = network.__dict__[backbone](num_classes=num_classes)
202 | model = torch.nn.DataParallel(model).cuda()
203 | model.load_state_dict(torch.load(save_path + '.pt'))
204 |
205 | return model
206 |
--------------------------------------------------------------------------------
/graph/data/citeseer.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/citeseer.npz
--------------------------------------------------------------------------------
/graph/data/cora.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/cora.npz
--------------------------------------------------------------------------------
/graph/data/pubmed.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data/pubmed.npz
--------------------------------------------------------------------------------
/graph/data_process/__pycache__/preprocess.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data_process/__pycache__/preprocess.cpython-37.pyc
--------------------------------------------------------------------------------
/graph/data_process/__pycache__/preprocess.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shwangtangjun/Diff-ResNet/687f342dbd2ad5477d99fd8161ae04b5d17819dc/graph/data_process/__pycache__/preprocess.cpython-39.pyc
--------------------------------------------------------------------------------
/graph/data_process/io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import scipy.sparse as sp
4 | from .preprocess import eliminate_self_loops as eliminate_self_loops_adj, largest_connected_components
5 |
6 |
7 | class SparseGraph:
8 | """Attributed labeled graph stored in sparse matrix form.
9 |
10 | """
11 |
12 | def __init__(self, adj_matrix, attr_matrix=None, labels=None,
13 | node_names=None, attr_names=None, class_names=None, metadata=None):
14 | """Create an attributed graph.
15 |
16 | Parameters
17 | ----------
18 | adj_matrix : sp.csr_matrix, shape [num_nodes, num_nodes]
19 | Adjacency matrix in CSR format.
20 | attr_matrix : sp.csr_matrix or np.ndarray, shape [num_nodes, num_attr], optional
21 | Attribute matrix in CSR or numpy format.
22 | labels : np.ndarray, shape [num_nodes], optional
23 | Array, where each entry represents respective node's label(s).
24 | node_names : np.ndarray, shape [num_nodes], optional
25 | Names of nodes (as strings).
26 | attr_names : np.ndarray, shape [num_attr]
27 | Names of the attributes (as strings).
28 | class_names : np.ndarray, shape [num_classes], optional
29 | Names of the class labels (as strings).
30 | metadata : object
31 | Additional metadata such as text.
32 |
33 | """
34 | # Make sure that the dimensions of matrices / arrays all agree
35 | if sp.isspmatrix(adj_matrix):
36 | adj_matrix = adj_matrix.tocsr().astype(np.float32)
37 | else:
38 | raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)"
39 | .format(type(adj_matrix)))
40 |
41 | if adj_matrix.shape[0] != adj_matrix.shape[1]:
42 | raise ValueError("Dimensions of the adjacency matrix don't agree")
43 |
44 | if attr_matrix is not None:
45 | if sp.isspmatrix(attr_matrix):
46 | attr_matrix = attr_matrix.tocsr().astype(np.float32)
47 | elif isinstance(attr_matrix, np.ndarray):
48 | attr_matrix = attr_matrix.astype(np.float32)
49 | else:
50 | raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)"
51 | .format(type(attr_matrix)))
52 |
53 | if attr_matrix.shape[0] != adj_matrix.shape[0]:
54 | raise ValueError("Dimensions of the adjacency and attribute matrices don't agree")
55 |
56 | if labels is not None:
57 | if labels.shape[0] != adj_matrix.shape[0]:
58 | raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree")
59 |
60 | if node_names is not None:
61 | if len(node_names) != adj_matrix.shape[0]:
62 | raise ValueError("Dimensions of the adjacency matrix and the node names don't agree")
63 |
64 | if attr_names is not None:
65 | if len(attr_names) != attr_matrix.shape[1]:
66 | raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree")
67 |
68 | self.adj_matrix = adj_matrix
69 | self.attr_matrix = attr_matrix
70 | self.labels = labels
71 | self.node_names = node_names
72 | self.attr_names = attr_names
73 | self.class_names = class_names
74 | self.metadata = metadata
75 |
76 | def num_nodes(self):
77 | """Get the number of nodes in the graph."""
78 | return self.adj_matrix.shape[0]
79 |
80 | def num_edges(self):
81 | """Get the number of edges in the graph.
82 |
83 | For undirected graphs, (i, j) and (j, i) are counted as single edge.
84 | """
85 | if self.is_directed():
86 | return int(self.adj_matrix.nnz)
87 | else:
88 | return int(self.adj_matrix.nnz / 2)
89 |
90 | def get_neighbors(self, idx):
91 | """Get the indices of neighbors of a given node.
92 |
93 | Parameters
94 | ----------
95 | idx : int
96 | Index of the node whose neighbors are of interest.
97 |
98 | """
99 | return self.adj_matrix[idx].indices
100 |
101 | def is_directed(self):
102 | """Check if the graph is directed (adjacency matrix is not symmetric)."""
103 | return (self.adj_matrix != self.adj_matrix.T).sum() != 0
104 |
105 | def to_undirected(self):
106 | """Convert to an undirected graph (make adjacency matrix symmetric)."""
107 | if self.is_weighted():
108 | raise ValueError("Convert to unweighted graph first.")
109 | else:
110 | self.adj_matrix = self.adj_matrix + self.adj_matrix.T
111 | self.adj_matrix[self.adj_matrix != 0] = 1
112 | return self
113 |
114 | def is_weighted(self):
115 | """Check if the graph is weighted (edge weights other than 1)."""
116 | return np.any(np.unique(self.adj_matrix[self.adj_matrix != 0].A1) != 1)
117 |
118 | def to_unweighted(self):
119 | """Convert to an unweighted graph (set all edge weights to 1)."""
120 | self.adj_matrix.data = np.ones_like(self.adj_matrix.data)
121 | return self
122 |
123 | # Quality of life (shortcuts)
124 | def standardize(self):
125 | """Select the LCC of the unweighted/undirected/no-self-loop graph.
126 |
127 | All changes are done inplace.
128 |
129 | """
130 | G = self.to_unweighted().to_undirected()
131 | G = eliminate_self_loops(G)
132 | G = largest_connected_components(G, 1)
133 | return G
134 |
135 | def unpack(self):
136 | """Return the (A, X, z) triplet."""
137 | return self.adj_matrix, self.attr_matrix, self.labels
138 |
139 |
140 | def eliminate_self_loops(G):
141 | G.adj_matrix = eliminate_self_loops_adj(G.adj_matrix)
142 | return G
143 |
144 |
145 | def load_dataset(data_path):
146 | """Load a dataset.
147 |
148 | Parameters
149 | ----------
150 | data_path : str
151 | Name of the dataset to load.
152 |
153 | Returns
154 | -------
155 | sparse_graph : SparseGraph
156 | The requested dataset in sparse format.
157 |
158 | """
159 | if not data_path.endswith('.npz'):
160 | data_path += '.npz'
161 | if os.path.isfile(data_path):
162 | return load_npz_to_sparse_graph(data_path)
163 | else:
164 | raise ValueError(f"{data_path} doesn't exist.")
165 |
166 |
167 | def load_npz_to_sparse_graph(file_name):
168 | """Load a SparseGraph from a Numpy binary file.
169 |
170 | Parameters
171 | ----------
172 | file_name : str
173 | Name of the file to load.
174 |
175 | Returns
176 | -------
177 | sparse_graph : SparseGraph
178 | Graph in sparse matrix format.
179 |
180 | """
181 | with np.load(file_name) as loader:
182 | loader = dict(loader)
183 | adj_matrix = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']),
184 | shape=loader['adj_shape'])
185 |
186 | if 'attr_data' in loader:
187 | # Attributes are stored as a sparse CSR matrix
188 | attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']),
189 | shape=loader['attr_shape'])
190 | elif 'attr_matrix' in loader:
191 | # Attributes are stored as a (dense) np.ndarray
192 | attr_matrix = loader['attr_matrix']
193 | else:
194 | attr_matrix = None
195 |
196 | if 'labels_data' in loader:
197 | # Labels are stored as a CSR matrix
198 | labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']),
199 | shape=loader['labels_shape'])
200 | elif 'labels' in loader:
201 | # Labels are stored as a numpy array
202 | labels = loader['labels']
203 | else:
204 | labels = None
205 |
206 | node_names = loader.get('node_names')
207 | attr_names = loader.get('attr_names')
208 | class_names = loader.get('class_names')
209 | metadata = loader.get('metadata')
210 |
211 | return SparseGraph(adj_matrix, attr_matrix, labels, node_names, attr_names, class_names, metadata)
212 |
213 |
214 | def save_sparse_graph_to_npz(filepath, sparse_graph):
215 | """Save a SparseGraph to a Numpy binary file.
216 |
217 | Parameters
218 | ----------
219 | filepath : str
220 | Name of the output file.
221 | sparse_graph : gust.SparseGraph
222 | Graph in sparse matrix format.
223 |
224 | """
225 | data_dict = {
226 | 'adj_data': sparse_graph.adj_matrix.data,
227 | 'adj_indices': sparse_graph.adj_matrix.indices,
228 | 'adj_indptr': sparse_graph.adj_matrix.indptr,
229 | 'adj_shape': sparse_graph.adj_matrix.shape
230 | }
231 | if sp.isspmatrix(sparse_graph.attr_matrix):
232 | data_dict['attr_data'] = sparse_graph.attr_matrix.data
233 | data_dict['attr_indices'] = sparse_graph.attr_matrix.indices
234 | data_dict['attr_indptr'] = sparse_graph.attr_matrix.indptr
235 | data_dict['attr_shape'] = sparse_graph.attr_matrix.shape
236 | elif isinstance(sparse_graph.attr_matrix, np.ndarray):
237 | data_dict['attr_matrix'] = sparse_graph.attr_matrix
238 |
239 | if sp.isspmatrix(sparse_graph.labels):
240 | data_dict['labels_data'] = sparse_graph.labels.data
241 | data_dict['labels_indices'] = sparse_graph.labels.indices
242 | data_dict['labels_indptr'] = sparse_graph.labels.indptr
243 | data_dict['labels_shape'] = sparse_graph.labels.shape
244 | elif isinstance(sparse_graph.labels, np.ndarray):
245 | data_dict['labels'] = sparse_graph.labels
246 |
247 | if sparse_graph.node_names is not None:
248 | data_dict['node_names'] = sparse_graph.node_names
249 |
250 | if sparse_graph.attr_names is not None:
251 | data_dict['attr_names'] = sparse_graph.attr_names
252 |
253 | if sparse_graph.class_names is not None:
254 | data_dict['class_names'] = sparse_graph.class_names
255 |
256 | if sparse_graph.metadata is not None:
257 | data_dict['metadata'] = sparse_graph.metadata
258 |
259 | if not filepath.endswith('.npz'):
260 | filepath += '.npz'
261 |
262 | np.savez(filepath, **data_dict)
263 |
--------------------------------------------------------------------------------
/graph/data_process/make_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .io import load_dataset
3 | from .preprocess import to_binary_bag_of_words, remove_underrepresented_classes, eliminate_self_loops, binarize_labels
4 |
5 |
6 | def get_dataset(name, data_path, standardize, train_examples_per_class=None, val_examples_per_class=None):
7 | dataset_graph = load_dataset(data_path)
8 |
9 | # some standardization preprocessing
10 | if standardize:
11 | dataset_graph = dataset_graph.standardize()
12 | else:
13 | dataset_graph = dataset_graph.to_undirected()
14 | dataset_graph = eliminate_self_loops(dataset_graph)
15 |
16 | if train_examples_per_class is not None and val_examples_per_class is not None:
17 | if name == 'cora_full':
18 | # cora_full has some classes that have very few instances. We have to remove these in order for
19 | # split generation not to fail
20 | dataset_graph = remove_underrepresented_classes(dataset_graph,
21 | train_examples_per_class, val_examples_per_class)
22 | dataset_graph = dataset_graph.standardize()
23 | # To avoid future bugs: the above two lines should be repeated to a fixpoint, otherwise code below might
24 | # fail. However, for cora_full the fixpoint is reached after one iteration, so leave it like this for now.
25 |
26 | graph_adj, node_features, labels = dataset_graph.unpack()
27 | labels = binarize_labels(labels)
28 |
29 | # convert to binary bag-of-words feature representation if necessary
30 | if not is_binary_bag_of_words(node_features):
31 | node_features = to_binary_bag_of_words(node_features)
32 |
33 | # some assertions that need to hold for all datasets
34 | # adj matrix needs to be symmetric
35 | assert (graph_adj != graph_adj.T).nnz == 0
36 | # features need to be binary bag-of-word vectors
37 | assert is_binary_bag_of_words(node_features), f"Non-binary node_features entry!"
38 |
39 | return graph_adj, node_features, labels
40 |
41 |
42 | def get_train_val_test_split(random_state,
43 | labels,
44 | train_examples_per_class=None, val_examples_per_class=None,
45 | test_examples_per_class=None,
46 | train_size=None, val_size=None, test_size=None):
47 | num_samples, num_classes = labels.shape
48 | remaining_indices = list(range(num_samples))
49 |
50 | if train_examples_per_class is not None:
51 | train_indices = sample_per_class(random_state, labels, train_examples_per_class)
52 | else:
53 | # select train examples with no respect to class distribution
54 | train_indices = random_state.choice(remaining_indices, train_size, replace=False)
55 |
56 | if val_examples_per_class is not None:
57 | val_indices = sample_per_class(random_state, labels, val_examples_per_class, forbidden_indices=train_indices)
58 | else:
59 | remaining_indices = np.setdiff1d(remaining_indices, train_indices)
60 | val_indices = random_state.choice(remaining_indices, val_size, replace=False)
61 |
62 | forbidden_indices = np.concatenate((train_indices, val_indices))
63 | if test_examples_per_class is not None:
64 | test_indices = sample_per_class(random_state, labels, test_examples_per_class,
65 | forbidden_indices=forbidden_indices)
66 | elif test_size is not None:
67 | remaining_indices = np.setdiff1d(remaining_indices, forbidden_indices)
68 | test_indices = random_state.choice(remaining_indices, test_size, replace=False)
69 | else:
70 | test_indices = np.setdiff1d(remaining_indices, forbidden_indices)
71 |
72 | # assert that there are no duplicates in sets
73 | assert len(set(train_indices)) == len(train_indices)
74 | assert len(set(val_indices)) == len(val_indices)
75 | assert len(set(test_indices)) == len(test_indices)
76 | # assert sets are mutually exclusive
77 | assert len(set(train_indices) - set(val_indices)) == len(set(train_indices))
78 | assert len(set(train_indices) - set(test_indices)) == len(set(train_indices))
79 | assert len(set(val_indices) - set(test_indices)) == len(set(val_indices))
80 | if test_size is None and test_examples_per_class is None:
81 | # all indices must be part of the split
82 | assert len(np.concatenate((train_indices, val_indices, test_indices))) == num_samples
83 |
84 | if train_examples_per_class is not None:
85 | train_labels = labels[train_indices, :]
86 | train_sum = np.sum(train_labels, axis=0)
87 | # assert all classes have equal cardinality
88 | assert np.unique(train_sum).size == 1
89 |
90 | if val_examples_per_class is not None:
91 | val_labels = labels[val_indices, :]
92 | val_sum = np.sum(val_labels, axis=0)
93 | # assert all classes have equal cardinality
94 | assert np.unique(val_sum).size == 1
95 |
96 | if test_examples_per_class is not None:
97 | test_labels = labels[test_indices, :]
98 | test_sum = np.sum(test_labels, axis=0)
99 | # assert all classes have equal cardinality
100 | assert np.unique(test_sum).size == 1
101 |
102 | return train_indices, val_indices, test_indices
103 |
104 |
105 | def sample_per_class(random_state, labels, num_examples_per_class, forbidden_indices=None):
106 | num_samples, num_classes = labels.shape
107 | sample_indices_per_class = {index: [] for index in range(num_classes)}
108 |
109 | # get indices sorted by class
110 | for class_index in range(num_classes):
111 | for sample_index in range(num_samples):
112 | if labels[sample_index, class_index] > 0.0:
113 | if forbidden_indices is None or sample_index not in forbidden_indices:
114 | sample_indices_per_class[class_index].append(sample_index)
115 |
116 | # get specified number of indices for each class
117 | return np.concatenate(
118 | [random_state.choice(sample_indices_per_class[class_index], num_examples_per_class, replace=False)
119 | for class_index in range(len(sample_indices_per_class))])
120 |
121 |
122 | def is_binary_bag_of_words(features):
123 | features_coo = features.tocoo()
124 | return all(single_entry == 1.0 for _, _, single_entry in zip(features_coo.row, features_coo.col, features_coo.data))
125 |
126 |
--------------------------------------------------------------------------------
/graph/data_process/preprocess.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | from collections import Counter
4 | from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer, normalize
5 |
6 |
7 | def to_binary_bag_of_words(features):
8 | """Converts TF/IDF features to binary bag-of-words features."""
9 | features_copy = features.tocsr()
10 | features_copy.data[:] = 1.0
11 | return features_copy
12 |
13 |
14 | def normalize_adj(A):
15 | """Compute D^-1/2 * A * D^-1/2."""
16 | # Make sure that there are no self-loops
17 | A = eliminate_self_loops(A)
18 | D = np.ravel(A.sum(1))
19 | D[D == 0] = 1 # avoid division by 0 error
20 | D_sqrt = np.sqrt(D)
21 | return A / D_sqrt[:, None] / D_sqrt[None, :]
22 |
23 |
24 | def renormalize_adj(A):
25 | """Renormalize the adjacency matrix (as in the GCN paper)."""
26 | A_tilde = A.tolil()
27 | A_tilde.setdiag(1)
28 | A_tilde = A_tilde.tocsr()
29 | A_tilde.eliminate_zeros()
30 | D = np.ravel(A.sum(1))
31 | D_sqrt = np.sqrt(D)
32 | return A / D_sqrt[:, None] / D_sqrt[None, :]
33 |
34 |
35 | def row_normalize(matrix):
36 | """Normalize the matrix so that the rows sum up to 1."""
37 | return normalize(matrix, norm='l1', axis=1)
38 |
39 |
40 | def add_self_loops(A, value=1.0):
41 | """Set the diagonal."""
42 | A = A.tolil() # make sure we work on a copy of the original matrix
43 | A.setdiag(value)
44 | A = A.tocsr()
45 | if value == 0:
46 | A.eliminate_zeros()
47 | return A
48 |
49 |
50 | def eliminate_self_loops(A):
51 | """Remove self-loops from the adjacency matrix."""
52 | A = A.tolil()
53 | A.setdiag(0)
54 | A = A.tocsr()
55 | A.eliminate_zeros()
56 | return A
57 |
58 |
59 | def largest_connected_components(sparse_graph, n_components=1):
60 | """Select the largest connected components in the graph.
61 |
62 | Parameters
63 | ----------
64 | sparse_graph : SparseGraph
65 | Input graph.
66 | n_components : int, default 1
67 | Number of largest connected components to keep.
68 |
69 | Returns
70 | -------
71 | sparse_graph : SparseGraph
72 | Subgraph of the input graph where only the nodes in largest n_components are kept.
73 |
74 | """
75 | _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix)
76 | component_sizes = np.bincount(component_indices)
77 | components_to_keep = np.argsort(component_sizes)[::-1][:n_components] # reverse order to sort descending
78 | nodes_to_keep = [
79 | idx for (idx, component) in enumerate(component_indices) if component in components_to_keep
80 | ]
81 | return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep)
82 |
83 |
84 | def create_subgraph(sparse_graph, _sentinel=None, nodes_to_remove=None, nodes_to_keep=None):
85 | """Create a graph with the specified subset of nodes.
86 |
87 | Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None.
88 | Note that to avoid confusion, it is required to pass node indices as named arguments to this function.
89 |
90 | Parameters
91 | ----------
92 | sparse_graph : SparseGraph
93 | Input graph.
94 | _sentinel : None
95 | Internal, to prevent passing positional arguments. Do not use.
96 | nodes_to_remove : array-like of int
97 | Indices of nodes that have to removed.
98 | nodes_to_keep : array-like of int
99 | Indices of nodes that have to be kept.
100 |
101 | Returns
102 | -------
103 | sparse_graph : SparseGraph
104 | Graph with specified nodes removed.
105 |
106 | """
107 | # Check that arguments are passed correctly
108 | if _sentinel is not None:
109 | raise ValueError("Only call `create_subgraph` with named arguments',"
110 | " (nodes_to_remove=...) or (nodes_to_keep=...)")
111 | if nodes_to_remove is None and nodes_to_keep is None:
112 | raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.")
113 | elif nodes_to_remove is not None and nodes_to_keep is not None:
114 | raise ValueError("Only one of nodes_to_remove or nodes_to_keep must be provided.")
115 | elif nodes_to_remove is not None:
116 | nodes_to_keep = [i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove]
117 | elif nodes_to_keep is not None:
118 | nodes_to_keep = sorted(nodes_to_keep)
119 | else:
120 | raise RuntimeError("This should never happen.")
121 |
122 | sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep]
123 | if sparse_graph.attr_matrix is not None:
124 | sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep]
125 | if sparse_graph.labels is not None:
126 | sparse_graph.labels = sparse_graph.labels[nodes_to_keep]
127 | if sparse_graph.node_names is not None:
128 | sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep]
129 | return sparse_graph
130 |
131 |
132 | def binarize_labels(labels, sparse_output=False, return_classes=False):
133 | """Convert labels vector to a binary label matrix.
134 |
135 | In the default single-label case, labels look like
136 | labels = [y1, y2, y3, ...].
137 | Also supports the multi-label format.
138 | In this case, labels should look something like
139 | labels = [[y11, y12], [y21, y22, y23], [y31], ...].
140 |
141 | Parameters
142 | ----------
143 | labels : array-like, shape [num_samples]
144 | Array of node labels in categorical single- or multi-label format.
145 | sparse_output : bool, default False
146 | Whether return the label_matrix in CSR format.
147 | return_classes : bool, default False
148 | Whether return the classes corresponding to the columns of the label matrix.
149 |
150 | Returns
151 | -------
152 | label_matrix : np.ndarray or sp.csr_matrix, shape [num_samples, num_classes]
153 | Binary matrix of class labels.
154 | num_classes = number of unique values in "labels" array.
155 | label_matrix[i, k] = 1 <=> node i belongs to class k.
156 | classes : np.array, shape [num_classes], optional
157 | Classes that correspond to each column of the label_matrix.
158 |
159 | """
160 | if hasattr(labels[0], '__iter__'): # labels[0] is iterable <=> multilabel format
161 | binarizer = MultiLabelBinarizer(sparse_output=sparse_output)
162 | else:
163 | binarizer = LabelBinarizer(sparse_output=sparse_output)
164 | label_matrix = binarizer.fit_transform(labels).astype(np.float32)
165 | return (label_matrix, binarizer.classes_) if return_classes else label_matrix
166 |
167 |
168 | def remove_underrepresented_classes(g, train_examples_per_class, val_examples_per_class):
169 | """Remove nodes from graph that correspond to a class of which there are less than
170 | num_classes * train_examples_per_class + num_classes * val_examples_per_class nodes.
171 |
172 | Those classes would otherwise break the training procedure.
173 | """
174 | min_examples_per_class = train_examples_per_class + val_examples_per_class
175 | examples_counter = Counter(g.labels)
176 | keep_classes = set(class_ for class_, count in examples_counter.items() if count > min_examples_per_class)
177 | keep_indices = [i for i in range(len(g.labels)) if g.labels[i] in keep_classes]
178 |
179 | return create_subgraph(g, nodes_to_keep=keep_indices)
180 |
--------------------------------------------------------------------------------
/graph/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class DiffusionLayer(nn.Module):
7 | def __init__(self, step):
8 | super(DiffusionLayer, self).__init__()
9 | self.step = step
10 |
11 | def forward(self, x, adj, diagonal):
12 | x = x - self.step * torch.matmul(diagonal - adj, x)
13 | return x
14 |
15 |
16 | class DiffusionNet(nn.Module):
17 | def __init__(self, n_features, num_classes, step, layer_num, dropout, diagonal):
18 | super(DiffusionNet, self).__init__()
19 |
20 | self.linear = nn.Linear(n_features, n_features)
21 |
22 | self.diffusion_layer = DiffusionLayer(step)
23 | self.classifier = nn.Linear(n_features, num_classes)
24 | self.dropout = dropout
25 | self.layer_num = layer_num
26 | self.diagonal = diagonal
27 |
28 | def forward(self, x, adj):
29 | x = x + F.relu(self.linear(x))
30 | for j in range(self.layer_num):
31 | x = F.dropout(x, self.dropout, training=self.training)
32 | x = self.diffusion_layer(x, adj, self.diagonal)
33 |
34 | out = self.classifier(x)
35 | return out
36 |
--------------------------------------------------------------------------------
/graph/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.optim as optim
6 | import torch.nn as nn
7 | from utils import load_data, accuracy
8 | from model import DiffusionNet
9 | import copy
10 |
11 | # Training settings
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('--seed', type=int, default=42, help='Random seed.')
14 | parser.add_argument('--num_splits', type=int, default=100, help='Number of different splits.')
15 | parser.add_argument('--num_inits', type=int, default=20, help='Number of different initializations.')
16 | parser.add_argument('--device', type=str, default='0')
17 |
18 | parser.add_argument('--max_epochs', type=int, default=10000, help='Max uumber of epochs to train.')
19 | parser.add_argument('--lr', type=float, default=0.01, help='Initial learning rate.')
20 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 loss on parameters).')
21 | parser.add_argument('--patience', type=int, default=50, help='Early Stop Patience.')
22 |
23 | parser.add_argument('--dataset', type=str, default="cora")
24 | parser.add_argument('--step_size', type=float)
25 | parser.add_argument('--layer_num', type=int)
26 | parser.add_argument('--dropout', type=float, default=0.5, help='Dropout rate (1 - keep probability).')
27 |
28 | args = parser.parse_args()
29 |
30 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
31 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # specify which GPU(s) to be used
32 |
33 |
34 | def train(model, optimizer, adj, features, labels, idx_train):
35 | model.train()
36 | optimizer.zero_grad()
37 |
38 | output = model(features, adj)
39 | loss = nn.CrossEntropyLoss()(output[idx_train], labels[idx_train])
40 | loss.backward()
41 | optimizer.step()
42 |
43 |
44 | def val(model, adj, features, labels, idx_val):
45 | model.eval()
46 | output = model(features, adj)
47 | loss = nn.CrossEntropyLoss()(output[idx_val], labels[idx_val])
48 | acc = accuracy(output[idx_val], labels[idx_val])
49 | loss = loss.detach().cpu().numpy()
50 | acc = acc.cpu().numpy()
51 |
52 | return loss, acc
53 |
54 |
55 | def test(model, adj, features, labels, idx_test):
56 | model.eval()
57 | output = model(features, adj)
58 | acc_test = accuracy(output[idx_test], labels[idx_test])
59 | return acc_test
60 |
61 |
62 | def run_single_trial_of_single_split(adj, features, labels, idx_train, idx_val, idx_test, diagonal, torch_seeds):
63 | torch.manual_seed(torch_seeds)
64 | torch.cuda.manual_seed(torch_seeds)
65 |
66 | model = DiffusionNet(n_features=features.shape[1], num_classes=labels.max().item() + 1, step=args.step_size,
67 | layer_num=args.layer_num, dropout=args.dropout, diagonal=diagonal.cuda())
68 |
69 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
70 |
71 | model = model.cuda()
72 | features = features.cuda()
73 | adj = adj.cuda()
74 | labels = labels.cuda()
75 | idx_train = idx_train.cuda()
76 | idx_val = idx_val.cuda()
77 | idx_test = idx_test.cuda()
78 |
79 | val_loss_min = np.inf
80 | val_acc_max = 0
81 | patience_step = 0
82 | best_state_dict = None
83 |
84 | val_loss_list = []
85 | val_acc_list = []
86 | for epoch in range(args.max_epochs):
87 | train(model, optimizer, adj, features, labels, idx_train)
88 | val_loss, val_acc = val(model, adj, features, labels, idx_val)
89 | val_loss_list.append(val_loss)
90 | val_acc_list.append(val_acc)
91 |
92 | if val_loss <= val_loss_min or val_acc >= val_acc_max:
93 | val_loss_min = np.min((val_loss, val_loss_min))
94 | val_acc_max = np.max((val_acc, val_acc_max))
95 | patience_step = 0
96 | best_state_dict = copy.deepcopy(model.state_dict())
97 | else:
98 | patience_step += 1
99 |
100 | if patience_step >= args.patience:
101 | model.load_state_dict(best_state_dict)
102 | break
103 |
104 | acc = test(model, adj, features, labels, idx_test)
105 | acc = acc.cpu().numpy()
106 | return acc
107 |
108 |
109 | def run_single_split(seed):
110 | random_state = np.random.RandomState(seed)
111 | adj, features, labels, idx_train, idx_val, idx_test, diagonal = load_data(args.dataset, random_state)
112 | torch_seeds = random_state.randint(0, 1000000, args.num_inits) # 20 trials for each split
113 | acc_list = []
114 | for i in range(args.num_inits):
115 | acc = run_single_trial_of_single_split(adj, features, labels, idx_train, idx_val, idx_test, diagonal,
116 | torch_seeds[i])
117 | acc_list.append(acc)
118 | return np.array(acc_list)
119 |
120 |
121 | def main():
122 | random_state = np.random.RandomState(args.seed)
123 | single_split_seed = random_state.randint(0, 1000000, args.num_splits) # 100 random splits
124 |
125 | total_acc_list = []
126 | for i in range(args.num_splits):
127 | acc_of_single_split = run_single_split(single_split_seed[i])
128 | print(acc_of_single_split)
129 | total_acc_list.append(acc_of_single_split)
130 |
131 | print(np.mean(total_acc_list) * 100)
132 | print(np.std(total_acc_list) * 100)
133 | print(args.dropout)
134 | print(args.step_size)
135 | print(args.layer_num)
136 |
137 |
138 | if __name__ == '__main__':
139 | main()
140 |
--------------------------------------------------------------------------------
/graph/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import torch
4 | from data_process.make_dataset import get_dataset, get_train_val_test_split
5 |
6 |
7 | def load_data(dataset_str, random_state):
8 | data_path = "data/" + dataset_str + ".npz"
9 |
10 | adj, features, labels = get_dataset(dataset_str, data_path, standardize=True, train_examples_per_class=20,
11 | val_examples_per_class=30)
12 | idx_train, idx_val, idx_test = get_train_val_test_split(random_state, labels, train_examples_per_class=20,
13 | val_examples_per_class=30, test_size=None)
14 |
15 | features = normalize_features(features)
16 | adj = normalize_adj(adj + sp.eye(adj.shape[0]), normalization="symmetric")
17 |
18 | diagonal = sp.diags(adj.sum(1).A1)
19 | diagonal = sparse_mx_to_torch_sparse_tensor(diagonal)
20 |
21 | adj = sparse_mx_to_torch_sparse_tensor(adj)
22 | features = torch.FloatTensor(features.todense())
23 | labels = torch.LongTensor(labels.argmax(axis=-1))
24 |
25 | idx_train = torch.LongTensor(idx_train)
26 | idx_val = torch.LongTensor(idx_val)
27 | idx_test = torch.LongTensor(idx_test)
28 |
29 | return adj, features, labels, idx_train, idx_val, idx_test, diagonal
30 |
31 |
32 | def normalize_features(features):
33 | """Row-normalize feature matrix"""
34 | rowsum = np.array(features.sum(1))
35 | r_inv = np.power(rowsum, -1).flatten()
36 | r_inv[np.isinf(r_inv)] = 0.
37 | r_mat_inv = sp.diags(r_inv)
38 | features = r_mat_inv.dot(features)
39 | return features
40 |
41 |
42 | def normalize_adj(adj, normalization="symmetric"):
43 | """Symmetrically or row normalize adjacency matrix."""
44 | if normalization == "symmetric":
45 | rowsum = np.array(adj.sum(1))
46 | d_inv_sqrt = np.power(rowsum, -0.5).flatten()
47 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
48 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
49 | mx = adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt)
50 | elif normalization == "row":
51 | rowsum = np.array(adj.sum(1))
52 | r_inv = np.power(rowsum, -1).flatten()
53 | r_inv[np.isinf(r_inv)] = 0.
54 | r_mat_inv = sp.diags(r_inv)
55 | mx = r_mat_inv.dot(adj)
56 | else:
57 | raise NotImplementedError
58 | return mx
59 |
60 |
61 | def accuracy(output, labels):
62 | preds = output.max(1)[1].type_as(labels)
63 | correct = preds.eq(labels).double()
64 | correct = correct.sum()
65 | return correct / len(labels)
66 |
67 |
68 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
69 | """Convert a scipy sparse matrix to a torch sparse tensor."""
70 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
71 | indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
72 | values = torch.from_numpy(sparse_mx.data)
73 | shape = torch.Size(sparse_mx.shape)
74 | return torch.sparse.FloatTensor(indices, values, shape)
75 |
--------------------------------------------------------------------------------
/synthetic/two_circle_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import random
8 |
9 | random.seed(42)
10 | np.random.seed(42)
11 | torch.manual_seed(42)
12 |
13 |
14 | def make_circles(n_samples=500):
15 | """Make two interleaving half circles.
16 | A simple toy dataset to visualize clustering and classification
17 | algorithms. Read more in the :ref:`User Guide `.
18 | Parameters
19 | """
20 | inner_circ_x = np.cos(np.linspace(0, 2 * np.pi, n_samples))
21 | inner_circ_y = np.sin(np.linspace(0, 2 * np.pi, n_samples))
22 | outer_circ_x = 2.0 * np.cos(np.linspace(0, 2 * np.pi, n_samples))
23 | outer_circ_y = 2.0 * np.sin(np.linspace(0, 2 * np.pi, n_samples))
24 |
25 | x = np.append(outer_circ_x, inner_circ_x)
26 | y = np.append(outer_circ_y, inner_circ_y)
27 |
28 | x += np.random.randn(1000) * 0.05
29 | y += np.random.randn(1000) * 0.05
30 | return x, y
31 |
32 |
33 | def calculate_weight(x, y, sigma=0.5, n_top=50):
34 | weight = np.zeros([1000, 1000])
35 | for i in range(1000):
36 | for j in range(1000):
37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2)
38 |
39 | # Sparse and Normalize
40 | for i in range(1000):
41 | idx = np.argpartition(weight[i], -n_top)[:-n_top]
42 | weight[i, idx] = 0.
43 | weight[i] /= weight[i].sum()
44 |
45 | return weight
46 |
47 |
48 | class DiffusionLayer(nn.Module):
49 | def __init__(self, step):
50 | super(DiffusionLayer, self).__init__()
51 | self.step = step
52 |
53 | def forward(self, x, adj):
54 | identity = torch.eye(x.size(0), device=x.device)
55 | x = x - self.step * torch.matmul(identity - adj, x.flatten(1)).view_as(x)
56 | return x
57 |
58 |
59 | class Net(nn.Module):
60 | def __init__(self):
61 | super(Net, self).__init__()
62 | self.fc1 = nn.Linear(2, 2)
63 | self.fc2 = nn.Linear(2, 2)
64 | self.classifier = nn.Linear(2, 2)
65 | self.layer_num = 200
66 | self.diffusion_layer = DiffusionLayer(step=1.0)
67 |
68 | def forward(self, x, weight):
69 | out = self.fc2(F.relu(self.fc1(x))) + x
70 |
71 | # Uncomment following lines to use diffusion
72 | # for i in range(self.layer_num):
73 | # out = self.diffusion_layer(out, weight)
74 | res = self.classifier(out)
75 | return res, out
76 |
77 |
78 | def train(model, inputs, weight, labels):
79 | optimizer = optim.SGD(model.parameters(), lr=1.0, momentum=0.9, weight_decay=5e-4)
80 | optimizer.zero_grad()
81 |
82 | outputs, features = model(inputs, weight)
83 | loss = nn.CrossEntropyLoss()(outputs, labels)
84 | loss.backward()
85 | optimizer.step()
86 |
87 |
88 | def test(model, inputs, weight, labels):
89 | outputs, features = model(inputs, weight)
90 |
91 | pred = outputs.argmax(1)
92 | acc = torch.eq(pred, labels).sum()
93 | return acc.item()
94 |
95 |
96 | def main():
97 | x, y = make_circles()
98 |
99 | color = [i for i in ['red', 'blue'] for _ in range(500)]
100 | plt.scatter(x, y, c=color, marker='.')
101 | plt.xticks([])
102 | plt.yticks([])
103 | plt.savefig("figures/two_circle/raw.png", bbox_inches='tight')
104 |
105 | weight = calculate_weight(x, y)
106 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight)
107 | inputs = torch.stack([x, y], dim=1).float()
108 | weight = weight.float()
109 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long()
110 |
111 | acc_list = np.zeros(21)
112 | model = Net()
113 | for epoch in range(21):
114 | outputs, features = model(inputs, weight)
115 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy()
116 | acc = test(model, inputs, weight, labels)
117 | print(epoch, acc)
118 | acc_list[epoch] = acc
119 |
120 | plt.cla()
121 | color = [i for i in ['red', 'blue'] for _ in range(500)]
122 | plt.scatter(x, y, c=color, marker='.')
123 | plt.xticks([])
124 | plt.yticks([])
125 | plt.savefig("figures/two_circle/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight')
126 | plt.savefig("figures/two_circle/with_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight')
127 |
128 | train(model, inputs, weight, labels)
129 | print(acc_list)
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
134 |
--------------------------------------------------------------------------------
/synthetic/two_moon_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import random
8 |
9 | random.seed(40)
10 | np.random.seed(40)
11 | torch.manual_seed(40)
12 |
13 |
14 | def make_moons(n_samples=500):
15 | """Make two interleaving half circles.
16 | A simple toy dataset to visualize clustering and classification
17 | algorithms. Read more in the :ref:`User Guide `.
18 | Parameters
19 | """
20 | outer_circ_x = np.cos(np.linspace(0, np.pi, n_samples))
21 | outer_circ_y = np.sin(np.linspace(0, np.pi, n_samples))
22 | inner_circ_x = 1 - np.cos(np.linspace(0, np.pi, n_samples))
23 | inner_circ_y = 0.5 - np.sin(np.linspace(0, np.pi, n_samples))
24 |
25 | x = np.append(outer_circ_x, inner_circ_x)
26 | y = np.append(outer_circ_y, inner_circ_y)
27 |
28 | x += np.random.randn(1000) * 0.05
29 | y += np.random.randn(1000) * 0.05
30 | return x, y
31 |
32 |
33 | def calculate_weight(x, y, sigma=0.5, n_top=25):
34 | weight = np.zeros([1000, 1000])
35 | for i in range(1000):
36 | for j in range(1000):
37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2)
38 |
39 | # Sparse and Normalize
40 | for i in range(1000):
41 | idx = np.argpartition(weight[i], -n_top)[:-n_top]
42 | weight[i, idx] = 0.
43 | weight[i] /= weight[i].sum()
44 |
45 | return weight
46 |
47 |
48 | class DiffusionLayer(nn.Module):
49 | def __init__(self, step):
50 | super(DiffusionLayer, self).__init__()
51 | self.step = step
52 |
53 | def forward(self, x, adj):
54 | diagonal = torch.eye(x.size(0), device=x.device)
55 | x = x - self.step * torch.matmul(diagonal - adj, x.flatten(1)).view_as(x)
56 | return x
57 |
58 |
59 | class Net(nn.Module):
60 | def __init__(self):
61 | super(Net, self).__init__()
62 | self.fc1 = nn.Linear(2, 2)
63 | self.fc2 = nn.Linear(2, 2)
64 | self.classifier = nn.Linear(2, 2)
65 | self.layer_num = 60
66 | self.diffusion_layer = DiffusionLayer(step=1.0)
67 |
68 | def forward(self, x, weight):
69 | out = self.fc2(F.relu(self.fc1(x))) + x
70 |
71 | # Uncomment following lines to use diffusion
72 | # for i in range(self.layer_num):
73 | # out = self.diffusion_layer(out, weight)
74 | res = self.classifier(out)
75 | return res, out
76 |
77 |
78 | def train(model, inputs, weight, labels):
79 | optimizer = optim.SGD(model.parameters(), lr=1.0, momentum=0.9, weight_decay=5e-4)
80 | optimizer.zero_grad()
81 |
82 | outputs, features = model(inputs, weight)
83 | loss = nn.CrossEntropyLoss()(outputs, labels)
84 | loss.backward()
85 | optimizer.step()
86 |
87 |
88 | def test(model, inputs, weight, labels):
89 | outputs, features = model(inputs, weight)
90 |
91 | pred = outputs.argmax(1)
92 | acc = torch.eq(pred, labels).sum()
93 | return acc.item()
94 |
95 |
96 | def main():
97 | x, y = make_moons()
98 |
99 | color = [i for i in ['red', 'blue'] for _ in range(500)]
100 | plt.scatter(x, y, c=color, marker='.')
101 | plt.xticks([])
102 | plt.yticks([])
103 | plt.savefig("figures/two_moon/raw.png", bbox_inches='tight')
104 |
105 | weight = calculate_weight(x, y)
106 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight)
107 | inputs = torch.stack([x, y], dim=1).float()
108 | weight = weight.float()
109 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long()
110 |
111 | acc_list = np.zeros(21)
112 | model = Net()
113 | for epoch in range(21):
114 | outputs, features = model(inputs, weight)
115 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy()
116 | acc = test(model, inputs, weight, labels)
117 | print(epoch, acc)
118 | acc_list[epoch] = acc
119 |
120 | plt.cla()
121 | color = [i for i in ['red', 'blue'] for _ in range(500)]
122 | plt.scatter(x, y, c=color, marker='.')
123 | plt.title("accuracy=" + str(round(acc / 1000 * 100, 1)) + "%", fontsize=40)
124 | plt.xticks([])
125 | plt.yticks([])
126 | plt.savefig("figures/two_moon/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight')
127 |
128 | train(model, inputs, weight, labels)
129 | print(acc_list)
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
134 |
--------------------------------------------------------------------------------
/synthetic/two_spiral_example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import random
8 |
9 | random.seed(42)
10 | np.random.seed(42)
11 | torch.manual_seed(42)
12 |
13 |
14 | def make_spirals(n_samples=500):
15 | """Make two interleaving half circles.
16 | A simple toy dataset to visualize clustering and classification
17 | algorithms. Read more in the :ref:`User Guide `.
18 | Parameters
19 | """
20 |
21 | a1, b1, a2, b2 = 1.0, 1.0, -1.0, -1.0
22 | theta = np.linspace(0, 2 * np.pi, n_samples)
23 |
24 | x1 = (a1 + b1 * theta) * np.cos(theta)
25 | y1 = (a1 + b1 * theta) * np.sin(theta)
26 | x2 = (a2 + b2 * theta) * np.cos(theta)
27 | y2 = (a2 + b2 * theta) * np.sin(theta)
28 |
29 | x = np.append(x1, x2)
30 | y = np.append(y1, y2)
31 |
32 | x += np.random.randn(1000) * 0.1
33 | y += np.random.randn(1000) * 0.1
34 |
35 | return x, y
36 |
37 |
38 | def calculate_weight(x, y, sigma=0.5, n_top=25):
39 | weight = np.zeros([1000, 1000])
40 | for i in range(1000):
41 | for j in range(1000):
42 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2)
43 |
44 | # Sparse and Normalize
45 | for i in range(1000):
46 | idx = np.argpartition(weight[i], -n_top)[:-n_top]
47 | weight[i, idx] = 0.
48 | weight[i] /= weight[i].sum()
49 |
50 | return weight
51 |
52 |
53 | class DiffusionLayer(nn.Module):
54 | def __init__(self, step):
55 | super(DiffusionLayer, self).__init__()
56 | self.step = step
57 |
58 | def forward(self, x, adj):
59 | identity = torch.eye(x.size(0), device=x.device)
60 | x = x - self.step * torch.matmul(identity - adj, x.flatten(1)).view_as(x)
61 | return x
62 |
63 |
64 | class Net(nn.Module):
65 | def __init__(self):
66 | super(Net, self).__init__()
67 | self.fc1 = nn.Linear(2, 2)
68 | self.fc2 = nn.Linear(2, 2)
69 | self.classifier = nn.Linear(2, 2)
70 | self.layer_num = 900
71 | self.diffusion_layer = DiffusionLayer(step=1.0)
72 |
73 | def forward(self, x, weight):
74 | out = self.fc2(F.relu(self.fc1(x))) + x
75 |
76 | # Uncomment following lines to use diffusion
77 | # for i in range(self.layer_num):
78 | # out = self.diffusion_layer(out, weight)
79 | res = self.classifier(out)
80 | return res, out
81 |
82 |
83 | def train(model, inputs, weight, labels):
84 | optimizer = optim.SGD(model.parameters(), lr=0.8, momentum=0.9, weight_decay=5e-4)
85 | optimizer.zero_grad()
86 |
87 | outputs, features = model(inputs, weight)
88 | loss = nn.CrossEntropyLoss()(outputs, labels)
89 | loss.backward()
90 | optimizer.step()
91 |
92 |
93 | def test(model, inputs, weight, labels):
94 | outputs, features = model(inputs, weight)
95 |
96 | pred = outputs.argmax(1)
97 | acc = torch.eq(pred, labels).sum()
98 | return acc.item()
99 |
100 |
101 | def main():
102 | x, y = make_spirals()
103 |
104 | color = [i for i in ['red', 'blue'] for _ in range(500)]
105 | plt.scatter(x, y, c=color, marker='.')
106 | plt.xticks([])
107 | plt.yticks([])
108 | plt.savefig("figures/two_spiral/raw.png", bbox_inches='tight')
109 |
110 | weight = calculate_weight(x, y)
111 | x, y, weight = torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(weight)
112 | inputs = torch.stack([x, y], dim=1).float()
113 | weight = weight.float()
114 | labels = torch.cat([torch.zeros(500), torch.ones(500)]).long()
115 |
116 | acc_list = np.zeros(21)
117 | model = Net()
118 | for epoch in range(21):
119 | outputs, features = model(inputs, weight)
120 | x, y = features[:, 0].detach().numpy(), features[:, 1].detach().numpy()
121 | acc = test(model, inputs, weight, labels)
122 | print(epoch, acc)
123 | acc_list[epoch] = acc
124 |
125 | plt.cla()
126 | color = [i for i in ['red', 'blue'] for _ in range(500)]
127 | plt.scatter(x, y, c=color, marker='.')
128 | plt.title("accuracy=" + str(round(acc / 1000 * 100, 1)) + "%", fontsize=40)
129 | plt.xticks([])
130 | plt.yticks([])
131 | plt.savefig("figures/two_spiral/without_diffusion_iter=" + str(epoch) + ".png", bbox_inches='tight')
132 |
133 | train(model, inputs, weight, labels)
134 | print(acc_list)
135 |
136 |
137 | if __name__ == '__main__':
138 | main()
139 |
--------------------------------------------------------------------------------
/synthetic/xor_example.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 | np.random.seed(42)
5 |
6 |
7 | def generate_points(samples_each=100):
8 | """
9 | Uniformly sample 100 points from 4 circles in R^2.
10 | Circles are centered at (0,0),(2,2),(2,0),(0,2), respectively. Their diameters are all 1.
11 | Points from circles centered at (0,0) and (2,2) belong to class 1. Others belong to class 2.
12 | :return: Two numpy array of size (100,)
13 | """
14 | diameter = 1.5
15 | radius = diameter / 2
16 | samples_total = 4 * samples_each
17 | # Why np.sqrt()? https://stats.stackexchange.com/questions/120527/simulate-a-uniform-distribution-on-a-disc
18 | r = np.sqrt(np.random.uniform(0, radius ** 2, samples_total))
19 | theta = np.pi * np.random.uniform(0, 2, samples_total)
20 | x = r * np.cos(theta)
21 | y = r * np.sin(theta)
22 |
23 | for i in range(samples_each, 3 * samples_each):
24 | x[i] += 2.
25 | for i in range(samples_each, 2 * samples_each):
26 | y[i] += 2.
27 | for i in range(3 * samples_each, 4 * samples_each):
28 | y[i] += 2.
29 |
30 | return x, y
31 |
32 |
33 | def calculate_weight(x, y, sigma=0.5, n_top=20, samples_total=400):
34 | weight = np.zeros([samples_total, samples_total])
35 | for i in range(samples_total):
36 | for j in range(samples_total):
37 | weight[i, j] = np.exp(-((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2) / sigma ** 2)
38 |
39 | # Sparse and Normalize
40 | for i in range(samples_total):
41 | idx = np.argpartition(weight[i], -n_top)[:-n_top]
42 | weight[i, idx] = 0.
43 | weight[i] /= weight[i].sum()
44 |
45 | return weight
46 |
47 |
48 | # naive un-vectorized implementation of diffusion
49 | def diffusion(x, y, weight, step_size=1.0, samples_total=400):
50 | new_x = np.zeros_like(x)
51 | new_y = np.zeros_like(y)
52 | for i in range(samples_total):
53 | delta_x = 0.
54 | delta_y = 0.
55 | for j in range(samples_total):
56 | delta_x += weight[i, j] * (x[i] - x[j])
57 | delta_y += weight[i, j] * (y[i] - y[j])
58 | new_x[i] = x[i] - step_size * delta_x
59 | new_y[i] = y[i] - step_size * delta_y
60 | return new_x, new_y
61 |
62 |
63 | def calculate_l(x, y, samples_total=400):
64 | min_l = 999.
65 | samples_each = samples_total / 4
66 | for i in range(samples_total):
67 | for j in range(samples_total):
68 | if i // samples_each != j // samples_each:
69 | l = np.sqrt((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2)
70 | if l < min_l:
71 | min_l = l
72 | return min_l
73 |
74 |
75 | def calculate_d(x, y, samples_total=400):
76 | max_d = 0.
77 | samples_each = samples_total / 4
78 | for i in range(samples_total):
79 | for j in range(samples_total):
80 | if i // samples_each == j // samples_each:
81 | d = np.sqrt((x[i] - x[j]) ** 2 + (y[i] - y[j]) ** 2)
82 | if d > max_d:
83 | max_d = d
84 | return max_d
85 |
86 |
87 | def main():
88 | x, y = generate_points()
89 | weight = calculate_weight(x, y)
90 |
91 | epochs = 201
92 | for i in range(epochs):
93 | plt.cla()
94 | color = [i for i in ['red', 'blue'] for _ in range(200)]
95 |
96 | plt.xticks([])
97 | plt.yticks([])
98 | plt.scatter(x, y, c=color, marker='.', animated=True)
99 | plt.savefig("figures/xor/iter=" + str(i) + ".png", bbox_inches='tight')
100 |
101 | # l = calculate_l(x, y)
102 | # d = calculate_d(x, y)
103 |
104 | x, y = diffusion(x, y, weight)
105 |
106 |
107 | if __name__ == '__main__':
108 | main()
109 |
--------------------------------------------------------------------------------