├── requirements.txt ├── .gitignore ├── memstats.png ├── README.md ├── pytorchmemtracer ├── ophooks │ ├── _base_ophook.py │ ├── _memtracer_ophook.py │ └── __init__.py └── __init__.py ├── setup.py ├── visualize.py ├── train.py └── model.py /requirements.txt: -------------------------------------------------------------------------------- 1 | psutil 2 | matplotlib 3 | fire -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.pkl 3 | *.py[cod] 4 | *$py.class 5 | -------------------------------------------------------------------------------- /memstats.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/feifeibear/PyTorchMemTracer/HEAD/memstats.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Memory Tracer For PyTorch 2 | 3 | OOM is a nightmare for PyTorch users. 4 | However, most of them do not know the exact memory footprint during training. 5 | This project helps you depict the GPU memory usage changing curve during training. 6 | We record the peak GPU memory of an operator at the moment of the operator finished. 7 | The operator can be computing of a FWD submodule or a BWD submodule. 8 | 9 | ## Usage 10 | ``` 11 | python train.py 12 | python visualize.py memstats.pkl 13 | ``` 14 | 15 | ![alt perf](./memstats.png "an example result") 16 | -------------------------------------------------------------------------------- /pytorchmemtracer/ophooks/_base_ophook.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import torch 3 | 4 | 5 | class BaseOpHook(ABC): 6 | """This class allows users to add customized operations 7 | before and after the execution of a PyTorch submodule""" 8 | def __init__(self): 9 | pass 10 | 11 | @abstractmethod 12 | def pre_fwd_exec(self, module: torch.nn.Module, *args): 13 | pass 14 | 15 | @abstractmethod 16 | def post_fwd_exec(self, module: torch.nn.Module, *args): 17 | pass 18 | 19 | @abstractmethod 20 | def pre_bwd_exec(self, module: torch.nn.Module, input, output): 21 | pass 22 | 23 | @abstractmethod 24 | def post_bwd_exec(self, module: torch.nn.Module, input): 25 | pass 26 | 27 | @abstractmethod 28 | def post_iter(self): 29 | pass -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | def fetch_requirements(path): 4 | with open(path, "r") as fd: 5 | return [r.strip() for r in fd.readlines()] 6 | 7 | 8 | require_list = fetch_requirements("requirements.txt") 9 | 10 | setup( 11 | name="pytorchmemtracer", 12 | version="0.1.6", 13 | description="pytorchmemtracer", 14 | author="feifeibear", 15 | author_email="fangjiarui123@gmail.com", 16 | url="https://fangjiarui.github.io/", 17 | install_requires=require_list, 18 | setup_requires=require_list, 19 | packages=find_packages(), 20 | include_package_data=True, 21 | classifiers=[ 22 | "Programming Language :: Python :: 3.6", 23 | "Programming Language :: Python :: 3.7", 24 | "Programming Language :: Python :: 3.8", 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /pytorchmemtracer/__init__.py: -------------------------------------------------------------------------------- 1 | from .ophooks import * 2 | import torch 3 | all = ["memtracer_wrapper"] 4 | 5 | class Engine(): 6 | def __init__(self, model, ophook_list): 7 | self._ophook_list = ophook_list 8 | self._model = model 9 | 10 | def __call__(self, *args, **kwargs): 11 | return self._model(*args, **kwargs) 12 | 13 | def forward(self, *args, **kwargs): 14 | return self._model.forward(*args, **kwargs) 15 | 16 | def backward(self, loss): 17 | loss.backward() 18 | for ophook in self._ophook_list: 19 | ophook.post_iter() 20 | 21 | def save_results(self, filename): 22 | for ophook in self._ophook_list: 23 | ophook.save_results(filename) 24 | 25 | def memtracer_wrapper(model): 26 | ophook_list = [MemTracerOpHook()] 27 | register_ophooks_recursively(model, ophook_list) 28 | engine = Engine(model, ophook_list) 29 | return engine -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | import sys 4 | 5 | import fire 6 | import matplotlib.pyplot as plt 7 | import matplotlib.patches as patches 8 | 9 | def visualize_memory(dict): 10 | mem_stats = dict['mem_stats'] 11 | mem_stats = [elem/1e6 for elem in mem_stats] 12 | start_time = min(dict['time_stamps']) 13 | time_stamps = [elem - start_time for elem in dict['time_stamps']] 14 | 15 | 16 | plt.style.use("ggplot") 17 | plt.plot(time_stamps, mem_stats, label="gpu mem stats") 18 | 19 | # plt.set_xlim([min(time_stamps), max(time_stamps)]) 20 | # axis.set_ylim([0, offset]) 21 | 22 | plt.xlabel("time/s") 23 | plt.ylabel("memory/MB") 24 | plt.title("gpu mem stats") 25 | 26 | plt.savefig('memstats.png') 27 | 28 | 29 | def visualize_profile(filename): 30 | # load profile data 31 | with open(filename, "rb") as f: 32 | dict = pickle.load(f) 33 | visualize_memory(dict) 34 | 35 | 36 | if __name__ == "__main__": 37 | fire.Fire(visualize_profile) 38 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from model import SimpleModel, get_bert_data_loader 5 | # from pytorchmemtracer.ophooks import register_ophooks_recursively, MemTracerOpHook 6 | from pytorchmemtracer import memtracer_wrapper 7 | 8 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 9 | 10 | BATCH_SIZE = 8 11 | HIDDEN_DIM = 128 12 | SEQ_LEN = 128 13 | 14 | 15 | def model_func(): 16 | return SimpleModel( 17 | hidden_dim=HIDDEN_DIM, seq_len=SEQ_LEN, is_ckp=False, is_share_param=True 18 | ) 19 | 20 | LR = 5e-5 21 | BETAS = (0.9, 0.999) 22 | EPS = 1e-6 23 | WEIGHT_DECAY = 0 24 | 25 | config = { 26 | # The same format as optimizer config of DeepSpeed 27 | # https://www.deepspeed.ai/docs/config-json/#optimizer-parameters 28 | "optimizer": { 29 | "type": "Adam", 30 | "params": { 31 | "lr": LR, 32 | "betas": BETAS, 33 | "eps": EPS, 34 | "weight_decay": WEIGHT_DECAY, 35 | "use_hybrid_adam": True, 36 | }, 37 | }, 38 | "fp16": { 39 | "enabled": True, 40 | "loss_scale": 0, 41 | "initial_scale_power": 2 ** 3, 42 | "loss_scale_window": 1000, 43 | "hysteresis": 2, 44 | "min_loss_scale": 1, 45 | }, 46 | "default_chunk_size": 1024, 47 | "use_fake_dist": False, 48 | "use_cpu_embedding": False, 49 | } 50 | 51 | torch.manual_seed(0) 52 | model = model_func() 53 | 54 | optim = torch.optim.Adam( 55 | model.parameters(), lr=LR, betas=BETAS, eps=EPS, weight_decay=WEIGHT_DECAY 56 | ) 57 | model.cuda() 58 | 59 | train_loader = get_bert_data_loader(BATCH_SIZE, 10000, 128, device, False) 60 | 61 | # add this line for mem tracing 62 | model = memtracer_wrapper(model) 63 | 64 | for i, batch in enumerate(train_loader): 65 | optim.zero_grad() 66 | input_ids, labels = batch 67 | loss = model(input_ids, labels) 68 | model.backward(loss) 69 | # change the backward API 70 | # loss.backward() 71 | optim.zero_grad() 72 | optim.step() 73 | print(i, loss.item()) 74 | if i == 10: 75 | break 76 | 77 | model.save_results("memstats.pkl") -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import checkpoint 3 | from torch.utils.data import SequentialSampler 4 | from transformers import BertConfig 5 | from transformers.models.bert.modeling_bert import BertEmbeddings 6 | 7 | class Encoder(torch.nn.Module): 8 | def __init__(self, hidden_dim, is_ckp=False): 9 | super(Encoder, self).__init__() 10 | self.linear1 = torch.nn.Sequential( 11 | torch.nn.Linear(hidden_dim, 4 * hidden_dim), 12 | torch.nn.Linear(4 * hidden_dim, hidden_dim), 13 | torch.nn.Linear(hidden_dim, hidden_dim), 14 | ) 15 | 16 | self.linear3 = torch.nn.Sequential(torch.nn.Linear(hidden_dim, hidden_dim), 17 | torch.nn.Linear(hidden_dim, hidden_dim), 18 | torch.nn.Linear(hidden_dim, hidden_dim) 19 | ) 20 | self.is_ckp = is_ckp 21 | 22 | def forward(self, x): 23 | h2 = self.linear1(x) 24 | if self.is_ckp: 25 | h3 = checkpoint(self.linear3, h2) 26 | else: 27 | h3 = self.linear3(h2) 28 | return h3 29 | 30 | 31 | def get_data_loader( 32 | batch_size, 33 | total_samples, 34 | hidden_dim, 35 | device, 36 | data_type=torch.float, 37 | is_distrbuted=False, 38 | ): 39 | train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=data_type) 40 | train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_( 41 | hidden_dim 42 | ) 43 | train_dataset = torch.utils.data.TensorDataset(train_data, train_label) 44 | if is_distrbuted: 45 | sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 46 | else: 47 | sampler = SequentialSampler(train_dataset) 48 | train_loader = torch.utils.data.DataLoader( 49 | train_dataset, batch_size=batch_size, sampler=sampler 50 | ) 51 | return train_loader 52 | 53 | 54 | def get_bert_data_loader( 55 | batch_size, total_samples, sequence_length, device, is_distrbuted=False 56 | ): 57 | train_data = torch.randint( 58 | low=0, 59 | high=10, 60 | size=(total_samples, sequence_length), 61 | device=device, 62 | dtype=torch.long, 63 | ) 64 | train_label = torch.zeros(total_samples, dtype=torch.long, device=device) 65 | train_dataset = torch.utils.data.TensorDataset(train_data, train_label) 66 | if is_distrbuted: 67 | sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 68 | else: 69 | sampler = SequentialSampler(train_dataset) 70 | train_loader = torch.utils.data.DataLoader( 71 | train_dataset, batch_size=batch_size, sampler=sampler 72 | ) 73 | return train_loader 74 | 75 | 76 | class SimpleModel(torch.nn.Module): 77 | def __init__(self, hidden_dim, seq_len, is_ckp=False, is_share_param=False): 78 | super(SimpleModel, self).__init__() 79 | config = BertConfig() 80 | config.vocab_size = 25 81 | config.max_position_embeddings = seq_len 82 | config.hidden_size = hidden_dim 83 | self.embeddings_1 = BertEmbeddings(config) 84 | 85 | self._is_share_param = is_share_param 86 | if is_share_param: 87 | self.embeddings_2 = self.embeddings_1 88 | else: 89 | self.embeddings_2 = BertEmbeddings(config) 90 | self.encoder = Encoder(hidden_dim, is_ckp) 91 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss() 92 | 93 | def forward(self, x, y): 94 | h1 = self.embeddings_1(x) 95 | h2 = self.embeddings_2(x) 96 | h3 = h1 + h2 97 | h3 = self.encoder(h3) 98 | return self.cross_entropy_loss(h3[:, 0], y) -------------------------------------------------------------------------------- /pytorchmemtracer/ophooks/_memtracer_ophook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import BaseOpHook 3 | from time import time 4 | import pickle 5 | from abc import abstractmethod 6 | 7 | class MemoryMonitor: 8 | """Base class for all types of memory monitor. 9 | All monitors should have a list called `time_stamps` and a list called `mem_stats`. 10 | """ 11 | 12 | def __init__(self): 13 | self.time_stamps = [] 14 | self.mem_stats = [] 15 | 16 | def __len__(self): 17 | return len(self.mem_stats) 18 | 19 | @abstractmethod 20 | def start(self): 21 | pass 22 | 23 | @abstractmethod 24 | def finish(self): 25 | pass 26 | 27 | def state_dict(self): 28 | return { 29 | "time_stamps": self.time_stamps, 30 | "mem_stats": self.mem_stats, 31 | } 32 | 33 | def save(self, filename): 34 | with open(filename, "wb") as f: 35 | pickle.dump(self.state_dict(), f) 36 | 37 | class SyncCudaMemoryMonitor(MemoryMonitor): 38 | """ 39 | A synchronized cuda memory monitor. 40 | It only record the maximum allocated cuda memory from start point to finish point. 41 | """ 42 | 43 | def __init__(self, power: int = 10): 44 | super().__init__() 45 | 46 | def start(self): 47 | torch.cuda.synchronize() 48 | torch.cuda.reset_peak_memory_stats() 49 | 50 | def finish(self): 51 | torch.cuda.synchronize() 52 | self.time_stamps.append(time()) 53 | max_usage = torch.cuda.max_memory_allocated() 54 | self.mem_stats.append(max_usage) 55 | return max_usage 56 | 57 | class MemTracerOpHook(BaseOpHook): 58 | def __init__(self): 59 | super().__init__() 60 | self.async_mem_monitor = SyncCudaMemoryMonitor() 61 | 62 | def pre_fwd_exec(self, module: torch.nn.Module, *args): 63 | if module.training: 64 | self.async_mem_monitor.finish() 65 | self.async_mem_monitor.start() 66 | # print(f'FWD PRE {module.__class__.__name__}') 67 | 68 | def post_fwd_exec(self, module: torch.nn.Module, *args): 69 | if module.training: 70 | self.async_mem_monitor.finish() 71 | # print(f'FWD POST {module.__class__.__name__}') 72 | 73 | def pre_bwd_exec(self, module: torch.nn.Module, input, output): 74 | assert isinstance(module, torch.nn.Module) 75 | if module.training: 76 | self.async_mem_monitor.finish() 77 | self.async_mem_monitor.start() 78 | # print(f'BWD PRE {module.__class__.__name__}') 79 | 80 | def post_bwd_exec(self, module: torch.nn.Module, input): 81 | assert isinstance(module, torch.nn.Module) 82 | if module.training: 83 | self.async_mem_monitor.finish() 84 | # print(f'BWD POST {module.__class__.__name__}') 85 | 86 | def pre_iter(self): 87 | pass 88 | 89 | def post_iter(self): 90 | self.async_mem_monitor.finish() 91 | # print(f'post_iter') 92 | 93 | def save_results(self, filename): 94 | self.async_mem_monitor.save(filename) 95 | 96 | def show_mem_stats(self): 97 | start_timestamp = min(self.async_mem_monitor.time_stamps) 98 | self.async_mem_monitor.time_stamps = [elem - start_timestamp for elem in self.async_mem_monitor.time_stamps] 99 | min_mem_used = min(self.async_mem_monitor.mem_stats) 100 | self.async_mem_monitor.mem_stats = [elem - min_mem_used for elem in self.async_mem_monitor.mem_stats] 101 | print(self.async_mem_monitor.time_stamps) 102 | print(self.async_mem_monitor.mem_stats) -------------------------------------------------------------------------------- /pytorchmemtracer/ophooks/__init__.py: -------------------------------------------------------------------------------- 1 | from ._base_ophook import BaseOpHook 2 | from ._memtracer_ophook import MemTracerOpHook 3 | import torch 4 | from typing import List 5 | 6 | all = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"] 7 | 8 | 9 | # apply torch.autograd.Function that calls a backward_function to tensors in output 10 | def _apply_to_tensors_only(module, functional, backward_function, outputs): 11 | if type(outputs) is tuple: 12 | touched_outputs = [] 13 | for output in outputs: 14 | touched_output = _apply_to_tensors_only(module, functional, 15 | backward_function, output) 16 | touched_outputs.append(touched_output) 17 | return tuple(touched_outputs) 18 | elif type(outputs) is torch.Tensor: 19 | return functional.apply(module, backward_function, outputs) 20 | else: 21 | return outputs 22 | 23 | 24 | class PreBackwardFunction(torch.autograd.Function): 25 | @staticmethod 26 | def forward(ctx, module, pre_backward_function, outputs): 27 | ctx.module = module 28 | ctx.pre_backward_function = pre_backward_function 29 | module.applied_pre_backward = False 30 | outputs = outputs.detach() 31 | return outputs 32 | 33 | @staticmethod 34 | def backward(ctx, *args): 35 | ctx.pre_backward_function(ctx.module) 36 | return (None, None) + args 37 | 38 | 39 | class PostBackwardFunction(torch.autograd.Function): 40 | @staticmethod 41 | def forward(ctx, module, pre_backward_function, output): 42 | ctx.module = module 43 | output = output.detach() 44 | ctx.pre_backward_function = pre_backward_function 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, *args): 49 | """ 50 | Args: 51 | activation_grad of the next layer. 52 | Returns: 53 | grad of the input activation. 54 | """ 55 | ctx.pre_backward_function(ctx.module) 56 | return (None, None) + args 57 | 58 | 59 | def register_ophooks_recursively(module: torch.nn.Module, 60 | ophook_list: List[BaseOpHook] = None, 61 | name: str = ""): 62 | r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" 63 | assert isinstance(module, torch.nn.Module) 64 | has_children = False 65 | for child_name, child in module.named_children(): 66 | register_ophooks_recursively(child, ophook_list, name + child_name) 67 | has_children = True 68 | 69 | # Early return on modules with no parameters or buffers that 70 | # are not in their children. 71 | if (len(list(module.named_parameters(recurse=False))) == 0 72 | and len(list(module.named_buffers(recurse=False))) == 0): 73 | return 74 | 75 | # return if the module has not childern. 76 | if has_children: 77 | return 78 | 79 | if ophook_list is not None: 80 | for hook in ophook_list: 81 | assert(isinstance(hook, BaseOpHook)) 82 | 83 | def _pre_forward_module_hook(submodule, *args): 84 | for hook in ophook_list: 85 | assert isinstance(submodule, torch.nn.Module) 86 | hook.pre_fwd_exec(submodule, *args) 87 | 88 | def _post_forward_module_hook(submodule, *args): 89 | for hook in ophook_list: 90 | assert isinstance(submodule, torch.nn.Module) 91 | hook.post_fwd_exec(submodule, *args) 92 | 93 | def _pre_backward_module_hook(submodule, inputs, output): 94 | def _run_before_backward_function(submodule): 95 | for hook in ophook_list: 96 | assert isinstance(submodule, torch.nn.Module) 97 | hook.pre_bwd_exec(submodule, inputs, output) 98 | 99 | return _apply_to_tensors_only(submodule, PreBackwardFunction, 100 | _run_before_backward_function, output) 101 | 102 | def _post_backward_module_hook(submodule, inputs): 103 | def _run_after_backward_function(submodule): 104 | for hook in ophook_list: 105 | assert isinstance(submodule, torch.nn.Module) 106 | hook.post_bwd_exec(submodule, inputs) 107 | 108 | return _apply_to_tensors_only(submodule, PostBackwardFunction, 109 | _run_after_backward_function, inputs) 110 | 111 | module.register_forward_pre_hook(_pre_forward_module_hook) 112 | module.register_forward_hook(_post_forward_module_hook) 113 | 114 | module.register_forward_hook(_pre_backward_module_hook) 115 | module.register_forward_pre_hook(_post_backward_module_hook) --------------------------------------------------------------------------------