├── Codes ├── helper │ ├── __init__.py │ ├── timer │ │ ├── __init__.py │ │ ├── timer.py │ │ └── comm_timer.py │ ├── context.py │ ├── intra_context.py │ ├── reducer.py │ ├── MongoManager.py │ ├── sampler.py │ ├── parser.py │ ├── mapper.py │ └── utils.py ├── module │ ├── __init__.py │ ├── sync_bn.py │ ├── others.py │ ├── baseline_model.py │ ├── model.py │ └── layer.py ├── brief_opt_baseline_test.sh ├── brief_sampling_test.sh ├── brief_masking_test.sh ├── main.py ├── our_main.py ├── test.py └── train.py ├── AE ├── badges.png ├── configs.py ├── ae2_parser.py ├── ae1_parser.py ├── ae2_acc_baseline.py ├── ae1_tpt_opt_baseline.py ├── ae2_acc_eas.py ├── ae1_tpt_eas.py ├── ae1_tpt_cob.py └── ae1_tpt_flx.py ├── parse_ae.sh ├── PACT24_GraNNDis_Author_Copy.pdf ├── run_ae.sh ├── pip-requirements.txt ├── LICENSE ├── conda-requirements.txt ├── .gitignore └── README.md /Codes/helper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Codes/module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Codes/helper/timer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /AE/badges.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-SNU/GraNNDis_Artifact/HEAD/AE/badges.png -------------------------------------------------------------------------------- /Codes/helper/timer/timer.py: -------------------------------------------------------------------------------- 1 | from helper.timer.comm_timer import * 2 | 3 | comm_timer = CommTimer() 4 | -------------------------------------------------------------------------------- /parse_ae.sh: -------------------------------------------------------------------------------- 1 | cd AE 2 | python ae1_parser.py > ../AE1_results.log 3 | python ae2_parser.py > ../AE2_results.log 4 | cd .. -------------------------------------------------------------------------------- /PACT24_GraNNDis_Author_Copy.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIS-SNU/GraNNDis_Artifact/HEAD/PACT24_GraNNDis_Author_Copy.pdf -------------------------------------------------------------------------------- /Codes/helper/context.py: -------------------------------------------------------------------------------- 1 | from helper.feature_buffer import * 2 | from helper.reducer import * 3 | 4 | buffer = Buffer() 5 | reducer = Reducer() 6 | -------------------------------------------------------------------------------- /Codes/helper/intra_context.py: -------------------------------------------------------------------------------- 1 | from helper.feature_buffer import * 2 | from helper.reducer import * 3 | 4 | buffer = IntraBuffer() 5 | reducer = Reducer() 6 | -------------------------------------------------------------------------------- /run_ae.sh: -------------------------------------------------------------------------------- 1 | echo "Preparing Log Directory" 2 | rm -rf Logs 3 | mkdir Logs 4 | 5 | echo "Run AE1" 6 | cd AE 7 | python ae1_tpt_opt_baseline.py 8 | python ae1_tpt_flx.py 9 | python ae1_tpt_cob.py 10 | python ae1_tpt_eas.py 11 | 12 | echo "Run AE2" 13 | python ae2_acc_baseline.py 14 | python ae2_acc_eas.py 15 | 16 | cd .. -------------------------------------------------------------------------------- /AE/configs.py: -------------------------------------------------------------------------------- 1 | global_configs = { 2 | 'env_loc': '/nfs/home/ae/anaconda3/envs/granndis_ae/bin/python', 3 | 'runner_loc': '/nfs/home/ae/GraNNDis_Artifact/Codes/main.py', 4 | 'our_runner_loc': '/nfs/home/ae/GraNNDis_Artifact/Codes/our_main.py', 5 | 'workspace_loc': '/nfs/home/ae/GraNNDis_Artifact/', 6 | 'data_loc': '~/datasets/granndis_ae/', 7 | 'num_runners': 2, 8 | 'gpus_per_server': 4, 9 | 'hosts': ['192.168.0.5', '192.168.0.6'] 10 | } -------------------------------------------------------------------------------- /pip-requirements.txt: -------------------------------------------------------------------------------- 1 | bcrypt==4.1.3 2 | cffi==1.16.0 3 | cryptography==42.0.8 4 | paramiko==3.4.0 5 | pycparser==2.22 6 | pynacl==1.5.0 7 | packaging==24.1 8 | pytz==2024.1 9 | python-dateutil==2.9.0 10 | sshtunnel==0.4.0 11 | prettytable==3.10.0 12 | wcwidth==0.2.13 13 | fsspec==2024.6.1 14 | torchdata==0.7.1 15 | tzdata==2024.1 16 | pandas==2.2.2 17 | annotated-types==0.7.0 18 | pydantic==2.8.2 19 | pydantic-core==2.20.1 20 | joblib==1.4.2 21 | littleutils==0.2.2 22 | ogb==1.3.6 23 | outdated==0.2.2 24 | scikit-learn==1.5.1 25 | threadpoolctl==3.5.0 26 | dnspython==2.6.1 27 | pymongo==4.8.0 -------------------------------------------------------------------------------- /Codes/helper/timer/comm_timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch.distributed as dist 3 | from contextlib import contextmanager 4 | 5 | 6 | class CommTimer(object): 7 | 8 | def __init__(self): 9 | super(CommTimer, self).__init__() 10 | self._time = {} 11 | 12 | @contextmanager 13 | def timer(self, name): 14 | if name in self._time: 15 | raise Exception(name + " already exists") 16 | t0 = time.time() 17 | yield 18 | t1 = time.time() 19 | self._time[name] = (t0, t1) 20 | 21 | def tot_time(self): 22 | tot = 0 23 | for (t0, t1) in self._time.values(): 24 | tot += t1 - t0 25 | return tot 26 | 27 | def print_time(self): 28 | rank, size = dist.get_rank(), dist.get_world_size() 29 | for (k, (t0, t1)) in self._time.items(): 30 | print(f'(rank {rank}) Communication time of {k}: {t1 - t0} seconds.') 31 | 32 | def clear(self): 33 | self._time = {} 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AIS-SNU 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Codes/brief_opt_baseline_test.sh: -------------------------------------------------------------------------------- 1 | rm -rf partitions/ 2 | 3 | python main.py \ 4 | --dataset ogbn-arxiv \ 5 | --dropout 0.5 \ 6 | --lr 0.01 \ 7 | --n-partitions 4 \ 8 | --n-epochs 100 \ 9 | --time-calc \ 10 | --model graphsage \ 11 | --n-layers 4 \ 12 | --n-linear 1 \ 13 | --n-hidden 64 \ 14 | --log-every 10 \ 15 | --backend nccl \ 16 | --dataset-path ~/datasets/granndis_ae/ \ 17 | --create-json 1 \ 18 | --json-path ./baseline_test_logs \ 19 | --project granndis_test \ 20 | --no-eval \ 21 | --inductive \ 22 | --total-nodes 1 \ 23 | 24 | python main.py \ 25 | --dataset reddit \ 26 | --dropout 0.5 \ 27 | --lr 0.01 \ 28 | --n-partitions 4 \ 29 | --n-epochs 100 \ 30 | --time-calc \ 31 | --model graphsage \ 32 | --n-layers 4 \ 33 | --n-linear 1 \ 34 | --n-hidden 64 \ 35 | --log-every 10 \ 36 | --backend nccl \ 37 | --dataset-path ~/datasets/granndis_ae/ \ 38 | --create-json 1 \ 39 | --json-path ./baseline_test_logs \ 40 | --project granndis_test \ 41 | --no-eval \ 42 | --inductive \ 43 | --total-nodes 1 \ 44 | 45 | python main.py \ 46 | --dataset ogbn-products \ 47 | --dropout 0.3 \ 48 | --lr 0.003 \ 49 | --n-partitions 4 \ 50 | --n-epochs 100 \ 51 | --time-calc \ 52 | --model graphsage \ 53 | --n-layers 4 \ 54 | --n-linear 1 \ 55 | --n-hidden 64 \ 56 | --log-every 10 \ 57 | --backend nccl \ 58 | --dataset-path ~/datasets/granndis_ae/ \ 59 | --create-json 1 \ 60 | --json-path ./baseline_test_logs \ 61 | --project granndis_test \ 62 | --no-eval \ 63 | --inductive \ 64 | --total-nodes 1 \ -------------------------------------------------------------------------------- /Codes/helper/reducer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing.pool import ThreadPool 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | class Reducer(object): 9 | 10 | def __init__(self): 11 | super(Reducer, self).__init__() 12 | self._data_cpu = {} 13 | self._pool = None 14 | self._handles = [] 15 | self._stream = None 16 | 17 | def init(self, model, world_size): 18 | 19 | num_params = len(list(model.named_parameters())) 20 | num_workers = int(os.cpu_count() / world_size) 21 | 22 | for i, (name, param) in enumerate(model.named_parameters()): 23 | cur_group = dist.new_group() 24 | self._data_cpu[name] = (torch.zeros_like(param.data, pin_memory=True, device='cpu'), cur_group) 25 | 26 | self._pool = ThreadPool(processes= num_workers) 27 | self._stream = torch.cuda.Stream() 28 | 29 | def reduce(self, param, name, data, n_train): 30 | def create_stream(): 31 | self._stream.wait_stream(torch.cuda.current_stream()) 32 | with torch.cuda.stream(self._stream): 33 | data.div_(n_train) 34 | data_cpu, group = self._data_cpu[name] 35 | data_cpu.copy_(data) 36 | dist.all_reduce(data_cpu, op=dist.ReduceOp.SUM, group=group) 37 | param.grad.copy_(data_cpu, non_blocking=True) 38 | 39 | self._handles.append(self._pool.apply_async(create_stream)) 40 | 41 | def synchronize(self): 42 | for handle in self._handles: 43 | handle.wait() 44 | self._handles.clear() 45 | torch.cuda.current_stream().wait_stream(self._stream) 46 | -------------------------------------------------------------------------------- /Codes/brief_sampling_test.sh: -------------------------------------------------------------------------------- 1 | rm -rf intra-partitions/ 2 | 3 | python our_main.py \ 4 | --dataset ogbn-arxiv \ 5 | --dropout 0.5 \ 6 | --lr 0.01 \ 7 | --n-partitions 4 \ 8 | --n-epochs 100 \ 9 | --time-calc \ 10 | --model graphsage \ 11 | --n-layers 4 \ 12 | --n-linear 1 \ 13 | --n-hidden 64 \ 14 | --log-every 10 \ 15 | --backend nccl \ 16 | --dataset-path ~/datasets/granndis_ae/ \ 17 | --create-json 1 \ 18 | --json-path ./test_logs \ 19 | --project granndis_test \ 20 | --no-eval \ 21 | --inductive \ 22 | --total-nodes 1 \ 23 | --bandwidth-aware \ 24 | --subgraph-hop 1 \ 25 | --fanout 15 \ 26 | --use-mask \ 27 | 28 | python our_main.py \ 29 | --dataset reddit \ 30 | --dropout 0.5 \ 31 | --lr 0.01 \ 32 | --n-partitions 4 \ 33 | --n-epochs 100 \ 34 | --time-calc \ 35 | --model graphsage \ 36 | --n-layers 4 \ 37 | --n-linear 1 \ 38 | --n-hidden 64 \ 39 | --log-every 10 \ 40 | --backend nccl \ 41 | --dataset-path ~/datasets/granndis_ae/ \ 42 | --create-json 1 \ 43 | --json-path ./test_logs \ 44 | --project granndis_test \ 45 | --no-eval \ 46 | --inductive \ 47 | --total-nodes 1 \ 48 | --bandwidth-aware \ 49 | --subgraph-hop 1 \ 50 | --fanout 15 \ 51 | --use-mask \ 52 | 53 | python our_main.py \ 54 | --dataset ogbn-products \ 55 | --dropout 0.3 \ 56 | --lr 0.003 \ 57 | --n-partitions 4 \ 58 | --n-epochs 100 \ 59 | --time-calc \ 60 | --model graphsage \ 61 | --n-layers 4 \ 62 | --n-linear 1 \ 63 | --n-hidden 64 \ 64 | --log-every 10 \ 65 | --backend nccl \ 66 | --dataset-path ~/datasets/granndis_ae/ \ 67 | --create-json 1 \ 68 | --json-path ./test_logs \ 69 | --project granndis_test \ 70 | --no-eval \ 71 | --inductive \ 72 | --total-nodes 1 \ 73 | --bandwidth-aware \ 74 | --subgraph-hop 1 \ 75 | --fanout 15 \ 76 | --use-mask \ -------------------------------------------------------------------------------- /Codes/brief_masking_test.sh: -------------------------------------------------------------------------------- 1 | rm -rf intra-partitions/ 2 | 3 | python our_main.py \ 4 | --dataset ogbn-arxiv \ 5 | --dropout 0.5 \ 6 | --lr 0.01 \ 7 | --n-partitions 4 \ 8 | --n-epochs 100 \ 9 | --epoch-iter 1 \ 10 | --time-calc \ 11 | --model graphsage \ 12 | --n-layers 4 \ 13 | --n-linear 1 \ 14 | --n-hidden 64 \ 15 | --log-every 10 \ 16 | --backend nccl \ 17 | --dataset-path ~/datasets/granndis_ae/ \ 18 | --create-json 1 \ 19 | --json-path ./masking_test_logs \ 20 | --project granndis_test \ 21 | --no-eval \ 22 | --inductive \ 23 | --total-nodes 1 \ 24 | --bandwidth-aware \ 25 | --subgraph-hop 3 \ 26 | --fanout -1 \ 27 | --sampler sage \ 28 | --use-mask \ 29 | 30 | python our_main.py \ 31 | --dataset reddit \ 32 | --dropout 0.5 \ 33 | --lr 0.01 \ 34 | --n-partitions 4 \ 35 | --n-epochs 100 \ 36 | --epoch-iter 1 \ 37 | --time-calc \ 38 | --model graphsage \ 39 | --n-layers 4 \ 40 | --n-linear 1 \ 41 | --n-hidden 64 \ 42 | --log-every 10 \ 43 | --backend nccl \ 44 | --dataset-path ~/datasets/granndis_ae/ \ 45 | --create-json 1 \ 46 | --json-path ./masking_test_logs \ 47 | --project granndis_test \ 48 | --no-eval \ 49 | --inductive \ 50 | --total-nodes 1 \ 51 | --bandwidth-aware \ 52 | --subgraph-hop 3 \ 53 | --fanout -1 \ 54 | --sampler sage \ 55 | --use-mask \ 56 | 57 | python our_main.py \ 58 | --dataset ogbn-products \ 59 | --dropout 0.3 \ 60 | --lr 0.003 \ 61 | --n-partitions 4 \ 62 | --n-epochs 100 \ 63 | --epoch-iter 1 \ 64 | --time-calc \ 65 | --model graphsage \ 66 | --n-layers 4 \ 67 | --n-linear 1 \ 68 | --n-hidden 64 \ 69 | --log-every 10 \ 70 | --backend nccl \ 71 | --dataset-path ~/datasets/granndis_ae/ \ 72 | --create-json 1 \ 73 | --json-path ./masking_test_logs \ 74 | --project granndis_test \ 75 | --no-eval \ 76 | --inductive \ 77 | --total-nodes 1 \ 78 | --bandwidth-aware \ 79 | --subgraph-hop 3 \ 80 | --fanout -1 \ 81 | --sampler sage \ 82 | --use-mask \ -------------------------------------------------------------------------------- /AE/ae2_parser.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from os import listdir 3 | from os.path import isfile, join 4 | 5 | import prettytable 6 | import json 7 | 8 | def find_acc(dict_list: List, dataset_name: str, acc_type: str) -> float: 9 | r""" find proper duration from dictionary list of various datasets 10 | """ 11 | for element in dict_list: 12 | if element['dataset'] == dataset_name: 13 | if acc_type in element: 14 | return element[acc_type] 15 | else: 16 | raise NotImplementedError 17 | 18 | if __name__ == '__main__': 19 | # prepare list 20 | fb_list = [] 21 | eas_list = [] 22 | 23 | # parse fb 24 | fb_basedir = '../Logs/fb_acc' 25 | fb_files = [f for f in listdir(fb_basedir) \ 26 | if isfile(join(fb_basedir, f))] 27 | assert len(fb_files) == 3, 'we tested only three datasets.' 28 | for fb_file in fb_files: 29 | with open(join(fb_basedir, fb_file)) as json_file: 30 | fb_data = json.load(json_file) 31 | fb_list.append(fb_data) 32 | 33 | # parse eas 34 | eas_basedir = '../Logs/eas_acc' 35 | eas_files = [f for f in listdir(eas_basedir) \ 36 | if isfile(join(eas_basedir, f))] 37 | assert len(eas_files) == 3, 'we tested only three datasets.' 38 | for eas_file in eas_files: 39 | with open(join(eas_basedir, eas_file)) as json_file: 40 | eas_data = json.load(json_file) 41 | eas_list.append(eas_data) 42 | 43 | acc_table = prettytable.PrettyTable() 44 | acc_table.title = 'Accuracy Comparison (FB vs. FLX-EAS)' 45 | acc_table.field_names = ['Method', 'Arxiv', 'Reddit', 'Products'] 46 | acc_table.add_row(['FB', find_acc(fb_list, 'ogbn-arxiv', 'test_accuracy'), 47 | find_acc(fb_list, 'reddit', 'test_accuracy'), 48 | find_acc(fb_list, 'ogbn-products', 'test_accuracy')]) 49 | acc_table.add_row(['EAS', find_acc(eas_list, 'ogbn-arxiv', 'test_accuracy'), 50 | find_acc(eas_list, 'reddit', 'test_accuracy'), 51 | find_acc(eas_list, 'ogbn-products', 'test_accuracy')]) 52 | acc_table.float_format = ".2" 53 | print(acc_table) -------------------------------------------------------------------------------- /Codes/module/sync_bn.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | from torch import nn 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class SyncBatchNormFunc(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, x, weight, bias, whole_size, running_mean, running_var, training, momentum, eps): 11 | if not training: 12 | mean = running_mean 13 | var = running_var 14 | else: 15 | sum_x = x.sum(axis=0) 16 | sum_x2 = (x ** 2).sum(axis=0) 17 | dist.all_reduce(sum_x, op=dist.ReduceOp.SUM) 18 | dist.all_reduce(sum_x2, op=dist.ReduceOp.SUM) 19 | mean = sum_x / whole_size 20 | var = (sum_x2 - mean * sum_x) / whole_size 21 | running_mean.mul_(1 - momentum).add_(mean * momentum) 22 | running_var.mul_(1 - momentum).add_(var * momentum) 23 | std = torch.sqrt(var + eps) 24 | x_hat = (x - mean) / std 25 | if training: 26 | ctx.save_for_backward(x_hat, weight, std) 27 | ctx.whole_size = whole_size 28 | return x_hat * weight + bias 29 | 30 | @staticmethod 31 | def backward(ctx, grad): 32 | x_hat, weight, std = ctx.saved_tensors 33 | dbias = grad.sum(axis=0) 34 | dweight = (grad * x_hat).sum(axis=0) 35 | dist.all_reduce(dbias, op=dist.ReduceOp.SUM) 36 | dist.all_reduce(dweight, op=dist.ReduceOp.SUM) 37 | n = ctx.whole_size 38 | dx = (weight / n) / std * (n * grad - dbias - x_hat * dweight) 39 | return dx, dweight, dbias, None, None, None, None, None, None 40 | 41 | 42 | class SyncBatchNorm(nn.Module): 43 | 44 | def __init__(self, num_features, whole_size, eps=1e-5, momentum=0.1): 45 | super(SyncBatchNorm, self).__init__() 46 | self.register_buffer('running_mean', torch.zeros(num_features)) 47 | self.register_buffer('running_var', torch.ones(num_features)) 48 | self.whole_size = whole_size 49 | self.eps = eps 50 | self.momentum = momentum 51 | self.weight = nn.Parameter(torch.ones(num_features)) 52 | self.bias = nn.Parameter(torch.zeros(num_features)) 53 | 54 | def forward(self, x): 55 | return SyncBatchNormFunc.apply(x, self.weight, self.bias, self.whole_size, self.running_mean, self.running_var, 56 | self.training, self.momentum, self.eps) 57 | -------------------------------------------------------------------------------- /Codes/module/others.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def act_layer(act_type, inplace=False, neg_slope=0.2, n_prelu=1): 7 | act = act_type.lower() 8 | 9 | if act == 'relu': 10 | layer = nn.ReLU(inplace) 11 | elif act == 'leakyrelu': 12 | layer = nn.LeakyReLU(neg_slope, inplace) 13 | elif act == 'prelu': 14 | layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) 15 | else: 16 | raise NotImplementedError('activation layer [%s] is not found' % act) 17 | 18 | return layer 19 | 20 | 21 | def norm_layer(norm_type, nc): 22 | norm = norm_type.lower() 23 | 24 | if norm == 'batch': 25 | layer = nn.BatchNorm1d(nc, affine=True) 26 | elif norm == 'layer': 27 | layer = nn.LayerNorm(nc, elementwise_affine=True) 28 | elif norm == 'instance': 29 | layer = nn.InstanceNorm1d(nc, affine=False) 30 | else: 31 | raise NotImplementedError(f'Normalization layer {norm} is not supported.') 32 | 33 | return layer 34 | 35 | 36 | class MLP(nn.Sequential): 37 | r""" 38 | Description 39 | ----------- 40 | From equation (5) in `DeeperGCN: All You Need to Train Deeper GCNs `_ 41 | """ 42 | 43 | def __init__(self, 44 | channels, 45 | act='relu', 46 | norm=None, 47 | dropout=0., 48 | bias=True): 49 | layers = [] 50 | 51 | for i in range(1, len(channels)): 52 | layers.append(nn.Linear(channels[i - 1], channels[i], bias)) 53 | if i < len(channels) - 1: 54 | if norm is not None and norm.lower() != 'none': 55 | layers.append(norm_layer(norm, channels[i])) 56 | if act is not None and act.lower() != 'none': 57 | layers.append(act_layer(act)) 58 | layers.append(nn.Dropout(dropout)) 59 | 60 | super(MLP, self).__init__(*layers) 61 | 62 | 63 | class MessageNorm(nn.Module): 64 | r""" 65 | 66 | Description 67 | ----------- 68 | Message normalization was introduced in `DeeperGCN: All You Need to Train Deeper GCNs `_ 69 | Parameters 70 | ---------- 71 | learn_scale: bool 72 | Whether s is a learnable scaling factor or not. Default is False. 73 | """ 74 | 75 | def __init__(self, learn_scale=False): 76 | super(MessageNorm, self).__init__() 77 | self.scale = nn.Parameter(torch.FloatTensor([1.0]), requires_grad=learn_scale) 78 | 79 | def forward(self, feats, msg, p=2): 80 | msg = F.normalize(msg, p=2, dim=-1) 81 | feats_norm = feats.norm(p=p, dim=-1, keepdim=True) 82 | return msg * feats_norm * self.scale -------------------------------------------------------------------------------- /Codes/helper/MongoManager.py: -------------------------------------------------------------------------------- 1 | from sshtunnel import SSHTunnelForwarder 2 | from pymongo import MongoClient 3 | from pymongo.cursor import CursorType 4 | 5 | db_config = { 6 | 'mongo_host': 'yourhost', 7 | 'mongo_db': None, 8 | 'project': None, 9 | 'ssh_username': None, 10 | 'ssh_pwd': None, 11 | 'local_addr': '127.0.0.1', 12 | 'local_port': 27017 13 | } 14 | 15 | class DBHandler: 16 | def __init__(self): 17 | self.server = SSHTunnelForwarder( 18 | db_config['mongo_host'], 19 | ssh_username=db_config['ssh_username'], 20 | ssh_password=db_config['ssh_pwd'], 21 | remote_bind_address=(db_config['local_addr'], db_config['local_port']) 22 | ) 23 | self.server.start() 24 | self.client = MongoClient(db_config['local_addr'], self.server.local_bind_port) 25 | self.db = self.client[db_config['mongo_db']] 26 | self.collection = self.db[db_config['project']] 27 | 28 | def insert_item_one(self, data, db_name=db_config['mongo_db'], collection_name=db_config['project']): 29 | result = self.collection.insert_one(data).inserted_id 30 | return result 31 | 32 | def insert_item_many(self, datas, db_name=db_config['mongo_db'], collection_name=db_config['project']): 33 | result = self.collection.insert_many(datas).inserted_ids 34 | return result 35 | 36 | def find_item_one(self, condition=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 37 | result = self.collection.find_one(condition, {"_id": False}) 38 | return result 39 | 40 | def find_item(self, condition=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 41 | result = self.collection.find(condition, {"_id": False}, no_cursor_timeout=True, cursor_type=CursorType.EXHAUST) 42 | return result 43 | 44 | def delete_item_one(self, condition=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 45 | result = self.collection.delete_one(condition) 46 | return result 47 | 48 | def delete_item_many(self, condition=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 49 | result = self.collection.delete_many(condition) 50 | return result 51 | 52 | def update_item_one(self, condition=None, update_value=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 53 | result = self.collection.update_one(filter=condition, update=update_value) 54 | return result 55 | 56 | def update_item_many(self, condition=None, update_value=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 57 | result = self.collection.update_many(filter=condition, update=update_value) 58 | return result 59 | 60 | def text_search(self, text=None, db_name=db_config['mongo_db'], collection_name=db_config['project']): 61 | result = self.collection.find({"$text": {"$search": text}}) 62 | return result 63 | 64 | def close_connection(self): 65 | self.server.stop() -------------------------------------------------------------------------------- /Codes/main.py: -------------------------------------------------------------------------------- 1 | from helper.parser import * 2 | import random 3 | import torch.multiprocessing as mp 4 | from helper.utils import * 5 | import train 6 | import warnings 7 | 8 | if __name__ == '__main__': 9 | 10 | args = create_parser() 11 | if args.fix_seed is False: 12 | if args.parts_per_node < args.n_partitions: 13 | warnings.warn('Please enable `--fix-seed` for multi-node training.') 14 | args.seed = random.randint(0, 1 << 31) 15 | 16 | if args.graph_name == '': 17 | if args.inductive: 18 | args.graph_name = '%s-%d-%s-%s-induc' % (args.dataset, args.n_partitions, 19 | args.partition_method, args.partition_obj) 20 | else: 21 | args.graph_name = '%s-%d-%s-%s-trans' % (args.dataset, args.n_partitions, 22 | args.partition_method, args.partition_obj) 23 | 24 | if args.skip_partition: 25 | if args.n_feat == 0 or args.n_class == 0 or args.n_train == 0: 26 | warnings.warn('Specifying `--n-feat`, `--n-class` and `--n-train` saves data loading time.') 27 | g, n_feat, n_class = load_data(args, args.dataset) 28 | args.n_feat = n_feat 29 | args.n_class = n_class 30 | args.n_train = g.ndata['train_mask'].int().sum().item() 31 | del g 32 | else: 33 | g, n_feat, n_class = load_data(args, args.dataset) 34 | if args.node_rank == 0: 35 | if args.inductive: 36 | graph_partition(g.subgraph(g.ndata['train_mask']), args) 37 | else: 38 | graph_partition(g, args) 39 | args.n_class = n_class 40 | args.n_feat = n_feat 41 | args.n_train = g.ndata['train_mask'].int().sum().item() 42 | del g 43 | 44 | print(args) 45 | 46 | if args.backend == 'gloo': 47 | processes = [] 48 | 49 | 50 | 51 | 52 | 53 | 54 | devices = ['0', '1', '2', '3'] 55 | 56 | mp.set_start_method('spawn', force=True) 57 | start_id = args.node_rank * args.parts_per_node 58 | 59 | for i in range(start_id, min(start_id + args.parts_per_node, args.n_partitions)): 60 | 61 | os.environ['CUDA_VISIBLE_DEVICES'] = devices[i % len(devices)] 62 | 63 | p = mp.Process(target=train.init_processes, args=(i, args.n_partitions, args)) 64 | p.start() 65 | processes.append(p) 66 | for p in processes: 67 | p.join() 68 | elif args.backend == 'nccl': 69 | processes = [] 70 | 71 | mp.set_start_method('spawn', force=True) 72 | start_id = args.node_rank * args.parts_per_node 73 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3" 74 | 75 | for i in range(start_id, min(start_id + args.parts_per_node, args.n_partitions)): 76 | p = mp.Process(target=train.init_processes, args=(i, args.n_partitions, args)) 77 | p.start() 78 | processes.append(p) 79 | for p in processes: 80 | p.join() 81 | 82 | elif args.backend == 'mpi': 83 | raise NotImplementedError 84 | else: 85 | raise ValueError 86 | -------------------------------------------------------------------------------- /conda-requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | _openmp_mutex=5.1=1_gnu 6 | # bcrypt=4.1.3=pypi_0 7 | blas=1.0=mkl 8 | brotli-python=1.0.9=py310h6a678d5_8 9 | bzip2=1.0.8=h5eee18b_6 10 | ca-certificates=2024.3.11=h06a4308_0 11 | certifi=2024.6.2=py310h06a4308_0 12 | # cffi=1.16.0=pypi_0 13 | charset-normalizer=2.0.4=pyhd3eb1b0_0 14 | # cryptography=42.0.8=pypi_0 15 | cuda-cudart=11.8.89=0 16 | cuda-cupti=11.8.87=0 17 | cuda-libraries=11.8.0=0 18 | cuda-nvrtc=11.8.89=0 19 | cuda-nvtx=11.8.86=0 20 | cuda-runtime=11.8.0=0 21 | cuda-version=12.5=3 22 | dgl=2.3.0.th21.cu118=py310_0 23 | ffmpeg=4.3=hf484d3e_0 24 | filelock=3.13.1=py310h06a4308_0 25 | freetype=2.12.1=h4a9f257_0 26 | gmp=6.2.1=h295c915_3 27 | gmpy2=2.1.2=py310heeb90bb_0 28 | gnutls=3.6.15=he1e5248_0 29 | idna=3.7=py310h06a4308_0 30 | intel-openmp=2023.1.0=hdb19cb5_46306 31 | jinja2=3.1.4=py310h06a4308_0 32 | jpeg=9e=h5eee18b_1 33 | lame=3.100=h7b6447c_0 34 | lcms2=2.12=h3be6417_0 35 | ld_impl_linux-64=2.38=h1181459_1 36 | lerc=3.0=h295c915_0 37 | libcublas=11.11.3.6=0 38 | libcufft=10.9.0.58=0 39 | libcufile=1.10.0.4=0 40 | libcurand=10.3.6.39=0 41 | libcusolver=11.4.1.48=0 42 | libcusparse=11.7.5.86=0 43 | libdeflate=1.17=h5eee18b_1 44 | libffi=3.4.4=h6a678d5_1 45 | libgcc-ng=11.2.0=h1234567_1 46 | libgfortran-ng=11.2.0=h00389a5_1 47 | libgfortran5=11.2.0=h1234567_1 48 | libgomp=11.2.0=h1234567_1 49 | libiconv=1.16=h5eee18b_3 50 | libidn2=2.3.4=h5eee18b_0 51 | libjpeg-turbo=2.0.0=h9bf148f_0 52 | libnpp=11.8.0.86=0 53 | libnvjpeg=11.9.0.86=0 54 | libpng=1.6.39=h5eee18b_0 55 | libstdcxx-ng=11.2.0=h1234567_1 56 | libtasn1=4.19.0=h5eee18b_0 57 | libtiff=4.5.1=h6a678d5_0 58 | libunistring=0.9.10=h27cfd23_0 59 | libuuid=1.41.5=h5eee18b_0 60 | libwebp-base=1.3.2=h5eee18b_0 61 | llvm-openmp=14.0.6=h9e868ea_0 62 | lz4-c=1.9.4=h6a678d5_1 63 | markupsafe=2.1.3=py310h5eee18b_0 64 | mkl=2023.1.0=h213fc3f_46344 65 | mkl-service=2.4.0=py310h5eee18b_1 66 | mkl_fft=1.3.8=py310h5eee18b_0 67 | mkl_random=1.2.4=py310hdb19cb5_0 68 | mpc=1.1.0=h10f8cd9_1 69 | mpfr=4.0.2=hb69a4c5_1 70 | mpmath=1.3.0=py310h06a4308_0 71 | ncurses=6.4=h6a678d5_0 72 | nettle=3.7.3=hbbd107a_1 73 | networkx=3.2.1=py310h06a4308_0 74 | numpy=1.26.4=py310h5f9d8c6_0 75 | numpy-base=1.26.4=py310hb5e798b_0 76 | openh264=2.1.1=h4ff587b_0 77 | openjpeg=2.4.0=h3ad879b_0 78 | openssl=3.0.14=h5eee18b_0 79 | # paramiko=3.4.0=pypi_0 80 | pillow=10.3.0=py310h5eee18b_0 81 | pip=24.0=py310h06a4308_0 82 | psutil=5.9.0=py310h5eee18b_0 83 | pybind11-abi=4=hd3eb1b0_1 84 | # pycparser=2.22=pypi_0 85 | # pynacl=1.5.0=pypi_0 86 | pysocks=1.7.1=py310h06a4308_0 87 | python=3.10.14=h955ad1f_1 88 | pytorch=2.1.0=py3.10_cuda11.8_cudnn8.7.0_0 89 | pytorch-cuda=11.8=h7e8668a_5 90 | pytorch-mutex=1.0=cuda 91 | pyyaml=6.0.1=py310h5eee18b_0 92 | readline=8.2=h5eee18b_0 93 | requests=2.32.2=py310h06a4308_0 94 | scipy=1.13.1=py310h5f9d8c6_0 95 | setuptools=69.5.1=py310h06a4308_0 96 | sqlite=3.45.3=h5eee18b_0 97 | sympy=1.12=py310h06a4308_0 98 | tbb=2021.8.0=hdb19cb5_0 99 | tk=8.6.14=h39e8969_0 100 | torchaudio=2.1.0=py310_cu118 101 | torchtriton=2.1.0=py310 102 | torchvision=0.16.0=py310_cu118 103 | tqdm=4.66.4=py310h2f386ee_0 104 | typing_extensions=4.11.0=py310h06a4308_0 105 | tzdata=2024a=h04d1e81_0 106 | urllib3=2.2.2=py310h06a4308_0 107 | wheel=0.43.0=py310h06a4308_0 108 | xz=5.4.6=h5eee18b_1 109 | yaml=0.2.5=h7b6447c_0 110 | zlib=1.2.13=h5eee18b_1 111 | zstd=1.5.5=hc292b87_2 112 | -------------------------------------------------------------------------------- /Codes/our_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import random 4 | import warnings 5 | from typing import TYPE_CHECKING 6 | 7 | import torch.multiprocessing as mp 8 | 9 | from helper.parser import * 10 | from helper.utils import * 11 | import our_train 12 | 13 | 14 | if __name__ == '__main__': 15 | 16 | args = create_parser() 17 | 18 | assert args.bandwidth_aware, 'This main file use bandwidth aware method!!' 19 | if args.use_flexible: 20 | assert args.flexible_hop <= (args.n_layers - 1), 'flexible preloading allows max n_conv layers' 21 | if args.fix_seed is False: 22 | warnings.warn('Please enable `--fix-seed` for multi-node training.') 23 | args.seed = random.randint(0, 1 << 31) 24 | 25 | if args.graph_name == '': 26 | if args.inductive: 27 | args.graph_name = '%s-%s-induc' % (args.partition_method, args.partition_obj) 28 | else: 29 | args.graph_name = '%s-%s-trans' % (args.partition_method, args.partition_obj) 30 | if args.use_flexible: 31 | if args.inductive: 32 | args.flexible_graph_name = '%s-%s-flex-induc' % (args.partition_method, args.partition_obj) 33 | else: 34 | args.flexible_graph_name = '%s-%s-flex-trans' % (args.partition_method, args.partition_obj) 35 | 36 | if args.n_feat == 0 or args.n_class == 0 or args.n_train == 0: 37 | warnings.warn('Specifying `--n-feat`, `--n-class` and `--n-train` saves data loading time.') 38 | 39 | g, n_feat, n_class = load_data(args, args.dataset) 40 | if args.node_rank == 0 and args.use_flexible: 41 | if args.inductive: 42 | graph_partition(g.node_subgraph(g.ndata['train_mask']), args) 43 | else: 44 | graph_partition(g, args) 45 | args.n_feat = n_feat 46 | args.n_class = n_class 47 | args.n_train = g.ndata['train_mask'].int().sum().item() 48 | del g 49 | 50 | print(args) 51 | 52 | 53 | 54 | 55 | if args.node_rank == 0 and args.remove_tmp: 56 | if os.path.exists('./intra-partitions'): 57 | shutil.rmtree('./intra-partitions') 58 | 59 | if args.backend == 'gloo': 60 | processes = [] 61 | args.local_device_cnt = torch.cuda.device_count() 62 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 63 | devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 64 | else: 65 | n = args.local_device_cnt 66 | devices = [f'{i}' for i in range(n)] 67 | 68 | mp.set_start_method('spawn', force=True) 69 | start_id = args.node_rank * args.local_device_cnt 70 | for i in range(start_id, start_id + args.local_device_cnt): 71 | os.environ['CUDA_VISIBLE_DEVICES'] = devices[i % len(devices)] 72 | p = mp.Process(target=our_train.init_processes, args=(i, args)) 73 | p.start() 74 | processes.append(p) 75 | for p in processes: 76 | p.join() 77 | elif args.backend == 'nccl': 78 | processes = [] 79 | args.local_device_cnt = torch.cuda.device_count() 80 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 81 | devices = os.environ['CUDA_VISIBLE_DEVICES'].split(',') 82 | else: 83 | n = args.local_device_cnt 84 | devices = [f'{i}' for i in range(n)] 85 | 86 | mp.set_start_method('spawn', force=True) 87 | start_id = args.node_rank * args.local_device_cnt 88 | for i in range(start_id, start_id + args.local_device_cnt): 89 | 90 | p = mp.Process(target=our_train.init_processes, args=(i, args)) 91 | p.start() 92 | processes.append(p) 93 | for p in processes: 94 | p.join() 95 | elif args.backend == 'mpi': 96 | raise NotImplementedError 97 | else: 98 | raise ValueError 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # dataset folder 2 | dataset 3 | datasets 4 | partitions 5 | intra-partitions 6 | results 7 | metis* 8 | checkpoint 9 | masks 10 | *_logs 11 | model 12 | Logs 13 | 14 | # checkpoint 15 | *.pt 16 | model/ 17 | DepCache/checkpoints/ 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | share/python-wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | MANIFEST 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .nox/ 61 | .coverage 62 | .coverage.* 63 | .cache 64 | nosetests.xml 65 | coverage.xml 66 | *.cover 67 | *.py,cover 68 | .hypothesis/ 69 | .pytest_cache/ 70 | cover/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | .pybuilder/ 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | # For a library or package, you might want to ignore these files since the code is 105 | # intended to run in multiple environments; otherwise, check them in: 106 | # .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # poetry 116 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 117 | # This is especially recommended for binary packages to ensure reproducibility, and is more 118 | # commonly ignored for libraries. 119 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 120 | #poetry.lock 121 | 122 | # pdm 123 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 124 | #pdm.lock 125 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 126 | # in version control. 127 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 128 | .pdm.toml 129 | .pdm-python 130 | .pdm-build/ 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 180 | #.idea/ 181 | -------------------------------------------------------------------------------- /Codes/module/baseline_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from module.layer import * 4 | from module.sync_bn import SyncBatchNorm 5 | from helper import context as ctx 6 | 7 | 8 | class GNNBase(nn.Module): 9 | 10 | def __init__(self, layer_size, activation, use_pp=False, dropout=0.5, norm='layer', n_linear=0): 11 | super(GNNBase, self).__init__() 12 | self.n_layers = len(layer_size) - 1 13 | self.layers = nn.ModuleList() 14 | self.activation = activation 15 | self.use_pp = use_pp 16 | self.n_linear = n_linear 17 | 18 | if norm is None: 19 | self.use_norm = False 20 | else: 21 | self.use_norm = True 22 | self.norm = nn.ModuleList() 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | class DeeperGCN(GNNBase): 26 | def __init__(self, layer_size, activation, use_pp, n_feat, n_class, \ 27 | dropout=0.5, norm='layer', train_size=None, n_linear=0): 28 | super(DeeperGCN, self).__init__(layer_size, activation, use_pp, dropout, norm, n_linear) 29 | 30 | 31 | self.node_encoder = nn.Linear(n_feat, layer_size[0]) 32 | 33 | 34 | 35 | 36 | use_pp = False 37 | 38 | for i in range(self.n_layers): 39 | if i < self.n_layers - self.n_linear: 40 | 41 | conv = GraphSAGELayer(layer_size[i], layer_size[i+1], use_pp=use_pp) 42 | if norm == 'layer': 43 | norm_fn = nn.LayerNorm(layer_size[i+1], elementwise_affine=True) 44 | elif norm == 'batch': 45 | norm_fn = SyncBatchNorm(layer_size[i+1], train_size) 46 | act = activation 47 | layer = DeeperGCNLayer(conv=conv, norm=norm_fn, act=act, block='res+', 48 | 49 | dropout=dropout, ckpt_grad=False) 50 | self.layers.append(layer) 51 | else: 52 | self.layers.append(nn.Linear(layer_size[i], layer_size[i+1])) 53 | 54 | def forward(self, g, feat, in_deg=None, masks=None): 55 | 56 | h = self.node_encoder(feat) 57 | 58 | 59 | 60 | 61 | for i in range(self.n_layers): 62 | if i < self.n_layers - self.n_linear: 63 | if self.training and (i > 0 or not self.use_pp): 64 | h = ctx.buffer.update(i, h) 65 | if masks is not None: 66 | h[masks[i]] = 0.0 67 | h = self.layers[i](g, h, in_deg) 68 | else: 69 | h = self.layers[0].act(self.layers[0].norm(h)) 70 | h = self.dropout(h) 71 | h = self.layers[i](h) 72 | return h 73 | 74 | 75 | class GraphSAGE(GNNBase): 76 | def __init__(self, layer_size, activation, use_pp, dropout=0.5, norm='layer', train_size=None, n_linear=0): 77 | super(GraphSAGE, self).__init__(layer_size, activation, use_pp, dropout, norm, n_linear) 78 | for i in range(self.n_layers): 79 | if i < self.n_layers - self.n_linear: 80 | self.layers.append(GraphSAGELayer(layer_size[i], layer_size[i + 1], use_pp=use_pp)) 81 | else: 82 | self.layers.append(nn.Linear(layer_size[i], layer_size[i + 1])) 83 | if i < self.n_layers - 1 and self.use_norm: 84 | if norm == 'layer': 85 | self.norm.append(nn.LayerNorm(layer_size[i + 1], elementwise_affine=True)) 86 | elif norm == 'batch': 87 | self.norm.append(SyncBatchNorm(layer_size[i + 1], train_size)) 88 | use_pp = False 89 | 90 | def forward(self, g, feat, in_deg=None, masks=None): 91 | h = feat 92 | 93 | for i in range(self.n_layers): 94 | if i < self.n_layers - self.n_linear: 95 | if self.training and (i > 0 or not self.use_pp): 96 | h = ctx.buffer.update(i, h) 97 | if masks is not None: 98 | h[masks[i]] = 0.0 99 | h = self.dropout(h) 100 | h = self.layers[i](g, h, in_deg) 101 | else: 102 | h = self.dropout(h) 103 | h = self.layers[i](h) 104 | 105 | if i < self.n_layers - 1: 106 | if self.use_norm: 107 | h = self.norm[i](h) 108 | h = self.activation(h) 109 | return h 110 | -------------------------------------------------------------------------------- /Codes/module/model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from module.layer import * 4 | from module.sync_bn import SyncBatchNorm 5 | from helper import intra_context as ctx 6 | 7 | 8 | class GNNBase(nn.Module): 9 | 10 | def __init__(self, layer_size, activation, use_pp=False, dropout=0.5, norm='layer', n_linear=0): 11 | super(GNNBase, self).__init__() 12 | self.n_layers = len(layer_size) - 1 13 | self.layers = nn.ModuleList() 14 | self.activation = activation 15 | self.use_pp = use_pp 16 | self.n_linear = n_linear 17 | 18 | if norm is None: 19 | self.use_norm = False 20 | else: 21 | self.use_norm = True 22 | self.norm = nn.ModuleList() 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | class DeeperGCN(GNNBase): 26 | def __init__(self, layer_size, activation, use_pp, n_feat, n_class, \ 27 | dropout=0.5, norm='layer', train_size=None, n_linear=0): 28 | super(DeeperGCN, self).__init__(layer_size, activation, use_pp, dropout, norm, n_linear) 29 | 30 | 31 | self.node_encoder = nn.Linear(n_feat, layer_size[0]) 32 | 33 | 34 | 35 | 36 | use_pp = False 37 | 38 | for i in range(self.n_layers): 39 | if i < self.n_layers - self.n_linear: 40 | 41 | conv = GraphSAGELayer(layer_size[i], layer_size[i+1], use_pp=use_pp) 42 | if norm == 'layer': 43 | norm_fn = nn.LayerNorm(layer_size[i+1], elementwise_affine=True) 44 | elif norm == 'batch': 45 | norm_fn = SyncBatchNorm(layer_size[i+1], train_size) 46 | act = activation 47 | layer = DeeperGCNLayer(conv=conv, norm=norm_fn, act=act, block='res+', 48 | 49 | dropout=dropout, ckpt_grad=False) 50 | self.layers.append(layer) 51 | else: 52 | self.layers.append(nn.Linear(layer_size[i], layer_size[i+1])) 53 | 54 | def forward(self, g, feat, in_deg=None, masks=None): 55 | 56 | h = self.node_encoder(feat) 57 | 58 | 59 | 60 | 61 | for i in range(self.n_layers): 62 | if i < self.n_layers - self.n_linear: 63 | if self.training and (i > 0 or not self.use_pp): 64 | h = ctx.buffer.update(i, h) 65 | if masks is not None: 66 | h[masks[i]] = 0.0 67 | h = self.layers[i](g, h, in_deg) 68 | else: 69 | h = self.layers[0].act(self.layers[0].norm(h)) 70 | h = self.dropout(h) 71 | h = self.layers[i](h) 72 | return h 73 | 74 | 75 | class GraphSAGE(GNNBase): 76 | def __init__(self, layer_size, activation, use_pp, dropout=0.5, norm='layer', train_size=None, n_linear=0): 77 | super(GraphSAGE, self).__init__(layer_size, activation, use_pp, dropout, norm, n_linear) 78 | for i in range(self.n_layers): 79 | if i < self.n_layers - self.n_linear: 80 | self.layers.append(GraphSAGELayer(layer_size[i], layer_size[i + 1], use_pp=use_pp)) 81 | else: 82 | self.layers.append(nn.Linear(layer_size[i], layer_size[i + 1])) 83 | if i < self.n_layers - 1 and self.use_norm: 84 | if norm == 'layer': 85 | self.norm.append(nn.LayerNorm(layer_size[i + 1], elementwise_affine=True)) 86 | elif norm == 'batch': 87 | self.norm.append(SyncBatchNorm(layer_size[i + 1], train_size)) 88 | use_pp = False 89 | 90 | def forward(self, g, feat, in_deg=None, masks=None): 91 | h = feat 92 | 93 | for i in range(self.n_layers): 94 | if i < self.n_layers - self.n_linear: 95 | if self.training and (i > 0 or not self.use_pp): 96 | h = ctx.buffer.update(i, h) 97 | if masks is not None: 98 | h[masks[i]] = 0.0 99 | h = self.dropout(h) 100 | h = self.layers[i](g, h, in_deg) 101 | else: 102 | h = self.dropout(h) 103 | h = self.layers[i](h) 104 | 105 | if i < self.n_layers - 1: 106 | if self.use_norm: 107 | h = self.norm[i](h) 108 | h = self.activation(h) 109 | return h 110 | -------------------------------------------------------------------------------- /Codes/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from module.model import * 4 | from helper.utils import * 5 | from sklearn.metrics import f1_score 6 | import argparse 7 | 8 | """ 9 | Tester for Hyper-scale Datasets (e.g., Papers-100M) 10 | 11 | e.g. 12 | python test.py --dataset yelp --ckpt ./model/yelp_granndis_2023_05_14__17_34_57.pth.tar --n-hidden 512 --n-layers 15 --n-linear 1 13 | """ 14 | 15 | def get_layer_size(args, n_feat, n_hidden, n_class, n_layers): 16 | 17 | if args.model in ['deepgcn']: 18 | layer_size = [n_hidden] 19 | else: 20 | layer_size = [n_feat] 21 | 22 | 23 | layer_size.extend([n_hidden] * (n_layers - 1)) 24 | 25 | layer_size.append(n_class) 26 | return layer_size 27 | 28 | def create_model(layer_size, args): 29 | if args.model == 'graphsage': 30 | return GraphSAGE(layer_size, F.relu, args.use_pp, norm=args.norm, dropout=args.dropout, 31 | n_linear=args.n_linear, train_size=args.n_train) 32 | elif args.model == 'deepgcn': 33 | return DeeperGCN(layer_size, nn.ReLU(inplace=True), args.use_pp, args.n_feat, args.n_class, norm=args.norm, dropout=args.dropout, 34 | n_linear=args.n_linear, train_size=args.n_train) 35 | else: 36 | raise NotImplementedError 37 | 38 | def inductive_split(g): 39 | g_train = g.subgraph(g.ndata['train_mask']) 40 | g_val = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask']) 41 | g_test = g 42 | return g_train, g_val, g_test 43 | 44 | 45 | def calc_acc(logits, labels): 46 | if labels.dim() == 1: 47 | _, indices = torch.max(logits, dim=1) 48 | correct = torch.sum(indices == labels) 49 | return correct.item() / labels.shape[0] 50 | else: 51 | return f1_score(labels, logits > 0, average='micro') 52 | 53 | @torch.no_grad() 54 | def evaluate_induc(name, model, g, mode, result_file_name=None): 55 | """ 56 | mode: 'val' or 'test' 57 | """ 58 | model.eval() 59 | feat, labels = g.ndata['feat'], g.ndata['label'] 60 | mask = g.ndata[mode + '_mask'] 61 | logits = model(g, feat) 62 | logits = logits[mask] 63 | labels = labels[mask] 64 | acc = calc_acc(logits, labels) 65 | del logits 66 | del labels 67 | buf = "{:s} | Accuracy {:.2%}".format(name, acc) 68 | if result_file_name is not None: 69 | with open(result_file_name, 'a+') as f: 70 | f.write(buf + '\n') 71 | print(buf) 72 | else: 73 | print(buf) 74 | 75 | return model, acc 76 | 77 | 78 | if __name__ == '__main__': 79 | parser = argparse.ArgumentParser(description='Test Code for Hyper-scale Datasets') 80 | 81 | parser.add_argument('--dataset', type=str, required=True, help='dataset') 82 | parser.add_argument('--inductive', action='store_true', help='inductive learning setting') 83 | parser.add_argument('--ckpt', type=str, required=True, help='ckpt path') 84 | 85 | parser.add_argument('--model', type=str, default='deepgcn', help='model type') 86 | parser.add_argument("--dropout", type=float, default=0.5, 87 | help="dropout probability") 88 | parser.add_argument("--n-hidden", "--n_hidden", type=int, default=16, 89 | help="the number of hidden units") 90 | parser.add_argument("--n-layers", "--n_layers", type=int, default=2, 91 | help="the number of GCN layers") 92 | parser.add_argument("--n-linear", "--n_linear", type=int, default=0, 93 | help="the number of linear layers") 94 | parser.add_argument("--norm", choices=['layer', 'batch'], default='layer', 95 | help="normalization method") 96 | parser.add_argument("--use-pp", "--use_pp", action='store_true', 97 | help="whether to use precomputation") 98 | 99 | parser.add_argument('--dataset-path', '--dataset_path', default='/datasets/atc23/', type=str, \ 100 | help='dataset path') 101 | 102 | args = parser.parse_args() 103 | 104 | g, n_feat, n_class = load_data(args, args.dataset) 105 | args.n_feat = n_feat 106 | args.n_class = n_class 107 | args.n_train = g.ndata['train_mask'].int().sum().item() 108 | 109 | if args.inductive: 110 | train_g, val_g, test_g = inductive_split(g) 111 | del train_g 112 | del val_g 113 | del g 114 | else: 115 | test_g = g 116 | 117 | 118 | layer_size = get_layer_size(args, args.n_feat, args.n_hidden, args.n_class, args.n_layers) 119 | model = create_model(layer_size, args) 120 | 121 | loaded = torch.load(args.ckpt) 122 | 123 | 124 | for key in list(loaded.keys()): 125 | key_split = key.split('.') 126 | loaded['.'.join(key_split[1:])] = loaded.pop(key) 127 | 128 | model.load_state_dict(loaded) 129 | 130 | _, acc = evaluate_induc('Test Result', model, test_g, 'test') 131 | 132 | print('Testing Finished.... ACC. ' + str(acc*100) + ' (%) !!!') 133 | 134 | 135 | -------------------------------------------------------------------------------- /Codes/module/layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | import dgl.function as fn 13 | 14 | from ogb.graphproppred.mol_encoder import BondEncoder 15 | from dgl.nn.functional import edge_softmax 16 | 17 | from module.others import MLP, MessageNorm 18 | 19 | 20 | 21 | class DeeperGCNLayer(nn.Module): 22 | def __init__(self, 23 | conv: Optional[nn.Module] = None, 24 | norm: Optional[nn.Module] = None, 25 | act: Optional[nn.Module] = None, 26 | block: str = 'res+', 27 | dropout: float = 0., 28 | ckpt_grad: bool = False) -> None: 29 | super(DeeperGCNLayer, self).__init__() 30 | self.conv = conv 31 | self.norm = norm 32 | self.act = act 33 | self.block = block.lower() 34 | 35 | assert self.block in ['res+', 'res', 'dense', 'plain'] 36 | self.dropout = dropout 37 | self.ckpt_grad = ckpt_grad 38 | self.num_inner = None 39 | 40 | def reset_parameters(self): 41 | self.conv.reset_paramters() 42 | self.norm.reset_parameters() 43 | 44 | def forward(self, graph, feat, in_deg) -> torch.Tensor: 45 | 46 | with graph.local_scope(): 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | res_feat = feat 59 | 60 | 61 | 62 | if self.block == 'res+': 63 | h = feat 64 | if self.norm is not None: 65 | h = self.norm(h) 66 | if self.act is not None: 67 | h = self.act(h) 68 | h = F.dropout(h, p=self.dropout, training=self.training) 69 | if self.conv is not None and self.ckpt_grad and h.requires_grad: 70 | h = checkpoint(self.conv, graph, h, in_deg) 71 | else: 72 | h = self.conv(graph, h, in_deg) 73 | 74 | return res_feat[:h.shape[0]] + h 75 | else: 76 | 77 | 78 | if self.conv is not None and self.ckpt_grad and feat.requires_grad: 79 | h = checkpoint(self.conv, graph, feat, in_deg) 80 | else: 81 | h = self.conv(graph, feat, in_deg) 82 | if self.norm is not None: 83 | h = self.norm(h) 84 | if self.act is not None: 85 | h = self.act(h) 86 | 87 | 88 | if self.block == 'res': 89 | h = res_feat[:h.shape[0]] + h 90 | elif self.block == 'dense': 91 | h = torch.cat([res_feat[:h.shape[0]], h], dim=-1) 92 | elif self.block == 'plain': 93 | pass 94 | return F.dropout(h, p=self.dropout, training=self.training) 95 | def __repr__(self) -> str: 96 | return f'{self.__class__.__name__}(block={self.block})' 97 | 98 | 99 | class GraphSAGELayer(nn.Module): 100 | 101 | def __init__(self, 102 | in_feats, 103 | out_feats, 104 | bias=True, 105 | use_pp=False): 106 | super(GraphSAGELayer, self).__init__() 107 | self.use_pp = use_pp 108 | if self.use_pp: 109 | self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias) 110 | else: 111 | self.linear1 = nn.Linear(in_feats, out_feats, bias=bias) 112 | self.linear2 = nn.Linear(in_feats, out_feats, bias=bias) 113 | self.reset_parameters() 114 | 115 | def reset_parameters(self): 116 | if self.use_pp: 117 | stdv = 1. / math.sqrt(self.linear.weight.size(1)) 118 | self.linear.weight.data.uniform_(-stdv, stdv) 119 | if self.linear.bias is not None: 120 | self.linear.bias.data.uniform_(-stdv, stdv) 121 | else: 122 | stdv = 1. / math.sqrt(self.linear1.weight.size(1)) 123 | self.linear1.weight.data.uniform_(-stdv, stdv) 124 | self.linear2.weight.data.uniform_(-stdv, stdv) 125 | if self.linear1.bias is not None: 126 | self.linear1.bias.data.uniform_(-stdv, stdv) 127 | self.linear2.bias.data.uniform_(-stdv, stdv) 128 | 129 | def forward(self, graph, feat, in_deg): 130 | with graph.local_scope(): 131 | if self.training: 132 | if self.use_pp: 133 | feat = self.linear(feat) 134 | else: 135 | degs = in_deg.unsqueeze(1) 136 | num_dst = graph.num_nodes('_V') 137 | graph.nodes['_U'].data['h'] = feat 138 | graph['_E'].update_all(fn.copy_u(u='h', out='m'), 139 | fn.sum(msg='m', out='h'), 140 | etype='_E') 141 | ah = graph.nodes['_V'].data['h'] / degs 142 | feat = self.linear1(feat[0:num_dst]) + self.linear2(ah) 143 | else: 144 | assert in_deg is None 145 | degs = graph.in_degrees().unsqueeze(1) 146 | graph.ndata['h'] = feat 147 | graph.update_all(fn.copy_u(u='h', out='m'), 148 | fn.sum(msg='m', out='h')) 149 | ah = graph.ndata.pop('h') / degs 150 | if self.use_pp: 151 | feat = self.linear(torch.cat((feat, ah), dim=1)) 152 | else: 153 | feat = self.linear1(feat) + self.linear2(ah) 154 | return feat 155 | -------------------------------------------------------------------------------- /AE/ae1_parser.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from os import listdir 3 | from os.path import isfile, join 4 | 5 | import prettytable 6 | import json 7 | 8 | def find_dur(dict_list: List, dataset_name: str, dur_type: str) -> float: 9 | r""" find proper duration from dictionary list of various datasets 10 | """ 11 | for element in dict_list: 12 | if element['dataset'] == dataset_name: 13 | if dur_type in element: 14 | return element[dur_type] 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | if __name__ == '__main__': 20 | # prepare list 21 | opt_baseline_list = [] 22 | flx_list = [] 23 | cob_list = [] 24 | eas_list = [] 25 | 26 | # parse opt_baseline 27 | opt_baseline_basedir = '../Logs/granndis_opt_baseline' 28 | opt_baseline_files = [f for f in listdir(opt_baseline_basedir) \ 29 | if isfile(join(opt_baseline_basedir, f))] 30 | assert len(opt_baseline_files) == 3, 'we tested only three datasets.' 31 | for opt_baseline_file in opt_baseline_files: 32 | with open(join(opt_baseline_basedir, opt_baseline_file)) as json_file: 33 | opt_baseline_data = json.load(json_file) 34 | opt_baseline_list.append(opt_baseline_data) 35 | 36 | # parse flx (flexible preloading) 37 | flx_basedir = '../Logs/granndis_flx' 38 | flx_files = [f for f in listdir(flx_basedir) \ 39 | if isfile(join(flx_basedir, f))] 40 | assert len(flx_files) == 3, 'we tested only three datasets.' 41 | for flx_file in flx_files: 42 | with open(join(flx_basedir, flx_file)) as json_file: 43 | flx_data = json.load(json_file) 44 | flx_list.append(flx_data) 45 | 46 | # parse cob (cooperative batching) 47 | cob_basedir = '../Logs/granndis_cob' 48 | cob_files = [f for f in listdir(cob_basedir) \ 49 | if isfile(join(cob_basedir, f))] 50 | assert len(cob_files) == 3, 'we tested only three datasets.' 51 | for cob_file in cob_files: 52 | with open(join(cob_basedir, cob_file)) as json_file: 53 | cob_data = json.load(json_file) 54 | cob_list.append(cob_data) 55 | 56 | # parse eas (expansion-aware sampling) 57 | eas_basedir = '../Logs/granndis_eas' 58 | eas_files = [f for f in listdir(eas_basedir) \ 59 | if isfile(join(eas_basedir, f))] 60 | assert len(eas_files) == 3, 'we tested only three datasets.' 61 | for eas_file in eas_files: 62 | with open(join(eas_basedir, eas_file)) as json_file: 63 | eas_data = json.load(json_file) 64 | eas_list.append(eas_data) 65 | 66 | ae1_arxiv = prettytable.PrettyTable() 67 | ae1_arxiv.title = 'Throughput Results for Arxiv' 68 | ae1_arxiv.field_names = ['Method', 'Total Time (sec)', 'Comm Time (sec)', 'Speedup'] 69 | base_tot = find_dur(opt_baseline_list, 'ogbn-arxiv', 'train_dur_aggregated')*1000 70 | ae1_arxiv.add_row(['Opt_FB', base_tot, 71 | find_dur(opt_baseline_list, 'ogbn-arxiv', 'comm_dur_aggregated')*1000, 1.0]) 72 | cur_tot = find_dur(flx_list, 'ogbn-arxiv', 'train_dur_aggregated')*1000 73 | ae1_arxiv.add_row(['FLX', cur_tot, 74 | find_dur(flx_list, 'ogbn-arxiv', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 75 | cur_tot = find_dur(cob_list, 'ogbn-arxiv', 'train_dur_aggregated')*1000 76 | ae1_arxiv.add_row(['CoB', cur_tot, 77 | find_dur(cob_list, 'ogbn-arxiv', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 78 | cur_tot = find_dur(eas_list, 'ogbn-arxiv', 'train_dur_aggregated')*1000 79 | ae1_arxiv.add_row(['EAS', cur_tot, 80 | find_dur(eas_list, 'ogbn-arxiv', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 81 | ae1_arxiv.float_format = ".2" 82 | print(ae1_arxiv) 83 | 84 | ae1_reddit = prettytable.PrettyTable() 85 | ae1_reddit.title = 'Throughput Results for Reddit' 86 | ae1_reddit.field_names = ['Method', 'Total Time (sec)', 'Comm Time (sec)', 'Speedup'] 87 | base_tot = find_dur(opt_baseline_list, 'reddit', 'train_dur_aggregated')*1000 88 | ae1_reddit.add_row(['Opt_FB', base_tot, 89 | find_dur(opt_baseline_list, 'reddit', 'comm_dur_aggregated')*1000, 1.0]) 90 | cur_tot = find_dur(flx_list, 'reddit', 'train_dur_aggregated')*1000 91 | ae1_reddit.add_row(['FLX', cur_tot, 92 | find_dur(flx_list, 'reddit', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 93 | cur_tot = find_dur(cob_list, 'reddit', 'train_dur_aggregated')*1000 94 | ae1_reddit.add_row(['CoB', cur_tot, 95 | find_dur(cob_list, 'reddit', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 96 | cur_tot = find_dur(eas_list, 'reddit', 'train_dur_aggregated')*1000 97 | ae1_reddit.add_row(['EAS', cur_tot, 98 | find_dur(eas_list, 'reddit', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 99 | ae1_reddit.float_format = ".2" 100 | print(ae1_reddit) 101 | 102 | ae1_products = prettytable.PrettyTable() 103 | ae1_products.title = 'Throughput Results for Products' 104 | ae1_products.field_names = ['Method', 'Total Time (sec)', 'Comm Time (sec)', 'Speedup'] 105 | base_tot = find_dur(opt_baseline_list, 'ogbn-products', 'train_dur_aggregated')*1000 106 | ae1_products.add_row(['Opt_FB', base_tot, 107 | find_dur(opt_baseline_list, 'ogbn-products', 'comm_dur_aggregated')*1000, 1.0]) 108 | cur_tot = find_dur(flx_list, 'ogbn-products', 'train_dur_aggregated')*1000 109 | ae1_products.add_row(['FLX', cur_tot, 110 | find_dur(flx_list, 'ogbn-products', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 111 | cur_tot = find_dur(cob_list, 'ogbn-products', 'train_dur_aggregated')*1000 112 | ae1_products.add_row(['CoB', cur_tot, 113 | find_dur(cob_list, 'ogbn-products', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 114 | cur_tot = find_dur(eas_list, 'ogbn-products', 'train_dur_aggregated')*1000 115 | ae1_products.add_row(['EAS', cur_tot, 116 | find_dur(eas_list, 'ogbn-products', 'comm_dur_aggregated')*1000, base_tot/cur_tot]) 117 | ae1_products.float_format = ".2" 118 | print(ae1_products) -------------------------------------------------------------------------------- /AE/ae2_acc_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | 105 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 106 | sampler_list = ['sage'] 107 | 108 | """ 109 | Iteration 110 | """ 111 | 112 | for dataset_dict in dataset_list: 113 | for n_layer in [3]: 114 | for sampler in sampler_list: 115 | remove_tmp = True 116 | for num_server in [2]: 117 | for check_intra_only in [False]: 118 | for hidden_size in [64]: 119 | """ 120 | Make an Experiment 121 | """ 122 | if n_layer > 5: 123 | model_type = 'deepgcn' 124 | else: 125 | model_type = 'graphsage' 126 | shared_cmd = """{env_loc} {runner_loc} \ 127 | --dataset {dataset} \ 128 | --dropout {dropout} \ 129 | --sampler {sampler} \ 130 | --lr {lr} \ 131 | --parts-per-node {gpus_per_server} \ 132 | --n-partitions {num_parts} \ 133 | --n-epochs 1000 \ 134 | --model {model_type} \ 135 | --n-layers {n_layers} \ 136 | --n-linear 1 \ 137 | --n-hidden {hidden_size} \ 138 | --log-every 10 \ 139 | --master-addr {master_addr} \ 140 | --port 7524 \ 141 | --fix-seed \ 142 | --seed 7524 \ 143 | --backend nccl \ 144 | --dataset-path {data_loc} \ 145 | --exp-id 1 \ 146 | --create-json 1 \ 147 | --json-path {workspace_loc}/Logs/fb_acc \ 148 | --project fb_acc \ 149 | """.format( 150 | env_loc = env_loc, 151 | runner_loc = runner_loc, 152 | data_loc = data_loc, 153 | workspace_loc = workspace_loc, 154 | gpus_per_server = gpus_per_server, 155 | num_parts = gpus_per_server * num_server, 156 | dataset = dataset_dict['dataset'], 157 | sampler = sampler, 158 | model_type = model_type, 159 | dropout = dataset_dict['dropout'], 160 | lr = dataset_dict['lr'], 161 | n_layers = (n_layer + 1), # for deepgcn... we need to plus 1 162 | hidden_size = hidden_size, 163 | master_addr = hosts[0], 164 | ) 165 | 166 | if remove_tmp: 167 | shared_cmd = shared_cmd + """--remove-tmp """ 168 | remove_tmp = False 169 | 170 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 171 | # shared_cmd = shared_cmd + """--partition-method random """ 172 | 173 | if check_intra_only: 174 | shared_cmd = shared_cmd + """--check-intra-only """ 175 | 176 | def __is_list_in_target(list, target): 177 | is_in = False 178 | for element in list: 179 | if element in target: 180 | is_in = True 181 | break 182 | return is_in 183 | 184 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 185 | shared_cmd = shared_cmd + """--inductive """ 186 | 187 | """ 188 | Pre-Cleaning 189 | """ 190 | kill_cmds = list() 191 | for i in range(num_server): 192 | kill_cmds.append('pkill -ef spawn && rm -rf ~/partitions') 193 | 194 | processes = [] 195 | mp.set_start_method('spawn', force=True) 196 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 197 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 198 | p.start() 199 | processes.append(p) 200 | 201 | for p in processes: 202 | p.join() 203 | 204 | """ 205 | Run an Experiment :) 206 | """ 207 | 208 | cmds = list() 209 | for i in range(num_server): 210 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 211 | runner_cmd += """--node-rank %d """ % i 212 | runner_cmd += "\n" 213 | cmds.append(runner_cmd) 214 | 215 | processes = [] 216 | mp.set_start_method('spawn', force=True) 217 | # Note that python zip only iterates for shorter list!!! 218 | # so do not worry for running on not dedicated runners 219 | for host, runner, cmd in zip(hosts, runners, cmds): 220 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 221 | p.start() 222 | processes.append(p) 223 | 224 | for p in processes: 225 | p.join() 226 | 227 | """ 228 | Post-Cleaning 229 | """ 230 | kill_cmds = list() 231 | for i in range(num_server): 232 | kill_cmds.append('pkill -ef spawn') 233 | 234 | processes = [] 235 | mp.set_start_method('spawn', force=True) 236 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 237 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 238 | p.start() 239 | processes.append(p) 240 | 241 | for p in processes: 242 | p.join() -------------------------------------------------------------------------------- /AE/ae1_tpt_opt_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | 105 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 106 | sampler_list = ['sage'] 107 | 108 | """ 109 | Iteration 110 | """ 111 | 112 | for dataset_dict in dataset_list: 113 | for n_layer in [3]: 114 | for sampler in sampler_list: 115 | remove_tmp = True 116 | for num_server in [2]: 117 | for check_intra_only in [False]: 118 | for hidden_size in [64]: 119 | """ 120 | Make an Experiment 121 | """ 122 | if n_layer > 5: 123 | model_type = 'deepgcn' 124 | else: 125 | model_type = 'graphsage' 126 | shared_cmd = """{env_loc} {runner_loc} \ 127 | --dataset {dataset} \ 128 | --dropout {dropout} \ 129 | --sampler {sampler} \ 130 | --lr {lr} \ 131 | --parts-per-node {gpus_per_server} \ 132 | --n-partitions {num_parts} \ 133 | --n-epochs 100 \ 134 | --model {model_type} \ 135 | --n-layers {n_layers} \ 136 | --n-linear 1 \ 137 | --n-hidden {hidden_size} \ 138 | --log-every 10 \ 139 | --master-addr {master_addr} \ 140 | --port 7524 \ 141 | --debug \ 142 | --time-calc \ 143 | --no-eval \ 144 | --fix-seed \ 145 | --seed 7524 \ 146 | --backend nccl \ 147 | --dataset-path {data_loc} \ 148 | --exp-id 1 \ 149 | --create-json 1 \ 150 | --json-path {workspace_loc}/Logs/granndis_opt_baseline \ 151 | --project granndis_ae_opt_baseline \ 152 | """.format( 153 | env_loc = env_loc, 154 | runner_loc = runner_loc, 155 | data_loc = data_loc, 156 | workspace_loc = workspace_loc, 157 | gpus_per_server = gpus_per_server, 158 | num_parts = gpus_per_server * num_server, 159 | dataset = dataset_dict['dataset'], 160 | sampler = sampler, 161 | model_type = model_type, 162 | dropout = dataset_dict['dropout'], 163 | lr = dataset_dict['lr'], 164 | n_layers = (n_layer + 1), # for deepgcn... we need to plus 1 165 | hidden_size = hidden_size, 166 | master_addr = hosts[0], 167 | ) 168 | 169 | if remove_tmp: 170 | shared_cmd = shared_cmd + """--remove-tmp """ 171 | remove_tmp = False 172 | 173 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 174 | # shared_cmd = shared_cmd + """--partition-method random """ 175 | 176 | if check_intra_only: 177 | shared_cmd = shared_cmd + """--check-intra-only """ 178 | 179 | def __is_list_in_target(list, target): 180 | is_in = False 181 | for element in list: 182 | if element in target: 183 | is_in = True 184 | break 185 | return is_in 186 | 187 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 188 | shared_cmd = shared_cmd + """--inductive """ 189 | 190 | """ 191 | Pre-Cleaning 192 | """ 193 | kill_cmds = list() 194 | for i in range(num_server): 195 | kill_cmds.append('pkill -ef spawn && rm -rf ~/partitions') 196 | 197 | processes = [] 198 | mp.set_start_method('spawn', force=True) 199 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 200 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 201 | p.start() 202 | processes.append(p) 203 | 204 | for p in processes: 205 | p.join() 206 | 207 | """ 208 | Run an Experiment :) 209 | """ 210 | 211 | cmds = list() 212 | for i in range(num_server): 213 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 214 | runner_cmd += """--node-rank %d """ % i 215 | runner_cmd += "\n" 216 | cmds.append(runner_cmd) 217 | 218 | processes = [] 219 | mp.set_start_method('spawn', force=True) 220 | # Note that python zip only iterates for shorter list!!! 221 | # so do not worry for running on not dedicated runners 222 | for host, runner, cmd in zip(hosts, runners, cmds): 223 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 224 | p.start() 225 | processes.append(p) 226 | 227 | for p in processes: 228 | p.join() 229 | 230 | """ 231 | Post-Cleaning 232 | """ 233 | kill_cmds = list() 234 | for i in range(num_server): 235 | kill_cmds.append('pkill -ef spawn') 236 | 237 | processes = [] 238 | mp.set_start_method('spawn', force=True) 239 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 240 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 241 | p.start() 242 | processes.append(p) 243 | 244 | for p in processes: 245 | p.join() -------------------------------------------------------------------------------- /Codes/helper/sampler.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | from dgl.data import RedditDataset 3 | from dgl import transforms 4 | from dgl.dataloading.base import set_node_lazy_features, set_edge_lazy_features, Sampler 5 | from dgl.sampling.utils import EidExcluder 6 | from helper.utils import * 7 | import numpy as np 8 | 9 | def _load_data(dataset): 10 | if dataset == 'reddit': 11 | data = RedditDataset(raw_dir='./dataset/') 12 | g = data[0] 13 | elif dataset == 'ogbn-products': 14 | g = load_ogb_dataset('ogbn-products') 15 | elif dataset == 'ogbn-papers100m': 16 | g = load_ogb_dataset('ogbn-papers100M') 17 | elif dataset == 'yelp': 18 | g = load_yelp() 19 | else: 20 | raise ValueError('Unknown dataset: {}'.format(dataset)) 21 | 22 | n_feat = g.ndata['feat'].shape[1] 23 | if g.ndata['label'].dim() == 1: 24 | n_class = g.ndata['label'].max().item() + 1 25 | else: 26 | n_class = g.ndata['label'].shape[1] 27 | 28 | g.edata.clear() 29 | g = dgl.remove_self_loop(g) 30 | g = dgl.add_self_loop(g) 31 | return g, n_feat, n_class 32 | 33 | """ 34 | Use default sampler 35 | """ 36 | 37 | class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler): 38 | def __init__(self, fanouts): 39 | super().__init__(len(fanouts)) 40 | 41 | self.fanouts = fanouts 42 | 43 | def sample_frontier(self, block_id, g, seed_nodes): 44 | fanout = self.fanouts[block_id] 45 | if fanout is None: 46 | frontier = dgl.in_subgraph(g, seed_nodes) 47 | else: 48 | frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout) 49 | return frontier 50 | 51 | """ 52 | Shadow-GNN subgraph samplers 53 | """ 54 | class ShaDowKHopSampler(Sampler): 55 | def __init__(self, fanouts, use_mask=False, replace=False, prob=None, prefetch_node_feats=None, 56 | prefetch_edge_feats=None, output_device=None): 57 | super().__init__() 58 | self.fanouts = fanouts 59 | self.use_mask = use_mask 60 | self.replace = replace 61 | self.prob = prob 62 | self.prefetch_node_feats = prefetch_node_feats 63 | self.prefetch_edge_feats = prefetch_edge_feats 64 | self.output_device = output_device 65 | 66 | 67 | def sample(self, g, seed_nodes, relabel_nodes=True, exclude_eids=None): 68 | """Sampling function. 69 | 70 | Parameters 71 | ---------- 72 | g : DGLGraph 73 | The graph to sampler from. 74 | seed_nodes : Tensor or dict[str, Tensor] 75 | The nodes sampled in the current minibatch. 76 | exclude_eids : Tensor or dict[etype, Tensor], optional 77 | The edges to exclude from neighborhood expansion. 78 | 79 | Returns 80 | ------- 81 | input_nodes, output_nodes, subg, masks 82 | A triplet containing (1) the node IDs inducing the subgraph, (2) the node 83 | IDs that are sampled in this minibatch, and (3) the subgraph itself. 84 | (4) mask 85 | """ 86 | output_nodes = seed_nodes 87 | masks = [] 88 | 89 | for fanout in reversed(self.fanouts): 90 | frontier = g.sample_neighbors( 91 | seed_nodes, fanout, output_device=self.output_device, 92 | replace=self.replace, prob=self.prob, exclude_edges=exclude_eids) 93 | block = transforms.to_block(frontier, seed_nodes) 94 | 95 | """ 96 | Code Snippet: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors 97 | """ 98 | combined = torch.cat((block.srcdata[dgl.NID], block.dstdata[dgl.NID], block.dstdata[dgl.NID])) 99 | uniques, counts = combined.unique(return_counts=True) 100 | diff = uniques[counts == 1] 101 | masks.insert(0, diff) 102 | 103 | seed_nodes = block.srcdata[dgl.NID] 104 | 105 | subg = dgl.node_subgraph(g, seed_nodes, relabel_nodes=relabel_nodes, output_device=self.output_device) 106 | if exclude_eids is not None: 107 | subg = EidExcluder(exclude_eids)(subg) 108 | 109 | set_node_lazy_features(subg, self.prefetch_node_feats) 110 | set_edge_lazy_features(subg, self.prefetch_edge_feats) 111 | 112 | return seed_nodes, output_nodes, subg, masks 113 | 114 | 115 | class SAINTSampler(Sampler): 116 | """Random node/edge/walk sampler from 117 | `GraphSAINT: Graph Sampling Based Inductive Learning Method 118 | `__ 119 | 120 | For each call, the sampler samples a node subset and then returns a node induced subgraph. 121 | There are three options for sampling node subsets: 122 | 123 | - For :attr:`'node'` sampler, the probability to sample a node is in proportion 124 | to its out-degree. 125 | - The :attr:`'edge'` sampler first samples an edge subset and then use the 126 | end nodes of the edges. 127 | - The :attr:`'walk'` sampler uses the nodes visited by random walks. It uniformly selects 128 | a number of root nodes and then performs a fixed-length random walk from each root node. 129 | 130 | Parameters 131 | ---------- 132 | mode : str 133 | The sampler to use, which can be :attr:`'node'`, :attr:`'edge'`, or :attr:`'walk'`. 134 | budget : int or tuple[int] 135 | Sampler configuration. 136 | 137 | - For :attr:`'node'` sampler, budget specifies the number of nodes 138 | in each sampled subgraph. 139 | - For :attr:`'edge'` sampler, budget specifies the number of edges 140 | to sample for inducing a subgraph. 141 | - For :attr:`'walk'` sampler, budget is a tuple. budget[0] specifies 142 | the number of root nodes to generate random walks. budget[1] specifies 143 | the length of a random walk. 144 | 145 | cache : bool, optional 146 | If False, it will not cache the probability arrays for sampling. Setting 147 | it to False is required if you want to use the sampler across different graphs. 148 | prefetch_ndata : list[str], optional 149 | The node data to prefetch for the subgraph. 150 | 151 | See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching. 152 | prefetch_edata : list[str], optional 153 | The edge data to prefetch for the subgraph. 154 | 155 | See :ref:`guide-minibatch-prefetching` for a detailed explanation of prefetching. 156 | output_device : device, optional 157 | The device of the output subgraphs. 158 | 159 | Examples 160 | -------- 161 | 162 | >>> import torch 163 | >>> from dgl.dataloading import SAINTSampler, DataLoader 164 | >>> num_iters = 1000 165 | >>> sampler = SAINTSampler(mode='node', budget=6000) 166 | >>> 167 | >>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4) 168 | >>> for subg in dataloader: 169 | ... train_on(subg) 170 | """ 171 | def __init__(self, mode, cache=True, prefetch_ndata=None, 172 | prefetch_edata=None, output_device='cpu'): 173 | super().__init__() 174 | 175 | if mode == 'node': 176 | self.sampler = self.node_sampler 177 | else: 178 | raise DGLError(f"Expect mode to be 'node', 'edge' or 'walk', got {mode}.") 179 | 180 | self.cache = cache 181 | self.prob = None 182 | self.prefetch_ndata = prefetch_ndata or [] 183 | self.prefetch_edata = prefetch_edata or [] 184 | self.output_device = output_device 185 | 186 | def node_sampler(self, g, budget): 187 | """Node ID sampler for random node sampler""" 188 | 189 | 190 | 191 | if self.cache and self.prob is not None: 192 | prob = self.prob 193 | else: 194 | prob = g.out_degrees().float().clamp(min=1) 195 | if self.cache: 196 | self.prob = prob 197 | return torch.multinomial(prob, num_samples=budget, 198 | replacement=True).unique().type(g.idtype) 199 | 200 | def sample(self, g, indices, budget): 201 | """Sampling function 202 | 203 | Parameters 204 | ---------- 205 | g : DGLGraph 206 | The graph to sample from. 207 | indices : Tensor 208 | Placeholder not used. 209 | 210 | Returns 211 | ------- 212 | DGLGraph 213 | The sampled subgraph. 214 | """ 215 | node_ids = self.sampler(g, budget) 216 | sg = dgl.node_subgraph(g, node_ids, relabel_nodes=True, output_device=self.output_device) 217 | 218 | set_node_lazy_features(sg, self.prefetch_ndata) 219 | set_edge_lazy_features(sg, self.prefetch_edata) 220 | return None, node_ids, sg, None 221 | 222 | 223 | """ 224 | Merge subs into single graph by dgl.merge 225 | """ 226 | 227 | def merge_subgraphs(subgs): 228 | return dgl.merge(subgs) 229 | 230 | if __name__ == '__main__': 231 | print('Mini-batch Partitioning Unit Test') 232 | 233 | print('[Step 0] Load Reddit Dataset') 234 | g, n_feat, n_class = _load_data('reddit') 235 | 236 | 237 | print('[Step 1] Get Dedicated IDs') 238 | print(' '*4, '> Train Mask') 239 | print(' '*8, g.ndata['train_mask'].shape) 240 | 241 | print(' '*4, '> Train NIDs') 242 | train_nids = g.ndata['train_mask'].nonzero().squeeze() 243 | test_nids = g.ndata['test_mask'].nonzero().squeeze() 244 | val_nids = g.ndata['val_mask'].nonzero().squeeze() 245 | print(len(train_nids)) 246 | print(len(test_nids)) 247 | print(len(val_nids)) 248 | print(len(train_nids)+len(test_nids)+len(val_nids)) 249 | print(len(train_nids)/(len(train_nids)+len(test_nids)+len(val_nids))) 250 | 251 | print(' '*4, '> Split NIDs') 252 | seed_node_array = np.array_split(np.array(train_nids), 4) 253 | print(len(seed_node_array[0])) 254 | 255 | print('[Step 2] Make Sampler and Make Subgraph') 256 | sampler = dgl.dataloading.NeighborSampler([10, 10]) 257 | 258 | 259 | 260 | 261 | print(' '*4, '> Sample Subgraph from 0th NIDs') 262 | train_sum = 0 263 | input_nodes, output_nodes, subgs = sampler.sample(g, seed_node_array[3]) 264 | train_sum = train_sum + len(output_nodes) 265 | input_nodes, output_nodes, subgs = sampler.sample(g, seed_node_array[1]) 266 | train_sum = train_sum + len(output_nodes) 267 | input_nodes, output_nodes, subgs = sampler.sample(g, seed_node_array[2]) 268 | train_sum = train_sum + len(output_nodes) 269 | input_nodes, output_nodes, subgs = sampler.sample(g, seed_node_array[0]) 270 | train_sum = train_sum + len(output_nodes) 271 | print(len(input_nodes)) 272 | print(train_sum) 273 | print(len(output_nodes)) 274 | 275 | 276 | 277 | print('[Step 3] Merge Subgraphs') 278 | full_g_0 = dgl.merge(subgs) 279 | print(full_g_0.nodes()) 280 | print(len(full_g_0.ndata['train_mask'].nonzero().squeeze()) / 281 | (len(full_g_0.ndata['train_mask'].nonzero().squeeze()) +\ 282 | len(full_g_0.ndata['test_mask'].nonzero().squeeze()) +\ 283 | len(full_g_0.ndata['val_mask'].nonzero().squeeze()))) 284 | print(len(full_g_0.ndata['train_mask'].nonzero().squeeze())) 285 | 286 | print(len(full_g_0.ndata['train_mask'].nonzero().squeeze()) +\ 287 | len(full_g_0.ndata['test_mask'].nonzero().squeeze()) +\ 288 | len(full_g_0.ndata['val_mask'].nonzero().squeeze())) 289 | 290 | print('[Step 4] Print Graph per Node Info') 291 | print(full_g_0) 292 | 293 | 294 | -------------------------------------------------------------------------------- /Codes/helper/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import TYPE_CHECKING, Literal 3 | 4 | if TYPE_CHECKING: 5 | import torch.distributed as dist 6 | from helper.mapper import Mapper 7 | 8 | 9 | class Args: 10 | dataset: str = 'reddit' 11 | "The input dataset" 12 | dataset_scale: int = 1 13 | "Scaling factor for dataset" 14 | dataset_path: str = "/datasets/atc23/" 15 | "Dataset path" 16 | graph_name: str = '' 17 | 18 | "sampler type" 19 | sampler: str = 'sage' 20 | 21 | model: str = 'graphsage' 22 | "Model for training" 23 | dropout: float = 0.5 24 | "Dropout probability" 25 | lr: float = 1e-2 26 | "Learning rate" 27 | n_epochs: int = 200 28 | "The number of training epochs" 29 | n_partitions: int = 2 30 | "The number of partitions" 31 | n_hidden: int = 16 32 | "The number of hidden units" 33 | n_layers: int = 2 34 | "The number of GCN layers" 35 | n_linear: int = 0 36 | "The number of linear layers" 37 | norm: Literal['layer', 'batch', 'none'] = 'layer' 38 | "Normalization method" 39 | weight_decay: float = 0 40 | "Weight for L2 loss" 41 | 42 | n_feat: int = 0 43 | n_class: int = 0 44 | n_train: int = 0 45 | skip_partition: bool = False 46 | "Skip graph partition" 47 | 48 | partition_obj: Literal['vol', 'cut'] = 'vol' 49 | "Partition objective function ('vol' or 'cut')" 50 | partition_method: Literal['metis', 'random'] = 'metis' 51 | "The method for graph partition ('metis' or 'random')" 52 | 53 | enable_pipeline: bool = False 54 | feat_corr: bool = False 55 | grad_corr: bool = False 56 | corr_momentum: float = 0.95 57 | 58 | use_mask: bool = False 59 | bandwidth_aware: bool = False 60 | "Whether to use intra-inter bandwidth aware method" 61 | remove_tmp: bool = False 62 | "Remove intra-partitions..." 63 | time_calc: bool = False 64 | "Time calculation..." 65 | epoch_iter: int = 1 66 | fanout: int = 5 67 | subgraph_hop: int = 1 68 | "Subgraph hyperparameter" 69 | flexible_hop: int = 0 70 | "flexible preloading hop" 71 | use_flexible: bool = False 72 | check_intra_only: bool = False 73 | "Check intra only comm time for evaluation...." 74 | full_neighbor: bool = False 75 | "Whether to use full neighbor mini-batch generation" 76 | skip_minibatch_partition: bool = False 77 | "Skip minibatch generation and partition if already exists" 78 | 79 | use_pp: bool = False 80 | "Whether to use precomputation" 81 | inductive: bool = False 82 | "Inductive or transductive learning setting" 83 | fix_seed: bool = False 84 | "Fix random seed" 85 | seed: int = 0 86 | log_every: int = 10 87 | 88 | backend: Literal['gloo', 'nccl', 'mpi'] = 'nccl' 89 | master_addr: str = "127.0.0.1" 90 | port: int = 18118 91 | "The network port for communication" 92 | key_value_port: int = 18118 93 | "The network port for communication" 94 | node_rank : int = 0 95 | total_nodes: int = 2 96 | parts_per_node: int = 10 97 | rank: int 98 | "[Warning] This is not a part of the parser." 99 | world_size: int 100 | "[Warning] This is not a part of the parser." 101 | local_rank: int 102 | "[Warning] This is not a part of the parser." 103 | local_device_cnt: int 104 | "Number of local devices. [Warning] This is not a part of the parser." 105 | 106 | head_group: 'dist.ProcessGroup' 107 | "Group consists of the head of the nodes. (ex: [0, 4, 8, ...]) [Warning] This is not a part of the parser." 108 | device_group: 'dist.ProcessGroup' 109 | "Group consists of the local devices. (ex: [4, 5, 6, 7]) [Warning] This is not a part of the parser." 110 | mapper: 'Mapper' 111 | "Mapper class between full-graph, server-graph, gpu-graph. [Warning] This is not a part of the parser." 112 | 113 | eval: bool = True 114 | "Enable evaluation" 115 | sequential_eval: bool = False 116 | "If true, evaluate sequentially. Else, make another thread and do evaluation while training next epoch." 117 | 118 | create_json: int = 0 119 | send_db: int = 0 120 | db_name: str 121 | "DB name... normally nickname" 122 | project: str 123 | "Project(experiment) name" 124 | ssh_user: str 125 | "SSH username" 126 | ssh_pwd: str 127 | "SSH password" 128 | json_path: str = "./json_logs" 129 | 130 | debug: bool = False 131 | exp_id: int = 0 132 | 133 | 134 | def create_parser() -> Args: 135 | 136 | parser = argparse.ArgumentParser(description='PipeGCN') 137 | 138 | parser.add_argument("--dataset", type=str, default='reddit', 139 | help="the input dataset") 140 | parser.add_argument("--dataset-scale", type=int, default=1, help='scaling factor for dataset') 141 | parser.add_argument("--graph-name", "--graph_name", type=str, default='') 142 | 143 | parser.add_argument("--sampler", type=str, default='sage', 144 | help='sampler type') 145 | 146 | parser.add_argument("--model", type=str, default='graphsage', 147 | help="model for training") 148 | parser.add_argument("--dropout", type=float, default=0.5, 149 | help="dropout probability") 150 | parser.add_argument("--lr", type=float, default=1e-2, 151 | help="learning rate") 152 | parser.add_argument("--n-epochs", "--n_epochs", type=int, default=200, 153 | help="the number of training epochs") 154 | parser.add_argument("--n-partitions", "--n_partitions", type=int, default=2, 155 | help="the number of partitions") 156 | parser.add_argument("--n-hidden", "--n_hidden", type=int, default=16, 157 | help="the number of hidden units") 158 | parser.add_argument("--n-layers", "--n_layers", type=int, default=2, 159 | help="the number of GCN layers") 160 | parser.add_argument("--n-linear", "--n_linear", type=int, default=0, 161 | help="the number of linear layers") 162 | parser.add_argument("--norm", choices=['layer', 'batch'], default='layer', 163 | help="normalization method") 164 | parser.add_argument("--weight-decay", "--weight_decay", type=float, default=0, 165 | help="weight for L2 loss") 166 | 167 | parser.add_argument("--n-feat", "--n_feat", type=int, default=0) 168 | parser.add_argument("--n-class", "--n_class", type=int, default=0) 169 | parser.add_argument("--n-train", "--n_train", type=int, default=0) 170 | parser.add_argument('--skip-partition', action='store_true', 171 | help="skip graph partition") 172 | 173 | parser.add_argument("--partition-obj", "--partition_obj", choices=['vol', 'cut'], default='vol', 174 | help="partition objective function ('vol' or 'cut')") 175 | parser.add_argument("--partition-method", "--partition_method", choices=['metis', 'random'], default='metis', 176 | help="the method for graph partition ('metis' or 'random')") 177 | 178 | parser.add_argument("--enable-pipeline", "--enable_pipeline", action='store_true') 179 | parser.add_argument("--feat-corr", "--feat_corr", action='store_true') 180 | parser.add_argument("--grad-corr", "--grad_corr", action='store_true') 181 | parser.add_argument("--corr-momentum", "--corr_momentum", type=float, default=0.95) 182 | 183 | parser.add_argument("--use-mask", "--use_mask", action='store_true', 184 | help='whether to use mask based unified batching to bandwidth aware method.') 185 | parser.add_argument("--bandwidth-aware", "--bandwidth_aware", action='store_true', 186 | help='whether to use intra-inter bandwidth aware method.') 187 | parser.add_argument('--remove-tmp', '--remove_tmp', action='store_true', 188 | help='remove intra-partitions...') 189 | parser.add_argument('--time_calc', '--time-calc', action='store_true', 190 | help='time calculation...') 191 | 192 | parser.add_argument("--epoch-iter", "--epoch-iter", type=int, default=1) 193 | parser.add_argument("--fanout", type=int, default=5) 194 | parser.add_argument("--subgraph-hop", "--subgraph_hop", type=int, default=1) 195 | parser.add_argument("--use-flexible", "--use_flexible", action='store_true', \ 196 | help='use flexible preloading') 197 | parser.add_argument("--flexible-hop", "--flexible_hop", type=int, default=0) 198 | parser.add_argument("--check-intra-only", "--check_intra_only", action='store_true', \ 199 | help='check intra only comm time for evaluation....') 200 | parser.add_argument("--full-neighbor", "--full_neighbor", action='store_true', 201 | help="whether use full neighbor mini-batch generation") 202 | parser.add_argument("--skip-minibatch-partition", "--skip_minibatch_partition", action='store_true', 203 | help='skip minibatch generation and partition if already exists.') 204 | 205 | parser.add_argument("--use-pp", "--use_pp", action='store_true', 206 | help="whether to use precomputation") 207 | parser.add_argument("--inductive", action='store_true', 208 | help="inductive learning setting") 209 | parser.add_argument("--fix-seed", "--fix_seed", action='store_true', 210 | help="fix random seed") 211 | parser.add_argument("--seed", type=int, default=0) 212 | parser.add_argument("--log-every", "--log_every", type=int, default=10) 213 | 214 | parser.add_argument("--backend", type=str, default='nccl') 215 | parser.add_argument("--port", type=int, default=18118, 216 | help="the network port for communication") 217 | parser.add_argument("--key-value-port", "--key_value_port", type=int, default=18118, 218 | help="the network port for communication") 219 | parser.add_argument("--master-addr", "--master_addr", type=str, default="127.0.0.1") 220 | parser.add_argument("--node-rank", "--node_rank", type=int, default=0) 221 | parser.add_argument('--total-nodes', '--total_nodes', type=int, default=2) 222 | parser.add_argument("--parts-per-node", "--parts_per_node", type=int, default=10) 223 | 224 | parser.add_argument('--eval', action='store_true', 225 | help="enable evaluation") 226 | parser.add_argument('--no-eval', action='store_false', dest='eval', 227 | help="disable evaluation") 228 | parser.add_argument('--sequential-eval', '--sequential_eval', action='store_true', 229 | help="if true, evaluate sequentially; else, make another thread and do evaluation while training next epoch.") 230 | 231 | parser.add_argument('--dataset-path', '--dataset_path', default='/datasets/atc23/', type=str, \ 232 | help='dataset path') 233 | 234 | parser.add_argument("--create-json", "--create_json", type=int, default=0) 235 | parser.add_argument("--send-db", "--send_db", type=int, default=0) 236 | parser.add_argument('--db-name', '--db_name', help='db name... normally nickname') 237 | parser.add_argument('--project', help='project(experiment) name') 238 | parser.add_argument('--ssh-user', '--ssh_user', help='ssh username') 239 | parser.add_argument('--ssh-pwd', '--ssh_pwd', help='ssh password') 240 | parser.add_argument('--json-path', '--json_path', type=str, default='./json_logs') 241 | 242 | parser.add_argument("--debug", action='store_true') 243 | parser.add_argument("--exp-id", "--exp_id", type=int, default=0) 244 | 245 | parser.set_defaults(eval=True) 246 | parser.set_defaults(sequential_eval=False) 247 | 248 | return parser.parse_args() 249 | -------------------------------------------------------------------------------- /AE/ae2_acc_eas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['our_runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 105 | sampler_list = ['sage'] 106 | 107 | """ 108 | Iteration 109 | """ 110 | 111 | for dataset_dict in dataset_list: 112 | for n_layer in [3]: 113 | for sampler in sampler_list: 114 | for subgraph_hop in [1]: 115 | if subgraph_hop == 0: 116 | fanout_list = [0] 117 | else: 118 | fanout_list = [15] 119 | for fanout in fanout_list: 120 | remove_tmp = True 121 | for num_server in [2]: 122 | # for each new layer num... we need to repartition 123 | for check_intra_only in [False]: # we do not have scheme 2... yet :( 124 | for hidden_size in [64]: # 32, 128 125 | """ 126 | Make an Experiment 127 | """ 128 | if n_layer > 5: 129 | model_type = 'deepgcn' 130 | else: 131 | model_type = 'graphsage' 132 | shared_cmd = """{env_loc} {runner_loc} \ 133 | --dataset {dataset} \ 134 | --dropout {dropout} \ 135 | --sampler {sampler} \ 136 | --lr {lr} \ 137 | --n-partitions {gpus_per_server} \ 138 | --n-epochs 1000 \ 139 | --model {model_type} \ 140 | --n-layers {n_layers} \ 141 | --n-linear 1 \ 142 | --n-hidden {hidden_size} \ 143 | --log-every 10 \ 144 | --bandwidth-aware \ 145 | --use-mask \ 146 | --epoch-iter 1 \ 147 | --master-addr {master_addr} \ 148 | --port 7524 \ 149 | --fix-seed \ 150 | --seed 7524 \ 151 | --backend nccl \ 152 | --fanout {fanout} \ 153 | --subgraph-hop {subgraph_hop} \ 154 | --dataset-path {data_loc} \ 155 | --exp-id 1 \ 156 | --create-json 1 \ 157 | --json-path {workspace_loc}/Logs/eas_acc \ 158 | --project eas_acc \ 159 | """.format( 160 | env_loc = env_loc, 161 | runner_loc = runner_loc, 162 | data_loc = data_loc, 163 | workspace_loc = workspace_loc, 164 | gpus_per_server = gpus_per_server, 165 | dataset = dataset_dict['dataset'], 166 | sampler = sampler, 167 | model_type = model_type, 168 | dropout = dataset_dict['dropout'], 169 | lr = dataset_dict['lr'], 170 | n_layers = (n_layer + 1), 171 | hidden_size = hidden_size, 172 | master_addr = hosts[0], 173 | fanout = fanout, 174 | subgraph_hop = subgraph_hop 175 | ) 176 | 177 | if remove_tmp: 178 | shared_cmd = shared_cmd + """--remove-tmp """ 179 | remove_tmp = False 180 | 181 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 182 | # shared_cmd = shared_cmd + """--partition-method random """ 183 | 184 | if check_intra_only: 185 | shared_cmd = shared_cmd + """--check-intra-only """ 186 | 187 | def __is_list_in_target(list, target): 188 | is_in = False 189 | for element in list: 190 | if element in target: 191 | is_in = True 192 | break 193 | return is_in 194 | 195 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 196 | shared_cmd = shared_cmd + """--inductive """ 197 | 198 | """ 199 | Pre-Cleaning 200 | """ 201 | kill_cmds = list() 202 | for i in range(num_server): 203 | kill_cmds.append('pkill -ef spawn') 204 | 205 | processes = [] 206 | mp.set_start_method('spawn', force=True) 207 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 208 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 209 | p.start() 210 | processes.append(p) 211 | 212 | for p in processes: 213 | p.join() 214 | 215 | """ 216 | Run an Experiment :) 217 | """ 218 | 219 | cmds = list() 220 | for i in range(num_server): 221 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 222 | runner_cmd += """--node-rank %d """ % i 223 | runner_cmd += "\n" 224 | cmds.append(runner_cmd) 225 | 226 | processes = [] 227 | mp.set_start_method('spawn', force=True) 228 | # Note that python zip only iterates for shorter list!!! 229 | # so do not worry for running on not dedicated runners 230 | for host, runner, cmd in zip(hosts, runners, cmds): 231 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 232 | p.start() 233 | processes.append(p) 234 | 235 | for p in processes: 236 | p.join() 237 | 238 | """ 239 | Post-Cleaning 240 | """ 241 | kill_cmds = list() 242 | for i in range(num_server): 243 | kill_cmds.append('pkill -ef spawn') 244 | 245 | processes = [] 246 | mp.set_start_method('spawn', force=True) 247 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 248 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 249 | p.start() 250 | processes.append(p) 251 | 252 | for p in processes: 253 | p.join() -------------------------------------------------------------------------------- /AE/ae1_tpt_eas.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['our_runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 105 | sampler_list = ['sage'] 106 | 107 | """ 108 | Iteration 109 | """ 110 | 111 | for dataset_dict in dataset_list: 112 | for n_layer in [3]: 113 | for sampler in sampler_list: 114 | for subgraph_hop in [1]: 115 | if subgraph_hop == 0: 116 | fanout_list = [0] 117 | else: 118 | fanout_list = [15] 119 | for fanout in fanout_list: 120 | remove_tmp = True 121 | for num_server in [2]: 122 | # for each new layer num... we need to repartition 123 | for check_intra_only in [False]: # we do not have scheme 2... yet :( 124 | for hidden_size in [64]: # 32, 128 125 | """ 126 | Make an Experiment 127 | """ 128 | if n_layer > 5: 129 | model_type = 'deepgcn' 130 | else: 131 | model_type = 'graphsage' 132 | shared_cmd = """{env_loc} {runner_loc} \ 133 | --dataset {dataset} \ 134 | --dropout {dropout} \ 135 | --sampler {sampler} \ 136 | --lr {lr} \ 137 | --n-partitions {gpus_per_server} \ 138 | --n-epochs 100 \ 139 | --model {model_type} \ 140 | --n-layers {n_layers} \ 141 | --n-linear 1 \ 142 | --n-hidden {hidden_size} \ 143 | --log-every 10 \ 144 | --bandwidth-aware \ 145 | --use-mask \ 146 | --epoch-iter 1 \ 147 | --master-addr {master_addr} \ 148 | --port 7524 \ 149 | --fix-seed \ 150 | --seed 7524 \ 151 | --backend nccl \ 152 | --fanout {fanout} \ 153 | --subgraph-hop {subgraph_hop} \ 154 | --dataset-path {data_loc} \ 155 | --exp-id 1 \ 156 | --create-json 1 \ 157 | --json-path {workspace_loc}/Logs/granndis_eas \ 158 | --project granndis_ae_eas \ 159 | --debug \ 160 | --time-calc \ 161 | --no-eval \ 162 | """.format( 163 | env_loc = env_loc, 164 | runner_loc = runner_loc, 165 | data_loc = data_loc, 166 | workspace_loc = workspace_loc, 167 | gpus_per_server = gpus_per_server, 168 | dataset = dataset_dict['dataset'], 169 | sampler = sampler, 170 | model_type = model_type, 171 | dropout = dataset_dict['dropout'], 172 | lr = dataset_dict['lr'], 173 | n_layers = (n_layer + 1), 174 | hidden_size = hidden_size, 175 | master_addr = hosts[0], 176 | fanout = fanout, 177 | subgraph_hop = subgraph_hop 178 | ) 179 | 180 | if remove_tmp: 181 | shared_cmd = shared_cmd + """--remove-tmp """ 182 | remove_tmp = False 183 | 184 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 185 | # shared_cmd = shared_cmd + """--partition-method random """ 186 | 187 | if check_intra_only: 188 | shared_cmd = shared_cmd + """--check-intra-only """ 189 | 190 | def __is_list_in_target(list, target): 191 | is_in = False 192 | for element in list: 193 | if element in target: 194 | is_in = True 195 | break 196 | return is_in 197 | 198 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 199 | shared_cmd = shared_cmd + """--inductive """ 200 | 201 | """ 202 | Pre-Cleaning 203 | """ 204 | kill_cmds = list() 205 | for i in range(num_server): 206 | kill_cmds.append('pkill -ef spawn') 207 | 208 | processes = [] 209 | mp.set_start_method('spawn', force=True) 210 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 211 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 212 | p.start() 213 | processes.append(p) 214 | 215 | for p in processes: 216 | p.join() 217 | 218 | """ 219 | Run an Experiment :) 220 | """ 221 | 222 | cmds = list() 223 | for i in range(num_server): 224 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 225 | runner_cmd += """--node-rank %d """ % i 226 | runner_cmd += "\n" 227 | cmds.append(runner_cmd) 228 | 229 | processes = [] 230 | mp.set_start_method('spawn', force=True) 231 | # Note that python zip only iterates for shorter list!!! 232 | # so do not worry for running on not dedicated runners 233 | for host, runner, cmd in zip(hosts, runners, cmds): 234 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 235 | p.start() 236 | processes.append(p) 237 | 238 | for p in processes: 239 | p.join() 240 | 241 | """ 242 | Post-Cleaning 243 | """ 244 | kill_cmds = list() 245 | for i in range(num_server): 246 | kill_cmds.append('pkill -ef spawn') 247 | 248 | processes = [] 249 | mp.set_start_method('spawn', force=True) 250 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 251 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 252 | p.start() 253 | processes.append(p) 254 | 255 | for p in processes: 256 | p.join() -------------------------------------------------------------------------------- /AE/ae1_tpt_cob.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['our_runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 105 | sampler_list = ['sage'] 106 | 107 | """ 108 | Iteration 109 | """ 110 | 111 | for dataset_dict in dataset_list: 112 | for n_layer in [3]: 113 | for sampler in sampler_list: 114 | for subgraph_hop in [n_layer]: 115 | if subgraph_hop == 0: 116 | fanout_list = [0] 117 | else: 118 | fanout_list = [25] 119 | for fanout in fanout_list: 120 | remove_tmp = True 121 | for num_server in [2]: 122 | # for each new layer num... we need to repartition 123 | for check_intra_only in [False]: # we do not have scheme 2... yet :( 124 | for hidden_size in [64]: # 32, 128 125 | """ 126 | Make an Experiment 127 | """ 128 | if n_layer > 5: 129 | model_type = 'deepgcn' 130 | else: 131 | model_type = 'graphsage' 132 | shared_cmd = """{env_loc} {runner_loc} \ 133 | --dataset {dataset} \ 134 | --dropout {dropout} \ 135 | --sampler {sampler} \ 136 | --lr {lr} \ 137 | --n-partitions {gpus_per_server} \ 138 | --n-epochs 100 \ 139 | --model {model_type} \ 140 | --n-layers {n_layers} \ 141 | --n-linear 1 \ 142 | --n-hidden {hidden_size} \ 143 | --log-every 10 \ 144 | --bandwidth-aware \ 145 | --use-mask \ 146 | --epoch-iter 1 \ 147 | --master-addr {master_addr} \ 148 | --port 7524 \ 149 | --fix-seed \ 150 | --seed 7524 \ 151 | --backend nccl \ 152 | --fanout {fanout} \ 153 | --subgraph-hop {subgraph_hop} \ 154 | --dataset-path {data_loc} \ 155 | --exp-id 1 \ 156 | --create-json 1 \ 157 | --json-path {workspace_loc}/Logs/granndis_cob \ 158 | --project granndis_ae_cob \ 159 | --debug \ 160 | --time-calc \ 161 | --no-eval \ 162 | """.format( 163 | env_loc = env_loc, 164 | runner_loc = runner_loc, 165 | data_loc = data_loc, 166 | workspace_loc = workspace_loc, 167 | gpus_per_server = gpus_per_server, 168 | dataset = dataset_dict['dataset'], 169 | sampler = sampler, 170 | model_type = model_type, 171 | dropout = dataset_dict['dropout'], 172 | lr = dataset_dict['lr'], 173 | n_layers = (n_layer + 1), 174 | hidden_size = hidden_size, 175 | master_addr = hosts[0], 176 | fanout = fanout, 177 | subgraph_hop = subgraph_hop 178 | ) 179 | 180 | if remove_tmp: 181 | shared_cmd = shared_cmd + """--remove-tmp """ 182 | remove_tmp = False 183 | 184 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 185 | # shared_cmd = shared_cmd + """--partition-method random """ 186 | 187 | if check_intra_only: 188 | shared_cmd = shared_cmd + """--check-intra-only """ 189 | 190 | def __is_list_in_target(list, target): 191 | is_in = False 192 | for element in list: 193 | if element in target: 194 | is_in = True 195 | break 196 | return is_in 197 | 198 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 199 | shared_cmd = shared_cmd + """--inductive """ 200 | 201 | """ 202 | Pre-Cleaning 203 | """ 204 | kill_cmds = list() 205 | for i in range(num_server): 206 | kill_cmds.append('pkill -ef spawn') 207 | 208 | processes = [] 209 | mp.set_start_method('spawn', force=True) 210 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 211 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 212 | p.start() 213 | processes.append(p) 214 | 215 | for p in processes: 216 | p.join() 217 | 218 | """ 219 | Run an Experiment :) 220 | """ 221 | 222 | cmds = list() 223 | for i in range(num_server): 224 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 225 | runner_cmd += """--node-rank %d """ % i 226 | runner_cmd += "\n" 227 | cmds.append(runner_cmd) 228 | 229 | processes = [] 230 | mp.set_start_method('spawn', force=True) 231 | # Note that python zip only iterates for shorter list!!! 232 | # so do not worry for running on not dedicated runners 233 | for host, runner, cmd in zip(hosts, runners, cmds): 234 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 235 | p.start() 236 | processes.append(p) 237 | 238 | for p in processes: 239 | p.join() 240 | 241 | """ 242 | Post-Cleaning 243 | """ 244 | kill_cmds = list() 245 | for i in range(num_server): 246 | kill_cmds.append('pkill -ef spawn') 247 | 248 | processes = [] 249 | mp.set_start_method('spawn', force=True) 250 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 251 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 252 | p.start() 253 | processes.append(p) 254 | 255 | for p in processes: 256 | p.join() -------------------------------------------------------------------------------- /AE/ae1_tpt_flx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch.multiprocessing as mp 6 | import select 7 | import paramiko 8 | 9 | import configs 10 | 11 | class Commands: 12 | def __init__(self, retry_time=0): 13 | self.retry_time = retry_time 14 | pass 15 | 16 | def run_cmd(self, host_ip, command): 17 | i = 0 18 | while True: 19 | try: 20 | ssh = paramiko.SSHClient() 21 | ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) 22 | ssh.connect(host_ip) 23 | break 24 | except paramiko.AuthenticationException: 25 | print("Authentication failed when connecting to %s" % host_ip) 26 | sys.exit(1) 27 | except: 28 | print("Could not SSH to %s, waiting for it to start" % host_ip) 29 | i += 1 30 | time.sleep(2) 31 | 32 | # If we could not connect within time limit 33 | if i >= self.retry_time: 34 | print("Could not connect to %s. Giving up" % host_ip) 35 | sys.exit(1) 36 | # After connection is successful 37 | # Send the command 38 | 39 | # print command 40 | print('> ' + command) 41 | # execute commands 42 | stdin, stdout, stderr = ssh.exec_command(command) 43 | 44 | print(stderr.read().decode("euc-kr")) 45 | stdin.close() 46 | 47 | # TODO() : if an error is thrown, stop further rules and revert back changes 48 | # Wait for the command to terminate 49 | while not stdout.channel.exit_status_ready(): 50 | # Only print data if there is data to read in the channel 51 | if stdout.channel.recv_ready(): 52 | rl, wl, xl = select.select([ stdout.channel ], [ ], [ ], 0.0) 53 | if len(rl) > 0: 54 | tmp = stdout.channel.recv(1024) 55 | output = tmp.decode() 56 | print(output) 57 | 58 | # Close SSH connection 59 | ssh.close() 60 | return 61 | 62 | if __name__ == '__main__': 63 | 64 | """ 65 | Set Server Environment 66 | """ 67 | env_loc = configs.global_configs['env_loc'] 68 | runner_loc = configs.global_configs['our_runner_loc'] 69 | workspace_loc = configs.global_configs['workspace_loc'] 70 | data_loc = configs.global_configs['data_loc'] 71 | num_runners = configs.global_configs['num_runners'] 72 | gpus_per_server = configs.global_configs['gpus_per_server'] 73 | hosts = configs.global_configs['hosts'] 74 | assert len(hosts) == num_runners, 'our script requires a host per a runner' 75 | 76 | """ 77 | SSH Connection Class 78 | """ 79 | runners = list() 80 | for i in range(num_runners): 81 | runners.append(Commands()) 82 | 83 | 84 | """ 85 | Set Common Dataset Information 86 | """ 87 | dataset_list = list() 88 | arxiv_dict = { 89 | 'dataset': 'ogbn-arxiv', 90 | 'dropout': 0.5, 91 | 'lr': 0.01 92 | } 93 | reddit_dict = { 94 | 'dataset': 'reddit', 95 | 'dropout': 0.5, 96 | 'lr': 0.01 97 | } 98 | product_dict = { 99 | 'dataset': 'ogbn-products', 100 | 'dropout': 0.3, 101 | 'lr': 0.003 102 | } 103 | 104 | dataset_list = [arxiv_dict, reddit_dict, product_dict] 105 | sampler_list = ['sage'] 106 | 107 | """ 108 | Iteration 109 | """ 110 | 111 | for dataset_dict in dataset_list: 112 | for n_layer in [3]: 113 | for sampler in sampler_list: 114 | for subgraph_hop in [n_layer]: 115 | if subgraph_hop == 0: 116 | fanout_list = [0] 117 | else: 118 | fanout_list = [-1] 119 | for fanout in fanout_list: 120 | remove_tmp = True 121 | for num_server in [2]: 122 | # for each new layer num... we need to repartition 123 | for check_intra_only in [False]: # we do not have scheme 2... yet :( 124 | for hidden_size in [64]: # 32, 128 125 | """ 126 | Make an Experiment 127 | """ 128 | if n_layer > 5: 129 | model_type = 'deepgcn' 130 | else: 131 | model_type = 'graphsage' 132 | shared_cmd = """{env_loc} {runner_loc} \ 133 | --dataset {dataset} \ 134 | --dropout {dropout} \ 135 | --sampler {sampler} \ 136 | --lr {lr} \ 137 | --n-partitions {gpus_per_server} \ 138 | --n-epochs 100 \ 139 | --model {model_type} \ 140 | --n-layers {n_layers} \ 141 | --n-linear 1 \ 142 | --n-hidden {hidden_size} \ 143 | --log-every 10 \ 144 | --bandwidth-aware \ 145 | --epoch-iter 1 \ 146 | --use-mask \ 147 | --master-addr {master_addr} \ 148 | --port 7524 \ 149 | --fix-seed \ 150 | --seed 7524 \ 151 | --backend nccl \ 152 | --fanout {fanout} \ 153 | --subgraph-hop {subgraph_hop} \ 154 | --dataset-path {data_loc} \ 155 | --exp-id 1 \ 156 | --create-json 1 \ 157 | --json-path {workspace_loc}/Logs/granndis_flx \ 158 | --project granndis_ae_flx \ 159 | --debug \ 160 | --time-calc \ 161 | --no-eval \ 162 | """.format( 163 | env_loc = env_loc, 164 | runner_loc = runner_loc, 165 | data_loc = data_loc, 166 | workspace_loc = workspace_loc, 167 | gpus_per_server = gpus_per_server, 168 | dataset = dataset_dict['dataset'], 169 | sampler = sampler, 170 | model_type = model_type, 171 | dropout = dataset_dict['dropout'], 172 | lr = dataset_dict['lr'], 173 | n_layers = (n_layer + 1), 174 | hidden_size = hidden_size, 175 | master_addr = hosts[0], 176 | fanout = fanout, 177 | subgraph_hop = subgraph_hop 178 | ) 179 | 180 | if remove_tmp: 181 | shared_cmd = shared_cmd + """--remove-tmp """ 182 | remove_tmp = False 183 | 184 | # if dataset_dict['dataset'] == 'ogbn-papers100m': 185 | # shared_cmd = shared_cmd + """--partition-method random """ 186 | 187 | if check_intra_only: 188 | shared_cmd = shared_cmd + """--check-intra-only """ 189 | 190 | def __is_list_in_target(list, target): 191 | is_in = False 192 | for element in list: 193 | if element in target: 194 | is_in = True 195 | break 196 | return is_in 197 | 198 | if not __is_list_in_target(list=['papers'], target=dataset_dict['dataset']): 199 | shared_cmd = shared_cmd + """--inductive """ 200 | 201 | """ 202 | Pre-Cleaning 203 | """ 204 | kill_cmds = list() 205 | for i in range(num_server): 206 | kill_cmds.append('pkill -ef spawn') 207 | 208 | processes = [] 209 | mp.set_start_method('spawn', force=True) 210 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 211 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 212 | p.start() 213 | processes.append(p) 214 | 215 | for p in processes: 216 | p.join() 217 | 218 | """ 219 | Run an Experiment :) 220 | """ 221 | 222 | cmds = list() 223 | for i in range(num_server): 224 | runner_cmd = shared_cmd + """--total-nodes %d """ % num_server 225 | runner_cmd += """--node-rank %d """ % i 226 | runner_cmd += "\n" 227 | cmds.append(runner_cmd) 228 | 229 | processes = [] 230 | mp.set_start_method('spawn', force=True) 231 | # Note that python zip only iterates for shorter list!!! 232 | # so do not worry for running on not dedicated runners 233 | for host, runner, cmd in zip(hosts, runners, cmds): 234 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 235 | p.start() 236 | processes.append(p) 237 | 238 | for p in processes: 239 | p.join() 240 | 241 | """ 242 | Post-Cleaning 243 | """ 244 | kill_cmds = list() 245 | for i in range(num_server): 246 | kill_cmds.append('pkill -ef spawn') 247 | 248 | processes = [] 249 | mp.set_start_method('spawn', force=True) 250 | for host, runner, cmd in zip(hosts, runners, kill_cmds): 251 | p = mp.Process(target=runner.run_cmd, args=(host, cmd)) 252 | p.start() 253 | processes.append(p) 254 | 255 | for p in processes: 256 | p.join() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [PACT'24] GraNNDis: Fast Distributed Graph Neural Network Training Framework for Multi-Server Clusters [![DOI](https://zenodo.org/badge/821791396.svg)](https://zenodo.org/doi/10.5281/zenodo.12677841) 2 | This repository is the artifact of GraNNDis for PACT'24 artifact evaluation (AE). 3 | 4 | Note that this repo provides the SOTA performance on distributed full-batch (full-graph) GNN training even without the GraNNDis schemes from our own optimizations using NCCL. 5 | Our implementations are mainly based on the original code of [PipeGCN](https://github.com/GATECH-EIC/PipeGCN). 6 | For details, please refer to our PACT'24 paper ([Author Copy](PACT24_GraNNDis_Author_Copy.pdf), [Proceeding](https://doi.org/10.1145/3656019.3676892)). 7 | 8 | This artifact earned the following badges: 9 | 10 | - Artifact Available 11 | - Artifact Evaluated - Reusable 12 | - Results Reproduced 13 | 14 | ## Getting Started 15 | ### 1. SW Dependencies and Setup 16 | - Prerequisite 17 | - CUDA/CuDNN 11.8 Setting (Make sure to include CUDA paths) 18 | - Anaconda Setting 19 | - NFS environment with more than two servers, each server having multiple GPUs. 20 | - Servers must be accessible by SSH connection (e.g., ssh [user]@[server]). 21 | ``` 22 | # include the following two lines in ~/.bashrc will include CUDA paths 23 | export PATH="/usr/local/cuda-11.8/bin:$PATH" 24 | export LD_LIBRARY_PATH="/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH" 25 | nvcc -V # check cuda version using nvcc 26 | ``` 27 | - Main SW Dependencies 28 | - Python 3.10 29 | - PyTorch 2.1.0 30 | - CUDA 11.8 31 | - DGL 2.1.x 32 | 33 | - Setup 34 | ``` 35 | conda update -y conda # update conda 36 | conda create -n granndis_ae python=3.10 -y # create conda env 37 | conda activate granndis_ae # activate conda env 38 | conda install -c conda-forge -c pytorch -c nvidia -c dglteam/label/th21_cu118 --file conda-requirements.txt -y # install conda packages 39 | pip install -r pip-requirements.txt # install pip packages 40 | ``` 41 | 42 | ### 2. HW Dependencies 43 | - Muti-server environment, and each server is equipped with multiple GPUs. 44 | - We need enough system memory (e.g., 256GB) for the artifact evaluation. 45 | - Internal server interconnect (e.g., NVLink) is much faster than external server interconnect (e.g., 10G Ethernet). 46 | 47 | ### 3. Sample Dataset Preparation & Single Server Test 48 | For the artifact evaluation, we will use three sample datasets (i.e., Arxiv, Reddit, and Products), which are widely used and easily accessible datasets. 49 | We provide a test script (`Codes/brief_masking_test.sh`) to download the datasets and test the created environment. 50 | It will test GraNNDis with default settings. 51 | ``` 52 | cd Codes 53 | chmod +x brief_masking_test.sh 54 | ./brief_masking_test.sh 55 | ``` 56 | While running the script, you may be required to type `y` for downloading a dataset. 57 | If the tests are successfully conducted, the logs will be saved in `Codes/masking_test_logs/`. 58 | The expected shell results are as follows: 59 | ``` 60 | ... (Omitted) .... 61 | Process 000 | Epoch 00099 | Time(s) 0.0145 | Comm(s) 0.0049 | Reduce(s) 0.0000 | Loss 0.1598 62 | json logs will be saved in ./masking_test_logs 63 | ... (Omitted) .... 64 | ============================== Speed result summary ============================== 65 | train_duration : 0.014522718265652657 66 | communication_duration : 0.005530537571758032 67 | reduce duration : 4.990156412532087e-07 68 | loss : 0.5080833435058594 69 | ============================================================ 70 | ... (Omitted) .... 71 | Rank 0 successfully created json log file. 72 | Rank 0 successfully created summary json log file. 73 | ... (Omitted) .... 74 | ``` 75 | 76 | ## Fast Reproducing of Main Results 77 | 78 | Some users may not be familiar with the distributed training procedure, so we provide distributed experiment launchers at `AE/*.py`. 79 | Before reproducing, users must change the configuration fields in the config file (`AE/configs.py`). 80 | ``` 81 | global_configs = { 82 | 'env_loc': '/nfs/home/ae/anaconda3/envs/granndis_ae/bin/python', 83 | 'runner_loc': '/nfs/home/ae/GraNNDis_Artifact/Codes/main.py', 84 | 'our_runner_loc': '/nfs/home/ae/GraNNDis_Artifact/Codes/our_main.py', 85 | 'workspace_loc': '/nfs/home/ae/GraNNDis_Artifact/', 86 | 'data_loc': '~/datasets/granndis_ae/', 87 | 'num_runners': 2, 88 | 'gpus_per_server': 4, 89 | 'hosts': ['192.168.0.5', '192.168.0.6'] 90 | } 91 | ``` 92 | 93 | After modification, just run the following commands, which will show the artifact evaluation results. 94 | ``` 95 | sh run_ae.sh # run AE scripts 96 | sh parse_ae.sh # parse AE results 97 | ``` 98 | Then, the results will be saved at `AE*_results.log`. 99 | 100 | ### Example Test Environment 101 | The example test cluster has two servers, each with four NVIDIA RTX A6000 GPUs. 102 | Internal server GPUs are connected via NVLink Bridge, and servers are connected via 10GbE. 103 | 104 | ### Expected Trend for the Artifact 105 | As the example test cluster has a 10GbE inter-server connection, the overall speedup could be higher (up to around 6x) than the paper (up to around 3x), which used 32Gbps Infiniband. 106 | 107 | All FLX, CoB (with SAGE sampling), and EAS would generally show significant speedup over the baseline optimized full-batch training because GraNNDis minimizes the slow external server communication (AE1). EAS (FLX-EAS) is expected to show more speedup than FLX, especially in larger datasets, such as Products. EAS usually shows higher speedup than CoB (especially in larger datasets) while providing comparable accuracy, as shown in AE2 (accuracy result). 108 | 109 | Please note that the result can fluctuate when the inter-server connection is shared with the cluster's NFS file system. In this case, running multiple trials will show the trend mentioned above. 110 | 111 | The following are examples of the results of running the above procedure on the authors' remote machine. 112 | 113 | ### AE1. Throughput Results (Flexible Preloading (FLX), Cooperative Batching (CoB), and Expansion-Aware Sampling (EAS)) 114 | The results show that the optimized full-batch training baseline (Opt_FB) suffers from communication overhead, while FLX/CoB addresses such an issue through server-wise preloading. 115 | EAS further accelerates the training through server boundary-aware sampling. 116 | This trend becomes vivid in larger datasets (Reddit and Products). 117 | 118 | ``` 119 | +-------------------------------------------------------+ 120 | | Throughput Results for Arxiv | 121 | +--------+------------------+-----------------+---------+ 122 | | Method | Total Time (sec) | Comm Time (sec) | Speedup | 123 | +--------+------------------+-----------------+---------+ 124 | | Opt_FB | 15.40 | 9.85 | 1.00 | 125 | | FLX | 8.60 | 2.37 | 1.79 | 126 | | CoB | 8.78 | 2.58 | 1.75 | 127 | | EAS | 11.67 | 3.70 | 1.32 | 128 | +--------+------------------+-----------------+---------+ 129 | +-------------------------------------------------------+ 130 | | Throughput Results for Reddit | 131 | +--------+------------------+-----------------+---------+ 132 | | Method | Total Time (sec) | Comm Time (sec) | Speedup | 133 | +--------+------------------+-----------------+---------+ 134 | | Opt_FB | 449.27 | 422.35 | 1.00 | 135 | | FLX | 87.55 | 49.67 | 5.13 | 136 | | CoB | 90.44 | 49.98 | 4.97 | 137 | | EAS | 75.16 | 40.34 | 5.98 | 138 | +--------+------------------+-----------------+---------+ 139 | +-------------------------------------------------------+ 140 | | Throughput Results for Products | 141 | +--------+------------------+-----------------+---------+ 142 | | Method | Total Time (sec) | Comm Time (sec) | Speedup | 143 | +--------+------------------+-----------------+---------+ 144 | | Opt_FB | 79.67 | 69.15 | 1.00 | 145 | | FLX | 20.03 | 6.36 | 3.98 | 146 | | CoB | 21.85 | 8.33 | 3.65 | 147 | | EAS | 18.23 | 5.78 | 4.37 | 148 | +--------+------------------+-----------------+---------+ 149 | ``` 150 | 151 | ### AE2. Accuracy Results (Expansion-Aware Sampling (EAS)) 152 | As EAS only targets sample server boundary vertices, contributing to acceleration, it successfully achieves comparable accuracy to the original full-batch training. 153 | ``` 154 | +--------------------------------------+ 155 | | Accuracy Comparison (FB vs. FLX-EAS) | 156 | +--------+-------+--------+------------+ 157 | | Method | Arxiv | Reddit | Products | 158 | +--------+-------+--------+------------+ 159 | | FB | 0.69 | 0.96 | 0.76 | 160 | | EAS | 0.69 | 0.96 | 0.76 | 161 | +--------+-------+--------+------------+ 162 | ``` 163 | 164 | ## Additional) GraNNDis Arguments & Distributed Launch 165 | 166 | For distributed training, we need to set the following common arguments. 167 | ``` 168 | --n-partitions 4 # set n-partitions as #GPUs per server (for GraNNDis only) 169 | --total-nodes 1 # #servers to conduct training 170 | ``` 171 | 172 | ### 1. Optimized Baseline Full-Batch (Full-Graph) Training 173 | The sample argument script for running optimized full-batch training is provided at `Codes/brief_opt_baseline_test.sh`. 174 | The main arguments are as follows: 175 | ``` 176 | --n-layers 4 # (#conv layers + #linear layers) 177 | --n-linear 1 # (#linear layers) 178 | --model graphsage # model type 179 | --dataset-path /dataset/granndis_ae/ # dataset path 180 | ``` 181 | 182 | ### 2. GraNNDis Options 183 | 184 | #### Flexible Preloading 185 | The sample argument script for running flexible preloading is provided at `Codes/brief_masking_test.sh`. 186 | The main arguments are as follows: 187 | ``` 188 | --bandwidth-aware # turn on server-wise preloading 189 | --subgraph-hop 3 # (#conv layers) 190 | --fanout -1 # do not apply sampling and utilize the whole information 191 | --sampler sage # use node-wise sampling 192 | --use-mask # use 1-hop graph masking to support intact full-batch/mini-batch training algorithm 193 | ``` 194 | 195 | #### Cooperative Batching 196 | The sample argument script for running cooperative batching is provided at `Codes/brief_masking_test.sh`. 197 | The main arguments are as follows: 198 | ``` 199 | --bandwidth-aware # turn on server-wise preloading 200 | --subgraph-hop 3 # (#conv layers) 201 | --fanout 25 # set sage sampling fanout (default: 25) 202 | --sampler sage # use node-wise sampling 203 | --epoch-iter 1 # #iters/epoch, use a larger value if you need finer-grained mini-batch 204 | --use-mask # use 1-hop graph masking to support intact full-batch/mini-batch training algorithm 205 | ``` 206 | 207 | #### Expansion-Aware Sampling 208 | The sample argument script for running expansion-aware sampling is provided at `Codes/brief_sampling_test.sh`. 209 | The main arguments are as follows: 210 | ``` 211 | --bandwidth-aware # turn on server-wise preloading 212 | --subgraph-hop 1 # sampling hop 213 | --fanout 15 # sampling fanout 214 | --sampler sage # use node-wise sampling 215 | --use-mask # use 1-hop graph masking to express dependency 216 | ``` 217 | 218 | ### 3. Distributed Launch 219 | We provide a simple distributed experiment runner interface for users unfamiliar with the distributed launch of training. 220 | The interface utilizes SSH for the distributed launch. 221 | The launchers using this interface are located at `AE/*.py`. 222 | Users can modify this launcher for their own use. 223 | 224 | 225 | ## Citation 226 | ``` 227 | @inproceedings{song2024granndis, 228 | title={{GraNNDis}: Fast Distributed Graph Neural Network Training Framework for Multi-Server Clusters}, 229 | author={Song, Jaeyong and Jang, Hongsun and Lim, Hunseong and Jung, Jaewon and Kim, Youngsok and Lee, Jinho}, 230 | booktitle={The 33rd International Conference on Parallel Architectures and Compilation Techniques (PACT 2024)}, 231 | year={2024} 232 | } 233 | ``` 234 | 235 | ## License 236 | For the codes from PipeGCN, we follow the license of it (MIT license). 237 | For other codes, the license is also under the MIT license. 238 | 239 | ## MISC 240 | 241 | For a further breakdown of internal/external communication time, users can utilize the `--check-intra-only` option. This option ignores external server communication, so users can figure out the internal server communication time only. The users also can further minimize the one-hop graph masking overhead through removing `--use-mask` option, but it does not provide the intact algorithm. 242 | -------------------------------------------------------------------------------- /Codes/helper/mapper.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import torch 4 | import torch.distributed as dist 5 | import dgl 6 | 7 | if TYPE_CHECKING: 8 | from helper.parser import Args 9 | 10 | class Mapper: 11 | """Original full-graph id <-> Server-wise macrobatch graph id <-> GPU-wise minibatch graph id 12 | Once you register the mapping, you can find full_id or gpu_id in O(1) time. 13 | 14 | This scheme has a total of 5 id types - full-id, server-id, psuedo-server-id, gpu-id, and remapped gpu-id. 15 | I'll refer full-id as fid, server-id and pseudo-server-id as sid, gpu-id and remapped-gpu-id as gid. 16 | When a sampler samples the full-graph, then full-id is translated into server-id. 17 | When the server-graph partitioned by METIS, first server-id is translated into psuedo-server-id, 18 | then the pseudo-server-graph is partitioned into many gpu-graphs. 19 | Finally, when the gpu-graph is transformed into the intra_graph, the gpu-graph remaps itself. 20 | 21 | To communicate which vertex should be shared(to implement historical embeddings), we have to know the full-ids of 22 | given boundary nodes, then send/recv to/from corresponding rank, 23 | then revert the full-ids into its corresponding gpu's (remapped)gpu_ids. 24 | """ 25 | def __init__(self, args: 'Args') -> None: 26 | self._full_graph_size: int = -1 27 | self.node_rank = args.node_rank 28 | self.local_rank = args.local_rank 29 | self.parts_per_node = args.parts_per_node 30 | 31 | self.args = args 32 | 33 | self._sid_from_fid: list[torch.Tensor] = [None] * args.epoch_iter 34 | self._fid_from_sid: list[torch.Tensor] = [None] * args.epoch_iter 35 | self._gid_from_fid: list[torch.Tensor] = [None] * args.epoch_iter 36 | self._fid_from_gid: list[torch.Tensor] = [None] * args.epoch_iter 37 | 38 | self._rank_from_fid: torch.Tensor = None 39 | 40 | if __debug__: 41 | if self.local_rank == 0: 42 | self.seeds: list[torch.Tensor] 43 | 44 | 45 | def synchronize(self): 46 | """Broadcast 0-th process's full graph size, server_from_full, full_from_server. 47 | """ 48 | args = self.args 49 | full_graph_size = torch.tensor(self._full_graph_size, device='cuda') 50 | dist.broadcast(full_graph_size, src=0, group=args.device_group) 51 | self._full_graph_size = full_graph_size.item() 52 | self._rank_from_fid = torch.zeros(self._full_graph_size, dtype=torch.int32, device='cuda') - 1 53 | 54 | for iter in range(args.epoch_iter): 55 | sid_from_fid_size = torch.tensor(self._sid_from_fid[iter].shape[0] if self.local_rank == 0 else -1, device='cuda') 56 | dist.broadcast(sid_from_fid_size, src=0, group=args.device_group) 57 | print(self.local_rank, sid_from_fid_size) 58 | if self.local_rank != 0: 59 | self._sid_from_fid[iter] = torch.empty(sid_from_fid_size, dtype=torch.int32, device='cuda') 60 | print(self.local_rank, self._sid_from_fid[iter].shape, self._sid_from_fid[iter]) 61 | dist.broadcast(self._sid_from_fid[iter], src=0, group=args.device_group) 62 | dist.barrier(args.device_group) 63 | 64 | fid_from_sid_size = torch.tensor(self._fid_from_sid[iter].shape[0] if self.local_rank == 0 else -1, device='cuda') 65 | dist.broadcast(fid_from_sid_size, src=0, group=args.device_group) 66 | print(self.local_rank, fid_from_sid_size) 67 | if self.local_rank != 0: 68 | self._fid_from_sid[iter] = torch.empty(fid_from_sid_size, dtype=torch.int32, device='cuda') 69 | print(self.local_rank, self._fid_from_sid[iter].shape, self._fid_from_sid[iter]) 70 | dist.broadcast(self._fid_from_sid[iter], src=0, group=args.device_group) 71 | dist.barrier(args.device_group) 72 | 73 | 74 | def register_full_graph_size(self, size: int): 75 | self._full_graph_size = size 76 | 77 | 78 | def register_fid_sid(self, iter: int, fid_from_sid: torch.Tensor, sid_from_psid: torch.Tensor) -> None: 79 | """Register mapping between full_ids and server_ids. This roughly takes O(V) time. 80 | 81 | Parameters 82 | ---------- 83 | iter : int 84 | Current iter value. 85 | fid_from_sid : Tensor 86 | Ids of the original full-graph. You can find this tensor by g.ndata[dgl.NID]. 87 | sid_from_psid : Tensor 88 | """ 89 | fid_from_sid = fid_from_sid.to(torch.int32) 90 | fid_from_psid = _chain(fid_from_sid, sid_from_psid) 91 | self._fid_from_sid[iter] = fid_from_psid 92 | self._sid_from_fid[iter] = _invert(fid_from_psid) 93 | 94 | self._fid_from_sid[iter] = self._fid_from_sid[iter].cuda() 95 | self._sid_from_fid[iter] = self._sid_from_fid[iter].cuda() 96 | 97 | 98 | def register_sid_gid(self, iter: int, node_dict: dict[str, torch.Tensor]) -> None: 99 | """Register mapping between server_ids and current gpu's gpu_ids. 100 | Before running this method, please first register a mapping between full_ids and server_ids. 101 | 102 | Parameters 103 | ---------- 104 | iter : int 105 | Current iter value. 106 | node_dict : dict[str, Tensor] 107 | """ 108 | assert self._sid_from_fid[iter] is not None 109 | assert self._fid_from_sid[iter] is not None 110 | 111 | sid_from_gid = node_dict[dgl.NID] 112 | self._fid_from_gid[iter] = _chain(self._fid_from_sid[iter], sid_from_gid) 113 | 114 | gid_from_sid = _invert(sid_from_gid) 115 | self._gid_from_fid[iter] = _chain(gid_from_sid, self._sid_from_fid[iter], naive=True) 116 | 117 | self._fid_from_gid[iter] = self._fid_from_gid[iter].cuda() 118 | self._gid_from_fid[iter] = self._gid_from_fid[iter].cuda() 119 | 120 | 121 | def register_and_share_core_nodes(self, iter: int, seed_node: torch.Tensor) -> None: 122 | """Register core nodes and all-gather to all workers (including other server's gpus). 123 | GPU's core nodes are (seed nodes of the GPU's partition) ∩ (inner nodes of the GPU). 124 | In other words, every full nodes are uniquely assigned as core to each GPUs. 125 | """ 126 | args = self.args 127 | rank, size = dist.get_rank(), dist.get_world_size() 128 | 129 | core_gids = seed_node.nonzero().flatten() 130 | core_fids = self._fid_from_gid[iter][core_gids] 131 | assert torch.all(core_fids != -1), f"{sum(core_fids == -1)=}" 132 | core_nodes_num = torch.tensor(len(core_fids), dtype=torch.int32, device='cuda') 133 | 134 | core_nodes_num_list = [torch.empty(1, dtype=torch.int32, device='cuda') for _ in range(size)] 135 | dist.all_gather(core_nodes_num_list, core_nodes_num) 136 | dist.barrier() 137 | 138 | core_full_ids_list = [torch.empty(core_nodes_num_list[r], dtype=torch.int32, device='cuda') for r in range(size)] 139 | dist.all_gather(core_full_ids_list, core_fids) 140 | dist.barrier() 141 | 142 | 143 | if __debug__: 144 | print(self.local_rank, f"{core_nodes_num=}", f"{core_fids=}") 145 | 146 | 147 | if self.local_rank == 0: 148 | assert self.seeds[iter] == set(torch.cat(core_full_ids_list).tolist()), f"{len(self.seeds[iter] & set(torch.cat(core_full_ids_list).tolist()))}" 149 | 150 | for r in range(size): 151 | assert torch.all(self._rank_from_fid[core_full_ids_list[r]] == -1), f"{r} {sum(self._rank_from_fid[core_full_ids_list[r]] != -1)=}" 152 | self._rank_from_fid[core_full_ids_list[r]] = r 153 | 154 | if self.local_rank == 0 and iter == args.epoch_iter - 1: 155 | assert sum(self._rank_from_fid != -1) == self.train_size, f"{sum(self._rank_from_fid != -1)=}, {self.train_size=}" 156 | 157 | 158 | def remap_gpu_and_gpu(self, iter: int, orig_from_new: torch.Tensor) -> None: 159 | assert self._gid_from_fid[iter] is not None 160 | assert self._fid_from_gid[iter] is not None 161 | 162 | self._fid_from_gid[iter] = _chain(self._fid_from_gid[iter], orig_from_new) 163 | 164 | new_from_orig = _invert(orig_from_new) 165 | self._gid_from_fid[iter] = _chain(new_from_orig, self._gid_from_fid[iter], naive=True) 166 | 167 | 168 | def get_full_ids(self, iter: int, gpu_ids: torch.Tensor) -> torch.Tensor: 169 | """Translates gpu_ids into full_ids. 170 | This method takes O(1) time. 171 | Make sure to register and remap the mapping between full_ids and gpu_ids. 172 | 173 | Parameters 174 | ---------- 175 | iter : int 176 | Current iter value. 177 | gpu_ids : int | list[int] | IntTensor 178 | 179 | Returns 180 | ------- 181 | full_ids : Tensor 182 | """ 183 | return self._fid_from_gid[iter][gpu_ids] 184 | 185 | 186 | def get_gpu_ids(self, iter: int, full_ids: torch.Tensor): 187 | """Translates full_ids into gpu_ids. 188 | This method takes O(1) time. 189 | Translating non-gpu full_ids into gpu_ids is undefined behavior (this could fire IndexError or return -1 tensor). 190 | So make sure full_ids are valid ids in the current gpu. 191 | Also make sure to register and remap the mapping between full_ids and gpu_ids. 192 | 193 | Parameters 194 | ---------- 195 | iter : int 196 | Current iter value. 197 | full_ids : int | list[int] | IntTensor 198 | full_ids must be valid ids in the current gpu. 199 | 200 | Returns 201 | ------- 202 | gpu_ids : Tensor 203 | """ 204 | return self._gid_from_fid[iter][full_ids] 205 | 206 | 207 | def get_rank(self, full_ids: torch.Tensor): 208 | return self._rank_from_fid[full_ids] 209 | 210 | 211 | def _chain(c_from_b: torch.Tensor, b_from_a: torch.Tensor, naive: bool = False) -> torch.Tensor: 212 | """Basically same as c_from_b[b_from_a], but this method ignores -1 from b_from_a. 213 | 214 | Parameters 215 | ---------- 216 | c_from_b : Tensor 217 | b_from_a : Tensor 218 | naive : bool 219 | If naive is true, then error does not occur when b_from_a is out of bounds from c_from_b. 220 | This option is used when the domain of b_from_a is bigger than the domain of c_from_b. 221 | 222 | Returns 223 | ------- 224 | c_from_a : Tensor 225 | 226 | Examples 227 | -------- 228 | When |B| > |A|: (ex: full <- server <- gpu) 229 | 230 | >>> c_b = torch.tensor([3, 2, 5, 4, 0]) 231 | >>> b_a = torch.tensor([1, 3, 0]) 232 | >>> c_a = _chain(c_b, b_a) 233 | >>> c_a 234 | tensor([2, 4, 3]) 235 | 236 | When |B| < |A|: (like gpu <- server <- full) 237 | 238 | >>> c_b = torch.tensor([1, 3, 0]) 239 | >>> b_a = torch.tensor([3, 2, -1, 4, 0]) 240 | >>> c_a = _chain(c_b, b_a, naive=True) 241 | >>> c_a 242 | tensor([-1, 0, -1, -1, 1]) 243 | 244 | In this case, c_a[0] is undefined because b_a[0] is 3 but 3 is not in the domain of c_b. 245 | c_a[1] is 0 because b_a[1] is 2 and c_b[2] is 0. 246 | c_a[2] is undefined because b_a[2] is undefined. 247 | """ 248 | len_a = len(b_from_a) 249 | dtype = c_from_b.dtype 250 | c_from_a = torch.zeros(len_a, dtype=dtype, device=c_from_b.device) - 1 251 | 252 | b_mask = b_from_a != -1 253 | if naive: 254 | safe_b_mask = b_from_a < len(c_from_b) 255 | b_mask &= safe_b_mask 256 | bs = b_from_a[b_mask] 257 | c_from_a[b_mask] = c_from_b[bs] 258 | 259 | """Above codes are equivalent to: 260 | for a, b in enumerate(b_from_a): 261 | if b != -1 (and b < len(c_from_b)): 262 | c_from_a[a] = c_from_b[b] 263 | """ 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | return c_from_a 273 | 274 | 275 | def _map_one_by_one(in_ids: torch.Tensor, out_ids: torch.Tensor) -> torch.Tensor: 276 | """Returns mapper tensor that behaves like mapper[in_ids] = out_ids. 277 | If in_ids has a hole(e.g. [0, 1, 2, 4, 5]), then out_id corresponds to the hole(e.g. mapper[3]) becomes -1. 278 | 279 | Parameters 280 | ---------- 281 | in_ids : Tensor 282 | Each ids must be unique. 283 | out_ids : Tensor 284 | Each ids must be unique. 285 | 286 | Returns 287 | ------- 288 | mapper : Tensor 289 | 290 | Examples 291 | -------- 292 | >>> t = torch.tensor([3, 2, 6, 4, 0]) 293 | >>> s = torch.tensor([4, 2, 1, 0, 5]) 294 | >>> mapper = _map_one_by_one(t, s) 295 | >>> mapper 296 | tensor([ 5, -1, 2, 4, 0, -1, 1]) 297 | >>> mapper[t] 298 | tensor([4, 2, 1, 0, 5]) 299 | """ 300 | assert in_ids.ndim == out_ids.ndim == 1 301 | assert in_ids.dtype == out_ids.dtype 302 | assert len(in_ids) == len(out_ids) 303 | l = in_ids.max().item() + 1 304 | dtype = out_ids.dtype 305 | mapper = torch.zeros(l, dtype=dtype, device=in_ids.device) - 1 306 | for i, o in zip(in_ids, out_ids): 307 | if i != -1: 308 | mapper[i] = o 309 | return mapper 310 | 311 | 312 | def _invert(mapper: torch.Tensor) -> torch.Tensor: 313 | """Returns inverted mapper. 314 | 315 | Parameters 316 | ---------- 317 | mapper : Tensor 318 | Each ids must be unique. 319 | 320 | Returns 321 | ------- 322 | mapper : Tensor 323 | 324 | Examples 325 | -------- 326 | >>> t = torch.tensor([1, 3, 4, 0]) 327 | >>> s = _invert(t) 328 | >>> s 329 | tensor([3, 0, -1, 1, 2]) 330 | >>> _invert(s) 331 | tensor([1, 3, 4, 0]) 332 | """ 333 | in_ids = torch.arange(mapper.shape[0], dtype=mapper.dtype, device=mapper.device) 334 | return _map_one_by_one(mapper, in_ids) 335 | -------------------------------------------------------------------------------- /Codes/helper/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import warnings 4 | import argparse 5 | warnings.filterwarnings("ignore") 6 | 7 | import scipy 8 | import torch 9 | import dgl 10 | from dgl.data import DGLDataset 11 | from dgl.data import RedditDataset, YelpDataset 12 | from dgl.distributed import partition_graph 13 | import torch.distributed as dist 14 | import time 15 | from contextlib import contextmanager 16 | from ogb.nodeproppred import DglNodePropPredDataset 17 | import json 18 | import numpy as np 19 | from sklearn.preprocessing import StandardScaler 20 | 21 | class IGB260M(object): 22 | def __init__(self, root: str, size: str, in_memory: int, \ 23 | classes: int, synthetic: int): 24 | self.dir = root 25 | self.size = size 26 | self.synthetic = synthetic 27 | self.in_memory = in_memory 28 | self.num_classes = classes 29 | 30 | def num_nodes(self): 31 | if self.size == 'experimental': 32 | return 100000 33 | elif self.size == 'small': 34 | return 1000000 35 | elif self.size == 'medium': 36 | return 10000000 37 | elif self.size == 'large': 38 | return 100000000 39 | elif self.size == 'full': 40 | return 269346174 41 | 42 | @property 43 | def paper_feat(self) -> np.ndarray: 44 | num_nodes = self.num_nodes() 45 | 46 | if self.size == 'large' or self.size == 'full': 47 | path = osp.join(self.dir, 'full', 'processed', 'paper', 'node_feat.npy') 48 | emb = np.memmap(path, dtype='float32', mode='r', shape=(num_nodes,1024)) 49 | else: 50 | path = osp.join(self.dir, self.size, 'processed', 'paper', 'node_feat.npy') 51 | if self.synthetic: 52 | emb = np.random.rand(num_nodes, 1024).astype('f') 53 | else: 54 | if self.in_memory: 55 | emb = np.load(path) 56 | else: 57 | emb = np.load(path, mmap_mode='r') 58 | 59 | return emb 60 | 61 | @property 62 | def paper_label(self) -> np.ndarray: 63 | 64 | if self.size == 'large' or self.size == 'full': 65 | num_nodes = self.num_nodes() 66 | if self.num_classes == 19: 67 | path = osp.join(self.dir, 'full', 'processed', 'paper', 'node_label_19.npy') 68 | node_labels = np.memmap(path, dtype='float32', mode='r', shape=(num_nodes)) 69 | 70 | else: 71 | path = osp.join(self.dir, 'full', 'processed', 'paper', 'node_label_2K.npy') 72 | node_labels = np.memmap(path, dtype='float32', mode='r', shape=(num_nodes)) 73 | 74 | 75 | else: 76 | if self.num_classes == 19: 77 | path = osp.join(self.dir, self.size, 'processed', 'paper', 'node_label_19.npy') 78 | else: 79 | path = osp.join(self.dir, self.size, 'processed', 'paper', 'node_label_2K.npy') 80 | if self.in_memory: 81 | node_labels = np.load(path) 82 | else: 83 | node_labels = np.load(path, mmap_mode='r') 84 | return node_labels 85 | 86 | @property 87 | def paper_edge(self) -> np.ndarray: 88 | path = osp.join(self.dir, self.size, 'processed', 'paper__cites__paper', 'edge_index.npy') 89 | 90 | 91 | if self.in_memory: 92 | return np.load(path) 93 | else: 94 | return np.load(path, mmap_mode='r') 95 | 96 | class IGB260MDGLDataset(DGLDataset): 97 | def __init__(self, args): 98 | self.dir = args.path 99 | self.args = args 100 | super().__init__(name='IGB260MDGLDataset') 101 | 102 | def process(self): 103 | dataset = IGB260M(root=self.dir, size=self.args.dataset_size, in_memory=self.args.in_memory, \ 104 | classes=self.args.num_classes, synthetic=self.args.synthetic) 105 | 106 | node_features = torch.from_numpy(dataset.paper_feat) 107 | node_edges = torch.from_numpy(dataset.paper_edge) 108 | node_labels = torch.from_numpy(dataset.paper_label).to(torch.long) 109 | 110 | self.graph = dgl.graph((node_edges[:, 0],node_edges[:, 1]), num_nodes=node_features.shape[0]) 111 | self.graph.ndata['feat'] = node_features 112 | self.graph.ndata['label'] = node_labels 113 | 114 | self.graph = dgl.remove_self_loop(self.graph) 115 | self.graph = dgl.add_self_loop(self.graph) 116 | 117 | if self.args.dataset_size == 'full': 118 | 119 | if self.args.num_classes == 19: 120 | n_labeled_idx = 227130858 121 | else: 122 | n_labeled_idx = 157675969 123 | 124 | n_nodes = node_features.shape[0] 125 | n_train = int(n_labeled_idx * 0.6) 126 | n_val = int(n_labeled_idx * 0.2) 127 | 128 | train_mask = torch.zeros(n_nodes, dtype=torch.bool) 129 | val_mask = torch.zeros(n_nodes, dtype=torch.bool) 130 | test_mask = torch.zeros(n_nodes, dtype=torch.bool) 131 | 132 | train_mask[:n_train] = True 133 | val_mask[n_train:n_train + n_val] = True 134 | test_mask[n_train + n_val:n_labeled_idx] = True 135 | 136 | self.graph.ndata['train_mask'] = train_mask 137 | self.graph.ndata['val_mask'] = val_mask 138 | self.graph.ndata['test_mask'] = test_mask 139 | else: 140 | n_nodes = node_features.shape[0] 141 | n_train = int(n_nodes * 0.6) 142 | n_val = int(n_nodes * 0.2) 143 | 144 | train_mask = torch.zeros(n_nodes, dtype=torch.bool) 145 | val_mask = torch.zeros(n_nodes, dtype=torch.bool) 146 | test_mask = torch.zeros(n_nodes, dtype=torch.bool) 147 | 148 | train_mask[:n_train] = True 149 | val_mask[n_train:n_train + n_val] = True 150 | test_mask[n_train + n_val:] = True 151 | 152 | self.graph.ndata['train_mask'] = train_mask 153 | self.graph.ndata['val_mask'] = val_mask 154 | self.graph.ndata['test_mask'] = test_mask 155 | 156 | def __getitem__(self, i): 157 | return self.graph 158 | 159 | def __len__(self): 160 | return len(self.graphs) 161 | 162 | def load_igb_dataset(args, name): 163 | igb_args = args 164 | igb_args.path = args.dataset_path + 'igb/' 165 | 166 | def _select_size(name): 167 | if 'tiny' in name: 168 | return 'tiny' 169 | elif 'small' in name: 170 | return 'small' 171 | elif 'medium' in name: 172 | return 'medium' 173 | elif 'large' in name: 174 | return 'large' 175 | elif 'full' in name: 176 | return 'full' 177 | 178 | igb_args.dataset_size = _select_size(name) 179 | 180 | 181 | 182 | igb_args.num_classes = 19 183 | igb_args.in_memory = 0 184 | igb_args.synthetic = 0 185 | dataset = IGB260MDGLDataset(args=igb_args) 186 | g = dataset[0] 187 | 188 | g = dgl.add_reverse_edges(g) 189 | return g 190 | 191 | def load_ogb_dataset(args, name): 192 | dataset = DglNodePropPredDataset(name=name, root=args.dataset_path) 193 | split_idx = dataset.get_idx_split() 194 | g, label = dataset[0] 195 | 196 | if name == 'ogbn-arxiv' or name == 'ogbn-papers100M': 197 | g = dgl.add_reverse_edges(g) 198 | n_node = g.num_nodes() 199 | node_data = g.ndata 200 | node_data['label'] = label.view(-1).long() 201 | node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool) 202 | node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool) 203 | node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool) 204 | node_data['train_mask'][split_idx["train"]] = True 205 | node_data['val_mask'][split_idx["valid"]] = True 206 | node_data['test_mask'][split_idx["test"]] = True 207 | return g 208 | 209 | 210 | def load_yelp(args): 211 | prefix = args.dataset_path + 'yelp/' 212 | 213 | with open(prefix + 'class_map.json') as f: 214 | class_map = json.load(f) 215 | with open(prefix + 'role.json') as f: 216 | role = json.load(f) 217 | 218 | adj_full = scipy.sparse.load_npz(prefix + 'adj_full.npz') 219 | feats = np.load(prefix + 'feats.npy') 220 | n_node = feats.shape[0] 221 | 222 | g = dgl.from_scipy(adj_full) 223 | node_data = g.ndata 224 | 225 | label = list(class_map.values()) 226 | node_data['label'] = torch.tensor(label) 227 | 228 | node_data['train_mask'] = torch.zeros(n_node, dtype=torch.bool) 229 | node_data['val_mask'] = torch.zeros(n_node, dtype=torch.bool) 230 | node_data['test_mask'] = torch.zeros(n_node, dtype=torch.bool) 231 | node_data['train_mask'][role['tr']] = True 232 | node_data['val_mask'][role['va']] = True 233 | node_data['test_mask'][role['te']] = True 234 | 235 | assert torch.all(torch.logical_not(torch.logical_and(node_data['train_mask'], node_data['val_mask']))) 236 | assert torch.all(torch.logical_not(torch.logical_and(node_data['train_mask'], node_data['test_mask']))) 237 | assert torch.all(torch.logical_not(torch.logical_and(node_data['val_mask'], node_data['test_mask']))) 238 | assert torch.all( 239 | torch.logical_or(torch.logical_or(node_data['train_mask'], node_data['val_mask']), node_data['test_mask'])) 240 | 241 | train_feats = feats[node_data['train_mask']] 242 | scaler = StandardScaler() 243 | scaler.fit(train_feats) 244 | feats = scaler.transform(feats) 245 | 246 | node_data['feat'] = torch.tensor(feats, dtype=torch.float) 247 | 248 | return g 249 | 250 | def graph_scaler(orig_g, orig_n, orig_e, scale=1.0): 251 | scale = int(scale) 252 | scaled_g = dgl.rand_graph(int(orig_n*scale), int(orig_e*scale), device='cpu') 253 | scaled_g.ndata['train_mask'] = orig_g.ndata['train_mask'].repeat_interleave(scale) 254 | scaled_g.ndata['val_mask'] = orig_g.ndata['val_mask'].repeat_interleave(scale) 255 | scaled_g.ndata['test_mask'] = orig_g.ndata['test_mask'].repeat_interleave(scale) 256 | scaled_g.ndata['feat'] = orig_g.ndata['feat'].repeat_interleave(scale, dim=0) 257 | scaled_g.ndata['label'] = orig_g.ndata['label'].repeat_interleave(scale) 258 | scaled_g.edata['__orig__'] = orig_g.edata['__orig__'].repeat_interleave(scale) 259 | return scaled_g 260 | 261 | def load_data(args, dataset): 262 | if dataset == 'reddit': 263 | data = RedditDataset(raw_dir=args.dataset_path) 264 | reddit_num_nodes = 232_965 265 | reddit_num_edges = 114_615_892 266 | g = data[0] 267 | if args.dataset_scale != 1: 268 | g = graph_scaler(g, reddit_num_nodes, reddit_num_edges, scale=args.dataset_scale) 269 | elif 'igb' in dataset: 270 | g = load_igb_dataset(args, dataset) 271 | elif dataset == 'ogbn-products': 272 | g = load_ogb_dataset(args, 'ogbn-products') 273 | elif dataset == 'ogbn-papers100m': 274 | g = load_ogb_dataset(args, 'ogbn-papers100M') 275 | elif dataset == 'ogbn-arxiv': 276 | g = load_ogb_dataset(args, 'ogbn-arxiv') 277 | elif dataset == 'yelp': 278 | 279 | g = load_yelp(args) 280 | 281 | 282 | else: 283 | raise ValueError('Unknown dataset: {}'.format(dataset)) 284 | 285 | n_feat = g.ndata['feat'].shape[1] 286 | if g.ndata['label'].dim() == 1: 287 | n_class = g.ndata['label'].max().item() + 1 288 | else: 289 | n_class = g.ndata['label'].shape[1] 290 | 291 | g.edata.clear() 292 | g = dgl.remove_self_loop(g) 293 | g = dgl.add_self_loop(g) 294 | return g, n_feat, n_class 295 | 296 | 297 | def load_partition(args, rank): 298 | graph_dir = 'partitions/' + args.graph_name + '/' 299 | part_config = graph_dir + args.graph_name + '.json' 300 | 301 | print('loading partitions') 302 | 303 | subg, node_feat, _, gpb, _, node_type, _ = dgl.distributed.load_partition(part_config, rank) 304 | node_type = node_type[0] 305 | node_feat[dgl.NID] = subg.ndata[dgl.NID] 306 | if 'part_id' in subg.ndata: 307 | node_feat['part_id'] = subg.ndata['part_id'] 308 | node_feat['inner_node'] = subg.ndata['inner_node'].bool() 309 | node_feat['label'] = node_feat[node_type + '/label'] 310 | node_feat['feat'] = node_feat[node_type + '/feat'] 311 | node_feat['in_degree'] = node_feat[node_type + '/in_degree'] 312 | node_feat['train_mask'] = node_feat[node_type + '/train_mask'].bool() 313 | node_feat.pop(node_type + '/label') 314 | node_feat.pop(node_type + '/feat') 315 | node_feat.pop(node_type + '/in_degree') 316 | node_feat.pop(node_type + '/train_mask') 317 | if not args.inductive: 318 | node_feat['val_mask'] = node_feat[node_type + '/val_mask'].bool() 319 | node_feat['test_mask'] = node_feat[node_type + '/test_mask'].bool() 320 | node_feat.pop(node_type + '/val_mask') 321 | node_feat.pop(node_type + '/test_mask') 322 | if args.dataset == 'ogbn-papers100m': 323 | node_feat.pop(node_type + '/year') 324 | subg.ndata.clear() 325 | subg.edata.clear() 326 | 327 | return subg, node_feat, gpb 328 | 329 | 330 | def graph_partition(g, args): 331 | graph_dir = 'partitions/' + args.graph_name + '/' 332 | part_config = graph_dir + args.graph_name + '.json' 333 | 334 | 335 | if not os.path.exists(part_config): 336 | with g.local_scope(): 337 | if args.inductive: 338 | g.ndata.pop('val_mask') 339 | g.ndata.pop('test_mask') 340 | g.ndata['in_degree'] = g.in_degrees() 341 | partition_graph(g, args.graph_name, args.n_partitions, graph_dir, part_method=args.partition_method, 342 | balance_edges=False, objtype=args.partition_obj) 343 | 344 | 345 | def load_intra_partition(args, local_rank, iter): 346 | graph_dir = 'intra-partitions/%s_%dlayers/total_%d_nodes_local_%d_gpus/iter%d/rank%d/' % (args.dataset, (args.n_layers - args.n_linear), args.total_nodes, 347 | args.local_device_cnt, iter, args.node_rank) 348 | part_config = graph_dir + args.graph_name + '.json' 349 | 350 | 351 | 352 | 353 | subg, node_feat, _, gpb, _, node_type, _ = dgl.distributed.load_partition(part_config, local_rank) 354 | node_type = node_type[0] 355 | node_feat[dgl.NID] = subg.ndata[dgl.NID] 356 | if 'part_id' in subg.ndata: 357 | node_feat['part_id'] = subg.ndata['part_id'] 358 | 359 | 360 | 361 | 362 | node_feat['inner_node'] = subg.ndata['inner_node'].bool() 363 | node_feat['seed_node'] = node_feat[node_type + '/seed_node'].bool() 364 | node_feat['label'] = node_feat[node_type + '/label'] 365 | node_feat['feat'] = node_feat[node_type + '/feat'] 366 | node_feat['in_degree'] = node_feat[node_type + '/in_degree'] 367 | node_feat['train_mask'] = node_feat[node_type + '/train_mask'].bool() 368 | node_feat.pop(node_type + '/seed_node') 369 | node_feat.pop(node_type + '/label') 370 | node_feat.pop(node_type + '/feat') 371 | node_feat.pop(node_type + '/in_degree') 372 | node_feat.pop(node_type + '/train_mask') 373 | if not args.inductive: 374 | node_feat['val_mask'] = node_feat[node_type + '/val_mask'].bool() 375 | node_feat['test_mask'] = node_feat[node_type + '/test_mask'].bool() 376 | node_feat.pop(node_type + '/val_mask') 377 | node_feat.pop(node_type + '/test_mask') 378 | if args.dataset == 'ogbn-papers100m': 379 | node_feat.pop(node_type + '/year') 380 | subg.ndata.clear() 381 | subg.edata.clear() 382 | 383 | return subg, node_feat, gpb 384 | 385 | def server_node_graph_partition(g, iter, args, map_dict=None, my_seed=None): 386 | """Does not manipulate anything; just save the partitioned graphs in the corresponding directory. 387 | g: subgraph for current server.""" 388 | sid_from_psid = None 389 | graph_dir = 'intra-partitions/%s_%dlayers/total_%d_nodes_local_%d_gpus/iter%d/rank%d/' % (args.dataset, (args.n_layers - args.n_linear), args.total_nodes, 390 | args.local_device_cnt, iter, args.node_rank) 391 | part_config = graph_dir + args.graph_name + '.json' 392 | 393 | 394 | 395 | 396 | print('!!!PARTITIONING GRAPH PER SERVER.... BE PATIENT!!!') 397 | 398 | 399 | if not os.path.exists(part_config): 400 | with g.local_scope(): 401 | if args.inductive: 402 | g.ndata.pop('val_mask') 403 | g.ndata.pop('test_mask') 404 | g.ndata['in_degree'] = g.in_degrees() 405 | sid_from_psid, _ = partition_graph(g, args.graph_name, args.local_device_cnt, graph_dir, part_method=args.partition_method, 406 | balance_edges=False, objtype=args.partition_obj, return_mapping=True) 407 | 408 | print(g.nodes(), ' ', g.nodes().shape) 409 | print(g.ndata[dgl.NID], 'MID ID MAPPER') 410 | return sid_from_psid 411 | 412 | 413 | def get_layer_size(args, n_feat, n_hidden, n_class, n_layers): 414 | 415 | if args.model in ['deepgcn']: 416 | layer_size = [n_hidden] 417 | else: 418 | layer_size = [n_feat] 419 | 420 | 421 | layer_size.extend([n_hidden] * (n_layers - 1)) 422 | 423 | layer_size.append(n_class) 424 | return layer_size 425 | 426 | 427 | def get_intra_boundary(args, node_dict, gpb): 428 | """Get boundary of the intra graph of this server. 429 | 430 | Parameters 431 | ---------- 432 | args : Args 433 | node_dict : NodeDict 434 | gpb : GraphPartitionBook 435 | 436 | Returns 437 | ------- 438 | boundary : list[Tensor | None] 439 | boundary[i] is the gpu_ids of inner vertices on my GPU partition that also exist on the i-th GPU. 440 | In other words, boundary vertex has at least one inner vertex of the i-th GPU partition as a neighbor. 441 | Also, boundary[local_rank] is None. 442 | """ 443 | local_rank, local_size = args.rank % args.local_device_cnt, args.local_device_cnt 444 | start_id = args.local_device_cnt * args.node_rank 445 | 446 | boundary = [None] * local_size 447 | buffer_size = list() 448 | 449 | """ 450 | Send how much i need from other rank.... 451 | Therefore, finally each worker has an array, 452 | which contains information about how much i need to send to other ranks. 453 | (So, array[my_rank_id] is None) 454 | """ 455 | 456 | for i in range(1, local_size): 457 | left = (local_rank - i + local_size) % local_size 458 | right = (local_rank + i) % local_size 459 | belong_right = (node_dict['part_id'] == right) 460 | num_right = belong_right.sum().view(-1) 461 | 462 | if dist.get_backend() == 'gloo': 463 | num_right = num_right.cpu() 464 | num_left = torch.tensor([0]) 465 | else: 466 | num_left = torch.tensor([0], device='cuda') 467 | if local_rank < i: 468 | dist.send(num_right, dst=start_id + right) 469 | dist.recv(num_left, src=start_id + left) 470 | if local_rank >= i: 471 | dist.send(num_right, dst=start_id + right) 472 | buffer_size.append(num_left) 473 | 474 | for i in range(1, local_size): 475 | left = (local_rank - i + local_size) % local_size 476 | right = (local_rank + i) % local_size 477 | belong_right = (node_dict['part_id'] == right) 478 | start = gpb.partid2nids(right)[0].item() 479 | v = node_dict[dgl.NID][belong_right] - start 480 | 481 | if dist.get_backend() == 'gloo': 482 | v = v.cpu() 483 | u = torch.zeros(buffer_size[i-1], dtype=torch.long) 484 | else: 485 | u = torch.zeros(buffer_size[i-1], dtype=torch.long, device='cuda') 486 | 487 | if local_rank < i: 488 | dist.send(v, dst=start_id + right) 489 | 490 | dist.recv(u, src=start_id + left) 491 | 492 | if local_rank >= i: 493 | dist.send(v, dst=start_id + right) 494 | 495 | u, _ = torch.sort(u) 496 | 497 | if dist.get_backend() == 'gloo': 498 | boundary[left] = u.to('cuda') 499 | else: 500 | boundary[left] = u 501 | return boundary 502 | 503 | 504 | def get_boundary(node_dict, gpb, device): 505 | rank, size = dist.get_rank(), dist.get_world_size() 506 | boundary = [None] * size 507 | buffer_size = list() 508 | for i in range(1, size): 509 | left = (rank - i + size) % size 510 | right = (rank + i) % size 511 | belong_right = (node_dict['part_id'] == right) 512 | num_right = belong_right.sum().view(-1) 513 | if dist.get_backend() == 'gloo': 514 | num_right = num_right.cpu() 515 | num_left = torch.tensor([0]) 516 | else: 517 | num_left = torch.tensor([0], device=device) 518 | if rank < i: 519 | dist.send(num_right, dst=right) 520 | dist.recv(num_left, src=left) 521 | if rank >= i: 522 | dist.send(num_right, dst=right) 523 | buffer_size.append(num_left) 524 | for i in range(1, size): 525 | left = (rank - i + size) % size 526 | right = (rank + i) % size 527 | belong_right = (node_dict['part_id'] == right) 528 | start = gpb.partid2nids(right)[0].item() 529 | v = node_dict[dgl.NID][belong_right] - start 530 | if dist.get_backend() == 'gloo': 531 | v = v.cpu() 532 | u = torch.zeros(buffer_size[i-1], dtype=torch.long) 533 | else: 534 | u = torch.zeros(buffer_size[i-1], dtype=torch.long, device=device) 535 | if rank < i: 536 | dist.send(v, dst=right) 537 | dist.recv(u, src=left) 538 | if rank >= i: 539 | dist.send(v, dst=right) 540 | u, _ = torch.sort(u) 541 | if dist.get_backend() == 'gloo': 542 | boundary[left] = u.to(device) 543 | else: 544 | boundary[left] = u 545 | return boundary 546 | 547 | 548 | 549 | def data_transfer(data, recv_shape, backend, dtype=torch.float, tag=0): 550 | rank, size = dist.get_rank(), dist.get_world_size() 551 | res = [None] * size 552 | 553 | for i in range(1, size): 554 | left = (rank - i + size) % size 555 | if backend == 'gloo': 556 | res[left] = torch.zeros(torch.Size([recv_shape[left], data[left].shape[1]]), dtype=dtype) 557 | else: 558 | res[left] = torch.zeros(torch.Size([recv_shape[left], data[left].shape[1]]), dtype=dtype, device='cuda') 559 | 560 | for i in range(1, size): 561 | left = (rank - i + size) % size 562 | right = (rank + i) % size 563 | if backend == 'gloo': 564 | req = dist.isend(data[right].cpu(), dst=right, tag=tag) 565 | else: 566 | req = dist.isend(data[right], dst=right, tag=tag) 567 | dist.recv(res[left], src=left, tag=tag) 568 | res[left] = res[left].cuda() 569 | req.wait() 570 | 571 | return res 572 | 573 | 574 | def merge_feature(feat, recv): 575 | size = len(recv) 576 | for i in range(size - 1, 0, -1): 577 | if recv[i] is None: 578 | recv[i] = recv[i - 1] 579 | recv[i - 1] = None 580 | recv[0] = feat 581 | return torch.cat(recv) 582 | 583 | 584 | def inductive_split(g): 585 | g_train = g.subgraph(g.ndata['train_mask']) 586 | g_val = g.subgraph(g.ndata['train_mask'] | g.ndata['val_mask']) 587 | g_test = g 588 | return g_train, g_val, g_test 589 | 590 | 591 | def minus_one_tensor(size, device=None): 592 | if device is not None: 593 | return torch.zeros(size, dtype=torch.long, device=device) - 1 594 | else: 595 | return torch.zeros(size, dtype=torch.long) - 1 596 | 597 | 598 | def nonzero_idx(x): 599 | return torch.nonzero(x, as_tuple=True)[0] 600 | 601 | 602 | def print_memory(s): 603 | torch.cuda.synchronize() 604 | print(s + ': current {:.2f}MB, peak {:.2f}MB, reserved {:.2f}MB'.format( 605 | torch.cuda.memory_allocated() / 1024 / 1024, 606 | torch.cuda.max_memory_allocated() / 1024 / 1024, 607 | torch.cuda.memory_reserved() / 1024 / 1024 608 | )) 609 | 610 | 611 | @contextmanager 612 | def timer(s): 613 | rank, size = dist.get_rank(), dist.get_world_size() 614 | t = time.time() 615 | yield 616 | print('(rank %d) running time of %s: %.3f seconds' % (rank, s, time.time() - t)) 617 | -------------------------------------------------------------------------------- /Codes/train.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from module.baseline_model import * 3 | from helper.utils import * 4 | import torch.distributed as dist 5 | import time 6 | import copy 7 | from multiprocessing.pool import ThreadPool 8 | from sklearn.metrics import f1_score 9 | 10 | from datetime import timedelta 11 | from torch.nn.parallel import DistributedDataParallel as DDP 12 | import traceback 13 | from helper.MongoManager import * 14 | 15 | import datetime 16 | 17 | def calc_acc(logits, labels): 18 | if labels.dim() == 1: 19 | _, indices = torch.max(logits, dim=1) 20 | correct = torch.sum(indices == labels) 21 | return correct.item() / labels.shape[0] 22 | else: 23 | return f1_score(labels, logits > 0, average='micro') 24 | 25 | 26 | @torch.no_grad() 27 | def evaluate_induc(name, model, g, mode, result_file_name=None): 28 | """ 29 | mode: 'val' or 'test' 30 | """ 31 | model.eval() 32 | model.cpu() 33 | feat, labels = g.ndata['feat'], g.ndata['label'] 34 | mask = g.ndata[mode + '_mask'] 35 | logits = model(g, feat) 36 | logits = logits[mask] 37 | labels = labels[mask] 38 | acc = calc_acc(logits, labels) 39 | buf = "{:s} | Accuracy {:.2%}".format(name, acc) 40 | if result_file_name is not None: 41 | with open(result_file_name, 'a+') as f: 42 | f.write(buf + '\n') 43 | print(buf) 44 | else: 45 | print(buf) 46 | return model, acc 47 | 48 | 49 | @torch.no_grad() 50 | def evaluate_trans(name, model, g, result_file_name=None): 51 | model.eval() 52 | model.cpu() 53 | feat, labels = g.ndata['feat'], g.ndata['label'] 54 | val_mask, test_mask = g.ndata['val_mask'], g.ndata['test_mask'] 55 | logits = model(g, feat) 56 | val_logits, test_logits = logits[val_mask], logits[test_mask] 57 | val_labels, test_labels = labels[val_mask], labels[test_mask] 58 | val_acc = calc_acc(val_logits, val_labels) 59 | test_acc = calc_acc(test_logits, test_labels) 60 | buf = "{:s} | Validation Accuracy {:.2%} | Test Accuracy {:.2%}".format(name, val_acc, test_acc) 61 | if result_file_name is not None: 62 | with open(result_file_name, 'a+') as f: 63 | f.write(buf + '\n') 64 | print(buf) 65 | else: 66 | print(buf) 67 | return model, val_acc 68 | 69 | 70 | def average_gradients(model, n_train): 71 | reduce_time = 0 72 | for i, (name, param) in enumerate(model.named_parameters()): 73 | t0 = time.time() 74 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 75 | param.grad.data /= n_train 76 | reduce_time += time.time() - t0 77 | return reduce_time 78 | 79 | 80 | def move_to_cuda(graph, part, node_dict, device): 81 | 82 | for key in node_dict.keys(): 83 | node_dict[key] = node_dict[key].to(device) 84 | graph = graph.int().to(device) 85 | part = part.int().to(device) 86 | 87 | return graph, part, node_dict 88 | 89 | 90 | def get_pos(node_dict, gpb, device): 91 | pos = [] 92 | rank, size = dist.get_rank(), dist.get_world_size() 93 | for i in range(size): 94 | if i == rank: 95 | pos.append(None) 96 | else: 97 | part_size = gpb.partid2nids(i).shape[0] 98 | start = gpb.partid2nids(i)[0].item() 99 | p = minus_one_tensor(part_size, device= device) 100 | in_idx = nonzero_idx(node_dict['part_id'] == i) 101 | out_idx = node_dict[dgl.NID][in_idx] - start 102 | p[out_idx] = in_idx 103 | pos.append(p) 104 | return pos 105 | 106 | 107 | def get_recv_shape(node_dict): 108 | rank, size = dist.get_rank(), dist.get_world_size() 109 | recv_shape = [] 110 | for i in range(size): 111 | if i == rank: 112 | recv_shape.append(None) 113 | else: 114 | t = (node_dict['part_id'] == i).int().sum().item() 115 | recv_shape.append(t) 116 | return recv_shape 117 | 118 | 119 | def create_inner_graph(graph, node_dict): 120 | u, v = graph.edges() 121 | sel = torch.logical_and(node_dict['inner_node'].bool()[u], node_dict['inner_node'].bool()[v]) 122 | u, v = u[sel], v[sel] 123 | return dgl.graph((u, v)) 124 | 125 | 126 | def order_graph(part, graph, gpb, node_dict, pos): 127 | rank, size = dist.get_rank(), dist.get_world_size() 128 | one_hops = [] 129 | for i in range(size): 130 | if i == rank: 131 | one_hops.append(None) 132 | continue 133 | start = gpb.partid2nids(i)[0].item() 134 | nodes = node_dict[dgl.NID][node_dict['part_id'] == i] - start 135 | nodes, _ = torch.sort(nodes) 136 | one_hops.append(nodes) 137 | return construct(part, graph, pos, one_hops) 138 | 139 | 140 | def move_train_first(graph, node_dict, boundary, device): 141 | train_mask = node_dict['train_mask'] 142 | num_train = torch.count_nonzero(train_mask).item() 143 | num_tot = graph.num_nodes('_V') 144 | 145 | new_id = torch.zeros(num_tot, dtype=torch.int, device=device) 146 | new_id[train_mask] = torch.arange(num_train, dtype=torch.int, device=device) 147 | new_id[torch.logical_not(train_mask)] = torch.arange(num_train, num_tot, dtype=torch.int, device=device) 148 | 149 | u, v = graph.edges() 150 | u[u < num_tot] = new_id[u[u < num_tot].long()] 151 | v = new_id[v.long()] 152 | graph = dgl.heterograph({('_U', '_E', '_V'): (u, v)}) 153 | 154 | for key in node_dict: 155 | node_dict[key][new_id.long()] = node_dict[key][0:num_tot].clone() 156 | 157 | for i in range(len(boundary)): 158 | if boundary[i] is not None: 159 | boundary[i] = new_id[boundary[i]].long() 160 | 161 | return graph, node_dict, boundary 162 | 163 | 164 | def create_graph_train(graph, node_dict): 165 | u, v = graph.edges() 166 | num_u = graph.num_nodes('_U') 167 | sel = nonzero_idx(node_dict['train_mask'][v.long()]) 168 | u, v = u[sel], v[sel] 169 | graph = dgl.heterograph({('_U', '_E', '_V'): (u, v)}) 170 | if graph.num_nodes('_U') < num_u: 171 | graph.add_nodes(num_u - graph.num_nodes('_U'), ntype='_U') 172 | return graph, node_dict['in_degree'][node_dict['train_mask']] 173 | 174 | 175 | def precompute(graph, node_dict, boundary, recv_shape, args): 176 | rank, size = dist.get_rank(), dist.get_world_size() 177 | in_size = node_dict['inner_node'].bool().sum() 178 | feat = node_dict['feat'] 179 | send_info = [] 180 | for i, b in enumerate(boundary): 181 | if i == rank: 182 | send_info.append(None) 183 | else: 184 | send_info.append(feat[b]) 185 | recv_feat = data_transfer(send_info, recv_shape, args.backend, dtype=torch.float) 186 | if args.model == 'graphsage': 187 | with graph.local_scope(): 188 | graph.nodes['_U'].data['h'] = merge_feature(feat, recv_feat) 189 | graph['_E'].update_all(fn.copy_u(u='h', out='m'), 190 | fn.sum(msg='m', out='h'), 191 | etype='_E') 192 | mean_feat = graph.nodes['_V'].data['h'] / node_dict['in_degree'][0:in_size].unsqueeze(1) 193 | return torch.cat([feat, mean_feat[0:in_size]], dim=1) 194 | else: 195 | raise Exception 196 | 197 | 198 | def create_model(layer_size, args): 199 | if args.model == 'graphsage': 200 | return GraphSAGE(layer_size, F.relu, args.use_pp, norm=args.norm, dropout=args.dropout, 201 | n_linear=args.n_linear, train_size=args.n_train) 202 | elif args.model == 'deepgcn': 203 | return DeeperGCN(layer_size, nn.ReLU(inplace=True), args.use_pp, args.n_feat, args.n_class, norm=args.norm, dropout=args.dropout, 204 | n_linear=args.n_linear, train_size=args.n_train) 205 | else: 206 | raise NotImplementedError 207 | 208 | 209 | def reduce_hook(param, name, n_train): 210 | def fn(grad): 211 | ctx.reducer.reduce(param, name, grad, n_train) 212 | return fn 213 | 214 | 215 | def construct(part, graph, pos, one_hops): 216 | rank, size = dist.get_rank(), dist.get_world_size() 217 | tot = part.num_nodes() 218 | u, v = part.edges() 219 | u_list, v_list = [u], [v] 220 | for i in range(size): 221 | if i == rank: 222 | continue 223 | else: 224 | u = one_hops[i] 225 | if u.shape[0] == 0: 226 | continue 227 | u = pos[i][u] 228 | u_ = torch.repeat_interleave(graph.out_degrees(u.int()).long()) + tot 229 | tot += u.shape[0] 230 | _, v = graph.out_edges(u.int()) 231 | u_list.append(u_.int()) 232 | v_list.append(v) 233 | u = torch.cat(u_list) 234 | v = torch.cat(v_list) 235 | g = dgl.heterograph({('_U', '_E', '_V'): (u, v)}) 236 | if g.num_nodes('_U') < tot: 237 | g.add_nodes(tot - g.num_nodes('_U'), ntype='_U') 238 | return g 239 | 240 | 241 | def extract(graph, node_dict): 242 | rank, size = dist.get_rank(), dist.get_world_size() 243 | sel = (node_dict['part_id'] < size) 244 | for key in node_dict.keys(): 245 | if node_dict[key].shape[0] == sel.shape[0]: 246 | node_dict[key] = node_dict[key][sel] 247 | graph = dgl.node_subgraph(graph, sel, store_ids=False) 248 | return graph, node_dict 249 | 250 | 251 | def run(graph, node_dict, gpb, device, args): 252 | 253 | rank, size = dist.get_rank(), dist.get_world_size() 254 | 255 | torch.autograd.set_detect_anomaly(False) 256 | torch.autograd.profiler.profile(False) 257 | torch.autograd.profiler.emit_nvtx(False) 258 | 259 | if rank == 0 and args.eval: 260 | full_g, n_feat, n_class = load_data(args, args.dataset) 261 | if args.inductive: 262 | _, val_g, test_g = inductive_split(full_g) 263 | else: 264 | val_g, test_g = full_g.clone(), full_g.clone() 265 | del full_g 266 | 267 | if rank == 0: 268 | os.makedirs('checkpoint/', exist_ok=True) 269 | os.makedirs('results/', exist_ok=True) 270 | 271 | part = create_inner_graph(graph.clone(), node_dict) 272 | 273 | 274 | num_in = node_dict['inner_node'].bool().sum().item() 275 | part.ndata.clear() 276 | part.edata.clear() 277 | 278 | print(f'Process {rank} has {graph.num_nodes()} nodes, {graph.num_edges()} edges ' 279 | f'{part.num_nodes()} inner nodes, and {part.num_edges()} inner edges.') 280 | 281 | graph, part, node_dict = move_to_cuda(graph, part, node_dict, device) 282 | 283 | boundary = get_boundary(node_dict, gpb, device) 284 | 285 | 286 | layer_size = get_layer_size(args, args.n_feat, args.n_hidden, args.n_class, args.n_layers) 287 | 288 | pos = get_pos(node_dict, gpb, device) 289 | graph = order_graph(part, graph, gpb, node_dict, pos) 290 | in_deg = node_dict['in_degree'] 291 | 292 | graph, node_dict, boundary = move_train_first(graph, node_dict, boundary, device) 293 | 294 | recv_shape = get_recv_shape(node_dict) 295 | 296 | ctx.buffer.init_buffer(num_in, graph.num_nodes('_U'), boundary, recv_shape, layer_size[:args.n_layers-args.n_linear], 297 | args.model, use_pp=args.use_pp, backend=args.backend, pipeline=args.enable_pipeline, 298 | corr_feat=args.feat_corr, corr_grad=args.grad_corr, corr_momentum=args.corr_momentum, debug= args.debug, check_intra_only = args.check_intra_only) 299 | 300 | if args.use_pp: 301 | node_dict['feat'] = precompute(graph, node_dict, boundary, recv_shape, args) 302 | 303 | labels = node_dict['label'][node_dict['train_mask']] 304 | train_mask = node_dict['train_mask'] 305 | part_train = train_mask.int().sum().item() 306 | 307 | 308 | del boundary 309 | del part 310 | del pos 311 | 312 | model = create_model(layer_size, args) 313 | 314 | model.to(device) 315 | 316 | model = DDP(model) 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | best_model, best_acc = None, 0 328 | 329 | if args.grad_corr and args.feat_corr: 330 | result_file_name = 'results/%s_n%d_p%d_grad_feat.txt' % (args.dataset, args.n_partitions, int(args.enable_pipeline)) 331 | elif args.grad_corr: 332 | result_file_name = 'results/%s_n%d_p%d_grad.txt' % (args.dataset, args.n_partitions, int(args.enable_pipeline)) 333 | elif args.feat_corr: 334 | result_file_name = 'results/%s_n%d_p%d_feat.txt' % (args.dataset, args.n_partitions, int(args.enable_pipeline)) 335 | else: 336 | result_file_name = 'results/%s_n%d_p%d.txt' % (args.dataset, args.n_partitions, int(args.enable_pipeline)) 337 | if args.dataset == 'yelp': 338 | loss_fcn = torch.nn.BCEWithLogitsLoss(reduction='sum') 339 | else: 340 | loss_fcn = torch.nn.CrossEntropyLoss(reduction='sum') 341 | optimizer = torch.optim.Adam(model.parameters(), 342 | lr=args.lr, 343 | weight_decay=args.weight_decay) 344 | 345 | train_dur, comm_dur, reduce_dur = [], [], [] 346 | val_accs = [] 347 | torch.cuda.reset_peak_memory_stats() 348 | thread = None 349 | pool = ThreadPool(processes=1) 350 | 351 | feat = node_dict['feat'] 352 | 353 | node_dict.pop('train_mask') 354 | node_dict.pop('inner_node') 355 | node_dict.pop('part_id') 356 | node_dict.pop(dgl.NID) 357 | 358 | if not args.eval: 359 | if 'val_mask' in node_dict.keys(): 360 | node_dict.pop('val_mask') 361 | if 'test_mask' in node_dict.keys(): 362 | node_dict.pop('test_mask') 363 | 364 | for epoch in range(args.n_epochs): 365 | t0 = time.time() 366 | model.train() 367 | if args.model in ['graphsage', 'deepgcn'] : 368 | logits = model(graph, feat, in_deg) 369 | else: 370 | raise Exception 371 | if args.inductive: 372 | loss = loss_fcn(logits, labels) 373 | else: 374 | loss = loss_fcn(logits[train_mask], labels) 375 | del logits 376 | optimizer.zero_grad(set_to_none=True) 377 | 378 | loss.backward() 379 | 380 | ctx.buffer.next_epoch() 381 | 382 | pre_reduce = time.time() 383 | 384 | reduce_time = time.time() - pre_reduce 385 | optimizer.step() 386 | 387 | if epoch >= 5 and epoch % args.log_every != 0: 388 | if args.debug: 389 | torch.cuda.synchronize() 390 | train_dur.append(time.time() - t0) 391 | comm_dur.append(ctx.comm_timer.tot_time()) 392 | reduce_dur.append(reduce_time) 393 | 394 | if (epoch + 1) % 10 == 0: 395 | print("Process {:03d} | Epoch {:05d} | Time(s) {:.4f} | Comm(s) {:.4f} | Reduce(s) {:.4f} | Loss {:.4f}".format( 396 | rank, epoch, np.mean(train_dur), np.mean(comm_dur), np.mean(reduce_dur), loss.item() / part_train)) 397 | 398 | ctx.comm_timer.clear() 399 | 400 | 401 | loss_scalar = loss.item() 402 | del loss 403 | 404 | if rank == 0 and args.eval and (epoch + 1) % args.log_every == 0: 405 | if thread is not None: 406 | model_copy, val_acc = thread.get() 407 | val_accs.append(val_acc) 408 | if val_acc > best_acc: 409 | best_acc = val_acc 410 | best_model = model_copy 411 | model_copy = copy.deepcopy(model) 412 | if not args.inductive: 413 | thread = pool.apply_async(evaluate_trans, args=('Epoch %05d' % epoch, model_copy, 414 | val_g, result_file_name)) 415 | else: 416 | thread = pool.apply_async(evaluate_induc, args=('Epoch %05d' % epoch, model_copy, 417 | val_g, 'val', result_file_name)) 418 | 419 | 420 | if args.create_json: 421 | info_dict = {} 422 | test_acc = 0.0 423 | 424 | if args.eval and rank == 0 and not args.time_calc: 425 | if thread is not None: 426 | model_copy, val_acc = thread.get() 427 | val_accs.append(val_acc) 428 | if val_acc > best_acc: 429 | best_acc = val_acc 430 | best_model = model_copy 431 | ckpt_path = 'model/' 432 | os.makedirs(ckpt_path, exist_ok=True) 433 | 434 | args.ckpt_str = ckpt_path + args.dataset + '_fullgraph_' + datetime.datetime.now().strftime('%Y_%m_%d__%H_%M_%S') + '.pth.tar' 435 | 436 | torch.save(best_model.state_dict(), args.ckpt_str) 437 | print('model saved') 438 | print("Validation accuracy {:.2%}".format(best_acc)) 439 | _, acc = evaluate_induc('Test Result', best_model, test_g, 'test') 440 | test_acc = acc 441 | 442 | if args.create_json : 443 | info_dict['best_accuracy'] = best_acc 444 | info_dict['test_accuracy'] = test_acc 445 | log_path = args.json_path 446 | os.makedirs(log_path, exist_ok=True) 447 | print('json logs will be saved in ', log_path) 448 | 449 | if not args.eval and rank == 0: 450 | ckpt_path = 'model/' 451 | os.makedirs(ckpt_path, exist_ok=True) 452 | 453 | model_copy = copy.deepcopy(model) 454 | model_copy.cpu() 455 | 456 | args.ckpt_str = ckpt_path + args.dataset + '_fullgraph_model_optimizer_' + datetime.datetime.now().strftime('%Y_%m_%d__%H_%M_%S') + '.pth.tar' 457 | 458 | torch.save({ 459 | 'epoch': args.n_epochs, 460 | 'model_state_dict': model_copy.state_dict(), 461 | 'optimizer_state_dict': optimizer.state_dict() 462 | }, args.ckpt_str) 463 | print('model saved') 464 | 465 | summary_buffer = torch.zeros(4).to(device) 466 | summary_buffer[0] = np.mean(train_dur) 467 | summary_buffer[1] = np.mean(comm_dur) 468 | summary_buffer[2] = np.mean(reduce_dur) 469 | summary_buffer[3] = loss_scalar/part_train 470 | 471 | dist.all_reduce(summary_buffer, op=dist.ReduceOp.SUM) 472 | 473 | summary_buffer = summary_buffer/dist.get_world_size() 474 | 475 | if dist.get_rank() == 0: 476 | 477 | print("="*30, "Speed result summary", "="*30) 478 | 479 | print("train_duration : ", summary_buffer[0].item()) 480 | print("communication_duration : ", summary_buffer[1].item()) 481 | print("reduce duration : ", summary_buffer[2].item()) 482 | print("loss : ", summary_buffer[3].item()) 483 | print("="*60) 484 | 485 | 486 | 487 | if args.create_json: 488 | 489 | log_path = args.json_path 490 | 491 | info_dict['dataset'] = args.dataset 492 | info_dict['rank'] = rank 493 | info_dict['train_dur_mean'] = np.mean(train_dur) 494 | info_dict['comm_dur_mean'] = np.mean(comm_dur) 495 | info_dict['reduce_dur_mean'] = np.mean(reduce_dur) 496 | info_dict['loss'] = loss_scalar/part_train 497 | info_dict['eval_epoch_interval'] = epoch 498 | info_dict['train_dur_array'] = train_dur 499 | info_dict['comm_dur_array'] = comm_dur 500 | info_dict['reduce_dur_array'] = reduce_dur 501 | 502 | for k, v in vars(args).items(): 503 | if 'group' in k : 504 | continue 505 | 506 | info_dict[k] = v 507 | 508 | 509 | info_dict.pop('ssh_username', None) 510 | info_dict.pop('ssh_pwd', None) 511 | 512 | timestr = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 513 | raw_log_path = log_path + '/json_raws' 514 | os.makedirs(raw_log_path, exist_ok=True) 515 | 516 | file_name = raw_log_path + ('/') + 'exp_id_' + str(args.exp_id) + '_' + timestr +'_rank_' + str(dist.get_rank()) 517 | 518 | info_dict['timestr'] = timestr 519 | 520 | 521 | with open(file_name, "w") as outfile: 522 | json.dump(info_dict, outfile) 523 | print("Rank ", dist.get_rank(), "successfully created json log file.") 524 | 525 | 526 | summary_buffer = torch.zeros(4).to(device) 527 | summary_buffer[0] = np.mean(train_dur) 528 | summary_buffer[1] = np.mean(comm_dur) 529 | summary_buffer[2] = np.mean(reduce_dur) 530 | summary_buffer[3] = loss_scalar/part_train 531 | 532 | dist.all_reduce(summary_buffer, op=dist.ReduceOp.SUM) 533 | 534 | summary_buffer = summary_buffer/dist.get_world_size() 535 | 536 | memory_buffer = torch.tensor([torch.cuda.max_memory_reserved(device)]).to(device) 537 | dist.all_reduce(memory_buffer, op=dist.ReduceOp.SUM) 538 | 539 | 540 | if dist.get_rank() == 0: 541 | 542 | best_acc_tensor = torch.tensor([best_acc]).to(device) 543 | info_dict_summary = {} 544 | 545 | summary_buffer = torch.cat((summary_buffer, best_acc_tensor)) 546 | summary_buffer_list = summary_buffer.tolist() 547 | 548 | info_dict_summary['dataset'] = args.dataset 549 | info_dict_summary['train_dur_aggregated'] = summary_buffer[0].item() 550 | info_dict_summary['comm_dur_aggregated'] = summary_buffer[1].item() 551 | info_dict_summary['reduce_dur_aggregated'] = summary_buffer[2].item() 552 | info_dict_summary['loss_aggregated'] = summary_buffer[3].item() 553 | info_dict_summary['best_accuracy'] = summary_buffer[4].item() 554 | info_dict_summary['avg_memory_reserved'] = memory_buffer[0].item() / dist.get_world_size() / (1024*1024*1024) 555 | info_dict_summary['val_accs'] = val_accs 556 | 557 | info_dict_summary['test_accuracy'] = test_acc 558 | 559 | os.makedirs(log_path, exist_ok=True) 560 | 561 | 562 | summary_file_name = log_path + ('/') + 'exp_id_' + str(args.exp_id) + '_' + timestr +'_summary' 563 | 564 | with open(summary_file_name, "w") as outfile: 565 | json.dump(info_dict_summary, outfile) 566 | print("Rank ", dist.get_rank(), "successfully created summary json log file.") 567 | 568 | for k, v in vars(args).items(): 569 | if 'group' in k : 570 | continue 571 | 572 | info_dict_summary[k] = v 573 | 574 | info_dict_summary.pop('ssh_username', None) 575 | info_dict_summary.pop('ssh_pwd', None) 576 | 577 | if args.send_db : 578 | 579 | try : 580 | 581 | db_config['mongo_db'] = args.db_name 582 | db_config['project'] = args.project 583 | db_config['ssh_username'] = args.ssh_user 584 | db_config['ssh_pwd'] = args.ssh_pwd 585 | 586 | try: 587 | 588 | mongo = DBHandler() 589 | except Exception as e: 590 | mongo = DBHandler() 591 | 592 | try: 593 | mongo.insert_item_one(info_dict_summary) 594 | except Exception as e: 595 | mongo.insert_item_one(info_dict_summary) 596 | 597 | print('Rank ', dist.get_rank(), "successfully sended a log to DB.") 598 | 599 | try: 600 | 601 | mongo.close_connection() 602 | except Exception as e: 603 | mongo.close_connection() 604 | 605 | except Exception as e : 606 | print("Sending logs to DB failed. Tracebacks are as follows.") 607 | traceback.print_exc() 608 | 609 | 610 | print("="*30, "Training result summary", "="*30) 611 | print("train_duration : ", info_dict_summary['train_dur_aggregated']) 612 | print("communication_duration : ", info_dict_summary['comm_dur_aggregated']) 613 | print("reduce duration : ", info_dict_summary['reduce_dur_aggregated']) 614 | print("loss : ", info_dict_summary['loss_aggregated']) 615 | print("best accuracy : ", info_dict_summary['best_accuracy']) 616 | print("avg memory : ", info_dict_summary['avg_memory_reserved']) 617 | print("="*60) 618 | 619 | 620 | 621 | def check_parser(args): 622 | if args.norm == 'none': 623 | args.norm = None 624 | 625 | 626 | def init_processes(rank, size, args): 627 | """ Initialize the distributed environment. """ 628 | os.environ['GLOO_SOCKET_IFNAME']='ib0' 629 | os.environ['MASTER_ADDR'] = args.master_addr 630 | os.environ['MASTER_PORT'] = '%d' % args.port 631 | 632 | dist.init_process_group(args.backend, rank=rank, world_size=size, timeout=datetime.timedelta(seconds=10800)) 633 | 634 | if args.backend == 'gloo': 635 | device = 'cuda:0' 636 | torch.cuda.set_device(device) 637 | else: 638 | device = str(rank % torch.cuda.device_count()) 639 | device = "cuda:" + device 640 | torch.cuda.set_device(device) 641 | 642 | 643 | print("My device:", device) 644 | 645 | rank, size = dist.get_rank(), dist.get_world_size() 646 | check_parser(args) 647 | g, node_dict, gpb = load_partition(args, rank) 648 | 649 | torch.manual_seed(args.seed) 650 | 651 | run(g, node_dict, gpb, device, args) 652 | --------------------------------------------------------------------------------