├── __init__.py ├── cluster4 ├── audio_data ├── __init__.py ├── labels.json ├── utils.py ├── distributed.py ├── an4.py ├── librispeech.py ├── data_loader.py └── an4_val_manifest.csv ├── requirements.txt ├── exp_configs ├── vgg16.conf ├── resnet20.conf ├── lstm.conf ├── lstman4.conf ├── alexnet.conf └── resnet50.conf ├── single.sh ├── labels.json ├── horovod_mpi.sh ├── gtopk_mpi.sh ├── scripts ├── killp2p.sh ├── ijcai2019 │ ├── utils.py │ └── plot_loss.py ├── test_read_hdf5.py ├── eval.sh ├── icdcs2019 │ ├── utils.py │ └── plot_sth.py └── create_hdf5.py ├── model_builder.py ├── settings.py ├── models ├── res_utils.py ├── __init__.py ├── lstman4.py ├── vgg.py ├── caffe_cifar.py ├── lstm.py ├── alexnet.py ├── densenet.py ├── resnext.py ├── resnet.py ├── preresnet.py ├── resnet_mod.py ├── imagenet_resnet.py └── lstm_models.py ├── utils.py ├── datasets.py ├── .gitignore ├── README.md ├── evaluate.py ├── compression.py ├── ptb_reader.py ├── horovod_trainer.py ├── gtopk_trainer.py ├── decoder.py ├── distributed_optimizer.py └── LICENSE /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cluster4: -------------------------------------------------------------------------------- 1 | localhost slots=4 2 | -------------------------------------------------------------------------------- /audio_data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | simplejson 2 | tensorboardX 3 | ujson 4 | coloredlogs 5 | tqdm 6 | h5py 7 | mpi4py 8 | psutil 9 | -------------------------------------------------------------------------------- /exp_configs/vgg16.conf: -------------------------------------------------------------------------------- 1 | lr="${lr:-0.1}" 2 | batch_size="${batch_size:-32}" 3 | dnn=vgg16 4 | dataset=cifar10 5 | max_epochs="${max_epochs:-141}" 6 | nstepsupdate=1 7 | data_dir=./data 8 | -------------------------------------------------------------------------------- /exp_configs/resnet20.conf: -------------------------------------------------------------------------------- 1 | lr="${lr:-0.1}" 2 | batch_size="${batch_size:-32}" 3 | dnn=resnet20 4 | dataset=cifar10 5 | max_epochs="${max_epochs:-141}" 6 | nstepsupdate=1 7 | data_dir=./data 8 | -------------------------------------------------------------------------------- /exp_configs/lstm.conf: -------------------------------------------------------------------------------- 1 | lr="${lr:-1.0}" 2 | batch_size="${batch_size:-20}" 3 | dnn=lstm 4 | dataset=ptb 5 | max_epochs="${max_epochs:-50}" 6 | nstepsupdate=1 7 | data_dir=/home/comp/csshshi/data/PennTreeBank 8 | -------------------------------------------------------------------------------- /exp_configs/lstman4.conf: -------------------------------------------------------------------------------- 1 | lr="${lr:-0.0003}" 2 | batch_size="${batch_size:-8}" 3 | dnn=lstman4 4 | dataset=an4 5 | max_epochs="${max_epochs:-100}" 6 | nstepsupdate=1 7 | data_dir=/home/comp/csshshi/data/an4data 8 | -------------------------------------------------------------------------------- /exp_configs/alexnet.conf: -------------------------------------------------------------------------------- 1 | lr=0.01 2 | batch_size="${batch_size:-256}" 3 | dnn=alexnet 4 | dataset=imagenet 5 | max_epochs="${max_epochs:-95}" 6 | nstepsupdate=1 7 | data_dir=/home/comp/csshshi/data/imagenet/imagenet_hdf5 8 | -------------------------------------------------------------------------------- /exp_configs/resnet50.conf: -------------------------------------------------------------------------------- 1 | lr=0.01 2 | batch_size="${batch_size:-64}" 3 | dnn=resnet50 4 | dataset=imagenet 5 | max_epochs="${max_epochs:-95}" 6 | nstepsupdate=1 7 | data_dir=/home/comp/csshshi/data/imagenet/imagenet_hdf5 8 | -------------------------------------------------------------------------------- /single.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dnn="${dnn:-resnet20}" 3 | source exp_configs/$dnn.conf 4 | nstepsupdate=1 5 | python dl_trainer.py --dnn $dnn --dataset $dataset --max-epochs $max_epochs --batch-size $batch_size --data-dir $data_dir --lr $lr --nsteps-update $nstepsupdate 6 | -------------------------------------------------------------------------------- /labels.json: -------------------------------------------------------------------------------- 1 | [ 2 | "_", 3 | "'", 4 | "A", 5 | "B", 6 | "C", 7 | "D", 8 | "E", 9 | "F", 10 | "G", 11 | "H", 12 | "I", 13 | "J", 14 | "K", 15 | "L", 16 | "M", 17 | "N", 18 | "O", 19 | "P", 20 | "Q", 21 | "R", 22 | "S", 23 | "T", 24 | "U", 25 | "V", 26 | "W", 27 | "X", 28 | "Y", 29 | "Z", 30 | " " 31 | ] -------------------------------------------------------------------------------- /audio_data/labels.json: -------------------------------------------------------------------------------- 1 | [ 2 | "_", 3 | "'", 4 | "A", 5 | "B", 6 | "C", 7 | "D", 8 | "E", 9 | "F", 10 | "G", 11 | "H", 12 | "I", 13 | "J", 14 | "K", 15 | "L", 16 | "M", 17 | "N", 18 | "O", 19 | "P", 20 | "Q", 21 | "R", 22 | "S", 23 | "T", 24 | "U", 25 | "V", 26 | "W", 27 | "X", 28 | "Y", 29 | "Z", 30 | " " 31 | ] -------------------------------------------------------------------------------- /horovod_mpi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dnn="${dnn:-resnet20}" 3 | source exp_configs/$dnn.conf 4 | nworkers="${nworkers:-4}" 5 | nwpernode=4 6 | nstepsupdate=1 7 | PY=python 8 | mpirun -np $nworkers -hostfile cluster$nworkers -bind-to none -map-by slot \ 9 | -x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \ 10 | -mca pml ob1 -mca btl ^openib \ 11 | -x NCCL_P2P_DISABLE=1 \ 12 | $PY horovod_trainer.py --dnn $dnn --dataset $dataset --max-epochs $max_epochs --batch-size $batch_size --nworkers $nworkers --data-dir $data_dir --lr $lr --nsteps-update $nstepsupdate --nwpernode $nwpernode 13 | -------------------------------------------------------------------------------- /gtopk_mpi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | dnn="${dnn:-resnet20}" 3 | density="${density:-0.001}" 4 | source exp_configs/$dnn.conf 5 | compressor="${compressor:-gtopk}" 6 | nworkers="${nworkers:-4}" 7 | nwpernode=4 8 | sigmascale=2.5 9 | PY=python 10 | mpirun --prefix $MPIPATH -np $nworkers -hostfile cluster$nworkers --bind-to none -map-by slot \ 11 | -x LD_LIBRARY_PATH \ 12 | $PY -m mpi4py gtopk_trainer.py --dnn $dnn --dataset $dataset --max-epochs $max_epochs --batch-size $batch_size --nworkers $nworkers --data-dir $data_dir --lr $lr --nwpernode $nwpernode --nsteps-update $nstepsupdate --compression --sigma-scale $sigmascale --density $density --compressor $compressor 13 | -------------------------------------------------------------------------------- /scripts/killp2p.sh: -------------------------------------------------------------------------------- 1 | #kill -9 `ps aux|grep 'python client.py' | awk '{print $2}'` 2 | #kill -9 `ps aux|grep 'python client_mp.py' | awk '{print $2}'` 3 | #kill -9 `ps aux|grep 'python psclient.py' | awk '{print $2}'` 4 | #kill -9 `ps aux|grep 'python dl_trainer.py' | awk '{print $2}'` 5 | kill -9 `ps aux|grep 'python -m mpi4py robust_trainer.py' | awk '{print $2}'` 6 | #kill -9 `ps aux|grep 'python robust_trainer.py' | awk '{print $2}'` 7 | kill -9 `ps aux|grep 'python horovod_trainer.py' | awk '{print $2}'` 8 | kill -9 `ps aux|grep 'python -m mpi4py horovod_trainer.py' | awk '{print $2}'` 9 | kill -9 `ps aux|grep 'python -m mpi4py hovorod_trainer.py' | awk '{print $2}'` 10 | -------------------------------------------------------------------------------- /scripts/ijcai2019/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | def update_fontsize(ax, fontsize=12.): 5 | for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + 6 | ax.get_xticklabels() + ax.get_yticklabels()): 7 | item.set_fontsize(fontsize) 8 | 9 | def autolabel(rects, ax, label, rotation=90): 10 | """ 11 | Attach a text label above each bar displaying its height 12 | """ 13 | for rect in rects: 14 | height = rect.get_y() + rect.get_height() 15 | ax.text(rect.get_x() + rect.get_width()/2., 1.03*height, 16 | label, 17 | ha='center', va='bottom', rotation=rotation) 18 | -------------------------------------------------------------------------------- /model_builder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class MnistNet(nn.Module): 5 | def __init__(self): 6 | super(MnistNet, self).__init__() 7 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 8 | self.conv2 = nn.Conv2d(10, 20, 5) 9 | self.conv2_drop = nn.Dropout2d() 10 | self.fc1 = nn.Linear(320, 50) 11 | self.fc2 = nn.Linear(50, 10) 12 | self.name = 'mnistnet' 13 | 14 | def forward(self, x): 15 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 16 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 17 | x = x.view(-1, 320) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, training=self.training) 20 | x = self.fc2(x) 21 | return x 22 | 23 | 24 | def model_builder(dnn): 25 | pass 26 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import logging 4 | import socket 5 | 6 | DEBUG =0 7 | 8 | SPARSE=False 9 | WARMUP=True 10 | DELAY_COMM=1 11 | 12 | if SPARSE: 13 | PREFIX='compression' 14 | else: 15 | PREFIX='baseline' 16 | if WARMUP: 17 | PREFIX=PREFIX+'-gwarmup' 18 | 19 | PREFIX=PREFIX+'-dc'+str(DELAY_COMM) 20 | PREFIX=PREFIX+'-model'+'-ijcai2019' 21 | TENSORBOARD=True 22 | PROFILING_NORM=False 23 | 24 | hostname = socket.gethostname() 25 | logger = logging.getLogger(hostname) 26 | 27 | if DEBUG: 28 | logger.setLevel(logging.DEBUG) 29 | else: 30 | logger.setLevel(logging.INFO) 31 | 32 | strhdlr = logging.StreamHandler() 33 | logger.addHandler(strhdlr) 34 | formatter = logging.Formatter('%(asctime)s [%(filename)s:%(lineno)d] %(levelname)s %(message)s') 35 | strhdlr.setFormatter(formatter) 36 | 37 | -------------------------------------------------------------------------------- /scripts/test_read_hdf5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse, os 4 | import glob 5 | import h5py 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | OUTPUTDIR='/tmp/imagenet_hdf5' 11 | 12 | def test_read(): 13 | h5file = os.path.join(OUTPUTDIR, 'imagenet-shuffled.hdf5') 14 | with h5py.File(h5file, 'r') as hf: 15 | imgs = hf['train_img'][0:10, ...] 16 | labels = hf["train_labels"][0:10] 17 | img = imgs[3] 18 | #img = img.transpose((1, 2, 0)) #np.moveaxis(imgs[0], 2, 0) 19 | #img = np.moveaxis(img, 2, 0) 20 | #img = img[...,[2,0,1]] 21 | print('labels: ', labels) 22 | print('image shape: ', img.shape) 23 | #cv2.imshow('h', img) 24 | plt.imshow(img) 25 | #cv2.waitKey(0) 26 | plt.show() 27 | 28 | if __name__ == '__main__': 29 | test_read() 30 | -------------------------------------------------------------------------------- /models/res_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DownsampleA(nn.Module): 5 | 6 | def __init__(self, nIn, nOut, stride): 7 | super(DownsampleA, self).__init__() 8 | assert stride == 2 9 | self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) 10 | 11 | def forward(self, x): 12 | x = self.avg(x) 13 | return torch.cat((x, x.mul(0)), 1) 14 | 15 | class DownsampleC(nn.Module): 16 | 17 | def __init__(self, nIn, nOut, stride): 18 | super(DownsampleC, self).__init__() 19 | assert stride != 1 or nIn != nOut 20 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=1, stride=stride, padding=0, bias=False) 21 | 22 | def forward(self, x): 23 | x = self.conv(x) 24 | return x 25 | 26 | class DownsampleD(nn.Module): 27 | 28 | def __init__(self, nIn, nOut, stride): 29 | super(DownsampleD, self).__init__() 30 | assert stride == 2 31 | self.conv = nn.Conv2d(nIn, nOut, kernel_size=2, stride=stride, padding=0, bias=False) 32 | self.bn = nn.BatchNorm2d(nOut) 33 | 34 | def forward(self, x): 35 | x = self.conv(x) 36 | x = self.bn(x) 37 | return x 38 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """The models subpackage contains definitions for the following model 2 | architectures: 3 | - `ResNeXt` for CIFAR10 CIFAR100 4 | You can construct a model with random weights by calling its constructor: 5 | .. code:: python 6 | import models 7 | resnext29_16_64 = models.ResNeXt29_16_64(num_classes) 8 | resnext29_8_64 = models.ResNeXt29_8_64(num_classes) 9 | resnet20 = models.ResNet20(num_classes) 10 | resnet32 = models.ResNet32(num_classes) 11 | 12 | 13 | .. ResNext: https://arxiv.org/abs/1611.05431 14 | """ 15 | 16 | from .resnext import resnext29_8_64, resnext29_16_64 17 | from .resnet import resnet20, resnet32, resnet44, resnet56, resnet110 18 | from .preresnet import preresnet20, preresnet32, preresnet44, preresnet56, preresnet110 19 | from .caffe_cifar import caffe_cifar 20 | from .densenet import densenet100_12 21 | from .resnet_mod import resnet_mod20, resnet_mod32, resnet_mod44, resnet_mod56, resnet_mod110 22 | 23 | from .imagenet_resnet import resnet18, resnet34, resnet50, resnet101, resnet152 24 | from .vgg import VGG 25 | from .alexnet import AlexNet 26 | from .lstman4 import create_net as LSTMAN4 27 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import time 3 | import os 4 | import numpy as np 5 | 6 | 7 | def gen_random_id(): 8 | id_ = hashlib.sha256() 9 | id_.update(str(time.time())) 10 | return id_.hexdigest() 11 | 12 | def create_path(relative_path): 13 | dirname = os.path.dirname(__file__) 14 | filename = os.path.join(dirname, relative_path) 15 | if not os.path.isdir(filename): 16 | try: 17 | #os.mkdir(filename) 18 | os.makedirs(filename) 19 | except: 20 | pass 21 | 22 | def update_fontsize(ax, fontsize=12.): 23 | for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + 24 | ax.get_xticklabels() + ax.get_yticklabels()): 25 | item.set_fontsize(fontsize) 26 | 27 | def autolabel(rects, ax, label, rotation=90): 28 | """ 29 | Attach a text label above each bar displaying its height 30 | """ 31 | for rect in rects: 32 | height = rect.get_y() + rect.get_height() 33 | ax.text(rect.get_x() + rect.get_width()/2., 1.03*height, 34 | label, 35 | ha='center', va='bottom', rotation=rotation) 36 | 37 | def topk(tensor, k): 38 | indexes = np.abs(tensor).argsort()[-k:] 39 | return indexes, tensor[indexes] 40 | 41 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import torch 5 | import h5py 6 | import numpy as np 7 | 8 | class DatasetHDF5(torch.utils.data.Dataset): 9 | def __init__(self, hdf5fn, t, transform=None, target_transform=None): 10 | """ 11 | t: 'train' or 'val' 12 | """ 13 | super(DatasetHDF5, self).__init__() 14 | self.hf = h5py.File(hdf5fn, 'r', libver='latest', swmr=True) 15 | self.t = t 16 | self.n_images= self.hf['%s_img'%self.t].shape[0] 17 | self.dlabel = self.hf['%s_labels'%self.t][...] 18 | self.d = self.hf['%s_img'%self.t] 19 | self.transform = transform 20 | self.target_transform = target_transform 21 | 22 | def _get_dataset_x_and_target(self, index): 23 | img = self.d[index, ...] 24 | target = self.dlabel[index] 25 | return img, np.int64(target) 26 | 27 | def __getitem__(self, index): 28 | img, target = self._get_dataset_x_and_target(index) 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | if self.target_transform is not None: 32 | target = self.target_transform(target) 33 | return img, target 34 | 35 | def __len__(self): 36 | return self.n_images 37 | -------------------------------------------------------------------------------- /models/lstman4.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import json 5 | import models.lstm_models as lm 6 | 7 | 8 | def create_net(nb_labers=5, labels=None, rnn_type='lstm', bidirectional=False, datapath=None, hidden_size=800, hidden_layers=5, sample_rate=16000, window_size=0.02, window_stride=0.01, window='hamming', noise_dir=None, noise_prob=0.4, noise_min=0.0, noise_max=0.5): 9 | if datapath is None: 10 | datapath = '/home/comp/zhtang/p2p-sync/p2p-dl/audio_data' 11 | if labels is None: 12 | with open(os.path.join(datapath, 'labels.json')) as label_file: 13 | labels = str(''.join(json.load(label_file))) 14 | 15 | print(" ============= audio_conf preparing =================") 16 | audio_conf = dict(sample_rate=sample_rate, 17 | window_size=window_size, 18 | window_stride=window_stride, 19 | window=window, 20 | noise_dir=noise_dir, 21 | noise_prob=noise_prob, 22 | noise_levels=(noise_min, noise_max)) 23 | 24 | print(" ============= net preparing =================") 25 | net = lm.DeepSpeech(rnn_hidden_size=hidden_size, 26 | nb_layers=hidden_layers, 27 | labels=labels, 28 | rnn_type=lm.supported_rnns[rnn_type], 29 | audio_conf=audio_conf, 30 | bidirectional=bidirectional) 31 | ext = {'audio_conf': audio_conf, 32 | 'labels': labels} 33 | return net, ext 34 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | cfg = { 7 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name): 16 | super(VGG, self).__init__() 17 | self.features = self._make_layers(cfg[vgg_name]) 18 | self.fc = nn.Linear(512, 10) 19 | 20 | def forward(self, x): 21 | out = self.features(x) 22 | out = out.view(out.size(0), -1) 23 | out = self.fc(out) 24 | return out 25 | 26 | def _make_layers(self, cfg): 27 | layers = [] 28 | in_channels = 3 29 | for x in cfg: 30 | if x == 'M': 31 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 32 | else: 33 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(x), 35 | nn.ReLU(inplace=True)] 36 | in_channels = x 37 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 38 | return nn.Sequential(*layers) 39 | 40 | 41 | def test(): 42 | net = VGG('VGG11') 43 | x = torch.randn(2,3,32,32) 44 | y = net(x) 45 | print(y.size()) 46 | 47 | # test() 48 | -------------------------------------------------------------------------------- /audio_data/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import fnmatch 4 | import io 5 | import os 6 | from tqdm import tqdm 7 | import subprocess 8 | import torch.distributed as dist 9 | 10 | 11 | def create_manifest(data_path, output_path, min_duration=None, max_duration=None): 12 | file_paths = [os.path.join(dirpath, f) 13 | for dirpath, dirnames, files in os.walk(data_path) 14 | for f in fnmatch.filter(files, '*.wav')] 15 | file_paths = order_and_prune_files(file_paths, min_duration, max_duration) 16 | with io.FileIO(output_path, "w") as file: 17 | for wav_path in tqdm(file_paths, total=len(file_paths)): 18 | transcript_path = wav_path.replace('/wav/', '/txt/').replace('.wav', '.txt') 19 | sample = os.path.abspath(wav_path) + ',' + os.path.abspath(transcript_path) + '\n' 20 | file.write(sample.encode('utf-8')) 21 | print('\n') 22 | 23 | 24 | def order_and_prune_files(file_paths, min_duration, max_duration): 25 | print("Sorting manifests...") 26 | duration_file_paths = [(path, float(subprocess.check_output( 27 | ['soxi -D \"%s\"' % path.strip()], shell=True))) for path in file_paths] 28 | if min_duration and max_duration: 29 | print("Pruning manifests between %d and %d seconds" % (min_duration, max_duration)) 30 | duration_file_paths = [(path, duration) for path, duration in duration_file_paths if 31 | min_duration <= duration <= max_duration] 32 | 33 | def func(element): 34 | return element[1] 35 | 36 | duration_file_paths.sort(key=func) 37 | return [x[0] for x in duration_file_paths] # Remove durations 38 | 39 | def reduce_tensor(tensor, world_size): 40 | rt = tensor.clone() 41 | dist.all_reduce(rt, op=dist.reduce_op.SUM) 42 | rt /= world_size 43 | return rt 44 | 45 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #python evaluate.py --model-path weights/allreduce/resnet20-n4-bs32-lr0.1000 --dnn resnet20 --dataset cifar10 --nepochs 140 --data-dir ./data 2 | #python evaluate.py --model-path weights/comp-gtopk-baseline-gwarmup-dc1-model-debug2-ds0.001/resnet110-n4-bs128-lr0.1000 --dnn resnet110 --dataset cifar10 --nepochs 140 --data-dir ./data 3 | python evaluate.py --model-path weights/comp-gtopk-baseline-gwarmup-dc1-model-debug2-ds0.001/resnet110-n4-bs128-lr0.1000 --dnn resnet110 --dataset cifar10 --nepochs 140 --data-dir ./data 4 | #python evaluate.py --model-path weights/comp-gtopk-baseline-gwarmup-dc1-model-ijcai-wu1-ds0.001/resnet20-n4-bs32-lr0.1000 --dnn resnet20 --dataset cifar10 --nepochs 140 --data-dir ./data 5 | #python evaluate.py --model-path weights/allreduce/vgg16-n4-bs128-lr0.1000 --dnn vgg16 --dataset cifar10 --nepochs 140 --data-dir ./data 6 | #python evaluate.py --model-path weights/comp-gtopk-baseline-gwarmup-dc1-model-ijcai-wu1-ds0.001/vgg16-n4-bs128-lr0.1000 --dnn vgg16 --dataset cifar10 --nepochs 140 --data-dir ./data 7 | 8 | #python evaluate.py --model-path weights/comp-gtopk-baseline-gwarmup-dc1-model-ijcai-wu1-ds0.001/lstm-n4-bs32-lr1.0000 --dnn lstm --dataset ptb --nepochs 40 --data-dir /home/shshi/data/PennTreeBank 9 | #python evaluate.py --model-path weights/allreduce/lstm-n4-bs5-lr20.0000 --dnn lstm --dataset ptb --nepochs 40 --data-dir /home/shshi/data/PennTreeBank 10 | #python evaluate.py --model-path weights/allreduce/lstm-n4-bs5-lr20.0000 --dnn lstm --dataset ptb --nepochs 40 --data-dir /home/comp/csshshi/data/PennTreeBank 11 | 12 | #python evaluate.py --model-path weights/allreduce/lstman4-n4-bs8-lr0.0003 --dnn lstman4 --dataset an4 --nepochs 90 --data-dir /home/comp/csshshi/data/an4data 13 | #python evaluate.py --model-path weights/allreduce/lstman4-n4-bs32-lr0.0003 --dnn lstman4 --dataset an4 --nepochs 90 --data-dir /home/comp/csshshi/data/an4data 14 | -------------------------------------------------------------------------------- /models/caffe_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | import math 8 | 9 | ## http://torch.ch/blog/2015/07/30/cifar.html 10 | class CifarCaffeNet(nn.Module): 11 | def __init__(self, num_classes): 12 | super(CifarCaffeNet, self).__init__() 13 | 14 | self.num_classes = num_classes 15 | 16 | self.block_1 = nn.Sequential( 17 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), 18 | nn.MaxPool2d(kernel_size=3, stride=2), 19 | nn.ReLU(), 20 | nn.BatchNorm2d(32)) 21 | 22 | self.block_2 = nn.Sequential( 23 | nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), 24 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 25 | nn.ReLU(), 26 | nn.AvgPool2d(kernel_size=3, stride=2), 27 | nn.BatchNorm2d(64)) 28 | 29 | self.block_3 = nn.Sequential( 30 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 31 | nn.Conv2d(64,128, kernel_size=3, stride=1, padding=1), 32 | nn.ReLU(), 33 | nn.AvgPool2d(kernel_size=3, stride=2), 34 | nn.BatchNorm2d(128)) 35 | 36 | self.classifier = nn.Linear(128*9, self.num_classes) 37 | 38 | for m in self.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 41 | m.weight.data.normal_(0, math.sqrt(2. / n)) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | m.weight.data.fill_(1) 44 | m.bias.data.zero_() 45 | elif isinstance(m, nn.Linear): 46 | init.kaiming_normal(m.weight) 47 | m.bias.data.zero_() 48 | 49 | def forward(self, x): 50 | x = self.block_1.forward(x) 51 | x = self.block_2.forward(x) 52 | x = self.block_3.forward(x) 53 | x = x.view(x.size(0), -1) 54 | #print ('{}'.format(x.size())) 55 | return self.classifier(x) 56 | 57 | def caffe_cifar(num_classes=10): 58 | model = CifarCaffeNet(num_classes) 59 | return model 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## ginore this file ## 2 | *.pth 3 | *.log 4 | *.pyc 5 | *.swp 6 | *.swo 7 | *.pfm 8 | *_local.sh 9 | tmp 10 | data* 11 | weights 12 | .nfs* 13 | runs 14 | *.ipynb 15 | *.npy 16 | bk/ 17 | *.npy* 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | 77 | # Flask stuff: 78 | instance/ 79 | .webassets-cache 80 | 81 | # Scrapy stuff: 82 | .scrapy 83 | 84 | # Sphinx documentation 85 | docs/_build/ 86 | 87 | # PyBuilder 88 | target/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | -------------------------------------------------------------------------------- /models/lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.autograd import Variable 3 | from torch import Tensor 4 | 5 | class lstm(nn.Module): 6 | def __init__(self, vocab_size, embedding_dim=1500, num_steps=35, batch_size=20, num_layers=2, dp_keep_prob=0.35): 7 | super(lstm, self).__init__() 8 | self.embedding_dim = embedding_dim 9 | self.num_steps = num_steps 10 | self.batch_size = batch_size 11 | self.vocab_size = vocab_size 12 | self.dp_keep_prob = dp_keep_prob 13 | self.num_layers = num_layers 14 | self.dropout = nn.Dropout(1 - dp_keep_prob) 15 | self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) 16 | self.lstm = nn.LSTM(input_size=embedding_dim, 17 | hidden_size=embedding_dim, 18 | num_layers=num_layers, 19 | dropout=1 - dp_keep_prob) 20 | self.sm_fc = nn.Linear(in_features=embedding_dim, 21 | out_features=vocab_size) 22 | self.init_weights() 23 | 24 | def init_weights(self): 25 | init_range = 0.1 26 | self.word_embeddings.weight.data.uniform_(-init_range, init_range) 27 | self.sm_fc.bias.data.fill_(0.0) 28 | self.sm_fc.weight.data.uniform_(-init_range, init_range) 29 | 30 | def init_hidden(self): 31 | weight = next(self.parameters()).data 32 | return (Variable(weight.new(self.num_layers, self.batch_size, self.embedding_dim).zero_()), 33 | Variable(weight.new(self.num_layers, self.batch_size, self.embedding_dim).zero_())) 34 | 35 | def forward(self, inputs, hidden): 36 | embeds = self.dropout(self.word_embeddings(inputs)) 37 | lstm_out, hidden = self.lstm(embeds, hidden) 38 | lstm_out = self.dropout(lstm_out) 39 | logits = self.sm_fc(lstm_out.view(-1, self.embedding_dim)) 40 | return logits.view(self.num_steps, self.batch_size, self.vocab_size), hidden 41 | 42 | def repackage_hidden(h): 43 | """Wraps hidden states in new Variables, to detach them from their history.""" 44 | if isinstance(h, Variable): 45 | return h.detach() 46 | else: 47 | return tuple(repackage_hidden(v) for v in h) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gTop-*k* S-SGD 2 | ## Introduction 3 | This repository contains the codes of the gTop-k S-SGD (Synchronous Schocastic Gradident Descent) papers appeared at *ICDCS 2019* (this version targets at empirical study) and *IJCAI 2019* (this version targets at theorectical study). gTop-k S-SGD is a communication-efficient distributed training algorithm for deep learning. The key idea of gTop-k is that each work only sends/recieves top-k (k could be 0.1% of the gradient dimension d, i.e., k=0.001d) with a tree structure (recursive doubling) so that the communication complexity is O(k logP), where P is the number of workers. The convergence property of gTop-k S-SGD is provable under some weak analytical assumptions. The communication complexity comparision with tranditional ring-based all-reduce (Dense) and Top-k sparsification is shown as follows: 4 | 5 | | S-SGD | Complexity | Time Cost | 6 | | ------------- |:-------------:| -----:| 7 | | Dense | O(d) | 2\alpha(P-1)+2(P-1)/Pd\beta | 8 | | Top-k | O(kP)| \alpha logP+2(P-1)k\beta | 9 | | **gTop-k** | **O(k logP)** | **\alpha logP+2klogP\beta** | 10 | 11 | For more details about the algorithm, please refer to our papers. 12 | 13 | ## Installation 14 | ### Prerequisites 15 | - Python 2 or 3 16 | - PyTorch-0.4.+ 17 | - [OpenMPI-3.1.+](https://www.open-mpi.org/software/ompi/v3.1/) 18 | - [Horovod-0.14.+](https://github.com/horovod/horovod): Optional if not run the dense version 19 | ### Quick Start 20 | ``` 21 | git clone https://github.com/hclhkbu/gtopkssgd.git 22 | cd gtopkssgd 23 | pip install -r requirements.txt 24 | dnn=resnet20 nworkers=4 ./gtopk_mpi.sh 25 | ``` 26 | Assume that you have 4 GPUs on a single node and everything works well, you will see that there are 4 workers running at a single node training the ResNet-20 model with the Cifar-10 data set using the gTop-k S-SGD algorithm. 27 | ## Papers 28 | - S. Shi, Q. Wang, K. Zhao, Z. Tang, Y. Wang, X. Huang, and X.-W. Chu, “A Distributed Synchronous SGD Algorithm with Global Top-k Sparsification for Low Bandwidth Networks,” *IEEE ICDCS 2019*, Dallas, Texas, USA, July 2019. [PDF](https://arxiv.org/pdf/1901.04359.pdf) 29 | - S. Shi, K. Zhao, Q. Wang, Z. Tang, and X.-W. Chu, “A Convergence Analysis of Distributed SGD with Communication-Efficient Gradient Sparsification,” *IJCAI 2019*, Macau, P.R.C., August 2019. [PDF](https://www.ijcai.org/proceedings/2019/0473.pdf) 30 | ## Referred Models 31 | - Deep speech: [https://github.com/SeanNaren/deepspeech.pytorch](https://github.com/SeanNaren/deepspeech.pytorch) 32 | - PyTorch examples: [https://github.com/pytorch/examples](https://github.com/pytorch/examples) 33 | -------------------------------------------------------------------------------- /scripts/icdcs2019/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def read_log(filename): 5 | f = open(filename, 'r') 6 | sizes = [] 7 | computes = [] 8 | comms = [] 9 | merged_comms = [] 10 | for l in f.readlines(): 11 | items = l.split('[')[1][0:-2].split(',') 12 | items = [float(it.strip()) for it in items] 13 | if int(items[2]) == 0 or int(items[3]) == 0:# or int(items[1]) > 1000000: 14 | continue 15 | #sizes.append(float(items[1])*4) 16 | sizes.append(float(items[1])) 17 | computes.append(items[2]) 18 | #comms.append(items[3]) 19 | comms.append(float(items[4])) 20 | merged_comms.append(items[4]) 21 | f.close() 22 | #print('filename: ', filename) 23 | #print('sizes: ', sizes) 24 | #print('total sizes: ', np.sum(sizes)) 25 | #print('sizes len: ', len(sizes)) 26 | #print('computes: ', computes) 27 | #print('communications: ', comms) 28 | return sizes, comms, computes, merged_comms 29 | 30 | def read_p100_log(filename): 31 | f = open(filename, 'r') 32 | computes = [] 33 | sizes = [] 34 | for l in f.readlines(): 35 | items = l.split(',') 36 | sizes.append(float(items[-2])) 37 | computes.append(float(items[-1])) 38 | #comms.append(items[3]) 39 | # remove duplicate 40 | reals = [] 41 | realc = [] 42 | pre = -1 43 | 44 | for i, comp in enumerate(computes): 45 | if pre != comp: 46 | reals.append(sizes[i]) 47 | realc.append(comp) 48 | else: 49 | reals[-1] += sizes[i] 50 | pre = comp 51 | f.close() 52 | return sizes, realc 53 | 54 | 55 | def plot_hist(d): 56 | d = np.array(d) 57 | flatten = d.ravel() 58 | mean = np.mean(flatten) 59 | std = np.std(flatten) 60 | count, bins, ignored = plt.hist(flatten, 100, normed=True) 61 | print 'mean: %.3f, std: %.3f' % (mean, std) 62 | n_neg = flatten[np.where(flatten<=0.0)].size 63 | print '# of zero: %d' % n_neg 64 | print '# of total: %d' % flatten.size 65 | #return n_neg, flatten.size # return #negative, total 66 | plt.ylabel('Propability') 67 | plt.xlabel('Nudule Size') 68 | return flatten 69 | 70 | 71 | def update_fontsize(ax, fontsize=12.): 72 | for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] + 73 | ax.get_xticklabels() + ax.get_yticklabels()): 74 | item.set_fontsize(fontsize) 75 | 76 | def autolabel(rects, ax, label, rotation=90): 77 | """ 78 | Attach a text label above each bar displaying its height 79 | """ 80 | for rect in rects: 81 | height = rect.get_y() + rect.get_height() 82 | ax.text(rect.get_x() + rect.get_width()/2., 1.03*height, 83 | label, 84 | ha='center', va='bottom', rotation=rotation) 85 | -------------------------------------------------------------------------------- /audio_data/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 3 | import torch.distributed as dist 4 | from torch.nn.modules import Module 5 | 6 | ''' 7 | This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py 8 | launcher included with this example. It assumes that your run is using multiprocess with 1 9 | GPU/process, that the model is on the correct device, and that torch.set_device has been 10 | used to set the device. 11 | 12 | Parameters are broadcasted to the other processes on initialization of DistributedDataParallel, 13 | and will be allreduced at the finish of the backward pass. 14 | ''' 15 | 16 | 17 | class DistributedDataParallel(Module): 18 | def __init__(self, module): 19 | super(DistributedDataParallel, self).__init__() 20 | self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False 21 | 22 | self.module = module 23 | 24 | for p in self.module.state_dict().values(): 25 | if not torch.is_tensor(p): 26 | continue 27 | if dist._backend == dist.dist_backend.NCCL: 28 | assert p.is_cuda, "NCCL backend only supports model parameters to be on GPU." 29 | dist.broadcast(p, 0) 30 | 31 | def allreduce_params(): 32 | if (self.needs_reduction): 33 | self.needs_reduction = False 34 | buckets = {} 35 | for param in self.module.parameters(): 36 | if param.requires_grad and param.grad is not None: 37 | tp = type(param.data) 38 | if tp not in buckets: 39 | buckets[tp] = [] 40 | buckets[tp].append(param) 41 | if self.warn_on_half: 42 | if torch.cuda.HalfTensor in buckets: 43 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 44 | " It is recommended to use the NCCL backend in this case.") 45 | self.warn_on_half = False 46 | 47 | for tp in buckets: 48 | bucket = buckets[tp] 49 | grads = [param.grad.data for param in bucket] 50 | coalesced = _flatten_dense_tensors(grads) 51 | dist.all_reduce(coalesced) 52 | coalesced /= dist.get_world_size() 53 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 54 | buf.copy_(synced) 55 | 56 | for param in list(self.module.parameters()): 57 | def allreduce_hook(*unused): 58 | param._execution_engine.queue_callback(allreduce_params) 59 | 60 | if param.requires_grad: 61 | param.register_hook(allreduce_hook) 62 | 63 | def forward(self, *inputs, **kwargs): 64 | self.needs_reduction = True 65 | return self.module(*inputs, **kwargs) 66 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import os 4 | import logging 5 | import torch 6 | from dl_trainer import DLTrainer 7 | import argparse 8 | from settings import logger, formatter 9 | 10 | def model_average(trainers): 11 | trainer = trainers[0] 12 | own_state = trainer.net.state_dict() 13 | for name, param in own_state.items(): 14 | for t in trainers[1:]: 15 | own_state[name] = own_state[name]+t.net.state_dict()[name] 16 | for name, param in own_state.items(): 17 | own_state[name] = own_state[name]/len(trainers) 18 | trainer.net.load_state_dict(own_state) 19 | 20 | def evaluate(model_path, dnn, dataset, data_dir, nepochs, allreduce=False): 21 | items = model_path.split('/')[-1].split('-') 22 | dnn = items[0] 23 | lr = float(items[-1][2:]) 24 | batch_size = int(items[2][2:]) 25 | #batch_size = 1 #int(items[2][2:]) 26 | rank = 0 27 | nworkers=1 28 | 29 | trainer = DLTrainer(rank, 1, dist=False, ngpus=1, batch_size=batch_size, is_weak_scaling=True, dataset=dataset, dnn=dnn, data_dir=data_dir, lr=lr, nworkers=nworkers) 30 | best_acc = 0.0 31 | start_epoch = 1 32 | for i in range(start_epoch, nepochs+1): 33 | filename = '%s-rank%d-epoch%d.pth' % (dnn, rank, i) 34 | fn = os.path.join(model_path, filename) 35 | if i == nepochs and not allreduce and False: 36 | trainers = [] 37 | for j in range(nworkers): 38 | filename = '%s-rank%d-epoch%d.pth' % (dnn, j, i) 39 | fn = os.path.join(model_path, filename) 40 | tr = DLTrainer(rank, 1, dist=False, ngpus=1, batch_size=batch_size, is_weak_scaling=True, dataset=dataset, dnn=dnn, data_dir=data_dir, lr=lr, nworkers=nworkers) 41 | tr.load_model_from_file(fn) 42 | trainers.append(tr) 43 | model_average(trainers) 44 | trainer = trainers[0] 45 | else: 46 | trainer.load_model_from_file(fn) 47 | acc = trainer.test(i) 48 | if i == start_epoch: 49 | best_acc = acc 50 | else: 51 | if dnn in ['lstm', 'lstman4']: # the lower the better 52 | if best_acc > acc: 53 | best_acc = acc 54 | else: 55 | if best_acc < acc: 56 | best_acc = acc 57 | logger.info('Best validation accuracy or perprexity: %f', best_acc) 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser(description="p2pdl model evaluater") 62 | parser.add_argument('--model-path', type=str, help='Saved model path') 63 | parser.add_argument('--dnn', type=str, default='resnet20') 64 | parser.add_argument('--dataset', type=str, default='cifar10') 65 | parser.add_argument('--data-dir', type=str, default='./data', help='Specify the data root path') 66 | parser.add_argument('--nepochs', type=int, default=90, help='Number of epochs to evaluate') 67 | args = parser.parse_args() 68 | logfile = '%s/evaluate.log' % args.model_path 69 | hdlr = logging.FileHandler(logfile) 70 | hdlr.setFormatter(formatter) 71 | logger.addHandler(hdlr) 72 | evaluate(args.model_path, args.dnn, args.dataset, args.data_dir, args.nepochs) 73 | 74 | 75 | -------------------------------------------------------------------------------- /models/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['AlexNet', 'alexnet'] 8 | 9 | class LRN(nn.Module): 10 | def __init__(self, local_size=1, alpha=1.0, beta=0.75, ACROSS_CHANNELS=True): 11 | super(LRN, self).__init__() 12 | self.ACROSS_CHANNELS = ACROSS_CHANNELS 13 | if ACROSS_CHANNELS: 14 | self.average=nn.AvgPool3d(kernel_size=(local_size, 1, 1), 15 | stride=1, 16 | padding=(int((local_size-1.0)/2), 0, 0)) 17 | else: 18 | self.average=nn.AvgPool2d(kernel_size=local_size, 19 | stride=1, 20 | padding=int((local_size-1.0)/2)) 21 | self.alpha = alpha 22 | self.beta = beta 23 | 24 | 25 | def forward(self, x): 26 | if self.ACROSS_CHANNELS: 27 | div = x.pow(2).unsqueeze(1) 28 | div = self.average(div).squeeze(1) 29 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 30 | else: 31 | div = x.pow(2) 32 | div = self.average(div) 33 | div = div.mul(self.alpha).add(1.0).pow(self.beta) 34 | x = x.div(div) 35 | return x 36 | 37 | class AlexNet(nn.Module): 38 | 39 | def __init__(self, num_classes=1000): 40 | super(AlexNet, self).__init__() 41 | self.features = nn.Sequential( 42 | nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0), 43 | nn.ReLU(inplace=True), 44 | LRN(local_size=5, alpha=0.0001, beta=0.75), 45 | nn.MaxPool2d(kernel_size=3, stride=2), 46 | nn.Conv2d(96, 256, kernel_size=5, padding=2, groups=2), 47 | nn.ReLU(inplace=True), 48 | LRN(local_size=5, alpha=0.0001, beta=0.75), 49 | nn.MaxPool2d(kernel_size=3, stride=2), 50 | nn.Conv2d(256, 384, kernel_size=3, padding=1), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(384, 384, kernel_size=3, padding=1, groups=2), 53 | nn.ReLU(inplace=True), 54 | nn.Conv2d(384, 256, kernel_size=3, padding=1, groups=2), 55 | nn.ReLU(inplace=True), 56 | nn.MaxPool2d(kernel_size=3, stride=2), 57 | ) 58 | self.classifier = nn.Sequential( 59 | nn.Linear(256 * 6 * 6, 4096), 60 | nn.ReLU(inplace=True), 61 | nn.Dropout(), 62 | nn.Linear(4096, 4096), 63 | nn.ReLU(inplace=True), 64 | nn.Dropout(), 65 | nn.Linear(4096, num_classes), 66 | ) 67 | 68 | def forward(self, x): 69 | x = self.features(x) 70 | x = x.view(x.size(0), 256 * 6 * 6) 71 | x = self.classifier(x) 72 | return x 73 | 74 | 75 | def alexnet(pretrained=False, **kwargs): 76 | r"""AlexNet model architecture from the 77 | `"One weird trick..." `_ paper. 78 | 79 | Args: 80 | pretrained (bool): If True, returns a model pre-trained on ImageNet 81 | """ 82 | model = AlexNet(**kwargs) 83 | if pretrained: 84 | model_path = 'model_list/alexnet.pth.tar' 85 | pretrained_model = torch.load(model_path) 86 | model.load_state_dict(pretrained_model['state_dict']) 87 | return model 88 | -------------------------------------------------------------------------------- /compression.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | import time 6 | 7 | 8 | class NoneCompressor(): 9 | @staticmethod 10 | def compress(tensor, name=None): 11 | return tensor, tensor.dtype 12 | 13 | @staticmethod 14 | def decompress(tensor, ctc, name=None): 15 | z = tensor 16 | return z 17 | 18 | 19 | class TopKCompressor(): 20 | """ 21 | Sparse Communication for Distributed Gradient Descent, Alham Fikri Aji et al., 2017 22 | """ 23 | residuals = {} 24 | c = 0 25 | sparsities = [] 26 | t = 0. 27 | zero_conditions = {} 28 | values = {} 29 | indexes = {} 30 | name = 'topk' 31 | @staticmethod 32 | def compress(tensor, name=None, sigma_scale=2.5, ratio=0.05): 33 | start = time.time() 34 | with torch.no_grad(): 35 | if name not in TopKCompressor.residuals: 36 | TopKCompressor.residuals[name] = torch.zeros_like(tensor.data) 37 | # top-k solution 38 | numel = tensor.numel() 39 | k = max(int(numel * ratio), 1) 40 | 41 | tensor.data.add_(TopKCompressor.residuals[name].data) 42 | 43 | values, indexes = torch.topk(torch.abs(tensor.data), k=k) 44 | values = tensor.data[indexes] 45 | if name not in TopKCompressor.zero_conditions: 46 | TopKCompressor.zero_conditions[name] = torch.ones(numel, dtype=torch.float32, device=tensor.device) 47 | zero_condition = TopKCompressor.zero_conditions[name] 48 | zero_condition.fill_(1.0) 49 | zero_condition[indexes] = 0.0 50 | 51 | TopKCompressor.residuals[name].data.fill_(0.) 52 | TopKCompressor.residuals[name].data = tensor.data * zero_condition 53 | tensor.data.sub_(TopKCompressor.residuals[name].data) 54 | 55 | TopKCompressor.values[name] = values 56 | TopKCompressor.indexes[name] = indexes 57 | return tensor, indexes 58 | 59 | @staticmethod 60 | def get_residuals(name, like_tensor): 61 | if name not in TopKCompressor.residuals: 62 | TopKCompressor.residuals[name] = torch.zeros_like(like_tensor.data) 63 | return TopKCompressor.residuals[name] 64 | 65 | @staticmethod 66 | def add_residuals(included_indexes, name): 67 | with torch.no_grad(): 68 | residuals = TopKCompressor.residuals[name] 69 | if type(included_indexes) is np.ndarray: 70 | indexes_t = torch.from_numpy(included_indexes).cuda(residuals.device).long() 71 | else: 72 | indexes_t = included_indexes 73 | values = TopKCompressor.values[name] 74 | values.data[indexes_t] = 0.0 75 | residuals.data[TopKCompressor.indexes[name]] += values.data 76 | 77 | @staticmethod 78 | def decompress(tensor, ctc, name=None): 79 | z = tensor 80 | return z 81 | 82 | class TopKCompressor2(TopKCompressor): 83 | name = 'topk2' 84 | 85 | class gTopKCompressor(TopKCompressor): 86 | name = 'gtopk' 87 | 88 | 89 | compressors = { 90 | 'topk': TopKCompressor, 91 | 'topk2': TopKCompressor2, 92 | 'gtopk': gTopKCompressor, 93 | 'none': NoneCompressor 94 | } 95 | -------------------------------------------------------------------------------- /audio_data/an4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import io 4 | import shutil 5 | import tarfile 6 | import wget 7 | 8 | from utils import create_manifest 9 | 10 | parser = argparse.ArgumentParser(description='Processes and downloads an4.') 11 | parser.add_argument('--target-dir', default='an4_dataset/', help='Path to save dataset') 12 | parser.add_argument('--min-duration', default=1, type=int, 13 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 14 | parser.add_argument('--max-duration', default=15, type=int, 15 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 16 | args = parser.parse_args() 17 | 18 | 19 | def _format_data(root_path, data_tag, name, wav_folder): 20 | data_path = args.target_dir + data_tag + '/' + name + '/' 21 | new_transcript_path = data_path + '/txt/' 22 | new_wav_path = data_path + '/wav/' 23 | 24 | os.makedirs(new_transcript_path) 25 | os.makedirs(new_wav_path) 26 | 27 | wav_path = root_path + 'wav/' 28 | file_ids = root_path + 'etc/an4_%s.fileids' % data_tag 29 | transcripts = root_path + 'etc/an4_%s.transcription' % data_tag 30 | train_path = wav_path + wav_folder 31 | 32 | _convert_audio_to_wav(train_path) 33 | _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path) 34 | 35 | 36 | def _convert_audio_to_wav(train_path): 37 | with os.popen('find %s -type f -name "*.raw"' % train_path) as pipe: 38 | for line in pipe: 39 | raw_path = line.strip() 40 | new_path = line.replace('.raw', '.wav').strip() 41 | cmd = 'sox -t raw -r %d -b 16 -e signed-integer -B -c 1 \"%s\" \"%s\"' % ( 42 | 16000, raw_path, new_path) 43 | os.system(cmd) 44 | 45 | 46 | def _format_files(file_ids, new_transcript_path, new_wav_path, transcripts, wav_path): 47 | with open(file_ids, 'r') as f: 48 | with open(transcripts, 'r') as t: 49 | paths = f.readlines() 50 | transcripts = t.readlines() 51 | for x in range(len(paths)): 52 | path = wav_path + paths[x].strip() + '.wav' 53 | filename = path.split('/')[-1] 54 | extracted_transcript = _process_transcript(transcripts, x) 55 | current_path = os.path.abspath(path) 56 | new_path = new_wav_path + filename 57 | text_path = new_transcript_path + filename.replace('.wav', '.txt') 58 | with io.FileIO(text_path, "w") as file: 59 | file.write(extracted_transcript.encode('utf-8')) 60 | os.rename(current_path, new_path) 61 | 62 | 63 | def _process_transcript(transcripts, x): 64 | extracted_transcript = transcripts[x].split('(')[0].strip("").split('<')[0].strip().upper() 65 | return extracted_transcript 66 | 67 | 68 | def main(): 69 | root_path = 'an4/' 70 | name = 'an4' 71 | wget.download('http://www.speech.cs.cmu.edu/databases/an4/an4_raw.bigendian.tar.gz') 72 | tar = tarfile.open('an4_raw.bigendian.tar.gz') 73 | tar.extractall() 74 | os.makedirs(args.target_dir) 75 | _format_data(root_path, 'train', name, 'an4_clstk') 76 | _format_data(root_path, 'test', name, 'an4test_clstk') 77 | shutil.rmtree(root_path) 78 | os.remove('an4_raw.bigendian.tar.gz') 79 | train_path = args.target_dir + '/train/' 80 | test_path = args.target_dir + '/test/' 81 | print ('\n', 'Creating manifests...') 82 | create_manifest(train_path, 'an4_train_manifest.csv', args.min_duration, args.max_duration) 83 | create_manifest(test_path, 'an4_val_manifest.csv') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | import math, torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Bottleneck(nn.Module): 6 | def __init__(self, nChannels, growthRate): 7 | super(Bottleneck, self).__init__() 8 | interChannels = 4*growthRate 9 | self.bn1 = nn.BatchNorm2d(nChannels) 10 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) 11 | self.bn2 = nn.BatchNorm2d(interChannels) 12 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 13 | 14 | def forward(self, x): 15 | out = self.conv1(F.relu(self.bn1(x))) 16 | out = self.conv2(F.relu(self.bn2(out))) 17 | out = torch.cat((x, out), 1) 18 | return out 19 | 20 | class SingleLayer(nn.Module): 21 | def __init__(self, nChannels, growthRate): 22 | super(SingleLayer, self).__init__() 23 | self.bn1 = nn.BatchNorm2d(nChannels) 24 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 25 | 26 | def forward(self, x): 27 | out = self.conv1(F.relu(self.bn1(x))) 28 | out = torch.cat((x, out), 1) 29 | return out 30 | 31 | class Transition(nn.Module): 32 | def __init__(self, nChannels, nOutChannels): 33 | super(Transition, self).__init__() 34 | self.bn1 = nn.BatchNorm2d(nChannels) 35 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 36 | 37 | def forward(self, x): 38 | out = self.conv1(F.relu(self.bn1(x))) 39 | out = F.avg_pool2d(out, 2) 40 | return out 41 | 42 | class DenseNet(nn.Module): 43 | def __init__(self, growthRate, depth, reduction, nClasses, bottleneck): 44 | super(DenseNet, self).__init__() 45 | 46 | if bottleneck: nDenseBlocks = int( (depth-4) / 6 ) 47 | else : nDenseBlocks = int( (depth-4) / 3 ) 48 | 49 | nChannels = 2*growthRate 50 | self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, bias=False) 51 | 52 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 53 | nChannels += nDenseBlocks*growthRate 54 | nOutChannels = int(math.floor(nChannels*reduction)) 55 | self.trans1 = Transition(nChannels, nOutChannels) 56 | 57 | nChannels = nOutChannels 58 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 59 | nChannels += nDenseBlocks*growthRate 60 | nOutChannels = int(math.floor(nChannels*reduction)) 61 | self.trans2 = Transition(nChannels, nOutChannels) 62 | 63 | nChannels = nOutChannels 64 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) 65 | nChannels += nDenseBlocks*growthRate 66 | 67 | self.bn1 = nn.BatchNorm2d(nChannels) 68 | self.fc = nn.Linear(nChannels, nClasses) 69 | 70 | for m in self.modules(): 71 | if isinstance(m, nn.Conv2d): 72 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 73 | m.weight.data.normal_(0, math.sqrt(2. / n)) 74 | elif isinstance(m, nn.BatchNorm2d): 75 | m.weight.data.fill_(1) 76 | m.bias.data.zero_() 77 | elif isinstance(m, nn.Linear): 78 | m.bias.data.zero_() 79 | 80 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): 81 | layers = [] 82 | for i in range(int(nDenseBlocks)): 83 | if bottleneck: 84 | layers.append(Bottleneck(nChannels, growthRate)) 85 | else: 86 | layers.append(SingleLayer(nChannels, growthRate)) 87 | nChannels += growthRate 88 | return nn.Sequential(*layers) 89 | 90 | def forward(self, x): 91 | out = self.conv1(x) 92 | out = self.trans1(self.dense1(out)) 93 | out = self.trans2(self.dense2(out)) 94 | out = self.dense3(out) 95 | out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8)) 96 | out = F.log_softmax(self.fc(out)) 97 | return out 98 | 99 | def densenet100_12(num_classes=10): 100 | model = DenseNet(12, 100, 0.5, num_classes, False) 101 | return model 102 | -------------------------------------------------------------------------------- /ptb_reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from torch.utils.data import Dataset 5 | import collections 6 | import os 7 | import numpy as np 8 | 9 | def _read_words(filename): 10 | with open(filename, "r") as f: 11 | return f.read().replace("\n", "").split() 12 | 13 | 14 | def _build_vocab(filename): 15 | data = _read_words(filename) 16 | 17 | counter = collections.Counter(data) 18 | count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0])) 19 | 20 | words, _ = list(zip(*count_pairs)) 21 | word_to_id = dict(zip(words, range(len(words)))) 22 | id_to_word = dict((v, k) for k, v in word_to_id.items()) 23 | 24 | return word_to_id, id_to_word 25 | 26 | 27 | def _file_to_word_ids(filename, word_to_id): 28 | data = _read_words(filename) 29 | return [word_to_id[word] for word in data if word in word_to_id] 30 | 31 | 32 | def ptb_raw_data(data_path=None, prefix="ptb"): 33 | """Load PTB raw data from data directory "data_path". 34 | Reads PTB text files, converts strings to integer ids, 35 | and performs mini-batching of the inputs. 36 | The PTB dataset comes from Tomas Mikolov's webpage: 37 | http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz 38 | Args: 39 | data_path: string path to the directory where simple-examples.tgz has 40 | been extracted. 41 | Returns: 42 | tuple (train_data, valid_data, test_data, vocabulary) 43 | where each of the data objects can be passed to PTBIterator. 44 | """ 45 | 46 | train_path = os.path.join(data_path, prefix + ".train.txt") 47 | valid_path = os.path.join(data_path, prefix + ".valid.txt") 48 | test_path = os.path.join(data_path, prefix + ".test.txt") 49 | 50 | word_to_id, id_2_word = _build_vocab(train_path) 51 | train_data = _file_to_word_ids(train_path, word_to_id) 52 | valid_data = _file_to_word_ids(valid_path, word_to_id) 53 | test_data = _file_to_word_ids(test_path, word_to_id) 54 | return train_data, valid_data, test_data, word_to_id, id_2_word 55 | 56 | class TrainDataset(Dataset): 57 | def __init__(self, raw_data, batch_size, num_steps): 58 | self.raw_data = np.array(raw_data, dtype=np.int64) 59 | self.num_steps = num_steps 60 | self.batch_size = batch_size 61 | self.num_steps = num_steps 62 | self.data_len = len(self.raw_data) 63 | self.sample_len = self.data_len // self.num_steps 64 | 65 | def __getitem__(self, idx): 66 | 67 | num_steps_begin_index = self.num_steps * idx 68 | 69 | num_steps_end_index = self.num_steps * (idx + 1) 70 | 71 | # print("num_steps_end_index : %d== ",num_steps_end_index) 72 | x = self.raw_data[num_steps_begin_index : num_steps_end_index] 73 | y = self.raw_data[num_steps_begin_index + 1 : num_steps_end_index + 1] 74 | 75 | return (x, y) 76 | 77 | def __len__(self): 78 | return self.sample_len - self.sample_len % self.batch_size 79 | 80 | class TestDataset(Dataset): 81 | def __init__(self, raw_data, batch_size, num_steps): 82 | self.raw_data = np.array(raw_data, dtype=np.int64) 83 | self.num_steps = num_steps 84 | self.batch_size = batch_size 85 | self.num_steps = num_steps 86 | self.data_len = len(self.raw_data) 87 | self.sample_len = self.data_len // self.num_steps 88 | # self.batch_len = self.sample_len // self.batch_size - 1 89 | 90 | def __getitem__(self, idx): 91 | num_steps_begin_index = self.num_steps * idx 92 | 93 | num_steps_end_index = self.num_steps * (idx + 1) 94 | 95 | # print("num_steps_end_index : %d== ",num_steps_end_index) 96 | x = self.raw_data[num_steps_begin_index: num_steps_end_index] 97 | y = self.raw_data[num_steps_begin_index + 1: num_steps_end_index + 1] 98 | 99 | return (x, y) 100 | 101 | def __len__(self): 102 | return self.sample_len - self.sample_len % self.batch_size 103 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | import math 5 | 6 | class ResNeXtBottleneck(nn.Module): 7 | expansion = 4 8 | """ 9 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 10 | """ 11 | def __init__(self, inplanes, planes, cardinality, base_width, stride=1, downsample=None): 12 | super(ResNeXtBottleneck, self).__init__() 13 | 14 | D = int(math.floor(planes * (base_width/64.0))) 15 | C = cardinality 16 | 17 | self.conv_reduce = nn.Conv2d(inplanes, D*C, kernel_size=1, stride=1, padding=0, bias=False) 18 | self.bn_reduce = nn.BatchNorm2d(D*C) 19 | 20 | self.conv_conv = nn.Conv2d(D*C, D*C, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 21 | self.bn = nn.BatchNorm2d(D*C) 22 | 23 | self.conv_expand = nn.Conv2d(D*C, planes*4, kernel_size=1, stride=1, padding=0, bias=False) 24 | self.bn_expand = nn.BatchNorm2d(planes*4) 25 | 26 | self.downsample = downsample 27 | 28 | def forward(self, x): 29 | residual = x 30 | 31 | bottleneck = self.conv_reduce(x) 32 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 33 | 34 | bottleneck = self.conv_conv(bottleneck) 35 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 36 | 37 | bottleneck = self.conv_expand(bottleneck) 38 | bottleneck = self.bn_expand(bottleneck) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | return F.relu(residual + bottleneck, inplace=True) 44 | 45 | 46 | class CifarResNeXt(nn.Module): 47 | """ 48 | ResNext optimized for the Cifar dataset, as specified in 49 | https://arxiv.org/pdf/1611.05431.pdf 50 | """ 51 | def __init__(self, block, depth, cardinality, base_width, num_classes): 52 | super(CifarResNeXt, self).__init__() 53 | 54 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 55 | assert (depth - 2) % 9 == 0, 'depth should be one of 29, 38, 47, 56, 101' 56 | layer_blocks = (depth - 2) // 9 57 | 58 | self.cardinality = cardinality 59 | self.base_width = base_width 60 | self.num_classes = num_classes 61 | 62 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 63 | self.bn_1 = nn.BatchNorm2d(64) 64 | 65 | self.inplanes = 64 66 | self.stage_1 = self._make_layer(block, 64 , layer_blocks, 1) 67 | self.stage_2 = self._make_layer(block, 128, layer_blocks, 2) 68 | self.stage_3 = self._make_layer(block, 256, layer_blocks, 2) 69 | self.avgpool = nn.AvgPool2d(8) 70 | self.classifier = nn.Linear(256*block.expansion, num_classes) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 75 | m.weight.data.normal_(0, math.sqrt(2. / n)) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() 79 | elif isinstance(m, nn.Linear): 80 | init.kaiming_normal(m.weight) 81 | m.bias.data.zero_() 82 | 83 | def _make_layer(self, block, planes, blocks, stride=1): 84 | downsample = None 85 | if stride != 1 or self.inplanes != planes * block.expansion: 86 | downsample = nn.Sequential( 87 | nn.Conv2d(self.inplanes, planes * block.expansion, 88 | kernel_size=1, stride=stride, bias=False), 89 | nn.BatchNorm2d(planes * block.expansion), 90 | ) 91 | 92 | layers = [] 93 | layers.append(block(self.inplanes, planes, self.cardinality, self.base_width, stride, downsample)) 94 | self.inplanes = planes * block.expansion 95 | for i in range(1, blocks): 96 | layers.append(block(self.inplanes, planes, self.cardinality, self.base_width)) 97 | 98 | return nn.Sequential(*layers) 99 | 100 | def forward(self, x): 101 | x = self.conv_1_3x3(x) 102 | x = F.relu(self.bn_1(x), inplace=True) 103 | x = self.stage_1(x) 104 | x = self.stage_2(x) 105 | x = self.stage_3(x) 106 | x = self.avgpool(x) 107 | x = x.view(x.size(0), -1) 108 | return self.classifier(x) 109 | 110 | def resnext29_16_64(num_classes=10): 111 | """Constructs a ResNeXt-29, 16*64d model for CIFAR-10 (by default) 112 | 113 | Args: 114 | num_classes (uint): number of classes 115 | """ 116 | model = CifarResNeXt(ResNeXtBottleneck, 29, 16, 64, num_classes) 117 | return model 118 | 119 | def resnext29_8_64(num_classes=10): 120 | """Constructs a ResNeXt-29, 8*64d model for CIFAR-10 (by default) 121 | 122 | Args: 123 | num_classes (uint): number of classes 124 | """ 125 | model = CifarResNeXt(ResNeXtBottleneck, 29, 8, 64, num_classes) 126 | return model 127 | -------------------------------------------------------------------------------- /scripts/create_hdf5.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import argparse, os 4 | import glob 5 | import h5py 6 | import numpy as np 7 | import cv2 8 | import csv 9 | import time 10 | from random import shuffle 11 | 12 | DATADIR='/tmp/imagenet/ILSVRC2012_dataset' 13 | OUTPUTDIR='/tmp/imagenet_hdf5' 14 | 15 | def get_list(listfile): 16 | files = [] 17 | with open(listfile) as f: 18 | for l in f.readlines(): 19 | files.append(l) 20 | return files 21 | 22 | def gen_class_maps(trainfiles, valfiles): 23 | def _gen_maps(fl): 24 | map = {} 25 | label = 0 26 | for f in fl: 27 | cls = f.split('/')[-2] 28 | if not cls in map: 29 | map[cls] = label 30 | label += 1 31 | print('num of classes: ', len(map.keys())) 32 | return map 33 | class_labels = _gen_maps(trainfiles) 34 | train_labels = [] 35 | val_labels = [] 36 | for f in trainfiles: 37 | cls = f.split('/')[-2] 38 | label = class_labels[cls] 39 | train_labels.append(label) 40 | for f in valfiles: 41 | cls = f.split('/')[-2] 42 | label = class_labels[cls] 43 | val_labels.append(label) 44 | return train_labels, val_labels, class_labels 45 | 46 | def convert(outputpath, output, nworkers, datadir): 47 | trainfiles = gen_list_from_folder(datadir, folder='train')#get_list(trainlist) 48 | shuffle(trainfiles) # shuffle list 49 | valfiles = gen_list_from_folder(datadir, folder='val') #get_list(vallist) 50 | print('Read train number of lines: %d' % len(trainfiles)) 51 | print('Read val number of lines: %d' % len(valfiles)) 52 | train_labels, val_labels, class_labels = gen_class_maps(trainfiles, valfiles) 53 | with open(os.path.join(outputpath, 'imagenet_label_mapping.csv'), 'w') as csvfile: 54 | writer = csv.writer(csvfile, delimiter=' ') 55 | for k in class_labels: 56 | l = str(class_labels[k]) 57 | writer.writerow([k, l]) 58 | 59 | h5file = os.path.join(outputpath, output) 60 | ntrains = len(trainfiles) 61 | nvals = len(valfiles) 62 | #train_shape = (ntrains, 256, 256, 3) 63 | #val_shape = (nvals, 256, 256, 3) 64 | train_shape = (ntrains, 224, 224, 3) 65 | val_shape = (nvals, 224, 224, 3) 66 | with h5py.File(h5file, 'w') as hf: 67 | hf.create_dataset("train_img", train_shape, np.uint8) 68 | hf.create_dataset("val_img", val_shape, np.uint8) 69 | hf.create_dataset("train_labels", (train_shape[0],), np.int16) 70 | hf["train_labels"][...] = train_labels 71 | hf.create_dataset("val_labels", (val_shape[0],), np.int16) 72 | hf["val_labels"][...] = val_labels 73 | 74 | s = time.time() 75 | for i in range(ntrains): 76 | if i % 1000 == 0 and i > 1: 77 | print('Train data: {}/{}, time used: {}'.format(i, ntrains, (time.time()-s))) 78 | s = time.time() 79 | 80 | f = trainfiles[i] 81 | img = cv2.imread(f) 82 | img = cv2.resize(img, (train_shape[1], train_shape[2]), interpolation=cv2.INTER_CUBIC) 83 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 84 | #img = np.rollaxis(img, 2) 85 | hf["train_img"][i, ...] = img[None] 86 | 87 | s = time.time() 88 | for i in range(nvals): 89 | if i % 1000 == 0 and i > 1: 90 | print('val data: {}/{}, time used: {}'.format(i, ntrains, (time.time()-s))) 91 | s = time.time() 92 | 93 | f = valfiles[i] 94 | img = cv2.imread(f) 95 | img = cv2.resize(img, (train_shape[1], train_shape[2]), interpolation=cv2.INTER_CUBIC) 96 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 97 | #img = np.rollaxis(img, 2) 98 | hf["val_img"][i, ...] = img[None] 99 | 100 | 101 | def gen_list_from_folder(path, folder='train'): 102 | f = '%s/%s/*/*.JPEG'%(path, folder) 103 | print(f) 104 | addrs = glob.glob(f) 105 | return addrs 106 | 107 | if __name__ == '__main__': 108 | parser = argparse.ArgumentParser(description="Convert the ImageNet2012 dataset to the HDF5 format") 109 | parser.add_argument('--outputpath', type=str, default=OUTPUTDIR) 110 | parser.add_argument('--output', type=str, default='imagenet-shuffled-224.hdf5') 111 | parser.add_argument('--nworkers', type=int, default=1, help='Multiple threads supported') 112 | parser.add_argument('--datadir', type=str, default=DATADIR, help='Specify the dataset path, e.g., %s' % DATADIR) 113 | args = parser.parse_args() 114 | convert(args.outputpath, args.output, args.nworkers, args.datadir) 115 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC, DownsampleD 6 | import math 7 | 8 | 9 | class ResNetBasicblock(nn.Module): 10 | expansion = 1 11 | """ 12 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 13 | """ 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(ResNetBasicblock, self).__init__() 16 | 17 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn_a = nn.BatchNorm2d(planes) 19 | 20 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn_b = nn.BatchNorm2d(planes) 22 | 23 | self.downsample = downsample 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | basicblock = self.conv_a(x) 29 | basicblock = self.bn_a(basicblock) 30 | basicblock = F.relu(basicblock, inplace=True) 31 | 32 | basicblock = self.conv_b(basicblock) 33 | basicblock = self.bn_b(basicblock) 34 | 35 | if self.downsample is not None: 36 | residual = self.downsample(x) 37 | 38 | return F.relu(residual + basicblock, inplace=True) 39 | 40 | class CifarResNet(nn.Module): 41 | """ 42 | ResNet optimized for the Cifar dataset, as specified in 43 | https://arxiv.org/abs/1512.03385.pdf 44 | """ 45 | def __init__(self, block, depth, num_classes): 46 | """ Constructor 47 | Args: 48 | depth: number of layers. 49 | num_classes: number of classes 50 | base_width: base width 51 | """ 52 | super(CifarResNet, self).__init__() 53 | 54 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 55 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 56 | self.name = 'resnet%d'%depth 57 | layer_blocks = (depth - 2) // 6 58 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 59 | 60 | self.num_classes = num_classes 61 | 62 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn_1 = nn.BatchNorm2d(16) 64 | 65 | self.inplanes = 16 66 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 67 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 68 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 69 | self.avgpool = nn.AvgPool2d(8) 70 | self.classifier = nn.Linear(64*block.expansion, num_classes) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 75 | m.weight.data.normal_(0, math.sqrt(2. / n)) 76 | #m.bias.data.zero_() 77 | elif isinstance(m, nn.BatchNorm2d): 78 | m.weight.data.fill_(1) 79 | m.bias.data.zero_() 80 | elif isinstance(m, nn.Linear): 81 | init.kaiming_normal_(m.weight) 82 | m.bias.data.zero_() 83 | 84 | def _make_layer(self, block, planes, blocks, stride=1): 85 | downsample = None 86 | if stride != 1 or self.inplanes != planes * block.expansion: 87 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 88 | 89 | layers = [] 90 | layers.append(block(self.inplanes, planes, stride, downsample)) 91 | self.inplanes = planes * block.expansion 92 | for i in range(1, blocks): 93 | layers.append(block(self.inplanes, planes)) 94 | 95 | return nn.Sequential(*layers) 96 | 97 | def forward(self, x): 98 | x = self.conv_1_3x3(x) 99 | x = F.relu(self.bn_1(x), inplace=True) 100 | x = self.stage_1(x) 101 | x = self.stage_2(x) 102 | x = self.stage_3(x) 103 | x = self.avgpool(x) 104 | x = x.view(x.size(0), -1) 105 | return self.classifier(x) 106 | 107 | def resnet20(num_classes=10): 108 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 109 | Args: 110 | num_classes (uint): number of classes 111 | """ 112 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 113 | return model 114 | 115 | def resnet32(num_classes=10): 116 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 117 | Args: 118 | num_classes (uint): number of classes 119 | """ 120 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 121 | return model 122 | 123 | def resnet44(num_classes=10): 124 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 125 | Args: 126 | num_classes (uint): number of classes 127 | """ 128 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 129 | return model 130 | 131 | def resnet56(num_classes=10): 132 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 133 | Args: 134 | num_classes (uint): number of classes 135 | """ 136 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 137 | return model 138 | 139 | def resnet110(num_classes=10): 140 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 141 | Args: 142 | num_classes (uint): number of classes 143 | """ 144 | model = CifarResNet(ResNetBasicblock, 110, num_classes) 145 | return model 146 | -------------------------------------------------------------------------------- /horovod_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import time 4 | import torch 5 | import numpy as np 6 | import argparse, os 7 | import settings 8 | import utils 9 | import logging 10 | from mpi4py import MPI 11 | comm = MPI.COMM_WORLD 12 | comm.Set_errhandler(MPI.ERRORS_RETURN) 13 | 14 | from dl_trainer import DLTrainer, _support_datasets, _support_dnns 15 | import horovod.torch as hvd 16 | from tensorboardX import SummaryWriter 17 | writer = None 18 | 19 | from settings import logger, formatter 20 | 21 | 22 | def ssgd_with_horovod(dnn, dataset, data_dir, nworkers, lr, batch_size, nsteps_update, max_epochs, nwpernode, pretrain, num_steps = 1): 23 | rank = hvd.rank() 24 | torch.cuda.set_device(rank%nwpernode) 25 | if rank != 0: 26 | pretrain = None 27 | trainer = DLTrainer(rank, nworkers, dist=False, batch_size=batch_size, is_weak_scaling=True, ngpus=1, data_dir=data_dir, dataset=dataset, dnn=dnn, lr=lr, nworkers=nworkers, prefix='allreduce', pretrain=pretrain, num_steps=num_steps, tb_writer=writer) 28 | 29 | init_epoch = torch.ones(1) * trainer.get_train_epoch() 30 | init_iter = torch.ones(1) * trainer.get_train_iter() 31 | trainer.set_train_epoch(int(hvd.broadcast(init_epoch, root_rank=0)[0])) 32 | trainer.set_train_iter(int(hvd.broadcast(init_iter, root_rank=0)[0])) 33 | 34 | optimizer = hvd.DistributedOptimizer(trainer.optimizer, named_parameters=trainer.net.named_parameters()) 35 | hvd.broadcast_parameters(trainer.net.state_dict(), root_rank=0) 36 | trainer.update_optimizer(optimizer) 37 | iters_per_epoch = trainer.get_num_of_training_samples() // (nworkers * batch_size * nsteps_update) 38 | 39 | times = [] 40 | display = 20 if iters_per_epoch > 20 else iters_per_epoch-1 41 | for epoch in range(max_epochs): 42 | hidden = None 43 | if dnn == 'lstm': 44 | hidden = trainer.net.init_hidden() 45 | for i in range(iters_per_epoch): 46 | s = time.time() 47 | optimizer.zero_grad() 48 | for j in range(nsteps_update): 49 | if j < nsteps_update - 1 and nsteps_update > 1: 50 | optimizer.local = True 51 | else: 52 | optimizer.local = False 53 | if dnn == 'lstm': 54 | _, hidden = trainer.train(1, hidden=hidden) 55 | else: 56 | trainer.train(1) 57 | if dnn == 'lstm': 58 | optimizer.synchronize() 59 | torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 0.25) 60 | elif dnn == 'lstman4': 61 | optimizer.synchronize() 62 | torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 400) 63 | trainer.update_model() 64 | times.append(time.time()-s) 65 | if i % display == 0 and i > 0: 66 | time_per_iter = np.mean(times) 67 | logger.info('Time per iteration including communication: %f. Speed: %f images/s', time_per_iter, batch_size * nsteps_update / time_per_iter) 68 | times = [] 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser(description="AllReduce trainer") 73 | parser.add_argument('--batch-size', type=int, default=32) 74 | parser.add_argument('--nsteps-update', type=int, default=1) 75 | parser.add_argument('--nworkers', type=int, default=1, help='Just for experiments, and it cannot be used in production') 76 | parser.add_argument('--nwpernode', type=int, default=1, help='Number of workers per node') 77 | parser.add_argument('--dataset', type=str, default='imagenet', choices=_support_datasets, help='Specify the dataset for training') 78 | parser.add_argument('--dnn', type=str, default='resnet50', choices=_support_dnns, help='Specify the neural network for training') 79 | parser.add_argument('--data-dir', type=str, default='./data', help='Specify the data root path') 80 | parser.add_argument('--lr', type=float, default=0.1, help='Default learning rate') 81 | parser.add_argument('--max-epochs', type=int, default=90, help='Default maximum epochs to train') 82 | parser.add_argument('--pretrain', type=str, default=None, help='Specify the pretrain path') 83 | parser.add_argument('--num-steps', type=int, default=35) 84 | parser.set_defaults(compression=False) 85 | args = parser.parse_args() 86 | batch_size = args.batch_size * args.nsteps_update 87 | prefix = settings.PREFIX 88 | logdir = 'allreduce-%s/%s-n%d-bs%d-lr%.4f-ns%d' % (prefix, args.dnn, args.nworkers, batch_size, args.lr, args.nsteps_update) 89 | relative_path = './logs/%s'%logdir 90 | utils.create_path(relative_path) 91 | rank = 0 92 | if args.nworkers > 1: 93 | hvd.init() 94 | rank = hvd.rank() 95 | if rank == 0: 96 | tb_runs = './runs/%s'%logdir 97 | writer = SummaryWriter(tb_runs) 98 | logfile = os.path.join(relative_path, settings.hostname+'-'+str(rank)+'.log') 99 | hdlr = logging.FileHandler(logfile) 100 | hdlr.setFormatter(formatter) 101 | logger.addHandler(hdlr) 102 | logger.info('Configurations: %s', args) 103 | ssgd_with_horovod(args.dnn, args.dataset, args.data_dir, args.nworkers, args.lr, args.batch_size, args.nsteps_update, args.max_epochs, args.nwpernode, args.pretrain, args.num_steps) 104 | -------------------------------------------------------------------------------- /models/preresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC 6 | import math 7 | 8 | 9 | class ResNetBasicblock(nn.Module): 10 | expansion = 1 11 | def __init__(self, inplanes, planes, stride, downsample, Type): 12 | super(ResNetBasicblock, self).__init__() 13 | 14 | self.Type = Type 15 | 16 | self.bn_a = nn.BatchNorm2d(inplanes) 17 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | 19 | self.bn_b = nn.BatchNorm2d(planes) 20 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | 22 | self.relu = nn.ReLU(inplace=True) 23 | self.downsample = downsample 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | basicblock = self.bn_a(x) 29 | basicblock = self.relu(basicblock) 30 | 31 | if self.Type == 'both_preact': 32 | residual = basicblock 33 | elif self.Type != 'normal': 34 | assert False, 'Unknow type : {}'.format(self.Type) 35 | 36 | basicblock = self.conv_a(basicblock) 37 | 38 | basicblock = self.bn_b(basicblock) 39 | basicblock = self.relu(basicblock) 40 | basicblock = self.conv_b(basicblock) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(residual) 44 | 45 | return residual + basicblock 46 | 47 | class CifarPreResNet(nn.Module): 48 | """ 49 | ResNet optimized for the Cifar dataset, as specified in 50 | https://arxiv.org/abs/1512.03385.pdf 51 | """ 52 | def __init__(self, block, depth, num_classes): 53 | """ Constructor 54 | Args: 55 | depth: number of layers. 56 | num_classes: number of classes 57 | base_width: base width 58 | """ 59 | super(CifarPreResNet, self).__init__() 60 | 61 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 62 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 63 | layer_blocks = (depth - 2) // 6 64 | print ('CifarPreResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 65 | 66 | self.num_classes = num_classes 67 | 68 | self.conv_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 69 | 70 | self.inplanes = 16 71 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 72 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 73 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 74 | self.lastact = nn.Sequential(nn.BatchNorm2d(64*block.expansion), nn.ReLU(inplace=True)) 75 | self.avgpool = nn.AvgPool2d(8) 76 | self.classifier = nn.Linear(64*block.expansion, num_classes) 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | #m.bias.data.zero_() 83 | elif isinstance(m, nn.BatchNorm2d): 84 | m.weight.data.fill_(1) 85 | m.bias.data.zero_() 86 | elif isinstance(m, nn.Linear): 87 | init.kaiming_normal(m.weight) 88 | m.bias.data.zero_() 89 | 90 | def _make_layer(self, block, planes, blocks, stride=1): 91 | downsample = None 92 | if stride != 1 or self.inplanes != planes * block.expansion: 93 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 94 | 95 | layers = [] 96 | layers.append(block(self.inplanes, planes, stride, downsample, 'both_preact')) 97 | self.inplanes = planes * block.expansion 98 | for i in range(1, blocks): 99 | layers.append(block(self.inplanes, planes, 1, None, 'normal')) 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | x = self.conv_3x3(x) 105 | x = self.stage_1(x) 106 | x = self.stage_2(x) 107 | x = self.stage_3(x) 108 | x = self.lastact(x) 109 | x = self.avgpool(x) 110 | x = x.view(x.size(0), -1) 111 | return self.classifier(x) 112 | 113 | def preresnet20(num_classes=10): 114 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 115 | Args: 116 | num_classes (uint): number of classes 117 | """ 118 | model = CifarPreResNet(ResNetBasicblock, 20, num_classes) 119 | return model 120 | 121 | def preresnet32(num_classes=10): 122 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 123 | Args: 124 | num_classes (uint): number of classes 125 | """ 126 | model = CifarPreResNet(ResNetBasicblock, 32, num_classes) 127 | return model 128 | 129 | def preresnet44(num_classes=10): 130 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 131 | Args: 132 | num_classes (uint): number of classes 133 | """ 134 | model = CifarPreResNet(ResNetBasicblock, 44, num_classes) 135 | return model 136 | 137 | def preresnet56(num_classes=10): 138 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 139 | Args: 140 | num_classes (uint): number of classes 141 | """ 142 | model = CifarPreResNet(ResNetBasicblock, 56, num_classes) 143 | return model 144 | 145 | def preresnet110(num_classes=10): 146 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 147 | Args: 148 | num_classes (uint): number of classes 149 | """ 150 | model = CifarPreResNet(ResNetBasicblock, 110, num_classes) 151 | return model 152 | -------------------------------------------------------------------------------- /models/resnet_mod.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | from .res_utils import DownsampleA, DownsampleC, DownsampleD 6 | import math 7 | 8 | 9 | class ResNetBasicblock(nn.Module): 10 | expansion = 1 11 | """ 12 | RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) 13 | """ 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(ResNetBasicblock, self).__init__() 16 | 17 | self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn_a = nn.BatchNorm2d(planes) 19 | 20 | self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn_b = nn.BatchNorm2d(planes) 22 | 23 | self.downsample = downsample 24 | 25 | def forward(self, x): 26 | if isinstance(x, list): 27 | x, is_list, features = x[0], True, x[1:] 28 | else: 29 | is_list, features = False, None 30 | residual = x 31 | 32 | conv_a = self.conv_a(x) 33 | bn_a = self.bn_a(conv_a) 34 | relu_a = F.relu(bn_a, inplace=True) 35 | 36 | conv_b = self.conv_b(relu_a) 37 | bn_b = self.bn_b(conv_b) 38 | 39 | if self.downsample is not None: 40 | residual = self.downsample(x) 41 | 42 | output = F.relu(residual + bn_b, inplace=True) 43 | 44 | if is_list: 45 | return [output] + features + [bn_a, bn_b] 46 | else: 47 | return output 48 | 49 | class CifarResNet(nn.Module): 50 | """ 51 | ResNet optimized for the Cifar dataset, as specified in 52 | https://arxiv.org/abs/1512.03385.pdf 53 | """ 54 | def __init__(self, block, depth, num_classes): 55 | """ Constructor 56 | Args: 57 | depth: number of layers. 58 | num_classes: number of classes 59 | base_width: base width 60 | """ 61 | super(CifarResNet, self).__init__() 62 | 63 | #Model type specifies number of layers for CIFAR-10 and CIFAR-100 model 64 | assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' 65 | layer_blocks = (depth - 2) // 6 66 | print ('CifarResNet : Depth : {} , Layers for each block : {}'.format(depth, layer_blocks)) 67 | 68 | self.num_classes = num_classes 69 | 70 | self.conv_1_3x3 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.bn_1 = nn.BatchNorm2d(16) 72 | 73 | self.inplanes = 16 74 | self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) 75 | self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) 76 | self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) 77 | self.avgpool = nn.AvgPool2d(8) 78 | self.classifier = nn.Linear(64*block.expansion, num_classes) 79 | 80 | for m in self.modules(): 81 | if isinstance(m, nn.Conv2d): 82 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 83 | m.weight.data.normal_(0, math.sqrt(2. / n)) 84 | #m.bias.data.zero_() 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | elif isinstance(m, nn.Linear): 89 | init.kaiming_normal(m.weight) 90 | m.bias.data.zero_() 91 | 92 | def _make_layer(self, block, planes, blocks, stride=1): 93 | downsample = None 94 | if stride != 1 or self.inplanes != planes * block.expansion: 95 | downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) 96 | 97 | layers = [] 98 | layers.append(block(self.inplanes, planes, stride, downsample)) 99 | self.inplanes = planes * block.expansion 100 | for i in range(1, blocks): 101 | layers.append(block(self.inplanes, planes)) 102 | 103 | return nn.Sequential(*layers) 104 | 105 | def forward(self, x): 106 | if isinstance(x, list): 107 | assert len(x) == 1, 'The length of inputs must be one vs {}'.format(len(x)) 108 | x, is_list = x[0], True 109 | else: 110 | x, is_list = x, False 111 | x = self.conv_1_3x3(x) 112 | x = F.relu(self.bn_1(x), inplace=True) 113 | 114 | if is_list: x = [x] 115 | x = self.stage_1(x) 116 | x = self.stage_2(x) 117 | x = self.stage_3(x) 118 | if is_list: 119 | x, features = x[0], x[1:] 120 | else: 121 | features = None 122 | x = self.avgpool(x) 123 | x = x.view(x.size(0), -1) 124 | cls = self.classifier(x) 125 | 126 | if is_list: return cls, features 127 | else: return cls 128 | 129 | def resnet_mod20(num_classes=10): 130 | """Constructs a ResNet-20 model for CIFAR-10 (by default) 131 | Args: 132 | num_classes (uint): number of classes 133 | """ 134 | model = CifarResNet(ResNetBasicblock, 20, num_classes) 135 | return model 136 | 137 | def resnet_mod32(num_classes=10): 138 | """Constructs a ResNet-32 model for CIFAR-10 (by default) 139 | Args: 140 | num_classes (uint): number of classes 141 | """ 142 | model = CifarResNet(ResNetBasicblock, 32, num_classes) 143 | return model 144 | 145 | def resnet_mod44(num_classes=10): 146 | """Constructs a ResNet-44 model for CIFAR-10 (by default) 147 | Args: 148 | num_classes (uint): number of classes 149 | """ 150 | model = CifarResNet(ResNetBasicblock, 44, num_classes) 151 | return model 152 | 153 | def resnet_mod56(num_classes=10): 154 | """Constructs a ResNet-56 model for CIFAR-10 (by default) 155 | Args: 156 | num_classes (uint): number of classes 157 | """ 158 | model = CifarResNet(ResNetBasicblock, 56, num_classes) 159 | return model 160 | 161 | def resnet_mod110(num_classes=10): 162 | """Constructs a ResNet-110 model for CIFAR-10 (by default) 163 | Args: 164 | num_classes (uint): number of classes 165 | """ 166 | model = CifarResNet(ResNetBasicblock, 110, num_classes) 167 | return model 168 | -------------------------------------------------------------------------------- /audio_data/librispeech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import wget 3 | import tarfile 4 | import argparse 5 | import subprocess 6 | from utils import create_manifest 7 | from tqdm import tqdm 8 | import shutil 9 | 10 | parser = argparse.ArgumentParser(description='Processes and downloads LibriSpeech dataset.') 11 | parser.add_argument("--target-dir", default='LibriSpeech_dataset/', type=str, help="Directory to store the dataset.") 12 | parser.add_argument('--sample-rate', default=16000, type=int, help='Sample rate') 13 | parser.add_argument('--files-to-use', default="train-clean-100.tar.gz," 14 | "train-clean-360.tar.gz,train-other-500.tar.gz," 15 | "dev-clean.tar.gz,dev-other.tar.gz," 16 | "test-clean.tar.gz,test-other.tar.gz", type=str, 17 | help='list of file names to download') 18 | parser.add_argument('--min-duration', default=1, type=int, 19 | help='Prunes training samples shorter than the min duration (given in seconds, default 1)') 20 | parser.add_argument('--max-duration', default=15, type=int, 21 | help='Prunes training samples longer than the max duration (given in seconds, default 15)') 22 | args = parser.parse_args() 23 | 24 | LIBRI_SPEECH_URLS = { 25 | "train": ["http://www.openslr.org/resources/12/train-clean-100.tar.gz", 26 | "http://www.openslr.org/resources/12/train-clean-360.tar.gz", 27 | "http://www.openslr.org/resources/12/train-other-500.tar.gz"], 28 | 29 | "val": ["http://www.openslr.org/resources/12/dev-clean.tar.gz", 30 | "http://www.openslr.org/resources/12/dev-other.tar.gz"], 31 | 32 | "test_clean": ["http://www.openslr.org/resources/12/test-clean.tar.gz"], 33 | "test_other": ["http://www.openslr.org/resources/12/test-other.tar.gz"] 34 | } 35 | 36 | 37 | def _preprocess_transcript(phrase): 38 | return phrase.strip().upper() 39 | 40 | 41 | def _process_file(wav_dir, txt_dir, base_filename, root_dir): 42 | full_recording_path = os.path.join(root_dir, base_filename) 43 | assert os.path.exists(full_recording_path) and os.path.exists(root_dir) 44 | wav_recording_path = os.path.join(wav_dir, base_filename.replace(".flac", ".wav")) 45 | subprocess.call(["sox {} -r {} -b 16 -c 1 {}".format(full_recording_path, str(args.sample_rate), 46 | wav_recording_path)], shell=True) 47 | # process transcript 48 | txt_transcript_path = os.path.join(txt_dir, base_filename.replace(".flac", ".txt")) 49 | transcript_file = os.path.join(root_dir, "-".join(base_filename.split('-')[:-1]) + ".trans.txt") 50 | assert os.path.exists(transcript_file), "Transcript file {} does not exist.".format(transcript_file) 51 | transcriptions = open(transcript_file).read().strip().split("\n") 52 | transcriptions = {t.split()[0].split("-")[-1]: " ".join(t.split()[1:]) for t in transcriptions} 53 | with open(txt_transcript_path, "w") as f: 54 | key = base_filename.replace(".flac", "").split("-")[-1] 55 | assert key in transcriptions, "{} is not in the transcriptions".format(key) 56 | f.write(_preprocess_transcript(transcriptions[key])) 57 | f.flush() 58 | 59 | 60 | def main(): 61 | target_dl_dir = args.target_dir 62 | if not os.path.exists(target_dl_dir): 63 | os.makedirs(target_dl_dir) 64 | files_to_dl = args.files_to_use.strip().split(',') 65 | for split_type, lst_libri_urls in LIBRI_SPEECH_URLS.items(): 66 | split_dir = os.path.join(target_dl_dir, split_type) 67 | if not os.path.exists(split_dir): 68 | os.makedirs(split_dir) 69 | split_wav_dir = os.path.join(split_dir, "wav") 70 | if not os.path.exists(split_wav_dir): 71 | os.makedirs(split_wav_dir) 72 | split_txt_dir = os.path.join(split_dir, "txt") 73 | if not os.path.exists(split_txt_dir): 74 | os.makedirs(split_txt_dir) 75 | extracted_dir = os.path.join(split_dir, "LibriSpeech") 76 | if os.path.exists(extracted_dir): 77 | shutil.rmtree(extracted_dir) 78 | for url in lst_libri_urls: 79 | # check if we want to dl this file 80 | dl_flag = False 81 | for f in files_to_dl: 82 | if url.find(f) != -1: 83 | dl_flag = True 84 | if not dl_flag: 85 | print("Skipping url: {}".format(url)) 86 | continue 87 | filename = url.split("/")[-1] 88 | target_filename = os.path.join(split_dir, filename) 89 | if not os.path.exists(target_filename): 90 | wget.download(url, split_dir) 91 | print("Unpacking {}...".format(filename)) 92 | tar = tarfile.open(target_filename) 93 | tar.extractall(split_dir) 94 | tar.close() 95 | os.remove(target_filename) 96 | print("Converting flac files to wav and extracting transcripts...") 97 | assert os.path.exists(extracted_dir), "Archive {} was not properly uncompressed.".format(filename) 98 | for root, subdirs, files in tqdm(os.walk(extracted_dir)): 99 | for f in files: 100 | if f.find(".flac") != -1: 101 | _process_file(wav_dir=split_wav_dir, txt_dir=split_txt_dir, 102 | base_filename=f, root_dir=root) 103 | 104 | print("Finished {}".format(url)) 105 | shutil.rmtree(extracted_dir) 106 | if split_type == 'train': # Prune to min/max duration 107 | create_manifest(split_dir, 'libri_' + split_type + '_manifest.csv', args.min_duration, args.max_duration) 108 | else: 109 | create_manifest(split_dir, 'libri_' + split_type + '_manifest.csv') 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /models/imagenet_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | "3x3 convolution with padding" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet(nn.Module): 83 | 84 | def __init__(self, block, layers, num_classes=1000): 85 | self.inplanes = 64 86 | super(ResNet, self).__init__() 87 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 88 | bias=False) 89 | self.bn1 = nn.BatchNorm2d(64) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 92 | self.layer1 = self._make_layer(block, 64, layers[0]) 93 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 94 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 95 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 96 | self.avgpool = nn.AvgPool2d(7) 97 | self.fc = nn.Linear(512 * block.expansion, num_classes) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | 107 | def _make_layer(self, block, planes, blocks, stride=1): 108 | downsample = None 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = nn.Sequential( 111 | nn.Conv2d(self.inplanes, planes * block.expansion, 112 | kernel_size=1, stride=stride, bias=False), 113 | nn.BatchNorm2d(planes * block.expansion), 114 | ) 115 | 116 | layers = [] 117 | layers.append(block(self.inplanes, planes, stride, downsample)) 118 | self.inplanes = planes * block.expansion 119 | for i in range(1, blocks): 120 | layers.append(block(self.inplanes, planes)) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.maxpool(x) 129 | 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x = self.layer3(x) 133 | x = self.layer4(x) 134 | 135 | x = self.avgpool(x) 136 | x = x.view(x.size(0), -1) 137 | x = self.fc(x) 138 | 139 | return x 140 | 141 | 142 | def resnet18(num_classes=1000): 143 | """Constructs a ResNet-18 model. 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes) 149 | return model 150 | 151 | 152 | def resnet34(num_classes=1000): 153 | """Constructs a ResNet-34 model. 154 | 155 | Args: 156 | pretrained (bool): If True, returns a model pre-trained on ImageNet 157 | """ 158 | model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes) 159 | model.name = 'resnet34' 160 | return model 161 | 162 | 163 | def resnet50(num_classes=1000): 164 | """Constructs a ResNet-50 model. 165 | 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes) 170 | model.name = 'resnet50' 171 | return model 172 | 173 | 174 | def resnet101(num_classes=1000): 175 | """Constructs a ResNet-101 model. 176 | 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes) 181 | model.name = 'resnet101' 182 | return model 183 | 184 | 185 | def resnet152(num_classes=1000): 186 | """Constructs a ResNet-152 model. 187 | 188 | Args: 189 | pretrained (bool): If True, returns a model pre-trained on ImageNet 190 | """ 191 | model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes) 192 | return model 193 | -------------------------------------------------------------------------------- /scripts/icdcs2019/plot_sth.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib.pyplot as plt 3 | from matplotlib.patches import Rectangle 4 | from matplotlib.collections import PatchCollection 5 | import numpy as np 6 | import os 7 | import sys 8 | import argparse 9 | import math 10 | from utils import read_log, plot_hist, update_fontsize 11 | 12 | OUTPUT_PATH = '/media/sf_Shared_Data/tmp/infocom19' 13 | 14 | comp_color = '#2e75b6' 15 | compression_color = 'g' 16 | comm_color = '#c00000' 17 | synceasgd_color = '#3F5EBA' 18 | opt_comm_color = '#c55a11' 19 | 20 | 21 | class Bar: 22 | initialized = False 23 | def __init__(self, start_time, duration, max_time, ax, type='p', index=1, is_optimal=False): 24 | """ 25 | type: p for compute, m for communication 26 | """ 27 | self.start_time_ = start_time 28 | self.ax_ = ax 29 | self.max_time_ = max_time 30 | self.duration_ = duration/max_time 31 | self.type_ = type 32 | self.height_ = 0.1 33 | self.start_ = 0.3 34 | self.index_ = index 35 | self.is_optimal_ = is_optimal 36 | self.y_ = self.start_+self.height_ if self.type_ is 'p' else self.start_ 37 | if self.type_ == 'p': 38 | self.y_ = self.start_+self.height_ 39 | self.color_ = comp_color 40 | elif self.type_ == 'wc': # WFBP 41 | self.y_ = self.start_ 42 | self.color_ = comm_color 43 | elif self.type_ == 'sc': # SyncEASGD 44 | self.y_ = self.start_-self.height_ 45 | self.color_ = synceasgd_color 46 | elif self.type_ == 'mc': # MG-WFGP 47 | self.y_ = self.start_-2*self.height_-self.height_*0.2 48 | self.color_ = opt_comm_color 49 | #self.color_ = comp_color if self.type_ is 'p' else comm_color 50 | #if self.is_optimal_: 51 | #self.y_ = self.start_-self.height_ - self.height_ * 0.2 52 | #self.color_ = opt_comm_color 53 | if not Bar.initialized: 54 | self.ax_.set_xlim(right=1.05) 55 | self.ax_.spines['top'].set_visible(False) 56 | self.ax_.spines['right'].set_visible(False) 57 | bottom = 0.00#self.start_-self.height_/2 58 | self.ax_.set_ylim(bottom=bottom, top=0.5) 59 | self.ax_.get_yaxis().set_ticks([]) 60 | self.ax_.xaxis.set_ticks_position('bottom') 61 | self.ax_.set_xticks([i*0.1 for i in range(0,11)]) 62 | self.ax_.set_xticklabels([str(int(self.max_time_*(i*0.1)/1000)) for i in range(0,11)]) 63 | self.ax_.arrow(0, 0.05, 1.01, 0., fc='k', ec='k', lw=0.1, color='black', length_includes_head= True, clip_on = False, overhang=0, width=0.0004) 64 | self.ax_.annotate(r'$t$ $(ms)$', (1.015, 0.07), color='black', 65 | fontsize=20, ha='center', va='center') 66 | fontsize = 18 67 | left_margin = 0. 68 | self.ax_.text(left_margin, self.start_+3*self.height_/2, 'Comp.',horizontalalignment='right', color='black', va='center', size=fontsize) 69 | self.ax_.text(left_margin, self.start_+self.height_/2, 'Comm.(WF.)',horizontalalignment='right', color='black', va='center',size=fontsize) 70 | self.ax_.text(left_margin, self.start_-self.height_/2, 'Comm.(S.E.)',horizontalalignment='right', color='black', va='center',size=fontsize) 71 | self.ax_.text(left_margin, self.start_-self.height_-self.height_/2, 'Comm.(M.W.)',horizontalalignment='right', color='black', va='center',size=fontsize) 72 | Bar.initialized = True 73 | 74 | def render(self): 75 | x = self.start_time_ / self.max_time_ 76 | y = self.y_ 77 | if self.duration_ > 0.0: 78 | rect = Rectangle((x, y), self.duration_, self.height_, axes=self.ax_, color=self.color_, ec='black', alpha=0.8) 79 | self.ax_.add_patch(rect) 80 | fz = 16 81 | if str(self.index_).find(',') > 0: 82 | fz = 16 83 | self.index_ = self.index_.split(',')[0]+'-'+self.index_.split(',')[-1] 84 | self.ax_.annotate(str(self.index_), (x, y+0.02), color='black', 85 | fontsize=fz, ha='left', va='center') 86 | #return rect 87 | 88 | def render_log(filename): 89 | sizes = [] 90 | computes = [] 91 | comms = [] 92 | sizes, comms, computes, merged_comms = read_log(filename) 93 | #sizes = sizes[::-1] 94 | #computes = computes[::-1] 95 | #comms = comms[::-1] 96 | start_time = 0.0 97 | comm_start_time = 0.0 98 | comm = 0.0 99 | max_time = max(np.sum(computes), np.sum(comms)+computes[0]) 100 | fig, ax = plt.subplots(1) 101 | print('sizes: ', sizes) 102 | print('computes: ', computes) 103 | print('communications: ', comms) 104 | for i in range(len(computes)): 105 | comp = computes[i] 106 | bar = Bar(start_time, comp, max_time, ax, type='p') 107 | bar.render() 108 | if comm_start_time + comm > start_time + comp: 109 | comm_start_time = comm_start_time + comm 110 | else: 111 | comm_start_time = start_time + comp 112 | comm = comms[i] 113 | bar_m = Bar(comm_start_time, comm, max_time, ax, type='m') 114 | bar_m.render() 115 | start_time += comp 116 | plt.show() 117 | plt.clf() 118 | plt.scatter(sizes, comms, c='blue') 119 | plt.scatter(sizes, computes, c='red') 120 | plt.show() 121 | 122 | def allreduce_log(filename): 123 | f = open(filename, 'r') 124 | num_of_nodes = 2 125 | sizes = [] 126 | comms = [] 127 | for l in f.readlines(): 128 | if l[0] == '#' or len(l)<10: 129 | continue 130 | items = ' '.join(l.split()).split() 131 | comm = float(items[-1]) 132 | size = int(items[0].split(',')[1]) 133 | num_of_nodes = int(items[0].split(',')[0]) 134 | comms.append(comm) 135 | sizes.append(size) 136 | f.close() 137 | #print('num_of_nodes: ', num_of_nodes) 138 | #print('sizes: ', sizes) 139 | #print('comms: ', comms) 140 | return num_of_nodes, sizes, comms 141 | 142 | 143 | def plot_allreduce_log(filenames): 144 | markers=['-ro', '-go', '-bo'] 145 | for index, fn in enumerate(filenames): 146 | num_of_nodes, sizes, comms = allreduce_log(fn) 147 | line1, = plt.plot(sizes, comms, markers[index]) 148 | plt.show() 149 | plt.clf() 150 | 151 | ax = None 152 | def statastic_gradient_size(filename, label, color, marker): 153 | global ax 154 | sizes, comms, computes, merged_comms = read_log(filename) 155 | if ax is None: 156 | fig, ax = plt.subplots(figsize=(5,4.5)) 157 | fontsize = 14 158 | ax.scatter(range(1, len(sizes)+1), sizes, c=color, label=label, marker=marker, s=40, facecolors='none', edgecolors=color) 159 | #plot_hist(sizes) 160 | ax.set_xlim(left=0) 161 | ax.set_xlabel('Learnable layer ID') 162 | #plt.ylim(bottom=1e3, top=1e7) 163 | #plt.ylabel('Message size (bytes)') 164 | ax.set_ylabel('# of parameters') 165 | ax.set_yscale("log", nonposy='clip') 166 | ax.legend() 167 | update_fontsize(ax, fontsize) 168 | print('total size: ', np.sum(sizes)) 169 | return sizes 170 | 171 | def statastic_gradient_size_all_cnns(): 172 | filenames = [] 173 | for nn in ['googlenet', 'resnet', 'densenet']: 174 | f = '/media/sf_Shared_Data/gpuhome/repositories/dpBenchmark/tools/caffe/cnn/%s/tmp8comm.log' % nn 175 | filenames.append(f) 176 | #cnns = ['GoogleNet', 'ResNet-50', 'DenseNet'] 177 | cnns = ['GoogleNet', 'ResNet-50'] 178 | colors = ['r', 'g', 'b'] 179 | markers = ['+', 'x', 'd'] 180 | sizes = [] 181 | for i, f in enumerate(filenames): 182 | if i >= len(cnns): 183 | break 184 | s = statastic_gradient_size(f, cnns[i], colors[i], markers[i]) 185 | sizes.extend(s) 186 | #plt.subplots_adjust(left=0.16, bottom=0.13, top=0.93, right=0.96) 187 | plt.subplots_adjust(left=0.18, bottom=0.13, top=0.91, right=0.92) 188 | plt.show() 189 | #plt.savefig('%s/%s.pdf' % (OUTPUT_PATH, 'gradient_distribution')) 190 | sizes = np.array(sizes) 191 | print(np.max(sizes), np.min(sizes)) 192 | 193 | 194 | if __name__ == '__main__': 195 | #render_log(test_file) 196 | #statastic_gradient_size(test_file) 197 | statastic_gradient_size_all_cnns() 198 | -------------------------------------------------------------------------------- /gtopk_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | import time 4 | import torch 5 | import numpy as np 6 | import sys 7 | import argparse, os 8 | import settings 9 | import utils 10 | import logging 11 | import distributed_optimizer as dopt 12 | from mpi4py import MPI 13 | comm = MPI.COMM_WORLD 14 | comm.Set_errhandler(MPI.ERRORS_RETURN) 15 | 16 | from dl_trainer import DLTrainer, _support_datasets, _support_dnns 17 | from compression import compressors 18 | 19 | from settings import logger, formatter 20 | import horovod.torch as hvd 21 | from tensorboardX import SummaryWriter 22 | writer = None 23 | relative_path = None 24 | 25 | 26 | def robust_ssgd(dnn, dataset, data_dir, nworkers, lr, batch_size, nsteps_update, max_epochs, compression=False, compressor='topk', nwpernode=1, sigma_scale=2.5, pretrain=None, density=0.01, prefix=None): 27 | global relative_path 28 | 29 | torch.cuda.set_device(dopt.rank()%nwpernode) 30 | rank = dopt.rank() 31 | if rank != 0: 32 | pretrain = None 33 | 34 | trainer = DLTrainer(rank, nworkers, dist=False, batch_size=batch_size, is_weak_scaling=True, ngpus=1, data_dir=data_dir, dataset=dataset, dnn=dnn, lr=lr, nworkers=nworkers, prefix=prefix+'-ds%s'%str(density), pretrain=pretrain, tb_writer=writer) 35 | 36 | init_epoch = trainer.get_train_epoch() 37 | init_iter = trainer.get_train_iter() 38 | 39 | trainer.set_train_epoch(comm.bcast(init_epoch)) 40 | trainer.set_train_iter(comm.bcast(init_iter)) 41 | 42 | def _error_handler(new_num_workers, new_rank): 43 | logger.info('Error info catched by trainer') 44 | trainer.update_nworker(new_num_workers, new_rank) 45 | 46 | compressor = compressor if compression else 'none' 47 | compressor = compressors[compressor] 48 | is_sparse = compression 49 | 50 | logger.info('Broadcast parameters....') 51 | hvd.broadcast_parameters(trainer.net.state_dict(), root_rank=0) 52 | logger.info('Broadcast parameters finished....') 53 | 54 | norm_clip = None 55 | optimizer = dopt.DistributedOptimizer(trainer.optimizer, trainer.net.named_parameters(), compression=compressor, is_sparse=is_sparse, err_handler=_error_handler, layerwise_times=None, sigma_scale=sigma_scale, density=density, norm_clip=norm_clip, writer=writer) 56 | 57 | trainer.update_optimizer(optimizer) 58 | 59 | iters_per_epoch = trainer.get_num_of_training_samples() / (nworkers * batch_size * nsteps_update) 60 | 61 | times = [] 62 | NUM_OF_DISLAY = 100 63 | display = NUM_OF_DISLAY if iters_per_epoch > NUM_OF_DISLAY else iters_per_epoch-1 64 | logger.info('Start training ....') 65 | for epoch in range(max_epochs): 66 | hidden = None 67 | if dnn == 'lstm': 68 | hidden = trainer.net.init_hidden() 69 | for i in range(iters_per_epoch): 70 | s = time.time() 71 | optimizer.zero_grad() 72 | for j in range(nsteps_update): 73 | if j < nsteps_update - 1 and nsteps_update > 1: 74 | optimizer.local = True 75 | else: 76 | optimizer.local = False 77 | if dnn == 'lstm': 78 | _, hidden = trainer.train(1, hidden=hidden) 79 | else: 80 | trainer.train(1) 81 | if dnn == 'lstm': 82 | optimizer.synchronize() 83 | torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 0.25) 84 | elif dnn == 'lstman4': 85 | optimizer.synchronize() 86 | torch.nn.utils.clip_grad_norm_(trainer.net.parameters(), 400) 87 | trainer.update_model() 88 | times.append(time.time()-s) 89 | if i % display == 0 and i > 0: 90 | time_per_iter = np.mean(times) 91 | logger.info('Time per iteration including communication: %f. Speed: %f images/s, current density: %f', time_per_iter, batch_size * nsteps_update / time_per_iter, optimizer.get_current_density()) 92 | times = [] 93 | optimizer.add_train_epoch() 94 | if settings.PROFILING_NORM: 95 | # For comparison purpose ===> 96 | fn = os.path.join(relative_path, 'gtopknorm-rank%d-epoch%d.npy' % (rank, epoch)) 97 | fn2 = os.path.join(relative_path, 'randknorm-rank%d-epoch%d.npy' % (rank, epoch)) 98 | fn3 = os.path.join(relative_path, 'upbound-rank%d-epoch%d.npy' % (rank, epoch)) 99 | fn5 = os.path.join(relative_path, 'densestd-rank%d-epoch%d.npy' % (rank, epoch)) 100 | arr = [] 101 | arr2 = [] 102 | arr3 = [] 103 | arr4 = [] 104 | arr5 = [] 105 | for gtopk_norm, randk_norm, upbound, xnorm, dense_std in optimizer._allreducer._profiling_norms: 106 | arr.append(gtopk_norm) 107 | arr2.append(randk_norm) 108 | arr3.append(upbound) 109 | arr4.append(xnorm) 110 | arr5.append(dense_std) 111 | arr = np.array(arr) 112 | arr2 = np.array(arr2) 113 | arr3 = np.array(arr3) 114 | arr4 = np.array(arr4) 115 | arr5 = np.array(arr5) 116 | logger.info('[rank:%d][%d] gtopk norm mean: %f, std: %f', rank, epoch, np.mean(arr), np.std(arr)) 117 | logger.info('[rank:%d][%d] randk norm mean: %f, std: %f', rank, epoch, np.mean(arr2), np.std(arr2)) 118 | logger.info('[rank:%d][%d] upbound norm mean: %f, std: %f', rank, epoch, np.mean(arr3), np.std(arr3)) 119 | logger.info('[rank:%d][%d] x norm mean: %f, std: %f', rank, epoch, np.mean(arr4), np.std(arr4)) 120 | logger.info('[rank:%d][%d] dense std mean: %f, std: %f', rank, epoch, np.mean(arr5), np.std(arr5)) 121 | np.save(fn, arr) 122 | np.save(fn2, arr2) 123 | np.save(fn3, arr3) 124 | np.save(fn5, arr5) 125 | # For comparison purpose <=== End 126 | optimizer._allreducer._profiling_norms = [] 127 | optimizer.stop() 128 | 129 | 130 | if __name__ == '__main__': 131 | parser = argparse.ArgumentParser(description="AllReduce trainer") 132 | parser.add_argument('--batch-size', type=int, default=32) 133 | parser.add_argument('--nsteps-update', type=int, default=1) 134 | parser.add_argument('--nworkers', type=int, default=1, help='Just for experiments, and it cannot be used in production') 135 | parser.add_argument('--nwpernode', type=int, default=1, help='Number of workers per node') 136 | parser.add_argument('--compression', dest='compression', action='store_true') 137 | parser.add_argument('--compressor', type=str, default='topk', choices=compressors.keys(), help='Specify the compressors if \'compression\' is open') 138 | parser.add_argument('--sigma-scale', type=float, default=2.5, help='Maximum sigma scaler for sparsification') 139 | parser.add_argument('--density', type=float, default=0.01, help='Density for sparsification') 140 | parser.add_argument('--dataset', type=str, default='imagenet', choices=_support_datasets, help='Specify the dataset for training') 141 | parser.add_argument('--dnn', type=str, default='resnet50', choices=_support_dnns, help='Specify the neural network for training') 142 | parser.add_argument('--data-dir', type=str, default='./data', help='Specify the data root path') 143 | parser.add_argument('--lr', type=float, default=0.1, help='Default learning rate') 144 | parser.add_argument('--max-epochs', type=int, default=90, help='Default maximum epochs to train') 145 | parser.add_argument('--pretrain', type=str, default=None, help='Specify the pretrain path') 146 | parser.set_defaults(compression=False) 147 | args = parser.parse_args() 148 | batch_size = args.batch_size * args.nsteps_update 149 | prefix = settings.PREFIX 150 | if args.compression: 151 | prefix = 'comp-' + args.compressor + '-' + prefix 152 | logdir = 'allreduce-%s/%s-n%d-bs%d-lr%.4f-ns%d-sg%.2f-ds%s' % (prefix, args.dnn, args.nworkers, batch_size, args.lr, args.nsteps_update, args.sigma_scale, str(args.density)) 153 | relative_path = './logs/%s'%logdir 154 | utils.create_path(relative_path) 155 | rank = 0 156 | rank = dopt.rank() 157 | hvd.init() 158 | if rank == 0: 159 | tb_runs = './runs/%s'%logdir 160 | writer = SummaryWriter(tb_runs) 161 | logfile = os.path.join(relative_path, settings.hostname+'-'+str(rank)+'.log') 162 | hdlr = logging.FileHandler(logfile) 163 | hdlr.setFormatter(formatter) 164 | logger.addHandler(hdlr) 165 | logger.info('Configurations: %s', args) 166 | 167 | logger.info('Interpreter: %s', sys.version) 168 | robust_ssgd(args.dnn, args.dataset, args.data_dir, args.nworkers, args.lr, args.batch_size, args.nsteps_update, args.max_epochs, args.compression, args.compressor, args.nwpernode, args.sigma_scale, args.pretrain, args.density, prefix) 169 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ---------------------------------------------------------------------------- 3 | # Copyright 2015-2016 Nervana Systems Inc. 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ---------------------------------------------------------------------------- 16 | # Modified to support pytorch Tensors 17 | 18 | import Levenshtein as Lev 19 | import torch 20 | from six.moves import xrange 21 | 22 | 23 | class Decoder(object): 24 | """ 25 | Basic decoder class from which all other decoders inherit. Implements several 26 | helper functions. Subclasses should implement the decode() method. 27 | 28 | Arguments: 29 | labels (string): mapping from integers to characters. 30 | blank_index (int, optional): index for the blank '_' character. Defaults to 0. 31 | space_index (int, optional): index for the space ' ' character. Defaults to 28. 32 | """ 33 | 34 | def __init__(self, labels, blank_index=0): 35 | # e.g. labels = "_'ABCDEFGHIJKLMNOPQRSTUVWXYZ#" 36 | self.labels = labels 37 | self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)]) 38 | self.blank_index = blank_index 39 | space_index = len(labels) # To prevent errors in decode, we add an out of bounds index for the space 40 | if ' ' in labels: 41 | space_index = labels.index(' ') 42 | self.space_index = space_index 43 | 44 | def wer(self, s1, s2): 45 | """ 46 | Computes the Word Error Rate, defined as the edit distance between the 47 | two provided sentences after tokenizing to words. 48 | Arguments: 49 | s1 (string): space-separated sentence 50 | s2 (string): space-separated sentence 51 | """ 52 | 53 | # build mapping of words to integers 54 | b = set(s1.split() + s2.split()) 55 | word2char = dict(zip(b, range(len(b)))) 56 | 57 | # map the words to a char array (Levenshtein packages only accepts 58 | # strings) 59 | w1 = [chr(word2char[w]) for w in s1.split()] 60 | w2 = [chr(word2char[w]) for w in s2.split()] 61 | 62 | return Lev.distance(''.join(w1), ''.join(w2)) 63 | 64 | def cer(self, s1, s2): 65 | """ 66 | Computes the Character Error Rate, defined as the edit distance. 67 | 68 | Arguments: 69 | s1 (string): space-separated sentence 70 | s2 (string): space-separated sentence 71 | """ 72 | s1, s2, = s1.replace(' ', ''), s2.replace(' ', '') 73 | return Lev.distance(s1, s2) 74 | 75 | def decode(self, probs, sizes=None): 76 | """ 77 | Given a matrix of character probabilities, returns the decoder's 78 | best guess of the transcription 79 | 80 | Arguments: 81 | probs: Tensor of character probabilities, where probs[c,t] 82 | is the probability of character c at time t 83 | sizes(optional): Size of each sequence in the mini-batch 84 | Returns: 85 | string: sequence of the model's best guess for the transcription 86 | """ 87 | raise NotImplementedError 88 | 89 | 90 | class BeamCTCDecoder(Decoder): 91 | def __init__(self, labels, lm_path=None, alpha=0, beta=0, cutoff_top_n=40, cutoff_prob=1.0, beam_width=100, 92 | num_processes=4, blank_index=0): 93 | super(BeamCTCDecoder, self).__init__(labels) 94 | try: 95 | from ctcdecode import CTCBeamDecoder 96 | except ImportError: 97 | raise ImportError("BeamCTCDecoder requires paddledecoder package.") 98 | self._decoder = CTCBeamDecoder(labels, lm_path, alpha, beta, cutoff_top_n, cutoff_prob, beam_width, 99 | num_processes, blank_index) 100 | 101 | def convert_to_strings(self, out, seq_len): 102 | results = [] 103 | for b, batch in enumerate(out): 104 | utterances = [] 105 | for p, utt in enumerate(batch): 106 | size = seq_len[b][p] 107 | if size > 0: 108 | transcript = ''.join(map(lambda x: self.int_to_char[x.item()], utt[0:size])) 109 | else: 110 | transcript = '' 111 | utterances.append(transcript) 112 | results.append(utterances) 113 | return results 114 | 115 | def convert_tensor(self, offsets, sizes): 116 | results = [] 117 | for b, batch in enumerate(offsets): 118 | utterances = [] 119 | for p, utt in enumerate(batch): 120 | size = sizes[b][p] 121 | if sizes[b][p] > 0: 122 | utterances.append(utt[0:size]) 123 | else: 124 | utterances.append(torch.tensor([], dtype=torch.int)) 125 | results.append(utterances) 126 | return results 127 | 128 | def decode(self, probs, sizes=None): 129 | """ 130 | Decodes probability output using ctcdecode package. 131 | Arguments: 132 | probs: Tensor of character probabilities, where probs[c,t] 133 | is the probability of character c at time t 134 | sizes: Size of each sequence in the mini-batch 135 | Returns: 136 | string: sequences of the model's best guess for the transcription 137 | """ 138 | probs = probs.cpu() 139 | out, scores, offsets, seq_lens = self._decoder.decode(probs, sizes) 140 | 141 | strings = self.convert_to_strings(out, seq_lens) 142 | offsets = self.convert_tensor(offsets, seq_lens) 143 | return strings, offsets 144 | 145 | 146 | class GreedyDecoder(Decoder): 147 | def __init__(self, labels, blank_index=0): 148 | super(GreedyDecoder, self).__init__(labels, blank_index) 149 | 150 | def convert_to_strings(self, sequences, sizes=None, remove_repetitions=False, return_offsets=False): 151 | """Given a list of numeric sequences, returns the corresponding strings""" 152 | strings = [] 153 | offsets = [] if return_offsets else None 154 | for x in xrange(len(sequences)): 155 | seq_len = sizes[x] if sizes is not None else len(sequences[x]) 156 | string, string_offsets = self.process_string(sequences[x], seq_len, remove_repetitions) 157 | strings.append([string]) # We only return one path 158 | if return_offsets: 159 | offsets.append([string_offsets]) 160 | if return_offsets: 161 | return strings, offsets 162 | else: 163 | return strings 164 | 165 | def process_string(self, sequence, size, remove_repetitions=False): 166 | string = '' 167 | offsets = [] 168 | for i in range(size): 169 | char = self.int_to_char[sequence[i].item()] 170 | if char != self.int_to_char[self.blank_index]: 171 | # if this char is a repetition and remove_repetitions=true, then skip 172 | if remove_repetitions and i != 0 and char == self.int_to_char[sequence[i - 1].item()]: 173 | pass 174 | elif char == self.labels[self.space_index]: 175 | string += ' ' 176 | offsets.append(i) 177 | else: 178 | string = string + char 179 | offsets.append(i) 180 | return string, torch.tensor(offsets, dtype=torch.int) 181 | 182 | def decode(self, probs, sizes=None): 183 | """ 184 | Returns the argmax decoding given the probability matrix. Removes 185 | repeated elements in the sequence, as well as blanks. 186 | 187 | Arguments: 188 | probs: Tensor of character probabilities from the network. Expected shape of batch x seq_length x output_dim 189 | sizes(optional): Size of each sequence in the mini-batch 190 | Returns: 191 | strings: sequences of the model's best guess for the transcription on inputs 192 | offsets: time step per character predicted 193 | """ 194 | _, max_probs = torch.max(probs, 2) 195 | strings, offsets = self.convert_to_strings(max_probs.view(max_probs.size(0), max_probs.size(1)), sizes, 196 | remove_repetitions=True, return_offsets=True) 197 | return strings, offsets 198 | -------------------------------------------------------------------------------- /distributed_optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import threading 7 | import allreducer as ar 8 | import torch 9 | import torch.nn as nn 10 | import time 11 | from mpi4py import MPI 12 | from compression import NoneCompressor 13 | from settings import logger 14 | is_py2 = sys.version[0] == '2' 15 | if is_py2: 16 | import Queue 17 | else: 18 | import queue as Queue 19 | 20 | 21 | class _DistributedOptimizer(torch.optim.Optimizer): 22 | def __init__(self, params, named_parameters, compressor, is_sparse=True, err_handler=None, layerwise_times=None, sigma_scale=2.5, density=0.01, norm_clip=None, writer=None): 23 | super(self.__class__, self).__init__(params) 24 | self._compressor= compressor 25 | self._sparse = is_sparse 26 | self._layerwise_times = layerwise_times 27 | self._msg_queue = Queue.Queue() 28 | self._msg_queue2 = Queue.Queue() 29 | 30 | if named_parameters is not None: 31 | named_parameters = list(named_parameters) 32 | else: 33 | named_parameters = [] 34 | 35 | # make sure that named_parameters are tuples 36 | if any([not isinstance(p, tuple) for p in named_parameters]): 37 | raise ValueError('named_parameters should be a sequence of ' 38 | 'tuples (name, parameter), usually produced by ' 39 | 'model.named_parameters().') 40 | 41 | if len(named_parameters) > 0: 42 | self._parameter_names = {v: k for k, v 43 | in sorted(named_parameters)} 44 | else: 45 | self._parameter_names = {v: 'allreduce.noname.%s' % i 46 | for param_group in self.param_groups 47 | for i, v in enumerate(param_group['params'])} 48 | 49 | self._handles = {} 50 | self._grad_accs = [] 51 | self._requires_update = set() 52 | self._register_hooks() 53 | 54 | self._lock = threading.Lock() 55 | self._key_lock = threading.Lock() 56 | self.momentum_correction = False 57 | self._allreducer = ar.AllReducer(named_parameters, self._lock, self._key_lock, compressor, sparse=self._sparse, err_callback=err_handler, layerwise_times=layerwise_times, sigma_scale=sigma_scale, density=density, norm_clip=norm_clip, msg_queue=self._msg_queue, msg_queue2=self._msg_queue2, writer=writer) 58 | self.allreducer_thread = threading.Thread(name='allreducer', target=self._allreducer.run) 59 | self.allreducer_thread.start() 60 | self.local = False 61 | self._synced = False 62 | 63 | def _register_hooks(self): 64 | for param_group in self.param_groups: 65 | for p in param_group['params']: 66 | if p.requires_grad: 67 | p.grad = p.data.new(p.size()).zero_() 68 | self._requires_update.add(p) 69 | p_tmp = p.expand_as(p) 70 | grad_acc = p_tmp.grad_fn.next_functions[0][0] 71 | grad_acc.register_hook(self._make_hook(p)) 72 | self._grad_accs.append(grad_acc) 73 | 74 | def _make_hook(self, p): 75 | def hook(*ignore): 76 | assert p not in self._handles 77 | assert not p.grad.requires_grad 78 | if not self.local: 79 | name = self._parameter_names.get(p) 80 | d_p = p.grad.data 81 | if self.momentum_correction: 82 | param_state = self.state[p] 83 | momentum = 0.9 84 | if 'momentum_buffer' not in param_state: 85 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 86 | buf = param_state['momentum_buffer'] 87 | buf.mul_(momentum).add_(d_p) 88 | d_p = buf 89 | self._handles[p] = self._allreducer.add_tensor(name, d_p) 90 | torch.cuda.synchronize() 91 | #if rank() == 0: 92 | # logger.info('-->pushed time [%s]: %s, norm: %f', name, time.time(), p.grad.data.norm()) 93 | self._msg_queue.put(name) 94 | return hook 95 | 96 | def synchronize(self): 97 | if not self._synced: 98 | self._msg_queue2.get() # wait for allreducer 99 | self._synced = True 100 | for p, value in self._handles.items(): 101 | output = self._allreducer.get_result(value) 102 | p.grad.data.set_(output.data) 103 | self._handles.clear() 104 | 105 | def _step(self, closure=None): 106 | """Performs a single optimization step. 107 | Arguments: 108 | closure (callable, optional): A closure that reevaluates the model 109 | and returns the loss. 110 | """ 111 | loss = None 112 | if closure is not None: 113 | loss = closure() 114 | 115 | for group in self.param_groups: 116 | weight_decay = group['weight_decay'] 117 | momentum = group['momentum'] 118 | dampening = group['dampening'] 119 | nesterov = group['nesterov'] 120 | 121 | for p in group['params']: 122 | if p.grad is None: 123 | continue 124 | d_p = p.grad.data 125 | name = self._parameter_names.get(p) 126 | if weight_decay != 0: 127 | d_p.add_(weight_decay, p.data) 128 | #if name.find('bias') >= 0 or name.find('bn') >= 0: 129 | # print('batch norm or bias detected, continue, %s' % name) 130 | if momentum != 0 and not self.momentum_correction: 131 | param_state = self.state[p] 132 | if 'momentum_buffer' not in param_state: 133 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 134 | buf.mul_(momentum).add_(d_p) 135 | else: 136 | buf = param_state['momentum_buffer'] 137 | buf.mul_(momentum).add_(1 - dampening, d_p) 138 | if nesterov: 139 | d_p = d_p.add(momentum, buf) 140 | else: 141 | d_p = buf 142 | p.data.add_(-group['lr'], d_p) 143 | return loss 144 | 145 | def _step_with_mc(self, closure=None): 146 | """Performs a single optimization step with momemtum correction. 147 | Arguments: 148 | closure (callable, optional): A closure that reevaluates the model 149 | and returns the loss. 150 | """ 151 | loss = None 152 | if closure is not None: 153 | loss = closure() 154 | 155 | for group in self.param_groups: 156 | weight_decay = group['weight_decay'] 157 | momentum = group['momentum'] 158 | dampening = group['dampening'] 159 | nesterov = group['nesterov'] 160 | 161 | for p in group['params']: 162 | if p.grad is None: 163 | continue 164 | d_p = p.grad.data 165 | name = self._parameter_names.get(p) 166 | if weight_decay != 0: 167 | d_p.add_(weight_decay, p.data) 168 | if momentum != 0: 169 | param_state = self.state[p] 170 | if 'momentum_buffer' not in param_state: 171 | buf = param_state['momentum_buffer'] = torch.zeros_like(p.data) 172 | buf.mul_(momentum).add_(d_p) 173 | else: 174 | buf = param_state['momentum_buffer'] 175 | buf.mul_(momentum).add_(1 - dampening, d_p) 176 | if nesterov: 177 | d_p = d_p.add(momentum, buf) 178 | else: 179 | d_p = buf 180 | p.data.add_(-group['lr'], d_p) 181 | return loss 182 | 183 | def step(self, closure=None): 184 | if not self.local: 185 | self.synchronize() 186 | ret = self._step(closure) 187 | self._synced = False 188 | return ret 189 | 190 | def stop(self): 191 | self._allreducer.stop() 192 | self._msg_queue.put('STOP') 193 | 194 | def add_train_epoch(self): 195 | self._allreducer.train_epoch += 1 196 | 197 | def get_current_density(self): 198 | return self._allreducer.get_current_density() 199 | 200 | 201 | def DistributedOptimizer(optimizer, named_parameters=None, compression=NoneCompressor, is_sparse=False, err_handler=None, layerwise_times=None, sigma_scale=2.5, density=0.1, norm_clip=None, writer=None): 202 | cls = type(optimizer.__class__.__name__, (optimizer.__class__,), 203 | dict(_DistributedOptimizer.__dict__)) 204 | 205 | return cls(optimizer.param_groups, named_parameters, compression, is_sparse, err_handler, layerwise_times, sigma_scale=sigma_scale, density=density, norm_clip=norm_clip, writer=writer) 206 | 207 | def rank(): 208 | return MPI.COMM_WORLD.rank 209 | 210 | def size(self): 211 | return MPI.COMM_WORLD.size 212 | 213 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /audio_data/data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from tempfile import NamedTemporaryFile 4 | 5 | from torch.distributed import get_rank 6 | from torch.distributed import get_world_size 7 | from torch.utils.data.sampler import Sampler 8 | 9 | import librosa 10 | import numpy as np 11 | import scipy.signal 12 | import torch 13 | import torchaudio 14 | import math 15 | from torch.utils.data import DataLoader 16 | from torch.utils.data import Dataset 17 | 18 | windows = {'hamming': scipy.signal.hamming, 'hann': scipy.signal.hann, 'blackman': scipy.signal.blackman, 19 | 'bartlett': scipy.signal.bartlett} 20 | 21 | 22 | def load_audio(path): 23 | sound, _ = torchaudio.load(path, normalization=True) 24 | sound = sound.numpy() 25 | if len(sound.shape) > 1: 26 | if sound.shape[0] == 1: 27 | sound = sound.squeeze() 28 | else: 29 | sound = sound.mean(axis=1) # multiple channels, average 30 | return sound 31 | 32 | 33 | class AudioParser(object): 34 | def parse_transcript(self, transcript_path): 35 | """ 36 | :param transcript_path: Path where transcript is stored from the manifest file 37 | :return: Transcript in training/testing format 38 | """ 39 | raise NotImplementedError 40 | 41 | def parse_audio(self, audio_path): 42 | """ 43 | :param audio_path: Path where audio is stored from the manifest file 44 | :return: Audio in training/testing format 45 | """ 46 | raise NotImplementedError 47 | 48 | 49 | class NoiseInjection(object): 50 | def __init__(self, 51 | path=None, 52 | sample_rate=16000, 53 | noise_levels=(0, 0.5)): 54 | """ 55 | Adds noise to an input signal with specific SNR. Higher the noise level, the more noise added. 56 | Modified code from https://github.com/willfrey/audio/blob/master/torchaudio/transforms.py 57 | """ 58 | if not os.path.exists(path): 59 | print("Directory doesn't exist: {}".format(path)) 60 | raise IOError 61 | self.paths = path is not None and librosa.util.find_files(path) 62 | self.sample_rate = sample_rate 63 | self.noise_levels = noise_levels 64 | 65 | def inject_noise(self, data): 66 | noise_path = np.random.choice(self.paths) 67 | noise_level = np.random.uniform(*self.noise_levels) 68 | return self.inject_noise_sample(data, noise_path, noise_level) 69 | 70 | def inject_noise_sample(self, data, noise_path, noise_level): 71 | noise_len = get_audio_length(noise_path) 72 | data_len = len(data) / self.sample_rate 73 | noise_start = np.random.rand() * (noise_len - data_len) 74 | noise_end = noise_start + data_len 75 | noise_dst = audio_with_sox(noise_path, self.sample_rate, noise_start, noise_end) 76 | assert len(data) == len(noise_dst) 77 | noise_energy = np.sqrt(noise_dst.dot(noise_dst) / noise_dst.size) 78 | data_energy = np.sqrt(data.dot(data) / data.size) 79 | data += noise_level * noise_dst * data_energy / noise_energy 80 | return data 81 | 82 | 83 | class SpectrogramParser(AudioParser): 84 | def __init__(self, audio_conf, normalize=False, augment=False): 85 | """ 86 | Parses audio file into spectrogram with optional normalization and various augmentations 87 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 88 | :param normalize(default False): Apply standard mean and deviation normalization to audio tensor 89 | :param augment(default False): Apply random tempo and gain perturbations 90 | """ 91 | super(SpectrogramParser, self).__init__() 92 | self.window_stride = audio_conf['window_stride'] 93 | self.window_size = audio_conf['window_size'] 94 | self.sample_rate = audio_conf['sample_rate'] 95 | self.window = windows.get(audio_conf['window'], windows['hamming']) 96 | self.normalize = normalize 97 | self.augment = augment 98 | self.noiseInjector = NoiseInjection(audio_conf['noise_dir'], self.sample_rate, 99 | audio_conf['noise_levels']) if audio_conf.get( 100 | 'noise_dir') is not None else None 101 | self.noise_prob = audio_conf.get('noise_prob') 102 | 103 | def parse_audio(self, audio_path): 104 | if self.augment: 105 | y = load_randomly_augmented_audio(audio_path, self.sample_rate) 106 | else: 107 | y = load_audio(audio_path) 108 | if self.noiseInjector: 109 | add_noise = np.random.binomial(1, self.noise_prob) 110 | if add_noise: 111 | y = self.noiseInjector.inject_noise(y) 112 | n_fft = int(self.sample_rate * self.window_size) 113 | win_length = n_fft 114 | hop_length = int(self.sample_rate * self.window_stride) 115 | # STFT 116 | D = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, 117 | win_length=win_length, window=self.window) 118 | spect, phase = librosa.magphase(D) 119 | # S = log(S+1) 120 | spect = np.log1p(spect) 121 | spect = torch.FloatTensor(spect) 122 | if self.normalize: 123 | mean = spect.mean() 124 | std = spect.std() 125 | spect.add_(-mean) 126 | spect.div_(std) 127 | 128 | return spect 129 | 130 | def parse_transcript(self, transcript_path): 131 | raise NotImplementedError 132 | 133 | 134 | class SpectrogramDataset(Dataset, SpectrogramParser): 135 | def __init__(self, audio_conf, manifest_filepath, labels, normalize=False, augment=False): 136 | """ 137 | Dataset that loads tensors via a csv containing file paths to audio files and transcripts separated by 138 | a comma. Each new line is a different sample. Example below: 139 | 140 | /path/to/audio.wav,/path/to/audio.txt 141 | ... 142 | 143 | :param audio_conf: Dictionary containing the sample rate, window and the window length/stride in seconds 144 | :param manifest_filepath: Path to manifest csv as describe above 145 | :param labels: String containing all the possible characters to map to 146 | :param normalize: Apply standard mean and deviation normalization to audio tensor 147 | :param augment(default False): Apply random tempo and gain perturbations 148 | """ 149 | with open(manifest_filepath) as f: 150 | ids = f.readlines() 151 | ids = [x.strip().split(',') for x in ids] 152 | self.ids = ids 153 | self.size = len(ids) 154 | self.labels_map = dict([(labels[i], i) for i in range(len(labels))]) 155 | super(SpectrogramDataset, self).__init__(audio_conf, normalize, augment) 156 | 157 | def __getitem__(self, index): 158 | sample = self.ids[index] 159 | audio_path, transcript_path = sample[0], sample[1] 160 | spect = self.parse_audio(audio_path) 161 | transcript = self.parse_transcript(transcript_path) 162 | return spect, transcript 163 | 164 | def parse_transcript(self, transcript_path): 165 | # with open(transcript_path, 'r', encoding='utf8') as transcript_file: 166 | with open(transcript_path, 'r') as transcript_file: 167 | transcript = transcript_file.read().replace('\n', '') 168 | transcript = list(filter(None, [self.labels_map.get(x) for x in list(transcript)])) 169 | return transcript 170 | 171 | def __len__(self): 172 | return self.size 173 | 174 | 175 | def _collate_fn(batch): 176 | def func(p): 177 | return p[0].size(1) 178 | 179 | batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) 180 | longest_sample = max(batch, key=func)[0] 181 | freq_size = longest_sample.size(0) 182 | minibatch_size = len(batch) 183 | max_seqlength = longest_sample.size(1) 184 | inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) 185 | input_percentages = torch.FloatTensor(minibatch_size) 186 | target_sizes = torch.IntTensor(minibatch_size) 187 | targets = [] 188 | for x in range(minibatch_size): 189 | sample = batch[x] 190 | tensor = sample[0] 191 | target = sample[1] 192 | seq_length = tensor.size(1) 193 | inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) 194 | input_percentages[x] = seq_length / float(max_seqlength) 195 | target_sizes[x] = len(target) 196 | targets.extend(target) 197 | targets = torch.IntTensor(targets) 198 | return inputs, targets, input_percentages, target_sizes 199 | 200 | 201 | class AudioDataLoader(DataLoader): 202 | def __init__(self, *args, **kwargs): 203 | """ 204 | Creates a data loader for AudioDatasets. 205 | """ 206 | super(AudioDataLoader, self).__init__(*args, **kwargs) 207 | self.collate_fn = _collate_fn 208 | 209 | 210 | class BucketingSampler(Sampler): 211 | def __init__(self, data_source, batch_size=1): 212 | """ 213 | Samples batches assuming they are in order of size to batch similarly sized samples together. 214 | """ 215 | super(BucketingSampler, self).__init__(data_source) 216 | self.data_source = data_source 217 | ids = list(range(0, len(data_source))) 218 | self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)] 219 | 220 | def __iter__(self): 221 | for ids in self.bins: 222 | np.random.shuffle(ids) 223 | yield ids 224 | 225 | def __len__(self): 226 | return len(self.bins) 227 | 228 | def shuffle(self, epoch): 229 | np.random.shuffle(self.bins) 230 | 231 | 232 | class DistributedBucketingSampler(Sampler): 233 | def __init__(self, data_source, batch_size=1, num_replicas=None, rank=None): 234 | """ 235 | Samples batches assuming they are in order of size to batch similarly sized samples together. 236 | """ 237 | super(DistributedBucketingSampler, self).__init__(data_source) 238 | if num_replicas is None: 239 | num_replicas = get_world_size() 240 | if rank is None: 241 | rank = get_rank() 242 | self.data_source = data_source 243 | self.ids = list(range(0, len(data_source))) 244 | self.batch_size = batch_size 245 | self.bins = [self.ids[i:i + batch_size] for i in range(0, len(self.ids), batch_size)] 246 | self.num_replicas = num_replicas 247 | self.rank = rank 248 | self.num_samples = int(math.ceil(len(self.bins) * 1.0 / self.num_replicas)) 249 | self.total_size = self.num_samples * self.num_replicas 250 | 251 | def __iter__(self): 252 | offset = self.rank 253 | # add extra samples to make it evenly divisible 254 | bins = self.bins + self.bins[:(self.total_size - len(self.bins))] 255 | assert len(bins) == self.total_size 256 | samples = bins[offset::self.num_replicas] # Get every Nth bin, starting from rank 257 | return iter(samples) 258 | 259 | def __len__(self): 260 | return self.num_samples 261 | 262 | def shuffle(self, epoch): 263 | # deterministically shuffle based on epoch 264 | g = torch.Generator() 265 | g.manual_seed(epoch) 266 | bin_ids = list(torch.randperm(len(self.bins), generator=g)) 267 | self.bins = [self.bins[i] for i in bin_ids] 268 | 269 | def set_epoch(self, epoch): 270 | self.epoch = epoch 271 | 272 | def get_audio_length(path): 273 | output = subprocess.check_output(['soxi -D \"%s\"' % path.strip()], shell=True) 274 | return float(output) 275 | 276 | 277 | def audio_with_sox(path, sample_rate, start_time, end_time): 278 | """ 279 | crop and resample the recording with sox and loads it. 280 | """ 281 | with NamedTemporaryFile(suffix=".wav") as tar_file: 282 | tar_filename = tar_file.name 283 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} trim {} ={} >/dev/null 2>&1".format(path, sample_rate, 284 | tar_filename, start_time, 285 | end_time) 286 | os.system(sox_params) 287 | y = load_audio(tar_filename) 288 | return y 289 | 290 | 291 | def augment_audio_with_sox(path, sample_rate, tempo, gain): 292 | """ 293 | Changes tempo and gain of the recording with sox and loads it. 294 | """ 295 | with NamedTemporaryFile(suffix=".wav") as augmented_file: 296 | augmented_filename = augmented_file.name 297 | sox_augment_params = ["tempo", "{:.3f}".format(tempo), "gain", "{:.3f}".format(gain)] 298 | sox_params = "sox \"{}\" -r {} -c 1 -b 16 -e si {} {} >/dev/null 2>&1".format(path, sample_rate, 299 | augmented_filename, 300 | " ".join(sox_augment_params)) 301 | os.system(sox_params) 302 | y = load_audio(augmented_filename) 303 | return y 304 | 305 | 306 | def load_randomly_augmented_audio(path, sample_rate=16000, tempo_range=(0.85, 1.15), 307 | gain_range=(-6, 8)): 308 | """ 309 | Picks tempo and gain uniformly, applies it to the utterance by using sox utility. 310 | Returns the augmented utterance. 311 | """ 312 | low_tempo, high_tempo = tempo_range 313 | tempo_value = np.random.uniform(low=low_tempo, high=high_tempo) 314 | low_gain, high_gain = gain_range 315 | gain_value = np.random.uniform(low=low_gain, high=high_gain) 316 | audio = augment_audio_with_sox(path=path, sample_rate=sample_rate, 317 | tempo=tempo_value, gain=gain_value) 318 | return audio 319 | -------------------------------------------------------------------------------- /models/lstm_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | 4 | import math 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn.parameter import Parameter 11 | from torch.autograd import Variable 12 | 13 | supported_rnns = { 14 | 'lstm': nn.LSTM, 15 | 'rnn': nn.RNN, 16 | 'gru': nn.GRU 17 | } 18 | supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items()) 19 | 20 | 21 | class SequenceWise(nn.Module): 22 | def __init__(self, module): 23 | """ 24 | Collapses input of dim T*N*H to (T*N)*H, and applies to a module. 25 | Allows handling of variable sequence lengths and minibatch sizes. 26 | :param module: Module to apply input to. 27 | """ 28 | super(SequenceWise, self).__init__() 29 | self.module = module 30 | 31 | def forward(self, x): 32 | t, n = x.size(0), x.size(1) 33 | x = x.view(t * n, -1) 34 | x = self.module(x) 35 | x = x.view(t, n, -1) 36 | return x 37 | 38 | def __repr__(self): 39 | tmpstr = self.__class__.__name__ + ' (\n' 40 | tmpstr += self.module.__repr__() 41 | tmpstr += ')' 42 | return tmpstr 43 | 44 | 45 | class MaskConv(nn.Module): 46 | def __init__(self, seq_module): 47 | """ 48 | Adds padding to the output of the module based on the given lengths. This is to ensure that the 49 | results of the model do not change when batch sizes change during inference. 50 | Input needs to be in the shape of (BxCxDxT) 51 | :param seq_module: The sequential module containing the conv stack. 52 | """ 53 | super(MaskConv, self).__init__() 54 | self.seq_module = seq_module 55 | 56 | def forward(self, x, lengths): 57 | """ 58 | :param x: The input of size BxCxDxT 59 | :param lengths: The actual length of each sequence in the batch 60 | :return: Masked output from the module 61 | """ 62 | for module in self.seq_module: 63 | x = module(x) 64 | mask = torch.ByteTensor(x.size()).fill_(0) 65 | if x.is_cuda: 66 | mask = mask.cuda() 67 | for i, length in enumerate(lengths): 68 | length = length.item() 69 | if (mask[i].size(2) - length) > 0: 70 | mask[i].narrow(2, length, mask[i].size(2) - length).fill_(1) 71 | x = x.masked_fill(mask, 0) 72 | return x, lengths 73 | 74 | 75 | class InferenceBatchSoftmax(nn.Module): 76 | def forward(self, input_): 77 | if not self.training: 78 | return F.softmax(input_, dim=-1) 79 | else: 80 | return input_ 81 | 82 | 83 | class BatchRNN(nn.Module): 84 | def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True): 85 | super(BatchRNN, self).__init__() 86 | self.input_size = input_size 87 | self.hidden_size = hidden_size 88 | self.bidirectional = bidirectional 89 | self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None 90 | self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size, 91 | bidirectional=bidirectional, bias=True) 92 | self.num_directions = 2 if bidirectional else 1 93 | 94 | def flatten_parameters(self): 95 | self.rnn.flatten_parameters() 96 | 97 | def forward(self, x, output_lengths): 98 | if self.batch_norm is not None: 99 | x = self.batch_norm(x) 100 | x = nn.utils.rnn.pack_padded_sequence(x, output_lengths) 101 | x, h = self.rnn(x) 102 | x, _ = nn.utils.rnn.pad_packed_sequence(x) 103 | if self.bidirectional: 104 | x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum 105 | return x 106 | 107 | 108 | class Lookahead(nn.Module): 109 | # Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks 110 | # input shape - sequence, batch, feature - TxNxH 111 | # output shape - same as input 112 | def __init__(self, n_features, context): 113 | # should we handle batch_first=True? 114 | super(Lookahead, self).__init__() 115 | self.n_features = n_features 116 | self.weight = Parameter(torch.Tensor(n_features, context + 1)) 117 | assert context > 0 118 | self.context = context 119 | self.register_parameter('bias', None) 120 | self.init_parameters() 121 | 122 | def init_parameters(self): # what's a better way initialiase this layer? 123 | stdv = 1. / math.sqrt(self.weight.size(1)) 124 | self.weight.data.uniform_(-stdv, stdv) 125 | 126 | def forward(self, input): 127 | seq_len = input.size(0) 128 | # pad the 0th dimension (T/sequence) with zeroes whose number = context 129 | # Once pytorch's padding functions have settled, should move to those. 130 | padding = torch.zeros(self.context, *(input.size()[1:])).type_as(input.data) 131 | x = torch.cat((input, Variable(padding)), 0) 132 | 133 | # add lookahead windows (with context+1 width) as a fourth dimension 134 | # for each seq-batch-feature combination 135 | x = [x[i:i + self.context + 1] for i in range(seq_len)] # TxLxNxH - sequence, context, batch, feature 136 | x = torch.stack(x) 137 | x = x.permute(0, 2, 3, 1) # TxNxHxL - sequence, batch, feature, context 138 | 139 | x = torch.mul(x, self.weight).sum(dim=3) 140 | return x 141 | 142 | def __repr__(self): 143 | return self.__class__.__name__ + '(' \ 144 | + 'n_features=' + str(self.n_features) \ 145 | + ', context=' + str(self.context) + ')' 146 | 147 | 148 | class DeepSpeech(nn.Module): 149 | def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layers=5, audio_conf=None, 150 | bidirectional=True, context=20): 151 | super(DeepSpeech, self).__init__() 152 | 153 | # model metadata needed for serialization/deserialization 154 | if audio_conf is None: 155 | audio_conf = {} 156 | self._version = '0.0.1' 157 | self._hidden_size = rnn_hidden_size 158 | self._hidden_layers = nb_layers 159 | self._rnn_type = rnn_type 160 | self._audio_conf = audio_conf or {} 161 | self._labels = labels 162 | self._bidirectional = bidirectional 163 | 164 | sample_rate = self._audio_conf.get("sample_rate", 16000) 165 | window_size = self._audio_conf.get("window_size", 0.02) 166 | num_classes = len(self._labels) 167 | 168 | self.conv = MaskConv(nn.Sequential( 169 | nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)), 170 | nn.BatchNorm2d(32), 171 | nn.Hardtanh(0, 20, inplace=True), 172 | nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)), 173 | nn.BatchNorm2d(32), 174 | nn.Hardtanh(0, 20, inplace=True) 175 | )) 176 | # Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1 177 | rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 178 | rnn_input_size = int(math.floor(rnn_input_size + 2 * 20 - 41) / 2 + 1) 179 | rnn_input_size = int(math.floor(rnn_input_size + 2 * 10 - 21) / 2 + 1) 180 | rnn_input_size *= 32 181 | 182 | rnns = [] 183 | rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, 184 | bidirectional=bidirectional, batch_norm=False) 185 | rnns.append(('0', rnn)) 186 | for x in range(nb_layers - 1): 187 | rnn = BatchRNN(input_size=rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type, 188 | bidirectional=bidirectional) 189 | rnns.append(('%d' % (x + 1), rnn)) 190 | self.rnns = nn.Sequential(OrderedDict(rnns)) 191 | self.lookahead = nn.Sequential( 192 | # consider adding batch norm? 193 | Lookahead(rnn_hidden_size, context=context), 194 | nn.Hardtanh(0, 20, inplace=True) 195 | ) if not bidirectional else None 196 | 197 | fully_connected = nn.Sequential( 198 | nn.BatchNorm1d(rnn_hidden_size), 199 | nn.Linear(rnn_hidden_size, num_classes, bias=False) 200 | ) 201 | self.fc = nn.Sequential( 202 | SequenceWise(fully_connected), 203 | ) 204 | self.inference_softmax = InferenceBatchSoftmax() 205 | 206 | def forward(self, x, lengths): 207 | lengths = lengths.cpu().int() 208 | output_lengths = self.get_seq_lens(lengths) 209 | x, _ = self.conv(x, output_lengths) 210 | 211 | sizes = x.size() 212 | x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension 213 | x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH 214 | 215 | for rnn in self.rnns: 216 | x = rnn(x, output_lengths) 217 | 218 | if not self._bidirectional: # no need for lookahead layer in bidirectional 219 | x = self.lookahead(x) 220 | 221 | x = self.fc(x) 222 | x = x.transpose(0, 1) 223 | # identity in training mode, softmax in eval mode 224 | x = self.inference_softmax(x) 225 | return x, output_lengths 226 | 227 | def get_seq_lens(self, input_length): 228 | """ 229 | Given a 1D Tensor or Variable containing integer sequence lengths, return a 1D tensor or variable 230 | containing the size sequences that will be output by the network. 231 | :param input_length: 1D Tensor 232 | :return: 1D Tensor scaled by model 233 | """ 234 | seq_len = input_length 235 | for m in self.conv.modules(): 236 | if type(m) == nn.modules.conv.Conv2d: 237 | seq_len = ((seq_len + 2 * m.padding[1] - m.dilation[1] * (m.kernel_size[1] - 1) - 1) / m.stride[1] + 1) 238 | return seq_len.int() 239 | 240 | @classmethod 241 | def load_model(cls, path): 242 | package = torch.load(path, map_location=lambda storage, loc: storage) 243 | model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'], 244 | labels=package['labels'], audio_conf=package['audio_conf'], 245 | rnn_type=supported_rnns[package['rnn_type']], bidirectional=package.get('bidirectional', True)) 246 | model.load_state_dict(package['state_dict']) 247 | for x in model.rnns: 248 | x.flatten_parameters() 249 | return model 250 | 251 | @classmethod 252 | def load_model_package(cls, package): 253 | model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'], 254 | labels=package['labels'], audio_conf=package['audio_conf'], 255 | rnn_type=supported_rnns[package['rnn_type']], bidirectional=package.get('bidirectional', True)) 256 | model.load_state_dict(package['state_dict']) 257 | return model 258 | 259 | @staticmethod 260 | def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=None, 261 | cer_results=None, wer_results=None, avg_loss=None, meta=None): 262 | model = model.module if DeepSpeech.is_parallel(model) else model 263 | package = { 264 | 'version': model._version, 265 | 'hidden_size': model._hidden_size, 266 | 'hidden_layers': model._hidden_layers, 267 | 'rnn_type': supported_rnns_inv.get(model._rnn_type, model._rnn_type.__name__.lower()), 268 | 'audio_conf': model._audio_conf, 269 | 'labels': model._labels, 270 | 'state_dict': model.state_dict(), 271 | 'bidirectional': model._bidirectional 272 | } 273 | if optimizer is not None: 274 | package['optim_dict'] = optimizer.state_dict() 275 | if avg_loss is not None: 276 | package['avg_loss'] = avg_loss 277 | if epoch is not None: 278 | package['epoch'] = epoch + 1 # increment for readability 279 | if iteration is not None: 280 | package['iteration'] = iteration 281 | if loss_results is not None: 282 | package['loss_results'] = loss_results 283 | package['cer_results'] = cer_results 284 | package['wer_results'] = wer_results 285 | if meta is not None: 286 | package['meta'] = meta 287 | return package 288 | 289 | @staticmethod 290 | def get_labels(model): 291 | return model.module._labels if model.is_parallel(model) else model._labels 292 | 293 | @staticmethod 294 | def get_param_size(model): 295 | params = 0 296 | for p in model.parameters(): 297 | tmp = 1 298 | for x in p.size(): 299 | tmp *= x 300 | params += tmp 301 | return params 302 | 303 | @staticmethod 304 | def get_audio_conf(model): 305 | return model.module._audio_conf if DeepSpeech.is_parallel(model) else model._audio_conf 306 | 307 | @staticmethod 308 | def get_meta(model): 309 | m = model.module if DeepSpeech.is_parallel(model) else model 310 | meta = { 311 | "version": m._version, 312 | "hidden_size": m._hidden_size, 313 | "hidden_layers": m._hidden_layers, 314 | "rnn_type": supported_rnns_inv[m._rnn_type] 315 | } 316 | return meta 317 | 318 | @staticmethod 319 | def is_parallel(model): 320 | return isinstance(model, torch.nn.parallel.DataParallel) or \ 321 | isinstance(model, torch.nn.parallel.DistributedDataParallel) 322 | 323 | 324 | if __name__ == '__main__': 325 | import os.path 326 | import argparse 327 | 328 | parser = argparse.ArgumentParser(description='DeepSpeech model information') 329 | parser.add_argument('--model-path', default='models/deepspeech_final.pth', 330 | help='Path to model file created by training') 331 | args = parser.parse_args() 332 | package = torch.load(args.model_path, map_location=lambda storage, loc: storage) 333 | model = DeepSpeech.load_model(args.model_path) 334 | 335 | print("Model name: ", os.path.basename(args.model_path)) 336 | print("DeepSpeech version: ", model._version) 337 | print("") 338 | print("Recurrent Neural Network Properties") 339 | print(" RNN Type: ", model._rnn_type.__name__.lower()) 340 | print(" RNN Layers: ", model._hidden_layers) 341 | print(" RNN Size: ", model._hidden_size) 342 | print(" Classes: ", len(model._labels)) 343 | print("") 344 | print("Model Features") 345 | print(" Labels: ", model._labels) 346 | print(" Sample Rate: ", model._audio_conf.get("sample_rate", "n/a")) 347 | print(" Window Type: ", model._audio_conf.get("window", "n/a")) 348 | print(" Window Size: ", model._audio_conf.get("window_size", "n/a")) 349 | print(" Window Stride: ", model._audio_conf.get("window_stride", "n/a")) 350 | 351 | if package.get('loss_results', None) is not None: 352 | print("") 353 | print("Training Information") 354 | epochs = package['epoch'] 355 | print(" Epochs: ", epochs) 356 | print(" Current Loss: {0:.3f}".format(package['loss_results'][epochs - 1])) 357 | print(" Current CER: {0:.3f}".format(package['cer_results'][epochs - 1])) 358 | print(" Current WER: {0:.3f}".format(package['wer_results'][epochs - 1])) 359 | 360 | if package.get('meta', None) is not None: 361 | print("") 362 | print("Additional Metadata") 363 | for k, v in model._meta: 364 | print(" ", k, ": ", v) 365 | -------------------------------------------------------------------------------- /scripts/ijcai2019/plot_loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | from matplotlib import rcParams 4 | FONT_FAMILY='DejaVu Serif' 5 | rcParams["font.family"] = FONT_FAMILY 6 | from mpl_toolkits.axes_grid.inset_locator import inset_axes 7 | import matplotlib.pyplot as plt 8 | import matplotlib 9 | matplotlib.use("TkAgg") 10 | import numpy as np 11 | import datetime 12 | import itertools 13 | import utils as u 14 | #markers=['.','x','o','v','^','<','>','1','2','3','4','8','s','p','*'] 15 | markers=[None] 16 | colors = ['b', 'g', 'r', 'm', 'y', 'k', 'orange', 'purple', 'olive'] 17 | markeriter = itertools.cycle(markers) 18 | coloriter = itertools.cycle(colors) 19 | fixed_colors = { 20 | 'S-SGD': '#ff3300', 21 | 'ssgd': '#ff3300', 22 | 'gTopK': '#009900', 23 | 'blue': 'b', 24 | 0.001: 'C2', 25 | 0.002: 'C5', 26 | 0.00025: 'C3', 27 | 0.0001: 'C0', 28 | 0.00005: 'C1', 29 | 0.00001: 'C4', 30 | } 31 | 32 | OUTPUTPATH='/tmp/ijcai2019' 33 | LOGHOME='/tmp/logs' 34 | 35 | FONTSIZE=14 36 | HOSTNAME='localhost' 37 | num_batches_per_epoch = None 38 | global_max_epochs=150 39 | global_density=0.001 40 | #NFIGURES=4;NFPERROW=2 41 | NFIGURES=6;NFPERROW=2 42 | #NFIGURES=1;NFPERROW=1 43 | #FIGSIZE=(5*NFPERROW,3.8*NFIGURES/NFPERROW) 44 | PLOT_NORM=False 45 | PLOT_NORM=True 46 | if PLOT_NORM: 47 | #FIGSIZE=(5*NFPERROW,3.1*NFIGURES/NFPERROW) 48 | FIGSIZE=(5*NFPERROW,3.2*NFIGURES/NFPERROW) 49 | else: 50 | #FIGSIZE=(5*NFPERROW,2.9*NFIGURES/NFPERROW) 51 | FIGSIZE=(5*NFPERROW,3.0*NFIGURES/NFPERROW) 52 | 53 | fig, group_axs = plt.subplots(NFIGURES/NFPERROW, NFPERROW,figsize=FIGSIZE) 54 | if NFIGURES > 1 and PLOT_NORM: 55 | ax = None 56 | group_axtwins = [] 57 | for i in range(NFIGURES/NFPERROW): 58 | tmp = [] 59 | for a in group_axs[i]: 60 | tmp.append(a.twinx()) 61 | group_axtwins.append(tmp) 62 | global_index = 0 63 | else: 64 | ax = group_axs 65 | ax1 = ax 66 | global_index = None 67 | ax2 = None 68 | 69 | STANDARD_TITLES = { 70 | 'resnet20': 'ResNet-20', 71 | 'vgg16': 'VGG-16', 72 | 'alexnet': 'AlexNet', 73 | 'resnet50': 'ResNet-50', 74 | 'lstmptb': 'LSTM-PTB', 75 | 'lstm': 'LSTM-PTB', 76 | 'lstman4': 'LSTM-AN4' 77 | } 78 | 79 | def get_real_title(title): 80 | return STANDARD_TITLES.get(title, title) 81 | 82 | def seconds_between_datetimestring(a, b): 83 | a = datetime.datetime.strptime(a, '%Y-%m-%d %H:%M:%S') 84 | b = datetime.datetime.strptime(b, '%Y-%m-%d %H:%M:%S') 85 | delta = b - a 86 | return delta.days*86400+delta.seconds 87 | sbd = seconds_between_datetimestring 88 | 89 | def get_loss(line, isacc=False): 90 | valid = line.find('val acc: ') > 0 if isacc else line.find('loss: ') > 0 91 | if line.find('Epoch') > 0 and valid: 92 | items = line.split(' ') 93 | loss = float(items[-1]) 94 | t = line.split(' I')[0].split(',')[0] 95 | t = datetime.datetime.strptime(t, '%Y-%m-%d %H:%M:%S') 96 | return loss, t 97 | 98 | def read_losses_from_log(logfile, isacc=False): 99 | global num_batches_per_epoch 100 | f = open(logfile) 101 | losses = [] 102 | times = [] 103 | average_delays = [] 104 | lrs = [] 105 | i = 0 106 | time0 = None 107 | max_epochs = global_max_epochs 108 | counter = 0 109 | for line in f.readlines(): 110 | if line.find('num_batches_per_epoch: ') > 0: 111 | num_batches_per_epoch = int(line[0:-1].split('num_batches_per_epoch:')[-1]) 112 | valid = line.find('val acc: ') > 0 if isacc else line.find('average loss: ') > 0 113 | if line.find('num_batches_per_epoch: ') > 0: 114 | num_batches_per_epoch = int(line[0:-1].split('num_batches_per_epoch:')[-1]) 115 | if line.find('Epoch') > 0 and valid: 116 | t = line.split(' I')[0].split(',')[0] 117 | t = datetime.datetime.strptime(t, '%Y-%m-%d %H:%M:%S') 118 | if not time0: 119 | time0 = t 120 | if line.find('lr: ') > 0: 121 | try: 122 | lr = float(line.split(',')[-2].split('lr: ')[-1]) 123 | lrs.append(lr) 124 | except: 125 | pass 126 | if line.find('average delay: ') > 0: 127 | delay = int(line.split(':')[-1]) 128 | average_delays.append(delay) 129 | loss, t = get_loss(line, isacc) 130 | if loss and t: 131 | counter += 1 132 | losses.append(loss) 133 | times.append(t) 134 | if counter > max_epochs: 135 | break 136 | f.close() 137 | if len(times) > 0: 138 | t0 = time0 if time0 else times[0] #times[0] 139 | for i in range(0, len(times)): 140 | delta = times[i]- t0 141 | times[i] = delta.days*86400+delta.seconds 142 | return losses, times, average_delays, lrs 143 | 144 | def read_norm_from_log(logfile): 145 | f = open(logfile) 146 | means = [] 147 | stds = [] 148 | for line in f.readlines(): 149 | if line.find('gtopk-dense norm mean') > 0: 150 | items = line.split(',') 151 | mean = float(items[-2].split(':')[-1]) 152 | std = float(items[--1].split(':')[-1]) 153 | means.append(mean) 154 | stds.append(std) 155 | print('means: ', means) 156 | print('stds: ', stds) 157 | return means, stds 158 | 159 | def plot_loss(logfile, label, isacc=False, title='ResNet-20', fixed_color=None): 160 | losses, times, average_delays, lrs = read_losses_from_log(logfile, isacc=isacc) 161 | norm_means, norm_stds = read_norm_from_log(logfile) 162 | 163 | print('times: ', times) 164 | print('losses: ', losses) 165 | if len(average_delays) > 0: 166 | delay = int(np.mean(average_delays)) 167 | else: 168 | delay = 0 169 | if delay > 0: 170 | label = label + ' (delay=%d)' % delay 171 | if isacc: 172 | ax.set_ylabel('top-1 Validation Accuracy') 173 | else: 174 | ax.set_ylabel('training loss') 175 | ax.set_title(get_real_title(title)) 176 | marker = markeriter.next() 177 | if fixed_color: 178 | color = fixed_color 179 | else: 180 | color = coloriter.next() 181 | 182 | iterations = np.arange(len(losses)) 183 | line = ax.plot(iterations, losses, label=label, marker=marker, markerfacecolor='none', color=color, linewidth=1) 184 | if False and len(norm_means) > 0: 185 | global ax2 186 | if ax2 is None: 187 | ax2 = ax.twinx() 188 | ax2.set_ylabel('L2-Norm of : gTopK-Dense') 189 | ax2.plot(norm_means, label=label+' norms', color=color) 190 | ax.set_xlabel('# of epochs') 191 | if len(lrs) > 0: 192 | lr_indexes = [0] 193 | lr = lrs[0] 194 | for i in range(len(lrs)): 195 | clr = lrs[i] 196 | if lr != clr: 197 | lr_indexes.append(i) 198 | lr = clr 199 | u.update_fontsize(ax, FONTSIZE) 200 | return line 201 | 202 | 203 | def plot_with_params(dnn, nworkers, bs, lr, hostname, legend, isacc=False, prefix='', title='ResNet-20', sparsity=None, nsupdate=None, sg=None, density=None, force_legend=False): 204 | global global_density 205 | global_density = density 206 | postfix='5922' 207 | color = None 208 | if prefix.find('allreduce')>=0: 209 | postfix='0' 210 | elif prefix.find('single') >= 0: 211 | postfix = None 212 | if sparsity: 213 | logfile = LOGHOME+'/%s/%s-n%d-bs%d-lr%.4f-s%.5f' % (prefix, dnn, nworkers, bs, lr, sparsity) 214 | elif nsupdate: 215 | logfile = LOGHOME+'/%s/%s-n%d-bs%d-lr%.4f-ns%d' % (prefix, dnn, nworkers, bs, lr, nsupdate) 216 | else: 217 | logfile = LOGHOME+'/%s/%s-n%d-bs%d-lr%.4f' % (prefix, dnn, nworkers, bs, lr) 218 | if sg is not None: 219 | logfile += '-sg%.2f' % sg 220 | if density is not None: 221 | logfile += '-ds%s' % str(density) 222 | color = fixed_colors[density] 223 | else: 224 | color = fixed_colors['S-SGD'] 225 | if postfix is None: 226 | logfile += '/%s.log' % (hostname) 227 | else: 228 | logfile += '/%s-%s.log' % (hostname, postfix) 229 | print('logfile: ', logfile) 230 | if force_legend: 231 | l = legend 232 | else: 233 | l = legend+ '(lr=%.4f, bs=%d, %d workers)'%(lr, bs, nworkers) 234 | line = plot_loss(logfile, l, isacc=isacc, title=dnn, fixed_color=color) 235 | return line 236 | 237 | def plot_group_norm_diff(): 238 | global ax 239 | networks = ['vgg16', 'resnet20', 'lstm', 'lstman4'] 240 | networks = ['vgg16', 'resnet20', 'alexnet', 'resnet50', 'lstm', 'lstman4'] 241 | for i, network in enumerate(networks): 242 | ax_row = i / NFPERROW 243 | ax_col = i % NFPERROW 244 | ax = group_axs[ax_row][ax_col] 245 | ax1 = group_axtwins[ax_row][ax_col] 246 | plts = plot_norm_diff(ax1, network) 247 | lines, labels = ax.get_legend_handles_labels() 248 | STNAME 249 | lines2, labels2 = ax1.get_legend_handles_labels() 250 | fig.legend(lines + lines2, labels + labels2, ncol=4, loc='upper center', fontsize=FONTSIZE, frameon=True) 251 | plt.subplots_adjust(bottom=0.09, left=0.08, right=0.90, top=0.88, wspace=0.49, hspace=0.42) 252 | plt.savefig('%s/multiple_normdiff.pdf'%OUTPUTPATH) 253 | 254 | def plot_norm_diff(lax=None, network=None, subfig=None): 255 | global global_index 256 | global global_max_epochs 257 | density = 0.001 258 | nsupdate=1 259 | prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019' 260 | if network == 'lstm': 261 | network = 'lstm';bs =100;lr=30.0;epochs =40 262 | elif network == 'lstman4': 263 | network = 'lstman4';bs =8;lr=0.0002;epochs = 80 264 | elif network == 'resnet20': 265 | network = 'resnet20';bs =32;lr=0.1;epochs=140 266 | elif network == 'vgg16': 267 | network = 'vgg16';bs=128;lr=0.1;epochs=140 268 | elif network == 'alexnet': 269 | network = 'alexnet';bs=256;lr=0.01;epochs =40 270 | elif network == 'resnet50': 271 | nsupdate=16 272 | network = 'resnet50';bs=512;lr=0.01;epochs =35 273 | global_max_epochs = epochs 274 | path = LOGHOME+'/%s/%s-n4-bs%d-lr%.4f-ns%d-sg1.50-ds%s' % (prefix, network,bs,lr, nsupdate,density) 275 | print(network, path) 276 | plts = [] 277 | if network == 'lstm': 278 | line = plot_with_params(network, 4, 100, 30.0, HOSTNAME, r'S-SGD loss', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, force_legend=True) 279 | plts.append(line) 280 | line = plot_with_params(network, 4, 100, 30.0, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 281 | plts.append(line) 282 | elif network == 'resnet20': 283 | line = plot_with_params(network, 4, 32, lr, HOSTNAME, 'S-SGD loss', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', force_legend=True) 284 | plts.append(line) 285 | line = plot_with_params(network, 4, bs, lr, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix='allreduce-comp-topk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 286 | plts.append(line) 287 | pass 288 | elif network == 'vgg16': 289 | line = plot_with_params(network, 4, bs, lr, HOSTNAME, 'S-SGD loss', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, force_legend=True) 290 | plts.append(line) 291 | line = plot_with_params(network, 4, bs, lr, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix=prefix, nsupdate=1, sg=1.5, density=density, force_legend=True) 292 | plts.append(line) 293 | elif network == 'lstman4': 294 | line = plot_with_params(network, 4, 8, 0.0002, HOSTNAME, 'S-SGD loss', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, force_legend=True) 295 | plts.append(line) 296 | line = plot_with_params(network, 4, 8, 0.0002, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 297 | plts.append(line) 298 | elif network == 'resnet50': 299 | line = plot_with_params(network, 4, 512, lr, HOSTNAME, 'S-SGD loss', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=nsupdate, force_legend=True) 300 | line = plot_with_params(network, 4, 512, lr, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix=prefix, nsupdate=nsupdate, sg=1.5, density=density, force_legend=True) 301 | plts.append(line) 302 | elif network == 'alexnet': 303 | plot_with_params(network, 4, 256, lr, HOSTNAME, 'S-SGD', prefix='allreduce-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, force_legend=True) 304 | line = plot_with_params(network, 4, 256, lr, HOSTNAME, r'gTop-$k$ S-SGD loss', prefix=prefix, nsupdate=nsupdate, sg=1.5, density=density, force_legend=True) 305 | plts.append(line) 306 | arr = [] 307 | arr2 = [] 308 | for i in range(1, epochs+1): 309 | fn = '%s/gtopknorm-rank0-epoch%d.npy' % (path, i) 310 | fn2 = '%s/randknorm-rank0-epoch%d.npy' % (path, i) 311 | arr.append(np.mean(np.power(np.load(fn), 2))) 312 | arr2.append(np.mean(np.power(np.load(fn2), 2))) 313 | arr = np.array(arr) 314 | arr2 = np.array(arr2) 315 | cax = lax if lax is not None else ax1 316 | cax.plot(arr/arr2, label=r'$\delta$', color=fixed_colors['blue'],linewidth=1) 317 | cax.set_ylim(bottom=0.97, top=1.001) 318 | zero_x = np.arange(len(arr), step=1) 319 | ones = np.ones_like(zero_x) 320 | cax.plot(zero_x, ones, ':', label='1 ref.', color='black', linewidth=1) 321 | if True or network.find('lstm') >= 0: 322 | subaxes = inset_axes(cax, 323 | width='50%', 324 | height='30%', 325 | bbox_to_anchor=(-0.04,0,1,0.95), 326 | bbox_transform=cax.transAxes, 327 | loc='upper right') 328 | half = epochs //2 329 | subx = np.arange(half, len(arr)) 330 | subaxes.plot(subx, (arr/arr2)[half:], color=fixed_colors['blue'], linewidth=1) 331 | subaxes.plot(subx, ones[half:], ':', color='black', linewidth=1) 332 | subaxes.set_ylim(bottom=subaxes.get_ylim()[0]) 333 | cax.set_xlabel('# of iteration') 334 | cax.set_ylabel(r'$\delta$') 335 | u.update_fontsize(cax, FONTSIZE) 336 | if global_index is not None: 337 | global_index += 1 338 | return plts 339 | 340 | 341 | def plot_group_lr_sensitivies(): 342 | def _plot_with_network(network): 343 | global global_max_epochs 344 | global global_density 345 | densities = [0.001, 0.00025, 0.0001, 0.00005] 346 | if network == 'vgg16': 347 | global_max_epochs = 140 348 | for density in densities: 349 | legend=r'$c$=%d'%(1/density) 350 | plot_with_params(network, 4, 128, 0.1, HOSTNAME, legend, prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 351 | elif network == 'resnet20': 352 | global_max_epochs = 140 353 | for density in densities: 354 | legend=r'$c$=%d'%(1/density) 355 | plot_with_params(network, 4, 32, 0.1, HOSTNAME, legend, prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 356 | elif network == 'lstm': 357 | global_max_epochs = 40 358 | for density in densities: 359 | legend=r'$c$=%d'%(1/density) 360 | plot_with_params(network, 4, 100, 30.0, HOSTNAME, legend, prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 361 | elif network == 'lstman4': 362 | global_max_epochs = 80 363 | for density in densities: 364 | legend=r'$c$=%d'%(1/density) 365 | plot_with_params(network, 4, 8, 0.0002, HOSTNAME, legend, prefix='allreduce-comp-gtopk-baseline-gwarmup-dc1-model-ijcai2019', nsupdate=1, sg=1.5, density=density, force_legend=True) 366 | global ax 367 | networks = ['vgg16', 'resnet20', 'lstm', 'lstman4'] 368 | for i, network in enumerate(networks): 369 | ax_row = i / NFPERROW 370 | ax_col = i % NFPERROW 371 | ax = group_axs[ax_row][ax_col] 372 | _plot_with_network(network) 373 | ax.legend(ncol=2, loc='upper right', fontsize=FONTSIZE-2) 374 | plt.subplots_adjust(bottom=0.10, left=0.10, right=0.94, top=0.95, wspace=0.37, hspace=0.42) 375 | plt.savefig('%s/multiple_lrs.pdf'%OUTPUTPATH) 376 | 377 | 378 | if __name__ == '__main__': 379 | if PLOT_NORM: 380 | plot_group_norm_diff() 381 | else: 382 | plot_group_lr_sensitivies() 383 | plt.show() 384 | -------------------------------------------------------------------------------- /audio_data/an4_val_manifest.csv: -------------------------------------------------------------------------------- 1 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an443-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an443-mmxg-b.txt 2 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an392-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an392-mjwl-b.txt 3 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an440-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an440-mjgm-b.txt 4 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an402-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an402-mdms2-b.txt 5 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an416-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an416-fjlp-b.txt 6 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an409-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an409-fcaw-b.txt 7 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an442-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an442-mmxg-b.txt 8 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-mjwl-b.txt 9 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-mdms2-b.txt 10 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-fjlp-b.txt 11 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-mjgm-b.txt 12 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-mjgm-b.txt 13 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-mmxg-b.txt 14 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an400-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an400-miry-b.txt 15 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an422-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an422-menk-b.txt 16 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-marh-b.txt 17 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an421-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an421-menk-b.txt 18 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-fvap-b.txt 19 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an436-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an436-mjgm-b.txt 20 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-fvap-b.txt 21 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-marh-b.txt 22 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-miry-b.txt 23 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-mjgm-b.txt 24 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an435-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an435-marh-b.txt 25 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-mjgm-b.txt 26 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an428-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an428-fvap-b.txt 27 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an426-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an426-fvap-b.txt 28 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an395-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an395-mjwl-b.txt 29 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-fcaw-b.txt 30 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-mjwl-b.txt 31 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-miry-b.txt 32 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an439-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an439-mjgm-b.txt 33 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an391-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an391-mjwl-b.txt 34 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-menk-b.txt 35 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-marh-b.txt 36 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-fvap-b.txt 37 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-mmxg-b.txt 38 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an432-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an432-marh-b.txt 39 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen3-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen3-menk-b.txt 40 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-mjwl-b.txt 41 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-fvap-b.txt 42 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-marh-b.txt 43 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-marh-b.txt 44 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-mjwl-b.txt 45 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-fvap-b.txt 46 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an418-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an418-fjlp-b.txt 47 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-menk-b.txt 48 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-fcaw-b.txt 49 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-fjlp-b.txt 50 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-mmxg-b.txt 51 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-fvap-b.txt 52 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-mjgm-b.txt 53 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-mjwl-b.txt 54 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an437-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an437-mjgm-b.txt 55 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-fjlp-b.txt 56 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an419-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an419-fjlp-b.txt 57 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an425-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an425-menk-b.txt 58 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-fjlp-b.txt 59 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-mmxg-b.txt 60 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-marh-b.txt 61 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-mjwl-b.txt 62 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an393-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an393-mjwl-b.txt 63 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an417-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an417-fjlp-b.txt 64 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-mjgm-b.txt 65 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-fjlp-b.txt 66 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-mjwl-b.txt 67 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-menk-b.txt 68 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-miry-b.txt 69 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an438-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an438-mjgm-b.txt 70 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an397-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an397-miry-b.txt 71 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-fjlp-b.txt 72 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-mjgm-b.txt 73 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-miry-b.txt 74 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an433-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an433-marh-b.txt 75 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-menk-b.txt 76 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-mdms2-b.txt 77 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-fcaw-b.txt 78 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an445-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an445-mmxg-b.txt 79 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-fvap-b.txt 80 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen8-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen8-fcaw-b.txt 81 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-fcaw-b.txt 82 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-fcaw-b.txt 83 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-mjwl-b.txt 84 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an404-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an404-mdms2-b.txt 85 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-marh-b.txt 86 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-mmxg-b.txt 87 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an430-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an430-fvap-b.txt 88 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-mmxg-b.txt 89 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-fvap-b.txt 90 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-miry-b.txt 91 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-menk-b.txt 92 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen2-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen2-mdms2-b.txt 93 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an427-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an427-fvap-b.txt 94 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen6-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen6-mdms2-b.txt 95 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an401-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an401-mdms2-b.txt 96 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an408-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an408-fcaw-b.txt 97 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-fjlp-b.txt 98 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an441-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an441-mmxg-b.txt 99 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-mjgm-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-mjgm-b.txt 100 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-marh-b.txt 101 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-miry-b.txt 102 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an410-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an410-fcaw-b.txt 103 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an423-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an423-menk-b.txt 104 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an431-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an431-marh-b.txt 105 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an429-fvap-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an429-fvap-b.txt 106 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-fcaw-b.txt 107 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an398-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an398-miry-b.txt 108 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an394-mjwl-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an394-mjwl-b.txt 109 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-menk-b.txt 110 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-miry-b.txt 111 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an403-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an403-mdms2-b.txt 112 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-mmxg-b.txt 113 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an424-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an424-menk-b.txt 114 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an396-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an396-miry-b.txt 115 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an406-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an406-fcaw-b.txt 116 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an407-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an407-fcaw-b.txt 117 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an444-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an444-mmxg-b.txt 118 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-menk-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-menk-b.txt 119 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an420-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an420-fjlp-b.txt 120 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an434-marh-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an434-marh-b.txt 121 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen1-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen1-mdms2-b.txt 122 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an399-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an399-miry-b.txt 123 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-mmxg-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-mmxg-b.txt 124 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-miry-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-miry-b.txt 125 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-fjlp-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-fjlp-b.txt 126 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/an405-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/an405-mdms2-b.txt 127 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-mdms2-b.txt 128 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen7-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen7-mdms2-b.txt 129 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen4-mdms2-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen4-mdms2-b.txt 130 | /home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/wav/cen5-fcaw-b.wav,/home/yxwang/proj/DGC_speech2text/data/an4_dataset/test/an4/txt/cen5-fcaw-b.txt 131 | --------------------------------------------------------------------------------