├── .gitignore ├── LICENSE ├── README.md └── distributed_code ├── README.md ├── auto_extract.py ├── environments └── docker │ ├── base │ ├── .screenrc │ ├── .tmux.conf │ ├── Dockerfile │ ├── entrypoint.sh │ └── fix-permissions │ ├── docker-compose.yml │ └── pytorch-mpi │ └── Dockerfile ├── hostfile ├── main.py ├── parameters.py ├── pcode ├── __init__.py ├── create_dataset.py ├── create_metrics.py ├── create_model.py ├── create_optimizer.py ├── create_scheduler.py ├── datasets │ ├── __init__.py │ ├── loader │ │ ├── __init__.py │ │ ├── epsilon_or_rcv1_folder.py │ │ ├── imagenet_folder.py │ │ ├── preprocess_toolkit.py │ │ ├── serialize.py │ │ ├── svhn_folder.py │ │ └── utils.py │ ├── partition_data.py │ └── prepare_data.py ├── distributed_running_cv.py ├── distributed_running_nlp.py ├── models │ ├── __init__.py │ ├── densenet.py │ ├── lenet.py │ ├── mlp.py │ ├── resnet.py │ ├── rnn_lm.py │ ├── vgg.py │ └── wideresnet.py ├── optim │ ├── __init__.py │ ├── dgc.py │ ├── ef_sign_sgd.py │ ├── local_ef_sign_sgd.py │ ├── local_sgd.py │ ├── local_sign_sgd.py │ ├── sgd.py │ ├── sign_sgd.py │ └── utils.py ├── tools │ ├── __init__.py │ ├── plot.py │ ├── plot_utils.py │ └── show_results.py └── utils │ ├── __init__.py │ ├── auxiliary.py │ ├── checkpoint.py │ ├── communication.py │ ├── error_handler.py │ ├── logging.py │ ├── mathdict.py │ ├── op_files.py │ ├── op_paths.py │ ├── sparsification.py │ ├── stat_tracker.py │ ├── tensor_buffer.py │ ├── timer.py │ └── topology.py ├── run.py └── tmux_cluster ├── __init__.py ├── tmux.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.vscode 2 | *test.sh 3 | *Makefile 4 | *iccluster 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Don't Use Large Mini-batches, Use Local SGD 2 | We present here the code of the experimental parts of the paper [Don't Use Large Mini-batches, Use Local SGD](https://openreview.net/forum?id=B1eyO1BFPr). 3 | 4 | Abstract: 5 | Mini-batch stochastic gradient methods (SGD) are state of the art for distributed training of deep neural networks. 6 | Drastic increases in the mini-batch sizes have lead to key efficiency and scalability gains in recent years. 7 | However, progress faces a major roadblock, as models trained with large batches often do not generalize well, i.e. they do not show good accuracy on new data. 8 | As a remedy, we propose a post-local SGD and show that it significantly improves the generalization performance compared to large-batch training on standard benchmarks while enjoying the same efficiency (time-to-accuracy) and scalability. We further provide an extensive study of the communication efficiency vs. performance trade-offs associated with a host of local SGD variants. 9 | 10 | 11 | # Code usage 12 | We rely on `Docker` for our experimental environments. Please refer to the folder `distributed_code/environments/docker` for more details. 13 | 14 | The script below trains `ResNet-20` with `CIFAR-10`, as an example of centralized training algorithm `(post) local SGD`. 15 | For the detailed instructions and more examples, please refer to the file `distributed_code/README.md`. 16 | ```bash 17 | OMP_NUM_THREADS=2 MKL_NUM_THREADS=2 $HOME/conda/envs/pytorch-py3.6/bin/python run.py \ 18 | --arch resnet20 --optimizer local_sgd \ 19 | --avg_model True --experiment demo --manual_seed 6 \ 20 | --data cifar10 --pin_memory True \ 21 | --batch_size 128 --base_batch_size 64 --num_workers 2 \ 22 | --num_epochs 300 --partition_data random --reshuffle_per_epoch True --stop_criteria epoch \ 23 | --n_mpi_process 16 --n_sub_process 1 --world 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 \ 24 | --on_cuda True --use_ipc False \ 25 | --lr 0.1 --lr_scaleup True --lr_warmup True --lr_warmup_epochs 5 \ 26 | --lr_scheduler MultiStepLR --lr_decay 0.1 --lr_milestones 150,225 \ 27 | --local_step 16 --turn_on_local_step_from 150 \ 28 | --weight_decay 1e-4 --use_nesterov True --momentum_factor 0.9 \ 29 | --hostfile hostfile --graph_topology complete --track_time True --display_tracked_time True \ 30 | --python_path $HOME/conda/envs/pytorch-py3.6/bin/python --mpi_path $HOME/.openmpi/ 31 | ``` 32 | 33 | # Reference 34 | If you use this code, please cite the following [paper](https://openreview.net/forum?id=B1eyO1BFPr) 35 | 36 | ``` 37 | @inproceedings{lin2020dont, 38 | title={Don't Use Large Mini-batches, Use Local {SGD}}, 39 | author={Tao Lin and Sebastian U. Stich and Kumar Kshitij Patel and Martin Jaggi}, 40 | booktitle={ICLR - International Conference on Learning Representations}, 41 | year={2020}, 42 | url={https://openreview.net/forum?id=B1eyO1BFPr} 43 | } 44 | ``` 45 | -------------------------------------------------------------------------------- /distributed_code/README.md: -------------------------------------------------------------------------------- 1 | # Getting started 2 | Our experiments heavily rely on `Docker` and `Kubernetes`. For the detailed experimental environment setup, please refer to dockerfile under the `environments` folder. 3 | 4 | 5 | ## Use case of distributed training (centralized) 6 | Some simple explanation of the arguments used in the code. 7 | * Arguments related to *distributed training*: 8 | * The `n_mpi_process` and `n_sub_process` indicates the number of nodes and the number of GPUs for each node. The data-parallel wrapper is adapted and applied locally for each node. 9 | * Note that the exact mini-batch size for each MPI process is specified by `batch_size`, while the mini-batch size used for each GPU is `batch_size/n_sub_process`. 10 | * The `world` describes the GPU topology of the distributed training, in terms of all GPUs used for the distributed training. 11 | * The `hostfile` from `mpi` specifies the physical location of the MPI processes. 12 | * We provide two use cases here: 13 | * `n_mpi_process=2`, `n_sub_process=1` and `world=0,0` indicates that two MPI processes are running on 2 GPUs with the same GPU id. It could be either 1 GPU at the same node or two GPUs at different nodes, where the exact configuration is determined by `hostfile`. 14 | * `n_mpi_process=2`, `n_sub_process=2` and `world=0,1,0,1` indicates that two MPI processes are running on 4 GPUs and each MPI process uses GPU id 0 and id 1 (on 2 nodes). 15 | * Arguments related to *communication compression*: 16 | * The `graph_topology` 17 | * The `optimizer` will decide the type of distributed training, e.g., centralized SGD, decentralized SGD 18 | * The `comm_op` specifies the communication compressor we can use, e.g., `sign+norm`, `random-k`, `top-k`. 19 | * Arguments related to *learning*: 20 | * The `lr_scaleup`, `lr_warmup` and `lr_warmup_epochs` will decide if we want to scale up the learning rate, or warm up the learning rate. For more details, please check `pcode/create_scheduler.py`. 21 | 22 | ### Examples 23 | The script below trains `ResNet-20` with `CIFAR-10`, as an example of centralized training algorithm `(post-)local SGD`. 24 | ```bash 25 | OMP_NUM_THREADS=2 MKL_NUM_THREADS=2 $HOME/conda/envs/pytorch-py3.6/bin/python run.py \ 26 | --arch resnet20 --optimizer local_sgd \ 27 | --avg_model True --experiment demo --manual_seed 6 \ 28 | --data cifar10 --pin_memory True \ 29 | --batch_size 128 --base_batch_size 64 --num_workers 2 \ 30 | --num_epochs 300 --partition_data random --reshuffle_per_epoch True --stop_criteria epoch \ 31 | --n_mpi_process 16 --n_sub_process 1 --world 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 \ 32 | --on_cuda True --use_ipc False \ 33 | --lr 0.1 --lr_scaleup True --lr_warmup True --lr_warmup_epochs 5 \ 34 | --lr_scheduler MultiStepLR --lr_decay 0.1 --lr_milestones 150,225 \ 35 | --local_step 32 --turn_on_local_step_from 150 \ 36 | --weight_decay 1e-4 --use_nesterov True --momentum_factor 0.9 \ 37 | --hostfile hostfile --graph_topology complete --track_time True --display_tracked_time True \ 38 | --python_path $HOME/conda/envs/pytorch-py3.6/bin/python --mpi_path $HOME/.openmpi/ 39 | ``` 40 | 41 | The script below trains `ResNet-20` with `CIFAR-10`, as an example of centralized training algorithm `post-local SGD` with `sign` based communication compressor. 42 | ```bash 43 | OMP_NUM_THREADS=2 MKL_NUM_THREADS=2 $HOME/conda/envs/pytorch-py3.6/bin/python run.py \ 44 | --arch resnet20 --optimizer local_sign_sgd \ 45 | --avg_model True --experiment demo --manual_seed 6 \ 46 | --data cifar10 --pin_memory True \ 47 | --batch_size 128 --base_batch_size 128 --num_workers 2 \ 48 | --num_epochs 300 --partition_data random --reshuffle_per_epoch True --stop_criteria epoch \ 49 | --n_mpi_process 16 --n_sub_process 1 --world 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 \ 50 | --on_cuda True --use_ipc False \ 51 | --lr 0.01 --lr_scaleup False --lr_warmup False --lr_warmup_epochs 5 \ 52 | --lr_scheduler MultiStepLR --lr_decay 0.1 --lr_milestones 150,225 \ 53 | --local_step 16 --turn_on_local_step_from 150 \ 54 | --weight_decay 1e-4 --use_nesterov True --momentum_factor 0.9 \ 55 | --hostfile hostfile --graph_topology complete --track_time True --display_tracked_time True \ 56 | --python_path $HOME/conda/envs/pytorch-py3.6/bin/python --mpi_path $HOME/.openmpi/ 57 | ``` 58 | 59 | The script below trains `ResNet-20` with `CIFAR-10`, as an example of centralized training algorithm `post-local SGD` with `sign+norm` based communication compressor. 60 | ```bash 61 | OMP_NUM_THREADS=2 MKL_NUM_THREADS=2 $HOME/conda/envs/pytorch-py3.6/bin/python run.py \ 62 | --arch resnet20 --optimizer local_ef_sign_sgd \ 63 | --avg_model True --experiment demo --manual_seed 6 \ 64 | --data cifar10 --pin_memory True \ 65 | --batch_size 128 --base_batch_size 128 --num_workers 2 \ 66 | --num_epochs 300 --partition_data random --reshuffle_per_epoch True --stop_criteria epoch \ 67 | --n_mpi_process 16 --n_sub_process 1 --world 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0 \ 68 | --on_cuda True --use_ipc False \ 69 | --lr 0.1 --lr_scaleup True --lr_warmup True --lr_warmup_epochs 5 \ 70 | --lr_scheduler MultiStepLR --lr_decay 0.1 --lr_milestones 150,225 \ 71 | --local_step 64 --turn_on_local_step_from 150 \ 72 | --weight_decay 1e-4 --use_nesterov True --momentum_factor 0.9 \ 73 | --hostfile hostfile --graph_topology complete --track_time True --display_tracked_time True \ 74 | --python_path $HOME/conda/envs/pytorch-py3.6/bin/python --mpi_path $HOME/.openmpi/ 75 | ``` 76 | 77 | The script example of `LSTM` on `wikitext2` for `SGD` follows: 78 | ```bash 79 | OMP_NUM_THREADS=2 MKL_NUM_THREADS=2 $HOME/conda/envs/pytorch-py3.6/bin/python run.py \ 80 | --arch rnn_lm --rnn_n_hidden 650 --rnn_n_layers 3 --rnn_bptt_len 30 \ 81 | --rnn_clip 0.4 --rnn_use_pretrained_emb False --rnn_tie_weights True --drop_rate 0.40 \ 82 | --optimizer sgd --avg_model True --experiment demo \ 83 | --data wikitext2 --pin_memory True \ 84 | --batch_size 32 --base_batch_size 24 --num_workers 2 \ 85 | --num_epochs 300 --partition_data random --reshuffle_per_epoch False --stop_criteria epoch \ 86 | --n_mpi_process 32 --n_sub_process 1 --world 0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1 \ 87 | --on_cuda True --use_ipc False --comm_device cuda \ 88 | --lr 2.5 --lr_scaleup True --lr_warmup True --lr_warmup_epochs 5 \ 89 | --lr_scheduler MultiStepLR --lr_decay 0.1 --lr_milestones 150,225 \ 90 | --weight_decay 0 --use_nesterov False --momentum_factor 0 \ 91 | --hostfile hostfile --graph_topology complete --track_time True --display_tracked_time True \ 92 | --python_path $HOME/conda/envs/pytorch-py3.6/bin/python --mpi_path $HOME/.openmpi/ 93 | ``` 94 | -------------------------------------------------------------------------------- /distributed_code/auto_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | 5 | import pcode.utils.op_files as op_files 6 | from pcode.tools.show_results import load_raw_info_from_experiments 7 | 8 | """parse and define arguments for different tasks.""" 9 | 10 | 11 | def get_args(): 12 | # feed them to the parser. 13 | parser = argparse.ArgumentParser(description='Extract results.') 14 | 15 | # add arguments. 16 | parser.add_argument('--in_dir', type=str) 17 | parser.add_argument('--out_name', type=str, default='summary.pickle') 18 | 19 | # parse aˇˇrgs. 20 | args = parser.parse_args() 21 | 22 | # an argument safety check. 23 | check_args(args) 24 | return args 25 | 26 | 27 | def check_args(args): 28 | assert args.in_dir is not None 29 | 30 | # define out path. 31 | args.out_path = os.path.join(args.in_dir, args.out_name) 32 | 33 | 34 | """write the results to path.""" 35 | 36 | 37 | def main(args): 38 | # save the parsed results to path. 39 | op_files.write_pickle( 40 | load_raw_info_from_experiments(args.in_dir), 41 | args.out_path) 42 | 43 | 44 | if __name__ == '__main__': 45 | args = get_args() 46 | 47 | main(args) 48 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/base/.screenrc: -------------------------------------------------------------------------------- 1 | # the following two lines give a two-line status, with the current window highlighted 2 | hardstatus alwayslastline 3 | hardstatus string '%{= kG}[%{G}%H%? %1`%?%{g}][%= %{= kw}%-w%{+b yk} %n*%t%?(%u)%? %{-}%+w %=%{g}][%{B}%m/%d %{W}%C%A%{g}]' 4 | 5 | # huge scrollback buffer 6 | defscrollback 5000 7 | 8 | # no welcome message 9 | startup_message off 10 | 11 | # 256 colors 12 | attrcolor b ".I" 13 | termcapinfo xterm 'Co#256:AB=\E[48;5;%dm:AF=\E[38;5;%dm' 14 | defbce on 15 | 16 | # mouse tracking allows to switch region focus by clicking 17 | mousetrack on 18 | 19 | # default windows 20 | screen -t Shell1 1 bash 21 | screen -t Shell2 2 bash 22 | screen -t Python 3 python 23 | screen -t Media 4 bash 24 | select 0 25 | bind c screen 1 # window numbering starts at 1 not 0 26 | bind 0 select 10 27 | 28 | # get rid of silly xoff stuff 29 | bind s split 30 | 31 | # layouts 32 | layout autosave on 33 | layout new one 34 | select 1 35 | layout new two 36 | select 1 37 | split 38 | resize -v +8 39 | focus down 40 | select 4 41 | focus up 42 | layout new three 43 | select 1 44 | split 45 | resize -v +7 46 | focus down 47 | select 3 48 | split -v 49 | resize -h +10 50 | focus right 51 | select 4 52 | focus up 53 | 54 | layout attach one 55 | layout select one 56 | 57 | # navigating regions with Ctrl-arrows 58 | bindkey "^[[1;5D" focus left 59 | bindkey "^[[1;5C" focus right 60 | bindkey "^[[1;5A" focus up 61 | bindkey "^[[1;5B" focus down 62 | 63 | # switch windows with F3 (prev) and F4 (next) 64 | bindkey "^[OR" prev 65 | bindkey "^[OS" next 66 | 67 | # switch layouts with Ctrl+F3 (prev layout) and Ctrl+F4 (next) 68 | bindkey "^[O1;5R" layout prev 69 | bindkey "^[O1;5S" layout next 70 | 71 | # F2 puts Screen into resize mode. Resize regions using hjkl keys. 72 | bindkey "^[OQ" eval "command -c rsz" # enter resize mode 73 | 74 | # use hjkl keys to resize regions 75 | bind -c rsz h eval "resize -h -5" "command -c rsz" 76 | bind -c rsz j eval "resize -v -5" "command -c rsz" 77 | bind -c rsz k eval "resize -v +5" "command -c rsz" 78 | bind -c rsz l eval "resize -h +5" "command -c rsz" 79 | 80 | # quickly switch between regions using tab and arrows 81 | bind -c rsz \t eval "focus" "command -c rsz" # Tab 82 | bind -c rsz -k kl eval "focus left" "command -c rsz" # Left 83 | bind -c rsz -k kr eval "focus right" "command -c rsz" # Right 84 | bind -c rsz -k ku eval "focus up" "command -c rsz" # Up 85 | bind -c rsz -k kd eval "focus down" "command -c rsz" # Down 86 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/base/.tmux.conf: -------------------------------------------------------------------------------- 1 | # 0 is too far from ` ;) 2 | set -g base-index 1 3 | 4 | # Automatically set window title 5 | set-window-option -g automatic-rename on 6 | set-option -g set-titles on 7 | set-option -g mouse on 8 | 9 | #set -g default-terminal screen-256color 10 | set -g status-keys vi 11 | set -g history-limit 10000 12 | 13 | setw -g mode-keys vi 14 | setw -g monitor-activity on 15 | 16 | bind-key v split-window -h 17 | bind-key s split-window -v 18 | 19 | bind-key J resize-pane -D 5 20 | bind-key K resize-pane -U 5 21 | bind-key H resize-pane -L 5 22 | bind-key L resize-pane -R 5 23 | 24 | bind-key M-j resize-pane -D 25 | bind-key M-k resize-pane -U 26 | bind-key M-h resize-pane -L 27 | bind-key M-l resize-pane -R 28 | 29 | # Vim style pane selection 30 | bind h select-pane -L 31 | bind j select-pane -D 32 | bind k select-pane -U 33 | bind l select-pane -R 34 | 35 | # Use Alt-vim keys without prefix key to switch panes 36 | bind -n M-h select-pane -L 37 | bind -n M-j select-pane -D 38 | bind -n M-k select-pane -U 39 | bind -n M-l select-pane -R 40 | 41 | # Use Alt-arrow keys without prefix key to switch panes 42 | bind -n M-Left select-pane -L 43 | bind -n M-Right select-pane -R 44 | bind -n M-Up select-pane -U 45 | bind -n M-Down select-pane -D 46 | 47 | # Shift arrow to switch windows 48 | bind -n S-Left previous-window 49 | bind -n S-Right next-window 50 | 51 | # No delay for escape key press 52 | set -sg escape-time 0 53 | 54 | # Reload tmux config 55 | bind r source-file ~/.tmux.conf 56 | 57 | # THEME 58 | set -g status-bg black 59 | set -g status-fg white 60 | set -g window-status-current-bg white 61 | set -g window-status-current-fg black 62 | set -g window-status-current-attr bold 63 | set -g status-interval 60 64 | set -g status-left-length 30 65 | set -g status-left '#[fg=green](#S) #(whoami)' 66 | set -g status-right '#[fg=yellow]#(cut -d " " -f 1-3 /proc/loadavg)#[default] #[fg=white]%H:%M#[default]' 67 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/base/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 2 | MAINTAINER Tao Lin 3 | 4 | 5 | # install some necessary tools. 6 | RUN echo "deb http://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1604/x86_64 /" > /etc/apt/sources.list.d/nvidia-ml.list 7 | RUN apt-get update \ 8 | && apt-get install -y --no-install-recommends \ 9 | build-essential \ 10 | ca-certificates \ 11 | pkg-config \ 12 | software-properties-common 13 | RUN apt-get install -y \ 14 | inkscape \ 15 | jed \ 16 | libsm6 \ 17 | libxext-dev \ 18 | libxrender1 \ 19 | lmodern \ 20 | libcurl3-dev \ 21 | libfreetype6-dev \ 22 | libzmq3-dev \ 23 | libcupti-dev \ 24 | pkg-config \ 25 | libav-tools \ 26 | libjpeg-dev \ 27 | libpng-dev \ 28 | zlib1g-dev \ 29 | locales 30 | RUN apt-get install -y \ 31 | sudo \ 32 | rsync \ 33 | cmake \ 34 | g++ \ 35 | swig \ 36 | vim \ 37 | git \ 38 | curl \ 39 | wget \ 40 | unzip \ 41 | zsh \ 42 | git \ 43 | screen \ 44 | tmux \ 45 | openssh-server 46 | RUN apt-get update && \ 47 | apt-get install -y pciutils net-tools iputils-ping && \ 48 | apt-get install -y htop 49 | RUN add-apt-repository ppa:openjdk-r/ppa \ 50 | && apt-get update \ 51 | && apt-get install -y \ 52 | openjdk-7-jdk \ 53 | openjdk-7-jre-headless 54 | # install good vim. 55 | RUN curl http://j.mp/spf13-vim3 -L -o - | sh 56 | 57 | # configure environments. 58 | RUN echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && locale-gen 59 | 60 | # configure user. 61 | ENV SHELL=/bin/bash \ 62 | NB_USER=user \ 63 | NB_UID=1000 \ 64 | NB_GROUP=1000 \ 65 | NB_GID=1000 66 | ENV HOME=/home/$NB_USER 67 | 68 | ADD base/fix-permissions /usr/local/bin/fix-permissions 69 | RUN chmod +x /usr/local/bin/fix-permissions 70 | ADD base/entrypoint.sh /usr/local/bin/entrypoint.sh 71 | RUN chmod +x /usr/local/bin/entrypoint.sh 72 | RUN groupadd $NB_GROUP -g $NB_GID 73 | RUN useradd -m -s /bin/bash -N -u $NB_UID -g $NB_GID $NB_USER && \ 74 | echo "${NB_USER}:${NB_USER}" | chpasswd && \ 75 | usermod -aG sudo,adm,root ${NB_USER} && \ 76 | fix-permissions $HOME 77 | RUN echo 'user ALL=(ALL) NOPASSWD: ALL' | sudo EDITOR='tee -a' visudo 78 | 79 | # Default ssh config file that skips (yes/no) question when first login to the host 80 | RUN mkdir /var/run/sshd 81 | RUN sed -i "s/#PasswordAuthentication.*/PasswordAuthentication no/g" /etc/ssh/sshd_config \ 82 | && sed -i "s/#PermitRootLogin.*/PermitRootLogin yes/g" /etc/ssh/sshd_config \ 83 | && sed -ri 's/UsePAM yes/#UsePAM yes/g' /etc/ssh/sshd_config \ 84 | && sed -i "s/#AuthorizedKeysFile/AuthorizedKeysFile/g" /etc/ssh/sshd_config 85 | RUN /usr/bin/ssh-keygen -A 86 | 87 | ENV SSHDIR $HOME/.ssh 88 | RUN mkdir -p $SSHDIR \ 89 | && chmod go-w $HOME/ \ 90 | && chmod 700 $SSHDIR \ 91 | && touch $SSHDIR/authorized_keys \ 92 | && chmod 600 $SSHDIR/authorized_keys \ 93 | && chown -R ${NB_USER}:${NB_GROUP} ${SSHDIR} \ 94 | && chown -R ${NB_USER}:${NB_GROUP} /etc/ssh/* 95 | 96 | ###### switch to user and compile test example. 97 | USER ${NB_USER} 98 | RUN ssh-keygen -b 2048 -t rsa -f $SSHDIR/id_rsa -q -N "" 99 | RUN cat ${SSHDIR}/*.pub >> ${SSHDIR}/authorized_keys 100 | RUN echo "StrictHostKeyChecking no" > ${SSHDIR}/config 101 | 102 | # configure screen and tmux 103 | ADD base/.tmux.conf $HOME/ 104 | ADD base/.screenrc $HOME/ 105 | 106 | # expose port for ssh and start ssh service. 107 | EXPOSE 22 108 | # expose port for notebook. 109 | EXPOSE 8888 110 | # expose port for tensorboard. 111 | EXPOSE 6666 112 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/base/entrypoint.sh: -------------------------------------------------------------------------------- 1 | sudo service ssh start 2 | exec "$@" 3 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/base/fix-permissions: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # set permissions on a directory 3 | # after any installation, if a directory needs to be (human) user-writable, 4 | # run this script on it. 5 | # It will make everything in the directory owned by the group $NB_GID 6 | # and writable by that group. 7 | # Deployments that want to set a specific user id can preserve permissions 8 | # by adding the `--group-add users` line to `docker run`. 9 | 10 | # uses find to avoid touching files that already have the right permissions, 11 | # which would cause massive image explosion 12 | 13 | # right permissions are: 14 | # group=$NB_GID 15 | # AND permissions include group rwX (directory-execute) 16 | # AND directories have setuid,setgid bits set 17 | 18 | set -e 19 | 20 | for d in $@; do 21 | find "$d" \ 22 | ! \( \ 23 | -group $NB_GID \ 24 | -a -perm -g+rwX \ 25 | \) \ 26 | -exec chgrp $NB_GID {} \; \ 27 | -exec chmod g+rwX {} \; 28 | # setuid,setgid *on directories only* 29 | find "$d" \ 30 | \( \ 31 | -type d \ 32 | -a ! -perm -6000 \ 33 | \) \ 34 | -exec chmod +6000 {} \; 35 | done 36 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | services: 3 | base: 4 | build: 5 | context: . 6 | dockerfile: base/Dockerfile 7 | image: user/base 8 | pytorch-mpi: 9 | build: 10 | context: . 11 | dockerfile: pytorch-mpi/Dockerfile 12 | image: user/pytorch-mpi 13 | depends_on: 14 | - base 15 | -------------------------------------------------------------------------------- /distributed_code/environments/docker/pytorch-mpi/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM user/base 2 | 3 | USER $NB_USER 4 | WORKDIR $HOME 5 | 6 | # install openMPI 7 | RUN mkdir $HOME/.openmpi/ 8 | RUN wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz 9 | RUN gunzip -c openmpi-3.0.0.tar.gz | tar xf - \ 10 | && cd openmpi-3.0.0 \ 11 | && ./configure --prefix=$HOME/.openmpi/ --with-cuda \ 12 | && make all install 13 | 14 | ENV PATH $HOME/.openmpi/bin:$PATH 15 | ENV LD_LIBRARY_PATH $HOME/.openmpi/lib:$LD_LIBRARY_PATH 16 | 17 | # install conda 18 | ENV PYTHON_VERSION=3.6 19 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 20 | sh miniconda.sh -b -p $HOME/conda && \ 21 | rm ~/miniconda.sh 22 | RUN $HOME/conda/bin/conda update -n base conda 23 | RUN $HOME/conda/bin/conda create -y --name pytorch-py$PYTHON_VERSION python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include 24 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -c soumith magma-cuda100 25 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION scikit-learn 26 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install pytelegraf pymongo influxdb kubernetes jinja2 27 | ENV PATH $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin:$PATH 28 | 29 | # install pytorch, torchvision, torchtext. 30 | RUN git clone --recursive https://github.com/pytorch/pytorch 31 | RUN cd pytorch && \ 32 | git checkout tags/v1.3.0 && \ 33 | git submodule sync && \ 34 | git submodule update --init && \ 35 | TORCH_CUDA_ARCH_LIST="3.5 3.7 5.2 6.0 6.1 7.0+PTX" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 36 | CMAKE_PREFIX_PATH="$(dirname $(which $HOME/conda/bin/conda))/../" \ 37 | pip install -v . 38 | RUN git clone https://github.com/pytorch/vision.git && cd vision && git checkout v0.4.0 && python setup.py install 39 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install --upgrade git+https://github.com/pytorch/text 40 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install spacy 41 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python -m spacy download en 42 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python -m spacy download de 43 | 44 | 45 | # install bit2byte. 46 | RUN git clone https://github.com/tvogels/signSGD-with-Majority-Vote.git && \ 47 | cd signSGD-with-Majority-Vote/main/bit2byte-extension/ && \ 48 | $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/python setup.py develop --user 49 | 50 | # install other python related softwares. 51 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y opencv protobuf 52 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y networkx 53 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y -c anaconda pandas 54 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -y -c conda-forge tabulate 55 | RUN $HOME/conda/envs/pytorch-py$PYTHON_VERSION/bin/pip install lmdb tensorboard_logger pyarrow msgpack msgpack_numpy mpi4py 56 | RUN $HOME/conda/bin/conda install --name pytorch-py$PYTHON_VERSION -c conda-forge python-blosc 57 | RUN $HOME/conda/bin/conda clean -ya 58 | -------------------------------------------------------------------------------- /distributed_code/hostfile: -------------------------------------------------------------------------------- 1 | localhost slots=32 -------------------------------------------------------------------------------- /distributed_code/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import datetime 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | import torch.multiprocessing as mp 9 | 10 | from parameters import get_args 11 | 12 | import pcode.create_dataset as create_dataset 13 | import pcode.create_optimizer as create_optimizer 14 | import pcode.create_metrics as create_metrics 15 | import pcode.create_model as create_model 16 | import pcode.create_scheduler as create_scheduler 17 | 18 | import pcode.utils.topology as topology 19 | import pcode.utils.checkpoint as checkpoint 20 | import pcode.utils.op_paths as op_paths 21 | import pcode.utils.stat_tracker as stat_tracker 22 | import pcode.utils.logging as logging 23 | from pcode.utils.timer import Timer 24 | 25 | 26 | def init_distributed_world(conf, backend): 27 | if backend == "mpi": 28 | dist.init_process_group("mpi") 29 | elif backend == "nccl" or backend == "gloo": 30 | # init the process group. 31 | _tmp_path = os.path.join(conf.checkpoint, "tmp", conf.timestamp) 32 | op_paths.build_dirs(_tmp_path) 33 | 34 | dist_init_file = os.path.join(_tmp_path, "dist_init") 35 | 36 | torch.distributed.init_process_group( 37 | backend=backend, 38 | init_method="file://" + os.path.abspath(dist_init_file), 39 | timeout=datetime.timedelta(seconds=120), 40 | world_size=conf.n_mpi_process, 41 | rank=conf.local_rank, 42 | ) 43 | else: 44 | raise NotImplementedError 45 | 46 | 47 | def main(conf): 48 | try: 49 | init_distributed_world(conf, backend=conf.backend) 50 | conf.distributed = True and conf.n_mpi_process > 1 51 | except AttributeError as e: 52 | print(f"failed to init the distributed world: {e}.") 53 | conf.distributed = False 54 | 55 | # init the config. 56 | init_config(conf) 57 | 58 | # define the timer for different operations. 59 | # if we choose the `train_fast` mode, then we will not track the time. 60 | conf.timer = Timer( 61 | verbosity_level=1 if conf.track_time and not conf.train_fast else 0, 62 | log_fn=conf.logger.log_metric, 63 | on_cuda=conf.on_cuda, 64 | ) 65 | 66 | # create dataset. 67 | data_loader = create_dataset.define_dataset(conf, force_shuffle=True) 68 | 69 | # create model 70 | model = create_model.define_model(conf, data_loader=data_loader) 71 | 72 | # define the optimizer. 73 | optimizer = create_optimizer.define_optimizer(conf, model) 74 | 75 | # define the lr scheduler. 76 | scheduler = create_scheduler.Scheduler(conf, optimizer) 77 | 78 | # add model with data-parallel wrapper. 79 | if conf.graph.on_cuda: 80 | if conf.n_sub_process > 1: 81 | model = torch.nn.DataParallel(model, device_ids=conf.graph.device) 82 | 83 | # (optional) reload checkpoint 84 | try: 85 | checkpoint.maybe_resume_from_checkpoint(conf, model, optimizer, scheduler) 86 | except RuntimeError as e: 87 | conf.logger.log(f"Resume Error: {e}") 88 | conf.resumed = False 89 | 90 | # train amd evaluate model. 91 | if "rnn_lm" in conf.arch: 92 | from pcode.distributed_running_nlp import train_and_validate 93 | 94 | # safety check. 95 | assert ( 96 | conf.n_sub_process == 1 97 | ), "our current data-parallel wrapper does not support RNN." 98 | 99 | # define the criterion and metrics. 100 | criterion = nn.CrossEntropyLoss(reduction="mean") 101 | criterion = criterion.cuda() if conf.graph.on_cuda else criterion 102 | metrics = create_metrics.Metrics( 103 | model.module if "DataParallel" == model.__class__.__name__ else model, 104 | task="language_modeling", 105 | ) 106 | 107 | # define the best_perf tracker, either empty or from the checkpoint. 108 | best_tracker = stat_tracker.BestPerf( 109 | best_perf=None if "best_perf" not in conf else conf.best_perf, 110 | larger_is_better=False, 111 | ) 112 | scheduler.set_best_tracker(best_tracker) 113 | 114 | # get train_and_validate_func 115 | train_and_validate_fn = train_and_validate 116 | else: 117 | from pcode.distributed_running_cv import train_and_validate 118 | 119 | # define the criterion and metrics. 120 | criterion = nn.CrossEntropyLoss(reduction="mean") 121 | criterion = criterion.cuda() if conf.graph.on_cuda else criterion 122 | metrics = create_metrics.Metrics( 123 | model.module if "DataParallel" == model.__class__.__name__ else model, 124 | task="classification", 125 | ) 126 | 127 | # define the best_perf tracker, either empty or from the checkpoint. 128 | best_tracker = stat_tracker.BestPerf( 129 | best_perf=None if "best_perf" not in conf else conf.best_perf, 130 | larger_is_better=True, 131 | ) 132 | scheduler.set_best_tracker(best_tracker) 133 | 134 | # get train_and_validate_func 135 | train_and_validate_fn = train_and_validate 136 | 137 | # save arguments to disk. 138 | checkpoint.save_arguments(conf) 139 | 140 | # start training. 141 | train_and_validate_fn( 142 | conf, 143 | model=model, 144 | criterion=criterion, 145 | scheduler=scheduler, 146 | optimizer=optimizer, 147 | metrics=metrics, 148 | data_loader=data_loader, 149 | ) 150 | 151 | 152 | def init_config(conf): 153 | # define the graph for the computation. 154 | cur_rank = dist.get_rank() if conf.distributed else 0 155 | conf.graph = topology.define_graph_topology( 156 | graph_topology=conf.graph_topology, 157 | world=conf.world, 158 | n_mpi_process=conf.n_mpi_process, # the # of total main processes. 159 | n_sub_process=conf.n_sub_process, # the # of subprocess for each main process. 160 | comm_device=conf.comm_device, 161 | on_cuda=conf.on_cuda, 162 | rank=cur_rank, 163 | ) 164 | conf.is_centralized = conf.graph_topology == "complete" 165 | 166 | # re-configure batch_size if sub_process > 1. 167 | if conf.n_sub_process > 1: 168 | conf.batch_size = conf.batch_size * conf.n_sub_process 169 | 170 | # configure cuda related. 171 | if conf.graph.on_cuda: 172 | assert torch.cuda.is_available() 173 | torch.manual_seed(conf.manual_seed) 174 | torch.cuda.manual_seed(conf.manual_seed) 175 | torch.cuda.set_device(conf.graph.device[0]) 176 | torch.backends.cudnn.enabled = True 177 | torch.backends.cudnn.benchmark = True 178 | torch.backends.cudnn.deterministic = True if conf.train_fast else False 179 | 180 | # define checkpoint for logging. 181 | checkpoint.init_checkpoint(conf) 182 | 183 | # configure logger. 184 | conf.logger = logging.Logger(conf.checkpoint_dir) 185 | 186 | # display the arguments' info. 187 | logging.display_args(conf) 188 | 189 | 190 | if __name__ == "__main__": 191 | # parse the arguments. 192 | conf = get_args() 193 | 194 | # configure for multi-process training. 195 | if conf.optimizer == "parallel_choco": 196 | mp.set_start_method("forkserver", force=True) 197 | # mp.set_start_method("spawn", force=True) 198 | mp.set_sharing_strategy("file_system") 199 | 200 | # enter the training procedure. 201 | main(conf) 202 | -------------------------------------------------------------------------------- /distributed_code/pcode/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/create_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import torch 4 | import torchtext 5 | 6 | from pcode.datasets.partition_data import DataPartitioner 7 | from pcode.datasets.prepare_data import get_dataset 8 | 9 | 10 | def load_data_batch(conf, _input, _target): 11 | """Load a mini-batch and record the loading time.""" 12 | if conf.graph.on_cuda: 13 | _input, _target = _input.cuda(), _target.cuda() 14 | return _input, _target 15 | 16 | 17 | def define_dataset(conf, force_shuffle=False): 18 | if "rnn_lm" in conf.arch: 19 | dataset = define_nlp_dataset(conf, force_shuffle) 20 | else: 21 | dataset = define_cv_dataset(conf, force_shuffle) 22 | print("Defined dataset.") 23 | return dataset 24 | 25 | 26 | """define loaders for different datasets.""" 27 | """nlp related task.""" 28 | 29 | 30 | def define_nlp_dataset(conf, force_shuffle): 31 | print("create {} dataset for rank {}".format(conf.data, conf.graph.rank)) 32 | # create dataset. 33 | TEXT, train, valid, _ = get_dataset(conf, conf.data, conf.data_dir) 34 | 35 | # Build vocb. 36 | # we can use some precomputed word embeddings, 37 | # e.g., GloVe vectors with 100, 200, and 300. 38 | if conf.rnn_use_pretrained_emb: 39 | try: 40 | vectors = "glove.6B.{}d".format(conf.rnn_n_hidden) 41 | vectors_cache = os.path.join(conf.data_dir, ".vector_cache") 42 | except: 43 | vectors, vectors_cache = None, None 44 | else: 45 | vectors, vectors_cache = None, None 46 | TEXT.build_vocab(train, vectors=vectors, vectors_cache=vectors_cache) 47 | 48 | # Partition training data. 49 | train_loader, _ = torchtext.data.BPTTIterator.splits( 50 | (train, valid), 51 | batch_size=conf.batch_size * conf.graph.n_nodes, 52 | bptt_len=conf.rnn_bptt_len, 53 | device="cuda:{}".format(conf.graph.device[0]) if conf.graph.on_cuda else None, 54 | repeat=True, 55 | shuffle=force_shuffle or conf.reshuffle_per_epoch, 56 | ) 57 | _, val_loader = torchtext.data.BPTTIterator.splits( 58 | (train, valid), 59 | batch_size=conf.batch_size, 60 | bptt_len=conf.rnn_bptt_len, 61 | device="cuda:{}".format(conf.graph.device[0]) if conf.graph.on_cuda else None, 62 | shuffle=False, 63 | ) 64 | 65 | # get some stat. 66 | _get_nlp_data_stat(conf, train, valid, train_loader, val_loader) 67 | return {"TEXT": TEXT, "train_loader": train_loader, "val_loader": val_loader} 68 | 69 | 70 | def _get_nlp_data_stat(conf, train, valid, train_loader, val_loader): 71 | # configure the workload for each worker. 72 | # Note that: the training will access to the same # of samples (w/ or w/o partition). 73 | 74 | # the current implementation will always partition the data. 75 | conf.train_word_size = len(train.examples[0].text) 76 | conf.valid_word_size = len(valid.examples[0].text) 77 | 78 | conf.num_batches_train_per_device_per_epoch = len(train_loader) 79 | conf.num_whole_train_batches_per_worker = ( 80 | conf.num_batches_train_per_device_per_epoch * conf.num_epochs 81 | ) 82 | conf.num_warmup_train_batches_per_worker = ( 83 | conf.num_batches_train_per_device_per_epoch * conf.lr_warmup_epochs 84 | ) 85 | 86 | # when the training is controlled by the num_iterations. 87 | conf.num_iterations_per_worker = conf.num_iterations // conf.graph.n_nodes 88 | 89 | # get the data statictics (on behalf of each worker) for val. 90 | conf.num_batches_val_per_device_per_epoch = len(val_loader) 91 | 92 | # define some parameters for training. 93 | print( 94 | "\nData Stat: we have {} epochs, \ 95 | {} mini-batches per device for training. \ 96 | {} mini-batches per device for val. \ 97 | The batch size: {}.".format( 98 | conf.num_epochs, 99 | conf.num_batches_train_per_device_per_epoch, 100 | conf.num_batches_val_per_device_per_epoch, 101 | conf.batch_size, 102 | ) 103 | ) 104 | 105 | 106 | """cv related task.""" 107 | 108 | 109 | def define_cv_dataset(conf, force_shuffle): 110 | print("Create dataset: {} for rank {}.".format(conf.data, conf.graph.rank)) 111 | train_loader = _define_cv_dataset( 112 | conf, 113 | partition_type=conf.partition_data, 114 | dataset_type="train", 115 | force_shuffle=force_shuffle, 116 | ) 117 | val_loader = _define_cv_dataset(conf, partition_type=None, dataset_type="test") 118 | 119 | _get_cv_data_stat(conf, train_loader, val_loader) 120 | return {"train_loader": train_loader, "val_loader": val_loader} 121 | 122 | 123 | def _define_cv_dataset(conf, partition_type, dataset_type, force_shuffle=False): 124 | """ Given a dataset, partition it. """ 125 | dataset = get_dataset(conf, conf.data, conf.data_dir, split=dataset_type) 126 | batch_size = conf.batch_size 127 | world_size = conf.graph.n_nodes 128 | 129 | # determine the data to load, 130 | # either the whole dataset, or a subset specified by partition_type. 131 | if partition_type is not None and conf.distributed: 132 | partition_sizes = [1.0 / world_size for _ in range(world_size)] 133 | partition = DataPartitioner( 134 | conf, dataset, partition_sizes, partition_type=partition_type 135 | ) 136 | data_to_load = partition.use(conf.graph.rank) 137 | print("Data partition: partitioned data and use subdata.") 138 | else: 139 | data_to_load = dataset 140 | print("Data partition: used whole data.") 141 | 142 | # use Dataloader. 143 | data_loader = torch.utils.data.DataLoader( 144 | data_to_load, 145 | batch_size=batch_size, 146 | shuffle=force_shuffle or dataset_type == "train", 147 | num_workers=conf.num_workers, 148 | pin_memory=conf.pin_memory, 149 | drop_last=False, 150 | ) 151 | 152 | print( 153 | ( 154 | "Data stat: we have {} samples for {}, " 155 | + "load {} data for process (rank {}). " 156 | + "The batch size is {}, number of batches is {}." 157 | ).format( 158 | len(dataset), 159 | dataset_type, 160 | len(data_to_load), 161 | conf.graph.rank, 162 | batch_size, 163 | len(data_loader), 164 | ) 165 | ) 166 | return data_loader 167 | 168 | 169 | def _get_cv_data_stat(conf, train_loader, val_loader): 170 | # configure the workload for each worker. 171 | # Note that: the training will access to the same # of samples (w/ or w/o partition). 172 | 173 | # when it is w/ partition, then return the true local loader size. 174 | # when it is w/o partition, then return the local loader size / world size. 175 | conf.num_batches_train_per_device_per_epoch = ( 176 | len(train_loader) // conf.graph.n_nodes 177 | if conf.partition_data is None 178 | else len(train_loader) 179 | ) 180 | conf.num_whole_train_batches_per_worker = ( 181 | conf.num_batches_train_per_device_per_epoch * conf.num_epochs 182 | ) 183 | conf.num_warmup_train_batches_per_worker = ( 184 | conf.num_batches_train_per_device_per_epoch * conf.lr_warmup_epochs 185 | ) 186 | 187 | # when the training is controlled by the num_iterations. 188 | conf.num_iterations_per_worker = conf.num_iterations // conf.graph.n_nodes 189 | 190 | # get the data statictics (on behalf of each worker) for val. 191 | conf.num_batches_val_per_device_per_epoch = len(val_loader) 192 | 193 | # define some parameters for training. 194 | print( 195 | "\nData Stat: we have {} epochs, \ 196 | {} mini-batches per device for training. \ 197 | {} mini-batches per device for val. \ 198 | The batch size: {}.".format( 199 | conf.num_epochs, 200 | conf.num_batches_train_per_device_per_epoch, 201 | conf.num_batches_val_per_device_per_epoch, 202 | conf.batch_size, 203 | ) 204 | ) 205 | -------------------------------------------------------------------------------- /distributed_code/pcode/create_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | 5 | class Metrics(object): 6 | """""" 7 | 8 | def __init__(self, model, task="classification"): 9 | self.model = model 10 | self.task = task 11 | self.metric_names = None 12 | self.metrics_fn = self._infer() 13 | 14 | def evaluate(self, loss, output, target): 15 | return self.metrics_fn(loss, output, target) 16 | 17 | def _infer(self): 18 | if self.task == "classification": 19 | self.topks = (1, 5) if self.model.num_classes >= 5 else (1,) 20 | self.metric_names = ["top{}".format(topk) for topk in self.topks] 21 | return self._accuracy 22 | elif self.task == "language_modeling": 23 | self.metric_names = ["ppl"] 24 | return self._ppl 25 | else: 26 | raise NotImplementedError 27 | 28 | # some safety check. 29 | assert self.metric_names is not None 30 | 31 | def _accuracy(self, loss, output, target): 32 | """Computes the precision@k for the specified values of k""" 33 | res = [] 34 | 35 | if len(self.topks) > 0: 36 | maxk = max(self.topks) 37 | batch_size = target.size(0) 38 | 39 | _, pred = output.topk(maxk, 1, True, True) 40 | pred = pred.t() 41 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 42 | 43 | for topk in self.topks: 44 | correct_k = correct[:topk].view(-1).float().sum(0, keepdim=True) 45 | res.append(correct_k.mul_(100.0 / batch_size).item()) 46 | else: 47 | res += [0] 48 | return res 49 | 50 | def _ppl(self, loss, output, target): 51 | return [math.exp(loss)] 52 | -------------------------------------------------------------------------------- /distributed_code/pcode/create_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.distributed as dist 4 | 5 | import pcode.models as models 6 | 7 | 8 | def define_model(conf, **kargs): 9 | if "rnn_lm" in conf.arch: 10 | return define_nlp_model(conf, TEXT=kargs["data_loader"]["TEXT"]) 11 | else: 12 | return define_cv_model(conf) 13 | 14 | 15 | """define loaders for different models.""" 16 | 17 | 18 | def define_cv_model(conf): 19 | if "wideresnet" in conf.arch: 20 | model = models.__dict__["wideresnet"](conf) 21 | elif "resnet" in conf.arch: 22 | model = models.__dict__["resnet"](conf) 23 | elif "densenet" in conf.arch: 24 | model = models.__dict__["densenet"](conf) 25 | elif "vgg" in conf.arch: 26 | model = models.__dict__["vgg"](conf) 27 | elif "lenet" in conf.arch: 28 | model = models.__dict__["lenet"](conf) 29 | else: 30 | model = models.__dict__[conf.arch](conf) 31 | 32 | if conf.graph.on_cuda: 33 | model = model.cuda() 34 | 35 | # get a consistent init model over the world. 36 | if conf.distributed: 37 | consistent_model(conf, model) 38 | 39 | # get the model stat info. 40 | get_model_stat(conf, model) 41 | return model 42 | 43 | 44 | def define_nlp_model(conf, TEXT): 45 | print("=> creating model '{}'".format(conf.arch)) 46 | 47 | # get embdding size and num_tokens. 48 | weight_matrix = TEXT.vocab.vectors 49 | 50 | if weight_matrix is not None: 51 | conf.n_tokens, emb_size = weight_matrix.size(0), weight_matrix.size(1) 52 | else: 53 | conf.n_tokens, emb_size = len(TEXT.vocab), conf.rnn_n_hidden 54 | 55 | # create model. 56 | model = models.RNNLM( 57 | ntoken=conf.n_tokens, 58 | ninp=emb_size, 59 | nhid=conf.rnn_n_hidden, 60 | nlayers=conf.rnn_n_layers, 61 | tie_weights=conf.rnn_tie_weights, 62 | dropout=conf.drop_rate, 63 | weight_norm=conf.rnn_weight_norm, 64 | ) 65 | 66 | # init the model. 67 | if weight_matrix is not None: 68 | model.encoder.weight.data.copy_(weight_matrix) 69 | 70 | if conf.graph.on_cuda: 71 | model = model.cuda() 72 | 73 | # consistent the model. 74 | consistent_model(conf, model) 75 | get_model_stat(conf, model) 76 | return model 77 | 78 | 79 | """some utilities functions.""" 80 | 81 | 82 | def get_model_stat(conf, model): 83 | print( 84 | "=> creating model '{}. total params for process {}: {}M".format( 85 | conf.arch, 86 | conf.graph.rank, 87 | sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 88 | ) 89 | ) 90 | 91 | 92 | def consistent_model(conf, model): 93 | """it might because of MPI, the model for each process is not the same. 94 | 95 | This function is proposed to fix this issue, 96 | i.e., use the model (rank=0) as the global model. 97 | """ 98 | print("consistent model for process (rank {})".format(conf.graph.rank)) 99 | cur_rank = conf.graph.rank 100 | for param in model.parameters(): 101 | param.data = param.data if cur_rank == 0 else param.data - param.data 102 | dist.all_reduce(param.data, op=dist.ReduceOp.SUM) 103 | -------------------------------------------------------------------------------- /distributed_code/pcode/create_optimizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pcode.optim.sgd import SGD 3 | 4 | from pcode.optim.local_sgd import LocalSGD 5 | from pcode.optim.local_ef_sign_sgd import Local_EFSignSGD 6 | from pcode.optim.local_sign_sgd import Local_SignSGD 7 | from pcode.optim.sign_sgd import SignSGD 8 | from pcode.optim.ef_sign_sgd import EF_SignSGD 9 | from pcode.optim.dgc import DGC 10 | 11 | 12 | def define_optimizer(conf, model): 13 | # define the param to optimize. 14 | params = [ 15 | { 16 | "params": [value], 17 | "name": key, 18 | "weight_decay": conf.weight_decay if "bn" not in key else 0.0, 19 | "param_size": value.size(), 20 | "nelement": value.nelement(), 21 | } 22 | for key, value in model.named_parameters() 23 | ] 24 | 25 | # define the optimizer. 26 | if conf.optimizer == "sgd": 27 | optim_class = SGD 28 | elif conf.optimizer == "dgc": 29 | optim_class = DGC 30 | elif conf.optimizer == "local_sgd": 31 | optim_class = LocalSGD 32 | elif conf.optimizer == "sign_sgd": 33 | optim_class = SignSGD 34 | elif conf.optimizer == "ef_sign_sgd": 35 | optim_class = EF_SignSGD 36 | elif conf.optimizer == "local_sign_sgd": 37 | optim_class = Local_SignSGD 38 | elif conf.optimizer == "local_ef_sign_sgd": 39 | optim_class = Local_EFSignSGD 40 | else: 41 | raise NotImplementedError 42 | 43 | optimizer = optim_class( 44 | params, 45 | lr=conf.lr, 46 | momentum=conf.momentum_factor, 47 | nesterov=conf.use_nesterov, 48 | conf=conf, 49 | ) 50 | return optimizer 51 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/datasets/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/datasets/loader/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/epsilon_or_rcv1_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from pcode.datasets.loader.utils import LMDBPT 4 | 5 | 6 | def define_epsilon_or_rcv1_folder(root): 7 | print('load epsilon_or_rcv1 from lmdb: {}.'.format(root)) 8 | return LMDBPT(root, is_image=False) 9 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/imagenet_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torchvision.datasets as datasets 3 | 4 | from pcode.datasets.loader.preprocess_toolkit import get_transform 5 | from pcode.datasets.loader.utils import LMDBPT 6 | 7 | 8 | def define_imagenet_folder(name, root, flag, cuda=True): 9 | is_train = 'train' in root 10 | transform = get_transform(name, augment=is_train, color_process=False) 11 | 12 | if flag: 13 | print('load imagenet from lmdb: {}'.format(root)) 14 | return LMDBPT(root, transform=transform, is_image=True) 15 | else: 16 | print("load imagenet using pytorch's default dataloader.") 17 | return datasets.ImageFolder(root=root, 18 | transform=transform, 19 | target_transform=None) 20 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/preprocess_toolkit.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import torch 5 | import torchvision.transforms as transforms 6 | 7 | 8 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 9 | 'std': [0.229, 0.224, 0.225]} 10 | __imagenet_pca = { 11 | 'eigval': torch.Tensor([0.2175, 0.0188, 0.0045]), 12 | 'eigvec': torch.Tensor([ 13 | [-0.5675, 0.7192, 0.4009], 14 | [-0.5808, -0.0045, -0.8140], 15 | [-0.5836, -0.6948, 0.4203], 16 | ]) 17 | } 18 | 19 | 20 | def scale_crop(input_size, scale_size=None, normalize=__imagenet_stats): 21 | t_list = [ 22 | transforms.CenterCrop(input_size), 23 | transforms.ToTensor() 24 | ] 25 | if normalize is not None: 26 | t_list += [transforms.Normalize(**normalize)] 27 | if scale_size != input_size: 28 | t_list = [transforms.Resize(scale_size)] + t_list 29 | return transforms.Compose(t_list) 30 | 31 | 32 | def scale_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 33 | t_list = [ 34 | transforms.RandomCrop(input_size), 35 | transforms.ToTensor(), 36 | ] 37 | if normalize is not None: 38 | t_list += [transforms.Normalize(**normalize)] 39 | if scale_size != input_size: 40 | t_list = [transforms.Resize(scale_size)] + t_list 41 | return transforms.Compose(t_list) 42 | 43 | 44 | def pad_random_crop(input_size, scale_size=None, normalize=__imagenet_stats): 45 | padding = int((scale_size - input_size) / 2) 46 | t_list = [ 47 | transforms.RandomCrop(input_size, padding=padding), 48 | transforms.RandomHorizontalFlip(), 49 | transforms.ToTensor(), 50 | ] 51 | if normalize is not None: 52 | t_list += [transforms.Normalize(**normalize)] 53 | return transforms.Compose(t_list) 54 | 55 | 56 | def inception_preproccess(input_size, normalize=__imagenet_stats): 57 | t_list = [ 58 | transforms.RandomResizedCrop(input_size), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | ] 62 | if normalize is not None: 63 | t_list += [transforms.Normalize(**normalize)] 64 | return transforms.Compose(t_list) 65 | 66 | 67 | def inception_color_preproccess(input_size, normalize=__imagenet_stats): 68 | t_list = [ 69 | transforms.RandomResizedCrop(input_size), 70 | transforms.RandomHorizontalFlip(), 71 | transforms.ToTensor(), 72 | ColorJitter( 73 | brightness=0.4, 74 | contrast=0.4, 75 | saturation=0.4, 76 | ), 77 | Lighting(0.1, __imagenet_pca['eigval'], __imagenet_pca['eigvec']), 78 | ] 79 | if normalize is not None: 80 | t_list += [transforms.Normalize(**normalize)] 81 | return transforms.Compose(t_list) 82 | 83 | 84 | def get_transform(name='imagenet', input_size=None, scale_size=None, 85 | normalize=None, augment=True, color_process=False): 86 | normalize = normalize or __imagenet_stats 87 | 88 | if 'imagenet' in name: 89 | scale_size = scale_size or (36 if 'downsampled' in name else 256) 90 | input_size = input_size or (32 if 'downsampled' in name else 224) 91 | 92 | if augment: 93 | if color_process: 94 | preprocess_fn = inception_color_preproccess 95 | else: 96 | preprocess_fn = inception_preproccess 97 | return preprocess_fn(input_size, normalize=normalize) 98 | else: 99 | return scale_crop(input_size=input_size, 100 | scale_size=scale_size, normalize=normalize) 101 | elif 'cifar' in name: 102 | input_size = input_size or 32 103 | if augment: 104 | scale_size = scale_size or 40 105 | return pad_random_crop(input_size, scale_size=scale_size, 106 | normalize=normalize) 107 | else: 108 | scale_size = scale_size or 32 109 | return scale_crop(input_size=input_size, 110 | scale_size=scale_size, normalize=normalize) 111 | elif name == 'mnist': 112 | normalize = {'mean': [0.5], 'std': [0.5]} 113 | input_size = input_size or 28 114 | if augment: 115 | scale_size = scale_size or 32 116 | return pad_random_crop(input_size, scale_size=scale_size, 117 | normalize=normalize) 118 | else: 119 | scale_size = scale_size or 32 120 | return scale_crop(input_size=input_size, 121 | scale_size=scale_size, normalize=normalize) 122 | 123 | 124 | class Lighting(object): 125 | """Lighting noise(AlexNet - style PCA - based noise)""" 126 | 127 | def __init__(self, alphastd, eigval, eigvec): 128 | self.alphastd = alphastd 129 | self.eigval = eigval 130 | self.eigvec = eigvec 131 | 132 | def __call__(self, img): 133 | if self.alphastd == 0: 134 | return img 135 | 136 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 137 | rgb = self.eigvec.type_as(img).clone()\ 138 | .mul(alpha.view(1, 3).expand(3, 3))\ 139 | .mul(self.eigval.view(1, 3).expand(3, 3))\ 140 | .sum(1).squeeze() 141 | 142 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 143 | 144 | 145 | class Grayscale(object): 146 | 147 | def __call__(self, img): 148 | gs = img.clone() 149 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 150 | gs[1].copy_(gs[0]) 151 | gs[2].copy_(gs[0]) 152 | return gs 153 | 154 | 155 | class Saturation(object): 156 | 157 | def __init__(self, var): 158 | self.var = var 159 | 160 | def __call__(self, img): 161 | gs = Grayscale()(img) 162 | alpha = random.uniform(0, self.var) 163 | return img.lerp(gs, alpha) 164 | 165 | 166 | class Brightness(object): 167 | 168 | def __init__(self, var): 169 | self.var = var 170 | 171 | def __call__(self, img): 172 | gs = img.new().resize_as_(img).zero_() 173 | alpha = random.uniform(0, self.var) 174 | return img.lerp(gs, alpha) 175 | 176 | 177 | class Contrast(object): 178 | 179 | def __init__(self, var): 180 | self.var = var 181 | 182 | def __call__(self, img): 183 | gs = Grayscale()(img) 184 | gs.fill_(gs.mean()) 185 | alpha = random.uniform(0, self.var) 186 | return img.lerp(gs, alpha) 187 | 188 | 189 | class RandomOrder(object): 190 | """ Composes several transforms together in random order. 191 | """ 192 | 193 | def __init__(self, transforms): 194 | self.transforms = transforms 195 | 196 | def __call__(self, img): 197 | if self.transforms is None: 198 | return img 199 | order = torch.randperm(len(self.transforms)) 200 | for i in order: 201 | img = self.transforms[i](img) 202 | return img 203 | 204 | 205 | class ColorJitter(RandomOrder): 206 | 207 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 208 | self.transforms = [] 209 | if brightness != 0: 210 | self.transforms.append(Brightness(brightness)) 211 | if contrast != 0: 212 | self.transforms.append(Contrast(contrast)) 213 | if saturation != 0: 214 | self.transforms.append(Saturation(saturation)) 215 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/serialize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | __all__ = ['loads', 'dumps'] 6 | 7 | 8 | def create_dummy_func(func, dependency): 9 | """ 10 | When a dependency of a function is not available, 11 | create a dummy function which throws ImportError when used. 12 | Args: 13 | func (str): name of the function. 14 | dependency (str or list[str]): name(s) of the dependency. 15 | Returns: 16 | function: a function object 17 | """ 18 | if isinstance(dependency, (list, tuple)): 19 | dependency = ','.join(dependency) 20 | 21 | def _dummy(*args, **kwargs): 22 | raise ImportError( 23 | "Cannot import '{}', therefore '{}' is not available".format( 24 | dependency, func) 25 | ) 26 | return _dummy 27 | 28 | 29 | def dumps_msgpack(obj): 30 | """ 31 | Serialize an object. 32 | Returns: 33 | Implementation-dependent bytes-like object 34 | """ 35 | return msgpack.dumps(obj, use_bin_type=True) 36 | 37 | 38 | def loads_msgpack(buf): 39 | """ 40 | Args: 41 | buf: the output of `dumps`. 42 | """ 43 | return msgpack.loads(buf, raw=False) 44 | 45 | 46 | def dumps_pyarrow(obj): 47 | """ 48 | Serialize an object. 49 | 50 | Returns: 51 | Implementation-dependent bytes-like object 52 | """ 53 | return pa.serialize(obj).to_buffer() 54 | 55 | 56 | def loads_pyarrow(buf): 57 | """ 58 | Args: 59 | buf: the output of `dumps`. 60 | """ 61 | return pa.deserialize(buf) 62 | 63 | 64 | try: 65 | # fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666 66 | import pyarrow as pa 67 | except ImportError: 68 | pa = None 69 | dumps_pyarrow = create_dummy_func('dumps_pyarrow', ['pyarrow']) # noqa 70 | loads_pyarrow = create_dummy_func('loads_pyarrow', ['pyarrow']) # noqa 71 | 72 | try: 73 | import msgpack 74 | import msgpack_numpy 75 | msgpack_numpy.patch() 76 | except ImportError: 77 | assert pa is not None, "pyarrow is a dependency of tensorpack!" 78 | loads_msgpack = create_dummy_func( # noqa 79 | 'loads_msgpack', ['msgpack', 'msgpack_numpy']) 80 | dumps_msgpack = create_dummy_func( # noqa 81 | 'dumps_msgpack', ['msgpack', 'msgpack_numpy']) 82 | 83 | if os.environ.get('TENSORPACK_SERIALIZE', 'msgpack') == 'msgpack': 84 | loads = loads_msgpack 85 | dumps = dumps_msgpack 86 | else: 87 | loads = loads_pyarrow 88 | dumps = dumps_pyarrow 89 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/svhn_folder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import print_function 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | import torch.utils.data as data 9 | from torchvision.datasets.utils import download_url, check_integrity 10 | 11 | 12 | def define_svhn_folder(root, is_train, transform, target_transform, download): 13 | return SVHN(root=root, 14 | is_train=is_train, 15 | transform=transform, 16 | target_transform=target_transform, 17 | is_download=download) 18 | 19 | 20 | class SVHN(data.Dataset): 21 | """`SVHN `_ Dataset. 22 | Note: The SVHN dataset assigns the label `10` to the digit `0`. 23 | However, in this Dataset, we assign the label `0` to the digit `0` 24 | to be compatible with PyTorch loss functions which 25 | expect the class labels to be in the range `[0, C-1]` 26 | 27 | Args: 28 | root (string): Root directory of dataset where directory 29 | ``SVHN`` exists. 30 | split (string): One of {'train', 'test', 'extra'}. 31 | Accordingly dataset is selected. 'extra' is Extra training set. 32 | transform (callable, optional): A function/transform that 33 | takes in an PIL image and returns a transformed version. 34 | E.g, ``transforms.RandomCrop`` 35 | target_transform (callable, optional): 36 | A function/transform that takes in the target and transforms it. 37 | download (bool, optional): If true, 38 | downloads the dataset from the internet and 39 | puts it in root directory. If dataset is already downloaded, 40 | it is not downloaded again. 41 | """ 42 | url = "" 43 | filename = "" 44 | file_md5 = "" 45 | 46 | split_list = { 47 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 48 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 49 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 50 | "test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], 51 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 52 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} 53 | 54 | def __init__(self, root, is_train='train', 55 | transform=None, target_transform=None, is_download=False): 56 | self.root = os.path.expanduser(root) 57 | self.transform = transform 58 | self.target_transform = target_transform 59 | self.is_train = is_train # training set or test set or extra set 60 | self.is_download = is_download 61 | 62 | if self.is_train: 63 | tr_data = self.load_svhn_data('train') 64 | ex_data = self.load_svhn_data('extra') 65 | self.data, self.labels = self.build_training(tr_data, ex_data) 66 | else: 67 | self.data, self.labels = self.load_svhn_data('test') 68 | 69 | def load_svhn_data(self, data_type): 70 | url = self.split_list[data_type][0] 71 | filename = self.split_list[data_type][1] 72 | file_md5 = self.split_list[data_type][2] 73 | 74 | if self.is_download: 75 | self.download(url, filename, file_md5) 76 | 77 | if not self._check_integrity(data_type, filename): 78 | raise RuntimeError( 79 | 'Dataset not found or corrupted.' + 80 | ' You can use download=True to download it') 81 | 82 | data, labels = self._load_svhn_data(filename) 83 | return data, labels 84 | 85 | def _load_svhn_data(self, filename): 86 | # import here rather than at top of file because this is 87 | # an optional dependency for torchvision 88 | import scipy.io as sio 89 | 90 | # reading(loading) mat file as array 91 | loaded_mat = sio.loadmat(os.path.join(self.root, filename)) 92 | 93 | data = loaded_mat['X'] 94 | # loading from the .mat file gives an np array of type np.uint8 95 | # converting to np.int64, so that we have a LongTensor after 96 | # the conversion from the numpy array 97 | # the squeeze is needed to obtain a 1D tensor 98 | labels = loaded_mat['y'].astype(np.int64).squeeze() 99 | 100 | # the svhn dataset assigns the class label "10" to the digit 0 101 | # this makes it inconsistent with several loss functions 102 | # which expect the class labels to be in the range [0, C-1] 103 | np.place(labels, labels == 10, 0) 104 | data = np.transpose(data, (3, 2, 0, 1)) 105 | return data, labels 106 | 107 | def build_training(self, tr_data, ex_data): 108 | def get_include_indices(total, exclude): 109 | return list(set(list(total)) - set(exclude)) 110 | 111 | def exclude_samples(data, size_per_class): 112 | images, labels = data 113 | exclude_indices = [] 114 | 115 | # get exclude indices. 116 | for label in range(min(labels), max(labels) + 1): 117 | matched_indices = np.where(labels == label)[0] 118 | # fix the choice to train data (do not use random.choice) 119 | exclude_index = matched_indices.tolist()[: size_per_class] 120 | exclude_indices += exclude_index 121 | 122 | # get include indices 123 | include_indices = get_include_indices( 124 | range(images.shape[0]), exclude_indices) 125 | images = images[include_indices, :, :, :] 126 | labels = labels[include_indices] 127 | return images, labels 128 | 129 | def build_train(tr_data, ex_data): 130 | # get indices to exclude. 131 | selected_tr_images, selected_tr_labels = exclude_samples( 132 | tr_data, 400) 133 | selected_ex_images, selected_ex_labels = exclude_samples( 134 | ex_data, 200) 135 | images = np.concatenate([selected_tr_images, selected_ex_images]) 136 | labels = np.concatenate([selected_tr_labels, selected_ex_labels]) 137 | return images, labels 138 | return build_train(tr_data, ex_data) 139 | 140 | def __getitem__(self, index): 141 | """ 142 | Args: 143 | index (int): Index 144 | 145 | Returns: 146 | tuple: (image, target) where target is index of the target class. 147 | """ 148 | img, target = self.data[index], int(self.labels[index]) 149 | 150 | # doing this so that it is consistent with all other datasets 151 | # to return a PIL Image 152 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 153 | 154 | if self.transform is not None: 155 | img = self.transform(img) 156 | 157 | if self.target_transform is not None: 158 | target = self.target_transform(target) 159 | 160 | return img, target 161 | 162 | def __len__(self): 163 | return len(self.data) 164 | 165 | def _check_integrity(self, data_type, filename): 166 | root = self.root 167 | md5 = self.split_list[data_type][2] 168 | fpath = os.path.join(root, filename) 169 | return check_integrity(fpath, md5) 170 | 171 | def download(self, url, filename, file_md5): 172 | download_url(url, self.root, filename, file_md5) 173 | 174 | def __repr__(self): 175 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 176 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 177 | fmt_str += ' Split: {}\n'.format(self.is_train) 178 | fmt_str += ' Root Location: {}\n'.format(self.root) 179 | tmp = ' Transforms (if any): ' 180 | fmt_str += '{0}{1}\n'.format( 181 | tmp, 182 | self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)) 183 | ) 184 | tmp = ' Target Transforms (if any): ' 185 | fmt_str += '{0}{1}'.format( 186 | tmp, 187 | self.target_transform.__repr__().replace( 188 | '\n', '\n' + ' ' * len(tmp)) 189 | ) 190 | return fmt_str 191 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/loader/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | 5 | import lmdb 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import torch.utils.data as data 11 | 12 | import pcode.datasets.loader.serialize as serialize 13 | 14 | 15 | if sys.version_info[0] == 2: 16 | import cPickle as pickle 17 | else: 18 | import pickle 19 | 20 | 21 | def be_ncwh_pt(x): 22 | return x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 23 | 24 | 25 | def uint8_to_float(x): 26 | x = x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 27 | return x.float() / 128.0 - 1.0 28 | 29 | 30 | class LMDBPT(data.Dataset): 31 | """A class to load the LMDB file for extreme large datasets. 32 | Args: 33 | root (string): Either root directory for the database files, 34 | or a absolute path pointing to the file. 35 | classes (string or list): One of {'train', 'val', 'test'} or a list of 36 | categories to load. e,g. ['bedroom_train', 'church_train']. 37 | transform (callable, optional): A function/transform that 38 | takes in an PIL image and returns a transformed version. 39 | E.g, ``transforms.RandomCrop`` 40 | target_transform (callable, optional): 41 | A function/transform that takes in the target and transforms it. 42 | """ 43 | 44 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 45 | self.root = os.path.expanduser(root) 46 | self.transform = transform 47 | self.target_transform = target_transform 48 | self.lmdb_files = self._get_valid_lmdb_files() 49 | 50 | # for each class, create an LSUNClassDataset 51 | self.dbs = [] 52 | for lmdb_file in self.lmdb_files: 53 | self.dbs.append( 54 | LMDBPTClass( 55 | root=lmdb_file, 56 | transform=transform, 57 | target_transform=target_transform, 58 | is_image=is_image, 59 | ) 60 | ) 61 | 62 | # build up indices. 63 | self.indices = np.cumsum([len(db) for db in self.dbs]) 64 | self.length = self.indices[-1] 65 | self._build_indices() 66 | 67 | def _get_valid_lmdb_files(self): 68 | """get valid lmdb based on given root.""" 69 | if not self.root.endswith(".lmdb"): 70 | files = [] 71 | for l in os.listdir(self.root): 72 | if "_" in l and "-lock" not in l: 73 | files.append(os.path.join(self.root, l)) 74 | return files 75 | else: 76 | return [self.root] 77 | 78 | def _build_indices(self): 79 | self.from_to_indices = enumerate(zip(self.indices[:-1], self.indices[1:])) 80 | 81 | def _get_matched_index(self, index): 82 | if len(list(self.from_to_indices)) == 0: 83 | return 0, index 84 | 85 | for ind, (from_index, to_index) in self.from_to_indices: 86 | if from_index <= index and index < to_index: 87 | return ind, index - from_index 88 | 89 | def __getitem__(self, index): 90 | block_index, item_index = self._get_matched_index(index) 91 | image, target = self.dbs[block_index][item_index] 92 | return image, target 93 | 94 | def __len__(self): 95 | return self.length 96 | 97 | def __repr__(self): 98 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 99 | fmt_str += " Number of datapoints: {}\n".format(self.__len__()) 100 | fmt_str += " Root Location: {}\n".format(self.root) 101 | tmp = " Transforms (if any): " 102 | fmt_str += "{0}{1}\n".format( 103 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 104 | ) 105 | tmp = " Target Transforms (if any): " 106 | fmt_str += "{0}{1}".format( 107 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 108 | ) 109 | return fmt_str 110 | 111 | 112 | class LMDBPTClass(data.Dataset): 113 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 114 | self.root = os.path.expanduser(root) 115 | self.transform = transform 116 | self.target_transform = target_transform 117 | self.is_image = is_image 118 | 119 | # init the placeholder for env and length. 120 | self.env = None 121 | self.length = self._get_tmp_length() 122 | 123 | def _open_lmdb(self): 124 | return lmdb.open( 125 | self.root, 126 | subdir=os.path.isdir(self.root), 127 | readonly=True, 128 | lock=False, 129 | readahead=False, 130 | # map_size=1099511627776 * 2, 131 | max_readers=1, 132 | meminit=False, 133 | ) 134 | 135 | def _get_tmp_length(self): 136 | env = lmdb.open( 137 | self.root, 138 | subdir=os.path.isdir(self.root), 139 | readonly=True, 140 | lock=False, 141 | readahead=False, 142 | # map_size=1099511627776 * 2, 143 | max_readers=1, 144 | meminit=False, 145 | ) 146 | with env.begin(write=False) as txn: 147 | length = txn.stat()["entries"] 148 | 149 | if txn.get(b"__keys__") is not None: 150 | length -= 1 151 | # clean everything. 152 | del env 153 | return length 154 | 155 | def _get_length(self): 156 | with self.env.begin(write=False) as txn: 157 | self.length = txn.stat()["entries"] 158 | 159 | if txn.get(b"__keys__") is not None: 160 | self.length -= 1 161 | 162 | def _prepare_cache(self): 163 | cache_file = self.root + "_cache_" 164 | if os.path.isfile(cache_file): 165 | self.keys = pickle.load(open(cache_file, "rb")) 166 | else: 167 | with self.env.begin(write=False) as txn: 168 | self.keys = [key for key, _ in txn.cursor() if key != b"__keys__"] 169 | pickle.dump(self.keys, open(cache_file, "wb")) 170 | 171 | def _image_decode(self, x): 172 | image = cv2.imdecode(x, cv2.IMREAD_COLOR).astype("uint8") 173 | return Image.fromarray(image, "RGB") 174 | 175 | def __getitem__(self, index): 176 | if self.env is None: 177 | # # open lmdb env. 178 | self.env = self._open_lmdb() 179 | # # get file stats. 180 | # self._get_length() 181 | # # prepare cache_file 182 | self._prepare_cache() 183 | 184 | # setup. 185 | env = self.env 186 | with env.begin(write=False) as txn: 187 | bin_file = txn.get(self.keys[index]) 188 | 189 | image, target = serialize.loads(bin_file) 190 | 191 | if self.is_image: 192 | image = self._image_decode(image) 193 | 194 | if self.transform is not None: 195 | image = self.transform(image) 196 | if self.target_transform is not None: 197 | target = self.target_transform(target) 198 | return image, target 199 | 200 | def __len__(self): 201 | return self.length 202 | 203 | def __repr__(self): 204 | return self.__class__.__name__ + " (" + self.root + ")" 205 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/partition_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | import torch 5 | import torch.distributed as dist 6 | 7 | 8 | class Partition(object): 9 | """ Dataset-like object, but only access a subset of it. """ 10 | 11 | def __init__(self, data, indices): 12 | self.data = data 13 | self.indices = indices 14 | 15 | def __len__(self): 16 | return len(self.indices) 17 | 18 | def __getitem__(self, index): 19 | data_idx = self.indices[index] 20 | return self.data[data_idx] 21 | 22 | 23 | class DataPartitioner(object): 24 | """ Partitions a dataset into different chuncks. """ 25 | 26 | def __init__(self, conf, data, partition_sizes, partition_type="random"): 27 | # prepare info. 28 | self.conf = conf 29 | self.data = data 30 | self.partition_sizes = partition_sizes 31 | self.partition_type = partition_type 32 | self.data_size = len(self.data) 33 | self.partitions = [] 34 | 35 | # get unshuffled indices. 36 | indices = [x for x in range(0, self.data_size)] 37 | 38 | # apply partition function. 39 | self.partition_indices(indices) 40 | 41 | def partition_indices(self, indices): 42 | indices = self._get_consistent_indices(indices) 43 | 44 | # partition indices. 45 | from_index = 0 46 | for partition_size in self.partition_sizes: 47 | to_index = from_index + int(partition_size * self.data_size) 48 | self.partitions.append(indices[from_index:to_index]) 49 | from_index = to_index 50 | 51 | def _get_consistent_indices(self, indices): 52 | if self.conf.graph.rank == 0: 53 | if self.partition_type == "random": 54 | # it will randomly shuffle the indices. 55 | random.shuffle(indices) 56 | elif self.partition_type == "sorted": 57 | # it will sort the indices based on the data label. 58 | indices = [ 59 | i[0] 60 | for i in sorted(enumerate(self.data.targets), key=lambda x: x[1]) 61 | ] 62 | 63 | # sync the indices over nodes. 64 | indices = torch.IntTensor(indices) 65 | indices = indices.cuda() if self.conf.backend == "nccl" else indices 66 | group = dist.new_group(self.conf.graph.ranks) 67 | dist.broadcast(indices, src=0, group=group) 68 | indices = indices.cpu() if self.conf.backend == "nccl" else indices 69 | return list(indices) 70 | 71 | def use(self, partition_ind): 72 | return Partition(self.data, self.partitions[partition_ind]) 73 | -------------------------------------------------------------------------------- /distributed_code/pcode/datasets/prepare_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import spacy 5 | from spacy.symbols import ORTH 6 | import torchtext 7 | import torchvision.datasets as datasets 8 | import torchvision.transforms as transforms 9 | 10 | 11 | from pcode.datasets.loader.imagenet_folder import define_imagenet_folder 12 | from pcode.datasets.loader.svhn_folder import define_svhn_folder 13 | from pcode.datasets.loader.epsilon_or_rcv1_folder import define_epsilon_or_rcv1_folder 14 | 15 | 16 | """the entry for classification tasks.""" 17 | 18 | 19 | def _get_cifar(name, root, split, transform, target_transform, download): 20 | is_train = split == "train" 21 | 22 | # decide normalize parameter. 23 | if name == "cifar10": 24 | dataset_loader = datasets.CIFAR10 25 | normalize = transforms.Normalize( 26 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) 27 | ) 28 | elif name == "cifar100": 29 | dataset_loader = datasets.CIFAR100 30 | normalize = transforms.Normalize( 31 | (0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761) 32 | ) 33 | 34 | # decide data type. 35 | if is_train: 36 | transform = transforms.Compose( 37 | [ 38 | transforms.RandomHorizontalFlip(), 39 | transforms.RandomCrop((32, 32), 4), 40 | transforms.ToTensor(), 41 | normalize, 42 | ] 43 | ) 44 | else: 45 | transform = transforms.Compose([transforms.ToTensor(), normalize]) 46 | return dataset_loader( 47 | root=root, 48 | train=is_train, 49 | transform=transform, 50 | target_transform=target_transform, 51 | download=download, 52 | ) 53 | 54 | 55 | def _get_mnist(root, split, transform, target_transform, download): 56 | is_train = split == "train" 57 | 58 | if is_train: 59 | transform = transforms.Compose( 60 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 61 | ) 62 | else: 63 | transform = transforms.Compose( 64 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 65 | ) 66 | return datasets.MNIST( 67 | root=root, 68 | train=is_train, 69 | transform=transform, 70 | target_transform=target_transform, 71 | download=download, 72 | ) 73 | 74 | 75 | def _get_stl10(root, split, transform, target_transform, download): 76 | return datasets.STL10( 77 | root=root, 78 | split=split, 79 | transform=transform, 80 | target_transform=target_transform, 81 | download=download, 82 | ) 83 | 84 | 85 | def _get_svhn(root, split, transform, target_transform, download): 86 | is_train = split == "train" 87 | 88 | transform = transforms.Compose( 89 | [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 90 | ) 91 | return define_svhn_folder( 92 | root=root, 93 | is_train=is_train, 94 | transform=transform, 95 | target_transform=target_transform, 96 | download=download, 97 | ) 98 | 99 | 100 | def _get_imagenet(conf, name, datasets_path, split): 101 | is_train = split == "train" 102 | root = ( 103 | os.path.join( 104 | datasets_path, "lmdb" if "downsampled" not in name else "lmdb_32x32" 105 | ) 106 | if conf.use_lmdb_data 107 | else datasets_path 108 | ) 109 | 110 | if is_train: 111 | root = os.path.join( 112 | root, "train{}".format("" if not conf.use_lmdb_data else ".lmdb") 113 | ) 114 | else: 115 | root = os.path.join( 116 | root, "val{}".format("" if not conf.use_lmdb_data else ".lmdb") 117 | ) 118 | return define_imagenet_folder( 119 | name=name, root=root, flag=conf.use_lmdb_data, cuda=conf.graph.on_cuda 120 | ) 121 | 122 | 123 | def _get_epsilon_or_rcv1(root, name, split): 124 | root = os.path.join(root, "{}_{}.lmdb".format(name, split)) 125 | return define_epsilon_or_rcv1_folder(root) 126 | 127 | 128 | """the entry for language modeling task.""" 129 | 130 | 131 | def _get_text(batch_first): 132 | spacy_en = spacy.load("en") 133 | spacy_en.tokenizer.add_special_case("", [{ORTH: ""}]) 134 | spacy_en.tokenizer.add_special_case("", [{ORTH: ""}]) 135 | spacy_en.tokenizer.add_special_case("", [{ORTH: ""}]) 136 | 137 | def spacy_tok(text): 138 | return [tok.text for tok in spacy_en.tokenizer(text)] 139 | 140 | TEXT = torchtext.data.Field(lower=True, tokenize=spacy_tok, batch_first=batch_first) 141 | return TEXT 142 | 143 | 144 | def _get_nlp_lm_dataset(name, datasets_path, batch_first): 145 | TEXT = _get_text(batch_first) 146 | 147 | # Load and split data. 148 | if "wikitext2" in name: 149 | train, valid, test = torchtext.datasets.WikiText2.splits( 150 | TEXT, root=datasets_path 151 | ) 152 | elif "ptb" in name: 153 | train, valid, test = torchtext.datasets.PennTreebank.splits( 154 | TEXT, root=datasets_path 155 | ) 156 | else: 157 | raise NotImplementedError 158 | return TEXT, train, valid, test 159 | 160 | 161 | """the entry for different supported dataset.""" 162 | 163 | 164 | def get_dataset( 165 | conf, 166 | name, 167 | datasets_path, 168 | split="train", 169 | transform=None, 170 | target_transform=None, 171 | download=True, 172 | ): 173 | # create data folder if it does not exist. 174 | root = os.path.join(datasets_path, name) 175 | 176 | if name == "cifar10" or name == "cifar100": 177 | return _get_cifar(name, root, split, transform, target_transform, download) 178 | elif name == "svhn": 179 | return _get_svhn(root, split, transform, target_transform, download) 180 | elif name == "mnist": 181 | return _get_mnist(root, split, transform, target_transform, download) 182 | elif name == "stl10": 183 | return _get_stl10(root, split, transform, target_transform, download) 184 | elif "imagenet" in name: 185 | return _get_imagenet(conf, name, datasets_path, split) 186 | elif name == "epsilon": 187 | return _get_epsilon_or_rcv1(root, name, split) 188 | elif name == "rcv1": 189 | return _get_epsilon_or_rcv1(root, name, split) 190 | elif name == "wikitext2" or name == "ptb": 191 | return _get_nlp_lm_dataset(name, datasets_path, batch_first=False) 192 | else: 193 | raise NotImplementedError 194 | -------------------------------------------------------------------------------- /distributed_code/pcode/distributed_running_cv.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import gc 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from pcode.create_dataset import define_dataset, load_data_batch 8 | 9 | from pcode.utils.checkpoint import save_to_checkpoint 10 | from pcode.utils.logging import ( 11 | display_training_stat, 12 | display_test_stat, 13 | dispaly_best_test_stat, 14 | ) 15 | from pcode.utils.stat_tracker import RuntimeTracker 16 | import pcode.utils.error_handler as error_handler 17 | 18 | # sys.excepthook = error_handler.global_except_hook 19 | 20 | 21 | def train_and_validate( 22 | conf, model, criterion, scheduler, optimizer, metrics, data_loader 23 | ): 24 | print("=>>>> start training and validation.") 25 | # define runtime stat tracker and start the training. 26 | tracker_tr = RuntimeTracker( 27 | metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda 28 | ) 29 | 30 | # get the timer. 31 | timer = conf.timer 32 | # break until finish expected full epoch training. 33 | print("=>>>> enter the training.\n") 34 | while True: 35 | dist.barrier() 36 | 37 | # configure local step. 38 | for _input, _target in data_loader["train_loader"]: 39 | model.train() 40 | 41 | # load data 42 | with timer("load_data", epoch=scheduler.epoch_): 43 | _input, _target = load_data_batch(conf, _input, _target) 44 | 45 | # inference and get current performance. 46 | with timer("forward_pass", epoch=scheduler.epoch_): 47 | optimizer.zero_grad() 48 | loss = inference(model, criterion, metrics, _input, _target, tracker_tr) 49 | 50 | with timer("backward_pass", epoch=scheduler.epoch_): 51 | loss.backward() 52 | 53 | with timer("sync_and_apply_grad", epoch=scheduler.epoch_): 54 | n_bits_to_transmit = optimizer.step(timer=timer, scheduler=scheduler) 55 | scheduler.step() 56 | 57 | # display the logging info. 58 | display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) 59 | 60 | # finish one epoch training and to decide if we want to val our model. 61 | if scheduler.epoch_ % 1 == 0: 62 | if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( 63 | tracker_tr.stat["loss"].avg 64 | ): 65 | print("\nThe process diverges!!!!!Early stop it.") 66 | error_handler.abort() 67 | 68 | # each worker finish one epoch training. 69 | do_validate( 70 | conf, model, optimizer, criterion, scheduler, metrics, data_loader 71 | ) 72 | 73 | # refresh the logging cache at the begining of each epoch. 74 | tracker_tr.reset() 75 | 76 | # determine if the training is finished. 77 | if scheduler.is_stop(): 78 | # save json. 79 | conf.logger.save_json() 80 | return 81 | 82 | # display tracking time. 83 | if ( 84 | conf.graph.rank == 0 85 | and conf.display_tracked_time 86 | and scheduler.local_index % conf.summary_freq == 0 87 | ): 88 | print(timer.summary()) 89 | 90 | # reshuffle the data. 91 | if conf.reshuffle_per_epoch: 92 | print("\nReshuffle the dataset.") 93 | del data_loader 94 | gc.collect() 95 | data_loader = define_dataset(conf) 96 | 97 | 98 | def inference(model, criterion, metrics, _input, _target, tracker=None): 99 | """Inference on the given model and get loss and accuracy.""" 100 | output = model(_input) 101 | loss = criterion(output, _target) 102 | performance = metrics.evaluate(loss, output, _target) 103 | if tracker is not None: 104 | tracker.update_metrics([loss.item()] + performance, n_samples=_input.size(0)) 105 | return loss 106 | 107 | 108 | def do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader): 109 | """Evaluate the model on the test dataset and save to the checkpoint.""" 110 | # wait until the whole group enters this function, and then evaluate. 111 | print("Enter validation phase.") 112 | performance = validate( 113 | conf, model, optimizer, criterion, scheduler, metrics, data_loader 114 | ) 115 | 116 | # remember best performance and display the val info. 117 | scheduler.best_tracker.update(performance[0], scheduler.epoch_) 118 | dispaly_best_test_stat(conf, scheduler) 119 | 120 | # save to the checkpoint. 121 | if not conf.train_fast: 122 | save_to_checkpoint( 123 | conf, 124 | { 125 | "arch": conf.arch, 126 | "current_epoch": scheduler.epoch, 127 | "local_index": scheduler.local_index, 128 | "best_perf": scheduler.best_tracker.best_perf, 129 | "optimizer": optimizer.state_dict(), 130 | "state_dict": model.state_dict(), 131 | }, 132 | scheduler.best_tracker.is_best, 133 | dirname=conf.checkpoint_dir, 134 | filename="checkpoint.pth.tar", 135 | save_all=conf.save_all_models, 136 | ) 137 | print("Finished validation.") 138 | 139 | 140 | def validate( 141 | conf, 142 | model, 143 | optimizer, 144 | criterion, 145 | scheduler, 146 | metrics, 147 | data_loader, 148 | label="local_model", 149 | ): 150 | """A function for model evaluation.""" 151 | 152 | def _evaluate(_model, label): 153 | # define stat. 154 | tracker_te = RuntimeTracker( 155 | metrics_to_track=metrics.metric_names, on_cuda=conf.graph.on_cuda 156 | ) 157 | 158 | # switch to evaluation mode 159 | _model.eval() 160 | 161 | for _input, _target in data_loader["val_loader"]: 162 | # load data and check performance. 163 | _input, _target = load_data_batch(conf, _input, _target) 164 | 165 | with torch.no_grad(): 166 | inference(_model, criterion, metrics, _input, _target, tracker_te) 167 | 168 | # display the test stat. 169 | display_test_stat(conf, scheduler, tracker_te, label) 170 | 171 | # get global (mean) performance 172 | global_performance = tracker_te.evaluate_global_metrics() 173 | return global_performance 174 | 175 | # evaluate each local model on the validation dataset. 176 | global_performance = _evaluate(model, label=label) 177 | return global_performance 178 | -------------------------------------------------------------------------------- /distributed_code/pcode/distributed_running_nlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | 5 | from pcode.utils.checkpoint import save_to_checkpoint 6 | from pcode.utils.logging import ( 7 | display_training_stat, 8 | display_test_stat, 9 | dispaly_best_test_stat, 10 | ) 11 | from pcode.utils.stat_tracker import RuntimeTracker 12 | import pcode.utils.error_handler as error_handler 13 | from pcode.create_dataset import load_data_batch 14 | 15 | 16 | # sys.excepthook = error_handler.global_except_hook 17 | 18 | 19 | def train_and_validate( 20 | conf, model, criterion, scheduler, optimizer, metrics, data_loader 21 | ): 22 | print("=>>>> start training and validation.\n") 23 | 24 | # define runtime stat tracker and start the training. 25 | tracker_tr = RuntimeTracker(metrics_to_track=metrics.metric_names) 26 | 27 | # get the timer. 28 | timer = conf.timer 29 | 30 | # break until finish expected full epoch training. 31 | print("=>>>> enter the training.\n") 32 | while True: 33 | # init the hidden state. 34 | _hidden = ( 35 | model.module.init_hidden(conf.batch_size) 36 | if "DataParallel" == model.__class__.__name__ 37 | else model.init_hidden(conf.batch_size) 38 | ) 39 | 40 | # configure local step. 41 | for batch in data_loader["train_loader"]: 42 | model.train() 43 | 44 | # repackage the hidden. 45 | _hidden = ( 46 | model.module.repackage_hidden(_hidden) 47 | if "DataParallel" == model.__class__.__name__ 48 | else model.repackage_hidden(_hidden) 49 | ) 50 | 51 | # load data 52 | with timer("load_data", epoch=scheduler.epoch_): 53 | _input = batch.text[ 54 | :, 55 | conf.graph.rank 56 | * conf.batch_size : (conf.graph.rank + 1) 57 | * conf.batch_size, 58 | ] 59 | _target = batch.target[ 60 | :, 61 | conf.graph.rank 62 | * conf.batch_size : (conf.graph.rank + 1) 63 | * conf.batch_size, 64 | ] 65 | _input, _target = load_data_batch(conf, _input, _target) 66 | 67 | # inference and get current performance. 68 | with timer("forward_pass", epoch=scheduler.epoch_): 69 | optimizer.zero_grad() 70 | loss, _hidden = inference( 71 | conf, 72 | model, 73 | criterion, 74 | metrics, 75 | _input, 76 | _target, 77 | _hidden, 78 | tracker_tr, 79 | ) 80 | 81 | with timer("backward_pass", epoch=scheduler.epoch_): 82 | loss.backward() 83 | 84 | with timer("sync_complete", epoch=scheduler.epoch_): 85 | # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. 86 | torch.nn.utils.clip_grad_norm_(model.parameters(), conf.rnn_clip) 87 | n_bits_to_transmit = optimizer.step(timer=timer) 88 | scheduler.step() 89 | 90 | # display the logging info. 91 | display_training_stat(conf, scheduler, tracker_tr, n_bits_to_transmit) 92 | 93 | # finish one epoch training and to decide if we want to val our model. 94 | if scheduler.epoch_ % 1 == 0: 95 | if tracker_tr.stat["loss"].avg > 1e3 or np.isnan( 96 | tracker_tr.stat["loss"].avg 97 | ): 98 | print("\nThe process diverges!!!!!Early stop it.") 99 | error_handler.abort() 100 | 101 | # each worker finish one epoch training. 102 | do_validate( 103 | conf, model, optimizer, criterion, scheduler, metrics, data_loader 104 | ) 105 | 106 | # refresh the logging cache at the begining of each epoch. 107 | tracker_tr.reset() 108 | 109 | # determine if the training is finished. 110 | if scheduler.is_stop(): 111 | conf.logger.save_json() 112 | return 113 | 114 | # display tracking time. 115 | if ( 116 | conf.graph.rank == 0 117 | and conf.display_tracked_time 118 | and scheduler.local_index % conf.summary_freq == 0 119 | ): 120 | print(timer.summary()) 121 | 122 | 123 | def inference(conf, model, criterion, metrics, _input, _target, _hidden, tracker=None): 124 | """Inference on the given model and get loss and accuracy.""" 125 | output, _hidden = model(_input, _hidden) 126 | loss = criterion(output.view(-1, conf.n_tokens), _target.contiguous().view(-1)) 127 | performance = metrics.evaluate(loss, output, _target) 128 | if tracker is not None: 129 | tracker.update_metrics([loss.item()] + performance, n_samples=_input.size(0)) 130 | return loss, _hidden 131 | 132 | 133 | def do_validate(conf, model, optimizer, criterion, scheduler, metrics, data_loader): 134 | """Evaluate the model on the test dataset and save to the checkpoint.""" 135 | # wait until the whole group enters this function, and then evaluate. 136 | performance = validate( 137 | conf, model, optimizer, criterion, scheduler, metrics, data_loader 138 | ) 139 | 140 | # remember best performance and display the val info. 141 | scheduler.best_tracker.update(performance[0], scheduler.epoch_) 142 | dispaly_best_test_stat(conf, scheduler) 143 | 144 | # save to the checkpoint. 145 | save_to_checkpoint( 146 | conf, 147 | { 148 | "arch": conf.arch, 149 | "current_epoch": scheduler.epoch, 150 | "local_index": scheduler.local_index, 151 | "best_perf": scheduler.best_tracker.best_perf, 152 | "optimizer": optimizer.state_dict(), 153 | "state_dict": model.state_dict(), 154 | }, 155 | scheduler.best_tracker.is_best, 156 | dirname=conf.checkpoint_dir, 157 | filename="checkpoint.pth.tar", 158 | save_all=conf.save_all_models, 159 | ) 160 | print("Finished validation.") 161 | 162 | 163 | def validate( 164 | conf, 165 | model, 166 | optimizer, 167 | criterion, 168 | scheduler, 169 | metrics, 170 | data_loader, 171 | label="local_model", 172 | ): 173 | """A function for model evaluation.""" 174 | 175 | def _evaluate(_model, label): 176 | # define stat. 177 | tracker_te = RuntimeTracker(metrics_to_track=metrics.metric_names) 178 | 179 | # switch to evaluation mode 180 | _model.eval() 181 | 182 | # define hidden state for RNN. 183 | _hidden = ( 184 | model.module.init_hidden(conf.batch_size) 185 | if "DataParallel" == model.__class__.__name__ 186 | else model.init_hidden(conf.batch_size) 187 | ) 188 | 189 | for batch in data_loader["val_loader"]: 190 | # load data and check performance. 191 | _input, _target = batch.text, batch.target 192 | 193 | # repackage the hidden. 194 | _hidden = ( 195 | model.module.repackage_hidden(_hidden) 196 | if "DataParallel" == model.__class__.__name__ 197 | else model.repackage_hidden(_hidden) 198 | ) 199 | 200 | with torch.no_grad(): 201 | _, _hidden = inference( 202 | conf, 203 | _model, 204 | criterion, 205 | metrics, 206 | _input, 207 | _target, 208 | _hidden, 209 | tracker_te, 210 | ) 211 | 212 | # display the test stat. 213 | display_test_stat(conf, scheduler, tracker_te, label) 214 | 215 | # get global (mean) performance 216 | global_performance = tracker_te.evaluate_global_metrics() 217 | return global_performance 218 | 219 | # evaluate each local model on the validation dataset. 220 | global_performance = _evaluate(model, label=label) 221 | return global_performance 222 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .densenet import * 3 | from .wideresnet import * 4 | from .mlp import * 5 | from .vgg import * 6 | from .rnn_lm import * 7 | from .lenet import * 8 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/densenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | __all__ = ['densenet'] 10 | 11 | 12 | class BasicLayer(nn.Module): 13 | def __init__(self, num_channels, growth_rate, drop_rate=0.0): 14 | super(BasicLayer, self).__init__() 15 | 16 | self.bn1 = nn.BatchNorm2d(num_channels) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d( 19 | num_channels, growth_rate, kernel_size=3, padding=1, bias=False) 20 | self.droprate = drop_rate 21 | 22 | def forward(self, x): 23 | out = self.conv1(self.relu(self.bn1(x))) 24 | 25 | if self.droprate > 0: 26 | out = F.dropout(out, p=self.droprate, training=self.training) 27 | 28 | out = torch.cat((x, out), 1) 29 | return out 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | def __init__(self, num_channels, growth_rate, drop_rate=0.0): 34 | super(Bottleneck, self).__init__() 35 | 36 | inter_channels = 4 * growth_rate 37 | self.bn1 = nn.BatchNorm2d(num_channels) 38 | self.relu = nn.ReLU(inplace=True) 39 | self.conv1 = nn.Conv2d( 40 | num_channels, inter_channels, kernel_size=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(inter_channels) 42 | self.conv2 = nn.Conv2d( 43 | inter_channels, growth_rate, kernel_size=3, padding=1, bias=False) 44 | self.droprate = drop_rate 45 | 46 | def forward(self, x): 47 | out = self.conv1(self.relu(self.bn1(x))) 48 | if self.droprate > 0: 49 | out = F.dropout( 50 | out, p=self.droprate, inplace=False, training=self.training) 51 | 52 | out = self.conv2(self.relu(self.bn2(out))) 53 | if self.droprate > 0: 54 | out = F.dropout( 55 | out, p=self.droprate, inplace=False, training=self.training) 56 | 57 | out = torch.cat((x, out), 1) 58 | return out 59 | 60 | 61 | class Transition(nn.Module): 62 | def __init__(self, num_channels, num_out_channels, drop_rate=0.0): 63 | super(Transition, self).__init__() 64 | self.bn1 = nn.BatchNorm2d(num_channels) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.conv1 = nn.Conv2d( 67 | num_channels, num_out_channels, kernel_size=1, bias=False) 68 | 69 | self.droprate = drop_rate 70 | 71 | def forward(self, x): 72 | out = self.conv1(self.relu(self.bn1(x))) 73 | if self.droprate > 0: 74 | out = F.dropout( 75 | out, p=self.droprate, inplace=False, training=self.training) 76 | 77 | out = F.avg_pool2d(out, 2) 78 | return out 79 | 80 | 81 | class DenseNet(nn.Module): 82 | def __init__(self, dataset, 83 | net_depth, growth_rate, bc_mode, compression, drop_rate): 84 | super(DenseNet, self).__init__() 85 | 86 | # determine some fundamental configurations. 87 | self.dataset = dataset 88 | self.num_classes = self._decide_num_classes() 89 | is_small_inputs = 'imagenet' not in self.dataset 90 | self.avgpool_size = 8 if is_small_inputs else 7 91 | assert 0 < compression <= 1, 'compression should be between 0 and 1.' 92 | 93 | # determine block_config for different types of the data. 94 | if is_small_inputs: 95 | num_blocks = 3 96 | num_layers_per_block = (net_depth - (num_blocks + 1)) // num_blocks 97 | 98 | if bc_mode: 99 | num_layers_per_block = num_layers_per_block // 2 100 | block_config = [num_layers_per_block] * num_blocks 101 | else: 102 | model_params = { 103 | 121: [6, 12, 24, 16], 104 | 169: [6, 12, 32, 32], 105 | 201: [6, 12, 48, 32], 106 | 264: [6, 12, 64, 48] 107 | } 108 | 109 | assert net_depth not in model_params.keys() 110 | block_config = model_params[net_depth] 111 | 112 | # init conv. 113 | num_channels = 2 * growth_rate 114 | if is_small_inputs: 115 | self.features = nn.Sequential( 116 | OrderedDict([ 117 | ('conv0', nn.Conv2d(3, num_channels, kernel_size=3, 118 | stride=1, padding=1, bias=False)) 119 | ]) 120 | ) 121 | else: 122 | self.features = nn.Sequential( 123 | OrderedDict([ 124 | ('conv0', nn.Conv2d(3, num_channels, kernel_size=7, 125 | stride=2, padding=3, bias=False)), 126 | ('norm0', nn.BatchNorm2d(num_channels)), 127 | ('relu0', nn.ReLU(inplace=True)), 128 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1, 129 | ceil_mode=False)) 130 | ]) 131 | ) 132 | 133 | # each denseblock 134 | for ind, num_layers in enumerate(block_config): 135 | block = self._make_dense( 136 | num_channels, growth_rate, num_layers, 137 | bc_mode, drop_rate) 138 | self.features.add_module('denseblock%d' % (ind + 1), block) 139 | 140 | num_channels += num_layers * growth_rate 141 | num_out_channels = int(math.floor(num_channels * compression)) 142 | 143 | # transition_layer 144 | if ind != len(block_config) - 1: 145 | trans = Transition(num_channels, num_out_channels, drop_rate) 146 | self.features.add_module('transition%d' % (ind + 1), trans) 147 | num_channels = num_out_channels 148 | 149 | # final batch norm 150 | self.features.add_module('norm_final', nn.BatchNorm2d(num_channels)) 151 | 152 | # Linear layer 153 | self.classifier = nn.Linear(num_channels, self.num_classes) 154 | 155 | # init weight. 156 | self._weight_initialization() 157 | 158 | def _decide_num_classes(self): 159 | if self.dataset == 'cifar10' or self.dataset == 'svhn': 160 | return 10 161 | elif self.dataset == 'cifar100': 162 | return 100 163 | elif self.dataset == 'imagenet': 164 | return 1000 165 | 166 | def _weight_initialization(self): 167 | for m in self.modules(): 168 | if isinstance(m, nn.Conv2d): 169 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 170 | m.weight.data.normal_(0, math.sqrt(2. / n)) 171 | elif isinstance(m, nn.BatchNorm2d): 172 | m.weight.data.fill_(1) 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.Linear): 175 | m.bias.data.zero_() 176 | 177 | def _make_dense( 178 | self, num_channels, growth_rate, num_layers_per_block, 179 | bc_mode, drop_rate): 180 | layers = [] 181 | for _ in range(int(num_layers_per_block)): 182 | if bc_mode: 183 | layers.append( 184 | Bottleneck(num_channels, growth_rate, drop_rate)) 185 | else: 186 | layers.append( 187 | BasicLayer(num_channels, growth_rate, drop_rate)) 188 | num_channels += growth_rate 189 | return nn.Sequential(*layers) 190 | 191 | def forward(self, x): 192 | features = self.features(x) 193 | out = F.relu(features, inplace=True) 194 | out = F.avg_pool2d( 195 | out, kernel_size=self.avgpool_size).view(features.size(0), -1) 196 | out = self.classifier(out) 197 | return out 198 | 199 | 200 | def densenet(conf): 201 | net_depth = int(conf.arch.replace('densenet', '')) 202 | 203 | model = DenseNet( 204 | dataset=conf.data, net_depth=net_depth, 205 | growth_rate=conf.densenet_growth_rate, bc_mode=conf.densenet_bc_mode, 206 | compression=conf.densenet_compression, 207 | drop_rate=conf.drop_rate) 208 | return model 209 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/lenet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from collections import OrderedDict 3 | 4 | import torch.nn as nn 5 | 6 | __all__ = ["lenet"] 7 | 8 | 9 | class LeNet(nn.Module): 10 | """ 11 | Input - 3x32x32 12 | C1 - 6@28x28 (5x5 kernel) 13 | tanh 14 | S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling 15 | C3 - 16@10x10 (5x5 kernel) 16 | tanh 17 | S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling 18 | C5 - 120@1x1 (5x5 kernel) 19 | F6 - 84 20 | ReLU 21 | F7 - 10 (Output) 22 | """ 23 | 24 | def __init__(self, dataset="cifar10"): 25 | super(LeNet, self).__init__() 26 | 27 | # some init. 28 | self.dataset = dataset 29 | self.num_classes = self._decide_num_classes() 30 | 31 | # init layers. 32 | self.convnet = nn.Sequential( 33 | OrderedDict( 34 | [ 35 | ( 36 | "conv1", 37 | nn.Conv2d(self._decide_input_dim(), 6, kernel_size=(5, 5)), 38 | ), 39 | ("relu1", nn.ReLU()), 40 | ("s2", nn.MaxPool2d(kernel_size=(2, 2), stride=2)), 41 | ("conv3", nn.Conv2d(6, 16, kernel_size=(5, 5))), 42 | ("relu3", nn.ReLU()), 43 | ("s4", nn.MaxPool2d(kernel_size=(2, 2), stride=2)), 44 | ("conv5", nn.Conv2d(16, 120, kernel_size=(5, 5))), 45 | ("relu5", nn.ReLU()), 46 | ] 47 | ) 48 | ) 49 | 50 | self.fc = nn.Sequential( 51 | OrderedDict( 52 | [ 53 | ("fc6", nn.Linear(120, 84)), 54 | ("relu6", nn.ReLU()), 55 | ("fc7", nn.Linear(84, self.num_classes)), 56 | ] 57 | ) 58 | ) 59 | 60 | def forward(self, x): 61 | out = self.convnet(x) 62 | out = out.view(x.size(0), -1) 63 | out = self.fc(out) 64 | return out 65 | 66 | def _decide_num_classes(self): 67 | if ( 68 | self.dataset == "cifar10" 69 | or self.dataset == "svhn" 70 | or self.dataset == "mnist" 71 | ): 72 | return 10 73 | elif self.dataset == "cifar100": 74 | return 100 75 | elif self.dataset == "imagenet": 76 | return 1000 77 | 78 | def _decide_input_dim(self): 79 | if ( 80 | "cifar" in self.dataset 81 | or self.dataset == "svhn" 82 | or self.dataset == "imagenet" 83 | ): 84 | return 3 85 | elif "mnist" == self.dataset: 86 | return 1 87 | else: 88 | raise RuntimeError("incorrect input dim.") 89 | 90 | 91 | def lenet(conf): 92 | """Constructs a lenet model.""" 93 | return LeNet(dataset=conf.data) 94 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/mlp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.nn as nn 3 | 4 | 5 | __all__ = ['mlp'] 6 | 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, dataset, num_layers, hidden_size, drop_rate): 10 | super(MLP, self).__init__() 11 | self.dataset = dataset 12 | 13 | # init 14 | self.num_layers = num_layers 15 | self.num_classes = self._decide_num_classes() 16 | input_size = self._decide_input_feature_size() 17 | 18 | # define layers. 19 | for i in range(1, self.num_layers + 1): 20 | in_features = input_size if i == 1 else hidden_size 21 | out_features = hidden_size 22 | 23 | layer = nn.Sequential( 24 | nn.Linear(in_features, out_features), 25 | nn.BatchNorm1d(out_features), 26 | nn.ReLU(), 27 | nn.Dropout(p=drop_rate)) 28 | setattr(self, 'layer{}'.format(i), layer) 29 | 30 | self.fc = nn.Linear(hidden_size, self.num_classes, bias=False) 31 | 32 | def _decide_num_classes(self): 33 | if self.dataset == 'cifar10': 34 | return 10 35 | elif self.dataset == 'cifar100': 36 | return 100 37 | 38 | def _decide_input_feature_size(self): 39 | if 'cifar' in self.dataset: 40 | return 32 * 32 * 3 41 | elif 'mnist' in self.dataset: 42 | return 28 * 28 43 | else: 44 | raise NotImplementedError 45 | 46 | def forward(self, x): 47 | out = x.view(x.size(0), -1) 48 | 49 | for i in range(1, self.num_layers + 1): 50 | out = getattr(self, 'layer{}'.format(i))(out) 51 | out = self.fc(out) 52 | return out 53 | 54 | 55 | def mlp(conf): 56 | return MLP( 57 | dataset=conf.data, num_layers=conf.mlp_num_layers, 58 | hidden_size=conf.mlp_hidden_size, drop_rate=conf.drop_rate) 59 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/rnn_lm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | 7 | class RNNLM(nn.Module): 8 | """Container module with an encoder, a recurrent module, and a decoder.""" 9 | 10 | def __init__( 11 | self, 12 | ntoken, 13 | ninp, 14 | nhid, 15 | nlayers, 16 | dropout=0.5, 17 | tie_weights=False, 18 | weight_norm=False, 19 | ): 20 | super(RNNLM, self).__init__() 21 | 22 | # define conf. 23 | self.nhid = nhid 24 | self.nlayers = nlayers 25 | 26 | # define layers. 27 | self.drop = nn.Dropout(dropout) 28 | self.encoder = nn.Embedding(ntoken, ninp) 29 | self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout, batch_first=False) 30 | self.decoder = nn.Linear(nhid, ntoken) 31 | 32 | # Optionally tie weights as in: 33 | # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) 34 | # https://arxiv.org/abs/1608.05859 35 | # and 36 | # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) 37 | # https://arxiv.org/abs/1611.01462 38 | if tie_weights: 39 | if nhid != ninp: 40 | raise ValueError( 41 | "When using the tied flag, nhid must be equal to emsize" 42 | ) 43 | self.decoder.weight = self.encoder.weight 44 | 45 | self.init_weights() 46 | 47 | def init_weights(self): 48 | initrange = 0.1 49 | self.encoder.weight.data.uniform_(-initrange, initrange) 50 | self.decoder.bias.data.zero_() 51 | self.decoder.weight.data.uniform_(-initrange, initrange) 52 | 53 | def forward(self, input, hidden): 54 | self.rnn.flatten_parameters() 55 | emb = self.drop(self.encoder(input)) 56 | output, hidden = self.rnn(emb, hidden) 57 | output = self.drop(output) 58 | decoded = self.decoder( 59 | output.view(output.size(0) * output.size(1), output.size(2)) 60 | ) 61 | return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden 62 | 63 | def init_hidden(self, bsz): 64 | weight = next(self.parameters()) 65 | return ( 66 | weight.new_zeros(self.nlayers, bsz, self.nhid), 67 | weight.new_zeros(self.nlayers, bsz, self.nhid), 68 | ) 69 | 70 | def repackage_hidden(self, h): 71 | """Wraps hidden states in new Tensors, to detach them from their history.""" 72 | if isinstance(h, torch.Tensor): 73 | return h.detach() 74 | else: 75 | return tuple(self.repackage_hidden(v) for v in h) 76 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/vgg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | 7 | __all__ = ['vgg'] 8 | 9 | 10 | ARCHITECTURES = { 11 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 12 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 13 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 14 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 15 | 512, 512, 512, 512, 'M'], 16 | } 17 | 18 | 19 | class VGG(nn.Module): 20 | def __init__(self, nn_arch, dataset, use_bn=True): 21 | super(VGG, self).__init__() 22 | 23 | # init parameters. 24 | self.use_bn = use_bn 25 | self.nn_arch = nn_arch 26 | self.dataset = dataset 27 | self.num_classes = self._decide_num_classes() 28 | 29 | # init models. 30 | self.features = self._make_layers() 31 | self.classifier = nn.Sequential( 32 | nn.Dropout(), 33 | nn.Linear(512, 512), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(512, 512), 37 | nn.ReLU(True), 38 | nn.Linear(512, self.num_classes), 39 | ) 40 | 41 | # weight initialization. 42 | self._weight_initialization() 43 | 44 | def _weight_initialization(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2. / n)) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | m.weight.data.fill_(1) 51 | m.bias.data.zero_() 52 | 53 | def _decide_num_classes(self): 54 | if self.dataset == 'cifar10' or self.dataset == 'svhn': 55 | return 10 56 | elif self.dataset == 'cifar100': 57 | return 100 58 | else: 59 | raise ValueError('not allowed dataset.') 60 | 61 | def _make_layers(self): 62 | layers = [] 63 | in_channels = 3 64 | for v in ARCHITECTURES[self.nn_arch]: 65 | if v == 'M': 66 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 67 | else: 68 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 69 | if self.use_bn: 70 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 71 | else: 72 | layers += [conv2d, nn.ReLU(inplace=True)] 73 | in_channels = v 74 | return nn.Sequential(*layers) 75 | 76 | def forward(self, x): 77 | x = self.features(x) 78 | x = x.view(x.size(0), -1) 79 | x = self.classifier(x) 80 | return x 81 | 82 | 83 | def vgg(conf): 84 | use_bn = 'bn' in conf.arch 85 | dataset = conf.data 86 | 87 | if '11' in conf.arch: 88 | return VGG(nn_arch='A', dataset=dataset, use_bn=use_bn) 89 | elif '13' in conf.arch: 90 | return VGG(nn_arch='B', dataset=dataset, use_bn=use_bn) 91 | elif '16' in conf.arch: 92 | return VGG(nn_arch='D', dataset=dataset, use_bn=use_bn) 93 | elif '19' in conf.arch: 94 | return VGG(nn_arch='E', dataset=dataset, use_bn=use_bn) 95 | else: 96 | raise NotImplementedError 97 | -------------------------------------------------------------------------------- /distributed_code/pcode/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ['wideresnet'] 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0): 13 | super(BasicBlock, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.relu1 = nn.ReLU(inplace=True) 16 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, 17 | stride=stride, padding=1, bias=False) 18 | self.bn2 = nn.BatchNorm2d(out_planes) 19 | self.relu2 = nn.ReLU(inplace=True) 20 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, 21 | stride=1, padding=1, bias=False) 22 | self.droprate = drop_rate 23 | self.equal_in_out = (in_planes == out_planes) 24 | self.conv_shortcut = (not self.equal_in_out) and nn.Conv2d( 25 | in_planes, out_planes, kernel_size=1, stride=stride, 26 | padding=0, bias=False) or None 27 | 28 | def forward(self, x): 29 | if not self.equal_in_out: 30 | x = self.relu1(self.bn1(x)) 31 | else: 32 | out = self.relu1(self.bn1(x)) 33 | out = self.relu2(self.bn2(self.conv1(out if self.equal_in_out else x))) 34 | if self.droprate > 0: 35 | out = F.dropout(out, p=self.droprate, training=self.training) 36 | out = self.conv2(out) 37 | return torch.add( 38 | x if self.equal_in_out else self.conv_shortcut(x), out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__( 43 | self, nb_layers, in_planes, out_planes, block, stride, 44 | drop_rate=0.0): 45 | super(NetworkBlock, self).__init__() 46 | self.layer = self._make_layer( 47 | block, in_planes, out_planes, nb_layers, stride, drop_rate) 48 | 49 | def _make_layer( 50 | self, block, in_planes, out_planes, 51 | nb_layers, stride, drop_rate): 52 | layers = [] 53 | for i in range(nb_layers): 54 | layers.append( 55 | block(i == 0 and in_planes or out_planes, out_planes, 56 | i == 0 and stride or 1, 57 | drop_rate) 58 | ) 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | return self.layer(x) 63 | 64 | 65 | class WideResNet(nn.Module): 66 | def __init__(self, dataset, net_depth, widen_factor, drop_rate): 67 | super(WideResNet, self).__init__() 68 | 69 | # define fundamental parameters. 70 | self.dataset = dataset 71 | 72 | assert((net_depth - 4) % 6 == 0) 73 | num_channels = [16, 16 * widen_factor, 74 | 32 * widen_factor, 64 * widen_factor] 75 | num_blocks = (net_depth - 4) // 6 76 | block = BasicBlock 77 | self.num_classes = self._decide_num_classes() 78 | 79 | # 1st conv before any network block 80 | self.conv1 = nn.Conv2d(3, num_channels[0], kernel_size=3, stride=1, 81 | padding=1, bias=False) 82 | # 1st block 83 | self.block1 = NetworkBlock(num_blocks, 84 | num_channels[0], num_channels[1], 85 | block, 1, drop_rate) 86 | # 2nd block 87 | self.block2 = NetworkBlock(num_blocks, 88 | num_channels[1], num_channels[2], 89 | block, 2, drop_rate) 90 | # 3rd block 91 | self.block3 = NetworkBlock(num_blocks, 92 | num_channels[2], num_channels[3], 93 | block, 2, drop_rate) 94 | 95 | # global average pooling and classifier 96 | self.bn1 = nn.BatchNorm2d(num_channels[3]) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.num_channels = num_channels[3] 99 | self.fc = nn.Linear(num_channels[3], self.num_classes) 100 | 101 | self._weight_initialization() 102 | 103 | def _decide_num_classes(self): 104 | if self.dataset == 'cifar10' or self.dataset == 'svhn': 105 | return 10 106 | elif self.dataset == 'cifar100': 107 | return 100 108 | elif 'imagenet' in self.dataset: 109 | return 1000 110 | 111 | def _weight_initialization(self): 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | elif isinstance(m, nn.Linear): 120 | m.bias.data.zero_() 121 | 122 | def forward(self, x): 123 | out = self.conv1(x) 124 | out = self.block1(out) 125 | out = self.block2(out) 126 | out = self.block3(out) 127 | out = self.relu(self.bn1(out)) 128 | out = F.avg_pool2d(out, 8) 129 | out = out.view(-1, self.num_channels) 130 | return self.fc(out) 131 | 132 | 133 | def wideresnet(conf): 134 | net_depth = int(conf.arch.replace('wideresnet', '')) 135 | dataset = conf.data 136 | 137 | if 'cifar' in conf.data or 'svhn' in conf.data: 138 | model = WideResNet( 139 | dataset=dataset, net_depth=net_depth, 140 | widen_factor=conf.wideresnet_widen_factor, 141 | drop_rate=conf.drop_rate) 142 | return model 143 | else: 144 | raise NotImplementedError 145 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/optim/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/optim/local_ef_sign_sgd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | import pcode.optim.utils as utils 8 | import pcode.utils.communication as comm 9 | from pcode.utils.sparsification import get_n_bits 10 | from pcode.utils.tensor_buffer import TensorBuffer 11 | 12 | 13 | class Local_EFSignSGD(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=required, 18 | momentum=0, 19 | dampening=0, 20 | weight_decay=0, 21 | nesterov=False, 22 | conf=None, 23 | model=None, 24 | ): 25 | defaults = dict( 26 | lr=lr, 27 | momentum=momentum, 28 | dampening=dampening, 29 | weight_decay=weight_decay, 30 | nesterov=nesterov, 31 | ) 32 | if nesterov and (momentum <= 0 or dampening != 0): 33 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 34 | super(Local_EFSignSGD, self).__init__(params, defaults) 35 | 36 | # store the whole training arguments. 37 | self.conf = conf 38 | self.rank = conf.graph.rank 39 | self.neighbors_info = conf.graph.get_neighborhood() 40 | self.local_step = conf.local_step 41 | self.turn_on_local_step_from_epoch = conf.turn_on_local_step_from 42 | 43 | # define the aggregator. 44 | self.world_aggregator = comm.get_aggregators( 45 | cur_rank=self.rank, 46 | world=conf.graph.ranks, 47 | neighbors_info=dict( 48 | (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks 49 | ), 50 | aggregator_type="centralized", 51 | ) 52 | 53 | # define sorted param names. 54 | self.param_names = list( 55 | enumerate([group["name"] for group in self.param_groups]) 56 | ) 57 | 58 | # initialize the concensus 59 | self._init_consensus() 60 | self._init_memory() 61 | 62 | def _init_consensus(self): 63 | params, _ = comm.get_data( 64 | self.param_groups, self.param_names, is_get_grad=False 65 | ) 66 | self.consensus_params_tb = deepcopy(TensorBuffer(params)) 67 | 68 | def _init_memory(self): 69 | params, self.shapes = comm.get_data( 70 | self.param_groups, self.param_names, is_get_grad=False 71 | ) 72 | self.memory_tb = TensorBuffer(params) 73 | self.memory_tb.buffer = torch.zeros_like(self.memory_tb.buffer) 74 | 75 | def __setstate__(self, state): 76 | super(Local_EFSignSGD, self).__setstate__(state) 77 | for group in self.param_groups: 78 | group.setdefault("nesterov", False) 79 | 80 | def step(self, closure=None, **kargs): 81 | # do the local update steps. 82 | with kargs["timer"]("sync.local_update", epoch=self.conf.epoch_): 83 | utils.apply_gradient( 84 | self.param_groups, self.state, apply_grad_to_model=True 85 | ) 86 | 87 | # enter the global sync if it satisfies the condition. 88 | if ( 89 | self.conf.epoch_ < self.turn_on_local_step_from_epoch 90 | or self.conf.local_index % self.local_step == 0 91 | ): 92 | with kargs["timer"]("sync.get_params", epoch=self.conf.epoch_): 93 | # get parmas. 94 | params, _ = comm.get_data( 95 | self.param_groups, self.param_names, is_get_grad=False 96 | ) 97 | params_tb = TensorBuffer(params) 98 | 99 | # get the params difference w.r.t. previous synced model. 100 | local_scale, local_sign = [], [] 101 | for consensus_param, param, memory in zip( 102 | self.consensus_params_tb, params_tb, self.memory_tb 103 | ): 104 | # add memory to the model difference. 105 | memory.data.copy_(consensus_param - param + memory) 106 | # compress. 107 | _local_scale, _local_sign = scaled_sign(memory) 108 | # update memory. 109 | memory.data.copy_(memory - _local_scale * _local_sign) 110 | # store local scales and local sign. 111 | local_scale.append(_local_scale) 112 | local_sign.append(_local_sign) 113 | 114 | # concat the update magnitude and directions. 115 | magnitudes_tb = TensorBuffer(local_scale) 116 | directions_tb = TensorBuffer(local_sign) 117 | 118 | # sync and decompress. 119 | with kargs["timer"]("sync.sync_and_decompress", epoch=self.conf.epoch_): 120 | # sync the directions. 121 | directions_tb.buffer = self.world_aggregator._agg( 122 | directions_tb.buffer, "avg", distributed=self.conf.distributed 123 | ) 124 | magnitudes_tb.buffer = self.world_aggregator._agg( 125 | magnitudes_tb.buffer, "avg", distributed=self.conf.distributed 126 | ) 127 | 128 | # unpack the synced info and update the consensus params. 129 | with kargs["timer"]("sync.update_consensus", epoch=self.conf.epoch_): 130 | for update_magnitude, update_direction, consensus_param in zip( 131 | magnitudes_tb, directions_tb, self.consensus_params_tb 132 | ): 133 | consensus_param.add_(-1.0, update_direction.mul(update_magnitude)) 134 | 135 | # consistent the local models by assigning the consensus params. 136 | self.consensus_params_tb.unpack(params) 137 | n_bits = get_n_bits(directions_tb.buffer) + get_n_bits(magnitudes_tb.buffer) 138 | else: 139 | n_bits = 0 140 | return n_bits 141 | 142 | 143 | def scaled_sign(x, name=None): 144 | """ 145 | :param x: torch Tensor 146 | :return: The sign tensor scaled by it's L1 norm divided by the number of elements 147 | """ 148 | _scale = x.norm(p=1) / x.numel() 149 | _sign = torch.sign(x) 150 | 151 | return _scale, _sign 152 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/local_sgd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | import pcode.optim.utils as utils 8 | import pcode.utils.communication as comm 9 | from pcode.utils.sparsification import get_n_bits 10 | from pcode.utils.tensor_buffer import TensorBuffer 11 | 12 | 13 | class LocalSGD(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=required, 18 | momentum=0, 19 | dampening=0, 20 | weight_decay=0, 21 | nesterov=False, 22 | conf=None, 23 | model=None, 24 | ): 25 | defaults = dict( 26 | lr=lr, 27 | momentum=momentum, 28 | dampening=dampening, 29 | weight_decay=weight_decay, 30 | nesterov=nesterov, 31 | ) 32 | if nesterov and (momentum <= 0 or dampening != 0): 33 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 34 | super(LocalSGD, self).__init__(params, defaults) 35 | 36 | # store the whole training arguments. 37 | self.conf = conf 38 | self.rank = conf.graph.rank 39 | self.neighbors_info = conf.graph.get_neighborhood() 40 | self.local_step = conf.local_step 41 | self.turn_on_local_step_from_epoch = conf.turn_on_local_step_from 42 | 43 | # define the aggregator. 44 | self.world_aggregator = comm.get_aggregators( 45 | cur_rank=self.rank, 46 | world=conf.graph.ranks, 47 | neighbors_info=dict( 48 | (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks 49 | ), 50 | aggregator_type="centralized", 51 | ) 52 | 53 | # define sorted param names. 54 | self.param_names = list( 55 | enumerate([group["name"] for group in self.param_groups]) 56 | ) 57 | 58 | # initialize the concensus 59 | self._init_consensus() 60 | 61 | def _init_consensus(self): 62 | params, _ = comm.get_data( 63 | self.param_groups, self.param_names, is_get_grad=False 64 | ) 65 | self.consensus_params_tb = deepcopy(TensorBuffer(params)) 66 | 67 | def __setstate__(self, state): 68 | super(LocalSGD, self).__setstate__(state) 69 | for group in self.param_groups: 70 | group.setdefault("nesterov", False) 71 | 72 | def step(self, closure=None, **kargs): 73 | with kargs["timer"]("sync.local_update", epoch=self.conf.epoch_): 74 | utils.apply_gradient( 75 | self.param_groups, self.state, apply_grad_to_model=True 76 | ) 77 | 78 | with kargs["timer"]("sync.sync_and_update", epoch=self.conf.epoch_): 79 | # enter the global sync if it satisfies the condition. 80 | if ( 81 | self.conf.epoch_ < self.turn_on_local_step_from_epoch 82 | or self.conf.local_index % self.local_step == 0 83 | ): 84 | # get parmas. 85 | params, _ = comm.get_data( 86 | self.param_groups, self.param_names, is_get_grad=False 87 | ) 88 | params_tb = TensorBuffer(params) 89 | 90 | # get params_diff. 91 | param_diff = self.consensus_params_tb.buffer - params_tb.buffer 92 | # sync the directions. 93 | param_diff = self.world_aggregator._agg( 94 | param_diff, "avg", distributed=self.conf.distributed 95 | ) 96 | 97 | # unpack the synced info and update the consensus params. 98 | self.consensus_params_tb.buffer.add_(-1.0, param_diff) 99 | 100 | # consistent the local models by assigning the consensus params. 101 | self.consensus_params_tb.unpack(params) 102 | 103 | # Get n_bits to transmit. 104 | n_bits = get_n_bits(param_diff) 105 | else: 106 | n_bits = 0 107 | return n_bits 108 | 109 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/local_sign_sgd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | import pcode.utils.communication as comm 8 | from pcode.utils.sparsification import get_n_bits 9 | from pcode.utils.tensor_buffer import TensorBuffer 10 | 11 | 12 | class Local_SignSGD(Optimizer): 13 | def __init__( 14 | self, 15 | params, 16 | lr=required, 17 | momentum=0, 18 | dampening=0, 19 | weight_decay=0, 20 | nesterov=False, 21 | conf=None, 22 | model=None, 23 | ): 24 | defaults = dict( 25 | lr=lr, 26 | momentum=momentum, 27 | dampening=dampening, 28 | weight_decay=weight_decay, 29 | nesterov=nesterov, 30 | ) 31 | if nesterov and (momentum <= 0 or dampening != 0): 32 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 33 | super(Local_SignSGD, self).__init__(params, defaults) 34 | 35 | # store the whole training arguments. 36 | self.conf = conf 37 | self.rank = conf.graph.rank 38 | self.neighbors_info = conf.graph.get_neighborhood() 39 | self.local_step = conf.local_step 40 | self.turn_on_local_step_from_epoch = conf.turn_on_local_step_from 41 | 42 | # define the aggregator. 43 | self.world_aggregator = comm.get_aggregators( 44 | cur_rank=self.rank, 45 | world=conf.graph.ranks, 46 | neighbors_info=dict( 47 | (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks 48 | ), 49 | aggregator_type="centralized", 50 | ) 51 | 52 | # define sorted param names. 53 | self.param_names = list( 54 | enumerate([group["name"] for group in self.param_groups]) 55 | ) 56 | 57 | # initialize the concensus 58 | self._init_consensus() 59 | 60 | def _init_consensus(self): 61 | params, _ = comm.get_data( 62 | self.param_groups, self.param_names, is_get_grad=False 63 | ) 64 | self.consensus_params_tb = deepcopy(TensorBuffer(params)) 65 | 66 | def __setstate__(self, state): 67 | super(Local_SignSGD, self).__setstate__(state) 68 | for group in self.param_groups: 69 | group.setdefault("nesterov", False) 70 | 71 | def step(self, closure=None, **kargs): 72 | # do the local update steps. 73 | with kargs["timer"]("sync.local_update", epoch=self.conf.epoch_): 74 | for group in self.param_groups: 75 | weight_decay = group["weight_decay"] 76 | momentum = group["momentum"] 77 | dampening = group["dampening"] 78 | nesterov = group["nesterov"] 79 | 80 | for p in group["params"]: 81 | # get param_state 82 | param_state = self.state[p] 83 | 84 | # get the gradient 85 | if p.grad is None: 86 | continue 87 | d_p = p.grad.data 88 | 89 | # add the weight decay and apply the momentum. 90 | if weight_decay != 0: 91 | d_p.add_(weight_decay, p.data) 92 | # apply the momentum. 93 | if momentum != 0: 94 | if "momentum_buffer" not in param_state: 95 | buf = param_state["momentum_buffer"] = torch.zeros_like( 96 | p.data 97 | ) 98 | buf.mul_(momentum).add_(d_p) 99 | else: 100 | buf = param_state["momentum_buffer"] 101 | buf.mul_(momentum).add_(1 - dampening, d_p) 102 | if nesterov: 103 | d_p = d_p.add(momentum, buf) 104 | else: 105 | d_p = buf 106 | 107 | # get the local sign and apply to the local model. 108 | p.data.add_(-group["lr"], torch.sign(d_p)) 109 | 110 | # enter the global sync if it satisfies the condition. 111 | if ( 112 | self.conf.epoch_ < self.turn_on_local_step_from_epoch 113 | or self.conf.local_index % self.local_step == 0 114 | ): 115 | with kargs["timer"]("sync.get_params", epoch=self.conf.epoch_): 116 | # get parmas. 117 | params, _ = comm.get_data( 118 | self.param_groups, self.param_names, is_get_grad=False 119 | ) 120 | params_tb = TensorBuffer(params) 121 | 122 | # get the params difference w.r.t. previous synced model. 123 | local_scale, local_sign = [], [] 124 | for consensus_param, param in zip(self.consensus_params_tb, params_tb): 125 | _local_scale, _local_sign = scaled_sign(consensus_param - param) 126 | local_scale.append(_local_scale) 127 | local_sign.append(_local_sign) 128 | 129 | # concat the update magnitude and directions. 130 | magnitudes_tb = TensorBuffer(local_scale) 131 | directions_tb = TensorBuffer(local_sign) 132 | 133 | # sync and decompress. 134 | with kargs["timer"]("sync.sync_and_decompress", epoch=self.conf.epoch_): 135 | # sync the directions. 136 | directions_tb.buffer = self.world_aggregator._agg( 137 | directions_tb.buffer, "avg", distributed=self.conf.distributed 138 | ) 139 | magnitudes_tb.buffer = self.world_aggregator._agg( 140 | magnitudes_tb.buffer, "avg", distributed=self.conf.distributed 141 | ) 142 | 143 | # unpack the synced info and update the consensus params. 144 | with kargs["timer"]("sync.update_consensus", epoch=self.conf.epoch_): 145 | for update_magnitude, update_direction, consensus_param in zip( 146 | magnitudes_tb, directions_tb, self.consensus_params_tb 147 | ): 148 | consensus_param.add_(-1.0, update_direction.mul(update_magnitude)) 149 | 150 | # consistent the local models by assigning the consensus params. 151 | self.consensus_params_tb.unpack(params) 152 | n_bits = get_n_bits(directions_tb.buffer) + get_n_bits(magnitudes_tb.buffer) 153 | else: 154 | n_bits = 0 155 | return n_bits 156 | 157 | 158 | def scaled_sign(x, name=None): 159 | """ 160 | :param x: torch Tensor 161 | :return: The sign tensor scaled by it's L1 norm divided by the number of elements 162 | """ 163 | _scale = x.norm(p=1) / x.numel() 164 | _sign = torch.sign(x) 165 | 166 | return _scale, _sign 167 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/sgd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from torch.optim.optimizer import Optimizer, required 3 | 4 | import pcode.optim.utils as utils 5 | import pcode.utils.communication as comm 6 | from pcode.utils.sparsification import get_n_bits 7 | from pcode.utils.tensor_buffer import TensorBuffer 8 | 9 | 10 | class SGD(Optimizer): 11 | def __init__( 12 | self, 13 | params, 14 | lr=required, 15 | momentum=0, 16 | dampening=0, 17 | weight_decay=0, 18 | nesterov=False, 19 | conf=None, 20 | model=None, 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | momentum=momentum, 25 | dampening=dampening, 26 | weight_decay=weight_decay, 27 | nesterov=nesterov, 28 | ) 29 | if nesterov and (momentum <= 0 or dampening != 0): 30 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 31 | super(SGD, self).__init__(params, defaults) 32 | 33 | # store the whole training arguments. 34 | self.conf = conf 35 | self.rank = conf.graph.rank 36 | self.neighbors_info = conf.graph.get_neighborhood() 37 | 38 | # define the aggregator. 39 | self.decentralized_aggregator = comm.get_aggregators( 40 | cur_rank=self.rank, 41 | world=conf.graph.ranks, 42 | neighbors_info=self.neighbors_info, 43 | aggregator_type="decentralized", 44 | ) 45 | self.world_aggregator = comm.get_aggregators( 46 | cur_rank=self.rank, 47 | world=conf.graph.ranks, 48 | neighbors_info=dict( 49 | (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks 50 | ), 51 | aggregator_type="centralized", 52 | ) 53 | 54 | # define reducer. 55 | self.backend = conf.backend 56 | 57 | # define sorted param names. 58 | self.param_names = list( 59 | enumerate([group["name"] for group in self.param_groups]) 60 | ) 61 | 62 | def __setstate__(self, state): 63 | super(SGD, self).__setstate__(state) 64 | for group in self.param_groups: 65 | group.setdefault("nesterov", False) 66 | 67 | def step(self, closure=None, **kargs): 68 | if self.conf.is_centralized: 69 | with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_): 70 | # Get data. 71 | grads, _ = comm.get_data( 72 | self.param_groups, self.param_names, is_get_grad=True 73 | ) 74 | flatten_grads = TensorBuffer(grads) 75 | 76 | with kargs["timer"]("sync.sync", epoch=self.conf.epoch_): 77 | # Aggregate the gradients. 78 | flatten_grads.buffer = self.world_aggregator._agg( 79 | flatten_grads.buffer, op="avg", distributed=self.conf.distributed 80 | ) 81 | 82 | with kargs["timer"]("sync.unflatten_grad", epoch=self.conf.epoch_): 83 | # unflatten grads. 84 | flatten_grads.unpack(grads) 85 | 86 | with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): 87 | utils.apply_gradient( 88 | self.param_groups, self.state, apply_grad_to_model=True 89 | ) 90 | 91 | # Get n_bits to transmit. 92 | n_bits = get_n_bits(flatten_grads.buffer) 93 | else: 94 | with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): 95 | utils.apply_gradient( 96 | self.param_groups, self.state, apply_grad_to_model=True 97 | ) 98 | 99 | with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_): 100 | # first get and flatten all params. 101 | params, _ = comm.get_data( 102 | self.param_groups, self.param_names, is_get_grad=False 103 | ) 104 | flatten_params = TensorBuffer(params) 105 | 106 | with kargs["timer"]("sync.sync", epoch=self.conf.epoch_): 107 | # prepare the sync. 108 | if self.conf.comm_device == "cpu": 109 | flatten_params.buffer.cpu().detach_() 110 | 111 | # then sync. 112 | flatten_params.buffer = self.decentralized_aggregator._agg( 113 | flatten_params.buffer, op="weighted" 114 | ) 115 | 116 | with kargs["timer"]("sync.update_model", epoch=self.conf.epoch_): 117 | # finally unflatten. 118 | flatten_params.unpack(params) 119 | 120 | # Get n_bits to transmit. 121 | n_bits = get_n_bits(flatten_params.buffer) 122 | return n_bits 123 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/sign_sgd.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer, required 6 | 7 | import pcode.optim.utils as utils 8 | import pcode.utils.communication as comm 9 | from pcode.utils.sparsification import get_n_bits, SignCompressor 10 | from pcode.utils.tensor_buffer import TensorBuffer 11 | 12 | 13 | class SignSGD(Optimizer): 14 | def __init__( 15 | self, 16 | params, 17 | lr=required, 18 | momentum=0, 19 | dampening=0, 20 | weight_decay=0, 21 | nesterov=False, 22 | conf=None, 23 | model=None, 24 | ): 25 | defaults = dict( 26 | lr=lr, 27 | momentum=momentum, 28 | dampening=dampening, 29 | weight_decay=weight_decay, 30 | nesterov=nesterov, 31 | ) 32 | if nesterov and (momentum <= 0 or dampening != 0): 33 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 34 | super(SignSGD, self).__init__(params, defaults) 35 | 36 | # store the whole training arguments. 37 | self.conf = conf 38 | self.rank = conf.graph.rank 39 | self.neighbors_info = conf.graph.get_neighborhood() 40 | self.local_step = conf.local_step 41 | self.turn_on_local_step_from_epoch = conf.turn_on_local_step_from 42 | 43 | # define the aggregator. 44 | self.world_aggregator = comm.get_aggregators( 45 | cur_rank=self.rank, 46 | world=conf.graph.ranks, 47 | neighbors_info=dict( 48 | (rank, 1.0 / conf.graph.n_nodes) for rank in conf.graph.ranks 49 | ), 50 | aggregator_type="centralized", 51 | ) 52 | 53 | # define sorted param names. 54 | self.param_names = list( 55 | enumerate([group["name"] for group in self.param_groups]) 56 | ) 57 | 58 | # initialize the concensus 59 | self.compressor = ExactSignCompressor( 60 | rank=self.rank, 61 | world_size=len(conf.graph.ranks), 62 | majority_vote=conf.majority_vote, 63 | aggregator=self.world_aggregator, 64 | comm_op=conf.comm_op, 65 | comm_device=self.conf.comm_device, 66 | use_ipc=conf.use_ipc, 67 | ) 68 | 69 | def __setstate__(self, state): 70 | super(SignSGD, self).__setstate__(state) 71 | for group in self.param_groups: 72 | group.setdefault("nesterov", False) 73 | 74 | def step(self, closure=None, **kargs): 75 | # do the local update steps. 76 | with kargs["timer"]("sync.get_data", epoch=self.conf.epoch_): 77 | # get parmas. 78 | params, _ = comm.get_data( 79 | self.param_groups, self.param_names, is_get_grad=False 80 | ) 81 | params_tb = TensorBuffer(params) 82 | 83 | with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): 84 | # prepare the gradient (sign) 85 | utils.apply_gradient( 86 | self.param_groups, self.state, apply_grad_to_model=False 87 | ) 88 | # get grads. 89 | grads, _ = comm.get_data( 90 | self.param_groups, self.param_names, is_get_grad=True 91 | ) 92 | grads_tb = TensorBuffer(grads) 93 | 94 | # enter the global sync if it satisfies the condition. 95 | # get the params difference w.r.t. previous synced model. 96 | with kargs["timer"]("sync.compress", epoch=self.conf.epoch_): 97 | sync_buffer = self.compressor.compress(grads_tb) 98 | 99 | # sync and decompress. 100 | with kargs["timer"]("sync.sync_and_decompress", epoch=self.conf.epoch_): 101 | self.compressor.sync(sync_buffer) 102 | synced_updates_tb = self.compressor.decompress(sync_buffer) 103 | 104 | # unpack the synced info and update the consensus params. 105 | with kargs["timer"]("sync.apply_grad", epoch=self.conf.epoch_): 106 | params_tb.buffer -= self.param_groups[0]["lr"] * synced_updates_tb.buffer 107 | params_tb.unpack(params) 108 | return sync_buffer["n_bits"] 109 | 110 | 111 | class ExactSignCompressor(object): 112 | def __init__( 113 | self, 114 | rank, 115 | world_size, 116 | majority_vote, 117 | aggregator, 118 | comm_op, 119 | comm_device, 120 | use_ipc, 121 | **kargs 122 | ): 123 | # assign the common hyper-parameters 124 | self.rank = rank 125 | self.world_size = world_size 126 | self.majority_vote = majority_vote 127 | self.aggregator_fn = aggregator 128 | self.comm_op = comm_op 129 | self.comm_device = comm_device 130 | self.use_ipc = use_ipc 131 | self.kargs = kargs 132 | self.compressor_fn = SignCompressor() 133 | 134 | def compress(self, grads_tb): 135 | # get the sign/magnitude for the tensor (to be transmitted). 136 | sync_buffer = dict() 137 | 138 | # concat the update magnitude and directions. 139 | signs, sign_size = self.compressor_fn.compress(grads_tb.buffer) 140 | 141 | # get n_bits to transmit. 142 | n_bits = get_n_bits(signs) 143 | 144 | # update shared dict. 145 | sync_buffer["grads_tb"] = grads_tb 146 | sync_buffer["signs"] = signs 147 | sync_buffer["sign_size"] = sign_size 148 | sync_buffer["n_bits"] = n_bits 149 | return sync_buffer 150 | 151 | def sync(self, sync_buffer): 152 | # prepare sync. 153 | to_sync_signs = sync_buffer["signs"] 154 | if self.comm_device == "cpu": 155 | to_sync_signs = to_sync_signs.cpu().pin_memory() 156 | 157 | # sync. 158 | synced_signs, sync_req = self.aggregator_fn._agg( 159 | to_sync_signs, communication_scheme="all_gather", async_op=True 160 | ) 161 | 162 | # update sync_buffer. 163 | sync_buffer["sync_req"] = sync_req 164 | sync_buffer["synced_signs"] = synced_signs 165 | 166 | def decompress(self, sync_buffer): 167 | # wait the sync. 168 | self.aggregator_fn.complete_wait(sync_buffer["sync_req"]) 169 | 170 | # init placeholder. 171 | synced_updates_tb = deepcopy(sync_buffer["grads_tb"]) 172 | synced_updates_tb.buffer = torch.zeros_like(synced_updates_tb.buffer) 173 | 174 | # decompress and update. 175 | for rank in range(self.world_size): 176 | # get signs and build its tensorbuffer. 177 | synced_updates_tb.buffer += self.compressor_fn.uncompress( 178 | comm.recover_device( 179 | sync_buffer["synced_signs"][rank], 180 | device=sync_buffer["grads_tb"].buffer.device, 181 | ), 182 | sync_buffer["sign_size"], 183 | ) 184 | 185 | # average grad. 186 | if self.majority_vote: 187 | synced_updates_tb.buffer = torch.sign(synced_updates_tb.buffer) 188 | else: 189 | synced_updates_tb.buffer /= self.world_size * 1.0 190 | return synced_updates_tb 191 | -------------------------------------------------------------------------------- /distributed_code/pcode/optim/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import threading 3 | 4 | import torch 5 | 6 | from pcode.utils.tensor_buffer import TensorBuffer 7 | import pcode.utils.communication as comm 8 | 9 | 10 | """common utilities""" 11 | 12 | 13 | def apply_gradient(param_groups, state, apply_grad_to_model=True): 14 | for group in param_groups: 15 | weight_decay = group["weight_decay"] 16 | momentum = group["momentum"] 17 | dampening = group["dampening"] 18 | nesterov = group["nesterov"] 19 | 20 | for p in group["params"]: 21 | if p.grad is None: 22 | continue 23 | d_p = p.grad.data 24 | 25 | # get param_state 26 | param_state = state[p] 27 | 28 | # add weight decay. 29 | if weight_decay != 0: 30 | d_p.add_(weight_decay, p.data) 31 | 32 | # apply the momentum. 33 | if momentum != 0: 34 | if "momentum_buffer" not in param_state: 35 | buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 36 | buf.mul_(momentum).add_(d_p) 37 | else: 38 | buf = param_state["momentum_buffer"] 39 | buf.mul_(momentum).add_(1 - dampening, d_p) 40 | if nesterov: 41 | d_p = d_p.add(momentum, buf) 42 | else: 43 | d_p = buf 44 | if apply_grad_to_model: 45 | p.data.add_(-group["lr"], d_p) 46 | else: 47 | p.grad.data = d_p 48 | 49 | 50 | def recover_params( 51 | param_groups, param_names, rank=None, neighbor_hat_params=None, get_hat_params=True 52 | ): 53 | # get flattened params. 54 | params, _ = comm.get_data(param_groups, param_names, is_get_grad=False) 55 | flatten_params = TensorBuffer(params) 56 | 57 | if get_hat_params: 58 | assert neighbor_hat_params is not None and rank is not None 59 | # recover the hat_params. 60 | flatten_hat_params = TensorBuffer(params) 61 | flatten_hat_params.buffer.data[:] = neighbor_hat_params[rank].buffer 62 | return params, flatten_params, flatten_hat_params 63 | else: 64 | return params, flatten_params 65 | 66 | 67 | def update_params_from_neighbor( 68 | neighbor_hat_params, flatten_params, consensus_stepsize, self_rank 69 | ): 70 | flatten_params.buffer += consensus_stepsize * ( 71 | neighbor_hat_params["memory"].buffer - neighbor_hat_params[self_rank].buffer 72 | ) 73 | 74 | 75 | """utilities for parallel choco.""" 76 | 77 | 78 | class HelperThread(threading.Thread): 79 | def __init__(self, name, func, *args, **kargs): 80 | threading.Thread.__init__(self) 81 | self.name = name 82 | self.func = func 83 | 84 | # task-related. 85 | self.args = args 86 | self.kargs = kargs 87 | 88 | def run(self): 89 | self.func(**self.kargs) 90 | 91 | 92 | def join_thread(thread): 93 | if thread is None: 94 | return False 95 | thread.join() 96 | return True 97 | -------------------------------------------------------------------------------- /distributed_code/pcode/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/tools/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/tools/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from operator import itemgetter 3 | import numpy as np 4 | 5 | from pcode.tools.show_results import reorder_records 6 | from pcode.tools.plot_utils import \ 7 | determine_color_and_lines, plot_one_case, \ 8 | smoothing_func, configure_figure, build_legend, groupby_indices 9 | 10 | 11 | """plot the curve in terms of time.""" 12 | 13 | 14 | def plot_curve_wrt_time( 15 | ax, records, 16 | x_wrt_sth, y_wrt_sth, xlabel, ylabel, title=None, markevery_list=None, 17 | is_smooth=True, smooth_space=100, l_subset=0.0, r_subset=1.0, 18 | reorder_record_item=None, remove_duplicate=True, legend=None, 19 | legend_loc='lower right', legend_ncol=2, bbox_to_anchor=[0, 0], 20 | ylimit_bottom=None, ylimit_top=None, use_log=False): 21 | """Each info consists of 22 | ['tr_loss', 'tr_top1', 'tr_time', 'te_top1', 'te_step', 'te_time']. 23 | """ 24 | # parse a list of records. 25 | num_records = len(records) 26 | distinct_conf_set = set() 27 | 28 | # re-order the records. 29 | if reorder_record_item is not None: 30 | records = reorder_records(records, based_on=reorder_record_item) 31 | 32 | for ind, (args, info) in enumerate(records): 33 | # build legend. 34 | _legend = build_legend(args, legend) 35 | if _legend in distinct_conf_set and remove_duplicate: 36 | continue 37 | else: 38 | distinct_conf_set.add(_legend) 39 | 40 | # determine the style of line, color and marker. 41 | line_style, color_style, mark_style = determine_color_and_lines( 42 | num_rows=num_records // 3, num_cols=3, ind=ind) 43 | 44 | if markevery_list is not None: 45 | mark_every = markevery_list[ind] 46 | else: 47 | mark_style = None 48 | mark_every = None 49 | 50 | # determine if we want to smooth the curve. 51 | if 'tr_step' in x_wrt_sth or 'tr_epoch' in x_wrt_sth: 52 | info['tr_step'] = list(range(1, 1 + len(info['tr_loss']))) 53 | if 'tr_epoch' == x_wrt_sth: 54 | x = info['tr_step'] 55 | x = [1.0 * _x / args['num_batches_train_per_device_per_epoch'] for _x in x] 56 | else: 57 | x = info[x_wrt_sth] 58 | if 'time' in x_wrt_sth: 59 | x = [(time - x[0]).seconds + 1 for time in x] 60 | y = info[y_wrt_sth] 61 | 62 | if is_smooth: 63 | x, y = smoothing_func(x, y, smooth_space) 64 | 65 | # only plot subtset. 66 | _l_subset, _r_subset = int(len(x) * l_subset), int(len(x) * r_subset) 67 | _x = x[_l_subset: _r_subset] 68 | _y = y[_l_subset: _r_subset] 69 | 70 | # use log scale for y 71 | if use_log: 72 | _y = np.log(_y) 73 | 74 | # plot 75 | ax = plot_one_case( 76 | ax, x=_x, y=_y, 77 | label=_legend, 78 | line_style=line_style, color_style=color_style, 79 | mark_style=mark_style, mark_every=mark_every, 80 | remove_duplicate=remove_duplicate) 81 | 82 | ax.set_ylim(bottom=ylimit_bottom, top=ylimit_top) 83 | ax = configure_figure( 84 | ax, xlabel=xlabel, ylabel=ylabel, title=title, 85 | has_legend=legend is not None, 86 | legend_loc=legend_loc, legend_ncol=legend_ncol, 87 | bbox_to_anchor=bbox_to_anchor 88 | ) 89 | return ax 90 | 91 | 92 | def plot_by_global_minibatch_size( 93 | averaged_records, attributes, 94 | ax, xlabel='', ylabel='', title='', mark_size=60, 95 | legend_loc='lower right', legend_ncol=2, bbox_to_anchor=[0, 0]): 96 | 97 | def extract_values(df, attributes=None): 98 | # extract the value. 99 | values = [] 100 | raw_values = df.values.tolist() 101 | 102 | for raw_value in raw_values: 103 | tmp_value = dict( 104 | (attr, raw_value[attr_id]) 105 | for attr_id, attr in enumerate(attributes)) 106 | tmp_value['global_batch_size'] = tmp_value['batch_size'] * tmp_value['n_nodes'] 107 | tmp_value['top1_acc'] = raw_value[-4] 108 | values += [tmp_value] 109 | return values 110 | 111 | x_range = set() 112 | 113 | # extract all results. 114 | extracted_records = extract_values(averaged_records, attributes) 115 | extracted_records = [ 116 | (extracted_record['n_nodes'], extracted_record['learning_rate'], 117 | extracted_record['global_batch_size'], extracted_record['top1_acc']) 118 | for extracted_record in extracted_records] 119 | 120 | # extract best results. 121 | grouped_records = [ 122 | g[1] for g in groupby_indices(extracted_records, itemgetter(0, 1, 2))] 123 | cleaned_records = [ 124 | max(grouped_record, key=lambda x: x[-1]) 125 | for grouped_record in grouped_records 126 | ] 127 | 128 | # re-group for final plot. 129 | grouped_records = [ 130 | g[1] for g in groupby_indices(cleaned_records, itemgetter(1)) 131 | ] 132 | 133 | for ind, info in enumerate(grouped_records): 134 | n_workers, lr = info[0][0], info[0][1] 135 | 136 | line_style, color_style, mark_style = determine_color_and_lines( 137 | num_rows=len(grouped_records) // 3, num_cols=3, ind=ind) 138 | 139 | x = [int(i[-2]) for i in info] 140 | x_range.update(x) 141 | y = [i[-1] for i in info] 142 | 143 | ax.scatter( 144 | x, y, c=color_style, s=mark_size, marker=mark_style, 145 | label='$lr={}$'.format(lr)) 146 | 147 | ax.set_xscale('log', basex=2) 148 | ax.set_xticks(list(x_range)) 149 | ax.set_xticklabels(list(x_range)) 150 | configure_figure( 151 | ax, xlabel=xlabel, ylabel=ylabel, title=title, 152 | legend_loc=legend_loc, legend_ncol=legend_ncol, 153 | bbox_to_anchor=bbox_to_anchor) 154 | -------------------------------------------------------------------------------- /distributed_code/pcode/tools/plot_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from matplotlib.lines import Line2D 4 | from itertools import groupby 5 | 6 | import seaborn as sns 7 | 8 | """operate x and y.""" 9 | 10 | 11 | def smoothing_func(x, y, smooth_length=10): 12 | def smoothing(end_index): 13 | # print(end_index) 14 | if end_index - smooth_length < 0: 15 | start_index = 0 16 | else: 17 | start_index = end_index - smooth_length 18 | 19 | data = y[start_index:end_index] 20 | if len(data) == 0: 21 | return y[start_index] 22 | else: 23 | return 1.0 * sum(data) / len(data) 24 | 25 | # smooth curve 26 | x_, y_ = [], [] 27 | 28 | for end_ind in range(0, len(x)): 29 | x_.append(x[end_ind]) 30 | y_.append(smoothing(end_ind)) 31 | return x_, y_ 32 | 33 | 34 | def reject_outliers(data, threshold=3): 35 | return data[abs(data - np.mean(data)) < threshold * np.std(data)] 36 | 37 | 38 | def groupby_indices(results, grouper): 39 | """group by indices and select the subset parameters 40 | """ 41 | out = [] 42 | for key, group in groupby(sorted(results, key=grouper), grouper): 43 | group_item = list(group) 44 | out += [(key, group_item)] 45 | return out 46 | 47 | 48 | def find_same_num_sync(num_update_steps_and_local_step): 49 | list_of_num_sync = [ 50 | num_update_steps // local_step 51 | for num_update_steps, local_step in num_update_steps_and_local_step 52 | ] 53 | return min(list_of_num_sync) 54 | 55 | 56 | def sample_from_records(x, y, local_step, max_same_num_sync): 57 | # cut the records. 58 | if max_same_num_sync is not None: 59 | x = x[: local_step * max_same_num_sync] 60 | y = y[: local_step * max_same_num_sync] 61 | return x[::local_step], y[::local_step] 62 | 63 | 64 | def drop_first_few(x, y, num_drop): 65 | return x[num_drop:], y[num_drop:] 66 | 67 | 68 | def rebuild_runtime_record(times): 69 | times = [(time - times[0]).seconds + 1 for time in times] 70 | return times 71 | 72 | 73 | def add_communication_delay(times, local_step, delay_factor): 74 | """add communication delay to original time.""" 75 | return [ 76 | time + delay_factor * ((ind + 1) // local_step) 77 | for ind, time in enumerate(times) 78 | ] 79 | 80 | 81 | """plot style related.""" 82 | 83 | 84 | def determine_color_and_lines(num_rows, num_cols, ind): 85 | line_styles = ["-", "--", "-.", ":"] 86 | color_styles = [ 87 | "#377eb8", 88 | "#ff7f00", 89 | "#4daf4a", 90 | "#f781bf", 91 | "#a65628", 92 | "#984ea3", 93 | "#999999", 94 | "#e41a1c", 95 | "#dede00", 96 | ] 97 | 98 | num_line_styles = len(line_styles) 99 | num_color_styles = len(color_styles) 100 | total_num_combs = num_line_styles * num_color_styles 101 | 102 | assert total_num_combs > num_rows * num_cols 103 | 104 | if max(num_rows, num_cols) > max(num_line_styles, num_color_styles): 105 | row = ind // num_line_styles 106 | col = ind % num_line_styles 107 | # print('plot {}. case 1, row: {}, col: {}'.format(ind, row, col)) 108 | return line_styles[row], color_styles[col], Line2D.filled_markers[ind] 109 | 110 | denominator = max(num_rows, num_cols) 111 | row = ind // denominator 112 | col = ind % denominator 113 | # print('plot {}. case 2, row: {}, col: {}'.format(ind, row, col)) 114 | return line_styles[row], color_styles[col], Line2D.filled_markers[ind] 115 | 116 | 117 | def configure_figure( 118 | ax, 119 | xlabel, 120 | ylabel, 121 | title=None, 122 | has_legend=True, 123 | legend_loc="lower right", 124 | legend_ncol=2, 125 | bbox_to_anchor=[0, 0], 126 | ): 127 | if has_legend: 128 | ax.legend( 129 | loc=legend_loc, 130 | bbox_to_anchor=bbox_to_anchor, 131 | ncol=legend_ncol, 132 | shadow=True, 133 | fancybox=True, 134 | fontsize=20, 135 | ) 136 | 137 | ax.set_xlabel(xlabel, fontsize=24, labelpad=18) 138 | ax.set_ylabel(ylabel, fontsize=24, labelpad=18) 139 | 140 | if title is not None: 141 | ax.set_title(title, fontsize=24) 142 | ax.xaxis.set_tick_params(labelsize=22) 143 | ax.yaxis.set_tick_params(labelsize=22) 144 | return ax 145 | 146 | 147 | def plot_one_case( 148 | ax, 149 | label, 150 | line_style, 151 | color_style, 152 | mark_style, 153 | line_width=2.0, 154 | mark_every=5000, 155 | x=None, 156 | y=None, 157 | sns_plot=None, 158 | remove_duplicate=False, 159 | ): 160 | if sns_plot is not None and not remove_duplicate: 161 | ax = sns.lineplot( 162 | x="x", 163 | y="y", 164 | data=sns_plot, 165 | label=label, 166 | linewidth=line_width, 167 | linestyle=line_style, 168 | color=color_style, 169 | marker=mark_style, 170 | markevery=mark_every, 171 | markersize=16, 172 | ax=ax, 173 | ) 174 | elif sns_plot is not None and remove_duplicate: 175 | ax = sns.lineplot( 176 | x="x", 177 | y="y", 178 | data=sns_plot, 179 | label=label, 180 | linewidth=line_width, 181 | linestyle=line_style, 182 | color=color_style, 183 | marker=mark_style, 184 | markevery=mark_every, 185 | markersize=16, 186 | ax=ax, 187 | estimator=None, 188 | ) 189 | else: 190 | ax.plot( 191 | x, 192 | y, 193 | label=label, 194 | linewidth=line_width, 195 | linestyle=line_style, 196 | color=color_style, 197 | marker=mark_style, 198 | markevery=mark_every, 199 | markersize=16, 200 | ) 201 | return ax 202 | 203 | 204 | def build_legend(args, legend): 205 | legend = legend.split(",") 206 | 207 | my_legend = [] 208 | for _legend in legend: 209 | _legend_content = args[_legend] 210 | my_legend += [ 211 | "{}={}".format( 212 | _legend, 213 | list(_legend_content)[0] 214 | if "pandas" in str(type(_legend_content)) 215 | else _legend_content, 216 | ) 217 | ] 218 | return ", ".join(my_legend) 219 | 220 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/pcode/utils/__init__.py -------------------------------------------------------------------------------- /distributed_code/pcode/utils/auxiliary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def deepcopy_model(conf, model): 10 | # a dirty hack.... 11 | tmp_model = deepcopy(model) 12 | if conf.track_model_aggregation: 13 | for tmp_para, para in zip(tmp_model.parameters(), model.parameters()): 14 | tmp_para.grad = para.grad.clone() 15 | return tmp_model 16 | 17 | 18 | def get_model_difference(model1, model2): 19 | list_of_tensors = [] 20 | for weight1, weight2 in zip(model1.parameters(), 21 | model2.parameters()): 22 | tensor = get_diff_weights(weight1, weight2) 23 | list_of_tensors.append(tensor) 24 | return list_to_vec(list_of_tensors).norm().item() 25 | 26 | 27 | def get_diff_weights(weights1, weights2): 28 | """ Produce a direction from 'weights1' to 'weights2'.""" 29 | if isinstance(weights1, list) and isinstance(weights2, list): 30 | return [w2 - w1 for (w1, w2) in zip(weights1, weights2)] 31 | elif isinstance(weights1, torch.Tensor) and isinstance(weights2, torch.Tensor): 32 | return weights2 - weights1 33 | else: 34 | raise NotImplementedError 35 | 36 | 37 | def get_diff_states(states1, states2): 38 | """ Produce a direction from 'states1' to 'states2'.""" 39 | return [ 40 | v2 - v1 41 | for (k1, v1), (k2, v2) in zip(states1.items(), states2.items()) 42 | ] 43 | 44 | 45 | def list_to_vec(weights): 46 | """ Concatnate a numpy list of weights of all layers into one torch vector. 47 | """ 48 | v = [] 49 | direction = [d * np.float64(1.0) for d in weights] 50 | for w in direction: 51 | if isinstance(w, np.ndarray): 52 | w = torch.tensor(w) 53 | else: 54 | w = w.clone().detach() 55 | if w.dim() > 1: 56 | v.append(w.view(w.numel())) 57 | elif w.dim() == 1: 58 | v.append(w) 59 | return torch.cat(v) 60 | 61 | 62 | def str2time(string, pattern): 63 | """convert the string to the datetime.""" 64 | return datetime.strptime(string, pattern) 65 | 66 | 67 | def get_fullname(o): 68 | """get the full name of the class.""" 69 | return '%s.%s' % (o.__module__, o.__class__.__name__) 70 | 71 | 72 | def is_float(value): 73 | try: 74 | float(value) 75 | return True 76 | except: 77 | return False 78 | 79 | 80 | class dict2obj(object): 81 | def __init__(self, d): 82 | for a, b in d.items(): 83 | if isinstance(b, (list, tuple)): 84 | setattr(self, a, 85 | [dict2obj(x) if isinstance(x, dict) else x for x in b]) 86 | else: 87 | setattr(self, a, dict2obj(b) if isinstance(b, dict) else b) 88 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import gc 3 | import shutil 4 | import time 5 | from os.path import join, isfile 6 | 7 | import torch 8 | 9 | import pcode.utils.logging as logging 10 | from pcode.utils.op_paths import build_dirs, remove_folder 11 | from pcode.utils.op_files import write_pickle 12 | 13 | 14 | def get_checkpoint_folder_name(conf): 15 | # get time_id 16 | time_id = str(int(time.time())) 17 | 18 | # get communication info. 19 | if conf.comm_op is None: 20 | comm_info = "" 21 | elif "compress" in conf.comm_op: 22 | comm_info = "{}-{}_".format(conf.comm_op, conf.compress_ratio) 23 | comm_info += "warmup_epochs-{}".format(conf.compress_warmup_epochs) 24 | comm_info += "_mask_momentum" if conf.mask_momentum else "" 25 | comm_info += ( 26 | "_clip_grad-{}".format(conf.clip_grad_val) if conf.clip_grad else "" 27 | ) 28 | elif conf.comm_op == "quantize_qsgd": 29 | comm_info = "{}-{}_".format(conf.comm_op, conf.quantize_level) 30 | elif conf.comm_op == "sign": 31 | comm_info = "{}_".format(conf.comm_op) 32 | else: 33 | comm_info = "" 34 | 35 | # get optimizer info. 36 | if "choco" in conf.optimizer: 37 | optim_info = "{}-stepsize-{}".format(conf.optimizer, conf.consensus_stepsize) 38 | else: 39 | optim_info = "{}".format(conf.optimizer) 40 | 41 | # concat them together. 42 | return ( 43 | time_id 44 | + "_l2-{}_lr-{}_epochs-{}_batchsize-{}_basebatchsize-{}_num_mpi_process_{}_n_sub_process-{}_topology-{}_optim-{}_comm_info-{}".format( 45 | conf.weight_decay, 46 | conf.lr, 47 | conf.num_epochs, 48 | conf.batch_size, 49 | conf.base_batch_size, 50 | conf.n_mpi_process, 51 | conf.n_sub_process, 52 | conf.graph_topology, 53 | optim_info, 54 | comm_info, 55 | ) 56 | ) 57 | 58 | 59 | def init_checkpoint(conf): 60 | # init checkpoint dir. 61 | conf.checkpoint_root = join( 62 | conf.checkpoint, 63 | conf.data, 64 | conf.arch, 65 | conf.experiment if conf.experiment is not None else "", 66 | conf.timestamp, 67 | ) 68 | conf.checkpoint_dir = join(conf.checkpoint_root, str(conf.graph.rank)) 69 | if conf.save_some_models is not None: 70 | conf.save_some_models = conf.save_some_models.split(",") 71 | 72 | # if the directory does not exists, create them. 73 | build_dirs(conf.checkpoint_dir) 74 | 75 | 76 | def _save_to_checkpoint(state, dirname, filename): 77 | checkpoint_path = join(dirname, filename) 78 | torch.save(state, checkpoint_path) 79 | return checkpoint_path 80 | 81 | 82 | def save_arguments(conf): 83 | # save the configure file to the checkpoint. 84 | if conf.graph.rank == 0: 85 | write_pickle(conf, path=join(conf.checkpoint_root, "arguments.pickle")) 86 | 87 | 88 | def save_to_checkpoint(conf, state, is_best, dirname, filename, save_all=False): 89 | # save full state. 90 | checkpoint_path = _save_to_checkpoint(state, dirname, filename) 91 | best_model_path = join(dirname, "model_best.pth.tar") 92 | if is_best: 93 | shutil.copyfile(checkpoint_path, best_model_path) 94 | if save_all: 95 | shutil.copyfile( 96 | checkpoint_path, 97 | join(dirname, "checkpoint_epoch_%s.pth.tar" % state["current_epoch"]), 98 | ) 99 | elif conf.save_some_models is not None: 100 | if str(state["current_epoch"]) in conf.save_some_models: 101 | shutil.copyfile( 102 | checkpoint_path, 103 | join(dirname, "checkpoint_epoch_%s.pth.tar" % state["current_epoch"]), 104 | ) 105 | 106 | 107 | def maybe_resume_from_checkpoint(conf, model, optimizer, scheduler): 108 | if conf.resume: 109 | if conf.checkpoint_index is not None: 110 | # reload model from a specific checkpoint index. 111 | checkpoint_index = "_epoch_" + conf.checkpoint_index 112 | else: 113 | # reload model from the latest checkpoint. 114 | checkpoint_index = "" 115 | checkpoint_path = join( 116 | conf.resume, 117 | str(conf.graph.rank), 118 | "checkpoint{}.pth.tar".format(checkpoint_index), 119 | ) 120 | print("try to load previous model from the path:{}".format(checkpoint_path)) 121 | 122 | if isfile(checkpoint_path): 123 | print( 124 | "=> loading checkpoint {} for {}".format(conf.resume, conf.graph.rank) 125 | ) 126 | 127 | # get checkpoint. 128 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 129 | 130 | # restore some run-time info. 131 | scheduler.update_from_checkpoint(checkpoint) 132 | 133 | # reset path for log. 134 | try: 135 | remove_folder(conf.checkpoint_root) 136 | except RuntimeError as e: 137 | print(f"ignore the error={e}") 138 | conf.checkpoint_root = conf.resume 139 | conf.checkpoint_dir = join(conf.resume, str(conf.graph.rank)) 140 | # restore model. 141 | model.load_state_dict(checkpoint["state_dict"]) 142 | # restore optimizer. 143 | optimizer.load_state_dict(checkpoint["optimizer"]) 144 | # logging. 145 | print( 146 | "=> loaded model from path '{}' checkpointed at (epoch {})".format( 147 | conf.resume, checkpoint["current_epoch"] 148 | ) 149 | ) 150 | # configure logger. 151 | conf.logger = logging.Logger(conf.checkpoint_dir) 152 | 153 | # try to solve memory issue. 154 | del checkpoint 155 | torch.cuda.empty_cache() 156 | gc.collect() 157 | return 158 | else: 159 | print("=> no checkpoint found at '{}'".format(conf.resume)) 160 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/error_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | 4 | 5 | def global_except_hook(exctype, value, traceback): 6 | import sys 7 | 8 | try: 9 | import mpi4py.MPI 10 | 11 | sys.stderr.write("\n*****************************************************\n") 12 | sys.stderr.write( 13 | "Uncaught exception was detected on rank {}. \n".format( 14 | mpi4py.MPI.COMM_WORLD.Get_rank() 15 | ) 16 | ) 17 | from traceback import print_exception 18 | 19 | print_exception(exctype, value, traceback) 20 | sys.stderr.write("*****************************************************\n\n\n") 21 | sys.stderr.write("\n") 22 | sys.stderr.write("Calling MPI_Abort() to shut down MPI processes...\n") 23 | sys.stderr.flush() 24 | finally: 25 | try: 26 | import mpi4py.MPI 27 | 28 | mpi4py.MPI.COMM_WORLD.Abort(1) 29 | except Exception as e: 30 | sys.stderr.write("*****************************************************\n") 31 | sys.stderr.write("Sorry, we failed to stop MPI, this process will hang.\n") 32 | sys.stderr.write("*****************************************************\n") 33 | sys.stderr.flush() 34 | raise e 35 | 36 | 37 | def abort(): 38 | try: 39 | import mpi4py.MPI 40 | 41 | mpi4py.MPI.COMM_WORLD.Abort(1) 42 | except Exception as e: 43 | sys.stderr.write("*****************************************************\n") 44 | sys.stderr.write("Sorry, we failed to stop MPI, this process will hang.\n") 45 | sys.stderr.write("*****************************************************\n") 46 | sys.stderr.flush() 47 | raise e 48 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import time 5 | import platform 6 | 7 | from pcode.utils.op_files import write_txt 8 | 9 | 10 | class Logger: 11 | """ 12 | Very simple prototype logger that will store the values to a JSON file 13 | """ 14 | 15 | def __init__(self, file_folder): 16 | """ 17 | :param filename: ending with .json 18 | :param auto_save: save the JSON file after every addition 19 | """ 20 | self.file_folder = file_folder 21 | self.file_json = os.path.join(file_folder, "log-1.json") 22 | self.file_txt = os.path.join(file_folder, "log.txt") 23 | self.values = [] 24 | 25 | def log_metric(self, name, values, tags, display=False): 26 | """ 27 | Store a scalar metric 28 | 29 | :param name: measurement, like 'accuracy' 30 | :param values: dictionary, like { epoch: 3, value: 0.23 } 31 | :param tags: dictionary, like { split: train } 32 | """ 33 | self.values.append({"measurement": name, **values, **tags}) 34 | 35 | if display: 36 | print( 37 | "{name}: {values} ({tags})".format(name=name, values=values, tags=tags) 38 | ) 39 | 40 | def log(self, value): 41 | content = time.strftime("%Y-%m-%d %H:%M:%S") + "\t" + value 42 | print(content) 43 | self.save_txt(content) 44 | 45 | def save_json(self): 46 | """Save the internal memory to a file.""" 47 | with open(self.file_json, "w") as fp: 48 | json.dump(self.values, fp, indent=" ") 49 | 50 | if len(self.values) > 1e3: 51 | # reset 'values' and redirect the json file to other name. 52 | self.values = [] 53 | self.redirect_new_json() 54 | 55 | def save_txt(self, value): 56 | write_txt(value + "\n", self.file_txt, type="a") 57 | 58 | def redirect_new_json(self): 59 | """get the number of existing json files under the current folder.""" 60 | existing_json_files = [ 61 | file for file in os.listdir(self.file_folder) if "json" in file 62 | ] 63 | self.file_json = os.path.join( 64 | self.file_folder, "log-{}.json".format(len(existing_json_files) + 1) 65 | ) 66 | 67 | 68 | def display_args(conf): 69 | if conf.graph.rank == 0: 70 | print("\n\nparameters: ") 71 | for arg in vars(conf): 72 | print("\t" + str(arg) + "\t" + str(getattr(conf, arg))) 73 | 74 | print( 75 | "\n\nexperiment platform: rank {} on {} {}-{}".format( 76 | conf.graph.rank, 77 | platform.node(), 78 | "GPU" if conf.graph.on_cuda else "CPU", 79 | conf.graph.device, 80 | ) 81 | ) 82 | for name in [ 83 | "n_nodes", 84 | "world", 85 | "rank", 86 | "device", 87 | "on_cuda", 88 | "get_neighborhood", 89 | ]: 90 | print("\t{}: {}".format(name, getattr(conf.graph, name))) 91 | print("\n\n") 92 | 93 | 94 | def display_training_stat(conf, scheduler, tracker, n_bits_to_transmit): 95 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 96 | 97 | # display the runtime training information. 98 | conf.logger.log_metric( 99 | name="runtime", 100 | values={ 101 | "time": current_time, 102 | "rank": conf.graph.rank, 103 | "epoch": scheduler.epoch_, 104 | "local_index": scheduler.local_index, 105 | "n_bits_to_transmit": n_bits_to_transmit / 8 / (2 ** 20), 106 | **tracker(), 107 | }, 108 | tags={"split": "train"}, 109 | display=True, 110 | ) 111 | 112 | 113 | def display_test_stat(conf, scheduler, tracker, label="local"): 114 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 115 | 116 | # display the runtime training information. 117 | conf.logger.log_metric( 118 | name="runtime", 119 | values={ 120 | "time": current_time, 121 | "rank": conf.graph.rank, 122 | "epoch": scheduler.epoch_, 123 | **tracker(), 124 | }, 125 | tags={"split": "test", "type": label}, 126 | display=True, 127 | ) 128 | conf.logger.save_json() 129 | 130 | 131 | def dispaly_best_test_stat(conf, scheduler): 132 | current_time = time.strftime("%Y-%m-%d %H:%M:%S") 133 | 134 | conf.logger.log_metric( 135 | name="runtime", 136 | values={ 137 | "time": current_time, 138 | "rank": conf.graph.rank, 139 | "epoch": scheduler.epoch_, 140 | "best_perf": scheduler.best_tracker.best_perf, 141 | }, 142 | tags={"split": "test", "type": "local_model_avg"}, 143 | display=False, 144 | ) 145 | 146 | conf.logger.log( 147 | "best performance at local index {} \ 148 | (best epoch {:.3f}, current epoch {:.3f}): {}.".format( 149 | scheduler.local_index, 150 | scheduler.best_tracker.get_best_perf_loc(), 151 | scheduler.epoch_, 152 | scheduler.best_tracker.best_perf, 153 | ) 154 | ) 155 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/mathdict.py: -------------------------------------------------------------------------------- 1 | class MathDict(): 2 | def __init__(self, dictionary): 3 | self.dictionary = dictionary 4 | self.keys = set(dictionary.keys()) 5 | 6 | def __str__(self): 7 | return 'MathDict({})'.format(str(self.dictionary)) 8 | 9 | def __repr__(self): 10 | return 'MathDict({})'.format(repr(self.dictionary)) 11 | 12 | def map(self, mapfun): 13 | new_dict = {} 14 | for key in self.keys: 15 | new_dict[key] = mapfun(self.dictionary[key]) 16 | return MathDict(new_dict) 17 | 18 | def filter(self, condfun): 19 | new_dict = {} 20 | for key in self.keys: 21 | if condfun(key): 22 | new_dict[key] = self.dictionary[key] 23 | return MathDict(new_dict) 24 | 25 | def detach(self): 26 | for key in self.keys: 27 | self.dictionary[key] = self.dictionary[key].detach() 28 | 29 | def values(self): 30 | return self.dictionary.values() 31 | 32 | def items(self): 33 | return self.dictionary.items() 34 | 35 | 36 | def _mathdict_binary_op(operation): 37 | def op(self, other): 38 | new_dict = {} 39 | if isinstance(other, MathDict): 40 | assert other.keys == self.keys 41 | for key in self.keys: 42 | new_dict[key] = operation(self.dictionary[key], other.dictionary[key]) 43 | else: 44 | for key in self.keys: 45 | new_dict[key] = operation(self.dictionary[key], other) 46 | return MathDict(new_dict) 47 | return op 48 | 49 | 50 | def _mathdict_map_op(operation): 51 | def op(self, *args, **kwargs): 52 | new_dict = {} 53 | for key in self.keys: 54 | new_dict[key] = operation(self.dictionary[key], args, kwargs) 55 | return MathDict(new_dict) 56 | return op 57 | 58 | 59 | def _mathdict_binary_in_place_op(operation): 60 | def op(self, other): 61 | if isinstance(other, MathDict): 62 | assert other.keys == self.keys 63 | for key in self.keys: 64 | operation(self.dictionary, key, other.dictionary[key]) 65 | else: 66 | for key in self.keys: 67 | operation(self.dictionary, key, other) 68 | return self 69 | return op 70 | 71 | 72 | def _iadd(dict, key, b): 73 | dict[key] += b 74 | 75 | 76 | def _isub(dict, key, b): 77 | dict[key] -= b 78 | 79 | 80 | def _imul(dict, key, b): 81 | dict[key] *= b 82 | 83 | 84 | def _itruediv(dict, key, b): 85 | dict[key] /= b 86 | 87 | 88 | def _ifloordiv(dict, key, b): 89 | dict[key] //= b 90 | 91 | 92 | MathDict.__add__ = _mathdict_binary_op(lambda a, b: a + b) 93 | MathDict.__sub__ = _mathdict_binary_op(lambda a, b: a - b) 94 | MathDict.__rsub__ = _mathdict_binary_op(lambda a, b: b - a) 95 | MathDict.__mul__ = _mathdict_binary_op(lambda a, b: a * b) 96 | MathDict.__rmul__ = _mathdict_binary_op(lambda a, b: a * b) 97 | MathDict.__truediv__ = _mathdict_binary_op(lambda a, b: a / b) 98 | MathDict.__floordiv__ = _mathdict_binary_op(lambda a, b: a // b) 99 | MathDict.__getitem__ = _mathdict_map_op( 100 | lambda x, args, kwargs: x.__getitem__(*args, **kwargs)) 101 | MathDict.__iadd__ = _mathdict_binary_in_place_op(_iadd) 102 | MathDict.__isub__ = _mathdict_binary_in_place_op(_isub) 103 | MathDict.__imul__ = _mathdict_binary_in_place_op(_imul) 104 | MathDict.__itruediv__ = _mathdict_binary_in_place_op(_itruediv) 105 | MathDict.__ifloordiv__ = _mathdict_binary_in_place_op(_ifloordiv) 106 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/op_files.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Auxiliary functions that support for system.""" 3 | import os 4 | import json 5 | import pickle 6 | from os.path import exists 7 | from six.moves import cPickle 8 | 9 | 10 | """operate files.""" 11 | 12 | 13 | def read_text_withoutsplit(path): 14 | """read text file from path.""" 15 | with open(path, "r") as f: 16 | return f.read() 17 | 18 | 19 | def read_txt(path): 20 | """read text file from path.""" 21 | with open(path, "r") as f: 22 | return f.read().splitlines() 23 | 24 | 25 | def read_json(path): 26 | """read json file from path.""" 27 | with open(path, 'r') as f: 28 | return json.load(f) 29 | 30 | 31 | def write_txt(data, out_path, type="w"): 32 | """write the data to the txt file.""" 33 | with open(out_path, type) as f: 34 | f.write(data) 35 | 36 | 37 | def load_pickle(path): 38 | """load data by pickle.""" 39 | with open(path, 'rb') as handle: 40 | return pickle.load(handle) 41 | 42 | 43 | def write_pickle(data, path): 44 | """dump file to dir.""" 45 | print("write --> data to path: {}\n".format(path)) 46 | with open(path, 'wb') as handle: 47 | pickle.dump(data, handle) 48 | 49 | 50 | def load_cpickle(path): 51 | """load data by pickle.""" 52 | with open(path, 'rb') as handle: 53 | return cPickle.load(handle) 54 | 55 | 56 | def write_cpickle(data, path): 57 | """dump file to dir.""" 58 | print("write --> data to path: {}\n".format(path)) 59 | with open(path, 'wb') as handle: 60 | cPickle.dump(data, handle) 61 | 62 | 63 | def output_string(data, path_output, delimiter='\n'): 64 | """join the string in a list and output them to a file.""" 65 | os.remove(path_output) if exists(path_output) else None 66 | 67 | for d in data: 68 | try: 69 | write_txt(d + delimiter, path_output, 'a') 70 | except: 71 | print(d) 72 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/op_paths.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import shutil 4 | 5 | 6 | def get_current_path(conf, rank): 7 | paths = conf.resume.split(',') 8 | splited_paths = map( 9 | lambda p: p.split('/')[-1].split('-')[: 1], paths) 10 | splited_paths_dict = dict([ 11 | (path, paths[ind]) for ind, path in enumerate(splited_paths)]) 12 | return splited_paths_dict[str(rank)] 13 | 14 | 15 | def build_dir(path, force): 16 | """build directory.""" 17 | if os.path.exists(path) and force: 18 | shutil.rmtree(path) 19 | os.mkdir(path) 20 | elif not os.path.exists(path): 21 | os.mkdir(path) 22 | return path 23 | 24 | 25 | def build_dirs(path): 26 | try: 27 | os.makedirs(path) 28 | except Exception as e: 29 | print(' encounter error: {}'.format(e)) 30 | 31 | 32 | def remove_folder(path): 33 | try: 34 | shutil.rmtree(path) 35 | except Exception as e: 36 | print(' encounter error: {}'.format(e)) 37 | 38 | 39 | def list_files(root_path): 40 | dirs = os.listdir(root_path) 41 | return [os.path.join(root_path, path) for path in dirs] 42 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/sparsification.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import bit2byte 8 | 9 | 10 | def get_n_bits(tensor): 11 | return 8 * tensor.nelement() * tensor.element_size() 12 | 13 | 14 | """define some general compressors, e.g., top_k, random_k, sign""" 15 | 16 | 17 | class SparsificationCompressor(object): 18 | def get_top_k(self, x, ratio): 19 | """it will sample the top 1-ratio of the samples.""" 20 | x_data = x.view(-1) 21 | x_len = x_data.nelement() 22 | top_k = max(1, int(x_len * (1 - ratio))) 23 | 24 | # get indices and the corresponding values 25 | if top_k == 1: 26 | _, selected_indices = torch.max(x_data.abs(), dim=0, keepdim=True) 27 | else: 28 | _, selected_indices = torch.topk( 29 | x_data.abs(), top_k, largest=True, sorted=False 30 | ) 31 | return x_data[selected_indices], selected_indices 32 | 33 | def get_mask(self, flatten_arr, indices): 34 | mask = torch.zeros_like(flatten_arr) 35 | mask[indices] = 1 36 | 37 | mask = mask.byte() 38 | return mask.float(), (~mask).float() 39 | 40 | def get_random_k(self, x, ratio, is_biased=True): 41 | """it will randomly sample the 1-ratio of the samples.""" 42 | # get tensor size. 43 | x_data = x.view(-1) 44 | x_len = x_data.nelement() 45 | top_k = max(1, int(x_len * (1 - ratio))) 46 | 47 | # random sample the k indices. 48 | selected_indices = np.random.choice(x_len, top_k, replace=False) 49 | selected_indices = torch.LongTensor(selected_indices).to(x.device) 50 | 51 | if is_biased: 52 | return x_data[selected_indices], selected_indices 53 | else: 54 | return x_len / top_k * x_data[selected_indices], selected_indices 55 | 56 | def compress(self, arr, op, compress_ratio, is_biased): 57 | if "top_k" in op: 58 | values, indices = self.get_top_k(arr, compress_ratio) 59 | elif "random_k" in op: 60 | values, indices = self.get_random_k(arr, compress_ratio) 61 | else: 62 | raise NotImplementedError 63 | 64 | # n_bits = get_n_bits(values) + get_n_bits(indices) 65 | return values, indices 66 | 67 | def uncompress(self, values, indices, selected_shapes, original_shapes): 68 | # apply each param. 69 | sync_pointer = 0 70 | pointer = 0 71 | 72 | _q_values, _q_indices = [], [] 73 | for idx, n_sparse_value in enumerate(selected_shapes): 74 | # get value and indice for the current param. 75 | _q_value = values[sync_pointer : sync_pointer + n_sparse_value] 76 | _q_indice = pointer + indices[sync_pointer : sync_pointer + n_sparse_value] 77 | _q_values += [_q_value] 78 | _q_indices += [_q_indice] 79 | 80 | # update the pointers. 81 | sync_pointer += n_sparse_value 82 | pointer += original_shapes[idx][1] 83 | return torch.cat(_q_values), torch.cat(_q_indices).long() 84 | 85 | 86 | class QuantizationCompressor(object): 87 | def get_qsgd(self, x, s, is_biased=False): 88 | norm = x.norm(p=2) 89 | level_float = s * x.abs() / norm 90 | previous_level = torch.floor(level_float) 91 | is_next_level = (torch.rand_like(x) < (level_float - previous_level)).float() 92 | new_level = previous_level + is_next_level 93 | 94 | scale = 1 95 | if is_biased: 96 | d = x.nelement() 97 | scale = 1.0 / (min(d / (s ** 2), math.sqrt(d) / s) + 1.0) 98 | return scale * torch.sign(x) * norm * new_level / s 99 | 100 | def qsgd_quantize_numpy(self, x, s, is_biased=False): 101 | """quantize the tensor x in d level on the absolute value coef wise""" 102 | norm = np.sqrt(np.sum(np.square(x))) 103 | level_float = s * np.abs(x) / norm 104 | previous_level = np.floor(level_float) 105 | is_next_level = np.random.rand(*x.shape) < (level_float - previous_level) 106 | new_level = previous_level + is_next_level 107 | 108 | scale = 1 109 | if is_biased: 110 | d = len(x) 111 | scale = 1.0 / (np.minimum(d / s ** 2, np.sqrt(d) / s) + 1.0) 112 | return scale * np.sign(x) * norm * new_level / s 113 | 114 | def compress(self, arr, op, quantize_level, is_biased): 115 | s = 2 ** quantize_level - 1 116 | values = self.get_qsgd(arr, s, is_biased) 117 | 118 | # n_bits = get_n_bits(values) * quantize_level / 32 119 | return values 120 | 121 | def uncompress(self, arr): 122 | return arr 123 | 124 | 125 | class SignCompressor(object): 126 | """Taken from https://github.com/PermiJW/signSGD-with-Majority-Vote""" 127 | 128 | def packing(self, src_tensor): 129 | src_tensor = torch.sign(src_tensor) 130 | src_tensor_size = src_tensor.size() 131 | src_tensor = src_tensor.view(-1) 132 | src_len = len(src_tensor) 133 | add_elm = 32 - (src_len % 32) 134 | if src_len % 32 == 0: 135 | add_elm = 0 136 | new_tensor = torch.zeros( 137 | [add_elm], dtype=torch.float32, device=src_tensor.device 138 | ) 139 | src_tensor = torch.cat((src_tensor, new_tensor), 0) 140 | src_tensor = src_tensor.view(32, -1) 141 | src_tensor = src_tensor.to(dtype=torch.int32) 142 | dst_tensor = bit2byte.packing(src_tensor) 143 | dst_tensor = dst_tensor.to(dtype=torch.int32) 144 | return dst_tensor, src_tensor_size 145 | 146 | def unpacking(self, src_tensor, src_tensor_size): 147 | src_element_num = self.element_num(src_tensor_size) 148 | add_elm = 32 - (src_element_num % 32) 149 | if src_element_num % 32 == 0: 150 | add_elm = 0 151 | src_tensor = src_tensor.int() 152 | new_tensor = torch.ones( 153 | src_element_num + add_elm, device=src_tensor.device, dtype=torch.int32 154 | ) 155 | new_tensor = new_tensor.view(32, -1) 156 | new_tensor = bit2byte.unpacking(src_tensor, new_tensor) 157 | new_tensor = new_tensor.view(-1) 158 | new_tensor = new_tensor[:src_element_num] 159 | new_tensor = new_tensor.view(src_tensor_size) 160 | new_tensor = -new_tensor.add_(-1) 161 | new_tensor = new_tensor.float() 162 | return new_tensor 163 | 164 | def majority_vote(self, src_tensor_list): 165 | voter_num = len(src_tensor_list) 166 | src_tensor = torch.stack(src_tensor_list) 167 | src_tensor = src_tensor.view(-1) 168 | full_size = 32 * len(src_tensor) 169 | new_tensor = torch.ones(full_size, device=src_tensor.device, dtype=torch.int32) 170 | new_tensor = new_tensor.view(32, -1) 171 | new_tensor = bit2byte.unpacking(src_tensor, new_tensor) 172 | new_tensor = -new_tensor.add_(-1) 173 | # sum 174 | new_tensor = new_tensor.permute(1, 0).contiguous().view(voter_num, -1) 175 | new_tensor = torch.sum(new_tensor, 0) 176 | new_tensor = new_tensor.view(-1, 32).permute(1, 0) 177 | new_tensor = torch.sign(new_tensor) 178 | new_tensor = bit2byte.packing(new_tensor) 179 | new_tensor = new_tensor.to(dtype=torch.int32) 180 | return new_tensor 181 | 182 | def element_num(self, size): 183 | num = 1 184 | for i in range(len(size)): 185 | num *= size[i] 186 | return num 187 | 188 | def compress(self, src_tensor): 189 | return self.packing(src_tensor) 190 | 191 | def uncompress(self, src_tensor, src_tensor_size): 192 | dst_tensor = self.unpacking(src_tensor, src_tensor_size) 193 | return dst_tensor 194 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/stat_tracker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from copy import deepcopy 3 | 4 | import torch 5 | 6 | from pcode.utils.communication import global_average 7 | 8 | 9 | class MaxMeter(object): 10 | """ 11 | Keeps track of the max of all the values that are 'add'ed 12 | """ 13 | 14 | def __init__(self): 15 | self.max = None 16 | 17 | def update(self, value): 18 | """ 19 | Add a value to the accumulator. 20 | :return: `true` if the provided value became the new max 21 | """ 22 | if self.max is None or value > self.max: 23 | self.max = deepcopy(value) 24 | return True 25 | else: 26 | return False 27 | 28 | def value(self): 29 | """Access the current running average""" 30 | return self.max 31 | 32 | 33 | class MinMeter(object): 34 | """ 35 | Keeps track of the max of all the values that are 'add'ed 36 | """ 37 | 38 | def __init__(self): 39 | self.min = None 40 | 41 | def update(self, value): 42 | """ 43 | Add a value to the accumulator. 44 | :return: `true` if the provided value became the new max 45 | """ 46 | if self.min is None or value < self.min: 47 | self.min = deepcopy(value) 48 | return True 49 | else: 50 | return False 51 | 52 | def value(self): 53 | """Access the current running average""" 54 | return self.min 55 | 56 | 57 | class AverageMeter(object): 58 | """Computes and stores the average and current value""" 59 | 60 | def __init__(self): 61 | self.reset() 62 | 63 | def reset(self): 64 | self.val = 0 65 | self.avg = 0 66 | self.sum = 0 67 | self.max = -float("inf") 68 | self.min = float("inf") 69 | self.count = 0 70 | 71 | def update(self, val, n=1): 72 | self.val = val 73 | self.sum += val * n 74 | self.count += n 75 | self.avg = self.sum / self.count 76 | self.max = val if val > self.max else self.max 77 | self.min = val if val < self.min else self.min 78 | 79 | 80 | class RuntimeTracker(object): 81 | """Tracking the runtime stat for local training.""" 82 | 83 | def __init__(self, metrics_to_track=["top1"], on_cuda=True): 84 | self.metrics_to_track = metrics_to_track 85 | self.things_to_track = ["loss"] + metrics_to_track 86 | self.on_cuda = on_cuda 87 | self.reset() 88 | 89 | def reset(self): 90 | self.stat = dict((name, AverageMeter()) for name in self.things_to_track) 91 | 92 | def evaluate_global_metric(self, metric): 93 | return global_average( 94 | self.stat[metric].sum, self.stat[metric].count, on_cuda=self.on_cuda 95 | ).item() 96 | 97 | def evaluate_global_metrics(self): 98 | return [self.evaluate_global_metric(metric) for metric in self.metrics_to_track] 99 | 100 | def get_metrics_performance(self): 101 | return [self.stat[metric].avg for metric in self.metrics_to_track] 102 | 103 | def update_metrics(self, metric_stat, n_samples): 104 | for idx, thing in enumerate(self.things_to_track): 105 | self.stat[thing].update(metric_stat[idx], n_samples) 106 | 107 | def __call__(self): 108 | return dict((name, val.avg) for name, val in self.stat.items()) 109 | 110 | 111 | class BestPerf(object): 112 | def __init__(self, best_perf=None, larger_is_better=True): 113 | self.best_perf = best_perf 114 | self.cur_perf = None 115 | self.best_perf_locs = [] 116 | self.larger_is_better = larger_is_better 117 | 118 | # define meter 119 | self._define_meter() 120 | 121 | def _define_meter(self): 122 | self.meter = MaxMeter() if self.larger_is_better else MinMeter() 123 | 124 | def update(self, perf, perf_location): 125 | self.is_best = self.meter.update(perf) 126 | self.cur_perf = perf 127 | 128 | if self.is_best: 129 | self.best_perf = perf 130 | self.best_perf_locs += [perf_location] 131 | 132 | def get_best_perf_loc(self): 133 | return self.best_perf_locs[-1] if len(self.best_perf_locs) != 0 else None 134 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/tensor_buffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pcode.utils.communication import flatten 3 | 4 | 5 | class TensorBuffer: 6 | """ 7 | Packs multiple tensors into one flat buffer for efficient 8 | intra-worker communication. 9 | """ 10 | 11 | def __init__(self, tensors, use_cuda=True): 12 | indices = [0] 13 | for tensor in tensors: 14 | new_end = indices[-1] + tensor.nelement() 15 | indices.append(new_end) 16 | 17 | self._start_idx = indices[:-1] 18 | self._end_idx = indices[1:] 19 | self._tensors_len = len(tensors) 20 | self._tensors_sizes = [x.size() for x in tensors] 21 | 22 | self.buffer = flatten(tensors, use_cuda=use_cuda) # copies 23 | 24 | def __getitem__(self, index): 25 | return self.buffer[self._start_idx[index] : self._end_idx[index]].view( 26 | self._tensors_sizes[index] 27 | ) 28 | 29 | def __len__(self): 30 | return self._tensors_len 31 | 32 | def is_cuda(self): 33 | return self.buffer.is_cuda 34 | 35 | def nelement(self): 36 | return self.buffer.nelement() 37 | 38 | def unpack(self, tensors): 39 | for tensor, entry in zip(tensors, self): 40 | tensor.data[:] = entry 41 | 42 | -------------------------------------------------------------------------------- /distributed_code/pcode/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from contextlib import contextmanager 4 | from io import StringIO 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class Timer: 11 | """ 12 | Timer for PyTorch code 13 | Comes in the form of a contextmanager: 14 | 15 | Example: 16 | >>> timer = Timer() 17 | ... for i in range(10): 18 | ... with timer("expensive operation"): 19 | ... x = torch.randn(100) 20 | ... print(timer.summary()) 21 | """ 22 | 23 | def __init__(self, verbosity_level=1, log_fn=None, skip_first=True, on_cuda=True): 24 | self.verbosity_level = verbosity_level 25 | self.log_fn = log_fn if log_fn is not None else self._default_log_fn 26 | self.skip_first = skip_first 27 | self.cuda_available = torch.cuda.is_available() and on_cuda 28 | 29 | self.reset() 30 | 31 | def reset(self): 32 | """Reset the timer""" 33 | self.totals = {} # Total time per label 34 | self.first_time = {} # First occurrence of a label (start time) 35 | self.last_time = {} # Last occurence of a label (end time) 36 | self.call_counts = {} # Number of times a label occurred 37 | 38 | @contextmanager 39 | def __call__(self, label, epoch=-1.0, verbosity=1): 40 | # Don't measure this if the verbosity level is too high 41 | if verbosity > self.verbosity_level: 42 | yield 43 | return 44 | 45 | # Measure the time 46 | self._cuda_sync() 47 | start = time.time() 48 | yield 49 | self._cuda_sync() 50 | end = time.time() 51 | 52 | # Update first and last occurrence of this label 53 | if label not in self.first_time: 54 | self.first_time[label] = start 55 | self.last_time[label] = end 56 | 57 | # Update the totals and call counts 58 | if label not in self.totals and self.skip_first: 59 | self.totals[label] = 0.0 60 | del self.first_time[label] 61 | self.call_counts[label] = 0 62 | elif label not in self.totals and not self.skip_first: 63 | self.totals[label] = end - start 64 | self.call_counts[label] = 1 65 | else: 66 | self.totals[label] += end - start 67 | self.call_counts[label] += 1 68 | 69 | if self.call_counts[label] > 0: 70 | # We will reduce the probability of logging a timing 71 | # linearly with the number of time we have seen it. 72 | # It will always be recorded in the totals, though. 73 | if np.random.rand() < 1 / self.call_counts[label]: 74 | self.log_fn( 75 | "timer", {"epoch": epoch, "value": end - start}, {"event": label} 76 | ) 77 | 78 | def summary(self): 79 | """ 80 | Return a summary in string-form of all the timings recorded so far 81 | """ 82 | if len(self.totals) > 0: 83 | with StringIO() as buffer: 84 | total_avg_time = 0 85 | print("--- Timer summary ------------------------", file=buffer) 86 | print(" Event | Count | Average time | Frac.", file=buffer) 87 | for event_label in sorted(self.totals): 88 | total = self.totals[event_label] 89 | count = self.call_counts[event_label] 90 | if count == 0: 91 | continue 92 | avg_duration = total / count 93 | total_runtime = ( 94 | self.last_time[event_label] - self.first_time[event_label] 95 | ) 96 | runtime_percentage = 100 * total / total_runtime 97 | total_avg_time += avg_duration if "." not in event_label else 0 98 | print( 99 | f"- {event_label:30s} | {count:6d} | {avg_duration:11.5f}s | {runtime_percentage:5.1f}%", 100 | file=buffer, 101 | ) 102 | print("-------------------------------------------", file=buffer) 103 | event_label = "total_averaged_time" 104 | print( 105 | f"- {event_label:30s}| {count:6d} | {total_avg_time:11.5f}s |", 106 | file=buffer, 107 | ) 108 | print("-------------------------------------------", file=buffer) 109 | return buffer.getvalue() 110 | 111 | def _cuda_sync(self): 112 | """Finish all asynchronous GPU computations to get correct timings""" 113 | if self.cuda_available: 114 | torch.cuda.synchronize() 115 | 116 | def _default_log_fn(self, _, values, tags): 117 | label = tags["label"] 118 | epoch = values["epoch"] 119 | duration = values["value"] 120 | print(f"Timer: {label:30s} @ {epoch:4.1f} - {duration:8.5f}s") 121 | -------------------------------------------------------------------------------- /distributed_code/run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import os 4 | 5 | import pcode.utils.op_files as op_files 6 | import parameters as para 7 | import tmux_cluster.tmux as tx 8 | 9 | 10 | def read_hostfile(file_path): 11 | def _parse(line): 12 | matched_line = re.findall(r"^(.*?) slots=(.*?)$", line, re.DOTALL) 13 | matched_line = [x.strip() for x in matched_line[0]] 14 | return matched_line 15 | 16 | # read file 17 | lines = op_files.read_txt(file_path) 18 | 19 | # use regex to parse the file. 20 | ip2slots = dict(_parse(line) for line in lines) 21 | return ip2slots 22 | 23 | 24 | def map_slot(ip2slots): 25 | ip_slot = [] 26 | for ip, slots in ip2slots.items(): 27 | for _ in range(int(slots)): 28 | ip_slot += [ip] 29 | return ip_slot 30 | 31 | 32 | def run_cmd(cmd): 33 | # run the cmd. 34 | print("\nRun the following cmd:\n" + cmd) 35 | os.system(cmd) 36 | 37 | 38 | def get_random_port(): 39 | import socket 40 | from contextlib import closing 41 | 42 | def find_free_port(): 43 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 44 | s.bind(("", 0)) 45 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 46 | return s.getsockname()[1] 47 | 48 | return find_free_port() 49 | 50 | 51 | def build_nccl_script(conf, replacement=None): 52 | # build runnable script. 53 | cmd = " main.py " 54 | for k, v in conf.__dict__.items(): 55 | if replacement is not None and k in replacement: 56 | cmd += " --{} {} ".format(k, replacement[k]) 57 | elif v is not None: 58 | cmd += " --{} {} ".format(k, v) 59 | return cmd 60 | 61 | 62 | def build_mpi_script(conf, replacement=None): 63 | # get prefix_cmd. 64 | if conf.n_mpi_process > 1: 65 | # prefix_cmd = "mpirun -n {} --hostfile {} -bind-to none -map-by slot -mca pml ob1 -mca btl ^openib -x CUDA_LAUNCH_BLOCKING=0 -x NCCL_DEBUG=INFO --mca orte_base_help_aggregate 1 --mca btl_tcp_if_exclude docker0,lo --mca btl_smcuda_use_cuda_ipc 1 --prefix {} " 66 | prefix_cmd = f"mpirun -n {conf.n_mpi_process} --hostfile {conf.hostfile} --mca orte_base_help_aggregate 0 --mca btl_tcp_if_exclude docker0,lo --mca btl_smcuda_use_cuda_ipc {1 if conf.use_ipc else 0} --prefix {conf.mpi_path} " 67 | prefix_cmd += ( 68 | f" -x {conf.mpi_env}" 69 | if conf.mpi_env is not None and len(conf.mpi_env) > 0 70 | else "" 71 | ) 72 | else: 73 | prefix_cmd = "" 74 | 75 | # build complete script. 76 | cmd = " {} main.py ".format(conf.python_path) 77 | for k, v in conf.__dict__.items(): 78 | if replacement is not None and k in replacement: 79 | cmd += " --{} {} ".format(k, replacement[k]) 80 | elif v is not None: 81 | cmd += " --{} {} ".format(k, v) 82 | return prefix_cmd + cmd 83 | 84 | 85 | def create_job_on_nodes(conf, tasks): 86 | # rebuild tasks for each script. 87 | node_tasks = [] 88 | for ip, _tasks in tasks.items(): 89 | _tasks = " & ".join(_tasks) 90 | node_tasks += [(ip, _tasks)] 91 | 92 | if (not conf.remote_exec) or "localhost" in tasks: 93 | run_cmd(node_tasks[0][1]) 94 | else: 95 | print("\nrun the job on the remote host.\n") 96 | 97 | for ip, _tasks in node_tasks: 98 | tx.Run(name=f"{conf.experiment}", job_node=ip).make_job( 99 | job_name="job", task_scripts=[_tasks] 100 | ) 101 | 102 | 103 | def main_nccl_or_gloo(conf, ip2slot): 104 | # build runnable script for a single machine. 105 | script = build_nccl_script(conf) 106 | assert conf.work_dir is not None 107 | 108 | # build scripts for distributed world 109 | tasks = dict() 110 | for rank in range(conf.n_mpi_process): 111 | if conf.clean_python: 112 | cmd = "pkill -9 python" 113 | else: 114 | script = build_nccl_script(conf, replacement={"local_rank": rank}) 115 | 116 | # build remote executable script. 117 | cmd = "cd {work_dir} && {env} {python_path} {script}".format( 118 | work_dir=conf.work_dir, 119 | env="", 120 | python_path=conf.python_path, 121 | script=script, 122 | ) 123 | 124 | if ip2slot[rank] in tasks: 125 | tasks[ip2slot[rank]].append(cmd) 126 | else: 127 | tasks[ip2slot[rank]] = [cmd] 128 | 129 | # build cmd. 130 | print( 131 | "build cmd ({rank}/{world_size}): \n{cmd}\n\n".format( 132 | rank=rank + 1, world_size=conf.n_mpi_process, cmd=cmd 133 | ) 134 | ) 135 | 136 | # run multiple cmds on node. 137 | create_job_on_nodes(conf, tasks) 138 | 139 | 140 | def main_mpi(conf, ip2slot): 141 | # build scripts for distributed world 142 | tasks = dict() 143 | if conf.clean_python: 144 | cmd = "pkill -9 python" 145 | else: 146 | # build runnable script for a single machine. 147 | cmd = build_mpi_script(conf) 148 | 149 | tasks[ip2slot[0]] = [ 150 | ( 151 | "cd {work_dir} && ".format(work_dir=conf.work_dir) 152 | if conf.work_dir is not None 153 | else "" 154 | ) 155 | + cmd 156 | ] 157 | 158 | # run cmd. 159 | create_job_on_nodes(conf, tasks) 160 | 161 | 162 | if __name__ == "__main__": 163 | # parse the arguments. 164 | conf = para.get_args() 165 | 166 | # get ip and the corresponding # of slots. 167 | ip2slots = read_hostfile(conf.hostfile) 168 | ip2slot = map_slot(ip2slots) 169 | 170 | # run the main script. 171 | if conf.backend == "nccl" or conf.backend == "gloo": 172 | main_nccl_or_gloo(conf, ip2slot) 173 | elif conf.backend == "mpi": 174 | main_mpi(conf, ip2slot) 175 | -------------------------------------------------------------------------------- /distributed_code/tmux_cluster/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfml/LocalSGD-Code/3d4811d01673af205a00176f5389ed008a1ddb37/distributed_code/tmux_cluster/__init__.py -------------------------------------------------------------------------------- /distributed_code/tmux_cluster/tmux.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from tmux_cluster.utils import ossystem 4 | import tmux_cluster.utils as u 5 | import shlex 6 | 7 | TASKDIR_PREFIX = "/tmp/tasklogs" 8 | 9 | 10 | def exec_on_node(cmds, host="localhost"): 11 | def _decide_node(cmd): 12 | return cmd if host == "localhost" else f"ssh {host} -t {shlex.quote(cmd)}" 13 | 14 | cmds = ( 15 | [_decide_node(cmd) for cmd in cmds] 16 | if isinstance(cmds, list) 17 | else _decide_node(cmds) 18 | ) 19 | ossystem(cmds) 20 | 21 | 22 | class Run(object): 23 | def __init__(self, name, job_node="localhost"): 24 | self.name = name 25 | self.jobs = [] 26 | self.job_node = job_node 27 | 28 | def make_job(self, job_name, task_scripts, run=True, **kwargs): 29 | num_tasks = len(task_scripts) 30 | assert num_tasks > 0 31 | 32 | if kwargs: 33 | print("Warning: unused kwargs", kwargs) 34 | 35 | # Creating cmds 36 | cmds = [] 37 | session_name = self.name + "-" + job_name # tmux can't use . in name 38 | cmds.append(f"tmux kill-session -t {session_name}") 39 | 40 | windows = [] 41 | for task_id in range(num_tasks): 42 | if task_id == 0: 43 | cmds.append(f"tmux new-session -s {session_name} -n {task_id} -d") 44 | else: 45 | cmds.append(f"tmux new-window -t {session_name} -n {task_id}") 46 | windows.append(f"{session_name}:{task_id}") 47 | 48 | job = Job(self, job_name, windows, task_scripts, self.job_node) 49 | job.make_tasks() 50 | self.jobs.append(job) 51 | if run: 52 | for job in self.jobs: 53 | cmds += job.cmds 54 | exec_on_node(cmds, self.job_node) 55 | return job 56 | 57 | def attach_job(self): 58 | raise NotImplementedError 59 | 60 | def kill_jobs(self): 61 | cmds = [] 62 | for job in self.jobs: 63 | session_name = self.name + "-" + job.name 64 | cmds.append(f"tmux kill-session -t {session_name}") 65 | exec_on_node(cmds, self.job_node) 66 | 67 | 68 | class Job(object): 69 | def __init__(self, run, name, windows, task_scripts, job_node): 70 | self._run = run 71 | self.name = name 72 | self.job_node = job_node 73 | self.windows = windows 74 | self.task_scripts = task_scripts 75 | self.tasks = [] 76 | 77 | def make_tasks(self): 78 | for task_id, (window, script) in enumerate( 79 | zip(self.windows, self.task_scripts) 80 | ): 81 | self.tasks.append( 82 | Task( 83 | window, 84 | self, 85 | task_id, 86 | install_script=script, 87 | task_node=self.job_node, 88 | ) 89 | ) 90 | 91 | def attach_tasks(self): 92 | raise NotImplementedError 93 | 94 | @property 95 | def cmds(self): 96 | output = [] 97 | for task in self.tasks: 98 | output += task.cmds 99 | return output 100 | 101 | 102 | class Task(object): 103 | """Local tasks interact with tmux session. 104 | 105 | * session name is derived from job name, and window names are task ids. 106 | * no pane is used. 107 | 108 | """ 109 | 110 | def __init__(self, window, job, task_id, install_script, task_node): 111 | self.window = window 112 | self.job = job 113 | self.id = task_id 114 | self.install_script = install_script 115 | self.task_node = task_node 116 | 117 | # Path 118 | self.cmds = [] 119 | self._run_counter = 0 120 | 121 | for line in install_script.split("\n"): 122 | self.run(line) 123 | 124 | def run(self, cmd): 125 | self._run_counter += 1 126 | 127 | cmd = cmd.strip() 128 | if not cmd or cmd.startswith("#"): 129 | # ignore empty command lines 130 | # ignore commented out lines 131 | return 132 | 133 | modified_cmd = cmd 134 | self.cmds.append( 135 | f"tmux send-keys -t {self.window} {shlex.quote(modified_cmd)} Enter" 136 | ) 137 | 138 | def upload(self, source_fn, target_fn="."): 139 | raise NotImplementedError() 140 | 141 | def download(self, source_fn, target_fn="."): 142 | raise NotImplementedError() 143 | -------------------------------------------------------------------------------- /distributed_code/tmux_cluster/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import yaml 3 | import os 4 | import time 5 | from tqdm import tqdm 6 | from argparse import Namespace 7 | 8 | 9 | def ossystem(cmds): 10 | if isinstance(cmds, str): 11 | print(f"\n=> {cmds}") 12 | os.system(cmds) 13 | elif isinstance(cmds, list): 14 | for cmd in tqdm(cmds): 15 | ossystem(cmd) 16 | else: 17 | raise NotImplementedError( 18 | "Cmds should be string or list of str. Got {}.".format(cmds) 19 | ) 20 | 21 | 22 | def environ(env): 23 | return os.getenv(env) 24 | 25 | 26 | def load_yaml(file): 27 | with open(file) as f: 28 | return yaml.safe_load(f) 29 | 30 | 31 | def wait_for_file(fn, max_wait_sec=600, check_interval=0.02): 32 | start_time = time.time() 33 | while True: 34 | if time.time() - start_time > max_wait_sec: 35 | assert False, "Timeout %s exceeded" % (max_wait_sec) 36 | if not os.path.exists(fn): 37 | time.sleep(check_interval) 38 | continue 39 | else: 40 | break 41 | 42 | 43 | if __name__ == "__main__": 44 | ossystem("ls") 45 | --------------------------------------------------------------------------------