├── README.md ├── config.json ├── datasets.py ├── distributed.py ├── main.py ├── models.py ├── modules.py ├── quant_utils.py ├── quantizers.py ├── range_trackers.py ├── srun.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # QuanTorch 2 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_root": "/mnt/lustre/sensetime/hiroki_sakuma/datasets/mnist/train", 3 | "train_meta": "/mnt/lustre/sensetime/hiroki_sakuma/datasets/mnist/train/meta.json", 4 | "val_root": "/mnt/lustre/sensetime/hiroki_sakuma/datasets/mnist/val", 5 | "val_meta": "/mnt/lustre/sensetime/hiroki_sakuma/datasets/mnist/val/meta.json", 6 | "num_workers": 4, 7 | "local_batch_size": 32, 8 | "global_batch_denom": 128, 9 | "num_training_epochs": 8, 10 | "num_quantization_epochs": 1, 11 | "optimizer": { 12 | "lr": 1e-3, 13 | "betas": [ 14 | 0.9, 15 | 0.999 16 | ], 17 | "eps": 1e-8 18 | } 19 | } -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torch import utils 2 | from PIL import Image 3 | import numpy as np 4 | import json 5 | 6 | 7 | class ImageDataset(utils.data.Dataset): 8 | 9 | def __init__(self, root, meta, transform=None): 10 | self.root = root 11 | self.transform = transform 12 | with open(meta) as file: 13 | self.meta = list(json.load(file).items()) 14 | 15 | def __len__(self): 16 | return len(self.meta) 17 | 18 | def __getitem__(self, idx): 19 | path, label = self.meta[idx] 20 | path = f'{self.root}/{path}' 21 | image = Image.open(path).convert('RGB') 22 | if self.transform is not None: 23 | image = self.transform(image) 24 | return image, label 25 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed 3 | import socket 4 | import os 5 | 6 | 7 | def init_process_group(backend): 8 | 9 | from mpi4py import MPI 10 | 11 | comm = MPI.COMM_WORLD 12 | world_size = comm.Get_size() 13 | rank = comm.Get_rank() 14 | 15 | info = dict() 16 | if rank == 0: 17 | host = socket.gethostname() 18 | address = socket.gethostbyname(host) 19 | info.update(dict(MASTER_ADDR=address, MASTER_PORT='1234')) 20 | 21 | info = comm.bcast(info, root=0) 22 | info.update(dict(WORLD_SIZE=str(world_size), RANK=str(rank))) 23 | os.environ.update(info) 24 | 25 | distributed.init_process_group(backend=backend) 26 | 27 | 28 | def average_gradients(parameters): 29 | world_size = distributed.get_world_size() 30 | for parameter in parameters: 31 | if parameter.requires_grad: 32 | distributed.all_reduce(parameter.grad) 33 | parameter.grad /= world_size 34 | 35 | 36 | def average_tensors(tensors): 37 | world_size = distributed.get_world_size() 38 | for tensor in tensors: 39 | distributed.all_reduce(tensor) 40 | tensor /= world_size 41 | 42 | 43 | def broadcast_tensors(tensors, src_rank=0): 44 | for tensor in tensors: 45 | distributed.broadcast(tensor, src_rank) 46 | 47 | 48 | def all_gather(src_tensor): 49 | for rank in range(distributed.get_world_size()): 50 | if rank == distributed.get_rank(): 51 | distributed.broadcast(src_tensor, rank) 52 | yield src_tensor 53 | else: 54 | dst_tensor = torch.empty_like(src_tensor) 55 | distributed.broadcast(dst_tensor, rank) 56 | yield dst_tensor 57 | 58 | 59 | def dprint(*args, rank=0, **kwargs): 60 | if distributed.get_rank() == rank: 61 | print(*args, **kwargs) 62 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributed 3 | from torch import backends 4 | from torch import cuda 5 | from torch import utils 6 | from torch import optim 7 | from torch import nn 8 | from torchvision import transforms 9 | from datasets import ImageDataset 10 | from models import ConvNet 11 | from distributed import * 12 | from quant_utils import * 13 | from utils import * 14 | import numpy as np 15 | import argparse 16 | import json 17 | import os 18 | 19 | 20 | def main(args): 21 | 22 | init_process_group(backend='nccl') 23 | 24 | with open(args.config) as file: 25 | config = json.load(file) 26 | config.update(vars(args)) 27 | config = apply_dict(Dict, config) 28 | 29 | backends.cudnn.benchmark = True 30 | backends.cudnn.fastest = True 31 | 32 | cuda.set_device(distributed.get_rank() % cuda.device_count()) 33 | 34 | train_dataset = ImageDataset( 35 | root=config.train_root, 36 | meta=config.train_meta, 37 | transform=transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.5,) * 3, (0.5,) * 3) 40 | ]) 41 | ) 42 | val_dataset = ImageDataset( 43 | root=config.val_root, 44 | meta=config.val_meta, 45 | transform=transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.5,) * 3, (0.5,) * 3) 48 | ]) 49 | ) 50 | 51 | train_sampler = utils.data.distributed.DistributedSampler(train_dataset) 52 | val_sampler = utils.data.distributed.DistributedSampler(val_dataset) 53 | 54 | train_data_loader = utils.data.DataLoader( 55 | dataset=train_dataset, 56 | batch_size=config.local_batch_size, 57 | sampler=train_sampler, 58 | num_workers=config.num_workers, 59 | pin_memory=True 60 | ) 61 | val_data_loader = utils.data.DataLoader( 62 | dataset=val_dataset, 63 | batch_size=config.local_batch_size, 64 | sampler=val_sampler, 65 | num_workers=config.num_workers, 66 | pin_memory=True 67 | ) 68 | 69 | model = ConvNet( 70 | conv_params=[ 71 | Dict(in_channels=3, out_channels=32, kernel_size=5, padding=2, stride=2, bias=False), 72 | Dict(in_channels=32, out_channels=64, kernel_size=5, padding=2, stride=2, bias=False), 73 | ], 74 | linear_params=[ 75 | Dict(in_channels=3136, out_channels=1024, kernel_size=1, bias=False), 76 | Dict(in_channels=1024, out_channels=10, kernel_size=1, bias=True), 77 | ] 78 | ) 79 | 80 | config.global_batch_size = config.local_batch_size * distributed.get_world_size() 81 | config.optimizer.lr *= config.global_batch_size / config.global_batch_denom 82 | optimizer = optim.Adam(model.parameters(), **config.optimizer) 83 | 84 | epoch = 0 85 | global_step = 0 86 | if config.checkpoint: 87 | checkpoint = Dict(torch.load(config.checkpoint)) 88 | model.load_state_dict(checkpoint.model_state_dict) 89 | optimizer.load_state_dict(checkpoint.optimizer_state_dict) 90 | epoch = checkpoint.last_epoch + 1 91 | global_step = checkpoint.global_step 92 | 93 | def train(data_loader): 94 | nonlocal global_step 95 | model.train() 96 | for images, labels in data_loader: 97 | images = images.cuda() 98 | labels = labels.cuda() 99 | optimizer.zero_grad() 100 | logits = model(images) 101 | loss = nn.functional.cross_entropy(logits, labels) 102 | loss.backward(retain_graph=True) 103 | average_gradients(model.parameters()) 104 | optimizer.step() 105 | predictions = logits.topk(k=1, dim=1)[1].squeeze() 106 | accuracy = torch.mean((predictions == labels).float()) 107 | average_tensors([loss, accuracy]) 108 | global_step += 1 109 | dprint(f'[training] epoch: {epoch} global_step: {global_step} ' 110 | f'loss: {loss:.4f} accuracy: {accuracy:.4f}') 111 | 112 | @torch.no_grad() 113 | def validate(data_loader): 114 | model.eval() 115 | losses = [] 116 | accuracies = [] 117 | for images, labels in data_loader: 118 | images = images.cuda() 119 | labels = labels.cuda() 120 | logits = model(images) 121 | loss = nn.functional.cross_entropy(logits, labels) 122 | predictions = logits.topk(k=1, dim=1)[1].squeeze() 123 | accuracy = torch.mean((predictions == labels).float()) 124 | average_tensors([loss, accuracy]) 125 | losses.append(loss) 126 | accuracies.append(accuracy) 127 | loss = torch.mean(torch.stack(losses)).item() 128 | accuracy = torch.mean(torch.stack(accuracies)).item() 129 | dprint(f'[validation] epoch: {epoch} global_step: {global_step} ' 130 | f'loss: {loss:.4f} accuracy: {accuracy:.4f}') 131 | 132 | @torch.no_grad() 133 | def feed(data_loader): 134 | model.eval() 135 | for images, _ in data_loader: 136 | images = images.cuda() 137 | logits = model(images) 138 | 139 | def save(): 140 | if not distributed.get_rank(): 141 | os.makedirs('checkpoints', exist_ok=True) 142 | torch.save(dict( 143 | model_state_dict=model.state_dict(), 144 | optimizer_state_dict=optimizer.state_dict(), 145 | last_epoch=epoch, 146 | global_step=global_step 147 | ), os.path.join('checkpoints', f'epoch_{epoch}')) 148 | 149 | if config.training: 150 | model.cuda() 151 | broadcast_tensors(model.state_dict().values()) 152 | for epoch in range(epoch, config.num_training_epochs): 153 | train_sampler.set_epoch(epoch) 154 | train(train_data_loader) 155 | validate(val_data_loader) 156 | save() 157 | 158 | if config.validation: 159 | model.cuda() 160 | broadcast_tensors(model.state_dict().values()) 161 | validate(val_data_loader) 162 | 163 | if config.quantization: 164 | model.cuda() 165 | broadcast_tensors(model.state_dict().values()) 166 | with QuantizationEnabler(model): 167 | with BatchStatsUser(model): 168 | for epoch in range(epoch, config.num_quantization_epochs): 169 | train_sampler.set_epoch(epoch) 170 | train(train_data_loader) 171 | validate(val_data_loader) 172 | save() 173 | with AverageStatsUser(model): 174 | for epoch in range(epoch, config.num_quantization_epochs): 175 | train_sampler.set_epoch(epoch) 176 | train(train_data_loader) 177 | validate(val_data_loader) 178 | save() 179 | 180 | 181 | if __name__ == '__main__': 182 | 183 | parser = argparse.ArgumentParser(description='QuanTorch MNIST Example') 184 | parser.add_argument('--config', type=str, default='') 185 | parser.add_argument('--checkpoint', type=str, default='') 186 | parser.add_argument('--training', action='store_true') 187 | parser.add_argument('--validation', action='store_true') 188 | parser.add_argument('--quantization', action='store_true') 189 | args = parser.parse_args() 190 | 191 | main(args) 192 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from torch import nn 5 | from modules import * 6 | from collections import OrderedDict 7 | 8 | 9 | class ConvNet(nn.Module): 10 | 11 | def __init__(self, conv_params, linear_params): 12 | super().__init__() 13 | self.network = nn.Sequential(OrderedDict( 14 | conv_blocks=nn.Sequential(*[ 15 | nn.Sequential(OrderedDict( 16 | conv2d=BatchNormFoldedQuantizedConv2d(**conv_param), 17 | relu=nn.ReLU() 18 | )) for conv_param in conv_params 19 | ]), 20 | flatten=Flatten(), 21 | unflatten=Unflatten(), 22 | linear_blocks=nn.Sequential(*[ 23 | nn.Sequential(OrderedDict( 24 | conv2d=BatchNormFoldedQuantizedConv2d(**linear_param), 25 | relu=nn.ReLU() 26 | )) for linear_param in linear_params[:-1] 27 | ]), 28 | linear_block=nn.Sequential(OrderedDict( 29 | conv2d=QuantizedConv2d(**linear_params[-1]), 30 | flatten=Flatten() 31 | )) 32 | )) 33 | 34 | def forward(self, images): 35 | return self.network(images) 36 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import autograd 4 | from torch import distributed 5 | from distributed import * 6 | from quantizers import * 7 | from range_trackers import * 8 | from enum import * 9 | 10 | 11 | class Identity(nn.Module): 12 | def forward(self, inputs): 13 | return inputs 14 | 15 | 16 | class Reshape(nn.Module): 17 | def __init__(self, *shape): 18 | super().__init__() 19 | self.shape = shape 20 | 21 | def forward(self, inputs): 22 | return inputs.reshape(*self.shape) 23 | 24 | 25 | class Flatten(nn.Module): 26 | def forward(self, inputs): 27 | assert inputs.dim() == 4 28 | return inputs.reshape(inputs.size(0), -1) 29 | 30 | 31 | class Unflatten(nn.Module): 32 | def forward(self, inputs): 33 | assert inputs.dim() == 2 34 | return inputs.reshape(inputs.size(0), -1, 1, 1) 35 | 36 | 37 | class QuantizedConv2d(nn.Conv2d): 38 | 39 | def __init__( 40 | self, 41 | in_channels, 42 | out_channels, 43 | kernel_size, 44 | stride=1, 45 | padding=0, 46 | dilation=1, 47 | groups=1, 48 | bias=True, 49 | activation_quantizer=None, 50 | weight_quantizer=None 51 | ): 52 | super().__init__( 53 | in_channels=in_channels, 54 | out_channels=out_channels, 55 | kernel_size=kernel_size, 56 | stride=stride, 57 | padding=padding, 58 | dilation=dilation, 59 | groups=groups, 60 | bias=bias 61 | ) 62 | self.activation_quantizer = activation_quantizer or AsymmetricQuantizer( 63 | bits_precision=8, 64 | range_tracker=AveragedRangeTracker((1, 1, 1, 1)) 65 | ) 66 | self.weight_quantizer = weight_quantizer or AsymmetricQuantizer( 67 | bits_precision=8, 68 | range_tracker=GlobalRangeTracker((1, out_channels, 1, 1)) 69 | ) 70 | 71 | self.quantization = False 72 | 73 | def enable_quantization(self): 74 | self.quantization = True 75 | 76 | def disable_quantization(self): 77 | self.quantization = False 78 | 79 | def forward(self, inputs): 80 | 81 | weight = self.weight 82 | if self.quantization: 83 | inputs = self.activation_quantizer(inputs) 84 | weight = self.weight_quantizer(self.weight) 85 | 86 | outputs = nn.functional.conv2d( 87 | input=inputs, 88 | weight=weight, 89 | bias=self.bias, 90 | stride=self.stride, 91 | padding=self.padding, 92 | dilation=self.dilation, 93 | groups=self.groups 94 | ) 95 | 96 | return outputs 97 | 98 | 99 | class BatchNormFoldedQuantizedConv2d(QuantizedConv2d): 100 | 101 | def __init__( 102 | self, 103 | in_channels, 104 | out_channels, 105 | kernel_size, 106 | stride=1, 107 | padding=0, 108 | dilation=1, 109 | groups=1, 110 | bias=False, 111 | eps=1e-5, 112 | momentum=0.1, 113 | activation_quantizer=None, 114 | weight_quantizer=None 115 | ): 116 | assert bias is False 117 | 118 | super().__init__( 119 | in_channels=in_channels, 120 | out_channels=out_channels, 121 | kernel_size=kernel_size, 122 | stride=stride, 123 | padding=padding, 124 | dilation=dilation, 125 | groups=groups, 126 | bias=bias, 127 | activation_quantizer=activation_quantizer, 128 | weight_quantizer=weight_quantizer 129 | ) 130 | 131 | self.eps = eps 132 | self.momentum = momentum 133 | 134 | self.register_parameter('beta', nn.Parameter(torch.zeros(out_channels))) 135 | self.register_parameter('gamma', nn.Parameter(torch.ones(out_channels))) 136 | self.register_buffer('running_mean', torch.zeros(out_channels)) 137 | self.register_buffer('running_var', torch.ones(out_channels)) 138 | 139 | self.batch_stats = True 140 | 141 | def use_batch_stats(self): 142 | self.batch_stats = True 143 | 144 | def use_running_stats(self): 145 | self.batch_stats = False 146 | 147 | def forward(self, inputs): 148 | 149 | def reshape_to_activation(inputs): 150 | return inputs.reshape(1, -1, 1, 1) 151 | 152 | def reshape_to_weight(inputs): 153 | return inputs.reshape(-1, 1, 1, 1) 154 | 155 | def reshape_to_bias(inputs): 156 | return inputs.reshape(-1) 157 | 158 | if self.training: 159 | 160 | outputs = nn.functional.conv2d( 161 | input=inputs, 162 | weight=self.weight, 163 | bias=self.bias, 164 | stride=self.stride, 165 | padding=self.padding, 166 | dilation=self.dilation, 167 | groups=self.groups 168 | ) 169 | dims = [dim for dim in range(4) if dim != 1] 170 | batch_mean = torch.mean(outputs, dim=dims) 171 | batch_var = torch.var(outputs, dim=dims) 172 | batch_std = torch.sqrt(batch_var + self.eps) 173 | 174 | self.running_mean = self.running_mean * (1 - self.momentum) + batch_mean * self.momentum 175 | self.running_var = self.running_var * (1 - self.momentum) + batch_var * self.momentum 176 | 177 | running_mean = self.running_mean 178 | running_var = self.running_var 179 | running_std = torch.sqrt(running_var + self.eps) 180 | 181 | weight = self.weight * reshape_to_weight(self.gamma / running_std) 182 | bias = reshape_to_bias(self.beta - self.gamma * running_mean / running_std) 183 | 184 | if self.quantization: 185 | inputs = self.activation_quantizer(inputs) 186 | weight = self.weight_quantizer(weight) 187 | 188 | outputs = nn.functional.conv2d( 189 | input=inputs, 190 | weight=weight, 191 | bias=bias, 192 | stride=self.stride, 193 | padding=self.padding, 194 | dilation=self.dilation, 195 | groups=self.groups 196 | ) 197 | 198 | if self.training and self.batch_stats: 199 | outputs *= reshape_to_activation(running_std / batch_std) 200 | outputs += reshape_to_activation(self.gamma * (running_mean / running_std - batch_mean / batch_std)) 201 | 202 | return outputs 203 | -------------------------------------------------------------------------------- /quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from modules import * 4 | from distributed import * 5 | 6 | 7 | class QuantizationEnabler(object): 8 | 9 | def __init__(self, model): 10 | self.model = model 11 | 12 | def __enter__(self): 13 | for module in self.model.modules(): 14 | if isinstance(module, QuantizedConv2d): 15 | module.enable_quantization() 16 | 17 | def __exit__(self, exc_type, exc_value, traceback): 18 | for module in self.model.modules(): 19 | if isinstance(module, QuantizedConv2d): 20 | module.disable_quantization() 21 | 22 | 23 | class BatchStatsUser(object): 24 | 25 | def __init__(self, model): 26 | self.model = model 27 | 28 | def __enter__(self): 29 | for module in self.model.modules(): 30 | if isinstance(module, BatchNormFoldedQuantizedConv2d): 31 | module.use_batch_stats() 32 | 33 | def __exit__(self, exc_type, exc_value, traceback): 34 | pass 35 | 36 | 37 | class AverageStatsUser(object): 38 | 39 | def __init__(self, model): 40 | self.model = model 41 | 42 | def __enter__(self): 43 | for module in self.model.modules(): 44 | if isinstance(module, BatchNormFoldedQuantizedConv2d): 45 | module.use_average_stats() 46 | 47 | def __exit__(self, exc_type, exc_value, traceback): 48 | pass 49 | -------------------------------------------------------------------------------- /quantizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import autograd 4 | 5 | 6 | class Round(autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, inputs): 10 | return torch.floor(inputs + 0.5) 11 | 12 | @staticmethod 13 | def backward(ctx, grads): 14 | return grads 15 | 16 | 17 | class Quantizer(nn.Module): 18 | 19 | def __init__(self, bits_precision, range_tracker): 20 | super().__init__() 21 | self.bits_precision = bits_precision 22 | self.range_tracker = range_tracker 23 | self.register_buffer('scale', None) 24 | self.register_buffer('zero_point', None) 25 | 26 | def update_params(self): 27 | raise NotImplementedError 28 | 29 | def quantize(self, inputs): 30 | outputs = inputs * self.scale - self.zero_point 31 | return outputs 32 | 33 | def round(self, inputs): 34 | # outputs = torch.round(inputs) + inputs - inputs.detach() 35 | outputs = Round.apply(inputs) 36 | return outputs 37 | 38 | def clamp(self, inputs): 39 | outputs = torch.clamp(inputs, self.min_val, self.max_val) 40 | return outputs 41 | 42 | def dequantize(self, inputs): 43 | outputs = (inputs + self.zero_point) / self.scale 44 | return outputs 45 | 46 | def forward(self, inputs): 47 | self.range_tracker(inputs) 48 | self.update_params() 49 | outputs = self.quantize(inputs) 50 | outputs = self.round(outputs) 51 | outputs = self.clamp(outputs) 52 | outputs = self.dequantize(outputs) 53 | return outputs 54 | 55 | 56 | class SignedQuantizer(Quantizer): 57 | 58 | def __init__(self, *args, **kwargs): 59 | super().__init__(*args, **kwargs) 60 | self.register_buffer('min_val', torch.tensor(-(1 << (self.bits_precision - 1)))) 61 | self.register_buffer('max_val', torch.tensor((1 << (self.bits_precision - 1)) - 1)) 62 | 63 | 64 | class UnsignedQuantizer(SignedQuantizer): 65 | 66 | def __init__(self, *args, **kwargs): 67 | super().__init__(*args, **kwargs) 68 | self.register_buffer('min_val', torch.tensor(0)) 69 | self.register_buffer('max_val', torch.tensor((1 << self.bits_precision) - 1)) 70 | 71 | 72 | class SymmetricQuantizer(SignedQuantizer): 73 | 74 | def update_params(self): 75 | quantized_range = torch.min(torch.abs(self.min_val), torch.abs(self.max_val)) 76 | float_range = torch.max(torch.abs(self.range_tracker.min_val), torch.abs(self.range_tracker.max_val)) 77 | self.scale = quantized_range / float_range 78 | self.zero_point = torch.zeros_like(self.scale) 79 | 80 | 81 | class AsymmetricQuantizer(UnsignedQuantizer): 82 | 83 | def update_params(self): 84 | quantized_range = self.max_val - self.min_val 85 | float_range = self.range_tracker.max_val - self.range_tracker.min_val 86 | self.scale = quantized_range / float_range 87 | self.zero_point = torch.round(self.range_tracker.min_val * self.scale) 88 | -------------------------------------------------------------------------------- /range_trackers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from distributed import * 4 | 5 | 6 | class RangeTracker(nn.Module): 7 | 8 | def __init__(self, shape): 9 | super().__init__() 10 | self.shape = shape 11 | self.register_buffer('min_val', None) 12 | self.register_buffer('max_val', None) 13 | 14 | def update_range(self, min_val, max_val): 15 | raise NotImplementedError 16 | 17 | @torch.no_grad() 18 | def forward(self, inputs): 19 | 20 | keep_dims = [dim for dim, size in enumerate(self.shape) if size != 1] 21 | reduce_dims = [dim for dim, size in enumerate(self.shape) if size == 1] 22 | permute_dims = [*keep_dims, *reduce_dims] 23 | repermute_dims = [permute_dims.index(dim) for dim, size in enumerate(self.shape)] 24 | 25 | inputs = inputs.permute(*permute_dims) 26 | inputs = inputs.reshape(*inputs.shape[:len(keep_dims)], -1) 27 | 28 | min_val = torch.min(inputs, dim=-1, keepdim=True)[0] 29 | min_val = min_val.reshape(*inputs.shape[:len(keep_dims)], *[1] * len(reduce_dims)) 30 | min_val = min_val.permute(*repermute_dims) 31 | 32 | max_val = torch.max(inputs, dim=-1, keepdim=True)[0] 33 | max_val = max_val.reshape(*inputs.shape[:len(keep_dims)], *[1] * len(reduce_dims)) 34 | max_val = max_val.permute(*repermute_dims) 35 | 36 | min_val = torch.min(torch.stack(list(all_gather(min_val)))) 37 | max_val = torch.max(torch.stack(list(all_gather(max_val)))) 38 | 39 | self.update_range(min_val, max_val) 40 | 41 | 42 | class GlobalRangeTracker(RangeTracker): 43 | 44 | def __init__(self, shape): 45 | super().__init__(shape) 46 | 47 | def update_range(self, min_val, max_val): 48 | self.min_val = torch.min(self.min_val, min_val) if self.min_val is not None else min_val 49 | self.max_val = torch.max(self.max_val, max_val) if self.max_val is not None else max_val 50 | 51 | 52 | class AveragedRangeTracker(RangeTracker): 53 | 54 | def __init__(self, shape, momentum=0.1): 55 | super().__init__(shape) 56 | self.momentum = momentum 57 | 58 | def update_range(self, min_val, max_val): 59 | self.min_val = self.min_val * (1 - self.momentum) + min_val * self.momentum if self.min_val is not None else min_val 60 | self.max_val = self.max_val * (1 - self.momentum) + max_val * self.momentum if self.max_val is not None else max_val 61 | -------------------------------------------------------------------------------- /srun.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import argparse 3 | import textwrap 4 | 5 | 6 | if __name__ == '__main__': 7 | 8 | parser = argparse.ArgumentParser(description='srun') 9 | parser.add_argument('--partition', type=str, default='16gV100') 10 | parser.add_argument('--num_nodes', type=int, default=1) 11 | parser.add_argument('--num_gpus', type=int, default=8) 12 | parser.add_argument('--nodelist', type=str, default='') 13 | known_args, unknown_args = parser.parse_known_args() 14 | 15 | command = textwrap.dedent(f'''\ 16 | srun \ 17 | --mpi=pmi2 \ 18 | --partition={known_args.partition} \ 19 | --nodes={known_args.num_nodes} \ 20 | --ntasks-per-node={known_args.num_gpus} \ 21 | --ntasks={known_args.num_nodes * known_args.num_gpus} \ 22 | --gres=gpu:{known_args.num_gpus} \ 23 | --nodelist={known_args.nodelist} \ 24 | python -u {' '.join(unknown_args)} 25 | ''') 26 | 27 | subprocess.call(command.split()) 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class Dict(dict): 2 | def __getattr__(self, name): return self[name] 3 | def __setattr__(self, name, value): self[name] = value 4 | def __delattr__(self, name): del self[name] 5 | 6 | 7 | def apply_dict(function, dictionary): 8 | if isinstance(dictionary, dict): 9 | for key, value in dictionary.items(): 10 | dictionary[key] = apply_dict(function, value) 11 | dictionary = function(dictionary) 12 | return dictionary 13 | --------------------------------------------------------------------------------