├── .gitignore ├── tensorboard.png ├── worker_requirements.txt ├── requirements.txt ├── dawn ├── logs │ └── single_machine │ │ ├── events.out.tfevents.1536151072.ip-192-168-84-137 │ │ └── event.log ├── prepare_dawn_is.py ├── prepare_dawn_lr.py ├── prepare_dawn_bs.py └── prepare_dawn_tsv.py ├── tools ├── launch_tensorboard.py ├── create_imagenet_snapshot.py └── replicate_imagenet.py ├── training ├── dist_utils.py ├── experimental_utils.py ├── meter.py ├── logger.py ├── fp16util.py ├── resnet.py ├── dataloader.py └── train_imagenet_nv.py ├── LICENSE ├── setup.sh ├── README.md ├── util.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | __pycache__ 3 | /.idea 4 | -------------------------------------------------------------------------------- /tensorboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/imagenet18/HEAD/tensorboard.png -------------------------------------------------------------------------------- /worker_requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | torch 3 | torchvision 4 | wandb 5 | ec2-metadata 6 | tensorboardX 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | awscli 2 | boto3 3 | ncluster 4 | paramiko 5 | portpicker 6 | tensorflow 7 | tzlocal 8 | tqdm 9 | tensorboardX 10 | torch 11 | wandb 12 | ec2-metadata 13 | -------------------------------------------------------------------------------- /dawn/logs/single_machine/events.out.tfevents.1536151072.ip-192-168-84-137: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cybertronai/imagenet18/HEAD/dawn/logs/single_machine/events.out.tfevents.1536151072.ip-192-168-84-137 -------------------------------------------------------------------------------- /tools/launch_tensorboard.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Usage: 3 | # ./launch_tensorboard.py 4 | # 5 | # This will launch r5.large machine on AWS with tensoboard, and print URL 6 | # in the console 7 | import ncluster 8 | 9 | task = ncluster.make_task('tensorboard', 10 | instance_type='r5.large', 11 | run_name='tensorboard', 12 | image_name='Deep Learning AMI (Ubuntu) Version 23.0') 13 | task.run('source activate tensorflow_p36') 14 | task.run(f'tensorboard --logdir={task.logdir}/..', non_blocking=True) 15 | print(f"Tensorboard at http://{task.public_ip}:6006") 16 | -------------------------------------------------------------------------------- /training/dist_utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | from torch.nn.parallel import DistributedDataParallel 3 | import os 4 | 5 | class DDP(DistributedDataParallel): 6 | # Distributed wrapper. Supports asynchronous evaluation and model saving 7 | def forward(self, *args, **kwargs): 8 | # DDP has a sync point on forward. No need to do this for eval. This allows us to have different batch sizes 9 | if self.training: return super().forward(*args, **kwargs) 10 | else: return self.module(*args, **kwargs) 11 | 12 | def load_state_dict(self, *args, **kwargs): 13 | self.module.load_state_dict(*args, **kwargs) 14 | 15 | def state_dict(self, *args, **kwargs): 16 | return self.module.state_dict(*args, **kwargs) 17 | 18 | 19 | 20 | 21 | def reduce_tensor(tensor): return sum_tensor(tensor)/env_world_size() 22 | def sum_tensor(tensor): 23 | rt = tensor.clone() 24 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 25 | return rt 26 | 27 | def env_world_size(): return int(os.environ['WORLD_SIZE']) 28 | def env_rank(): return int(os.environ['RANK']) 29 | 30 | -------------------------------------------------------------------------------- /training/experimental_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Filter out batch norm parameters and remove them from weight decay - gets us higher accuracy 93.2 -> 93.48 4 | # https://arxiv.org/pdf/1807.11205.pdf 5 | def bnwd_optim_params(model, model_params, master_params): 6 | bn_params, remaining_params = split_bn_params(model, model_params, master_params) 7 | return [{'params':bn_params,'weight_decay':0}, {'params':remaining_params}] 8 | 9 | 10 | def split_bn_params(model, model_params, master_params): 11 | def get_bn_params(module): 12 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters() 13 | accum = set() 14 | for child in module.children(): [accum.add(p) for p in get_bn_params(child)] 15 | return accum 16 | 17 | mod_bn_params = get_bn_params(model) 18 | zipped_params = list(zip(model_params, master_params)) 19 | 20 | mas_bn_params = [p_mast for p_mod,p_mast in zipped_params if p_mod in mod_bn_params] 21 | mas_rem_params = [p_mast for p_mod,p_mast in zipped_params if p_mod not in mod_bn_params] 22 | return mas_bn_params, mas_rem_params 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to [http://unlicense.org] 25 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # ImageNet training setup script for pytorch.imagenet.source.v7 AMI 4 | # That image is a fork of Ubuntu DLAMI v12 with "pytorch_source" conda env added which is a clone of pytorch_p36 with 5 | # the following modifications: 6 | # - pip install tqdm 7 | # - pip install tensorboardX 8 | # - version of PyTorch built from master around August 9 | 10 | # source activate pytorch_source 11 | 12 | # index file used to speed up evaluation 13 | pushd ~/data/imagenet 14 | wget --no-clobber https://s3.amazonaws.com/yaroslavvb2/data/sorted_idxar.p 15 | popd 16 | 17 | pip uninstall pillow -y 18 | CC="cc -mavx2" pip install -U pillow-simd 19 | 20 | # setting network settings - 21 | # https://github.com/aws-samples/deep-learning-models/blob/5f00600ebd126410ee5a85ddc30ff2c4119681e4/hpc-cluster/prep_client.sh 22 | sudo sysctl -w net.core.rmem_max=16777216 23 | sudo sysctl -w net.core.wmem_max=16777216 24 | sudo sysctl -w net.ipv4.tcp_rmem='4096 87380 16777216' 25 | sudo sysctl -w net.ipv4.tcp_wmem='4096 65536 16777216' 26 | sudo sysctl -w net.core.netdev_max_backlog=30000 27 | sudo sysctl -w net.core.rmem_default=16777216 28 | sudo sysctl -w net.core.wmem_default=16777216 29 | sudo sysctl -w net.ipv4.tcp_mem='16777216 16777216 16777216' 30 | sudo sysctl -w net.ipv4.route.flush=1 31 | 32 | pip install -r worker_requirements.txt 33 | 34 | -------------------------------------------------------------------------------- /tools/create_imagenet_snapshot.py: -------------------------------------------------------------------------------- 1 | """ 2 | ncluster launch --instance_type=r5.16xlarge --name=imagenet-prep 3 | ncluster connect imagenet-prep 4 | 5 | # download dataset 6 | cd ~/ 7 | wget https://s3.amazonaws.com/yaroslavvb2/data/imagenet18.tar 8 | 9 | # create volume and attach it 10 | source activate pytorch_p36 11 | 12 | < add AWS credentials > 13 | 14 | pip install ec2-metadata ncluster 15 | python 16 | from ncluster import aws_util as u 17 | from ec2_metadata import ec2_metadata 18 | ec2 = u.get_ec2_resource() 19 | 20 | def create_tags(name): 21 | return [{ 22 | 'ResourceType': 'volume', 23 | 'Tags': [{ 24 | 'Key': 'Name', 25 | 'Value': name 26 | }] 27 | }] 28 | 29 | vol = ec2.create_volume(Size=400, TagSpecifications=create_tags('imagenet18'), AvailabilityZone=ec2_metadata.availability_zone, VolumeType='gp2') 30 | vol = ec2.Volume('vol-0fd20d716517c942d') 31 | 32 | instance = ec2.Instance(ec2_metadata.instance_id) 33 | device_name = '/dev/xvdh' # or /dev/nvme1n1 34 | instance.attach_volume(Device=device_name, VolumeId=vol.id) 35 | vol.reload() 36 | assert ec2_metadata.instance_id in str(vol.attachments) 37 | 38 | # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ebs-using-volumes.html 39 | lsblk # get name of device (look for one with 300 size like nvme1n1, then dev name is /dev/nvme1n1) 40 | sudo file -s /dev/nvme1n1 41 | 42 | sudo mkfs -t ext4 /dev/nvme1n1 43 | sudo umount data || echo skipping 44 | sudo mkdir -p /data 45 | sudo chown `whoami` /data 46 | sudo mount /dev/nvme1n1 /data 47 | 48 | 49 | cd /data 50 | tar xf ~/imagenet18.tar --strip 1 51 | 52 | 53 | # create snapshot 54 | snapshot = ec2.create_snapshot(Description=f'{u.get_name(vol)} snapshot', VolumeId=vol.id,) 55 | """ 56 | -------------------------------------------------------------------------------- /dawn/prepare_dawn_is.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Prepares DAWN TSV file from TensorBoard events url 4 | 5 | import sys, os, re 6 | from dateutil import parser 7 | 8 | events_url = 'https://s3.amazonaws.com/yaroslavvb/logs/release-sixteen.04.events' 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import datetime as dt 14 | import pytz 15 | from tensorflow.python.summary import summary_iterator 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='launch') 19 | parser.add_argument('--ignore-eval', action='store_true', 20 | help='ignore eval time') 21 | args = parser.parse_args() 22 | 23 | def get_events(fname, x_axis='step'): 24 | """Returns event dictionary for given run, has form 25 | {tag1: {step1: val1}, tag2: ..} 26 | 27 | If x_axis is set to "time", step is replaced by timestamp 28 | """ 29 | result = {} 30 | 31 | events = summary_iterator.summary_iterator(fname) 32 | 33 | try: 34 | for event in events: 35 | if x_axis == 'step': 36 | x_val = event.step 37 | elif x_axis == 'time': 38 | x_val = event.wall_time 39 | else: 40 | assert False, f"Unknown x_axis ({x_axis})" 41 | 42 | vals = {val.tag: val.simple_value for val in event.summary.value} 43 | # step_time: value 44 | for tag in vals: 45 | event_dict = result.setdefault(tag, {}) 46 | if x_val in event_dict: 47 | print(f"Warning, overwriting {tag} for {x_axis}={x_val}") 48 | print(f"old val={event_dict[x_val]}") 49 | print(f"new val={vals[tag]}") 50 | 51 | event_dict[x_val] = vals[tag] 52 | except Exception as e: 53 | print(e) 54 | pass 55 | 56 | return result 57 | 58 | def datetime_from_seconds(seconds, timezone="US/Pacific"): 59 | """ 60 | timezone: pytz timezone name to use for conversion, ie, UTC or US/Pacific 61 | """ 62 | return dt.datetime.fromtimestamp(seconds, pytz.timezone(timezone)) 63 | 64 | 65 | def download_file(url): 66 | import urllib.request 67 | response = urllib.request.urlopen(url) 68 | data = response.read() 69 | return data 70 | 71 | def main(): 72 | with open('/tmp/events', 'wb') as f: 73 | f.write(download_file(events_url)) 74 | 75 | 76 | events_dict=get_events('/tmp/events', 'step') 77 | 78 | # build step->time dict for eval events 79 | lr = events_dict['losses/test1'] 80 | for step in lr: 81 | print('{"image_size": '+str(lr[step])+', "example": '+str(step)+"},") 82 | 83 | if __name__=='__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /dawn/prepare_dawn_lr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Prepares DAWN TSV file from TensorBoard events url 4 | 5 | import sys, os, re 6 | from dateutil import parser 7 | 8 | events_url = 'https://s3.amazonaws.com/yaroslavvb/logs/release-sixteen.04.events' 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import datetime as dt 14 | import pytz 15 | from tensorflow.python.summary import summary_iterator 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='launch') 19 | parser.add_argument('--ignore-eval', action='store_true', 20 | help='ignore eval time') 21 | args = parser.parse_args() 22 | 23 | def get_events(fname, x_axis='step'): 24 | """Returns event dictionary for given run, has form 25 | {tag1: {step1: val1}, tag2: ..} 26 | 27 | If x_axis is set to "time", step is replaced by timestamp 28 | """ 29 | result = {} 30 | 31 | events = summary_iterator.summary_iterator(fname) 32 | 33 | try: 34 | for event in events: 35 | if x_axis == 'step': 36 | x_val = event.step 37 | elif x_axis == 'time': 38 | x_val = event.wall_time 39 | else: 40 | assert False, f"Unknown x_axis ({x_axis})" 41 | 42 | vals = {val.tag: val.simple_value for val in event.summary.value} 43 | # step_time: value 44 | for tag in vals: 45 | event_dict = result.setdefault(tag, {}) 46 | if x_val in event_dict: 47 | print(f"Warning, overwriting {tag} for {x_axis}={x_val}") 48 | print(f"old val={event_dict[x_val]}") 49 | print(f"new val={vals[tag]}") 50 | 51 | event_dict[x_val] = vals[tag] 52 | except Exception as e: 53 | print(e) 54 | pass 55 | 56 | return result 57 | 58 | def datetime_from_seconds(seconds, timezone="US/Pacific"): 59 | """ 60 | timezone: pytz timezone name to use for conversion, ie, UTC or US/Pacific 61 | """ 62 | return dt.datetime.fromtimestamp(seconds, pytz.timezone(timezone)) 63 | 64 | 65 | def download_file(url): 66 | import urllib.request 67 | response = urllib.request.urlopen(url) 68 | data = response.read() 69 | return data 70 | 71 | def main(): 72 | with open('/tmp/events', 'wb') as f: 73 | f.write(download_file(events_url)) 74 | 75 | 76 | events_dict=get_events('/tmp/events', 'step') 77 | 78 | # build step->time dict for eval events 79 | lr = events_dict['sizes/lr'] 80 | for step in lr: 81 | print('{"learning_rate": '+str(lr[step])+', "example": '+str(step)+"},") 82 | 83 | if __name__=='__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /dawn/prepare_dawn_bs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Prepares DAWN TSV file from TensorBoard events url 4 | 5 | import sys, os, re 6 | from dateutil import parser 7 | 8 | events_url = 'https://s3.amazonaws.com/yaroslavvb/logs/release-sixteen.04.events' 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import datetime as dt 14 | import pytz 15 | from tensorflow.python.summary import summary_iterator 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='launch') 19 | parser.add_argument('--ignore-eval', action='store_true', 20 | help='ignore eval time') 21 | args = parser.parse_args() 22 | 23 | def get_events(fname, x_axis='step'): 24 | """Returns event dictionary for given run, has form 25 | {tag1: {step1: val1}, tag2: ..} 26 | 27 | If x_axis is set to "time", step is replaced by timestamp 28 | """ 29 | result = {} 30 | 31 | events = summary_iterator.summary_iterator(fname) 32 | 33 | try: 34 | for event in events: 35 | if x_axis == 'step': 36 | x_val = event.step 37 | elif x_axis == 'time': 38 | x_val = event.wall_time 39 | else: 40 | assert False, f"Unknown x_axis ({x_axis})" 41 | 42 | vals = {val.tag: val.simple_value for val in event.summary.value} 43 | # step_time: value 44 | for tag in vals: 45 | event_dict = result.setdefault(tag, {}) 46 | if x_val in event_dict: 47 | print(f"Warning, overwriting {tag} for {x_axis}={x_val}") 48 | print(f"old val={event_dict[x_val]}") 49 | print(f"new val={vals[tag]}") 50 | 51 | event_dict[x_val] = vals[tag] 52 | except Exception as e: 53 | print(e) 54 | pass 55 | 56 | return result 57 | 58 | def datetime_from_seconds(seconds, timezone="US/Pacific"): 59 | """ 60 | timezone: pytz timezone name to use for conversion, ie, UTC or US/Pacific 61 | """ 62 | return dt.datetime.fromtimestamp(seconds, pytz.timezone(timezone)) 63 | 64 | 65 | def download_file(url): 66 | import urllib.request 67 | response = urllib.request.urlopen(url) 68 | data = response.read() 69 | return data 70 | 71 | def main(): 72 | with open('/tmp/events', 'wb') as f: 73 | f.write(download_file(events_url)) 74 | 75 | 76 | events_dict=get_events('/tmp/events', 'step') 77 | 78 | # build step->time dict for eval events 79 | lr = events_dict['sizes/batch'] 80 | for step in lr: 81 | print('{"batch_size": '+str(16*8*lr[step])+', "example": '+str(step)+"},") 82 | 83 | if __name__=='__main__': 84 | main() 85 | -------------------------------------------------------------------------------- /training/meter.py: -------------------------------------------------------------------------------- 1 | 2 | import subprocess, time 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | def __init__(self, avg_mom=0.5): 7 | self.avg_mom = avg_mom 8 | self.reset() 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 # running average of whole epoch 13 | self.smooth_avg = 0 14 | self.sum = 0 15 | self.count = 0 16 | 17 | def update(self, val, n=1): 18 | self.val = val 19 | self.sum += val * n 20 | self.count += n 21 | self.smooth_avg = val if self.count == 0 else self.avg*self.avg_mom + val*(1-self.avg_mom) 22 | self.avg = self.sum / self.count 23 | 24 | class NetworkMeter: 25 | def __init__(self): 26 | self.recv_meter = AverageMeter() 27 | self.transmit_meter = AverageMeter() 28 | self.last_recv_bytes, self.last_transmit_bytes = network_bytes() 29 | self.last_log_time = time.time() 30 | 31 | def update_bandwidth(self): 32 | time_delta = time.time()-self.last_log_time 33 | recv_bytes, transmit_bytes = network_bytes() 34 | 35 | recv_delta = recv_bytes - self.last_recv_bytes 36 | transmit_delta = transmit_bytes - self.last_transmit_bytes 37 | 38 | # turn into Gbps 39 | recv_gbit = 8*recv_delta/time_delta/1e9 40 | transmit_gbit = 8*transmit_delta/time_delta/1e9 41 | self.recv_meter.update(recv_gbit) 42 | self.transmit_meter.update(transmit_gbit) 43 | 44 | self.last_log_time = time.time() 45 | self.last_recv_bytes = recv_bytes 46 | self.last_transmit_bytes = transmit_bytes 47 | return recv_gbit, transmit_gbit 48 | 49 | class TimeMeter: 50 | def __init__(self): 51 | self.batch_time = AverageMeter() 52 | self.data_time = AverageMeter() 53 | self.start = time.time() 54 | 55 | def batch_start(self): 56 | self.data_time.update(time.time() - self.start) 57 | 58 | def batch_end(self): 59 | self.batch_time.update(time.time() - self.start) 60 | self.start = time.time() 61 | 62 | 63 | ################################################################################ 64 | # Generic utility methods, eventually refactor into separate file 65 | ################################################################################ 66 | def network_bytes(): 67 | """Returns received bytes, transmitted bytes.""" 68 | 69 | proc = subprocess.Popen(['cat', '/proc/net/dev'], stdout=subprocess.PIPE) 70 | stdout,stderr = proc.communicate() 71 | stdout=stdout.decode('ascii') 72 | 73 | recv_bytes = 0 74 | transmit_bytes = 0 75 | lines=stdout.strip().split('\n') 76 | lines = lines[2:] # strip header 77 | for line in lines: 78 | line = line.strip() 79 | # ignore loopback interface 80 | if line.startswith('lo'): 81 | continue 82 | toks = line.split() 83 | 84 | recv_bytes += int(toks[1]) 85 | transmit_bytes += int(toks[9]) 86 | return recv_bytes, transmit_bytes 87 | 88 | ################################################################################ 89 | -------------------------------------------------------------------------------- /dawn/prepare_dawn_tsv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # Prepares DAWN TSV file from TensorBoard events url 4 | 5 | import sys, os, re 6 | from dateutil import parser 7 | 8 | events_url = 'https://s3.amazonaws.com/yaroslavvb/logs/release-sixteen.04.events' 9 | 10 | import os 11 | import glob 12 | import numpy as np 13 | import datetime as dt 14 | import pytz 15 | from tensorflow.python.summary import summary_iterator 16 | import argparse 17 | 18 | parser = argparse.ArgumentParser(description='launch') 19 | parser.add_argument('--ignore-eval', action='store_true', 20 | help='ignore eval time') 21 | args = parser.parse_args() 22 | 23 | def get_events(fname, x_axis='step'): 24 | """Returns event dictionary for given run, has form 25 | {tag1: {step1: val1}, tag2: ..} 26 | 27 | If x_axis is set to "time", step is replaced by timestamp 28 | """ 29 | result = {} 30 | 31 | events = summary_iterator.summary_iterator(fname) 32 | 33 | try: 34 | for event in events: 35 | if x_axis == 'step': 36 | x_val = event.step 37 | elif x_axis == 'time': 38 | x_val = event.wall_time 39 | else: 40 | assert False, f"Unknown x_axis ({x_axis})" 41 | 42 | vals = {val.tag: val.simple_value for val in event.summary.value} 43 | # step_time: value 44 | for tag in vals: 45 | event_dict = result.setdefault(tag, {}) 46 | if x_val in event_dict: 47 | print(f"Warning, overwriting {tag} for {x_axis}={x_val}") 48 | print(f"old val={event_dict[x_val]}") 49 | print(f"new val={vals[tag]}") 50 | 51 | event_dict[x_val] = vals[tag] 52 | except Exception as e: 53 | print(e) 54 | pass 55 | 56 | return result 57 | 58 | def datetime_from_seconds(seconds, timezone="US/Pacific"): 59 | """ 60 | timezone: pytz timezone name to use for conversion, ie, UTC or US/Pacific 61 | """ 62 | return dt.datetime.fromtimestamp(seconds, pytz.timezone(timezone)) 63 | 64 | 65 | def download_file(url): 66 | import urllib.request 67 | response = urllib.request.urlopen(url) 68 | data = response.read() 69 | return data 70 | 71 | def main(): 72 | with open('/tmp/events', 'wb') as f: 73 | f.write(download_file(events_url)) 74 | 75 | 76 | events_dict=get_events('/tmp/events', 'step') 77 | events_dict2 = get_events('/tmp/events', 'time') 78 | # starting time, "first" event gets logged in beginning of main() 79 | first = events_dict2['first'] 80 | start_time = list(first.keys())[0] 81 | 82 | # build step->time dict for eval events 83 | events_step = events_dict['losses/test_5'] 84 | steps = list(events_step.keys()) 85 | events_time = events_dict2['losses/test_5'] 86 | times = list(events_time.keys()) 87 | step_time = {v[0]:v[1] for v in zip(events_step, events_time)} 88 | print(step_time) 89 | 90 | 91 | # get ending time 92 | test_5 = events_dict['losses/test_5'] 93 | test_1 = events_dict['losses/test_1'] 94 | eval_sec = events_dict['times/eval_sec'] 95 | total_eval_sec = 0 96 | for (i, step) in enumerate(test_1): 97 | # subtract eval time, which is not required 98 | # https://github.com/stanford-futuredata/dawn-bench-entries/issues/12#issuecomment-381363792 99 | ts = step_time[step] 100 | elapsed = ts-start_time 101 | if args.ignore_eval: 102 | total_eval_sec+=eval_sec[step] 103 | elapsed -= total_eval_sec 104 | 105 | print(f"{i+1}\t{(elapsed/3600)}\t{test_1[step]}\t{test_5[step]}") 106 | if test_5[step]>=93: 107 | end_time = ts 108 | break 109 | 110 | 111 | if __name__=='__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code to reproduce ImageNet in 18 minutes, by Andrew Shaw, Yaroslav Bulatov, and Jeremy Howard. High-level overview of techniques used is ([here](https://docs.google.com/document/d/14v6elpz12Nm5VwSVYbQym0Sbsl0I6UbqV3-SYaLzyTc/edit#heading=h.b2wwp49hhjut), Yaroslav) and ([here](http://fast.ai/2018/08/10/fastai-diu-imagenet/), fast.ai) 2 | 3 | 4 | Pre-requisites: Python 3.6 or higher 5 | 6 | - Set your `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_DEFAULT_REGION` (example [instructions](https://docs.google.com/document/d/1Z8lCZVWXs7XORbiNmBAsBDtouV3KwrtH8-UL5M-zHus/edit)) 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | 11 | ncluster spot_prices p3 # check spot prices for regions to find valid zone for p3 instances 12 | export NCLUSTER_ZONE=us-east-1 # set to a zone with cheap p3's 13 | python tools/replicate_imagenet.py --replicas=4 # configure 16 high performance disks 14 | python train.py --machines=4 15 | python tools/replicate_imagenet.py --replicas=4 --delete # delete high performance disks 16 | ``` 17 | 18 | To run with smaller number of machines: 19 | 20 | ``` 21 | python train.py --machines=1 22 | python train.py --machines=2 23 | python train.py --machines=4 24 | python train.py --machines=8 25 | python train.py --machines=16 26 | ``` 27 | 28 | To run as spot prices, add `--spot` argument, ie `train.py --spot` 29 | 30 | Your AWS account needs to have high enough limit in order to reserve this number of p3.16xlarge instances. The code will set up necessary infrastructure like EFS, VPC, subnets, keypairs and placement groups. Therefore permissions to create these those resources are needed. Note that high performance disks cost about $1/hour, so make sure to delete them after using. 31 | 32 | 33 | # Checking progress 34 | 35 | Machines print progress to local stdout, log TensorBoard event files to EFS under unique directory and also send data to wandb if WANDB_API_KEY env var is set to API key (it's under https://app.wandb.ai/settings). 36 | 37 | 38 | ## TensorBoard 39 | 1. launch tensorboard using `python tools/launch_tensorboard.py` 40 | 41 | That will provide a link to tensorboard instance which has loss graph under "losses" group. You'll see something like this under "Losses" tab 42 | 43 | 44 | ## Console 45 | You can connect to one of the instances using instructions printed during launch. Look for something like this 46 | 47 | ``` 48 | 2019-07-29 15:58:10.653377 0.monday-quad: To connect to 0.monday-quad do "ncluster connect 0.monday-quad" or 49 | ssh ubuntu@184.73.100.7 50 | tmux a 51 | ``` 52 | 53 | This will connect you to tmux session and you will see something like this 54 | 55 | ``` 56 | .997 (65.102) Acc@5 85.854 (85.224) Data 0.004 (0.035) BW 2.444 2.445 57 | Epoch: [21][175/179] Time 0.318 (0.368) Loss 1.4276 (1.4767) Acc@1 66.169 (65.132) Acc@5 86.063 (85.244) Data 0.004 (0.035) BW 2.464 2.466 58 | Changing LR from 0.4012569832402235 to 0.40000000000000013 59 | Epoch: [21][179/179] Time 0.336 (0.367) Loss 1.4457 (1.4761) Acc@1 65.473 (65.152) Acc@5 86.061 (85.252) Data 0.004 (0.034) BW 2.393 2.397 60 | Test: [21][5/7] Time 0.106 (0.563) Loss 1.3254 (1.3187) Acc@1 67.508 (67.693) Acc@5 88.644 (88.315) 61 | Test: [21][7/7] Time 0.105 (0.432) Loss 1.4089 (1.3346) Acc@1 67.134 (67.462) Acc@5 87.257 (88.124) 62 | ~~21 0.31132 67.462 88.124 63 | ``` 64 | 65 | The last number indicates that at epoch 21 the run got 67.462 top-1 test accuracy and 88.124 top-5 test accuracy. 66 | 67 | ## Weights and Biases 68 | 69 | Runs will show up under under "imagenet18" project in your Weights and Biases page, is https://app.wandb.ai/yaroslavvb/imagenet18/runs/8fv3xosq 70 | 71 | # Other notes 72 | If you run locally, you may need to download imagenet yourself from [here](https://s3.amazonaws.com/yaroslavvb2/data/imagenet18.tar) 73 | 74 | -------------------------------------------------------------------------------- /tools/replicate_imagenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Downloads imagenet and replicates it across multiple disks 4 | # 5 | # Script to initialize a set of high-performance volumes with ImageNet data 6 | # 7 | # replicate_imagenet.py --replicas 8 8 | # replicate_imagenet.py --replicas 8 --volume-offset=8 9 | # 10 | # or 11 | # 12 | # replicate_imagenet.py --replicas 16 --zone=us-east-1b 13 | # Creates volumes: imagenet_00, imagenet_01, imagenet_02, ..., imagenet_15 14 | # 15 | # ImageNet data should follow structure as in 16 | # https://github.com/diux-dev/cluster/tree/master/pytorch#data-preparation 17 | # (paths replace ~/data with /) 18 | # 19 | # steps to create snapshot: 20 | # create blank volume (ec2.create_volume()) 21 | # attach it to an existing instance with ImageNet under data, then 22 | # sudo mkfs -t ext4 /dev/xvdf 23 | # mkdir data 24 | # sudo mount /dev/xvdf data 25 | # sudo chown data `whoami` 26 | # cp -R data0 data 27 | # snapshot = ec2.create_snapshot(Description=f'{u.get_name(vol)} snapshot', 28 | # VolumeId=vol.id,) 29 | 30 | import argparse 31 | 32 | from ncluster import aws_util as u 33 | 34 | parser = argparse.ArgumentParser(description='launch') 35 | parser.add_argument('--replicas', type=int, default=1) 36 | #parser.add_argument('--snapshot', type=str, default='imagenet18') 37 | parser.add_argument('--snapshot', type=str, default='imagenet18-backup') 38 | #parser.add_argument('--snapshot_account', type=str, default='316880547378', 39 | # help='account id hosting this snapshot') 40 | 41 | parser.add_argument('--volume_offset', type=int, default=0, help='start numbering with this value') 42 | parser.add_argument('--size_gb', type=int, default=0, help="size in GBs") 43 | parser.add_argument('--delete', action='store_true', help="delete volumes instead of creating") 44 | 45 | args = parser.parse_args() 46 | 47 | 48 | def create_volume_tags(name): 49 | return [{ 50 | 'ResourceType': 'volume', 51 | 'Tags': [{ 52 | 'Key': 'Name', 53 | 'Value': name 54 | }] 55 | }] 56 | 57 | 58 | # TODO: switch to snap-03e6fc1ab6d2da3c5 59 | 60 | def main(): 61 | ec2 = u.get_ec2_resource() 62 | zone = u.get_zone() 63 | 64 | # use filtering by description since Name is not public 65 | # snapshots = list(ec2.snapshots.filter(Filters=[{'Name': 'description', 'Values': [args.snapshot]}, 66 | # {'Name': 'owner-id', 'Values': [args.snapshot_account]}])) 67 | 68 | snap = None 69 | if not args.delete: 70 | snapshots = list(ec2.snapshots.filter(Filters=[{'Name': 'description', 'Values': [args.snapshot]}])) 71 | 72 | assert len(snapshots) > 0, f"no snapshot matching {args.snapshot}" 73 | assert len(snapshots) < 2, f"multiple snapshots matching {args.snapshot}" 74 | snap = snapshots[0] 75 | if not args.size_gb: 76 | args.size_gb = snap.volume_size 77 | 78 | # list existing volumes 79 | vols = {} 80 | for vol in ec2.volumes.all(): 81 | vols[u.get_name(vol)] = vol 82 | 83 | print(f"{'Deleting' if args.delete else 'Making'} {args.replicas} {args.size_gb} GB replicas in {zone}") 84 | 85 | for i in range(args.volume_offset, args.replicas + args.volume_offset): 86 | vol_name = f'imagenet_{zone[-2:]}_{i:02d}' 87 | if args.delete: 88 | print(f"Deleting {vol_name}") 89 | if vol_name not in vols: 90 | print(" Not found") 91 | continue 92 | else: 93 | try: 94 | vols[vol_name].delete() 95 | except Exception as e: 96 | print(f"Deletion of {vol_name} failed with {e}") 97 | continue 98 | 99 | if vol_name in vols: 100 | print(f"{vol_name} exists, skipping") 101 | else: 102 | vol = ec2.create_volume(Size=args.size_gb, 103 | TagSpecifications=create_volume_tags(vol_name), 104 | AvailabilityZone=zone, 105 | SnapshotId=snap.id, 106 | Iops=11500, VolumeType='io1') 107 | print(f"Creating {vol_name} {vol.id}") 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /training/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | import wandb 6 | from tensorboardX import SummaryWriter 7 | 8 | # util is one level up, so import that 9 | module_path = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.insert(0, os.path.abspath(f'{module_path}/..')) 11 | 12 | 13 | class TensorboardLogger: 14 | def __init__(self, output_dir, is_master=False): 15 | self.output_dir = output_dir 16 | self.current_step = 0 17 | self.is_master = is_master 18 | if is_master: 19 | self.writer = SummaryWriter(self.output_dir) 20 | 21 | else: self.writer = NoOp() 22 | # self.log('first', time.time()) 23 | 24 | def log(self, tag, val): 25 | """Log value to tensorboard (relies on global_example_count being set properly)""" 26 | if not self.writer: return 27 | self.writer.add_scalar(tag, val, self.current_step) 28 | try: 29 | wandb.log({tag: val}, step=int(self.current_step)) 30 | except: 31 | pass 32 | 33 | def update_step_count(self, batch_total): 34 | self.current_step += batch_total 35 | 36 | def close(self): 37 | self.writer.export_scalars_to_json(self.output_dir+'/scalars.json') 38 | self.writer.close() 39 | 40 | # Convenience logging methods 41 | def log_size(self, bs=None, sz=None): 42 | if bs: self.log('sizes/batch', bs) 43 | if sz: self.log('sizes/image', sz) 44 | 45 | def log_eval(self, top1, top5, time): 46 | self.log('losses/test_1', top1) 47 | self.log('losses/test_5', top5) 48 | self.log('times/eval_sec', time) 49 | 50 | def log_trn_loss(self, loss, top1, top5): 51 | self.log("losses/xent", loss) # cross_entropy 52 | self.log("losses/train_1", top1) # precision@1 53 | self.log("losses/train_5", top5) # precision@5 54 | 55 | def log_memory(self): 56 | if not self.writer: return 57 | self.log("memory/allocated_gb", torch.cuda.memory_allocated()/1e9) 58 | self.log("memory/max_allocated_gb", torch.cuda.max_memory_allocated()/1e9) 59 | self.log("memory/cached_gb", torch.cuda.memory_cached()/1e9) 60 | self.log("memory/max_cached_gb", torch.cuda.max_memory_cached()/1e9) 61 | 62 | def log_trn_times(self, batch_time, data_time, batch_size): 63 | if not self.writer: return 64 | self.log("times/step", 1000*batch_time) 65 | self.log("times/data", 1000*data_time) 66 | images_per_sec = batch_size/batch_time 67 | self.log("times/1gpu_images_per_sec", images_per_sec) 68 | self.log("times/8gpu_images_per_sec", 8*images_per_sec) 69 | 70 | 71 | import logging 72 | 73 | 74 | class FileLogger: 75 | def __init__(self, output_dir, is_master=False, is_rank0=False): 76 | self.output_dir = output_dir 77 | 78 | # Log to console if rank 0, Log to console and file if master 79 | if not is_rank0: self.logger = NoOp() 80 | else: self.logger = self.get_logger(output_dir, log_to_file=is_master) 81 | 82 | def get_logger(self, output_dir, log_to_file=True): 83 | logger = logging.getLogger('imagenet_training') 84 | logger.setLevel(logging.DEBUG) 85 | formatter = logging.Formatter('%(message)s') 86 | 87 | time_formatter = logging.Formatter('%(asctime)s - %(filename)s:%(lineno)d - %(message)s') 88 | 89 | if log_to_file: 90 | vlog = logging.FileHandler(output_dir+'/verbose.log') 91 | vlog.setLevel(logging.INFO) 92 | vlog.setFormatter(formatter) 93 | logger.addHandler(vlog) 94 | 95 | eventlog = logging.FileHandler(output_dir+'/event.log') 96 | eventlog.setLevel(logging.WARN) 97 | eventlog.setFormatter(formatter) 98 | logger.addHandler(eventlog) 99 | 100 | debuglog = logging.FileHandler(output_dir+'/debug.log') 101 | debuglog.setLevel(logging.DEBUG) 102 | debuglog.setFormatter(time_formatter) 103 | logger.addHandler(debuglog) 104 | 105 | console = logging.StreamHandler() 106 | console.setFormatter(time_formatter) 107 | console.setLevel(logging.DEBUG) 108 | logger.addHandler(console) 109 | return logger 110 | 111 | def console(self, *args): 112 | if args and args[0]: 113 | args0 = 'rank-'+os.environ.get('RANK', '0')+' '+str(args[0]) 114 | new_args = (args0,)+args[1:] 115 | self.logger.debug(*new_args) 116 | 117 | def event(self, *args): 118 | self.logger.warn(*args) 119 | 120 | def verbose(self, *args): 121 | self.logger.info(*args) 122 | 123 | # no_op method/object that accept every signature 124 | class NoOp: 125 | def __getattr__(self, *args): 126 | def no_op(*args, **kwargs): pass 127 | return no_op 128 | -------------------------------------------------------------------------------- /dawn/logs/single_machine/event.log: -------------------------------------------------------------------------------- 1 | ~~epoch hours top1 top5 2 | 3 | Dataset changed. 4 | Image size: 128 5 | Batch size: 128 6 | Train Directory: /home/ubuntu/data/imagenet-sz/160/train 7 | Validation Directory: /home/ubuntu/data/imagenet-sz/160/validation 8 | Changing LR from None to 1.9220382165605094 9 | Changing LR from 2.2379617834394905 to 2.2399999999999998 10 | ~~0 0.01241 4.248 12.422 11 | 12 | Changing LR from 2.2399999999999998 to 2.2420382165605095 13 | Changing LR from 2.5579617834394903 to 2.5599999999999996 14 | ~~1 0.02056 13.744 31.890 15 | 16 | Changing LR from 2.5599999999999996 to 2.5620382165605093 17 | Changing LR from 2.87796178343949 to 2.88 18 | ~~2 0.02872 17.476 37.000 19 | 20 | Changing LR from 2.88 to 2.8820382165605096 21 | Changing LR from 3.1979617834394904 to 3.1999999999999997 22 | ~~3 0.03675 22.806 45.220 23 | 24 | Changing LR from 3.1999999999999997 to 3.2020382165605095 25 | Changing LR from 3.5179617834394903 to 3.5199999999999996 26 | ~~4 0.04485 23.948 46.612 27 | 28 | Changing LR from 3.5199999999999996 to 3.5220382165605093 29 | Changing LR from 3.83796178343949 to 3.84 30 | ~~5 0.05292 27.468 52.032 31 | 32 | Batch size changed: 256 33 | ~~6 0.06062 36.966 62.856 34 | 35 | ~~7 0.06630 32.990 57.666 36 | 37 | ~~8 0.07187 31.698 56.344 38 | 39 | ~~9 0.07749 18.858 38.408 40 | 41 | ~~10 0.08310 37.414 63.216 42 | 43 | Changing LR from 3.84 to 3.831898734177215 44 | Changing LR from 3.2081012658227848 to 3.1999999999999997 45 | ~~11 0.08867 35.538 60.328 46 | 47 | Changing LR from 3.1999999999999997 to 3.191898734177215 48 | Changing LR from 2.5681012658227846 to 2.5599999999999996 49 | ~~12 0.09431 42.848 69.076 50 | 51 | Changing LR from 2.5599999999999996 to 2.551898734177215 52 | Changing LR from 1.9281012658227845 to 1.9199999999999997 53 | ~~13 0.09992 51.678 76.764 54 | 55 | Dataset changed. 56 | Image size: 224 57 | Batch size: 128 58 | Train Directory: /home/ubuntu/data/imagenet-sz/352/train 59 | Validation Directory: /home/ubuntu/data/imagenet-sz/352/validation 60 | Changing LR from 1.9199999999999997 to 1.92 61 | ~~14 0.11677 49.534 75.002 62 | 63 | ~~15 0.12934 51.472 77.124 64 | 65 | ~~16 0.14186 47.602 73.600 66 | 67 | Batch size changed: 224 68 | Changing LR from 1.92 to 1.9170666666666667 69 | Changing LR from 1.6589333333333334 to 1.656 70 | ~~17 0.15763 53.816 78.380 71 | 72 | Changing LR from 1.656 to 1.6530666666666667 73 | Changing LR from 1.3949333333333334 to 1.392 74 | ~~18 0.16902 60.346 83.514 75 | 76 | Changing LR from 1.392 to 1.3890666666666667 77 | Changing LR from 1.1309333333333333 to 1.1280000000000001 78 | ~~19 0.18002 60.764 83.586 79 | 80 | Changing LR from 1.1280000000000001 to 1.1250666666666667 81 | Changing LR from 0.8669333333333333 to 0.8640000000000001 82 | ~~20 0.19103 64.170 85.980 83 | 84 | Changing LR from 0.8640000000000001 to 0.8610666666666666 85 | Changing LR from 0.6029333333333335 to 0.6000000000000001 86 | ~~21 0.20206 67.700 88.448 87 | 88 | Changing LR from 0.6000000000000001 to 0.5970666666666669 89 | Changing LR from 0.33893333333333353 to 0.3360000000000001 90 | ~~22 0.21315 69.844 89.668 91 | 92 | Changing LR from 0.3360000000000001 to 0.33544 93 | Changing LR from 0.28616 to 0.2856 94 | ~~23 0.22425 70.372 89.810 95 | 96 | Changing LR from 0.2856 to 0.28504 97 | Changing LR from 0.23576000000000003 to 0.23520000000000002 98 | ~~24 0.23541 71.352 90.306 99 | 100 | Changing LR from 0.23520000000000002 to 0.23464000000000002 101 | Changing LR from 0.18536000000000002 to 0.18480000000000005 102 | ~~25 0.24650 71.172 90.228 103 | 104 | Changing LR from 0.18480000000000005 to 0.18424000000000004 105 | Changing LR from 0.13496000000000005 to 0.13440000000000005 106 | ~~26 0.25760 72.086 90.818 107 | 108 | Changing LR from 0.13440000000000005 to 0.13384000000000004 109 | Changing LR from 0.08456000000000002 to 0.08400000000000002 110 | ~~27 0.26867 72.834 91.222 111 | 112 | Changing LR from 0.08400000000000002 to 0.08344000000000007 113 | Changing LR from 0.034160000000000024 to 0.033600000000000074 114 | ~~28 0.27976 73.198 91.338 115 | 116 | Dataset changed. 117 | Image size: 288 118 | Batch size: 128 119 | Train Directory: /home/ubuntu/data/imagenet/train 120 | Validation Directory: /home/ubuntu/data/imagenet/validation 121 | Changing LR from 0.033600000000000074 to 0.019186242038216558 122 | Changing LR from 0.017053757961783437 to 0.01704 123 | ~~29 0.31472 75.812 93.012 124 | 125 | Changing LR from 0.01704 to 0.01702624203821656 126 | Changing LR from 0.014893757961783438 to 0.014879999999999999 127 | ~~30 0.33507 75.850 93.116 128 | 129 | Changing LR from 0.014879999999999999 to 0.01486624203821656 130 | Changing LR from 0.01273375796178344 to 0.012719999999999999 131 | ~~31 0.35516 75.926 93.154 132 | 133 | Changing LR from 0.012719999999999999 to 0.01270624203821656 134 | Changing LR from 0.010573757961783439 to 0.01056 135 | ~~32 0.37524 75.916 93.134 136 | 137 | Changing LR from 0.01056 to 0.010546242038216561 138 | Changing LR from 0.00841375796178344 to 0.0084 139 | ~~33 0.39400 75.942 93.122 140 | 141 | Changing LR from 0.0084 to 0.00838624203821656 142 | Changing LR from 0.00625375796178344 to 0.006240000000000001 143 | ~~34 0.41283 75.928 93.168 144 | 145 | Changing LR from 0.006240000000000001 to 0.00622624203821656 146 | Changing LR from 0.004093757961783439 to 0.00408 147 | ~~35 0.43295 75.944 93.186 148 | 149 | Changing LR from 0.00408 to 0.004066242038216561 150 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import pickle 4 | import random 5 | import re 6 | import string 7 | import subprocess 8 | import threading 9 | from typing import Tuple 10 | 11 | 12 | def is_set(name: str) -> bool: 13 | """Helper method to check if given property is set, anything except missing, 0 and false means set """ 14 | 15 | val = os.environ.get(name, '0').lower() 16 | return not (val == '0' or val == 'false') 17 | 18 | 19 | def extract_ec2_metadata(): 20 | """Returns dictionary of common ec2 metadata""" 21 | from ec2_metadata import ec2_metadata 22 | try: 23 | return { 24 | 'region': ec2_metadata.region, 25 | 'account_id': ec2_metadata.account_id, 26 | 'ami_id': ec2_metadata.ami_id, 27 | 'availability_zone': ec2_metadata.availability_zone, 28 | 'instance_type': ec2_metadata.instance_type, 29 | 'public_ipv4': ec2_metadata.public_ipv4, 30 | 'private_ipv4': ec2_metadata.private_ipv4 31 | } 32 | except: # may crash with requests.exceptions.ConnectTimeout when not on AWS 33 | return {} 34 | 35 | 36 | def random_id(k=3): 37 | """Random id to use for AWS identifiers.""" 38 | # https://stackoverflow.com/questions/2257441/random-string-generation-with-upper-case-letters-and-digits-in-python 39 | return ''.join(random.choices(string.ascii_lowercase + string.digits, k=k)) 40 | 41 | 42 | def log_environment(): 43 | """Logs AWS local machine environment to wandb config.""" 44 | import os 45 | import wandb 46 | import torch 47 | 48 | if not (hasattr(wandb, 'config') and wandb.config is not None): 49 | return 50 | 51 | for key in os.environ: 52 | if re.match(r"^NCCL|CUDA|PATH|^LD|USER|PWD|^OMP", key): 53 | wandb.config['env_'+key] = os.getenv(key) 54 | 55 | wandb.config['pytorch_version'] = torch.__version__ 56 | wandb.config.update(extract_ec2_metadata()) 57 | 58 | 59 | def ossystem(cmd, shell=True): 60 | """Like os.system, but returns output of command as string.""" 61 | p = subprocess.Popen(cmd, shell=shell, stdout=subprocess.PIPE, 62 | stderr=subprocess.STDOUT) 63 | (stdout, stderr) = p.communicate() 64 | return stdout.decode('ascii') 65 | 66 | 67 | def text_pickle(obj) -> str: 68 | """Pickles object into character string""" 69 | pickle_string = pickle.dumps(obj) 70 | pickle_string_encoded: bytes = base64.b64encode(pickle_string) 71 | s = pickle_string_encoded.decode('ascii') 72 | return s 73 | 74 | 75 | def text_unpickle(pickle_string_encoded: str): 76 | """Unpickles character string""" 77 | if not pickle_string_encoded: 78 | return None 79 | obj = pickle.loads(base64.b64decode(pickle_string_encoded)) 80 | return obj 81 | 82 | 83 | def format_env(**d): 84 | """Converts env var values into variable string, ie 85 | 'var1="val1" var2="val2" '""" 86 | args_ = [f'{key}="{d[key]}" ' for key in d] 87 | return ''.join(args_) 88 | 89 | 90 | def format_env_export(**d): 91 | """Converts env var values into variable string, ie 92 | 'export var1="val1" && export var2="val2" '""" 93 | args_ = [f'export {key}="{d[key]}" ' for key in d] 94 | return ' && '.join(args_) 95 | 96 | 97 | def format_env_x(**d): 98 | """Converts env var values into format suitable for mpirun, ie 99 | '-x var1="val1" -x var2="val2" '""" 100 | args_ = [f'-x {key}="{d[key]}" ' for key in sorted(d)] 101 | return ''.join(args_) 102 | 103 | 104 | def setup_mpi(job, skip_ssh_setup=False) -> Tuple[str, str]: 105 | """Sets up passwordless SSH between all tasks in the job.""" 106 | public_keys = {} 107 | if not skip_ssh_setup: 108 | for task in job.tasks: 109 | key_fn = '~/.ssh/id_rsa' # this fn is special, used by default by ssh 110 | task.run(f"yes | ssh-keygen -t rsa -f {key_fn} -N ''") 111 | 112 | public_keys[task] = task.read(key_fn + '.pub') 113 | 114 | keys = {} 115 | for i, task1 in enumerate(job.tasks): 116 | task1.run('echo "StrictHostKeyChecking no" >> /etc/ssh/ssh_config', 117 | sudo=True, non_blocking=True) 118 | for j, task2_ in enumerate(job.tasks): 119 | # task1 ->ssh-> task2 120 | # task2.run(f'echo "{public_keys[task1]}" >> ~/.ssh/authorized_keys', 121 | # non_blocking=True) 122 | keys.setdefault(j, []).append(public_keys[task1]) 123 | 124 | def setup_task_mpi(j2): 125 | task2 = job.tasks[j2] 126 | key_str = '\n'.join(keys[j2]) 127 | fn = f'task-{j2}' 128 | with open(fn, 'w') as f: 129 | f.write(key_str) 130 | task2.upload(fn) 131 | task2.run(f"""echo `cat {fn}` >> ~/.ssh/authorized_keys""", 132 | non_blocking=True) 133 | 134 | run_parallel(setup_task_mpi, range(len(job.tasks))) 135 | # for j, task2_ in enumerate(job.tasks): 136 | # setup_task_mpi(j) 137 | 138 | task0 = job.tasks[0] 139 | hosts = [task.ip for task in job.tasks] 140 | hosts_str = ','.join(hosts) 141 | hosts_file_lines = [f'{host} slots={task0.num_gpus} max-slots={task0.num_gpus}' for host in hosts] 142 | hosts_file_str = '\n'.join(hosts_file_lines) 143 | return hosts_str, hosts_file_str 144 | 145 | 146 | def run_parallel(f, args_): 147 | threads = [threading.Thread(name=f'run_parallel_{i}', target=f, args=[t]) for i, t in enumerate(args_)] 148 | for thread in threads: 149 | thread.start() 150 | for thread in threads: 151 | thread.join() 152 | 153 | -------------------------------------------------------------------------------- /training/fp16util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 5 | 6 | class tofp16(nn.Module): 7 | """ 8 | Model wrapper that implements:: 9 | 10 | def forward(self, input): 11 | return input.half() 12 | """ 13 | 14 | def __init__(self): 15 | super(tofp16, self).__init__() 16 | 17 | def forward(self, input): 18 | return input.half() 19 | 20 | 21 | def BN_convert_float(module): 22 | ''' 23 | Designed to work with network_to_half. 24 | BatchNorm layers need parameters in single precision. 25 | Find all layers and convert them back to float. This can't 26 | be done with built in .apply as that function will apply 27 | fn to all modules, parameters, and buffers. Thus we wouldn't 28 | be able to guard the float conversion based on the module type. 29 | ''' 30 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 31 | module.float() 32 | for child in module.children(): 33 | BN_convert_float(child) 34 | return module 35 | 36 | 37 | def network_to_half(network): 38 | """ 39 | Convert model to half precision in a batchnorm-safe way. 40 | """ 41 | # (AS) This is better as it does not change model structure 42 | return BN_convert_float(network.half()) 43 | # return nn.Sequential(tofp16(), BN_convert_float(network.half())) 44 | 45 | 46 | def backwards_debug_hook(grad): 47 | raise RuntimeError("master_params recieved a gradient in the backward pass!") 48 | 49 | def prep_param_lists(model, flat_master=False): 50 | """ 51 | Creates a list of FP32 master parameters for a given model, as in 52 | `Training Neural Networks with Mixed Precision: Real Examples`_. 53 | 54 | Args: 55 | model (torch.nn.Module): Existing Pytorch model 56 | flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. 57 | Returns: 58 | A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. 59 | 60 | Example:: 61 | 62 | model_params, master_params = prep_param_lists(model) 63 | 64 | .. warning:: 65 | Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. 66 | 67 | .. _`Training Neural Networks with Mixed Precision: Real Examples`: 68 | http://on-demand.gputechconf.com/gtc/2018/video/S81012/ 69 | """ 70 | model_params = [param for param in model.parameters() if param.requires_grad] 71 | 72 | if flat_master: 73 | # Give the user some more useful error messages 74 | try: 75 | # flatten_dense_tensors returns a contiguous flat array. 76 | # http://pytorch.org/docs/master/_modules/torch/_utils.html 77 | master_params = _flatten_dense_tensors([param.data for param in model_params]).float() 78 | except TypeError as instance: 79 | # This is brittle, and depends on how cat chooses to word its error message. 80 | if "cat received an invalid combination of arguments" not in instance.args[0]: 81 | raise 82 | else: 83 | # If you append a message to the exception instance, via 84 | # instance.args = instance.args + ("Error...",) 85 | # this messes up the terminal-formatted printing of the instance's original message. 86 | # Basic solution for now: 87 | print("Error in prep_param_lists: model likely contains a mixture of parameters " 88 | "of different types. Use flat_master=False, or use F16_Optimizer.") 89 | raise 90 | master_params = torch.nn.Parameter(master_params) 91 | master_params.requires_grad = True 92 | # master_params.register_hook(backwards_debug_hook) 93 | if master_params.grad is None: 94 | master_params.grad = master_params.new(*master_params.size()) 95 | return model_params, [master_params] 96 | else: 97 | master_params = [param.clone().float().detach() for param in model_params] 98 | for param in master_params: 99 | param.requires_grad = True 100 | return model_params, master_params 101 | 102 | 103 | def model_grads_to_master_grads(model_params, master_params, flat_master=False): 104 | """ 105 | Copy model gradients to master gradients. 106 | 107 | Args: 108 | model_params: List of model parameters created by :func:`prep_param_lists`. 109 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. 110 | """ 111 | if flat_master: 112 | # The flattening may incur one more deep copy than is necessary. 113 | master_params[0].grad.data.copy_( 114 | _flatten_dense_tensors([p.grad.data for p in model_params])) 115 | else: 116 | for model, master in zip(model_params, master_params): 117 | if model.grad is not None: 118 | if master.grad is None: 119 | master.grad = Variable(master.data.new(*master.data.size())) 120 | master.grad.data.copy_(model.grad.data) 121 | else: 122 | master.grad = None 123 | 124 | 125 | def master_params_to_model_params(model_params, master_params, flat_master=False): 126 | """ 127 | Copy master parameters to model parameters. 128 | 129 | Args: 130 | model_params: List of model parameters created by :func:`prep_param_lists`. 131 | master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. 132 | """ 133 | if flat_master: 134 | for model, master in zip(model_params, 135 | _unflatten_dense_tensors(master_params[0].data, model_params)): 136 | model.data.copy_(master) 137 | else: 138 | for model, master in zip(model_params, master_params): 139 | model.data.copy_(master.data) -------------------------------------------------------------------------------- /training/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 64 | padding=1, bias=False) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, num_classes=1000): 98 | self.inplanes = 64 99 | super(ResNet, self).__init__() 100 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 101 | bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | self.avgpool = nn.AdaptiveAvgPool2d(1) 110 | self.fc = nn.Linear(512 * block.expansion, num_classes) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 115 | elif isinstance(m, nn.BatchNorm2d): 116 | nn.init.constant_(m.weight, 1) 117 | nn.init.constant_(m.bias, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | x = self.avgpool(x) 148 | x = x.view(x.size(0), -1) 149 | x = self.fc(x) 150 | 151 | return x 152 | 153 | 154 | def init_dist_weights(model): 155 | # https://arxiv.org/pdf/1706.02677.pdf 156 | # https://github.com/pytorch/examples/pull/262 157 | for m in model.modules(): 158 | if isinstance(m, BasicBlock): m.bn2.weight = nn.Parameter(torch.zeros_like(m.bn2.weight)) 159 | if isinstance(m, Bottleneck): m.bn3.weight = nn.Parameter(torch.zeros_like(m.bn3.weight)) 160 | if isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.01) 161 | 162 | 163 | def resnet18(pretrained=False, **kwargs): 164 | """Constructs a ResNet-18 model. 165 | 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 170 | if pretrained: 171 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 172 | return model 173 | 174 | 175 | def resnet34(pretrained=False, **kwargs): 176 | """Constructs a ResNet-34 model. 177 | 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 182 | if pretrained: 183 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 184 | return model 185 | 186 | 187 | def resnet50(pretrained=False, bn0=False, **kwargs): 188 | """Constructs a ResNet-50 model. 189 | 190 | Args: 191 | pretrained (bool): If True, returns a model pre-trained on ImageNet 192 | """ 193 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 194 | if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 195 | if bn0: init_dist_weights(model) 196 | return model 197 | 198 | 199 | def resnet101(pretrained=False, **kwargs): 200 | """Constructs a ResNet-101 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 220 | return model -------------------------------------------------------------------------------- /training/dataloader.py: -------------------------------------------------------------------------------- 1 | import argparse, os, shutil, time, warnings 2 | from pathlib import Path 3 | import numpy as np 4 | import sys 5 | import math 6 | 7 | import torch 8 | import torch.utils.data 9 | from torch.utils.data.distributed import DistributedSampler 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | 13 | from torch.utils.data.sampler import Sampler 14 | import torchvision 15 | import pickle 16 | from tqdm import tqdm 17 | from dist_utils import env_world_size, env_rank 18 | 19 | 20 | class SyntheticBatchSampler(object): 21 | def __init__(self, batch_size): 22 | self.batch_size = batch_size 23 | 24 | 25 | class SyntheticDataLoader(object): 26 | def __init__(self, batch_size, input_shape): 27 | # Create two random tensors - one for data and one for label 28 | self.input_shape = input_shape 29 | data_shape = (batch_size,) + input_shape 30 | self.data = torch.randn(data_shape) 31 | self.labels = torch.from_numpy( 32 | np.random.randint(0, 1000, batch_size).astype(np.long) 33 | ) 34 | self.prefetchable = False 35 | self.data = self.data.cuda() 36 | self.labels = self.labels.cuda() 37 | self.finish = 0 38 | self.batch_size = batch_size 39 | self.batch_sampler = SyntheticBatchSampler(batch_size) 40 | self.batch_num = 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1 41 | 42 | def next(self): 43 | return (self.data, self.labels) 44 | 45 | def __iter__(self): 46 | return self 47 | 48 | def __len__(self): 49 | return 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1 50 | 51 | def __next__(self): 52 | # Support BatchTransformDataLoader.update_batch_size() 53 | if self.batch_size != self.batch_sampler.batch_size: 54 | self.batch_size = self.batch_sampler.batch_size 55 | data_shape = (self.batch_size,) + self.input_shape 56 | self.data = torch.randn(data_shape) 57 | self.labels = torch.from_numpy( 58 | np.random.randint(0, 1000, self.batch_size).astype(np.long) 59 | ) 60 | self.data = self.data.cuda() 61 | self.labels = self.labels.cuda() 62 | self.batch_num = 1281167 // (env_world_size() * self.batch_sampler.batch_size) + 1 63 | if self.finish >= self.batch_num: 64 | self.finish = 0 65 | raise StopIteration 66 | self.finish += 1 67 | return (self.data, self.labels) 68 | 69 | 70 | # util is one level up, so import that 71 | module_path = os.path.dirname(os.path.abspath(__file__)) 72 | sys.path.insert(0, os.path.abspath(f'{module_path}/..')) 73 | 74 | import util 75 | 76 | def get_loaders(traindir, valdir, sz, bs, fp16=True, val_bs=None, workers=8, rect_val=False, min_scale=0.08, distributed=False, synthetic=False): 77 | val_bs = val_bs or bs 78 | train_tfms = [ 79 | transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), 80 | transforms.RandomHorizontalFlip() 81 | ] 82 | train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) 83 | train_sampler = (DistributedSampler(train_dataset, num_replicas=env_world_size(), rank=env_rank()) if distributed else None) 84 | 85 | if synthetic: 86 | print("Using synthetic dataloader") 87 | train_loader = SyntheticDataLoader(bs, (3, sz, sz)) 88 | elif util.is_set('PYTORCH_USE_SPAWN'): 89 | print("Using SPAWN method for dataloader") 90 | train_loader = torch.utils.data.DataLoader( 91 | train_dataset, batch_size=bs, shuffle=(train_sampler is None), 92 | num_workers=workers, pin_memory=True, collate_fn=fast_collate, 93 | sampler=train_sampler, 94 | multiprocessing_context='spawn') 95 | else: 96 | train_loader = torch.utils.data.DataLoader( 97 | train_dataset, batch_size=bs, shuffle=(train_sampler is None), 98 | num_workers=workers, pin_memory=True, collate_fn=fast_collate, 99 | sampler=train_sampler) 100 | 101 | val_dataset, val_sampler = create_validation_set(valdir, val_bs, sz, rect_val=rect_val, distributed=distributed) 102 | val_loader = torch.utils.data.DataLoader( 103 | val_dataset, 104 | num_workers=workers, pin_memory=True, collate_fn=fast_collate, 105 | batch_sampler=val_sampler) 106 | 107 | train_loader = BatchTransformDataLoader(train_loader, fp16=fp16) 108 | val_loader = BatchTransformDataLoader(val_loader, fp16=fp16) 109 | 110 | return train_loader, val_loader, train_sampler, val_sampler 111 | 112 | 113 | def create_validation_set(valdir, batch_size, target_size, rect_val, distributed): 114 | if rect_val: 115 | idx_ar_sorted = sort_ar(valdir) 116 | idx_sorted, _ = zip(*idx_ar_sorted) 117 | idx2ar = map_idx2ar(idx_ar_sorted, batch_size) 118 | 119 | ar_tfms = [transforms.Resize(int(target_size*1.14)), CropArTfm(idx2ar, target_size)] 120 | val_dataset = ValDataset(valdir, transform=ar_tfms) 121 | val_sampler = DistValSampler(idx_sorted, batch_size=batch_size, distributed=distributed) 122 | return val_dataset, val_sampler 123 | 124 | val_tfms = [transforms.Resize(int(target_size*1.14)), transforms.CenterCrop(target_size)] 125 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms)) 126 | val_sampler = DistValSampler(list(range(len(val_dataset))), batch_size=batch_size, distributed=distributed) 127 | return val_dataset, val_sampler 128 | 129 | class BatchTransformDataLoader(): 130 | # Mean normalization on batch level instead of individual 131 | # https://github.com/NVIDIA/apex/blob/59bf7d139e20fb4fa54b09c6592a2ff862f3ac7f/examples/imagenet/main.py#L222 132 | def __init__(self, loader, fp16=True): 133 | self.loader = loader 134 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 135 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 136 | self.fp16 = fp16 137 | if self.fp16: self.mean, self.std = self.mean.half(), self.std.half() 138 | 139 | def __len__(self): return len(self.loader) 140 | 141 | def process_tensors(self, input, target, non_blocking=True): 142 | input = input.cuda(non_blocking=non_blocking) 143 | if self.fp16: input = input.half() 144 | else: input = input.float() 145 | if len(input.shape) < 3: return input, target.cuda(non_blocking=non_blocking) 146 | return input.sub_(self.mean).div_(self.std), target.cuda(non_blocking=non_blocking) 147 | 148 | def update_batch_size(self, bs): 149 | self.loader.batch_sampler.batch_size = bs 150 | 151 | def __iter__(self): 152 | return (self.process_tensors(input, target, non_blocking=True) for input,target in self.loader) 153 | 154 | def fast_collate(batch): 155 | if not batch: return torch.tensor([]), torch.tensor([]) 156 | imgs = [img[0] for img in batch] 157 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 158 | w = imgs[0].size[0] 159 | h = imgs[0].size[1] 160 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 161 | for i, img in enumerate(imgs): 162 | nump_array = np.asarray(img, dtype=np.uint8) 163 | tens = torch.from_numpy(nump_array) 164 | if(nump_array.ndim < 3): 165 | nump_array = np.expand_dims(nump_array, axis=-1) 166 | nump_array = np.rollaxis(nump_array, 2) 167 | tensor[i] += torch.from_numpy(nump_array) 168 | return tensor, targets 169 | 170 | class ValDataset(datasets.ImageFolder): 171 | def __init__(self, root, transform=None, target_transform=None): 172 | super().__init__(root, transform, target_transform) 173 | def __getitem__(self, index): 174 | path, target = self.imgs[index] 175 | sample = self.loader(path) 176 | if self.transform is not None: 177 | for tfm in self.transform: 178 | if isinstance(tfm, CropArTfm): sample = tfm(sample, index) 179 | else: sample = tfm(sample) 180 | if self.target_transform is not None: 181 | target = self.target_transform(target) 182 | 183 | return sample, target 184 | 185 | 186 | class DistValSampler(Sampler): 187 | # DistValSampler distrbutes batches equally (based on batch size) to every gpu (even if there aren't enough images) 188 | # WARNING: Some baches will contain an empty array to signify there aren't enough images 189 | # Distributed=False - same validation happens on every single gpu 190 | def __init__(self, indices, batch_size, distributed=True): 191 | self.indices = indices 192 | self.batch_size = batch_size 193 | if distributed: 194 | self.world_size = env_world_size() 195 | self.global_rank = env_rank() 196 | else: 197 | self.global_rank = 0 198 | self.world_size = 1 199 | 200 | # expected number of batches per sample. Need this so each distributed gpu validates on same number of batches. 201 | # even if there isn't enough data to go around 202 | self.expected_num_batches = math.ceil(len(self.indices) / self.world_size / self.batch_size) 203 | 204 | # num_samples = total images / world_size. This is what we distribute to each gpu 205 | self.num_samples = self.expected_num_batches * self.batch_size 206 | 207 | def __iter__(self): 208 | offset = self.num_samples * self.global_rank 209 | sampled_indices = self.indices[offset:offset+self.num_samples] 210 | for i in range(self.expected_num_batches): 211 | offset = i*self.batch_size 212 | yield sampled_indices[offset:offset+self.batch_size] 213 | def __len__(self): return self.expected_num_batches 214 | def set_epoch(self, epoch): return 215 | 216 | 217 | class CropArTfm(object): 218 | def __init__(self, idx2ar, target_size): 219 | self.idx2ar, self.target_size = idx2ar, target_size 220 | def __call__(self, img, idx): 221 | target_ar = self.idx2ar[idx] 222 | if target_ar < 1: 223 | w = int(self.target_size/target_ar) 224 | size = (w//8*8, self.target_size) 225 | else: 226 | h = int(self.target_size*target_ar) 227 | size = (self.target_size, h//8*8) 228 | return torchvision.transforms.functional.center_crop(img, size) 229 | 230 | import os.path 231 | def sort_ar(valdir): 232 | idx2ar_file = valdir+'/../sorted_idxar.p' 233 | if os.path.isfile(idx2ar_file): return pickle.load(open(idx2ar_file, 'rb')) 234 | print('Creating AR indexes. Please be patient this may take a couple minutes...') 235 | val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset 236 | sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))] 237 | idx_ar = [(i, round(s[0]/s[1], 5)) for i,s in enumerate(sizes)] 238 | sorted_idxar = sorted(idx_ar, key=lambda x: x[1]) 239 | pickle.dump(sorted_idxar, open(idx2ar_file, 'wb')) 240 | print('Done') 241 | return sorted_idxar 242 | 243 | def chunks(l, n): 244 | n = max(1, n) 245 | return (l[i:i+n] for i in range(0, len(l), n)) 246 | 247 | def map_idx2ar(idx_ar_sorted, batch_size): 248 | ar_chunks = list(chunks(idx_ar_sorted, batch_size)) 249 | idx2ar = {} 250 | for chunk in ar_chunks: 251 | idxs, ars = list(zip(*chunk)) 252 | mean = round(np.mean(ars), 5) 253 | for idx in idxs: idx2ar[idx] = mean 254 | return idx2ar 255 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import re 6 | import sys 7 | import time 8 | 9 | import ncluster 10 | from ncluster import aws_util as u 11 | 12 | # todo(y): change to AMI owned by me ie, pytorch.imagenet.source.v7-copy 13 | import util 14 | 15 | # IMAGE_NAME = 'pytorch.imagenet.source.v7' 16 | HOSTS_SLOTS_FN = 'hosts.slots' 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--name', type=str, default='imagenet', 20 | help="name of the current run, used for machine naming") 21 | parser.add_argument('--run_name', type=str, default='', 22 | help="name of run for loggin") 23 | parser.add_argument('--machines', type=int, default=1, 24 | help="how many machines to use") 25 | parser.add_argument('--num_tasks', type=int, default=1, 26 | help="same as machines for compatibility, don't use") 27 | parser.add_argument('--mount_imagenet', type=int, default=1, 28 | help="if set, mount imagenet disk rather than taking data from local image") 29 | parser.add_argument('--offset', type=int, default=0, 30 | help='offset for imagenet ebs numbering') 31 | parser.add_argument('--vmtouch', type=int, default=0, 32 | help="lock all examples into physical memory") 33 | parser.add_argument('--internal_config_fn', type=str, default='ncluster_config_dict', 34 | help='location of filename with extra info to log') 35 | parser.add_argument('--nproc_per_node', type=int, default=8, help="Processes per machine, must not exceed number of GPUS") 36 | parser.add_argument('--image_name', type=str, default='pytorch-efa01', 37 | help="Image to use for this run") 38 | parser.add_argument('--instance_type', type=str, default='p3.16xlarge', help="Image to use for this run") 39 | parser.add_argument('--conda_env', type=str, default='pytorch_p36', help="name of conda env") 40 | parser.add_argument('--efa', type=int, default=0, help="use AWS EFA network") 41 | parser.add_argument('--pseudo_efa', type=int, default=0, help="use sockets interface when launching under EFA") 42 | parser.add_argument('--no_op', type=int, default=0, help='just print environment/debug info and skip rest') 43 | parser.add_argument('--log_all_workers', type=int, default=0, help='log from each worker instead of just chief') 44 | parser.add_argument('--spot', action='store_true', help='use spot instead of regular instances') 45 | parser.add_argument('--cuda_debug', action='store_true', help='debug cuda errors') 46 | parser.add_argument('--pytorch_nightly', action='store_true', help='install nightly PyTorch') 47 | parser.add_argument('--pytorch_use_spawn', action='store_true', help='use spawn method in dataloaders') 48 | parser.add_argument('--simple_ring_setup', action='store_true', help='set 16 rings instead of manual ring order') 49 | parser.add_argument('--skip_setup', action='store_true', help='speed up relaunch by skipping some steps') 50 | args = parser.parse_args() 51 | args.num_tasks = args.machines 52 | if not args.run_name: 53 | args.run_name = args.name 54 | 55 | # 109:12 to 93.00 56 | # https://app.wandb.ai/yaroslavvb/imagenet18/runs/gxsdo6i0 57 | lr = 1.0 58 | scale_224 = 224 / 512 59 | scale_288 = 128 / 512 60 | one_machine = [ 61 | {'ep': 0, 'sz': 128, 'bs': 512, 'trndir': '-sz/160'}, 62 | {'ep': (0, 5), 'lr': (lr, lr * 2)}, # lr warmup is better with --init-bn0 63 | {'ep': 5, 'lr': lr}, 64 | {'ep': 14, 'sz': 224, 'bs': 224, 65 | 'lr': lr * scale_224}, 66 | {'ep': 16, 'lr': lr / 10 * scale_224}, 67 | {'ep': 27, 'lr': lr / 100 * scale_224}, 68 | {'ep': 32, 'sz': 288, 'bs': 128, 'min_scale': 0.5, 'rect_val': True, 69 | 'lr': lr / 100 * scale_288}, 70 | {'ep': (33, 35), 'lr': lr / 1000 * scale_288} 71 | ] 72 | 73 | # 54 minutes to 93.364 74 | # https://app.wandb.ai/yaroslavvb/imagenet18/runs/lhx5a053 75 | lr = 0.75 * 2 76 | bs = [256, 224, 128] # largest batch size that fits in memory for each image size 77 | bs_scale = [x / bs[0] for x in bs] # scale learning rate to batch size 78 | two_machines = [ 79 | {'ep': 0, 'sz': 128, 'bs': bs[0], 'trndir': '-sz/160'}, 80 | # bs = 256 * 4 * 8 = 8192 81 | {'ep': (0, 6), 'lr': (lr, lr * 2)}, 82 | {'ep': 6, 'sz': 128, 'bs': bs[0] * 2, 'keep_dl': True}, 83 | {'ep': 6, 'lr': lr * 2}, 84 | {'ep': (11, 13), 'lr': (lr * 2, lr)}, # trying one cycle 85 | {'ep': 13, 'sz': 224, 'bs': bs[1], 'trndir': '-sz/352', 'min_scale': 0.087}, 86 | {'ep': 13, 'lr': lr * bs_scale[1]}, 87 | {'ep': (16, 23), 'lr': (lr * bs_scale[1], lr / 10 * bs_scale[1])}, 88 | {'ep': (23, 28), 'lr': (lr / 10 * bs_scale[1], lr / 100 * bs_scale[1])}, 89 | {'ep': 28, 'sz': 288, 'bs': bs[2], 'min_scale': 0.5, 'rect_val': True}, 90 | {'ep': (28, 30), 'lr': (lr / 100 * bs_scale[2], lr / 1000 * bs_scale[2])} 91 | ] 92 | 93 | # 29:44 to 93.05 94 | # events: https://s3.amazonaws.com/yaroslavvb/logs/imagenet-4 95 | # p3dn: https://app.wandb.ai/yaroslavvb/imagenet18/runs/pp0g9k5c 96 | lr = 0.50 * 4 # 4 = num tasks 97 | bs = [256, 224, 98 | 128] # largest batch size that fits in memory for each image size 99 | bs_scale = [x / bs[0] for x in bs] # scale learning rate to batch size 100 | four_machines = [ 101 | {'ep': 0, 'sz': 128, 'bs': bs[0], 'trndir': '-sz/160'}, 102 | # bs = 256 * 4 * 8 = 8192 103 | {'ep': (0, 6), 'lr': (lr, lr * 2)}, 104 | {'ep': 6, 'sz': 128, 'bs': bs[0] * 2, 'keep_dl': True}, 105 | {'ep': 6, 'lr': lr * 2}, 106 | {'ep': (11, 13), 'lr': (lr * 2, lr)}, # trying one cycle 107 | {'ep': 13, 'sz': 224, 'bs': bs[1], 'trndir': '-sz/352', 'min_scale': 0.087}, 108 | {'ep': 13, 'lr': lr * bs_scale[1]}, 109 | {'ep': (16, 23), 'lr': (lr * bs_scale[1], lr / 10 * bs_scale[1])}, 110 | {'ep': (23, 28), 'lr': (lr / 10 * bs_scale[1], lr / 100 * bs_scale[1])}, 111 | {'ep': 28, 'sz': 288, 'bs': bs[2], 'min_scale': 0.5, 'rect_val': True}, 112 | {'ep': (28, 30), 'lr': (lr / 100 * bs_scale[2], lr / 1000 * bs_scale[2])} 113 | ] 114 | 115 | # 19:04 to 93.0 116 | lr = 0.235 * 8 117 | scale_224 = 224 / 128 118 | eight_machines = [ 119 | {'ep': 0, 'sz': 128, 'bs': 128, 'trndir': '-sz/160'}, 120 | {'ep': (0, 6), 'lr': (lr, lr * 2)}, 121 | {'ep': 6, 'bs': 256, 'keep_dl': True, 122 | 'lr': lr * 2}, 123 | {'ep': (11, 14), 'lr': (lr * 2, lr)}, # trying one cycle 124 | {'ep': 14, 'sz': 224, 'bs': 128, 'trndir': '-sz/352', 'min_scale': 0.087, 125 | 'lr': lr}, 126 | {'ep': 17, 'bs': 224, 'keep_dl': True}, 127 | {'ep': (17, 23), 'lr': (lr, lr / 10 * scale_224)}, 128 | {'ep': (23, 29), 'lr': (lr / 10 * scale_224, lr / 100 * scale_224)}, 129 | {'ep': 29, 'sz': 288, 'bs': 128, 'min_scale': 0.5, 'rect_val': True}, 130 | {'ep': (29, 35), 'lr': (lr / 100, lr / 1000)} 131 | ] 132 | 133 | # 16:08 to 93.04 (after prewarming) 134 | lr = 0.235 * 8 # 135 | bs = 64 136 | sixteen_machines = [ 137 | {'ep': 0, 'sz': 128, 'bs': 64, 'trndir': '-sz/160'}, 138 | {'ep': (0, 6), 'lr': (lr, lr * 2)}, 139 | {'ep': 6, 'bs': 128, 'keep_dl': True}, 140 | {'ep': 6, 'lr': lr * 2}, 141 | {'ep': 16, 'sz': 224, 'bs': 64}, # todo: increase this bs 142 | {'ep': 16, 'lr': lr}, 143 | {'ep': 19, 'bs': 192, 'keep_dl': True}, 144 | {'ep': 19, 'lr': 2 * lr / (10 / 1.5)}, 145 | {'ep': 31, 'lr': 2 * lr / (100 / 1.5)}, 146 | {'ep': 37, 'sz': 288, 'bs': 128, 'min_scale': 0.5, 'rect_val': True}, 147 | {'ep': 37, 'lr': 2 * lr / 100}, 148 | {'ep': (38, 50), 'lr': 2 * lr / 1000} 149 | ] 150 | 151 | schedules = {1: one_machine, 152 | 2: two_machines, 153 | 4: four_machines, 154 | 8: eight_machines, 155 | 16: sixteen_machines} 156 | 157 | 158 | # routines to build NCCL ring orders 159 | def get_nccl_params(num_tasks, nproc_per_node): 160 | if num_tasks <= 1: 161 | return 'NCCL_DEBUG=VERSION' 162 | nccl_rings = get_nccl_rings(num_tasks, nproc_per_node) 163 | env = f'NCCL_RINGS="{nccl_rings}" NCCL_SINGLE_RING_THRESHOLD=10 ' 164 | if args.simple_ring_setup: 165 | env = f'NCCL_MIN_NRINGS=16 NCCL_MAX_NRINGS=16 ' 166 | 167 | return env 168 | # return 'NCCL_MIN_NRINGS=2 NCCL_SINGLE_RING_THRESHOLD=10 NCCL_DEBUG=VERSION' 169 | 170 | 171 | def get_nccl_rings(num_tasks, num_gpus): 172 | ring = build_ring_order(range(num_tasks), range(num_gpus)) 173 | ring_rev = build_ring_order(reversed(range(num_tasks)), 174 | reversed(range(num_gpus))) 175 | rotated_gpu_order = [3, 2, 1, 0, 7, 6, 5, 4] 176 | skip_gpu_order = get_skip_order(num_gpus) 177 | if (num_tasks >= 4) and (num_gpus == 8): 178 | assert ((num_tasks % 4) == 0) 179 | skip_machine_order = get_skip_order(num_tasks) 180 | ring_skip = build_ring_order(skip_machine_order, rotated_gpu_order) 181 | ring_skip_rev = build_ring_order(reversed(skip_machine_order), 182 | skip_gpu_order) 183 | rings_arr = [ring, ring_rev, ring_skip, ring_skip_rev] 184 | # rings_arr = [ring, ring_rev, ring_skip] 185 | else: 186 | rings_arr = [ring, ring_rev] 187 | return ' | '.join(rings_arr) 188 | 189 | 190 | def build_ring_order(machine_order, gpu_order): 191 | gpu_order = list(gpu_order) 192 | machine_order = list(machine_order) 193 | ngpus = len(gpu_order) 194 | r_order = [(x * ngpus) + y for x in machine_order for y in gpu_order] 195 | return ' '.join(map(str, r_order)) 196 | 197 | 198 | def get_skip_order(size): 199 | if size == 4: 200 | return [0, 2, 1, 3] 201 | skip_step = 5 if size == 16 else 3 202 | # step size of 3 yields - [0,3,6,1,4,7,2,5] 203 | return [(i * skip_step) % size for i in range(size)] 204 | 205 | 206 | def format_params(arg): 207 | if isinstance(arg, list) or isinstance(arg, dict): 208 | return '\"' + str(arg) + '\"' 209 | else: 210 | return str(arg) 211 | 212 | 213 | def create_volume_tags(name): 214 | return [{ 215 | 'ResourceType': 'volume', 216 | 'Tags': [{ 217 | 'Key': 'Name', 218 | 'Value': name 219 | }] 220 | }] 221 | 222 | 223 | DEFAULT_UNIX_DEVICE = '/dev/xvdf' 224 | ATTACH_WAIT_INTERVAL_SEC = 5 225 | 226 | 227 | def mount_imagenet(job: ncluster.aws_backend.Job): 228 | """Attaches EBS disks with imagenet data to each task of the job.""" 229 | 230 | task0 = job.tasks[0] 231 | zone = u.get_zone() 232 | vols = {} 233 | ec2 = u.get_ec2_resource() 234 | for vol in ec2.volumes.all(): 235 | vols[u.get_name(vol)] = vol 236 | 237 | attach_attempted = False 238 | for i, t in enumerate(job.tasks): 239 | vol_name = f'imagenet_{zone[-2:]}_{i+args.offset:02d}' 240 | assert vol_name in vols, f"Volume {vol_name} not found, set your NCLUSTER_ZONE={zone} and run replicate_imagenet.py" 241 | vol = vols[vol_name] 242 | print(f"Attaching {vol_name} to {t.name}") 243 | if vol.attachments: 244 | instance = ec2.Instance(vol.attachments[0]['InstanceId']) 245 | if instance.id == t.instance.id: 246 | print(f"{vol_name} already attached") 247 | continue 248 | else: # attached to some other instance, detach 249 | print(f"detaching {vol_name} from {u.get_name(instance)}") 250 | vol.detach_from_instance() 251 | while vol.state != 'available': 252 | vol.reload() 253 | time.sleep(5) 254 | print(f"waiting for detachment from {u.get_name(instance)}") 255 | vol.attach_to_instance(InstanceId=t.instance.id, Device=DEFAULT_UNIX_DEVICE) 256 | attach_attempted = True 257 | 258 | else: 259 | vol.attach_to_instance(InstanceId=t.instance.id, Device=DEFAULT_UNIX_DEVICE) 260 | attach_attempted = True 261 | 262 | if attach_attempted: 263 | time.sleep(2) # wait for attachment to succeed 264 | i = 0 265 | vol_name = f'imagenet_{zone[-2:]}_{i+args.offset:02d}' 266 | vol = vols[vol_name] 267 | vol.reload() 268 | assert vol.attachments[0]['InstanceId'] == job.tasks[0].instance.id 269 | 270 | def strip_dev(d): 271 | return d[len('/dev/'):] 272 | 273 | # attach the volume if needed 274 | df_output = task0.run('df', return_output=True) 275 | actual_device = DEFAULT_UNIX_DEVICE 276 | if '/data' not in df_output: 277 | # hack for p3dn's ignoring device name during volume attachment 278 | lsblk_output = task0.run('lsblk', return_output=True) 279 | if strip_dev(DEFAULT_UNIX_DEVICE) not in lsblk_output: 280 | actual_device = '/dev/nvme3n1' 281 | assert strip_dev(actual_device) in lsblk_output, f"Hack for p3dn failed, {actual_device} not found, " \ 282 | f"available devices '{lsblk_output}'" 283 | 284 | job.run(f'sudo mkdir -p /data && sudo chown `whoami` /data && sudo mount {actual_device} /data') 285 | while '/data' not in task0.run('df', return_output=True): 286 | time.sleep(ATTACH_WAIT_INTERVAL_SEC) 287 | print(f"Waiting for attachment") 288 | 289 | 290 | def main(): 291 | if args.image_name == 'pytorch.imagenet.source.v7': 292 | supported_regions = ['us-west-2', 'us-east-1', 'us-east-2'] 293 | assert ncluster.get_region() in supported_regions, f"required AMI {args.image_name} has only been made available in regions {supported_regions}, but your current region is {ncluster.get_region()} (set $AWS_DEFAULT_REGION)" 294 | assert args.machines in schedules, f"{args.machines} not supported, only support {schedules.keys()}" 295 | 296 | if args.mount_imagenet: 297 | datadir = '/data/imagenet' 298 | else: 299 | datadir = '~/data/imagenet' 300 | os.environ['NCLUSTER_AWS_FAST_ROOTDISK'] = '1' # use io2 disk on AWS 301 | 302 | if args.num_tasks >= 16: 303 | assert args.simple_ring_setup, "must use --simple_ring_setup, otherwise NCCL_RINGS env var exceeds cmd-line limit" 304 | 305 | job = ncluster.make_job(name=args.name, 306 | run_name=args.run_name, 307 | num_tasks=args.machines, 308 | image_name=args.image_name, 309 | instance_type=args.instance_type, 310 | disk_size=500, 311 | spot=args.spot, 312 | skip_setup=args.skip_setup, 313 | ) 314 | 315 | task0 = job.tasks[0] 316 | _logdir = task0.logdir # workaround for race condition in creating logdir 317 | 318 | config = {} 319 | for key in os.environ: 320 | if re.match(r"^NCLUSTER", key): 321 | config['env_' + key] = os.getenv(key) 322 | config.update(vars(args)) 323 | 324 | CUDA_HOME = f'/usr/local/cuda' 325 | EFA_HOME = f'/opt/amazon/efa' 326 | MPI_HOME = EFA_HOME 327 | NPROC_PER_NODE = args.nproc_per_node 328 | assert NPROC_PER_NODE <= task0.num_gpus, f"requested {NPROC_PER_NODE} processes, but only {task0.num_gpus} gpus present" 329 | NUM_GPUS = NPROC_PER_NODE * args.num_tasks 330 | 331 | config['NUM_GPUS'] = NUM_GPUS 332 | 333 | config['internal_id'] = u.get_account_number() 334 | config['internal_alias'] = u.get_account_name() 335 | config['region'] = u.get_region() 336 | config['zone'] = u.get_zone() 337 | config['launch_user'] = os.environ.get('USER', '') 338 | config['cmd'] = ' '.join(sys.argv) 339 | config['launcher_conda'] = util.ossystem('echo ${CONDA_PREFIX:-"$(dirname $(which conda))/../"}') 340 | config['launcher_cmd'] = 'python ' + ' '.join(sys.argv) 341 | config['logdir'] = job.logdir 342 | 343 | pickled_config = util.text_pickle(config) 344 | if args.log_all_workers: 345 | job.write(args.internal_config_fn, pickled_config) 346 | else: 347 | job.tasks[0].write(args.internal_config_fn, pickled_config) 348 | 349 | if args.mount_imagenet: 350 | assert u.get_zone(), "Must specify zone when reusing EBS volumes" 351 | mount_imagenet(job) 352 | 353 | if not args.skip_setup: 354 | job.run('rm -f *.py') # remove files backed into imagenet18 release image 355 | job.run('conda init') # missing .bashrc 356 | job.run( 357 | f'{{ source activate {args.conda_env} && bash setup.sh && pip install -U protobuf ; }} && {{ killall python || echo hi ; }} ') 358 | if args.pytorch_nightly: 359 | job.run('conda install -y -c pytorch pytorch-nightly && bash setup.sh') 360 | else: 361 | job.run([f'source ~/.bashrc && conda activate {args.conda_env}', f'killall python || echo hi']) 362 | 363 | job.rsync('.') 364 | 365 | if args.efa: 366 | assert 'efa' in args.image_name # make sure we use EFA-enabled image 367 | hosts_str, hosts_file_str = util.setup_mpi(job, skip_ssh_setup=args.skip_setup) 368 | if not args.skip_setup: 369 | task0.write(HOSTS_SLOTS_FN, hosts_file_str) 370 | 371 | env_params = get_nccl_params(args.machines, args.nproc_per_node) 372 | if args.cuda_debug: 373 | env_params += 'CUDA_LAUNCH_BLOCKING=1 NCCL_DEBUG=INFO ' 374 | else: 375 | env_params += 'NCCL_DEBUG=INFO ' 376 | 377 | env_params += " OMP_NUM_THREADS=1 " 378 | if args.pytorch_use_spawn: 379 | assert args.pytorch_nightly 380 | env_params += " PYTORCH_USE_SPAWN=1 " 381 | if 'WANDB_API_KEY' in os.environ: 382 | env_params += f" WANDB_API_KEY={os.environ.get('WANDB_API_KEY')} " 383 | 384 | # Training script args 385 | default_params = [ 386 | datadir, 387 | '--fp16', 388 | '--logdir', job.logdir, 389 | '--name', f'{args.run_name}-{util.random_id()}', 390 | '--distributed', 391 | '--init-bn0', 392 | '--no-bn-wd', 393 | '--log_all_workers', args.log_all_workers, 394 | ] 395 | 396 | params = ['--phases', util.text_pickle(schedules[args.machines])] 397 | training_params = default_params + params 398 | training_params = ' '.join(map(format_params, training_params)) 399 | 400 | if not args.efa: 401 | # TODO: simplify args processing, or give link to actual commands run 402 | for i, task in enumerate(job.tasks): 403 | dist_params = f'--nproc_per_node={args.nproc_per_node} --nnodes={args.machines} --node_rank={i} --master_addr={job.tasks[0].ip} --master_port={6006}' 404 | cmd = f'{env_params} python -m torch.distributed.launch {dist_params} training/train_imagenet_nv.py {training_params}' 405 | task.run(f'echo {cmd} > {job.logdir}/task-{i}.cmd') # save command-line 406 | task.run(cmd, non_blocking=True) 407 | else: 408 | FI_PROVIDER = 'efa' 409 | if args.pseudo_efa: 410 | FI_PROVIDER = 'sockets' 411 | 412 | local_env = util.format_env_export(LOCAL_RANK='$OMPI_COMM_WORLD_LOCAL_RANK', 413 | RANK='$OMPI_COMM_WORLD_RANK', 414 | WORLD_SIZE='$OMPI_COMM_WORLD_SIZE', 415 | MASTER_ADDR=task0.ip, 416 | MASTER_PORT=6016) 417 | 418 | mpi_env = util.format_env_x(FI_PROVIDER=FI_PROVIDER, # Enables running nccl-tests using EFA provider. 419 | FI_OFI_RXR_RX_COPY_UNEXP=1, #  Disables using bounce buffers for unexpected messages. 420 | FI_OFI_RXR_RX_COPY_OOO=1, # Disables using bounce buffers for out of order messages. 421 | FI_EFA_MR_CACHE_ENABLE=1, # Enables memory region caching. 422 | FI_OFI_RXR_INLINE_MR_ENABLE=1, # Enables inline memory registration of data buffers. 423 | NCCL_TREE_THRESHOLD=10 * 4294967296, # force tree for everything under 40GB 424 | LD_LIBRARY_PATH=f'{CUDA_HOME}/lib:{CUDA_HOME}/lib64:{EFA_HOME}/lib64', 425 | NCCL_DEBUG='INFO', 426 | OMP_NUM_THREADS=1, 427 | WANDB_API_KEY=os.environ.get('WANDB_API_KEY', ''), 428 | PYTORCH_USE_SPAWN=args.pytorch_use_spawn, 429 | NO_WANDB=args.pytorch_use_spawn, 430 | ) 431 | if args.no_op: 432 | worker_script_fn = 'training/env_test.py' 433 | else: 434 | worker_script_fn = 'training/train_imagenet_nv.py' 435 | 436 | local_cmd = [f"{local_env} && source ~/.bashrc && conda activate {args.conda_env} && ", 437 | f'python {worker_script_fn} {training_params} --local_rank=$OMPI_COMM_WORLD_LOCAL_RANK'] 438 | local_cmd = ' '.join(local_cmd) 439 | 440 | cmd = [f"{MPI_HOME}/bin/mpirun -n {NUM_GPUS} -N {NPROC_PER_NODE} --hostfile {HOSTS_SLOTS_FN} ", 441 | f'{mpi_env} ', 442 | f'--mca btl tcp,self --mca btl_tcp_if_exclude lo,docker0 ', 443 | f'--bind-to none ', 444 | f"bash -c '{local_cmd}'"] 445 | cmd = ' '.join(cmd) 446 | 447 | task0.run(cmd, non_blocking=True) 448 | 449 | print(f"Logging to {job.logdir}") 450 | 451 | 452 | if __name__ == '__main__': 453 | main() 454 | -------------------------------------------------------------------------------- /training/train_imagenet_nv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import copy 4 | import os 5 | import shutil 6 | import sys # add path to util which is one level above 7 | import time 8 | import warnings 9 | from datetime import datetime 10 | 11 | import gc 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | 18 | import dataloader 19 | import dist_utils 20 | import experimental_utils 21 | import resnet 22 | 23 | from pprint import pprint as pp 24 | 25 | # util is one level up, so import that 26 | module_path = os.path.dirname(os.path.abspath(__file__)) 27 | sys.path.insert(0, os.path.abspath(f'{module_path}/..')) 28 | 29 | import util 30 | from fp16util import * 31 | from logger import TensorboardLogger, FileLogger 32 | from meter import AverageMeter, NetworkMeter, TimeMeter 33 | 34 | import wandb 35 | 36 | 37 | def get_parser(): 38 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 39 | parser.add_argument('data', metavar='DIR', help='path to dataset') 40 | parser.add_argument('--phases', type=str, 41 | help='Specify epoch order of data resize and learning rate schedule: [{"ep":0,"sz":128,"bs":64},{"ep":5,"lr":1e-2}]') 42 | # parser.add_argument('--save-dir', type=str, default=Path.cwd(), help='Directory to save logs and models.') 43 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', 44 | help='number of data loading workers (default: 8)') 45 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 46 | help='manual epoch number (useful on restarts)') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 48 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)') 50 | parser.add_argument('--init-bn0', action='store_true', help='Intialize running batch norm mean to 0') 51 | parser.add_argument('--print-freq', '-p', default=50, type=int, 52 | metavar='N', help='log/print every this many steps (default: 50)') 53 | parser.add_argument('--no-bn-wd', action='store_true', help='Remove batch norm from weight decay') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--fp16', action='store_true', help='Run model fp16 mode. Default True') 59 | parser.add_argument('--loss-scale', type=float, default=1024, 60 | help='Loss scaling, positive power of 2 values can improve fp16 convergence.') 61 | parser.add_argument('--distributed', action='store_true', help='Run distributed training. Default True') 62 | parser.add_argument('--dist-url', default='env://', type=str, 63 | help='url used to set up distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') 65 | parser.add_argument('--synthetic-data', action='store_true', help='Use synthetic training data') 66 | parser.add_argument('--local_rank', default=0, type=int, 67 | help='Used for multi-process training. Can either be manually set ' + 68 | 'or automatically set by using \'python -m multiproc\'.') 69 | parser.add_argument('--logdir', default='', type=str, 70 | help='where logs go') 71 | parser.add_argument('--skip-auto-shutdown', action='store_true', 72 | help='Shutdown instance at the end of training or failure') 73 | parser.add_argument('--auto-shutdown-success-delay-mins', default=10, type=int, 74 | help='how long to wait until shutting down on success') 75 | parser.add_argument('--auto-shutdown-failure-delay-mins', default=60, type=int, 76 | help='how long to wait before shutting down on error') 77 | 78 | parser.add_argument('--name', type=str, default='imagenet', 79 | help="name of the current run, used for machine naming and tensorboard visualization") 80 | parser.add_argument('--short-epoch', action='store_true', 81 | help='make epochs short (for debugging)') 82 | parser.add_argument('--internal_config_fn', type=str, default='ncluster_config_dict', help='location of filename with extra info to log') 83 | parser.add_argument('--log_all_workers', type=int, default=0, help='log from each worker instead of just chief') 84 | return parser 85 | 86 | 87 | cudnn.benchmark = True 88 | args = get_parser().parse_args() 89 | 90 | # print some debug info 91 | RANK = os.environ.get('RANK', '0') 92 | LOCAL_RANK = os.environ.get('LOCAL_RANK', '-1') 93 | OMPI_COMM_WORLD_LOCAL_RANK = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '-1') 94 | IS_CHIEF = (RANK == '0') 95 | print(f"*** Debug: {os.uname()[1]} RANK={RANK} local_rank_arg={args.local_rank} LOCAL_RANK={LOCAL_RANK}, OMPI_COMM_WORLD_LOCAL_RANK={OMPI_COMM_WORLD_LOCAL_RANK}, {' '.join(sys.argv)}") 96 | 97 | pp(dict(os.environ)) 98 | 99 | # Only want master rank logging to tensorboard 100 | is_master = os.environ.get('RANK', '0') == '0' 101 | 102 | 103 | # for mpirun the messages are propagated to main machine, so don't log in that case 104 | is_rank0 = (args.local_rank == 0) 105 | 106 | tb = TensorboardLogger(args.logdir, is_master=is_master) 107 | log = FileLogger(args.logdir, is_master=is_master, is_rank0=is_rank0) 108 | 109 | 110 | if args.log_all_workers: 111 | group_name=args.name 112 | run_name=args.name + '-' + os.environ.get("RANK", "0") 113 | wandb.init(project='imagenet18', group=group_name, name=run_name) 114 | log.console("initializing wandb logging to group "+args.name+" name ") 115 | else: 116 | if not is_master: 117 | os.environ['WANDB_MODE'] = 'dryrun' # all wandb.log are no-op 118 | log.console("local-only wandb logging for run "+args.name) 119 | wandb.init(project='imagenet18', name=args.name) 120 | log.console("initializing logging to run "+args.name) 121 | 122 | if hasattr(wandb, 'config') and wandb.config is not None: 123 | wandb.config['gpus'] = int(os.environ.get('WORLD_SIZE', 1)) 124 | 125 | 126 | try: 127 | config = util.text_unpickle(open(args.internal_config_fn).read()) 128 | except Exception as e: 129 | log.console(f'couldnt open wandb config file with {e}') 130 | config = {} 131 | 132 | config['worker_conda'] = os.path.basename(util.ossystem('echo ${CONDA_PREFIX:-"$(dirname $(which conda))/../"}')) 133 | if hasattr(wandb, 'config') and wandb.config is not None: 134 | wandb.config.update(config) 135 | util.log_environment() 136 | 137 | 138 | def main(): 139 | os.system('sudo shutdown -c') # cancel previous shutdown command 140 | log.console(args) 141 | tb.log('sizes/world', dist_utils.env_world_size()) 142 | 143 | assert os.path.exists(args.data) 144 | 145 | # need to index validation directory before we start counting the time 146 | dataloader.sort_ar(args.data + '/validation') 147 | 148 | if args.distributed: 149 | log.console('Distributed initializing process group') 150 | torch.cuda.set_device(args.local_rank) 151 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 152 | world_size=dist_utils.env_world_size()) 153 | assert (dist_utils.env_world_size() == dist.get_world_size()) 154 | # todo(y): use global_rank instead of local_rank here 155 | log.console("Distributed: success (%d/%d)" % (args.local_rank, dist.get_world_size())) 156 | 157 | log.console("Loading model") 158 | model = resnet.resnet50(bn0=args.init_bn0).cuda() 159 | if args.fp16: 160 | model = network_to_half(model) 161 | if args.distributed: 162 | model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) 163 | best_top5 = 93 # only save models over 93%. Otherwise it stops to save every time 164 | 165 | global model_params, master_params 166 | if args.fp16: 167 | model_params, master_params = prep_param_lists(model) 168 | else: 169 | model_params = master_params = model.parameters() 170 | 171 | optim_params = experimental_utils.bnwd_optim_params(model, model_params, 172 | master_params) if args.no_bn_wd else master_params 173 | 174 | # define loss function (criterion) and optimizer 175 | criterion = nn.CrossEntropyLoss().cuda() 176 | optimizer = torch.optim.SGD(optim_params, 0, momentum=args.momentum, 177 | weight_decay=args.weight_decay) # start with 0 lr. Scheduler will change this later 178 | 179 | if args.resume: 180 | checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) 181 | model.load_state_dict(checkpoint['state_dict']) 182 | args.start_epoch = checkpoint['epoch'] 183 | best_top5 = checkpoint['best_top5'] 184 | optimizer.load_state_dict(checkpoint['optimizer']) 185 | 186 | # save script so we can reproduce from logs 187 | shutil.copy2(os.path.realpath(__file__), f'{args.logdir}') 188 | 189 | log.console("Creating data loaders (this could take up to 10 minutes if volume needs to be warmed up)") 190 | phases = util.text_unpickle(args.phases) 191 | dm = DataManager([copy.deepcopy(p) for p in phases if 'bs' in p]) 192 | scheduler = Scheduler(optimizer, [copy.deepcopy(p) for p in phases if 'lr' in p]) 193 | 194 | start_time = datetime.now() # Loading start to after everything is loaded 195 | if args.evaluate: 196 | return validate(dm.val_dl, model, criterion, 0, start_time) 197 | 198 | if args.distributed: 199 | log.console('Syncing machines before training') 200 | dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) 201 | 202 | log.event("~~epoch\thours\ttop1\ttop5\n") 203 | for epoch in range(args.start_epoch, scheduler.tot_epochs): 204 | dm.set_epoch(epoch) 205 | 206 | train(dm.trn_dl, model, criterion, optimizer, scheduler, epoch) 207 | top1, top5 = validate(dm.val_dl, model, criterion, epoch, start_time) 208 | 209 | time_diff = (datetime.now() - start_time).total_seconds() / 3600.0 210 | log.event(f'~~{epoch}\t{time_diff:.5f}\t\t{top1:.3f}\t\t{top5:.3f}\n') 211 | 212 | is_best = top5 > best_top5 213 | best_top5 = max(top5, best_top5) 214 | if args.local_rank == 0: 215 | if is_best: 216 | save_checkpoint(epoch, model, best_top5, optimizer, is_best=True, 217 | filename='model_best.pth.tar') 218 | phase = dm.get_phase(epoch) 219 | if phase: save_checkpoint(epoch, model, best_top5, optimizer, 220 | filename=f'sz{phase["bs"]}_checkpoint.path.tar') 221 | 222 | 223 | def train(trn_loader, model, criterion, optimizer, scheduler, epoch): 224 | net_meter = NetworkMeter() 225 | timer = TimeMeter() 226 | losses = AverageMeter() 227 | top1 = AverageMeter() 228 | top5 = AverageMeter() 229 | 230 | # switch to train mode 231 | model.train() 232 | for i, (input, target) in enumerate(trn_loader): 233 | if args.short_epoch and (i > 10): break 234 | batch_num = i + 1 235 | 236 | # TODO(y): cuda is async so some time spent inside step is not inside batch_start/batch_end measurement 237 | timer.batch_start() 238 | scheduler.update_lr(epoch, i + 1, len(trn_loader)) 239 | 240 | # compute output 241 | output = model(input) 242 | loss = criterion(output, target) 243 | 244 | # compute gradient and do SGD step 245 | if args.fp16: 246 | loss = loss * args.loss_scale 247 | model.zero_grad() 248 | loss.backward() 249 | model_grads_to_master_grads(model_params, master_params) 250 | for param in master_params: param.grad.data = param.grad.data / args.loss_scale 251 | optimizer.step() 252 | master_params_to_model_params(model_params, master_params) 253 | loss = loss / args.loss_scale 254 | else: 255 | optimizer.zero_grad() 256 | loss.backward() 257 | optimizer.step() 258 | 259 | # Train batch done. Logging results 260 | timer.batch_end() 261 | corr1, corr5 = correct(output.data, target, topk=(1, 5)) 262 | reduced_loss, batch_total = to_python_float(loss.data), to_python_float(input.size(0)) 263 | if args.distributed: # Must keep track of global batch size, since not all machines are guaranteed equal batches at the end of an epoch 264 | metrics = torch.tensor([batch_total, reduced_loss, corr1, corr5]).float().cuda() 265 | batch_total, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy() 266 | reduced_loss = reduced_loss / dist_utils.env_world_size() 267 | top1acc = to_python_float(corr1) * (100.0 / batch_total) 268 | top5acc = to_python_float(corr5) * (100.0 / batch_total) 269 | 270 | losses.update(reduced_loss, batch_total) 271 | top1.update(top1acc, batch_total) 272 | top5.update(top5acc, batch_total) 273 | 274 | should_print = (batch_num % args.print_freq == 0) or (batch_num == len(trn_loader)) 275 | if args.local_rank == 0 and should_print: 276 | tb.log_memory() 277 | tb.log_trn_times(timer.batch_time.val, timer.data_time.val, input.size(0)) 278 | tb.log_trn_loss(losses.val, top1.val, top5.val) 279 | 280 | recv_gbit, transmit_gbit = net_meter.update_bandwidth() 281 | tb.log("sizes/batch_total", batch_total) 282 | tb.log('net/recv_gbit', recv_gbit) 283 | tb.log('net/transmit_gbit', transmit_gbit) 284 | 285 | output = (f'Epoch: [{epoch}][{batch_num}/{len(trn_loader)}]\t' 286 | f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' 287 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 288 | f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 289 | f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})\t' 290 | f'Data {timer.data_time.val:.3f} ({timer.data_time.avg:.3f})\t' 291 | f'BW {recv_gbit:.3f} {transmit_gbit:.3f}') 292 | log.verbose(output) 293 | 294 | tb.update_step_count(batch_total) 295 | 296 | 297 | def validate(val_loader, model, criterion, epoch, start_time): 298 | timer = TimeMeter() 299 | losses = AverageMeter() 300 | top1 = AverageMeter() 301 | top5 = AverageMeter() 302 | 303 | model.eval() 304 | eval_start_time = time.time() 305 | 306 | for i, (input, target) in enumerate(val_loader): 307 | if args.short_epoch and (i > 10): break 308 | batch_num = i + 1 309 | timer.batch_start() 310 | if args.distributed: 311 | top1acc, top5acc, loss, batch_total = distributed_predict(input, target, model, criterion) 312 | else: 313 | with torch.no_grad(): 314 | output = model(input) 315 | loss = criterion(output, target).data 316 | batch_total = input.size(0) 317 | top1acc, top5acc = accuracy(output.data, target, topk=(1, 5)) 318 | 319 | # Eval batch done. Logging results 320 | timer.batch_end() 321 | losses.update(to_python_float(loss), to_python_float(batch_total)) 322 | top1.update(to_python_float(top1acc), to_python_float(batch_total)) 323 | top5.update(to_python_float(top5acc), to_python_float(batch_total)) 324 | should_print = (batch_num % args.print_freq == 0) or (batch_num == len(val_loader)) 325 | if args.local_rank == 0 and should_print: 326 | output = (f'Test: [{epoch}][{batch_num}/{len(val_loader)}]\t' 327 | f'Time {timer.batch_time.val:.3f} ({timer.batch_time.avg:.3f})\t' 328 | f'Loss {losses.val:.4f} ({losses.avg:.4f})\t' 329 | f'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 330 | f'Acc@5 {top5.val:.3f} ({top5.avg:.3f})') 331 | log.verbose(output) 332 | 333 | tb.log_eval(top1.avg, top5.avg, time.time() - eval_start_time) 334 | tb.log('epoch', epoch) 335 | 336 | return top1.avg, top5.avg 337 | 338 | 339 | def distributed_predict(input, target, model, criterion): 340 | # Allows distributed prediction on uneven batches. Test set isn't always large enough for every GPU to get a batch 341 | batch_size = input.size(0) 342 | output = loss = corr1 = corr5 = valid_batches = 0 343 | 344 | if batch_size: 345 | with torch.no_grad(): 346 | output = model(input) 347 | loss = criterion(output, target).data 348 | # measure accuracy and record loss 349 | valid_batches = 1 350 | corr1, corr5 = correct(output.data, target, topk=(1, 5)) 351 | 352 | metrics = torch.tensor([batch_size, valid_batches, loss, corr1, corr5]).float().cuda() 353 | batch_total, valid_batches, reduced_loss, corr1, corr5 = dist_utils.sum_tensor(metrics).cpu().numpy() 354 | reduced_loss = reduced_loss / valid_batches 355 | 356 | top1 = corr1 * (100.0 / batch_total) 357 | top5 = corr5 * (100.0 / batch_total) 358 | return top1, top5, reduced_loss, batch_total 359 | 360 | 361 | class DataManager(): 362 | def __init__(self, phases): 363 | self.phases = self.preload_phase_data(phases) 364 | 365 | def set_epoch(self, epoch): 366 | cur_phase = self.get_phase(epoch) 367 | if cur_phase: self.set_data(cur_phase) 368 | if hasattr(self.trn_smp, 'set_epoch'): self.trn_smp.set_epoch(epoch) 369 | if hasattr(self.val_smp, 'set_epoch'): self.val_smp.set_epoch(epoch) 370 | 371 | def get_phase(self, epoch): 372 | return next((p for p in self.phases if p['ep'] == epoch), None) 373 | 374 | def set_data(self, phase): 375 | """Initializes data loader.""" 376 | if phase.get('keep_dl', False): 377 | log.event(f'Batch size changed: {phase["bs"]}') 378 | tb.log_size(phase['bs']) 379 | self.trn_dl.update_batch_size(phase['bs']) 380 | return 381 | 382 | log.event(f'Dataset changed.\nImage size: {phase["sz"]}\nBatch size: {phase["bs"]}\nTrain Directory: {phase["trndir"]}\nValidation Directory: {phase["valdir"]}') 383 | tb.log_size(phase['bs'], phase['sz']) 384 | 385 | self.trn_dl, self.val_dl, self.trn_smp, self.val_smp = phase['data'] 386 | self.phases.remove(phase) 387 | 388 | # clear memory before we begin training 389 | gc.collect() 390 | 391 | def preload_phase_data(self, phases): 392 | for phase in phases: 393 | if not phase.get('keep_dl', False): 394 | self.expand_directories(phase) 395 | phase['data'] = self.preload_data(**phase) 396 | return phases 397 | 398 | def expand_directories(self, phase): 399 | trndir = phase.get('trndir', '') 400 | valdir = phase.get('valdir', trndir) 401 | phase['trndir'] = args.data + trndir + '/train' 402 | phase['valdir'] = args.data + valdir + '/validation' 403 | 404 | def preload_data(self, ep, sz, bs, trndir, valdir, **kwargs): # dummy ep var to prevent error 405 | if 'lr' in kwargs: del kwargs['lr'] # in case we mix schedule and data phases 406 | """Pre-initializes data-loaders. Use set_data to start using it.""" 407 | if sz == 128: 408 | val_bs = max(bs, 512) 409 | elif sz == 224: 410 | val_bs = max(bs, 256) 411 | else: 412 | val_bs = max(bs, 128) 413 | return dataloader.get_loaders(trndir, valdir, bs=bs, val_bs=val_bs, sz=sz, workers=args.workers, 414 | distributed=args.distributed, synthetic=args.synthetic_data, **kwargs) 415 | 416 | 417 | # ### Learning rate scheduler 418 | class Scheduler: 419 | def __init__(self, optimizer, phases): 420 | self.optimizer = optimizer 421 | self.current_lr = None 422 | self.phases = [self.format_phase(p) for p in phases] 423 | self.tot_epochs = max([max(p['ep']) for p in self.phases]) 424 | 425 | def format_phase(self, phase): 426 | phase['ep'] = listify(phase['ep']) 427 | phase['lr'] = listify(phase['lr']) 428 | if len(phase['lr']) == 2: 429 | assert (len(phase['ep']) == 2), 'Linear learning rates must contain end epoch' 430 | return phase 431 | 432 | def linear_phase_lr(self, phase, epoch, batch_curr, batch_tot): 433 | lr_start, lr_end = phase['lr'] 434 | ep_start, ep_end = phase['ep'] 435 | if 'epoch_step' in phase: batch_curr = 0 # Optionally change learning rate through epoch step 436 | ep_relative = epoch - ep_start 437 | ep_tot = ep_end - ep_start 438 | return self.calc_linear_lr(lr_start, lr_end, ep_relative, batch_curr, ep_tot, batch_tot) 439 | 440 | def calc_linear_lr(self, lr_start, lr_end, epoch_curr, batch_curr, epoch_tot, batch_tot): 441 | step_tot = epoch_tot * batch_tot 442 | step_curr = epoch_curr * batch_tot + batch_curr 443 | step_size = (lr_end - lr_start) / step_tot 444 | return lr_start + step_curr * step_size 445 | 446 | def get_current_phase(self, epoch): 447 | for phase in reversed(self.phases): 448 | if epoch >= phase['ep'][0]: return phase 449 | raise Exception('Epoch out of range') 450 | 451 | def get_lr(self, epoch, batch_curr, batch_tot): 452 | phase = self.get_current_phase(epoch) 453 | if len(phase['lr']) == 1: return phase['lr'][0] # constant learning rate 454 | return self.linear_phase_lr(phase, epoch, batch_curr, batch_tot) 455 | 456 | def update_lr(self, epoch, batch_num, batch_tot): 457 | lr = self.get_lr(epoch, batch_num, batch_tot) 458 | if self.current_lr == lr: return 459 | if (batch_num == 1) or (batch_num == batch_tot): 460 | log.event(f'Changing LR from {self.current_lr} to {lr}') 461 | 462 | self.current_lr = lr 463 | for param_group in self.optimizer.param_groups: 464 | param_group['lr'] = lr 465 | 466 | tb.log("sizes/lr", lr) 467 | tb.log("sizes/momentum", args.momentum) 468 | 469 | 470 | # item() is a recent addition, so this helps with backward compatibility. 471 | def to_python_float(t): 472 | if isinstance(t, (float, int)): return t 473 | if hasattr(t, 'item'): 474 | return t.item() 475 | else: 476 | return t[0] 477 | 478 | 479 | def save_checkpoint(epoch, model, best_top5, optimizer, is_best=False, filename='checkpoint.pth.tar'): 480 | state = { 481 | 'epoch': epoch + 1, 'state_dict': model.state_dict(), 482 | 'best_top5': best_top5, 'optimizer': optimizer.state_dict(), 483 | } 484 | torch.save(state, filename) 485 | if is_best: shutil.copyfile(filename, f'{args.logdir}/{filename}') 486 | 487 | 488 | def accuracy(output, target, topk=(1,)): 489 | """Computes the accuracy@k for the specified values of k""" 490 | corrrect_ks = correct(output, target, topk) 491 | batch_size = target.size(0) 492 | return [correct_k.float().mul_(100.0 / batch_size) for correct_k in corrrect_ks] 493 | 494 | 495 | def correct(output, target, topk=(1,)): 496 | """Computes the accuracy@k for the specified values of k""" 497 | maxk = max(topk) 498 | _, pred = output.topk(maxk, 1, True, True) 499 | pred = pred.t() 500 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 501 | res = [] 502 | for k in topk: 503 | correct_k = correct[:k].view(-1).sum(0, keepdim=True) 504 | res.append(correct_k) 505 | return res 506 | 507 | 508 | def listify(p=None, q=None): 509 | if p is None: 510 | p = [] 511 | elif not isinstance(p, collections.Iterable): 512 | p = [p] 513 | n = q if type(q) == int else 1 if q is None else len(q) 514 | if len(p) == 1: p = p * n 515 | return p 516 | 517 | # todo(y): pdb debug on error 518 | 519 | 520 | if __name__ == '__main__': 521 | try: 522 | with warnings.catch_warnings(): 523 | warnings.simplefilter("ignore", category=UserWarning) 524 | main() 525 | if not args.skip_auto_shutdown: os.system(f'sudo shutdown -h -P +{args.auto_shutdown_success_delay_mins}') 526 | except Exception as e: 527 | exc_type, exc_value, exc_traceback = sys.exc_info() 528 | import traceback 529 | 530 | traceback.print_tb(exc_traceback, file=sys.stdout) 531 | print(str(e)) 532 | log.event(e) 533 | # in case of exception, wait 2 hours before shutting down 534 | if not args.skip_auto_shutdown: os.system(f'sudo shutdown -h -P +{args.auto_shutdown_failure_delay_mins}') 535 | tb.close() 536 | --------------------------------------------------------------------------------