├── NOTICE ├── requirements.txt ├── extract.py ├── CODE_OF_CONDUCT.md ├── script.sh ├── scripts └── script.sh ├── main.py ├── README.md ├── test_random.py ├── test_saved_graphs.py ├── models.py ├── CONTRIBUTING.md ├── models_gcn.py ├── dense_sgc.py ├── utils.py ├── LICENSE └── graph_agent.py /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.1 2 | ogb==1.3.0 3 | pandas==1.2.3 4 | scikit_learn==1.1.1 5 | scipy==1.10.0 6 | torch==1.8.1 7 | torch_geometric==2.0.1 8 | tqdm==4.60.0 -------------------------------------------------------------------------------- /extract.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | for file in os.listdir('.'): 5 | new_file = file.replace('lr1', 'lra') 6 | new_file = new_file.replace('lr2', 'lrf') 7 | new_file = new_file.split('_') 8 | new_file = '_'.join(new_file[:3] + new_file[4:]) 9 | os.system(f'mv {file} {new_file}') 10 | 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /script.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 # use 0-th gpu to run the experiments 2 | 3 | dataset=DD; lr_adj=2; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 4 | dataset=NCI1; lr_adj=1; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 5 | dataset=ogbg-molhiv; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=sum --seed=0 --ipc=1 --save=0 6 | dataset=ogbg-molbace; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 7 | dataset=ogbg-molbbbp; lr_adj=2; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 8 | dataset=MUTAG; lr_adj=0.1; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 9 | dataset=CIFAR10; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=instancenorm --pool=mean --seed=0 --ipc=1 --save=0 10 | -------------------------------------------------------------------------------- /scripts/script.sh: -------------------------------------------------------------------------------- 1 | gpu_id=0 # use 0-th gpu to run the experiments 2 | 3 | dataset=DD; lr_adj=2; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 4 | dataset=NCI1; lr_adj=1; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 5 | dataset=ogbg-molhiv; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=sum --seed=0 --ipc=1 --save=0 6 | dataset=ogbg-molbace; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 7 | dataset=ogbg-molbbbp; lr_adj=2; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 8 | dataset=MUTAG; lr_adj=0.1; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=0 9 | dataset=CIFAR10; lr_adj=0.01; lr_feat=0.01; python main.py --dataset ${dataset} --init real --gpu_id=${gpu_id} --nconvs=3 --dis=mse --lr_adj=${lr_adj} --lr_feat=${lr_feat} --epochs=1000 --eval_init=1 --net_norm=instancenorm --pool=mean --seed=0 --ipc=1 --save=0 10 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | import torch.nn.functional as F 4 | from graph_agent import GraphAgent 5 | import argparse 6 | import random 7 | import numpy as np 8 | from utils import * 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') 12 | parser.add_argument('--dataset', type=str, default='DD') 13 | parser.add_argument('--epochs', type=int, default=1000) 14 | parser.add_argument('--hidden', type=int, default=128) 15 | parser.add_argument('--init', type=str, default='real') 16 | parser.add_argument('--lr_adj', type=float, default=0.01) 17 | parser.add_argument('--lr_feat', type=float, default=0.01) 18 | parser.add_argument('--dropout', type=float, default=0.) 19 | parser.add_argument('--seed', type=int, default=0) 20 | parser.add_argument('--nconvs', type=int, default=3) 21 | parser.add_argument('--outer', type=int, default=1) 22 | parser.add_argument('--inner', type=int, default=0) 23 | parser.add_argument('--pooling', type=str, default='mean') 24 | parser.add_argument('--lr_model', type=float, default=0.005) 25 | parser.add_argument('--stru_discrete', type=int, default=1) 26 | parser.add_argument('--ipc', type=int, default=0, help='number of condensed samples per class') 27 | parser.add_argument('--reduction_rate', type=float, default=0.1, help='if ipc=0, this param will be enabled') 28 | parser.add_argument('--save', type=int, default=0, help='whether to save the condensed graphs') 29 | parser.add_argument('--dis_metric', type=str, default='mse', help='distance metric') 30 | parser.add_argument('--eval_init', type=int, default=1, help='whether to evaluate initialized graphs') 31 | parser.add_argument('--bs_cond', type=int, default=256, help='batch size for sampling graphs') 32 | parser.add_argument('--net_norm', type=str, default='none') 33 | parser.add_argument('--beta', type=float, default=0.1, help='coefficient for the regularization term') 34 | args = parser.parse_args() 35 | 36 | if args.dataset == 'ogbg-molhiv': 37 | args.pooling = 'sum' 38 | if args.dataset == 'CIFAR10': 39 | args.net_norm = 'instancenorm' 40 | if args.dataset == 'MUTAG' and args.ipc == 50: 41 | args.ipc = 20 42 | torch.cuda.set_device(args.gpu_id) 43 | 44 | # torch.set_num_threads(1) 45 | 46 | print(args) 47 | device = 'cuda' 48 | 49 | data = Dataset(args) 50 | packed_data = data.packed_data 51 | 52 | random.seed(args.seed) 53 | np.random.seed(args.seed) 54 | torch.manual_seed(args.seed) 55 | torch.cuda.manual_seed(args.seed) 56 | 57 | agent = GraphAgent(data=packed_data, args=args, device=device, nnodes_syn=get_mean_nodes(args)) 58 | agent.train() 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DosCond 2 | 3 | [KDD 2022] The implementation for ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) on graph classification is shown below. For node classification, please refer to [link](https://github.com/ChandlerBang/GCond/tree/main/KDD22_DosCond). 4 | 5 | 6 | Abstract 7 | -- 8 | As training deep learning models on large dataset takes a lot of time and resources, it is desired to construct a small synthetic dataset with which we can train deep learning models sufficiently. There are recent works that have explored solutions on condensing image datasets through complex bi-level optimization. For instance, dataset condensation (DC) matches network gradients w.r.t. large-real data and small-synthetic data, where the network weights are optimized for multiple steps at each outer iteration. However, existing approaches have their inherent limitations: (1) they are not directly applicable to graphs where the data is discrete; and (2) the condensation process is computationally expensive due to the involved nested optimization. To bridge the gap, we investigate efficient dataset condensation tailored for graph datasets where we model the discrete graph structure as a probabilistic model. We further propose a one-step gradient matching scheme, which performs gradient matching for only one single step without training the network weights. 9 | 10 | 11 | ## Requirements 12 | All experiments are performed under `python=3.8.8` 13 | 14 | Please see [requirements.txt](https://github.com/amazon-research/doscond/blob/main/requirements.txt). 15 | ``` 16 | numpy==1.20.1 17 | ogb==1.3.0 18 | pandas==1.2.3 19 | scikit_learn==1.1.1 20 | scipy==1.6.2 21 | torch==1.8.1 22 | torch_geometric==2.0.1 23 | tqdm==4.60.0 24 | ``` 25 | 26 | 27 | ## Run the code 28 | Use the following code to run the experiment 29 | ``` 30 | python main.py --dataset DD --init real --gpu_id=0 --nconvs=3 --dis=mse --lr_adj=2 --lr_feat=0.01 --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 31 | ``` 32 | 33 | The hyper-parameter settings are listed in [`script.sh`](https://github.com/amazon-research/doscond/blob/main/script.sh). Run the following command to get the results. 34 | ``` 35 | bash script.sh 36 | ``` 37 | 38 | By specifying `save=1`, we can save the condensed graphs, e.g., 39 | ``` 40 | python main.py --dataset DD --init real --gpu_id=0 --nconvs=3 --dis=mse --lr_adj=2 --lr_feat=0.01 --epochs=1000 --eval_init=1 --net_norm=none --pool=mean --seed=0 --ipc=1 --save=1 41 | ``` 42 | The condensed graphs will be saved under the `saved` directory. Run multiple seeds to get the multiple condesed datasets. Then use the following command to test the condensed dataset for multiple times. 43 | ``` 44 | python test_saved_graphs.py --filename DD_ipc1_s0_lra2.0_lrf0.01.pt --dataset DD 45 | ``` 46 | 47 | ## Security 48 | 49 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 50 | 51 | ## License 52 | 53 | This project is licensed under the Apache-2.0 License. 54 | 55 | 56 | -------------------------------------------------------------------------------- /test_random.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from math import ceil 3 | import torch 4 | from torch_geometric.datasets import TUDataset 5 | import torch_geometric.transforms as T 6 | from torch_geometric.loader import DenseDataLoader 7 | import torch.nn.functional as F 8 | from graph_agent import GraphAgent 9 | import argparse 10 | import random 11 | from utils import * 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') 17 | parser.add_argument('--dataset', type=str, default='DD') 18 | parser.add_argument('--epochs', type=int, default=1000) 19 | parser.add_argument('--nconvs', type=int, default=3) 20 | parser.add_argument('--hidden', type=int, default=128) 21 | parser.add_argument('--dropout', type=float, default=0) 22 | parser.add_argument('--init', type=str, default='noise') 23 | parser.add_argument('--lr_adj', type=float, default=1) 24 | parser.add_argument('--lr_feat', type=float, default=0.01) 25 | parser.add_argument('--seed', type=int, default=0) 26 | parser.add_argument('--debug', type=int, default=0) 27 | parser.add_argument('--reduction_rate', type=float, default=1) 28 | parser.add_argument('--save', type=int, default=0) 29 | parser.add_argument('--nruns', type=int, default=10) 30 | parser.add_argument('--ipc', type=int, default=0) 31 | parser.add_argument('--mlp', type=int, default=0) 32 | parser.add_argument('--pooling', type=str, default='mean') 33 | parser.add_argument('--net_norm', type=str, default='none') 34 | parser.add_argument('--stru_discrete', type=int, default=0) 35 | args = parser.parse_args() 36 | 37 | torch.cuda.set_device(args.gpu_id) 38 | torch.set_num_threads(1) 39 | 40 | # random seed setting 41 | data_seed = 0 42 | random.seed(data_seed) 43 | np.random.seed(data_seed) 44 | torch.manual_seed(data_seed) 45 | torch.cuda.manual_seed(data_seed) 46 | 47 | print(args) 48 | 49 | data = Dataset(args) 50 | packed_data = data.packed_data 51 | 52 | # random seed setting 53 | random.seed(args.seed) 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | torch.cuda.manual_seed(args.seed) 57 | 58 | device = 'cuda' 59 | max_nodes = 1 60 | agent = GraphAgent(data=packed_data, args=args, device=device, nnodes_syn=max_nodes) 61 | train_dataset = packed_data[0] 62 | sampled = [] 63 | for c in range(train_dataset.num_classes): 64 | ind = agent.syn_class_indices[c] 65 | idx_shuffle = np.random.permutation(agent.real_indices_class[c])[:ind[1]-ind[0]] 66 | sampled.append(agent.data[4][idx_shuffle]) 67 | agent.adj_syn = np.hstack(sampled) 68 | 69 | 70 | runs = args.nruns 71 | res = [] 72 | for _ in tqdm(range(runs)): 73 | if args.dataset in ['ogbg-molhiv']: 74 | res.append(agent.test_pyg_data(save=args.save, epochs=100, verbose=0)) 75 | else: 76 | res.append(agent.test_pyg_data(save=args.save, epochs=500, verbose=0)) 77 | 78 | res = np.array(res) 79 | print('Mean Train/Val/TestAcc/TrainLoss:', repr(res.mean(0))) 80 | print('Std Train/Val/TestAcc/TrainLoss:', repr(res.std(0))) 81 | 82 | -------------------------------------------------------------------------------- /test_saved_graphs.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | import torch.nn.functional as F 4 | from graph_agent import GraphAgent 5 | import argparse 6 | import random 7 | import numpy as np 8 | from utils import * 9 | import sys 10 | import logging 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') 14 | parser.add_argument('--dataset', type=str, default='PROTEINS') 15 | parser.add_argument('--epochs', type=int, default=10000) 16 | parser.add_argument('--hidden', type=int, default=128) 17 | parser.add_argument('--init', type=str, default='noise') 18 | parser.add_argument('--lr_adj', type=float, default=0.01) 19 | parser.add_argument('--lr_feat', type=float, default=0.01) 20 | parser.add_argument('--seed', type=int, default=0) 21 | parser.add_argument('--reduction_rate', type=float, default=0.1) 22 | parser.add_argument('--nconvs', type=int, default=3) 23 | parser.add_argument('--outer', type=int, default=1) 24 | parser.add_argument('--inner', type=int, default=0) 25 | parser.add_argument('--ipc', type=int, default=0) 26 | parser.add_argument('--nruns', type=int, default=10) 27 | parser.add_argument('--dis_metric', type=str, default='ours', help='distance metric') 28 | parser.add_argument('--num_blocks', type=int, default=1) 29 | parser.add_argument('--num_bases', type=int, default=0) 30 | parser.add_argument('--stru_discrete', type=int, default=1) 31 | parser.add_argument('--pooling', type=str, default='mean') 32 | parser.add_argument('--net_norm', type=str, default='none') 33 | parser.add_argument('--dropout', type=float, default=0) 34 | parser.add_argument('--filename', type=str) 35 | 36 | args = parser.parse_args() 37 | 38 | torch.cuda.set_device(args.gpu_id) 39 | 40 | torch.set_num_threads(1) 41 | 42 | args.ipc = int(args.filename.split('_')[1][3:]) 43 | args.seed = int((args.filename.split('_'))[2][1:]) 44 | 45 | LOG_FILENAME = f'logs/{args.dataset}_seeds.log' 46 | for handler in logging.root.handlers[:]: 47 | logging.root.removeHandler(handler) 48 | logging.basicConfig(filename=LOG_FILENAME,level=logging.DEBUG) 49 | 50 | print(args) 51 | device = 'cuda' 52 | 53 | data = Dataset(args) 54 | packed_data = data.packed_data 55 | 56 | random.seed(args.seed) 57 | np.random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | torch.cuda.manual_seed(args.seed) 60 | 61 | if args.dataset == 'ogbg-molhiv': 62 | args.pooling = 'sum' 63 | 64 | if args.dataset == 'CIFAR10': 65 | args.nruns = 3 66 | args.net_norm = 'instancenorm' 67 | 68 | agent = GraphAgent(data=packed_data, args=args, device=device, nnodes_syn=get_mean_nodes(args)) 69 | assert args.stru_discrete == 1, 'must be discrete' 70 | 71 | if args.stru_discrete: 72 | agent.adj_syn, agent.feat_syn = torch.load(f'saved/{args.filename}', map_location='cuda') 73 | 74 | 75 | res = [] 76 | for _ in range(args.nruns): 77 | if args.dataset in ['ogbg-molhiv']: 78 | res.append(agent.test(epochs=100)) 79 | else: 80 | res.append(agent.test(epochs=500)) 81 | 82 | 83 | res = np.array(res) 84 | print('Mean Train/Val/TestAcc/TrainLoss:', res.mean(0)) 85 | print('Std Train/Val/TestAcc/TrainLoss:', res.std(0)) 86 | 87 | logging.info(str(args)+'\n'+f'Mean Train/Val/TestAcc/TrainLoss: {res.mean(0)}') 88 | 89 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from math import ceil 3 | import torch 4 | from torch_geometric.nn import DenseSAGEConv, DenseGCNConv, GCNConv 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import BatchNorm1d 8 | from torch_geometric.nn import LayerNorm, InstanceNorm 9 | 10 | 11 | class DenseGCN(torch.nn.Module): 12 | def __init__(self, nfeat, nhid, nclass, nconvs=3, dropout=0, if_mlp=False, net_norm='none', pooling='mean', **kwargs): 13 | super(DenseGCN, self).__init__() 14 | 15 | self.molhiv = False 16 | if kwargs['args'].dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 17 | nclass = 1 18 | self.molhiv = True 19 | 20 | if nconvs == 1: 21 | nhid = nclass 22 | 23 | self.mlp = if_mlp 24 | if self.mlp: 25 | DenseGCNConv = nn.Linear 26 | else: 27 | from torch_geometric.nn import DenseSAGEConv, DenseGCNConv 28 | self.convs = nn.ModuleList([]) 29 | self.convs.append(DenseGCNConv(nfeat, nhid)) 30 | for _ in range(nconvs-1): 31 | self.convs.append(DenseGCNConv(nhid, nhid)) 32 | 33 | self.norms = nn.ModuleList([]) 34 | 35 | for _ in range(nconvs): 36 | if nconvs == 1: norm = torch.nn.Identity() 37 | elif net_norm == 'none': 38 | norm = torch.nn.Identity() 39 | elif net_norm == 'batchnorm': 40 | norm = BatchNorm1d(nhid) 41 | elif net_norm == 'layernorm': 42 | norm = nn.LayerNorm([nhid], elementwise_affine=True) 43 | elif net_norm == 'instancenorm': 44 | norm = InstanceNorm(nhid, affine=False) #pyg 45 | elif net_norm == 'groupnorm': 46 | norm = nn.GroupNorm(4, nhid, affine=True) 47 | self.norms.append(norm) 48 | 49 | self.lin3 = torch.nn.Linear(nhid, nclass) if nconvs != 1 else lambda x: x 50 | self.dropout = dropout 51 | self.pooling = pooling 52 | 53 | def forward(self, x, adj, mask=None, if_embed=False): 54 | if self.dropout !=0: 55 | x_mask = torch.distributions.bernoulli.Bernoulli(self.dropout).sample([x.size(0), x.size(1)]).to('cuda').unsqueeze(-1) 56 | x = x_mask * x 57 | 58 | for i in range(len(self.convs)): 59 | if self.mlp: 60 | x = self.convs[i](x) 61 | else: 62 | x = self.convs[i](x, adj, mask) 63 | x = self.perform_norm(i, x) 64 | x = F.relu(x) 65 | 66 | if self.pooling == 'sum': 67 | x = x.sum(1) 68 | if self.pooling == 'mean': 69 | x = x.mean(1) 70 | if if_embed: 71 | return x 72 | if self.molhiv: 73 | x = self.lin3(x) 74 | else: 75 | x = F.log_softmax(self.lin3(x), dim=-1) 76 | 77 | return x 78 | 79 | 80 | def embed(self, x, adj, mask=None): 81 | return self.forward(x, adj, mask, if_embed=True) 82 | 83 | def perform_norm(self, i, x): 84 | batch_size, num_nodes, num_channels = x.size() 85 | x = x.view(-1, num_channels) 86 | x = self.norms[i](x) 87 | x = x.view(batch_size, num_nodes, num_channels) 88 | return x 89 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /models_gcn.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from math import ceil 3 | import torch 4 | from torch_geometric.nn import DenseSAGEConv, DenseGCNConv, GCNConv 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import BatchNorm1d 8 | from torch_geometric.nn import LayerNorm, InstanceNorm 9 | from torch_geometric.nn import global_mean_pool, global_add_pool 10 | from ogb.graphproppred.mol_encoder import AtomEncoder 11 | 12 | 13 | 14 | 15 | class GCN(torch.nn.Module): 16 | def __init__(self, nfeat, nhid, nclass, nconvs=3, dropout=0, if_mlp=False, net_norm='none', pooling='mean', learn_adj=False, **kwargs): 17 | super(GCN, self).__init__() 18 | self.molhiv = False 19 | if kwargs['args'].dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 20 | nclass = 1 21 | self.molhiv = True 22 | 23 | if nconvs ==1: 24 | nhid = nclass 25 | 26 | self.mlp = if_mlp 27 | if self.mlp: 28 | GCNConv = nn.Linear 29 | else: 30 | from torch_geometric.nn import GCNConv 31 | self.convs = nn.ModuleList([]) 32 | self.convs.append(GCNConv(nfeat, nhid, learn_adj=learn_adj)) 33 | for _ in range(nconvs-1): 34 | self.convs.append(GCNConv(nhid, nhid, learn_adj=learn_adj)) 35 | 36 | self.norms = nn.ModuleList([]) 37 | for _ in range(nconvs): 38 | if nconvs == 1: norm = torch.nn.Identity() 39 | elif net_norm == 'none': 40 | norm = torch.nn.Identity() 41 | elif net_norm == 'batchnorm': 42 | norm = BatchNorm1d(nhid) 43 | elif net_norm == 'layernorm': 44 | norm = nn.LayerNorm([nhid], elementwise_affine=True) 45 | elif net_norm == 'instancenorm': 46 | norm = InstanceNorm(nhid, affine=False) #pyg 47 | elif net_norm == 'groupnorm': 48 | norm = nn.GroupNorm(4, nhid, affine=True) 49 | self.norms.append(norm) 50 | 51 | self.lin3 = torch.nn.Linear(nhid, nclass) 52 | self.dropout = dropout 53 | self.pooling = pooling 54 | 55 | def forward(self, data, if_embed=False): 56 | x, edge_index, batch = data.x, data.edge_index, data.batch 57 | if self.dropout !=0 and self.training: 58 | x_mask = torch.distributions.bernoulli.Bernoulli(self.dropout).sample([x.size(0)]).to('cuda').unsqueeze(-1) 59 | x = x_mask * x 60 | 61 | for i in range(len(self.convs)): 62 | if self.mlp: 63 | x = self.convs[i](x) #, edge_index) 64 | else: 65 | x = self.convs[i](x, edge_index) 66 | x = self.perform_norm(i, x) 67 | x = F.relu(x) 68 | 69 | if self.pooling == 'mean': 70 | x = global_mean_pool(x, batch=data.batch) 71 | if self.pooling == 'sum': 72 | x = global_add_pool(x, batch=data.batch) 73 | if if_embed: 74 | return x 75 | if self.molhiv: 76 | x = self.lin3(x) 77 | else: 78 | x = F.log_softmax(self.lin3(x), dim=-1) 79 | return x 80 | 81 | def forward_edgeweight(self, data, if_embed=False): 82 | x, edge_index, edge_weight, batch = data 83 | 84 | for i in range(len(self.convs)): 85 | if self.mlp: 86 | x = self.convs[i](x) #, edge_index) 87 | else: 88 | x = self.convs[i](x, edge_index, edge_weight) 89 | x = self.perform_norm(i, x) 90 | x = F.relu(x) 91 | x = F.dropout(x, self.dropout, training=self.training) 92 | 93 | if self.pooling == 'mean': 94 | x = global_mean_pool(x, batch=batch) 95 | if self.pooling == 'sum': 96 | x = global_add_pool(x, batch=batch) 97 | if if_embed: 98 | return x 99 | if self.molhiv: 100 | x = self.lin3(x) 101 | else: 102 | x = F.log_softmax(self.lin3(x), dim=-1) 103 | return x 104 | 105 | 106 | 107 | def embed(self, data): 108 | return self.forward(data, if_embed=True) 109 | 110 | 111 | def perform_norm(self, i, x): 112 | batch_size, num_channels = x.size() 113 | x = x.view(-1, num_channels) 114 | x = self.norms[i](x) 115 | x = x.view(batch_size, num_channels) 116 | return x 117 | -------------------------------------------------------------------------------- /dense_sgc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn.inits import zeros 4 | from torch_geometric.nn.dense.linear import Linear 5 | import os.path as osp 6 | from math import ceil 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import BatchNorm1d 11 | from torch_geometric.nn import LayerNorm, InstanceNorm 12 | 13 | class AKX(torch.nn.Module): 14 | def __init__(self, K=3): 15 | super(AKX, self).__init__() 16 | self.conv = SGConv(1, 1, K=K) 17 | self.conv.lin = torch.nn.Identity() # TODO 18 | 19 | def forward(self, x, adj, pool): 20 | x = self.conv(x, adj) 21 | if pool == 'mean': 22 | x = x.mean(1) 23 | if pool == 'sum': 24 | x = x.sum(1) 25 | return torch.norm(x).item() 26 | 27 | def get_akx(x, adj, K, pool): 28 | conv = DenseSGConv(1, 1, K=K) 29 | conv.lin = torch.nn.Identity() # TODO 30 | x = conv(x, adj) 31 | if pool == 'mean': 32 | x = x.mean(1) 33 | if pool == 'sum': 34 | x = x.sum(1) 35 | return torch.norm(x) 36 | 37 | 38 | 39 | class DenseSGC(torch.nn.Module): 40 | def __init__(self, nfeat, nhid, nclass, ntrans=1, nconvs=3, dropout=0): 41 | super(DenseSGC, self).__init__() 42 | self.conv = DenseSGConv(nfeat, nhid, K=nconvs) 43 | 44 | self.norms = nn.ModuleList([]) 45 | net_norm = 'none' 46 | for _ in range(ntrans+1): 47 | if ntrans == 0: norm = torch.nn.Identity() 48 | elif net_norm == 'none': 49 | norm = torch.nn.Identity() 50 | elif net_norm == 'batchnorm': 51 | norm = BatchNorm1d(nhid) 52 | elif net_norm == 'layernorm': 53 | norm = nn.LayerNorm([nhid,111], elementwise_affine=True) 54 | elif net_norm == 'instancenorm': 55 | norm = InstanceNorm(nhid, affine=False) #pyg 56 | elif net_norm == 'groupnorm': 57 | norm = nn.GroupNorm(4, nhid, affine=True) 58 | self.norms.append(norm) 59 | 60 | self.lins = nn.ModuleList([]) 61 | self.ntrans = ntrans 62 | for _ in range(ntrans): 63 | self.lins.append(torch.nn.Linear(nhid, nhid)) 64 | self.lin_final = torch.nn.Linear(nhid, nclass) if ntrans == 0 else lambda x:x 65 | self.dropout = dropout 66 | 67 | self.dropout = dropout 68 | 69 | def forward(self, x, adj, mask=None): 70 | x = F.dropout(x, self.dropout, training=self.training) 71 | x = self.conv(x, adj, mask) 72 | x = self.perform_norm(0, x) 73 | x = F.relu(x) 74 | for ii, lin in enumerate(self.lins): 75 | x = lin(x) 76 | x = self.perform_norm(ii+1, x) 77 | x = F.relu(x) 78 | 79 | x = x.mean(1) 80 | x = F.log_softmax(self.lin_final(x), dim=-1) 81 | return x 82 | 83 | def perform_norm(self, i, x): 84 | batch_size, num_nodes, num_channels = x.size() 85 | x = x.view(-1, num_channels) 86 | x = self.norms[i](x) 87 | # x = self.norm(x) 88 | x = x.view(batch_size, num_nodes, num_channels) 89 | return x 90 | 91 | 92 | class DenseSGConv(torch.nn.Module): 93 | def __init__(self, in_channels, out_channels, K=2, improved=False, bias=True): 94 | super().__init__() 95 | 96 | self.in_channels = in_channels 97 | self.out_channels = out_channels 98 | self.improved = improved 99 | 100 | self.K = K 101 | self.lin = Linear(in_channels, out_channels, bias=bias, 102 | weight_initializer='glorot') 103 | 104 | self.reset_parameters() 105 | 106 | def reset_parameters(self): 107 | self.lin.reset_parameters() 108 | 109 | def forward(self, x, adj, mask=None, add_loop=True): 110 | x = x.unsqueeze(0) if x.dim() == 2 else x 111 | adj = adj.unsqueeze(0) if adj.dim() == 2 else adj 112 | B, N, _ = adj.size() 113 | 114 | if add_loop: 115 | adj = adj.clone() 116 | idx = torch.arange(N, dtype=torch.long, device=adj.device) 117 | adj[:, idx, idx] = 1 if not self.improved else 2 118 | 119 | deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5) 120 | 121 | adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2) 122 | 123 | for _ in range(self.K): 124 | x = torch.matmul(adj, x) 125 | 126 | out = self.lin(x) 127 | 128 | if mask is not None: 129 | out = out * mask.view(B, N, 1).to(x.dtype) 130 | 131 | return out 132 | 133 | def __repr__(self) -> str: 134 | return (f'{self.__class__.__name__}({self.in_channels}, ' 135 | f'{self.out_channels})') 136 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch_geometric.data import InMemoryDataset 3 | from torch_geometric.datasets import TUDataset, GNNBenchmarkDataset 4 | import torch_geometric.transforms as T 5 | from torch_geometric.loader import DataLoader, DenseDataLoader 6 | import os.path as osp 7 | from torch_geometric.datasets import MNISTSuperpixels 8 | import numpy as np 9 | import torch 10 | from ogb.graphproppred import PygGraphPropPredDataset 11 | from torch_geometric.utils.convert import to_scipy_sparse_matrix 12 | import random 13 | 14 | class Complete(object): 15 | def __call__(self, data): 16 | if data.x is None: 17 | if hasattr(data, 'adj'): 18 | data.x = data.adj.sum(1).view(-1, 1) 19 | else: 20 | adj = to_scipy_sparse_matrix(data.edge_index).sum(1) 21 | data.x = torch.FloatTensor(adj.sum(1)).view(-1, 1) 22 | return data 23 | 24 | class RemoveEdgeAttr(object): 25 | def __call__(self, data): 26 | if data.edge_attr is not None: 27 | data.edge_attr = None 28 | if data.x is None: 29 | if hasattr(data, 'adj'): 30 | data.x = data.adj.sum(1).view(-1, 1) 31 | else: 32 | adj = to_scipy_sparse_matrix(data.edge_index).sum(1) 33 | data.x = torch.FloatTensor(adj.sum(1)).view(-1, 1) 34 | 35 | data.y = data.y.squeeze(0) 36 | data.x = data.x.float() 37 | return data 38 | 39 | class ConcatPos(object): 40 | def __call__(self, data): 41 | if data.edge_attr is not None: 42 | data.edge_attr = None 43 | data.x = torch.cat([data.x, data.pos], dim=1) 44 | data.pos = None 45 | return data 46 | 47 | class Dataset: 48 | 49 | def __init__(self, args): 50 | # random seed setting 51 | random.seed(0) 52 | np.random.seed(0) 53 | torch.manual_seed(0) 54 | torch.cuda.manual_seed(0) 55 | 56 | name = args.dataset 57 | path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', f'{name}') 58 | 59 | if name in ['DD', 'MUTAG', 'NCI1']: 60 | dataset = TUDataset(path, name=name, transform=T.Compose([Complete()]), use_node_attr=True) 61 | dataset = dataset.shuffle() 62 | n = (len(dataset) + 9) // 10 63 | test_dataset = dataset[:n] 64 | val_dataset = dataset[n:2 * n] 65 | train_dataset = dataset[2 * n:] 66 | nnodes = [x.num_nodes for x in dataset] 67 | print('mean #nodes:', np.mean(nnodes), 'max #nodes:', np.max(nnodes)) 68 | 69 | if name in ['CIFAR10']: 70 | transform = T.Compose([ConcatPos()]) 71 | train_dataset= GNNBenchmarkDataset(path, name=name, split='train', transform=transform) 72 | val_dataset= GNNBenchmarkDataset(path, name=name, split='val', transform=transform) 73 | test_dataset= GNNBenchmarkDataset(path, name=name, split='test', transform=transform) 74 | train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True) 75 | val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False) 76 | test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) 77 | nnodes = [x.num_nodes for x in train_dataset] 78 | print('mean #nodes:', np.mean(nnodes), 'max #nodes:', np.max(nnodes)) 79 | 80 | 81 | if name in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 82 | dataset = PygGraphPropPredDataset(name=name, transform=T.Compose([RemoveEdgeAttr()])) 83 | split_idx = dataset.get_idx_split() 84 | train_dataset = dataset[split_idx["train"]] 85 | nnodes = [x.num_nodes for x in train_dataset] 86 | print('mean #nodes:', np.mean(nnodes), 'max #nodes:', np.max(nnodes)) 87 | ### automatic evaluator. takes dataset name as input 88 | train_dataset = dataset[split_idx["train"]] 89 | val_dataset = dataset[split_idx["valid"]] 90 | test_dataset = dataset[split_idx["test"]] 91 | 92 | 93 | y_final = [g.y.item() for g in test_dataset] 94 | from collections import Counter; counter=Counter(y_final); print(counter) 95 | print("#Majority guessing:", sorted(counter.items())[-1][1]/len(y_final)) 96 | 97 | test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False) 98 | val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False) 99 | train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True) 100 | 101 | train_datalist = np.ndarray((len(train_dataset),), dtype=np.object) 102 | for ii in range(len(train_dataset)): 103 | train_datalist[ii] = train_dataset[ii] 104 | self.packed_data = [train_dataset, train_loader, val_loader, test_loader, train_datalist] 105 | 106 | 107 | class TensorDataset(Dataset): 108 | def __init__(self, feat, adj, labels): # images: n x c x h x w tensor 109 | self.x = feat.detach() 110 | self.adj = adj.detach() 111 | self.y = labels.detach() 112 | 113 | def __getitem__(self, index): 114 | return self.x[index], self.adj[index], self.y[index] 115 | 116 | def __len__(self): 117 | return self.x.shape[0] 118 | 119 | class SparseTensorDataset(Dataset): 120 | def __init__(self, data): # images: n x c x h x w tensor 121 | self.data = data 122 | 123 | def __getitem__(self, index): 124 | return self.data[index] 125 | 126 | def __len__(self): 127 | return len(self.data) 128 | 129 | 130 | def get_max_nodes(args): 131 | if args.dataset == 'CIFAR10': 132 | return 150 133 | if args.dataset == 'DD': 134 | return 5748 135 | if args.dataset == 'MUTAG': 136 | return 28 137 | if args.dataset == 'NCI1': 138 | return 111 139 | if args.dataset == 'ogbg-molhiv': 140 | return 222 141 | raise NotImplementedError 142 | 143 | def get_mean_nodes(args): 144 | if args.dataset == 'CIFAR10': 145 | return 118 146 | if args.dataset == 'DD': 147 | return 285 148 | if args.dataset == 'MUTAG': 149 | return 18 150 | if args.dataset == 'NCI1': 151 | return 30 152 | if args.dataset == 'ogbg-molhiv': 153 | return 26 154 | if args.dataset == 'ogbg-molbbbp': 155 | return 24 156 | if args.dataset == 'ogbg-molbace': 157 | return 34 158 | 159 | raise NotImplementedError 160 | 161 | 162 | def match_loss(gw_syn, gw_real, args, device): 163 | dis = torch.tensor(0.0).to(device) 164 | 165 | if args.dis_metric == 'ours': 166 | for ig in range(len(gw_real)): 167 | gwr = gw_real[ig] 168 | gws = gw_syn[ig] 169 | dis += distance_wb(gwr, gws) 170 | 171 | elif args.dis_metric == 'mse': 172 | gw_real_vec = [] 173 | gw_syn_vec = [] 174 | for ig in range(len(gw_real)): 175 | gw_real_vec.append(gw_real[ig].reshape((-1))) 176 | gw_syn_vec.append(gw_syn[ig].reshape((-1))) 177 | gw_real_vec = torch.cat(gw_real_vec, dim=0) 178 | gw_syn_vec = torch.cat(gw_syn_vec, dim=0) 179 | dis = torch.sum((gw_syn_vec - gw_real_vec)**2) / torch.sum((gw_real_vec)**2) # I used this a lot 180 | 181 | elif args.dis_metric == 'cos': 182 | gw_real_vec = [] 183 | gw_syn_vec = [] 184 | for ig in range(len(gw_real)): 185 | gw_real_vec.append(gw_real[ig].reshape((-1))) 186 | gw_syn_vec.append(gw_syn[ig].reshape((-1))) 187 | gw_real_vec = torch.cat(gw_real_vec, dim=0) 188 | gw_syn_vec = torch.cat(gw_syn_vec, dim=0) 189 | dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001) 190 | 191 | else: 192 | exit('unknown distance function: %s'%args.dis_metric) 193 | 194 | return dis 195 | 196 | 197 | def distance_wb(gwr, gws): 198 | shape = gwr.shape 199 | if len(shape) == 4: # conv, out*in*h*w 200 | gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3]) 201 | gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3]) 202 | elif len(shape) == 3: # layernorm, C*h*w 203 | gwr = gwr.reshape(shape[0], shape[1] * shape[2]) 204 | gws = gws.reshape(shape[0], shape[1] * shape[2]) 205 | elif len(shape) == 2: # linear, out*in 206 | tmp = 'do nothing' 207 | elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias 208 | gwr = gwr.reshape(1, shape[0]) 209 | gws = gws.reshape(1, shape[0]) 210 | return torch.tensor(0, dtype=torch.float, device=gwr.device) 211 | 212 | dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001)) 213 | dis = dis_weight 214 | return dis 215 | 216 | 217 | def save_pyg_graphs(graphs, args): 218 | memory_dict = {} 219 | for d in graphs: 220 | y = d.y.item() 221 | if y not in memory_dict: 222 | memory_dict[y] = [d] 223 | else: 224 | memory_dict[y].append(d) 225 | 226 | for k, v in memory_dict.items(): 227 | graph_dict = {} 228 | d, slices = InMemoryDataset.collate(v) 229 | graph_dict['x'] = d.x 230 | graph_dict['edge_index'] = d.edge_index 231 | graph_dict['y'] = d.y 232 | memory_dict[k] = (graph_dict, slices) 233 | 234 | torch.save(memory_dict, f'saved/memory/{args.dataset}_ours_{args.seed}_ipc{args.ipc}.pt') 235 | 236 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /graph_agent.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | from math import ceil 4 | import torch 5 | import torch.nn.functional as F 6 | from models_gcn import GCN 7 | from models import DenseGCN 8 | from dense_sgc import get_akx 9 | from collections import Counter 10 | import numpy as np 11 | from utils import TensorDataset, SparseTensorDataset 12 | from utils import * 13 | from copy import deepcopy 14 | from torch_geometric.utils import to_dense_batch, to_dense_adj 15 | from torch_geometric.data import Batch 16 | from sklearn.metrics import roc_auc_score 17 | 18 | cls_criterion = torch.nn.BCEWithLogitsLoss() 19 | 20 | class GraphAgent: 21 | 22 | def __init__(self, data, args, device, nnodes_syn=75): 23 | self.data = data 24 | self.args = args 25 | self.device = device 26 | labels_train = [x.y.item() for x in data[0]] 27 | 28 | print('training size:', len(labels_train)) 29 | nfeat = data[0].num_features 30 | nclass = data[0].num_classes 31 | 32 | self.prepare_train_indices() 33 | 34 | # parametrize syn data 35 | self.labels_syn = self.get_labels_syn(labels_train) 36 | if args.ipc == 0: 37 | n = int(len(labels_train) * args.reduction_rate) 38 | else: 39 | self.labels_syn = torch.LongTensor([[i]*args.ipc for i in range(nclass)]).to(device).view(-1) 40 | self.syn_class_indices = {i: [i*args.ipc, (i+1)*args.ipc] for i in range(nclass)} 41 | n = args.ipc * nclass 42 | 43 | self.adj_syn = torch.rand(size=(n, nnodes_syn, nnodes_syn), dtype=torch.float, requires_grad=True, device=device) 44 | self.feat_syn = torch.rand(size=(n, nnodes_syn, nfeat), dtype=torch.float, requires_grad=True, device=device) 45 | 46 | if args.init == 'real': 47 | for c in range(nclass): 48 | ind = self.syn_class_indices[c] 49 | feat_real, adj_real = self.get_graphs(c, batch_size=ind[1]-ind[0], max_node_size=nnodes_syn, to_dense=True) 50 | self.feat_syn.data[ind[0]: ind[1]] = feat_real[:, :nnodes_syn].detach().data 51 | self.adj_syn.data[ind[0]: ind[1]] = adj_real[:, :nnodes_syn, :nnodes_syn].detach().data 52 | self.sparsity = self.adj_syn.mean().item() 53 | if args.stru_discrete: 54 | self.adj_syn.data.copy_(self.adj_syn*10-5) # max:5; min:-5 55 | else: 56 | if args.stru_discrete: 57 | adj_init = torch.log(self.adj_syn) - torch.log(1-self.adj_syn) 58 | adj_init = adj_init.clamp(-10, 10) 59 | self.adj_syn.data.copy_(adj_init) 60 | 61 | print('adj.shape:', self.adj_syn.shape, 'feat.shape:', self.feat_syn.shape) 62 | self.optimizer_adj = torch.optim.Adam([self.adj_syn], lr=args.lr_adj) 63 | self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat) 64 | self.weights = [] 65 | 66 | def prepare_train_indices(self): 67 | dataset = self.data[0] 68 | indices_class = {} 69 | nnodes_all = [] 70 | for ix, single in enumerate(dataset): 71 | c = single.y.item() 72 | if c not in indices_class: 73 | indices_class[c] = [ix] 74 | else: 75 | indices_class[c].append(ix) 76 | nnodes_all.append(single.num_nodes) 77 | 78 | self.nnodes_all = np.array(nnodes_all) 79 | self.real_indices_class = indices_class 80 | 81 | def get_labels_syn(self, labels_train): 82 | counter = Counter(labels_train) 83 | num_class_dict = {} 84 | n = len(labels_train) 85 | 86 | sorted_counter = sorted(counter.items(), key=lambda x:x[1]) 87 | sum_ = 0 88 | labels_syn = [] 89 | self.syn_class_indices = {} 90 | 91 | for ix, (c, num) in enumerate(sorted_counter): 92 | if ix == len(sorted_counter) - 1: 93 | num_class_dict[c] = int(n * self.args.reduction_rate) - sum_ 94 | self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]] 95 | labels_syn += [c] * num_class_dict[c] 96 | else: 97 | num_class_dict[c] = max(int(num * self.args.reduction_rate), 1) 98 | sum_ += num_class_dict[c] 99 | self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]] 100 | labels_syn += [c] * num_class_dict[c] 101 | 102 | self.num_class_dict = num_class_dict 103 | return torch.LongTensor(labels_syn).to(self.device) 104 | 105 | def get_graphs(self, c, batch_size, max_node_size=None, to_dense=False, idx_selected=None): 106 | """get random n images from class c""" 107 | if idx_selected is None: 108 | if max_node_size is None: 109 | idx_shuffle = np.random.permutation(self.real_indices_class[c])[:batch_size] 110 | sampled = self.data[4][idx_shuffle] 111 | else: 112 | indices = np.array(self.real_indices_class[c])[self.nnodes_all[self.real_indices_class[c]] <= max_node_size] 113 | idx_shuffle = np.random.permutation(indices)[:batch_size] 114 | sampled = self.data[4][idx_shuffle] 115 | else: 116 | sampled = self.data[4][idx_selected] 117 | data = Batch.from_data_list(sampled) 118 | if to_dense: 119 | x, edge_index, batch = data.x, data.edge_index, data.batch 120 | x, mask = to_dense_batch(x, batch=batch, max_num_nodes=max_node_size) 121 | adj = to_dense_adj(edge_index, batch=batch, max_num_nodes=max_node_size) 122 | return x.to(self.device), adj.to(self.device) 123 | else: 124 | return data.to(self.device) 125 | 126 | def get_graphs_multiclass(self, batch_size, max_node_size=None, idx_herding=None): 127 | """get random n graphs from classes""" 128 | if idx_herding is None: 129 | if max_node_size is None: 130 | idx_shuffle = [] 131 | for c in range(self.data[0].num_classes): 132 | idx_shuffle.append(np.random.permutation(self.real_indices_class[c])[:batch_size]) 133 | idx_shuffle = np.hstack(idx_shuffle) 134 | sampled = self.data[4][idx_shuffle] 135 | else: 136 | idx_shuffle = [] 137 | for c in range(self.data[0].num_classes): 138 | indices = np.array(self.real_indices_class[c])[self.nnodes_all[self.real_indices_class[c]] <= max_node_size] 139 | idx_shuffle.append(np.random.permutation(indices)[:batch_size]) 140 | idx_shuffle = np.hstack(idx_shuffle) 141 | sampled = self.data[4][idx_shuffle] 142 | else: 143 | sampled = self.data[4][idx_herding] 144 | data = Batch.from_data_list(sampled) 145 | return data.to(self.device) 146 | 147 | def clip(self): 148 | self.adj_syn.data.clamp_(min=0, max=1) 149 | # self.feat_syn.data.clamp_(min=0, max=1) 150 | 151 | def train(self): 152 | dataset = self.data[0] 153 | train_loader = self.data[1] 154 | device = self.device 155 | args = self.args 156 | 157 | args.outer_loop, args.inner_loop = args.outer, args.inner 158 | 159 | sparsity = self.sparsity 160 | import time; st=time.time() 161 | for it in range(args.epochs): 162 | runs = 3 163 | if it == 0 and args.lr_adj!=0 and args.eval_init: 164 | print('=== performance before optimizing:') 165 | res = [] 166 | for _ in range(runs): 167 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace' ]: 168 | res.append(self.test(epochs=500)) 169 | elif args.dataset in ['DD']: 170 | res.append(self.test(epochs=100)) 171 | else: 172 | res.append(self.test(epochs=500)) 173 | 174 | res = np.array(res) 175 | print('Mean Train/Val/TestAcc/TrainLoss:', res.mean(0)) 176 | print('Std Train/Val/TestAcc/TrainLoss:', res.std(0)) 177 | 178 | model_syn = DenseGCN(nfeat=dataset.num_features, nhid=args.hidden, net_norm=args.net_norm, pooling=args.pooling, 179 | dropout=0.0, nclass=dataset.num_classes, nconvs=args.nconvs, args=args).to(self.device) 180 | model_real = GCN(nfeat=dataset.num_features, nhid=args.hidden, net_norm=args.net_norm, pooling=args.pooling, 181 | dropout=0.0, nclass=dataset.num_classes, nconvs=args.nconvs, args=args).to(self.device) 182 | 183 | model_real.load_state_dict(model_syn.state_dict()) 184 | model_real_parameters = list(model_real.parameters()) 185 | model_syn_parameters = list(model_syn.parameters()) 186 | optimizer = torch.optim.Adam(model_syn.parameters(), lr=args.lr_model) 187 | 188 | loss_avg = 0 189 | for ol in range(args.outer_loop): 190 | 191 | BN_flag = False 192 | bn_real_state = [] 193 | for model in [model_real]: 194 | for module in model.modules(): 195 | if 'BatchNorm' in module._get_name(): #BatchNorm 196 | BN_flag = True 197 | if BN_flag: 198 | data_real = self.get_graphs_multiclass(batch_size=16) 199 | model.train() # for updating the mu, sigma of BatchNorm 200 | output_real = model(data_real) 201 | for module in model.modules(): 202 | if 'BatchNorm' in module._get_name(): #BatchNorm 203 | module.eval() # fix mu and sigma of every BatchNorm layer 204 | bn_real_state.append(module.state_dict()) 205 | 206 | if BN_flag: 207 | model_syn.train() # for updating the mu, sigma of BatchNorm 208 | for module in model_syn.modules(): 209 | ii = 0 210 | if 'BatchNorm' in module._get_name(): #BatchNorm 211 | module.eval() # fix mu and sigma of every BatchNorm layer 212 | module.load_state_dict(bn_real_state[ii]) 213 | ii += 1 214 | 215 | feat_syn = self.feat_syn 216 | adj_syn = self.adj_syn 217 | 218 | if args.stru_discrete: 219 | adj_syn = self.get_discrete_graphs(adj_syn, inference=False) 220 | loss = 0 221 | if args.dataset not in ['ogbg-molbace', 'CIFAR10']: 222 | for c in range(dataset.num_classes): 223 | data_real = self.get_graphs(c, batch_size=args.bs_cond) 224 | ind = self.syn_class_indices[c] 225 | feat_syn_c = feat_syn[ind[0]:ind[1]] 226 | adj_syn_c = adj_syn[ind[0]: ind[1]] 227 | 228 | labels_real = torch.ones((data_real.y.shape[0],), device=self.device, dtype=torch.long) * c 229 | 230 | labels_syn = self.labels_syn[ind[0]:ind[1]] 231 | output_real = model_real(data_real) 232 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 233 | loss_real = cls_criterion(output_real, labels_real.view(-1, 1).float()) 234 | else: 235 | loss_real = F.nll_loss(output_real, labels_real) 236 | gw_real = torch.autograd.grad(loss_real, model_real_parameters) 237 | gw_real = list((_.detach().clone() for _ in gw_real)) 238 | 239 | output_syn = model_syn(feat_syn_c, adj_syn_c) 240 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 241 | loss_syn = cls_criterion(output_syn, labels_syn.view(-1, 1).float()) 242 | else: 243 | loss_syn = F.nll_loss(output_syn, labels_syn) 244 | gw_syn = torch.autograd.grad(loss_syn, model_syn_parameters, create_graph=True) 245 | 246 | loss += match_loss(gw_syn, gw_real, args, self.device) 247 | else: 248 | data_real = self.get_graphs_multiclass(batch_size=args.bs_cond) 249 | selected = [] 250 | for c in range(dataset.num_classes): 251 | ind = self.syn_class_indices[c] 252 | ind = np.arange(ind[0], ind[1]) 253 | selected.append(ind) 254 | 255 | selected = np.hstack(selected) 256 | feat_syn_c = feat_syn[selected] 257 | adj_syn_c = adj_syn[selected] 258 | 259 | labels_real = data_real.y 260 | 261 | labels_syn = self.labels_syn[selected] 262 | output_real = model_real(data_real) 263 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 264 | loss_real = cls_criterion(output_real, labels_real.view(-1, 1).float()) 265 | else: 266 | loss_real = F.nll_loss(output_real, labels_real) 267 | gw_real = torch.autograd.grad(loss_real, model_real_parameters) 268 | gw_real = list((_.detach().clone() for _ in gw_real)) 269 | 270 | output_syn = model_syn(feat_syn_c, adj_syn_c) 271 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 272 | loss_syn = cls_criterion(output_syn, labels_syn.view(-1, 1).float()) 273 | else: 274 | loss_syn = F.nll_loss(output_syn, labels_syn) 275 | gw_syn = torch.autograd.grad(loss_syn, model_syn_parameters, create_graph=True) 276 | 277 | loss += 1e-0*match_loss(gw_syn, gw_real, args, self.device) 278 | 279 | loss_reg = F.relu(torch.sigmoid(self.adj_syn).mean() - sparsity) 280 | if args.dataset in ['ogbg-molhiv']: 281 | akx = get_akx(feat_syn, adj_syn, K=args.nconvs, pool=args.pooling) 282 | nclass = dataset.num_classes 283 | first = np.sqrt(2) * loss_avg * nclass 284 | second = 3/2/np.sqrt(100) * (nclass-1)/nclass / adj_syn.shape[0] * akx 285 | if it % 50==0: 286 | print('first:', first , 'second:', second) 287 | loss_avg += loss.item() 288 | loss = loss + self.args.beta*loss_reg + 1/np.sqrt(2)*second # + 1e-4* torch.norm(self.feat_syn) 289 | else: 290 | loss_avg += loss.item() 291 | loss = loss + self.args.beta*loss_reg 292 | self.optimizer_adj.zero_grad() 293 | self.optimizer_feat.zero_grad() 294 | 295 | loss.backward() 296 | 297 | self.optimizer_adj.step() 298 | self.optimizer_feat.step() 299 | if not self.args.stru_discrete: 300 | self.clip() 301 | 302 | if ol == args.outer_loop - 1: 303 | break 304 | 305 | self.train_inner(model_syn, model_real, optimizer, epochs=args.inner_loop) 306 | 307 | loss_avg /= (dataset.num_classes*args.outer_loop) 308 | 309 | if it % 20 == 0: 310 | print('Condensation - Iter:', it, 'loss:', loss_avg) 311 | print('sparsity loss', loss_reg.item()) 312 | 313 | if it == 400: 314 | self.optimizer_adj = torch.optim.Adam([self.adj_syn], lr=0.1*args.lr_adj) # optimizer for synthetic data 315 | self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=0.1*args.lr_feat) # optimizer for 316 | 317 | print_freq = 200 318 | if (it+1) % print_freq == 0: 319 | print('time consumed:', time.time()-st) 320 | adj_syn2 = self.adj_syn.detach().clone() 321 | 322 | if args.save: 323 | torch.save([self.adj_syn, self.feat_syn], f'saved/{args.dataset}_ipc{args.ipc}_s{args.seed}_lra{args.lr_adj}_lrf{args.lr_feat}.pt') 324 | 325 | res = [] 326 | for _ in range(runs): 327 | if args.dataset in ['ogbg-molhiv']: 328 | res.append(self.test(epochs=100)) 329 | else: 330 | res.append(self.test(epochs=500)) 331 | res = np.array(res) 332 | print('Mean Train/Val/TestAcc/TrainLoss:', res.mean(0)) 333 | print('Std Train/Val/TestAcc/TrainLoss:', res.std(0)) 334 | 335 | 336 | def test(self, epochs=500, save=False, verbose=False, new_data=None): 337 | dataset = self.data[0] 338 | args = self.args 339 | model_syn = DenseGCN(nfeat=dataset.num_features, nhid=args.hidden, dropout=args.dropout, net_norm=args.net_norm, 340 | nconvs=args.nconvs, nclass=dataset.num_classes, pooling=args.pooling, args=args).to(self.device) 341 | model_real = GCN(nfeat=dataset.num_features, dropout=0.0, net_norm=args.net_norm, 342 | nconvs=args.nconvs, nhid=args.hidden, nclass=dataset.num_classes, pooling=args.pooling, args=args).to(self.device) 343 | 344 | if new_data is None: 345 | feat_syn = self.feat_syn.detach() 346 | adj_syn = self.adj_syn.detach() 347 | if args.stru_discrete: 348 | adj_syn = self.get_discrete_graphs(adj_syn, inference=True) 349 | # print('Mean sparsity:', (adj_syn.sum(1).sum(1) / adj_syn.size(1) / adj_syn.size(1)).mean().item()) 350 | else: 351 | feat_syn, adj_syn = new_data 352 | feat_syn, adj_syn = feat_syn.detach(), adj_syn.detach() 353 | 354 | labels_syn = self.labels_syn 355 | 356 | # Convert adjancency matrix to edge_index stored as torch_geometric.data.Data 357 | sampled = [] 358 | sampled = np.ndarray((adj_syn.size(0),), dtype=np.object) 359 | from torch_geometric.data import Data 360 | for i in range(adj_syn.size(0)): 361 | x = feat_syn[i] 362 | adj = adj_syn[i] 363 | g = adj.nonzero().T 364 | y = self.labels_syn[i] 365 | single_data = Data(x=x, edge_index=g, y=y) 366 | sampled[i] = (single_data) 367 | return self.test_pyg_data(sampled, epochs=epochs) 368 | 369 | def test_pyg_data(self, syn_data=None, epochs=500, save=False, verbose=False): 370 | dataset = self.data[0] 371 | args = self.args 372 | use_val = True 373 | model = GCN(nfeat=dataset.num_features, nconvs=args.nconvs, nhid=args.hidden, nclass=dataset.num_classes, net_norm=args.net_norm, pooling=args.pooling, dropout=args.dropout, args=args).to(self.device) 374 | lr = 0.001 375 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 376 | if syn_data is None: 377 | data = self.adj_syn 378 | else: 379 | data = syn_data 380 | dst_syn_train = SparseTensorDataset(data) 381 | 382 | from torch_geometric.loader import DataLoader 383 | if args.dataset in ['CIFAR10']: 384 | train_loader = DataLoader(dst_syn_train, batch_size=512, shuffle=True, num_workers=0) 385 | else: 386 | train_loader = DataLoader(dst_syn_train, batch_size=128, shuffle=True, num_workers=0) 387 | 388 | @torch.no_grad() 389 | def test(loader, report_metric=False): 390 | model.eval() 391 | if self.args.dataset in ['ogbg-molhiv','ogbg-molbbbp', 'ogbg-molbace']: 392 | pred, y = [], [] 393 | for data in loader: 394 | data = data.to(self.device) 395 | pred.append(model(data)) 396 | y.append(data.y.view(-1,1)) 397 | from ogb.graphproppred import Evaluator; 398 | evaluator = Evaluator(self.args.dataset) 399 | return evaluator.eval({'y_pred': torch.cat(pred), 400 | 'y_true': torch.cat(y)})['rocauc'] 401 | else: 402 | correct = 0 403 | for data in loader: 404 | data = data.to(self.device) 405 | pred = model(data).max(dim=1)[1] 406 | correct += pred.eq(data.y.view(-1)).sum().item() 407 | if report_metric: 408 | nnodes_list = [(data.ptr[i]-data.ptr[i-1]).item() for i in range(1, len(data.ptr))] 409 | low = np.quantile(nnodes_list, 0.2) 410 | high = np.quantile(nnodes_list, 0.8) 411 | correct_low = pred.eq(data.y.view(-1))[nnodes_list<=low].sum().item() 412 | correct_medium = pred.eq(data.y.view(-1))[(nnodes_list>low)&(nnodes_list=high].sum().item() 414 | print(100*correct_low/(nnodes_list<=low).sum(), 415 | 100*correct_medium/((nnodes_list>low) & (nnodes_list=high).sum()) 417 | return 100*correct / len(loader.dataset) 418 | 419 | res = [] 420 | best_val_acc = 0 421 | 422 | for it in range(epochs): 423 | if it == epochs//2: 424 | optimizer = torch.optim.Adam(model.parameters(), lr=0.1*lr) 425 | 426 | model.train() 427 | loss_all = 0 428 | for data in train_loader: 429 | data = data.to(self.device) 430 | y = data.y 431 | optimizer.zero_grad() 432 | output = model(data) 433 | if args.dataset in ['ogbg-molhiv','ogbg-molbbbp', 'ogbg-molbace']: 434 | loss = cls_criterion(output, y.view(-1, 1).float()) 435 | else: 436 | loss = F.nll_loss(output, y.view(-1)) 437 | loss.backward() 438 | loss_all += y.size(0) * loss.item() 439 | optimizer.step() 440 | 441 | loss = loss_all / len(dst_syn_train) 442 | if verbose: 443 | if it % 100 == 0: 444 | print('Evaluation Stage - loss:', loss) 445 | 446 | if use_val: 447 | acc_val = test(self.data[2]) 448 | if acc_val > best_val_acc: 449 | best_val_acc = acc_val 450 | if verbose: 451 | acc_train = test(self.data[1]) 452 | acc_test = test(self.data[3], report_metric=False) 453 | print('acc_train:', acc_train, 'acc_val:', acc_val, 'acc_test:', acc_test) 454 | if save: 455 | torch.save(model.state_dict(), f'saved/{args.dataset}_{args.seed}.pt') 456 | weights = deepcopy(model.state_dict()) 457 | 458 | if use_val: 459 | model.load_state_dict(weights) 460 | else: 461 | best_val_acc = test(self.data[2]) 462 | acc_train = test(self.data[1]) 463 | acc_test = test(self.data[3], report_metric=False) 464 | # print([acc_train, best_val_acc, acc_test]) 465 | return [acc_train, best_val_acc, acc_test] 466 | 467 | def train_inner(self, model_syn, model_real, optimizer, epochs=500, save=False, verbose=False): 468 | if epochs == 0: 469 | return 470 | dataset = self.data[0] 471 | args = self.args 472 | feat_syn = self.feat_syn.detach() 473 | adj_syn = self.adj_syn 474 | adj_syn = adj_syn.detach() 475 | labels_syn = self.labels_syn 476 | dst_syn_train = TensorDataset(feat_syn, adj_syn, labels_syn) 477 | train_loader = torch.utils.data.DataLoader(dst_syn_train, batch_size=128, shuffle=True, num_workers=0) 478 | 479 | for it in range(epochs): 480 | model_syn.train() 481 | loss_all = 0 482 | for data in train_loader: 483 | x, adj, y = data 484 | x, adj, y = x.to(self.device), adj.to(self.device), y.to(self.device) 485 | optimizer.zero_grad() 486 | output = model_syn(x, adj, mask=None) 487 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 488 | loss = cls_criterion(output, y.view(-1, 1).float()) 489 | else: 490 | loss = F.nll_loss(output, y.view(-1)) 491 | loss.backward() 492 | optimizer.step() 493 | model_real.load_state_dict(model_syn.state_dict()) 494 | 495 | def test_full_train(self, epochs=500, save=False, verbose=False): 496 | dataset = self.data[0] 497 | use_val = True 498 | args = self.args 499 | model = GCN(nfeat=dataset.num_features, nconvs=args.nconvs, nhid=args.hidden, nclass=dataset.num_classes, net_norm=args.net_norm, pooling=args.pooling, dropout=args.dropout, args=args).to(self.device) 500 | 501 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 502 | train_loader = self.data[1] 503 | 504 | @torch.no_grad() 505 | def test(loader, report_metric=False): 506 | model.eval() 507 | if self.args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 508 | pred, y = [], [] 509 | for data in loader: 510 | data = data.to(self.device) 511 | pred.append(model(data)) 512 | y.append(data.y.view(-1,1)) 513 | from ogb.graphproppred import Evaluator; 514 | evaluator = Evaluator(self.args.dataset) 515 | return evaluator.eval({'y_pred': torch.cat(pred), 516 | 'y_true': torch.cat(y)})['rocauc'] 517 | else: 518 | correct = 0 519 | for data in loader: 520 | data = data.to(self.device) 521 | pred = model(data).max(dim=1)[1] 522 | correct += pred.eq(data.y.view(-1)).sum().item() 523 | if report_metric: 524 | nnodes_list = [(data.ptr[i]-data.ptr[i-1]).item() for i in range(1, len(data.ptr))] 525 | low = np.quantile(nnodes_list, 0.2) 526 | high = np.quantile(nnodes_list, 0.8) 527 | correct_low = pred.eq(data.y.view(-1))[nnodes_list<=low].sum().item() 528 | correct_medium = pred.eq(data.y.view(-1))[(nnodes_list>low)&(nnodes_list=high].sum().item() 530 | print(100*correct_low/(nnodes_list<=low).sum(), 531 | 100*correct_medium/((nnodes_list>low) & (nnodes_list=high).sum()) 533 | return 100*correct / len(loader.dataset) 534 | 535 | res = [] 536 | best_val_acc = 0 537 | 538 | for it in range(epochs): 539 | if it == epochs//2: 540 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 541 | 542 | model.train() 543 | loss_all = 0 544 | for data in train_loader: 545 | data = data.to(self.device) 546 | y = data.y 547 | optimizer.zero_grad() 548 | if self.args.augment: 549 | x = F.dropout(x) 550 | output = model(data) 551 | if args.dataset in ['ogbg-molhiv', 'ogbg-molbbbp', 'ogbg-molbace']: 552 | loss = cls_criterion(output, y.view(-1, 1).float()) 553 | else: 554 | loss = F.nll_loss(output, y.view(-1)) 555 | loss.backward() 556 | loss_all += y.size(0) * loss.item() 557 | optimizer.step() 558 | 559 | loss = loss_all / len(self.data[0]) 560 | if verbose: 561 | if it % 100 == 0: 562 | print('Evaluation Stage - loss:', loss) 563 | 564 | if use_val: 565 | acc_val = test(self.data[2]) 566 | if acc_val > best_val_acc: 567 | best_val_acc = acc_val 568 | if verbose: 569 | acc_train = test(self.data[1]) 570 | acc_test = test(self.data[3], report_metric=True) 571 | print('acc_train:', acc_train, 'acc_val:', acc_val, 'acc_test:', acc_test) 572 | if save: 573 | torch.save(model.state_dict(), f'saved/{args.dataset}_{args.seed}.pt') 574 | weights = deepcopy(model.state_dict()) 575 | 576 | if use_val: 577 | model.load_state_dict(weights) 578 | acc_train = test(self.data[1]) 579 | acc_test = test(self.data[3], report_metric=True) 580 | 581 | @torch.no_grad() 582 | def get_embeds(loader): 583 | model.eval() 584 | all_emb = [] 585 | for data in loader: 586 | data = data.to(self.device) 587 | emb = model.embed(data) 588 | all_emb.append(emb) 589 | return torch.cat(all_emb, dim=0) 590 | 591 | # don't shuffle training data 592 | new_train_loader = DataLoader(self.data[0], batch_size=1024, shuffle=False) 593 | embeds = get_embeds(new_train_loader) 594 | return embeds.cpu() 595 | 596 | def get_discrete_graphs(self, adj, inference): 597 | if not hasattr(self, 'cnt'): 598 | self.cnt = 0 599 | 600 | if self.args.dataset not in ['CIFAR10']: 601 | adj = (adj.transpose(1,2) + adj) / 2 602 | 603 | if not inference: 604 | N = adj.size()[1] 605 | vals = torch.rand(adj.size(0) * N * (N+1) // 2) 606 | vals = vals.view(adj.size(0), -1).to(self.device) 607 | i, j = torch.triu_indices(N, N) 608 | epsilon = torch.zeros_like(adj) 609 | epsilon[:, i, j] = vals 610 | epsilon.transpose(1,2)[:, i, j] = vals 611 | 612 | tmp = torch.log(epsilon) - torch.log(1-epsilon) 613 | self.tmp = tmp 614 | adj = tmp + adj 615 | t0 = 1 616 | tt = 0.01 617 | end_iter = 200 618 | t = t0*(tt/t0)**(self.cnt/end_iter) 619 | if self.cnt == end_iter: 620 | print('===reached the end of anealing...') 621 | self.cnt += 1 622 | 623 | t = max(t, tt) 624 | adj = torch.sigmoid(adj/t) 625 | adj = adj * (1-torch.eye(adj.size(1)).to(self.device)) 626 | else: 627 | adj = torch.sigmoid(adj) 628 | adj = adj * (1-torch.eye(adj.size(1)).to(self.device)) 629 | adj[adj> 0.5] = 1 630 | adj[adj<= 0.5] = 0 631 | return adj 632 | 633 | 634 | --------------------------------------------------------------------------------