├── README.md ├── pyproject.toml ├── tests ├── __init__.py ├── config │ ├── bert.yaml │ └── resnet.yaml ├── download_example.py └── test_unit.py └── torchprep ├── __init__.py ├── distillation.py ├── format.py ├── fusion.py ├── main.py ├── profile.py ├── pruning.py ├── quantization.py ├── runtime.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Torchprep 2 | 3 | A CLI tool to prepare your Pytorch models for efficient inference. The only prerequisite is a model trained and saved with `torch.save(model_name, model_path)`. See `example.py` for an example. 4 | 5 | **Be warned**: `torchprep` is an experimental tool so expect bugs, deprecations and limitations. That said if you like the project and would like to improve it please open up a Github issue! 6 | 7 | ## Install from source 8 | 9 | Create a virtual environment 10 | 11 | ```sh 12 | apt-get install python3-venv 13 | python3 -m venv venv 14 | source venv/bin/activate 15 | ``` 16 | 17 | Install `poetry` 18 | 19 | ```sh 20 | sudo python3 -m pip install -U pip 21 | sudo python3 -m pip install -U setuptools 22 | pip install poetry 23 | ``` 24 | 25 | Install `torchprep` 26 | 27 | ```sh 28 | cd torchprep 29 | poetry install 30 | ``` 31 | 32 | ## Install from Pypi 33 | 34 | ```sh 35 | pip install torchprep 36 | ``` 37 | 38 | ## Usage 39 | 40 | ```sh 41 | torchprep quantize --help 42 | ``` 43 | 44 | ### Example 45 | 46 | ```sh 47 | # Install example dependencies 48 | pip install torchvision transformers 49 | 50 | # Download resnet and bert example 51 | python tests/download_example.py 52 | 53 | # quantize a cpu model with int8 on cpu and profile with a float tensor of shape [64,3,7,7] 54 | torchprep quantize models/resnet152.pt int8 55 | ``` 56 | 57 | ### Profile 58 | 59 | To profile a model you need to create a `yaml` file describing your model input shape. The YAML can accept multiple inputs 60 | 61 | ```yaml 62 | # restnet.yaml 63 | input: 64 | dtype: "int8" 65 | device: "cpu" 66 | shape: [16, 3, 7, 7] # the first element is the batch size 67 | ``` 68 | 69 | Then you can pass in the `yaml` file to `torchprep` 70 | 71 | ```sh 72 | # profile a model for a 100 iterations 73 | torchprep profile models/resnet152.pt --iterations 100 --device cpu --input-shape config/resnet.yaml 74 | 75 | # set omp threads to 1 to optimize cpu inference 76 | torchprep env --device cpu 77 | 78 | # Prune 30% of model weights 79 | torchprep prune models/resnet152.pt --prune-amount 0.3 80 | ``` 81 | 82 | 83 | ### Available commands 84 | 85 | 86 | ``` 87 | Usage: torchprep [OPTIONS] COMMAND [ARGS]... 88 | 89 | Options: 90 | --install-completion Install completion for the current shell. 91 | --show-completion Show completion for the current shell, to copy it or 92 | customize the installation. 93 | --help Show this message and exit. 94 | 95 | Commands: 96 | distill Create a smaller student model by setting a distillation... 97 | prune Zero out small model weights using l1 norm 98 | env-variables Set environment variables for optimized inference. 99 | fuse Supports optimizations including conv/bn fusion, dropout... 100 | profile Profile model latency 101 | quantize Quantize a saved torch model to a lower precision float... 102 | ``` 103 | 104 | ### Usage instructions for a command 105 | 106 | `torchprep --help` 107 | 108 | ``` 109 | Usage: torchprep quantize [OPTIONS] MODEL_PATH PRECISION:{int8|float16} 110 | 111 | Quantize a saved torch model to a lower precision float format to reduce its 112 | size and latency 113 | 114 | Arguments: 115 | MODEL_PATH [required] 116 | PRECISION:{int8|float16} [required] 117 | 118 | Options: 119 | --device [cpu|gpu] [default: Device.cpu] 120 | --input-shape TEXT Comma separated input tensor shape 121 | --help Show this message and exit. 122 | ``` 123 | 124 | ## Dev instructions 125 | 126 | ### Run tests 127 | 128 | ```sh 129 | pytest --disable-pytest-warnings 130 | ``` 131 | 132 | ### Create binaries 133 | 134 | To create binaries and test them out locally 135 | 136 | ```sh 137 | poetry build 138 | pip install --user /path/to/wheel 139 | ``` 140 | 141 | ### Upload to Pypi 142 | 143 | ```sh 144 | poetry config pypi-token.pypi 145 | poetry publish --build 146 | ``` 147 | 148 | ## Roadmap 149 | * [x] Supporting add custom model names and output paths 150 | * [x] Support multiple input tensors for models like BERT that expect a batch size and sequence length 151 | * [x] Support multiple input tensor types 152 | * [x] Print environment variables 153 | * [x] TensorRT 154 | * [x] IPEX 155 | 156 | ### Short term 157 | * [ ] Integrate into universal benchmark tool `serve/benchmarks` 158 | * [ ] Automatic distillation example: Reduce parameter count by 1/3 `torchprep distill model.pt 1/3` 159 | * [ ] Training aware optimizations 160 | 161 | ### Medium term 162 | * [ ] Get model input shape with type annotations - [solution exists in Python 3.11 only](https://github.com/pytorch/serve/issues/1505) 163 | * [ ] Automated release with github actions - low priority for now 164 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "torchprep" 3 | version = "0.2.0" 4 | description = "The easiest way to prepare Pytorch models for efficient inference" 5 | authors = ["Mark Saroufim "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.scripts] 9 | torchprep = "torchprep.main:app" 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.8" 13 | typer = {extras = ["all"], version = "^0.4.0"} 14 | torch = "^1.9.1" 15 | tqdm = "^4.62.3" 16 | PyYAML = "^6.0" 17 | 18 | [tool.poetry.dev-dependencies] 19 | pytest = "^5.2" 20 | torchvision = "^0.12.0" 21 | transformers = "^4.18.0" 22 | scalene = "^1.5.8" 23 | torch-tb-profiler = "^0.4.0" 24 | 25 | 26 | [build-system] 27 | requires = ["poetry-core>=1.0.0"] 28 | build-backend = "poetry.core.masonry.api" 29 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/msaroufim/torchprep/0bbbefb047b301db874acc117e052c2c18b53f06/tests/__init__.py -------------------------------------------------------------------------------- /tests/config/bert.yaml: -------------------------------------------------------------------------------- 1 | # Each model can have multiple inputs 2 | # Each input has a dtype, device and shape 3 | # Some shapes can be arbitrary like a positive integer for batch size 4 | 5 | # HuggingFace BERT model 6 | input1: # batch size 7 | dtype: "int32" 8 | device: "cpu" 9 | shape: [1,1] # the first element is the batch size and second element is the sequence length 10 | high : 10 # max value 11 | mode : "latency" 12 | -------------------------------------------------------------------------------- /tests/config/resnet.yaml: -------------------------------------------------------------------------------- 1 | # Each model can have multiple inputs 2 | # Each input has a dtype, device and shape 3 | # Some shapes can be arbitrary like a positive integer for batch size 4 | 5 | # restnet.yaml 6 | input: 7 | dtype: "float16" # Extend this to a list later, will need to print all permutations 8 | device: "cpu" 9 | shape: [16, 3, 7, 7] # the first element is the batch size 10 | high : 10 # max value -------------------------------------------------------------------------------- /tests/download_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 3 | import torchvision 4 | import os 5 | 6 | 7 | def download_bert(): 8 | tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") 9 | model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") 10 | torch.save(model, "models/bert.pt") 11 | 12 | 13 | def download_resnet(): 14 | model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet152', pretrained=True) 15 | torch.save(model, "models/resnet152.pt") 16 | 17 | 18 | def main(): 19 | if not os.path.exists("models"): 20 | os.makedirs("models") 21 | 22 | if len(os.listdir("models")) == 0: 23 | download_resnet() 24 | download_bert() 25 | else: 26 | print("models directory is not empty") 27 | 28 | 29 | if __name__ == "__main__": 30 | main() -------------------------------------------------------------------------------- /tests/test_unit.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchprep import __version__ 5 | from torchprep.format import ( 6 | Precision, 7 | Profiler, 8 | materialize_tensors, 9 | parse_input_format, 10 | ) 11 | from torchprep.fusion import _fuse 12 | from torchprep.pruning import _prune 13 | from torchprep.quantization import _quantize 14 | from torchprep.utils import ToyNet, load_model, profile_model 15 | 16 | from .download_example import main 17 | 18 | 19 | # General tests 20 | def test_version(): 21 | assert __version__ == "0.2.0" 22 | 23 | 24 | def test_download(): 25 | main() 26 | assert len(os.listdir("models")) > 0 27 | 28 | 29 | def test_prune(): 30 | model_path = os.path.join( 31 | os.path.dirname(os.path.realpath(__file__)), "models", "resnet152.pt" 32 | ) 33 | pruned_model = _prune(model_path=model_path, prune_amount=0.3) 34 | assert isinstance(pruned_model, torch.nn.Module) 35 | 36 | 37 | def test_profile(): 38 | net = ToyNet() 39 | result = profile_model( 40 | model=net, 41 | custom_profiler=Profiler.nothing, 42 | input_tensors=[torch.randn(10)], 43 | label="toy_profile", 44 | iterations=100, 45 | ) 46 | assert len(result) == 3 47 | 48 | 49 | def test_quantization(): 50 | model_path = os.path.join( 51 | os.path.dirname(os.path.realpath(__file__)), "models", "resnet152.pt" 52 | ) 53 | quantized_model = _quantize(model_path=model_path, precision=Precision.float16) 54 | assert isinstance(quantized_model, torch.nn.Module) 55 | 56 | 57 | def test_fuse(): 58 | # TODO: Fusion needs to know the input shape 59 | model_path = os.path.join( 60 | os.path.dirname(os.path.realpath(__file__)), "models", "resnet152.pt" 61 | ) 62 | input_shape = os.path.join( 63 | os.path.dirname(os.path.realpath(__file__)), "config", "resnet.yaml" 64 | ) 65 | fused_model = _fuse(model_path=model_path, input_shape=input_shape) 66 | print(f"Type of fused model: {type(fused_model)}") 67 | if fused_model: 68 | assert isinstance(fused_model, torch.nn.Module) 69 | 70 | # Model is not torchscriptable 71 | assert True == True 72 | 73 | 74 | def test_format(): 75 | config_file = os.path.join( 76 | os.path.dirname(os.path.realpath(__file__)), "config", "resnet.yaml" 77 | ) 78 | tensors = materialize_tensors(parse_input_format(config_file)) 79 | assert tensors[0].shape == torch.Size([1, 3, 7, 7]) 80 | 81 | 82 | def test_multiple_input(): 83 | config_file = os.path.join( 84 | os.path.dirname(os.path.realpath(__file__)), "config", "bert.yaml" 85 | ) 86 | tensors = materialize_tensors(parse_input_format(config_file)) 87 | model_path = os.path.join( 88 | os.path.dirname(os.path.realpath(__file__)), "models", "bert.pt" 89 | ) 90 | model = load_model(model_path) 91 | profile_model(model, Profiler.nothing, tensors) 92 | 93 | assert True == True 94 | 95 | 96 | def test_runtime_export(): 97 | # if runtime is installed run test 98 | return NotImplemented 99 | -------------------------------------------------------------------------------- /torchprep/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.2.0' 2 | -------------------------------------------------------------------------------- /torchprep/distillation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from typing import List 4 | from .format import Device 5 | 6 | def _distill(model_path : Path, device : Device = Device.cpu, parameter_scaling : int = 2, layer_scaling : int = None, profile : List[int] = None) -> torch.nn.Module: 7 | print(f"Coming soon") 8 | print("See this notebook for more information https://colab.research.google.com/drive/1RzQtprrHx8PokLQsFiQPAKzfn_DiTpDN?usp=sharing") -------------------------------------------------------------------------------- /torchprep/format.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Dict, List, Union 3 | 4 | import torch 5 | import yaml 6 | 7 | 8 | class Precision(Enum): 9 | int8 = "int8" 10 | float16 = "float16" 11 | 12 | 13 | class Device(str, Enum): 14 | cpu = "cpu" 15 | cuda = "cuda" 16 | 17 | 18 | class Profiler(str, Enum): 19 | nothing = "nothing" 20 | torchtbprofiler = "torchtbprofiler" 21 | scalene = "scalene" 22 | 23 | 24 | dtype_map = { 25 | # randn 26 | "float32": torch.float32, 27 | "float": torch.float, 28 | "float64": torch.float64, 29 | "half": torch.half, 30 | "float16": torch.float16, 31 | "bfloat16": torch.bfloat16, 32 | "complex64": torch.complex64, 33 | "complex128": torch.complex128, 34 | "cdouble": torch.cdouble, 35 | # randint 36 | "uint8": torch.uint8, 37 | "int8": torch.int8, 38 | "int16": torch.int16, 39 | "short": torch.short, 40 | "int32": torch.int32, 41 | "int": torch.int, 42 | "int64": torch.int64, 43 | "long": torch.long, 44 | "bool": torch.bool, 45 | "quint8": torch.qint8, 46 | "qint8": torch.qint8, 47 | # "qfint32" : torch.qfint32, 48 | # "qint4x2" : torch.quint4x2, 49 | } 50 | 51 | device_map = { 52 | "CPU": torch.device("cpu"), 53 | "cpu": torch.device("cpu"), 54 | "gpu": torch.device("cuda"), 55 | "GPU": torch.device("cuda"), 56 | "cuda": torch.device("cuda"), 57 | } 58 | 59 | # Helper function to serialize a dictionary 60 | # def freeze(d): 61 | # if isinstance(d, dict): 62 | # return frozenset((key, freeze(value)) for key, value in d.items()) 63 | # elif isinstance(d, list): 64 | # return tuple(freeze(value) for value in d) 65 | # return d 66 | 67 | 68 | def parse_input_format( 69 | filename: str = "example.yaml", 70 | ) -> Dict[str, Union[int, List[int]]]: 71 | with open(filename, "r") as f: 72 | try: 73 | parsed_yaml = yaml.safe_load(f) 74 | return parsed_yaml 75 | except yaml.YAMLError as exc: 76 | print(exc) 77 | 78 | 79 | def materialize_tensors(yaml_dict) -> List[torch.Tensor]: 80 | tensor_list = [] 81 | for key, value in yaml_dict.items(): 82 | # If a new input is found 83 | if key.startswith("input"): 84 | input_params = value 85 | 86 | for input_key, input_value in input_params.items(): 87 | if input_key == "shape": 88 | shape = input_value 89 | elif input_key == "dtype": 90 | dtype = input_value 91 | elif input_key == "device": 92 | device = input_value 93 | elif input_key == "high": 94 | high = input_value 95 | 96 | if dtype in [ 97 | "float32", 98 | "float", 99 | "float64", 100 | "half", 101 | "float16", 102 | "bfloat16", 103 | "complex64", 104 | "complex128", 105 | "cdouble", 106 | ]: 107 | x = ( 108 | torch.randn( 109 | *shape, dtype=dtype_map[dtype], device=device_map[device] 110 | ) 111 | * high 112 | ) 113 | 114 | elif dtype in [ 115 | "uint8", 116 | "int8", 117 | "int16", 118 | "short", 119 | "int32", 120 | "int", 121 | "int64", 122 | "long", 123 | "bool", 124 | "quint8", 125 | "qint8", 126 | ]: 127 | x = torch.randint( 128 | low=0, high=high, size=tuple(shape), dtype=dtype_map[dtype] 129 | ) 130 | else: 131 | print("dtype {dtype} is no supported") 132 | return 133 | 134 | tensor_list.append(x) 135 | return tensor_list 136 | -------------------------------------------------------------------------------- /torchprep/fusion.py: -------------------------------------------------------------------------------- 1 | #TODO: Add helpers with triton for optimized fused ops 2 | import torch 3 | from pathlib import Path 4 | from typing import Optional 5 | from .utils import load_model 6 | from .format import Device, materialize_tensors, parse_input_format 7 | 8 | 9 | def _fuse(model_path : Path, input_shape : Path, output_name : str = "fused_model.pt", device : Device = Device.cpu) -> Optional[torch.nn.Module]: 10 | model = load_model(model_path, device) 11 | input_tensors = materialize_tensors(parse_input_format(input_shape)) 12 | 13 | try: 14 | model = torch.jit.trace(model,input_tensors) 15 | except Exception as e: 16 | print(f"{model_path} is not torchscriptable") 17 | return 18 | 19 | optimized_model = torch.jit.optimize_for_inference(model) 20 | 21 | torch.save(optimized_model, output_name) 22 | return optimized_model -------------------------------------------------------------------------------- /torchprep/main.py: -------------------------------------------------------------------------------- 1 | import typer 2 | from pathlib import Path 3 | import torch 4 | from .format import Device, Profiler, Precision 5 | from typing import List, Optional 6 | 7 | from .distillation import _distill 8 | from .fusion import _fuse 9 | from .profile import _profile 10 | from .quantization import _quantize 11 | from .pruning import _prune 12 | from .runtime import _export_to_runtime, Runtime, _env 13 | 14 | app = typer.Typer() 15 | 16 | @app.command() 17 | def distill(model_path : Path, device : Device = Device.cpu, parameter_scaling : int = 2, layer_scaling : int = None) -> torch.nn.Module: 18 | """ 19 | [Coming soon]: Create a smaller student model by setting a distillation ratio and teach it how to behave exactly like your existing model 20 | """ 21 | return _distill(model_path, device, parameter_scaling, layer_scaling) 22 | 23 | @app.command() 24 | def fuse(model_path : Path, input_shape : Path, output_name : str = "fused_model.pt", device : Device = Device.cpu) -> Optional[torch.nn.Module]: 25 | """ 26 | Supports optimizations including conv/bn fusion, dropout removal and mkl layout optimizations 27 | Works only for models that are scriptable 28 | """ 29 | return _fuse(model_path, input_shape, output_name, device) 30 | 31 | @app.command() 32 | def prune(model_path : Path, output_name : str = "pruned_model.pt", prune_amount : float = typer.Option(default=0.3, help=" 0 < prune_amount < 1 Percentage of connections to prune"), device : Device = Device.cpu) -> torch.nn.Module: 33 | """ 34 | Zero out small model weights using l1 norm 35 | """ 36 | return _prune(model_path, output_name, prune_amount, device) 37 | 38 | @app.command() 39 | def quantize(model_path : Path, precision : Precision ,output_name : str = "quantized_model.pt", device : Device = Device.cpu) -> torch.nn.Module: 40 | """ 41 | Quantize a saved torch model to a lower precision float format to reduce its size and latency 42 | """ 43 | return _quantize(model_path, precision, output_name, device) 44 | 45 | @app.command 46 | def profile(model_path : Path, input_shape : Path, profiler : Profiler = Profiler.nothing, iterations : int = 100, device : Device = Device.cpu) -> List[float]: 47 | """ 48 | Profile model latency given an input yaml file 49 | """ 50 | return _profile(model_path, input_shape, profiler, iterations, device) 51 | 52 | @app.command() 53 | def export_to_runtime(model_path : Path, runtime : Runtime, input_shape : Path, device : Device, output_name : str = "optimized_model.pt"): 54 | """ 55 | [Not Tested]: Do not use 56 | Export your model to an optimized runtime for accelerated inference 57 | """ 58 | return _export_to_runtime(model_path, runtime, input_shape, device, output_name) 59 | 60 | @app.command() 61 | def env(device : Device = Device.cpu, omp_num_threads : int = 1, kmp_blocktime : int = 1) -> None: 62 | """ 63 | Set optimized environment variables 64 | """ 65 | return _env(device, omp_num_threads, kmp_blocktime) -------------------------------------------------------------------------------- /torchprep/profile.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List 3 | from .format import materialize_tensors, parse_input_format, Device, Profiler 4 | from .utils import profile_model, load_model 5 | from enum import Enum 6 | 7 | def _profile(model_path : Path, input_shape : Path, profiler : Profiler = Profiler.nothing, iterations : int = 100, device : Device = Device.cpu) -> List[float]: 8 | if iterations < 100: 9 | print("Please set iterations > 100") 10 | return 11 | model = load_model(model_path, device) 12 | 13 | input_tensors = materialize_tensors(parse_input_format(input_shape)) 14 | 15 | return profile_model(model,profiler, input_tensors,model_path,iterations) -------------------------------------------------------------------------------- /torchprep/pruning.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.nn.utils.prune 5 | import typer 6 | 7 | from .format import Device 8 | from .utils import load_model 9 | 10 | 11 | def _prune( 12 | model_path: Path, 13 | output_name: str = "pruned_model.pt", 14 | prune_amount: float = typer.Option( 15 | default=0.3, help=" 0 < prune_amount < 1 Percentage of connections to prune" 16 | ), 17 | device: Device = Device.cpu, 18 | ) -> torch.nn.Module: 19 | model = load_model(model_path, device) 20 | 21 | for name, module in model.named_modules(): 22 | if ( 23 | isinstance(module, torch.nn.Conv2d) 24 | or isinstance(module, torch.nn.Linear) 25 | or isinstance(module, torch.nn.LSTM) 26 | ): 27 | torch.nn.utils.prune.l1_unstructured(module, "weight", prune_amount) 28 | 29 | torch.save(model, output_name) 30 | print("Saved prune model {output_name}") 31 | return model 32 | -------------------------------------------------------------------------------- /torchprep/quantization.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from .format import Device, Precision 6 | from .utils import load_model 7 | 8 | 9 | def _quantize( 10 | model_path: Path, 11 | precision: Precision, 12 | output_name: str = "quantized_model.pt", 13 | device: Device = Device.cpu, 14 | ) -> torch.nn.Module: 15 | 16 | model = load_model(model_path, device) 17 | 18 | if device == Device.cpu: 19 | if precision == Precision.int8: 20 | dtype = torch.qint8 21 | elif precision == Precision.float16: 22 | dtype = torch.float16 23 | else: 24 | print("unsupported {precision}") 25 | return 26 | 27 | quantized_model = torch.quantization.quantize_dynamic( 28 | model, {torch.nn.LSTM, torch.nn.Linear, torch.nn.Conv2d}, dtype=dtype 29 | ) 30 | # TODO: Add AMP 31 | if device == Device.cuda: 32 | if precision == Precision.int8: 33 | print("int8 precision is not supported for GPUs, defaulting to float16") 34 | quantized_model = model.half() 35 | 36 | print("Model successfully quantized") 37 | 38 | torch.save(quantized_model, output_name) 39 | print(f"model {output_name} was saved") 40 | return quantized_model 41 | -------------------------------------------------------------------------------- /torchprep/runtime.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # import torch2trt 3 | # import ipex 4 | # import ort 5 | import os 6 | from enum import Enum 7 | from pathlib import Path 8 | from .format import materialize_tensors, parse_input_format, Device 9 | from .utils import load_model 10 | 11 | 12 | class Runtime(str, Enum): 13 | ipex = "ipex" 14 | tensorrt = "tensorrt" 15 | fastertransformer = "fastertransformer" 16 | 17 | def _export_to_runtime(model_path : Path, runtime : Runtime, input_shape : Path, device : Device, output_name : str = "optimized_model.pt"): 18 | """ 19 | [Not Tested]: Do not use 20 | Export your model to an optimized runtime for accelerated inference 21 | """ 22 | model = load_model(model_path) 23 | input_tensors = materialize_tensors(parse_input_format(input_shape)) 24 | 25 | if runtime == Runtime.ipex: 26 | optimized_model = ipex.optimize(model) 27 | # elif runtime == Runtime.tensorrt: 28 | # optimized_model = torch2trt(model, input_tensors) 29 | # elif runtime == Runtime.ort: 30 | # options = ort.SessionOptions() 31 | # return ort.InferenceSession(model, options) 32 | 33 | torch.save(optimized_model, output_name) 34 | 35 | def _env(device : Device = Device.cpu, omp_num_threads : int = 1, kmp_blocktime : int = 1) -> None: 36 | """ 37 | [Experimental]: Set environment variables for optimized inference. Run this command on the machine where inference will happen! 38 | """ 39 | if device == Device.cpu: 40 | os.environ["OMP_NUM_THREADS"] = omp_num_threads 41 | os.environ["KMP_BLOCKTIME"] = kmp_blocktime 42 | else: 43 | print(f"support for architecture {device} coming soon") -------------------------------------------------------------------------------- /torchprep/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | from typing import Dict, List 5 | 6 | import torch 7 | from torch import nn 8 | from tqdm import tqdm 9 | 10 | from .format import Profiler 11 | 12 | 13 | def load_model(model_path: str, device="cpu") -> torch.nn.Module: 14 | map_location = torch.device(device) 15 | model = torch.load(model_path, map_location=map_location) 16 | return model 17 | 18 | 19 | def print_size_of_model(model: torch.nn.Module, label: str = ""): 20 | torch.save(model.state_dict(), "temp.p") 21 | size = os.path.getsize("temp.p") 22 | print("model: ", label, ":", "Size (MB):", size / 1e6) 23 | os.remove("temp.p") 24 | return size 25 | 26 | 27 | def print_environment_variables() -> Dict[str, str]: 28 | print(os.environ) 29 | 30 | 31 | def profile_model( 32 | model: torch.nn.Module, 33 | custom_profiler: Profiler, 34 | input_tensors: List[torch.tensor], 35 | label: str = "model", 36 | iterations: int = 100, 37 | ) -> List[float]: 38 | print("Starting profile") 39 | print_size_of_model(model, label) 40 | 41 | if custom_profiler == Profiler.scalene: 42 | from scalene import scalene_profiler 43 | 44 | scalene_profiler.start() 45 | 46 | if custom_profiler == Profiler.torchtbprofiler: 47 | print("Torch tensorboard profiler not yet supported") 48 | 49 | print(f"input_tensors: {input_tensors}") 50 | 51 | warmup_iterations = iterations // 10 52 | for step in range(warmup_iterations): 53 | model(*input_tensors) 54 | 55 | durations = [] 56 | for step in tqdm(range(iterations)): 57 | tic = time.time() 58 | model(*input_tensors) 59 | toc = time.time() 60 | duration = toc - tic 61 | duration = math.trunc(duration * 1000) 62 | durations.append(duration) 63 | avg = sum(durations) / len(durations) 64 | min_latency = min(durations) 65 | max_latency = max(durations) 66 | print(f"Average latency for {label} is: {avg} ms") 67 | print(f"Min latency for {label} is: {min_latency} ms") 68 | print(f"Max p99 latency for {label} is: {max_latency} ms") 69 | 70 | if custom_profiler == Profiler.scalene: 71 | scalene_profiler.stop() 72 | 73 | return [avg, min_latency, max_latency] 74 | 75 | 76 | class ToyNet(nn.Module): 77 | def __init__(self): 78 | super(ToyNet, self).__init__() 79 | self.fc = nn.Linear(10, 1) 80 | 81 | def forward(self, x): 82 | return self.fc(x) 83 | --------------------------------------------------------------------------------