├── src └── owlite │ ├── core │ ├── __init__.py │ ├── cli │ │ ├── api │ │ │ ├── __init__.py │ │ │ └── login.py │ │ ├── commands │ │ │ ├── __init__.py │ │ │ ├── device_commands.py │ │ │ ├── url_commands.py │ │ │ └── user_commands.py │ │ ├── __init__.py │ │ ├── owlite_cli.py │ │ ├── url.py │ │ ├── device.py │ │ └── login.py │ ├── exceptions.py │ ├── cache │ │ ├── __init__.py │ │ ├── tokens.py │ │ ├── device.py │ │ ├── workspace.py │ │ ├── base_urls.py │ │ └── text.py │ ├── github_utils.py │ ├── constants.py │ ├── device_settings.py │ ├── settings.py │ └── logger.py │ ├── backend │ ├── onnx │ │ ├── __init__.py │ │ ├── onnx_op.py │ │ └── optimize.py │ ├── __init__.py │ ├── fx │ │ ├── __init__.py │ │ ├── types.py │ │ ├── passes │ │ │ ├── decompose_expm1.py │ │ │ ├── decompose_silu.py │ │ │ ├── eliminate_dummy_output.py │ │ │ ├── rewrite_pass.py │ │ │ ├── eliminate_nop_getitem.py │ │ │ ├── decompose_in_projection.py │ │ │ ├── eliminate_identity.py │ │ │ ├── fuse_consecutive_concats.py │ │ │ ├── rewrite_layernorms_functional.py │ │ │ ├── __init__.py │ │ │ ├── eliminate_explicit_getitem.py │ │ │ ├── connect_inplace_ops_to_users.py │ │ │ ├── fix_hard_coded_devices.py │ │ │ ├── utils.py │ │ │ ├── decompose_transformer.py │ │ │ ├── decompose_transformer_decoder.py │ │ │ ├── node_argument.py │ │ │ ├── decompose_transformer_encoder.py │ │ │ └── decompose_in_projection_packed.py │ │ ├── serialize.py │ │ ├── filter_warnings.py │ │ ├── trace.py │ │ ├── optimize.py │ │ ├── target.py │ │ └── node.py │ └── config.py │ ├── api │ ├── __init__.py │ ├── utils.py │ └── project.py │ ├── options │ ├── channel.py │ ├── __init__.py │ ├── quantization_options.py │ ├── tensor_type.py │ ├── onnx_export_options.py │ ├── dynamic_input_options.py │ ├── generic_type_checking.py │ ├── options_mixin.py │ └── options_dict.py │ ├── calib │ ├── __init__.py │ ├── percentile_calibrator.py │ └── minmax_calibrator.py │ ├── enums │ ├── forward_param_status.py │ ├── __init__.py │ ├── price_plan.py │ ├── model_status.py │ ├── dtype.py │ ├── runtime.py │ ├── benchmark_status.py │ ├── annotations.py │ ├── ptq_calibration_type.py │ ├── qat_backward_type.py │ └── target_dtype.py │ ├── nn │ ├── __init__.py │ ├── functions │ │ ├── ste_fp.py │ │ ├── __init__.py │ │ ├── clq.py │ │ ├── ste.py │ │ └── fake_fp_quantize.py │ └── modules │ │ ├── __init__.py │ │ ├── qmodule_mixins.py │ │ ├── qlinear.py │ │ ├── qconvbn.py │ │ └── granularity_mixin.py │ ├── __init__.py │ ├── compression.py │ └── calibrators.py ├── .gitignore ├── NOTICE ├── SECURITY.md ├── setup.py ├── pyproject.toml └── README.md /src/owlite/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/owlite/core/cli/api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/owlite/core/cli/commands/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/owlite/backend/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | from .export import export 2 | from .optimize import optimize, optimize_path 3 | -------------------------------------------------------------------------------- /src/owlite/api/__init__.py: -------------------------------------------------------------------------------- 1 | from .baseline import Baseline 2 | from .experiment import Experiment 3 | from .project import Project 4 | -------------------------------------------------------------------------------- /src/owlite/backend/__init__.py: -------------------------------------------------------------------------------- 1 | """Patches required for torch.compile and torch.onnx.export.""" 2 | 3 | from . import config, fx, onnx, patches 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.log 4 | *.out 5 | *.tar 6 | *.pth 7 | *.pt 8 | *.csv 9 | *.onnx 10 | *.bin 11 | *.engine 12 | __pycache__/ 13 | .DS_Store 14 | .vscode 15 | owlite.egg-info 16 | build 17 | -------------------------------------------------------------------------------- /src/owlite/options/channel.py: -------------------------------------------------------------------------------- 1 | from .options_mixin import OptionsMixin 2 | 3 | 4 | class Channel(OptionsMixin): 5 | """The channel axis and size of a tensor.""" 6 | 7 | axis: int 8 | size: int 9 | -------------------------------------------------------------------------------- /src/owlite/core/exceptions.py: -------------------------------------------------------------------------------- 1 | class LoginError(Exception): 2 | """Exception raised for login-related errors.""" 3 | 4 | 5 | class DeviceError(Exception): 6 | """Exception raised for device-related errors.""" 7 | -------------------------------------------------------------------------------- /src/owlite/calib/__init__.py: -------------------------------------------------------------------------------- 1 | from .entropy_calibrator import EntropyCalibrator 2 | from .minmax_calibrator import MinmaxCalibrator 3 | from .mse_calibrator import MSECalibrator 4 | from .percentile_calibrator import PercentileCalibrator 5 | -------------------------------------------------------------------------------- /src/owlite/enums/forward_param_status.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class ForwardParamStatus(IntEnum): 5 | """The three possible statuses of an input parameter of `forward` method after tracing.""" 6 | 7 | ALIVE = 0 8 | KEPT = 1 9 | PURGED = 2 10 | -------------------------------------------------------------------------------- /src/owlite/core/cache/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | OWLITE_CACHE_PATH = Path( 5 | os.getenv( 6 | "OWLITE_CACHE_DIR", 7 | os.path.join(os.path.expanduser("~"), ".cache", "owlite"), 8 | ) 9 | ).resolve() 10 | 11 | OWLITE_CACHE_PATH.mkdir(parents=True, exist_ok=True) 12 | -------------------------------------------------------------------------------- /src/owlite/enums/__init__.py: -------------------------------------------------------------------------------- 1 | from .annotations import get_before_validator, serialize_as_name 2 | from .model_status import ModelStatus 3 | from .price_plan import PricePlan 4 | from .ptq_calibration_type import PTQCalibrationType 5 | from .qat_backward_type import QATBackwardType 6 | from .runtime import Runtime 7 | from .target_dtype import TargetDType 8 | -------------------------------------------------------------------------------- /src/owlite/core/cache/tokens.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | 3 | 4 | class Tokens(BaseModel): 5 | """Represents tokens. 6 | 7 | Attributes: 8 | access_token (str): access token for OwLite login. 9 | refresh_token (str): refresh token for OwLite login. 10 | """ 11 | 12 | access_token: str 13 | refresh_token: str 14 | -------------------------------------------------------------------------------- /src/owlite/enums/price_plan.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class PricePlan(IntEnum): 5 | """User's pricing plan.""" 6 | 7 | UNKNOWN = 0 8 | FREE = 1 9 | CLOUD = 2 10 | ON_PREM = 3 11 | 12 | @property 13 | def paid(self) -> bool: 14 | """Whether the status indicates if the plan is paid.""" 15 | return self > PricePlan.FREE 16 | -------------------------------------------------------------------------------- /src/owlite/options/__init__.py: -------------------------------------------------------------------------------- 1 | from .channel import Channel 2 | from .compression_option import CompressionOptions 3 | from .dynamic_input_options import DynamicAxisOptions, DynamicInputOptions, DynamicRangeOptions 4 | from .fake_quantizer_options import FakeQuantizerOptions 5 | from .onnx_export_options import ONNXExportOptions 6 | from .quantization_options import GraphQuantizationOptions, NodeQuantizationOptions 7 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | OwLite, (C) 2022-2023 SqueezeBits, Inc. 2 | 3 | This product includes software developed at SqueezeBits, Inc. 4 | (https://squeezebits.com/). 5 | 6 | The OwLite project contains unmodified/modified subcomponents too with 7 | separate copyright notices and license terms. Your use of the source 8 | code for these subcomponents is subject to the terms and conditions 9 | of GNU Affero General Public License 3.0. 10 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Notification of Security Vulnerabilities 4 | Should you identify a security vulnerability within OwLite, we cordially request that you inform us via email at owlite-admin@squeezebits.com. 5 | We will investigate all valid reports and do our best to fix any issues quickly. 6 | 7 | Please ensure that you incorporate a detailed description and an example of the vulnerability within the reports. 8 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_checker import ( 2 | UnsupportedAutogradFunctionCallError, 3 | UnsupportedFunctionCallError, 4 | UnsupportedModuleCallError, 5 | ) 6 | from .serialize import serialize 7 | from .trace import symbolic_trace 8 | from .transforms import ( 9 | clip_narrow_range_weights, 10 | fuse_bn, 11 | fuse_bn_into_qlinear_with_quantized_bias, 12 | qconv_bn_to_qconvbn_with_int32bias, 13 | ) 14 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/types.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | from torch.fx.graph_module import GraphModule 5 | from torch.fx.node import Target 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | 8 | TorchTarget = Target | type[torch.nn.Module] 9 | GraphModuleOrDataParallel = GraphModule | DataParallel | DistributedDataParallel 10 | Op = Literal["call_function", "call_method", "call_module", "get_attr"] 11 | -------------------------------------------------------------------------------- /src/owlite/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import ( 2 | FakeFPQuantizer, 3 | FakeINTQuantizer, 4 | FakePerChannelFPQuantizer, 5 | FakePerChannelINTQuantizer, 6 | FakePerTensorFPQuantizer, 7 | FakePerTensorINTQuantizer, 8 | FakeQuantizer, 9 | QConv1d, 10 | QConv2d, 11 | QConv3d, 12 | QConvBn1d, 13 | QConvBn2d, 14 | QConvBn3d, 15 | QLinear, 16 | disable_quantizers, 17 | enable_quantizers, 18 | ) 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | from setuptools.command.build_ext import build_ext 3 | 4 | 5 | class build_ext_(build_ext): 6 | def run(self): 7 | pass 8 | 9 | 10 | setuptools.setup( 11 | ext_modules=[ 12 | # just a dummy extension for platform specific build 13 | setuptools.Extension( 14 | name="owlite.capi", 15 | sources=[], 16 | ) 17 | ], 18 | cmdclass={"build_ext": build_ext_}, 19 | ) 20 | -------------------------------------------------------------------------------- /src/owlite/enums/model_status.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class ModelStatus(IntEnum): 5 | """The enum for specifying model status about compression with `GraphModule.meta`. 6 | 7 | Attributes: 8 | TRACED: The model is traced, but not compressed. 9 | COMPRESSED: The model is compressed, but not calibrated. 10 | CALIBRATED: The model is calibrated with `owlite.calibrate`. 11 | """ 12 | 13 | TRACED = 0 14 | COMPRESSED = 1 15 | CALIBRATED = 2 16 | -------------------------------------------------------------------------------- /src/owlite/options/quantization_options.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D205 2 | from .fake_quantizer_options import FakeQuantizerOptions 3 | from .options_dict import OptionsDict 4 | 5 | 6 | class NodeQuantizationOptions(OptionsDict[str, FakeQuantizerOptions]): 7 | """* Key (str): the input node index or a predefined key. Numeric string keys are for inter-nodal modifications, 8 | while alphabetic string keys are for intra-nodal modifications. 9 | * Value (FakeQuantizerOptions): fake quantizer options. 10 | """ 11 | 12 | 13 | class GraphQuantizationOptions(OptionsDict[str, NodeQuantizationOptions]): 14 | """* Key (str): the name of a FX node. 15 | * Value (NodeQuantizationOptions): node quantization options. 16 | """ 17 | -------------------------------------------------------------------------------- /src/owlite/enums/dtype.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | 5 | 6 | # pylint: disable=invalid-name 7 | class DType(Enum): 8 | """The enum for specifying available np.dtype objects.""" 9 | 10 | FLOAT16 = np.dtype("float16") 11 | FLOAT32 = np.dtype("float32") 12 | FLOAT64 = np.dtype("float64") 13 | 14 | UINT8 = np.dtype("uint8") 15 | INT8 = np.dtype("int8") 16 | UINT16 = np.dtype("uint16") 17 | INT16 = np.dtype("int16") 18 | UINT32 = np.dtype("uint32") 19 | INT32 = np.dtype("int32") 20 | UINT64 = np.dtype("uint64") 21 | INT64 = np.dtype("int64") 22 | 23 | COMPLEX64 = np.dtype("complex64") 24 | COMPLEX128 = np.dtype("complex128") 25 | 26 | BOOL = np.dtype("bool") 27 | -------------------------------------------------------------------------------- /src/owlite/core/github_utils.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | from .constants import OWLITE_GIT_REPO_URL 4 | 5 | 6 | def get_latest_version_from_github() -> str: 7 | """Retrieve the latest release version of the package from GitHub. 8 | 9 | Args: 10 | repo_url(str): The url of the GitHub package. 11 | 12 | Returns: 13 | str: The latest release version if successful. 14 | 15 | Raises: 16 | requests.HTTPError: If the request to GitHub fails. 17 | """ 18 | api_url = f"{OWLITE_GIT_REPO_URL}/releases/latest" 19 | response = requests.get(api_url, timeout=15) 20 | 21 | if not response.ok: 22 | response.raise_for_status() 23 | latest_release_version = response.url.split("/")[-1][1:] 24 | return latest_release_version 25 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_expm1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fx.node import Node 3 | 4 | from .rewrite_pass import RewritePass 5 | 6 | 7 | class DecomposeExpm1(RewritePass): 8 | """Decompose all occurrences of `torch.expm1(x)` by `torch.exp(x) - 1`.""" 9 | 10 | @classmethod 11 | def rewrite(cls, node: Node) -> dict[Node, Node]: 12 | if not (node.op == "call_function" and node.target is torch.expm1): 13 | return {} 14 | 15 | graph = node.graph 16 | input_node = node.all_input_nodes[0] 17 | with graph.inserting_before(node): 18 | exp_node = graph.call_function(torch.exp, args=(input_node,)) 19 | sub_node = graph.call_function(torch.sub, args=(exp_node, 1)) 20 | return {node: sub_node} 21 | -------------------------------------------------------------------------------- /src/owlite/enums/runtime.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | # pylint: disable=invalid-name 5 | class Runtime(IntEnum): 6 | """Runtimes supported by OwLite.""" 7 | 8 | Unknown = 0 9 | TensorRT = 1 10 | FuriosaSDK = 2 11 | RebelionSDK = 3 12 | QNN = 4 13 | 14 | @property 15 | def simulate_int32_bias(self) -> bool: 16 | """Whether or not this runtime requires int32 bias simulation.""" 17 | return self in (Runtime.FuriosaSDK, Runtime.QNN) 18 | 19 | @property 20 | def file_ext(self) -> str: 21 | """File extension of the runtime binary.""" 22 | match self.value: 23 | case Runtime.TensorRT: 24 | return "engine" 25 | 26 | case Runtime.QNN: 27 | return "qnn.bin" 28 | 29 | return "engine" 30 | -------------------------------------------------------------------------------- /src/owlite/backend/onnx/onnx_op.py: -------------------------------------------------------------------------------- 1 | from .op_schema import OpSchema, get_core_operator_schemas 2 | 3 | 4 | class ONNXOp: 5 | """Class representing each ONNX op allowing convenient access to its schema properties.""" 6 | 7 | schemas: dict[str, OpSchema] = get_core_operator_schemas() 8 | 9 | def __init__(self, name: str) -> None: 10 | self.name = name 11 | 12 | def __str__(self) -> str: 13 | return f"{self.name}" 14 | 15 | @property 16 | def is_valid(self) -> bool: 17 | """Check if the op exists in schemas.""" 18 | return self.name in ONNXOp.schemas 19 | 20 | @property 21 | def schema(self) -> OpSchema: 22 | """The full schema object of the op. 23 | 24 | Returns: 25 | OpSchema: the op schema 26 | """ 27 | return ONNXOp.schemas[self.name] 28 | -------------------------------------------------------------------------------- /src/owlite/enums/benchmark_status.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class BenchmarkStatus(IntEnum): 5 | """Benchmark job status.""" 6 | 7 | IDLE = 0 8 | PRE_FETCHING = 1 9 | UPLOADING = 2 10 | BENCHMARKING = 3 11 | BENCHMARK_DONE = 4 12 | FETCHING_ERR = -1 13 | TIMEOUT_ERR = -2 14 | BENCHMARK_ERR = -3 15 | WEIGHT_GEN_ERR = -5 16 | STATUS_NOT_FOUND = -999 17 | 18 | @property 19 | def in_progress(self) -> bool: 20 | """Whether the status indicates if the benchmark is in progress.""" 21 | return self in ( 22 | BenchmarkStatus.PRE_FETCHING, 23 | BenchmarkStatus.UPLOADING, 24 | BenchmarkStatus.BENCHMARKING, 25 | ) 26 | 27 | @property 28 | def failed(self) -> bool: 29 | """Whether the status indicates if the benchmark has failed.""" 30 | return self.value < 0 31 | -------------------------------------------------------------------------------- /src/owlite/core/cache/device.py: -------------------------------------------------------------------------------- 1 | from pydantic import AliasChoices, BaseModel, ConfigDict, Field 2 | 3 | from ...enums.runtime import Runtime 4 | 5 | 6 | class Device(BaseModel): 7 | """Represents a device. 8 | 9 | Attributes: 10 | name (str): The name of the device. 11 | runtime (Runtime): The runtime associated with the device. 12 | runtime_extra (str | None): Extra information about the device determined in code runtime. 13 | """ 14 | 15 | model_config = ConfigDict(extra="ignore") 16 | name: str = Field(validation_alias=AliasChoices("name", "device_name")) 17 | runtime: Runtime = Field(default=Runtime.TensorRT, validation_alias=AliasChoices("framework", "runtime")) 18 | runtime_extra: str | None = Field(default=None, exclude=True) 19 | 20 | def __str__(self) -> str: 21 | return f"{self.runtime_extra or self.name} [{self.runtime.name}]" 22 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_silu.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.fx.node import Node 4 | 5 | from ..node import get_target_module 6 | from .rewrite_pass import RewritePass 7 | 8 | 9 | class DecomposeSiLU(RewritePass): 10 | """Decompose all occurrences of `torch.nn.SiLU` and `F.silu` by sigmoid and mul node pairs.""" 11 | 12 | @classmethod 13 | def rewrite(cls, node: Node) -> dict[Node, Node]: 14 | if not ( 15 | isinstance(get_target_module(node), torch.nn.SiLU) or (node.op == "call_function" and node.target is F.silu) 16 | ): 17 | return {} 18 | 19 | graph = node.graph 20 | input_node = node.all_input_nodes[0] 21 | with graph.inserting_before(node): 22 | sigmoid_node = graph.call_function(F.sigmoid, args=(input_node,)) 23 | mul_node = graph.call_function(torch.mul, args=(input_node, sigmoid_node)) 24 | return {node: mul_node} 25 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/eliminate_dummy_output.py: -------------------------------------------------------------------------------- 1 | from torch.fx import GraphModule, Node 2 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 3 | 4 | 5 | class EliminateDummyOutput(PassBase): 6 | """Eliminate dummy output node inserted by `FixHardCodedDevices`.""" 7 | 8 | def call(self, graph_module: GraphModule) -> PassResult: 9 | """Eliminate dummy output node inserted by `FixHardCodedDevices`. 10 | 11 | Args: 12 | graph_module (GraphModule): the input graph module 13 | 14 | Returns: 15 | PassResult: the result of the pass 16 | """ 17 | nodes: list[Node] = [*graph_module.graph.nodes] 18 | modified = False 19 | for node in nodes: 20 | if not (node.op == "output" and graph_module.meta["canary_device_node"] in node.all_input_nodes): 21 | continue 22 | 23 | node.args = tuple(filter(lambda x: x != graph_module.meta["canary_device_node"], node.args)) 24 | modified = True 25 | break 26 | 27 | return PassResult(graph_module, modified) 28 | -------------------------------------------------------------------------------- /src/owlite/options/tensor_type.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import numpy as np 4 | 5 | from .options_mixin import OptionsMixin 6 | 7 | DType = Literal[ 8 | "float16", 9 | "float32", 10 | "float64", 11 | "uint8", 12 | "int8", 13 | "uint16", 14 | "int16", 15 | "uint32", 16 | "int32", 17 | "uint64", 18 | "int64", 19 | "complex64", 20 | "complex128", 21 | "bool", 22 | ] 23 | 24 | 25 | class TensorType(OptionsMixin): 26 | """The properties of a `torch.Tensor` object required for specifying its MLIR-style type.""" 27 | 28 | shape: tuple[int, ...] 29 | dtype: DType 30 | is_constant: bool 31 | 32 | @property 33 | def dim(self) -> int: 34 | """The number of dimensions.""" 35 | return len(self.shape) 36 | 37 | @property 38 | def numpy_dtype(self) -> np.dtype: 39 | """The numpy dtype corresponding to this `dtype` of this object.""" 40 | return np.dtype(self.dtype) 41 | 42 | def __repr__(self) -> str: 43 | header = "Constant" if self.is_constant else "Variable" 44 | return f"{header}(shape={self.shape}, dtype={self.dtype})" 45 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/serialize.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from tabulate import tabulate 4 | from torch.fx.graph_module import GraphModule 5 | 6 | from ...core.logger import log 7 | from ..utils import targetstr 8 | from .node import get_target_module 9 | 10 | 11 | def serialize(graph_module: GraphModule) -> str: 12 | """Serialize model into textual form. 13 | 14 | Args: 15 | graph_module (GraphModule): the model to be serialized 16 | 17 | Returns: 18 | serialized (str): serialized fx graph 19 | """ 20 | graph = graph_module.graph 21 | node_specs = [ 22 | [ 23 | n.op, 24 | n.name, 25 | targetstr(n.target) 26 | if n.op == "call_function" 27 | else get_target_module(n).__class__ 28 | if n.op == "call_module" 29 | else n.target, 30 | n.args, 31 | n.kwargs, 32 | ] 33 | for n in graph.nodes 34 | ] 35 | 36 | serialized = tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) 37 | 38 | if log.level <= logging.DEBUG: 39 | print(serialized) 40 | 41 | return serialized 42 | -------------------------------------------------------------------------------- /src/owlite/core/cache/workspace.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from pydantic import BaseModel 3 | from typing_extensions import Self 4 | 5 | from ...enums import PricePlan 6 | 7 | 8 | class Workspace(BaseModel): 9 | """Represents the workspace. 10 | 11 | Attributes: 12 | id (str): The ID of the workspace. 13 | name (str): The name of the workspace. 14 | plan (PricePlan): The price plan of the workspace. 15 | """ 16 | 17 | id: str 18 | name: str 19 | plan: PricePlan 20 | 21 | @classmethod 22 | def load(cls, workspace_id: str) -> Self: 23 | """Load the workspace with the given id. 24 | 25 | Args: 26 | workspace_id (str): The id of the workspace to load. 27 | 28 | Returns: 29 | Workspace: The loaded workspace. 30 | """ 31 | from ..api_base import MAIN_API_BASE # pylint: disable=import-outside-toplevel 32 | 33 | try: 34 | resp = MAIN_API_BASE.get(f"/workspaces/{workspace_id}") 35 | except requests.exceptions.HTTPError as e: 36 | raise e 37 | 38 | assert isinstance(resp, dict) 39 | 40 | return cls.model_validate(resp) 41 | -------------------------------------------------------------------------------- /src/owlite/enums/annotations.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from enum import Enum, IntEnum 3 | from typing import TypeVar 4 | 5 | # pylint: disable-next=invalid-name 6 | IntEnumType = TypeVar("IntEnumType", bound=IntEnum) 7 | 8 | 9 | def get_before_validator(int_enum_type: type[IntEnumType]) -> Callable[[int | str | IntEnumType], IntEnumType]: 10 | """Get a before validator for the given `IntEnum` subclass that converts integer or string value to it. 11 | 12 | Args: 13 | int_enum_type (type[IntEnumType]): a subclass of `IntEnum`. 14 | 15 | Returns: 16 | Callable[[int | str | IntEnumType], IntEnumType]: the function that converts either an integer representing an 17 | enum value or a string representing the name of a enum category 18 | """ 19 | 20 | def preprocess(value: int | str | IntEnumType) -> IntEnumType: 21 | if isinstance(value, int_enum_type): 22 | return value 23 | if isinstance(value, str): 24 | return int_enum_type[value] 25 | return int_enum_type(value) 26 | 27 | return preprocess 28 | 29 | 30 | def serialize_as_name(enum: Enum) -> str: 31 | """Return the name of the given `Enum` object.""" 32 | return enum.name 33 | -------------------------------------------------------------------------------- /src/owlite/core/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """BaseCLICommand class for using owlite.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from argparse import _SubParsersAction 5 | 6 | 7 | class BaseOwLiteCLICommand(ABC): 8 | """Abstract base class defining the structure for OwLite CLI commands.""" 9 | 10 | @staticmethod 11 | @abstractmethod 12 | def register_subcommand(parser: _SubParsersAction) -> None: 13 | """Abstract method to register subcommands. 14 | 15 | This method should be implemented by subclasses to register subcommands with the provided parser. 16 | 17 | Args: 18 | parser (_SubParsersAction): The ArgumentParser for registering subcommands. 19 | 20 | Raises: 21 | NotImplementedError: If the method is not implemented in a subclass. 22 | """ 23 | raise NotImplementedError() 24 | 25 | @abstractmethod 26 | def run(self) -> None: 27 | """Abstract method to execute the command logic. 28 | 29 | This method should be implemented by subclasses to define the logic executed when the command runs. 30 | 31 | Raises: 32 | NotImplementedError: If the method is not implemented in a subclass. 33 | """ 34 | raise NotImplementedError() 35 | -------------------------------------------------------------------------------- /src/owlite/core/cache/base_urls.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | from ..constants import ( 4 | NEST_URL, 5 | OWLITE_DOVE_API_BASE_URL, 6 | OWLITE_FRONT_BASE_URL, 7 | OWLITE_MAIN_API_BASE_URL, 8 | ) 9 | 10 | 11 | # pylint:disable=too-few-public-methods 12 | class BaseURLs(BaseModel): 13 | """Represents base urls. 14 | 15 | Attributes: 16 | FRONT (str): The url for OwLite front server. 17 | MAIN (str): The url for OwLite main server. 18 | DOVE (str): The url for OwLite Dove server. 19 | """ 20 | 21 | FRONT: str = Field(default=OWLITE_FRONT_BASE_URL) 22 | MAIN: str = Field(default=OWLITE_MAIN_API_BASE_URL) 23 | DOVE: str = Field(default=OWLITE_DOVE_API_BASE_URL) 24 | NEST: str = Field(default=NEST_URL) 25 | 26 | def set(self, name: str, url: str | None = None) -> None: 27 | """Set the given URL to input value or its default value. 28 | 29 | Args: 30 | name (str): The name of the URL to reset. 31 | url (str, None): The address to change. If None, reset it. 32 | """ 33 | if name in self.model_fields: 34 | if url: 35 | setattr(self, name, url) 36 | else: 37 | setattr(self, name, self.model_fields[name].default) 38 | else: 39 | raise ValueError(f"Invalid name: {name}") 40 | -------------------------------------------------------------------------------- /src/owlite/enums/ptq_calibration_type.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | from typing import TYPE_CHECKING 3 | 4 | from ..core.logger import log 5 | 6 | if TYPE_CHECKING: 7 | from ..calib.calibrator import Calibrator 8 | 9 | 10 | # pylint: disable=invalid-name 11 | class PTQCalibrationType(IntEnum): 12 | """The enum for specifying available Calibrator classes.""" 13 | 14 | absmax = 0 15 | percentile = 1 16 | mse = 2 17 | minmax = 3 18 | entropy = 4 19 | 20 | @property 21 | def calibrator_class(self) -> type["Calibrator"]: 22 | """The Calibrator class corresponding to this enum value.""" 23 | # pylint: disable-next=import-outside-toplevel 24 | from ..calib import ( 25 | EntropyCalibrator, 26 | MinmaxCalibrator, 27 | MSECalibrator, 28 | PercentileCalibrator, 29 | ) 30 | 31 | predefined_classes: dict[str, type[Calibrator]] = { 32 | "absmax": MinmaxCalibrator, 33 | "percentile": PercentileCalibrator, 34 | "mse": MSECalibrator, 35 | "minmax": MinmaxCalibrator, 36 | "entropy": EntropyCalibrator, 37 | } 38 | if self.name == "absmax": 39 | log.warning("`absmax` is deprecated and will be removed in the future release. Use `minmax` instead.") # UX 40 | return predefined_classes[self.name] 41 | -------------------------------------------------------------------------------- /src/owlite/options/onnx_export_options.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing_extensions import Self 3 | 4 | from ..core.cache.device import Device 5 | from ..core.logger import log 6 | from ..enums import Runtime 7 | 8 | 9 | class ONNXExportOptions(BaseModel): 10 | """Class handling options for ONNX export. 11 | 12 | OwLite internally imports the target model to ONNX during conversion or benchmarking. 13 | Users can set options for ONNX export using this class. 14 | """ 15 | 16 | opset_version: int = Field(default=17) 17 | 18 | @classmethod 19 | def create(cls, device: Device) -> Self: 20 | """Create a ONNXExportOptions for export on `device`. 21 | 22 | Args: 23 | device (Device): The device to benchmark. 24 | 25 | Returns: 26 | Self: Additional options for exporting ONNX compatibility with the device. 27 | """ 28 | match device.runtime: 29 | case Runtime.TensorRT: 30 | return cls() 31 | 32 | case Runtime.FuriosaSDK: 33 | log.info( 34 | "ONNX opset version will be automatically set to 13 for the compatibility with Furiosa SDK" 35 | ) # UX 36 | return cls(opset_version=13) 37 | 38 | case Runtime.QNN: 39 | return cls() 40 | 41 | case _: 42 | log.warning("Unknown device, using default ONNX export options") # UX 43 | return cls() 44 | -------------------------------------------------------------------------------- /src/owlite/core/cli/owlite_cli.py: -------------------------------------------------------------------------------- 1 | """CLI script to manage OwLite commands using argparse.""" 2 | 3 | from argparse import ArgumentParser 4 | 5 | from ..constants import OWLITE_VERSION 6 | from ..exceptions import DeviceError, LoginError 7 | from ..logger import log 8 | from .commands.device_commands import DeviceCommands 9 | from .commands.url_commands import UrlCommands 10 | from .commands.user_commands import UserCommands 11 | 12 | 13 | def main() -> None: 14 | """Set up and run OwLite CLI commands.""" 15 | parser = ArgumentParser("owlite", usage="owlite []") 16 | commands_parser = parser.add_subparsers(help="owlite command helpers") 17 | 18 | # Register commands 19 | UserCommands.register_subcommand(commands_parser) 20 | DeviceCommands.register_subcommand(commands_parser) 21 | UrlCommands.register_subcommand(commands_parser) 22 | 23 | parser.add_argument("--version", "-v", action="store_true", help="Display OwLite version") 24 | 25 | # pylint: disable-next=too-few-public-methods, missing-class-docstring 26 | class _Default: 27 | # pylint: disable-next=missing-function-docstring 28 | def run(self) -> None: 29 | parser.print_help() 30 | 31 | parser.set_defaults(func=lambda _: _Default()) 32 | args = parser.parse_args() 33 | 34 | if args.version: 35 | log.info(f"OwLite version {OWLITE_VERSION}") # UX 36 | return 37 | 38 | owlite_cli = args.func(args) 39 | try: 40 | owlite_cli.run() 41 | except (LoginError, DeviceError) as e: 42 | log.debug(e) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/rewrite_pass.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from torch.fx import GraphModule 4 | from torch.fx.node import Node 5 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 6 | 7 | 8 | class RewritePass(PassBase): 9 | """Abstract class for implementing node-wise rewriting pass.""" 10 | 11 | def call(self, graph_module: GraphModule) -> PassResult: 12 | """Apply `cls.rewrite` method across all nodes in the graph. 13 | 14 | Args: 15 | graph_module (GraphModule): the input graph module 16 | 17 | Returns: 18 | PassResult: the result of the pass 19 | """ 20 | modified = False 21 | nodes = list((graph := graph_module.graph).nodes) 22 | for node in nodes: 23 | if replacement_map := self.rewrite(node): 24 | is_replaced = [ 25 | len(existing_node.replace_all_uses_with(rewritten_node)) > 0 26 | for existing_node, rewritten_node in replacement_map.items() 27 | ] 28 | modified_here = any(is_replaced) 29 | modified = modified or modified_here 30 | if modified_here: 31 | graph.eliminate_dead_code() 32 | 33 | return PassResult(graph_module, modified) 34 | 35 | @classmethod 36 | @abstractmethod 37 | def rewrite(cls, node: Node) -> dict[Node, Node]: 38 | """Rewrite the given node. 39 | 40 | Args: 41 | node (Node): a node to rewrite 42 | 43 | Returns: 44 | dict[Node, Node]: a dictionary mapping an existing node to its replacement. 45 | """ 46 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/eliminate_nop_getitem.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import torch 4 | import torch.utils._pytree as pytree 5 | from torch.fx import GraphModule, Node 6 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 7 | 8 | 9 | class EliminateNopGetitem(PassBase): 10 | """Eliminate function calls of operator.getitem with no effect.""" 11 | 12 | def call(self, graph_module: GraphModule) -> PassResult: 13 | """Eliminate function calls of operator.getitem with no effect. 14 | 15 | Args: 16 | graph_module (GraphModule): the input graph module 17 | 18 | Returns: 19 | PassResult: the result of the pass 20 | """ 21 | nodes: list[Node] = [*graph_module.graph.nodes] 22 | modified = False 23 | for node in nodes: 24 | if not ( 25 | node.op == "call_function" 26 | and node.target is operator.getitem 27 | and len(node.args) == 2 28 | and isinstance(x := node.args[0] if node.args else node.kwargs.get("input"), Node) 29 | and isinstance(example_value := x.meta.get("example_value"), torch.Tensor) 30 | and not any(isinstance(leaf, Node) for leaf in pytree.tree_flatten(node.args[1])[0]) 31 | and operator.getitem( # type: ignore[misc] 32 | torch.zeros(example_value.shape, dtype=example_value.dtype), node.args[1] 33 | ).shape 34 | == example_value.shape 35 | ): 36 | continue 37 | 38 | modified = modified or len(node.replace_all_uses_with(x)) > 0 39 | 40 | return PassResult(graph_module, modified) 41 | -------------------------------------------------------------------------------- /src/owlite/core/cache/text.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from packaging.version import Version 4 | 5 | from ..constants import OWLITE_SETTINGS_FORMAT_VERSION 6 | from ..logger import log 7 | from . import OWLITE_CACHE_PATH 8 | 9 | 10 | def write_text(path: Path, text: str) -> None: 11 | """Write text to a file specified by the given path. 12 | 13 | Args: 14 | path (Path): The path to the file. 15 | text (str): The text to write. 16 | """ 17 | versioned_text = f"{text}@v{OWLITE_SETTINGS_FORMAT_VERSION}" 18 | path.write_text(versioned_text, encoding="utf-8") 19 | 20 | 21 | def read_text(path: Path) -> str | None: 22 | """Read text from a file specified by the given path. 23 | 24 | Args: 25 | path (Path): The path to the file. 26 | 27 | Returns: 28 | str | None: Text read from the file or None if the file is not found. 29 | """ 30 | try: 31 | cached_text = path.read_text(encoding="utf-8") 32 | cached_text = cached_text.replace("\r", "").replace("\n", "").strip() 33 | except FileNotFoundError: 34 | return None 35 | 36 | if "@v" not in cached_text: 37 | text = cached_text 38 | version = "1.0" 39 | path.write_text(f"{cached_text}@v{version}", encoding="utf-8") 40 | else: 41 | text, version = cached_text.rsplit("@v", 1) 42 | if Version(version).major < OWLITE_SETTINGS_FORMAT_VERSION.major: 43 | log.error( 44 | f"The cache version ({Version(version)}) is not supported. " 45 | f"Please remove the cache file in {OWLITE_CACHE_PATH} and retry" 46 | ) # UX 47 | raise RuntimeError("Version is not supported") 48 | return text 49 | -------------------------------------------------------------------------------- /src/owlite/core/cli/commands/device_commands.py: -------------------------------------------------------------------------------- 1 | """Handles device-related commands in OwLite CLI using argparse.""" 2 | 3 | # pylint: disable=unnecessary-lambda, too-few-public-methods 4 | from argparse import Namespace, _SubParsersAction 5 | 6 | from .. import BaseOwLiteCLICommand 7 | from ..device import connect_device, disconnect_device 8 | 9 | 10 | class DeviceCommands(BaseOwLiteCLICommand): 11 | """Handle device-related commands in OwLite CLI.""" 12 | 13 | @staticmethod 14 | def register_subcommand(parser: _SubParsersAction) -> None: 15 | """Register subcommands for device-related operations. 16 | 17 | Args: 18 | parser (_SubParsersAction): The parser object to add subcommands to. 19 | """ 20 | device_parser = parser.add_parser("device", help="Device setting from owlite") 21 | device_parser.add_argument( 22 | "mode", 23 | choices=["connect", "disconnect"], 24 | help="Device setting command", 25 | ) 26 | 27 | device_parser.set_defaults(func=lambda args: DeviceCommand(args)) 28 | 29 | 30 | class DeviceCommand: 31 | """Handles device-specific commands in OwLite CLI.""" 32 | 33 | def __init__(self, args: Namespace) -> None: 34 | """Initialize the DeviceCommand. 35 | 36 | Args: 37 | args: Arguments passed to the command. 38 | """ 39 | self.args = args 40 | 41 | def run(self) -> None: 42 | """Execute the specified device-related operation based on the mode specified.""" 43 | if self.args.mode == "connect": 44 | connect_device() 45 | elif self.args.mode == "disconnect": 46 | disconnect_device() 47 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_in_projection.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch.fx.node import Node 3 | 4 | from .node_argument import NodeArgument 5 | from .rewrite_pass import RewritePass 6 | 7 | 8 | class InProjectionNodeArgument(NodeArgument): 9 | """The arguments of a "call_function" node with target `F._in_projection`.""" 10 | 11 | q: Node 12 | k: Node 13 | v: Node 14 | w_q: Node 15 | w_k: Node 16 | w_v: Node 17 | b_q: Node | None = None 18 | b_k: Node | None = None 19 | b_v: Node | None = None 20 | 21 | @classmethod 22 | def validate_node(cls, node: Node) -> bool: 23 | return ( 24 | # pylint: disable-next=protected-access 25 | node.op == "call_function" and node.target is F._in_projection # type: ignore 26 | ) 27 | 28 | 29 | class DecomposeInProjection(RewritePass): 30 | """Decompose all occurrences of `F._in_projection` by an equivalent subgraph. 31 | 32 | Note: this rewrite pass is implemented based on torch>=2.3.1,<=2.4.0 33 | """ 34 | 35 | @classmethod 36 | def rewrite(cls, node: Node) -> dict[Node, Node]: 37 | if (arguments := InProjectionNodeArgument.extract_from(node)) is None: 38 | return {} 39 | 40 | graph = node.graph 41 | with graph.inserting_before(node): 42 | q = graph.call_function(F.linear, (arguments.q, arguments.w_q, arguments.b_q)) 43 | k = graph.call_function(F.linear, (arguments.k, arguments.w_k, arguments.b_k)) 44 | v = graph.call_function(F.linear, (arguments.v, arguments.w_v, arguments.b_v)) 45 | old_q, old_k, old_v = tuple(node.users) 46 | 47 | return {old_q: q, old_k: k, old_v: v} 48 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/eliminate_identity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.fx import GraphModule, Node 3 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 4 | 5 | from ..node import get_target_module 6 | 7 | 8 | class EliminateIdentity(PassBase): 9 | """Eliminate module calls of torch.nn.Identity.""" 10 | 11 | def call(self, graph_module: GraphModule) -> PassResult: 12 | """Eliminate module calls of torch.nn.Identity. 13 | 14 | Args: 15 | graph_module (GraphModule): the input graph module 16 | 17 | Returns: 18 | PassResult: the result of the pass 19 | """ 20 | nodes: list[Node] = [*graph_module.graph.nodes] 21 | modified = False 22 | for node in nodes: 23 | if not (node.op == "call_module" and isinstance(get_target_module(node), torch.nn.Identity)): 24 | continue 25 | 26 | x = node.args[0] if node.args else node.kwargs.get("input") 27 | usages: dict[Node, int | str] = {} 28 | for user in node.users: 29 | if node in user.args: 30 | i = user.args.index(node) 31 | usages[user] = i 32 | continue 33 | for key, value in [*node.kwargs.items()]: 34 | if value is node: 35 | usages[user] = key 36 | break 37 | 38 | modified = modified or len(usages) > 0 39 | for user, index_or_key in usages.items(): 40 | if isinstance(index_or_key, int): 41 | user.args = user.args[:index_or_key] + (x,) + user.args[index_or_key + 1 :] 42 | continue 43 | user.kwargs[index_or_key] = x 44 | 45 | return PassResult(graph_module, modified) 46 | -------------------------------------------------------------------------------- /src/owlite/options/dynamic_input_options.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D205 2 | from dataclasses import dataclass 3 | 4 | from .options_dict import OptionsDict 5 | from .options_mixin import OptionsMixin 6 | 7 | 8 | class DynamicAxisOptions(dict[str, int]): 9 | """Key (str): the name of an input tensor 10 | Value (int): the axis to be dynamic. 11 | """ 12 | 13 | 14 | # pylint: disable=redefined-builtin 15 | @dataclass 16 | class DynamicRangeOptions(OptionsMixin): 17 | """Dynamic axis range setting for benchmark.""" 18 | 19 | min: int 20 | """The minimum size along the dynamic axis""" 21 | max: int 22 | """The maximum size along the dynamic axis""" 23 | opt: int 24 | """The size along the dynamic axis for optimizing latency""" 25 | test: int 26 | """The size along the dynamic axis for running benchmark""" 27 | 28 | def check_min(self, min: int) -> bool: 29 | """Min must be positive integer.""" 30 | return 0 < min 31 | 32 | def check_max(self, max: int) -> bool: 33 | """Max must be greater or equal to min.""" 34 | return self.min <= max 35 | 36 | def check_opt(self, opt: int) -> bool: 37 | """Opt must be in between min and max(inclusive).""" 38 | return self.min <= opt <= self.max 39 | 40 | def check_test(self, test: int) -> bool: 41 | """Test must be in between min and max(inclusive).""" 42 | return self.min <= test <= self.max 43 | 44 | # pylint: disable-next=missing-function-docstring 45 | def to_list(self) -> list[int]: 46 | return [self.min, self.opt, self.max, self.test] 47 | 48 | 49 | class DynamicInputOptions(OptionsDict[str, DynamicRangeOptions]): 50 | """Key (str): the name of an input tensor 51 | Value (DynamicSizeOptions): the dynamic size options for the input tensor when engine executes. 52 | """ 53 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/fuse_consecutive_concats.py: -------------------------------------------------------------------------------- 1 | from types import EllipsisType 2 | 3 | import torch 4 | from torch.fx.node import Node 5 | 6 | from .node_argument import NodeArgument 7 | from .rewrite_pass import RewritePass 8 | 9 | 10 | class ConcatNodeArgument(NodeArgument): 11 | """The arguments of a "call_function" node with target `torch.cat`, `torch.concat` or `torch.concatenate`.""" 12 | 13 | tensors: tuple[Node, ...] | list[Node] 14 | dim: int | str | EllipsisType | None 15 | out: Node | None = None 16 | 17 | @classmethod 18 | def validate_node(cls, node: Node) -> bool: 19 | return node.op == "call_function" and node.target in (torch.cat, torch.concat, torch.concatenate) 20 | 21 | 22 | class FuseConsecutiveConcats(RewritePass): 23 | """Fuse consecutive calls of torch.cat, torch.concat or torch.concatenate.""" 24 | 25 | @classmethod 26 | def rewrite(cls, node: Node) -> dict[Node, Node]: 27 | if (arguments := ConcatNodeArgument.extract_from(node)) is None: 28 | return {} 29 | 30 | graph = node.graph 31 | args: list[Node] = [] 32 | has_parent_concat_with_same_dim = False 33 | for parent in arguments.tensors: 34 | if ( 35 | parent_arguments := ConcatNodeArgument.extract_from(parent) 36 | ) is not None and parent_arguments.dim == arguments.dim: 37 | args.extend(parent_arguments.tensors) 38 | has_parent_concat_with_same_dim = True 39 | else: 40 | args.append(parent) 41 | if not has_parent_concat_with_same_dim: 42 | return {} 43 | 44 | with graph.inserting_before(node): 45 | fused_concat_node = graph.call_function(torch.cat, args=(tuple(args),), kwargs={"dim": arguments.dim}) 46 | return {node: fused_concat_node} 47 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/rewrite_layernorms_functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.fx.node import Node 4 | 5 | from ....core.logger import log 6 | from ..node import get_target_module 7 | from .rewrite_pass import RewritePass 8 | 9 | 10 | class RewriteLayerNormsFunctional(RewritePass): 11 | """Rewrite all occurrences of `torch.nn.LayerNorm` to `torch.nn.functional.layer_norm`.""" 12 | 13 | @classmethod 14 | def rewrite(cls, node: Node) -> dict[Node, Node]: 15 | graph = node.graph 16 | if not ( 17 | (graph_module := graph.owning_module) is not None 18 | and isinstance(layernorm := get_target_module(node), torch.nn.LayerNorm) 19 | ): 20 | return {} 21 | 22 | try: 23 | input_node = node.all_input_nodes[0] 24 | except IndexError: 25 | log.warning(f"LayerNorm node {node.name} has no input node: {node.format_node()}") 26 | return {} 27 | 28 | graph_module.register_parameter(weight_name := f"{node.target}_weight", layernorm.weight) 29 | graph_module.register_parameter(bias_name := f"{node.target}_bias", layernorm.bias) 30 | # pylint: disable=protected-access 31 | _ = graph_module._parameters.pop(f"{node.target}.weight", None) 32 | _ = graph_module._parameters.pop(f"{node.target}.bias", None) 33 | 34 | with graph.inserting_before(node): 35 | normalized_shape = list(layernorm.normalized_shape) 36 | weight_node = graph.get_attr(weight_name) 37 | bias_node = graph.get_attr(bias_name) 38 | layer_norm_node = graph.call_function( 39 | F.layer_norm, 40 | args=(input_node, normalized_shape, weight_node, bias_node), 41 | kwargs={"eps": layernorm.eps}, 42 | ) 43 | return {node: layer_norm_node} 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0", "packaging"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name="owlite" 7 | dynamic = ["version"] 8 | description = "OwLite - No-Code AI compression Toolkit" 9 | dependencies = [ 10 | "onnx>=1.15, <1.18", 11 | "torch>=2.2,<2.9", 12 | "onnxruntime>=1.17, <1.23", 13 | "colored", 14 | "yacs", 15 | "tabulate", 16 | "requests", 17 | "tqdm", 18 | "pydantic", 19 | "lazy_imports", 20 | "numpy", 21 | ] 22 | authors = [ 23 | {name = "SqueezeBits.inc", email = "owlite@squeezebits.com"} 24 | ] 25 | maintainers = [ 26 | {name = "SqueezeBits.inc", email = "owlite@squeezebits.com"} 27 | ] 28 | requires-python = ">=3.10, <3.14" 29 | keywords=["torch", "onnx", "graph", "quantization"] 30 | classifiers=[ 31 | "Intended Audience :: Developers", 32 | "Intended Audience :: Education", 33 | "Intended Audience :: Science/Research", 34 | "License :: OSI Approved :: GNU Affero General Public License v3", 35 | "Programming Language :: Python :: 3", 36 | "Programming Language :: Python :: 3.10", 37 | "Programming Language :: Python :: 3.11", 38 | "Programming Language :: Python :: 3.12", 39 | "Programming Language :: Python :: 3.13", 40 | "Topic :: Scientific/Engineering", 41 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 42 | "Topic :: Software Development", 43 | "Topic :: Software Development :: Libraries", 44 | "Topic :: Software Development :: Libraries :: Python Modules", 45 | ] 46 | 47 | [project.urls] 48 | Repository = "https://github.com/SqueezeBits/owlite" 49 | Documentation = "https://squeezebits.gitbook.io/owlite/quick/readme" 50 | 51 | [project.scripts] 52 | owlite = "owlite.core.cli.owlite_cli:main" 53 | 54 | [tool.setuptools.dynamic] 55 | version = {attr = "owlite.core.constants.OWLITE_VERSION"} 56 | 57 | [tool.setuptools.packages.find] 58 | where = ["src"] 59 | exclude = ["owlite-capi*"] 60 | 61 | [tool.setuptools.package-data] 62 | owlite = ["*.so"] -------------------------------------------------------------------------------- /src/owlite/core/cli/url.py: -------------------------------------------------------------------------------- 1 | """OwLite Base URL Management Module. 2 | 3 | This module handles the caching of base URLs used in OwLite APIs. 4 | """ 5 | 6 | from ...core.logger import log 7 | from ..settings import OWLITE_SETTINGS 8 | from .device import disconnect_device 9 | 10 | URL_NAME_LIST = ["FRONT", "MAIN", "DOVE", "NEST"] 11 | 12 | 13 | def save_base_url(name: str, url: str) -> None: 14 | """Save the base URL for an API in the cache. 15 | 16 | Args: 17 | name (str): A name of a URL. 18 | url (str): URL to save. 19 | 20 | Raises: 21 | ValueError: If the API name is invalid. 22 | """ 23 | if name not in URL_NAME_LIST: 24 | log.error(f"Invalid API base name: '{name}'. Valid API base names are {URL_NAME_LIST}") # UX 25 | raise ValueError(f"Invalid value given to url name: {name}") 26 | 27 | if name == "NEST": 28 | disconnect_device() 29 | 30 | base_urls = OWLITE_SETTINGS.base_url 31 | base_urls.set(name, url) 32 | OWLITE_SETTINGS.base_url = base_urls 33 | 34 | log.info(f"The {name} API base is set to {url}") # UX 35 | 36 | 37 | def print_base_urls() -> None: 38 | """Print base url in cache.""" 39 | base_urls = OWLITE_SETTINGS.base_url 40 | url_list = "\n".join([f"{name} : {getattr(base_urls, name)}" for name in URL_NAME_LIST]) 41 | log.info(f"Base urls list\n{url_list}") # UX 42 | 43 | 44 | def delete_base_url(name: str) -> None: 45 | """Delete url in cache. 46 | 47 | Args: 48 | name (str): base url's name 49 | """ 50 | if name not in URL_NAME_LIST: 51 | log.error(f"Invalid API base name: '{name}'. Valid API base names are {URL_NAME_LIST}") # UX 52 | raise ValueError(f"Invalid value given to url name: {name}") 53 | 54 | if name == "NEST": 55 | disconnect_device() 56 | 57 | base_urls = OWLITE_SETTINGS.base_url 58 | base_urls.set(name) 59 | OWLITE_SETTINGS.base_url = base_urls 60 | log.info(f"Deleted the {name} API base") # UX 61 | -------------------------------------------------------------------------------- /src/owlite/enums/qat_backward_type.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from enum import IntEnum 3 | 4 | from .target_dtype import TargetDType 5 | 6 | 7 | # pylint: disable=invalid-name 8 | class QATBackwardType(IntEnum): 9 | """The enum for specifying available QAT backward functions.""" 10 | 11 | ste = 0 12 | clq = 1 13 | 14 | def function(self, target_dtype: TargetDType = TargetDType.int8) -> Callable: 15 | """Return the `torch.autograd.Function` class corresponding to this enum value with the specified target dtype. 16 | 17 | The supported `target_dtype` values are: 18 | - `TargetDType.int8` and `TargetDType.uint8` for integer quantization 19 | - `TargetDType.fp8_e4m3` for floating-point quantization 20 | 21 | The returned `torch.autograd.Function` class depends on the `target_dtype` and the `QATBackwardType`. For 22 | example, if `target_dtype` is `TargetDType.int8` and `QATBackwardType` is "clq", the method returns the 23 | `clq_function`. 24 | 25 | Raises: 26 | ValueError: If the `target_dtype` is not supported for the given `QATBackwardType`. 27 | 28 | Args: 29 | target_dtype (TargetDType, optional): The target data type for quantization. Defaults to `TargetDType.int8`. 30 | 31 | Returns: 32 | Callable: The `torch.autograd.Function` class for fake quantization. 33 | """ 34 | # pylint: disable-next=import-outside-toplevel 35 | from ..nn.functions import clq_function, fake_fp_quantize_ste_function, fake_quantize_ste_function 36 | 37 | if target_dtype in (TargetDType.int8, TargetDType.uint8): 38 | return { 39 | "clq": clq_function, 40 | "ste": fake_quantize_ste_function, 41 | }[self.name] 42 | if target_dtype in (TargetDType.fp8_e4m3,): 43 | return {"ste": fake_fp_quantize_ste_function}[self.name] 44 | 45 | raise ValueError(f"Invalid QATBackwardType({self.name}) for target dtype({target_dtype})") 46 | -------------------------------------------------------------------------------- /src/owlite/nn/functions/ste_fp.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=duplicate-code, unused-argument 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .fake_fp_quantize import BaseFakeFPQuantizeFunction, fake_fp8_quantize 8 | 9 | 10 | # pylint: disable-next=abstract-method 11 | class FakeFPQuantizeSTEFunction(BaseFakeFPQuantizeFunction): 12 | r"""Fake FP8 quantizing function for QAT using STE (Straight-Through Estimator). 13 | 14 | For $$ quant\_min $$ <= `input` <= $$ quant\_max $$ the gradient passes straight through, 15 | otherwise the gradient is zero 16 | 17 | When $$x$$ is input of FakeQuantize . 18 | 19 | $$ 20 | \hat{x} = \text{FakeFPQuantize}(x) 21 | $$ 22 | """ 23 | 24 | @staticmethod # pylint: disable-next=arguments-differ, too-many-positional-arguments 25 | def forward( 26 | ctx: Any, 27 | inputs: Tensor, 28 | step_size: Tensor, 29 | zero_point: Tensor, 30 | grad_scale: float, # grad_scale is not used 31 | quant_min: float, 32 | quant_max: float, 33 | axis: int | None, 34 | ) -> Tensor: 35 | ctx.save_for_backward(inputs) 36 | lower_bound = quant_min * step_size 37 | upper_bound = quant_max * step_size 38 | ctx.other = lower_bound, upper_bound 39 | return fake_fp8_quantize(inputs, step_size, zero_point, quant_min=quant_min, quant_max=quant_max, axis=axis) 40 | 41 | @staticmethod 42 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 43 | inputs = ctx.saved_tensors[0] 44 | grad_output = grad_outputs[0] 45 | lower_bound, upper_bound = ctx.other 46 | lower_bound = lower_bound.reshape([-1] + ([1] * (inputs.dim() - 1))) 47 | upper_bound = lower_bound.reshape([-1] + ([1] * (inputs.dim() - 1))) 48 | grad_inputs = torch.where(inputs.ge(lower_bound) * inputs.le(upper_bound), grad_output, 0) 49 | return grad_inputs, None, None, None, None, None, None, None 50 | 51 | 52 | fake_fp_quantize_ste_function = FakeFPQuantizeSTEFunction.apply 53 | -------------------------------------------------------------------------------- /src/owlite/core/cli/commands/url_commands.py: -------------------------------------------------------------------------------- 1 | """Handles URL-related commands in OwLite CLI using argparse.""" 2 | 3 | # pylint: disable=unnecessary-lambda, too-few-public-methods 4 | from argparse import Namespace, _SubParsersAction 5 | 6 | from .. import BaseOwLiteCLICommand 7 | from ..url import delete_base_url, print_base_urls, save_base_url 8 | 9 | 10 | class UrlCommands(BaseOwLiteCLICommand): 11 | """Handle URL-related commands in OwLite CLI.""" 12 | 13 | @staticmethod 14 | def register_subcommand(parser: _SubParsersAction) -> None: 15 | """Register subcommands for URL-related operations. 16 | 17 | Args: 18 | parser (_SubParsersAction): The parser object to add subcommands to. 19 | """ 20 | url_parser = parser.add_parser("url", help="Set OwLite API base") 21 | url_parser.add_argument( 22 | "mode", 23 | choices=["add", "rm", "ls"], 24 | help="Device setting command", 25 | ) 26 | 27 | url_parser.add_argument( 28 | "--name", 29 | "-n", 30 | type=str, 31 | help="API base name", 32 | ) 33 | url_parser.add_argument( 34 | "--url", 35 | "-u", 36 | type=str, 37 | help="API base url", 38 | ) 39 | 40 | url_parser.set_defaults(func=lambda args: UrlCommand(args)) 41 | 42 | 43 | class UrlCommand: 44 | """Handle URL-related commands in OwLite CLI.""" 45 | 46 | def __init__(self, args: Namespace) -> None: 47 | """Initialize the UrlCommand. 48 | 49 | Args: 50 | args: Arguments passed to the command. 51 | """ 52 | self.args = args 53 | 54 | def run(self) -> None: 55 | """Execute the specified URL-related operation.""" 56 | if self.args.mode == "add": 57 | save_base_url(self.args.name, self.args.url.rstrip("/")) 58 | elif self.args.mode == "rm": 59 | delete_base_url(self.args.name) 60 | elif self.args.mode == "ls": 61 | print_base_urls() 62 | -------------------------------------------------------------------------------- /src/owlite/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .fake_quantizer import ( 4 | FakeFPQuantizer, 5 | FakeINTQuantizer, 6 | FakePerChannelFPQuantizer, 7 | FakePerChannelINTQuantizer, 8 | FakePerTensorFPQuantizer, 9 | FakePerTensorINTQuantizer, 10 | FakeQuantizer, 11 | ) 12 | from .qconv import QConv1d, QConv2d, QConv3d 13 | from .qconvbn import QConvBn1d, QConvBn2d, QConvBn3d 14 | from .qlinear import QLinear 15 | from .qmodule_mixins import UnaryNeuralQModuleMixin 16 | 17 | 18 | def promote_to_qmodule( 19 | cls: type[torch.nn.Module], 20 | ) -> type[QConv1d] | type[QConv2d] | type[QConv3d] | type[QLinear] | None: 21 | """Convert a torch.nn.Module subclass to its quantized counterpart if exists. 22 | 23 | Args: 24 | cls (type[torch.nn.Module]): a subclass of a torch.nn.Module 25 | 26 | Returns: 27 | type[QConv1d] | type[QConv2d] | type[QConv3d] | type[QLinear] | None: a quantized counterpart of the `cls` if 28 | exists. None otherwise. 29 | """ 30 | quantized_class_map: dict[type[torch.nn.Module], type[QConv1d] | type[QConv2d] | type[QConv3d] | type[QLinear]] = { 31 | torch.nn.Conv1d: QConv1d, 32 | torch.nn.Conv2d: QConv2d, 33 | torch.nn.Conv3d: QConv3d, 34 | torch.nn.Linear: QLinear, 35 | } 36 | return quantized_class_map.get(cls, None) 37 | 38 | 39 | def enable_quantizers(module: torch.nn.Module) -> None: 40 | """Enable all fake quantizers in the module. 41 | 42 | Args: 43 | module (torch.nn.Module): The module containing fake quantizers to enable or disable. 44 | """ 45 | for _, submodule in module.named_modules(): 46 | if isinstance(submodule, FakeQuantizer): 47 | submodule.enable() 48 | 49 | 50 | def disable_quantizers(module: torch.nn.Module) -> None: 51 | """Disable all fake quantizers in the module. 52 | 53 | Args: 54 | module (torch.nn.Module): The module containing fake quantizers to enable or disable. 55 | """ 56 | for _, submodule in module.named_modules(): 57 | if isinstance(submodule, FakeQuantizer): 58 | submodule.disable() 59 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from .connect_inplace_ops_to_users import ConnectInplaceOpsToUsers 4 | from .decompose_expm1 import DecomposeExpm1 5 | from .decompose_in_projection import DecomposeInProjection 6 | from .decompose_in_projection_packed import DecomposeInProjectionPacked 7 | from .decompose_multi_head_attention_forward import DecomposeMultiHeadAttentionForward 8 | from .decompose_multihead_attention import DecomposeMultiheadAttention 9 | from .decompose_scaled_dot_product_attention import DecomposeScaledDotProductAttention 10 | from .decompose_silu import DecomposeSiLU 11 | from .decompose_transformer import DecomposeTransformer 12 | from .decompose_transformer_decoder import DecomposeTransformerDecoder 13 | from .decompose_transformer_decoder_layer import DecomposeTransformerDecoderLayer 14 | from .decompose_transformer_encoder import DecomposeTransformerEncoder 15 | from .decompose_transformer_encoder_layer import DecomposeTransformerEncoderLayer 16 | from .eliminate_dummy_output import EliminateDummyOutput 17 | from .eliminate_explicit_getitem import EliminateExplicitGetitem 18 | from .eliminate_identity import EliminateIdentity 19 | from .eliminate_nop_getitem import EliminateNopGetitem 20 | from .fix_hard_coded_devices import FixHardCodedDevice 21 | from .fuse_consecutive_concats import FuseConsecutiveConcats 22 | from .rewrite_layernorms_functional import RewriteLayerNormsFunctional 23 | 24 | PassName = Literal[ 25 | "ConnectSetitemToItsUsers", 26 | "DecomposeExpm1", 27 | "DecomposeInProjectionPacked", 28 | "DecomposeInProjection", 29 | "DecomposeMultiHeadAttentionForward", 30 | "DecomposeMultiheadAttention", 31 | "DecomposeScaledDotProductAttention", 32 | "DecomposeSiLU", 33 | "DecomposeTransformer", 34 | "DecomposeTransformerEncoder", 35 | "DecomposeTransformerEncoderLayer", 36 | "DecomposeTransformerDecoder", 37 | "DecomposeTransformerDecoderLayer", 38 | "EliminateDummyOutput", 39 | "EliminateExplicitGetitem", 40 | "EliminateIdentity", 41 | "EliminateNopGetitem", 42 | "FixHardCodedDevice", 43 | "FuseConsecutiveConcats", 44 | "RewriteLayerNormsFunctional", 45 | ] 46 | -------------------------------------------------------------------------------- /src/owlite/core/cli/device.py: -------------------------------------------------------------------------------- 1 | """Device Management Module. 2 | 3 | Provides functions to manage device connections. 4 | """ 5 | 6 | from ..cache.device import Device 7 | from ..device_settings import OWLITE_DEVICE_SETTINGS 8 | from ..exceptions import DeviceError 9 | from ..logger import log 10 | 11 | 12 | def connect_device() -> None: 13 | """Connect to the device in selected device manager. 14 | 15 | Raises: 16 | DeviceError: If the selected device manager doesn't exist. 17 | """ 18 | devices = list(OWLITE_DEVICE_SETTINGS.devices.values()) 19 | _devices = "\n".join(f"{index}: {device}" for index, device in enumerate(devices)) 20 | log.info(f"Available devices:\n{_devices}") 21 | 22 | user_input = input("Enter the index of the device you want to connect to: ") # UX 23 | try: 24 | index = int(user_input) 25 | if index not in range(len(devices)): 26 | log.error( 27 | f"Index out of range. Please choose the device index within the range [0, {len(devices) - 1}]" 28 | ) # UX 29 | raise DeviceError(f"Invalid index given : {index}") 30 | device = devices[index] 31 | except ValueError as e: 32 | log.error(f"Please provide a valid index within the range [0, {len(devices) - 1}]") # UX 33 | raise DeviceError(e) from e 34 | 35 | OWLITE_DEVICE_SETTINGS.connected = device 36 | 37 | log.info(f"Connected to the device '{OWLITE_DEVICE_SETTINGS.connected}'") 38 | 39 | 40 | def disconnect_device() -> None: 41 | """Disconnect the currently connected device, if any.""" 42 | OWLITE_DEVICE_SETTINGS.connected = None 43 | log.info("Successfully disconnected from the device") # UX 44 | 45 | 46 | def connect_to_first_available_device() -> Device: 47 | """Connect to the free device in NEST. 48 | 49 | Returns: 50 | Device: connected device name 51 | """ 52 | if OWLITE_DEVICE_SETTINGS.connected: 53 | return OWLITE_DEVICE_SETTINGS.connected 54 | log.info("Connecting to the first device at NEST") # UX 55 | device = list(OWLITE_DEVICE_SETTINGS.devices.values())[0] 56 | OWLITE_DEVICE_SETTINGS.connected = device 57 | return device 58 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/eliminate_explicit_getitem.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from torch.fx import GraphModule 4 | from torch.fx.node import Argument, Node 5 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 6 | 7 | 8 | class EliminateExplicitGetitem(PassBase): 9 | """Eliminate function calls of operator.getitem on list / tuple / dict.""" 10 | 11 | def call(self, graph_module: GraphModule) -> PassResult: 12 | """Eliminate function calls of operator.getitem on list / tuple / dict. 13 | 14 | Args: 15 | graph_module (GraphModule): the input graph module 16 | 17 | Returns: 18 | PassResult: the result of the pass 19 | """ 20 | nodes: list[Node] = [*graph_module.graph.nodes] 21 | modified = False 22 | for node in nodes: 23 | if not (node.op == "call_function" and node.target is operator.getitem and len(node.args) == 2): 24 | continue 25 | 26 | container = node.args[0] 27 | position = node.args[1] 28 | value: Argument 29 | if isinstance(container, dict) and isinstance(position, str): 30 | value = container[position] 31 | elif isinstance(container, list | tuple) and isinstance(position, int | slice): 32 | value = container[position] 33 | else: 34 | continue 35 | 36 | usages: dict[Node, int | str] = {} 37 | for user in node.users: 38 | if node in user.args: 39 | i = user.args.index(node) 40 | usages[user] = i 41 | continue 42 | for position, value in [*node.kwargs.items()]: 43 | if value is node: 44 | usages[user] = position 45 | break 46 | 47 | modified = modified or len(usages) > 0 48 | for user, index_or_key in usages.items(): 49 | if isinstance(index_or_key, int): 50 | user.args = user.args[:index_or_key] + (value,) + user.args[index_or_key + 1 :] 51 | continue 52 | user.kwargs[index_or_key] = value 53 | 54 | return PassResult(graph_module, modified) 55 | -------------------------------------------------------------------------------- /src/owlite/enums/target_dtype.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | import torch 4 | 5 | 6 | # pylint: disable=invalid-name 7 | class TargetDType(IntEnum): 8 | """The enum for specifying available target data type.""" 9 | 10 | fp16 = 0 11 | int8 = 1 12 | uint8 = 2 13 | fp8_e4m3 = 3 14 | 15 | @property 16 | def unsigned(self) -> bool: 17 | """Return True if the data type is unsigned, False otherwise. 18 | 19 | Returns: 20 | bool: True if the data type is unsigned, False otherwise. 21 | """ 22 | return self.name == "uint8" 23 | 24 | @property 25 | def precision(self) -> int: 26 | """Return the precision of the data type. 27 | 28 | Returns: 29 | int: The precision of the data type in bits. 30 | """ 31 | if self.name in ("int8", "uint8", "fp8_e4m3"): 32 | return 8 33 | return 16 34 | 35 | @property 36 | def torch_dtype(self) -> torch.dtype: 37 | """Returns the corresponding PyTorch data type. 38 | 39 | This property maps the target data type to its corresponding PyTorch data type. 40 | 41 | Returns: 42 | torch.dtype: The corresponding PyTorch data type. 43 | """ 44 | torch_dtypes: dict[str, torch.dtype] = { 45 | "fp16": torch.float32, 46 | "int8": torch.int8, 47 | "uint8": torch.uint8, 48 | "fp8_e4m3": torch.float8_e4m3fn, 49 | } 50 | return torch_dtypes[self.name] 51 | 52 | @classmethod 53 | def invert_signedness(cls, target_dtype: "TargetDType") -> "TargetDType | None": 54 | """Return the TargetDType with inverted signedness, if possible. 55 | 56 | This method inverts the signedness of the target data type, i.e., int8 becomes uint8 and vice versa. 57 | It returns None for fp16 and fp8_e4m3, as they only have a a corresponding unsigned type. 58 | 59 | Args: 60 | target_dtype (TargetDType): The target data type. 61 | 62 | Returns: 63 | TargetDType | None: The TargetDType with inverted signedness, or None if it's not possible. 64 | """ 65 | if target_dtype == TargetDType.int8: 66 | return TargetDType.uint8 67 | if target_dtype == TargetDType.uint8: 68 | return TargetDType.int8 69 | return None 70 | -------------------------------------------------------------------------------- /src/owlite/core/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Literal 4 | 5 | from packaging.version import Version 6 | 7 | OWLITE_HOME_PATH = Path(os.getenv("OWLITE_HOME", os.path.join(os.getcwd(), "owlite"))).resolve() 8 | 9 | OWLITE_REPORT_URL = "https://squeezebits.zendesk.com/hc/en-us" 10 | 11 | OWLITE_FRONT_BASE_URL = "https://owlite.ai" 12 | OWLITE_MAIN_API_BASE_URL = "https://api.owlite.ai" 13 | OWLITE_DOVE_API_BASE_URL = "https://dove.owlite.ai" 14 | 15 | NEST_URL = "https://nest.owlite.ai" 16 | 17 | OWLITE_GIT_REPO_URL = "https://github.com/SqueezeBits/owlite" 18 | 19 | OWLITE_API_DEFAULT_TIMEOUT = int(os.environ.get("OWLITE_API_DEFAULT_TIMEOUT", "300")) 20 | 21 | OWLITE_CALIBRATOR_HISTOGRAM_SIZE = int(os.environ.get("OWLITE_CALIBRATOR_HISTOGRAM_SIZE", "2048")) 22 | OWLITE_CALIBRATION_ENABLE_GRAD = bool(os.environ.get("OWLITE_CALIBRATION_ENABLE_GRAD", False)) 23 | 24 | FX_CONFIGURATION_FORMAT_VERSION = Version("1.3") 25 | OWLITE_SETTINGS_FORMAT_VERSION = Version("2.1") 26 | OWLITE_VERSION = Version("2.6.0") 27 | 28 | # pylint: disable-next=invalid-name 29 | SUPPORTED_QUALCOMM_DEVICES = Literal[ 30 | "SA7255P ADP", 31 | "SA8255 (Proxy)", 32 | "SA8295P ADP", 33 | "SA8650 (Proxy)", 34 | "SA8775 (Proxy)", 35 | "SA8775P ADP", 36 | "Snapdragon 8cx Gen 3 CRD", 37 | "Snapdragon X Elite CRD", 38 | "Snapdragon X Plus 8-Core CRD", 39 | "QCS6490 (Proxy)", 40 | "QCS8275 (Proxy)", 41 | "QCS8550 (Proxy)", 42 | "QCS9075 (Proxy)", 43 | "RB3 Gen 2 (Proxy)", 44 | "Samsung Galaxy A73 5G", 45 | "Samsung Galaxy S21", 46 | "Samsung Galaxy S21 (Family)", 47 | "Samsung Galaxy S21 Ultra", 48 | "Samsung Galaxy S21+", 49 | "Samsung Galaxy S22 (Family)", 50 | "Samsung Galaxy S22 5G", 51 | "Samsung Galaxy S22 Ultra 5G", 52 | "Samsung Galaxy S22+ 5G", 53 | "Samsung Galaxy S23", 54 | "Samsung Galaxy S23 (Family)", 55 | "Samsung Galaxy S23 Ultra", 56 | "Samsung Galaxy S23+", 57 | "Samsung Galaxy S24", 58 | "Samsung Galaxy S24 (Family)", 59 | "Samsung Galaxy S24 Ultra", 60 | "Samsung Galaxy S24+", 61 | "Samsung Galaxy S25", 62 | "Samsung Galaxy S25 (Family)", 63 | "Samsung Galaxy S25 Ultra", 64 | "Samsung Galaxy S25+", 65 | "Samsung Galaxy Tab S8", 66 | "Snapdragon 8 Elite QRD", 67 | "Xiaomi 12", 68 | "Xiaomi 12 (Family)", 69 | "Xiaomi 12 Pro", 70 | "QCS8450 (Proxy)", 71 | "XR2 Gen 2 (Proxy)", 72 | ] 73 | -------------------------------------------------------------------------------- /src/owlite/options/generic_type_checking.py: -------------------------------------------------------------------------------- 1 | from types import NoneType 2 | from typing import Any, Union, get_args, get_origin 3 | 4 | 5 | def generic_isinstance(obj: Any, type_hint: type | tuple[type]) -> bool: 6 | """Extend the built-in function `isinstance` for type hint checking.""" 7 | if isinstance(type_hint, tuple): 8 | return any(generic_isinstance(obj, t) for t in type_hint) 9 | 10 | origin_type = getattr(type_hint, "__origin__", None) 11 | if origin_type is None: 12 | return isinstance(obj, type_hint) 13 | value_types = get_args(type_hint) 14 | if origin_type is dict: 15 | value_type = value_types[0] 16 | return isinstance(obj, origin_type) and all(generic_isinstance(x, value_type) for x in obj.values()) 17 | if origin_type in (tuple, list): 18 | value_type = value_types[0] 19 | return isinstance(obj, origin_type) and all(generic_isinstance(x, value_type) for x in obj) 20 | if origin_type is Union: 21 | return generic_isinstance(obj, value_types) 22 | raise NotImplementedError(f"generic_isinstance for {type_hint} is not implemented.") 23 | 24 | 25 | def generic_issubclass(type_hint: type, superclass: type | tuple[type]) -> bool: 26 | """Extend the built-in function `issubclass` for type hint checking.""" 27 | if isinstance(superclass, tuple): 28 | return any(generic_issubclass(type_hint, s) for s in superclass) 29 | 30 | origin_type = getattr(type_hint, "__origin__", None) 31 | if origin_type is None: 32 | return issubclass(type_hint, superclass) 33 | if origin_type in (dict, tuple, list): 34 | return issubclass(type_hint, superclass) 35 | if origin_type is Union: 36 | field_type_args = get_args(type_hint) 37 | return any(generic_issubclass(x, superclass) for x in field_type_args) 38 | raise NotImplementedError(f"generic_issubclass for {type_hint} is not implemented.") 39 | 40 | 41 | def is_optional(type_hint: type) -> bool: 42 | """Check if the type hint is wrapped with Optional.""" 43 | if get_origin(type_hint) is not Union: 44 | return False 45 | args = get_args(type_hint) 46 | return len(args) == 2 and NoneType in args 47 | 48 | 49 | def unwrap_optional(type_hint: type) -> type: 50 | """Unwrap the Optional from the type hint if it is wrapped.""" 51 | if not is_optional(type_hint): 52 | return type_hint 53 | return [arg for arg in get_args(type_hint) if arg is not NoneType][0] 54 | -------------------------------------------------------------------------------- /src/owlite/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from typing import TYPE_CHECKING 3 | 4 | from lazy_imports import LazyImporter 5 | 6 | from .core.constants import OWLITE_VERSION as __version__ # noqa: N811 7 | 8 | _import_structure = { 9 | "backend": [ 10 | "fx", 11 | "onnx", 12 | ], 13 | "calibrators": [ 14 | "CalibrationContext", 15 | "calibrate", 16 | ], 17 | "compression": ["compress"], 18 | "enums": [ 19 | "PTQCalibrationType", 20 | "QATBackwardType", 21 | "Runtime", 22 | ], 23 | "nn": [ 24 | "FakePerChannelINTQuantizer", 25 | "FakePerTensorINTQuantizer", 26 | "FakePerChannelFPQuantizer", 27 | "FakePerTensorFPQuantizer", 28 | "FakeINTQuantizer", 29 | "FakeQuantizer", 30 | "QConv1d", 31 | "QConv2d", 32 | "QConv3d", 33 | "QLinear", 34 | "QConvBn1d", 35 | "QConvBn2d", 36 | "QConvBn3d", 37 | ], 38 | "options": [ 39 | "Channel", 40 | "CompressionOptions", 41 | "DynamicAxisOptions", 42 | "DynamicInputOptions", 43 | "FakeQuantizerOptions", 44 | "GraphQuantizationOptions", 45 | "NodeQuantizationOptions", 46 | "ONNXExportOptions", 47 | ], 48 | "owlite": ["init"], 49 | } 50 | 51 | if TYPE_CHECKING: 52 | from .backend import fx, onnx 53 | from .calibrators import ( 54 | CalibrationContext, 55 | calibrate, 56 | ) 57 | from .compression import compress 58 | from .enums import PTQCalibrationType, QATBackwardType, Runtime 59 | from .nn import ( 60 | FakePerChannelFPQuantizer, 61 | FakePerChannelINTQuantizer, 62 | FakePerTensorFPQuantizer, 63 | FakePerTensorINTQuantizer, 64 | FakeQuantizer, 65 | QConv1d, 66 | QConv2d, 67 | QConv3d, 68 | QConvBn1d, 69 | QConvBn2d, 70 | QConvBn3d, 71 | QLinear, 72 | ) 73 | from .options import ( 74 | Channel, 75 | CompressionOptions, 76 | DynamicAxisOptions, 77 | DynamicInputOptions, 78 | FakeQuantizerOptions, 79 | GraphQuantizationOptions, 80 | NodeQuantizationOptions, 81 | ONNXExportOptions, 82 | ) 83 | from .owlite import init 84 | else: 85 | sys.modules[__name__] = LazyImporter( 86 | __name__, 87 | globals()["__file__"], 88 | _import_structure, 89 | extra_objects={"__version__": __version__}, 90 | ) 91 | -------------------------------------------------------------------------------- /src/owlite/backend/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Flag to disable automatic object monkey patching 4 | DISABLE_AUTO_PATCH = os.environ.get("OWLITE_DISABLE_AUTO_PATCH", "0") == "1" 5 | 6 | # Maximum iteration limit for ONNX transformations. 7 | FX_TRANSFORM_MAXIMUM_ITERATION = int(os.environ.get("OWLITE_FX_TRANSFORM_MAXIMUM_ITERATION", 100)) 8 | 9 | # Maximum iteration limit for ONNX transformations. 10 | ONNX_TRANSFORM_MAXIMUM_ITERATION = int(os.environ.get("OWLITE_ONNX_TRANSFORM_MAXIMUM_ITERATION", 100)) 11 | 12 | # All ONNX initialized tensors with size (in bytes) greater than or equal to this value 13 | # will be saved at external data file. OwLite will write all initialized tensors at external 14 | # data file by default. 15 | # Note that, `onnxruntime.InferenceSession` might fail to load models larger than 2 GB 16 | # with all initializers saved externally, onnxruntime. In such case, set this value to 1024. 17 | ONNX_EXTERNAL_DATA_SIZE_THRESHOLD = int(os.environ.get("OWLITE_ONNX_EXTERNAL_DATA_SIZE_THRESHOLD", 0)) 18 | 19 | # Run strict shape inference 20 | STRICT_ONNX_SHAPE_INFERENCE = os.environ.get("OWLITE_STRICT_ONNX_SHAPE_INFERENCE", "1") == "1" 21 | 22 | # Run strict invariance checking 23 | STRICT_ONNX_FUNCTIONALITY_CHECKING = os.environ.get("OWLITE_STRICT_ONNX_FUNCTIONALITY_CHECKING", "1") == "1" 24 | 25 | 26 | # [DISCLAIMER] Configurations below are deprecated and may be removed in later versions. 27 | 28 | # (used only for QNN runtime) ONNX operator types to save input parameters internally during onnx export. 29 | # List entry can be either 30 | # 1) a operator type in string 31 | # 2) a tuple of operator type in string and tuple of indices of inputs to store internally 32 | # 33 | # When an index tuple is provided, the input parameters not included in the tuple will be stored externally. 34 | ONNX_OPS_TO_SAVE_PARAMETERS_INTERNALLY: list[tuple[str, list[int]] | str] = [ 35 | "Col2Im", 36 | "Compress", 37 | "ConstantOfShape", 38 | "CumSum", 39 | "Expand", 40 | ("Gather", [1]), 41 | "GatherElements", 42 | "GatherND", 43 | "GridSample", 44 | "Pad", 45 | "ReduceL1", 46 | "ReduceL2", 47 | "ReduceLogSum", 48 | "ReduceLogSumExp", 49 | "ReduceMax", 50 | "ReduceMean", 51 | "ReduceMin", 52 | "ReduceProd", 53 | "ReduceSum", 54 | "ReduceSumSquare", 55 | "Reshape", 56 | "Resize", 57 | "Scatter", 58 | "ScatterElements", 59 | "ScatterND", 60 | "Shape", 61 | "Slice", 62 | "Split", 63 | "Squeeze", 64 | "Tile", 65 | "TopK", 66 | "Unsqueeze", 67 | ] 68 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/filter_warnings.py: -------------------------------------------------------------------------------- 1 | import linecache 2 | import warnings 3 | from typing import TextIO 4 | 5 | from torch.jit._trace import TracerWarning 6 | 7 | 8 | class FilterWarningsCausedByPasses(warnings.catch_warnings): 9 | """Context to ignore the warnings caused by FX optimization passes.""" 10 | 11 | def __init__(self) -> None: 12 | super().__init__(record=False, module=warnings) 13 | 14 | def __enter__(self) -> None: 15 | warnings.showwarning = self.custom_showwarning 16 | return super().__enter__() 17 | 18 | def custom_showwarning( 19 | self, 20 | message: Warning | str, 21 | category: type[Warning], 22 | filename: str, 23 | lineno: int, 24 | file: TextIO | None = None, 25 | line: str | None = None, 26 | ) -> None: 27 | """Run `warnings.showwarning`, ignoring specific warnings caused by FX optimization passes. 28 | 29 | Args: 30 | message (Warning | str): the warning message to show 31 | category (type[Warning]): the warning category 32 | filename (str): the name of the file where this warning is caused 33 | lineno (int): the line number at which this warning is caused 34 | file (TextIO | None, optional): The file descriptor to write the warnings at. Defaults to None. 35 | line (str | None, optional): the line that caused this warning. Defaults to None. 36 | """ 37 | if not hasattr(self, "_showwarning"): 38 | return 39 | if ( 40 | category is TracerWarning 41 | and filename.startswith("") 42 | and "math_sqrt" in (line or linecache.getline(filename, lineno)) 43 | ): 44 | # The FX optimization passes `DecomposeMultiHeadAttentionForward` and `DecomposeScaledDotProductAttention` 45 | # explicitly add nodes calling the function `math.sqrt`, generating a number of identical tracer warnings 46 | # such as: ".38:800: TracerWarning: Converting a tensor to a Python float might cause the 47 | # trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a 48 | # constant in the future. This means that the trace might not generalize to other inputs! 49 | # sqrt_11 = math_sqrt(size_107); size_107 = None" 50 | return 51 | msg = warnings.WarningMessage(message, category, filename, lineno, file, line) 52 | # pylint: disable-next=protected-access 53 | warnings._showwarnmsg_impl(msg) # type: ignore[attr-defined] 54 | -------------------------------------------------------------------------------- /src/owlite/options/options_mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from packaging.version import Version 4 | from pydantic import BaseModel, ConfigDict, model_validator 5 | from typing_extensions import Self 6 | 7 | from .options_dict import OptionsDict 8 | from .serializable import Serializable 9 | 10 | 11 | class OptionsMixin( 12 | BaseModel, Serializable, arbitrary_types_allowed=True, populate_by_name=True, validate_assignment=True 13 | ): 14 | """The Mixin-style base class for adding type-checking feature and custom value-checking feature.""" 15 | 16 | model_config = ConfigDict( 17 | arbitrary_types_allowed=True, 18 | populate_by_name=True, 19 | validate_assignment=True, 20 | validate_default=True, 21 | validate_return=True, 22 | extra="forbid", 23 | ) 24 | 25 | @classmethod 26 | def deserialize(cls, d: dict | str | Self) -> Self: 27 | return cls.model_validate(cls._deserialize(d)) 28 | 29 | def serialize_as_json(self, version: Version | None = None) -> dict[str, Any]: 30 | cls = type(self) 31 | version = cls.truncate_version(version) 32 | 33 | exclude: set[str] = set() 34 | if isinstance((available_since := cls.__available_since__), dict) and isinstance(version, Version): 35 | exclude = exclude.union( 36 | field 37 | for field in self.model_fields 38 | if isinstance((min_version := available_since.get(field)), Version) and version < min_version 39 | ) 40 | d = self.model_dump(mode="json", exclude=exclude) 41 | if isinstance((cls_version := getattr(self, "__version__", None)), Version): 42 | d.update(version=str(version or cls_version)) 43 | return d 44 | 45 | @model_validator(mode="before") 46 | @classmethod 47 | def preprocess_options_dict(cls, values: Any) -> dict[str, Any]: 48 | """Preprocess all fields of type `OptionsDict` before passing it to the Pydantic's type validator. 49 | 50 | Args: 51 | values (dict[str, Any]): a dictionary containing values for initializing an instance of this class. 52 | 53 | Returns: 54 | dict[str, Any]: the dictionary whose key corresponds to a field whose annotation is a subclass of 55 | `OptionsDict`. 56 | """ 57 | for name, field in cls.model_fields.items(): 58 | if not (isinstance(values, dict) and name in values and OptionsDict.issubclass(field.annotation)): 59 | continue 60 | values[name] = field.annotation.deserialize(values[name]) 61 | return values 62 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/connect_inplace_ops_to_users.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | from torch.fx import GraphModule 4 | from torch.fx.node import Node 5 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 6 | 7 | 8 | class ConnectInplaceOpsToUsers(PassBase): 9 | """Connect `call_function(operator.setitem)` nodes to their user nodes.""" 10 | 11 | def call(self, graph_module: GraphModule) -> PassResult: 12 | """Connect `call_function(operator.setitem)` nodes to its user nodes. 13 | 14 | Args: 15 | graph_module (GraphModule): the input graph module 16 | 17 | Returns: 18 | PassResult: the result of the pass 19 | """ 20 | nodes = list(graph_module.graph.nodes) 21 | 22 | def replace_users_with_index_larger_than(*, target: Node, replacement: Node, index_lower_bound: int) -> bool: 23 | def is_index_greater_than_lower_bound(user: Node) -> bool: 24 | try: 25 | return nodes.index(user) > index_lower_bound 26 | except IndexError: 27 | return False 28 | 29 | replaced_uses = target.replace_all_uses_with(replacement, delete_user_cb=is_index_greater_than_lower_bound) 30 | return len(replaced_uses) > 0 31 | 32 | modified = False 33 | for inplace_node_index, inplace_node in enumerate(nodes): 34 | if not ( 35 | is_inplace(inplace_node) 36 | and isinstance( 37 | ( 38 | parent_node := inplace_node.args[0] 39 | if len(inplace_node.args) > 0 40 | else inplace_node.kwargs.get("a") 41 | ), 42 | Node, 43 | ) 44 | ): 45 | continue 46 | 47 | is_any_user_replaced = replace_users_with_index_larger_than( 48 | target=parent_node, replacement=inplace_node, index_lower_bound=inplace_node_index 49 | ) 50 | modified = modified or is_any_user_replaced 51 | return PassResult(graph_module, modified) 52 | 53 | 54 | def is_inplace(node: Node) -> bool: 55 | """Check if the given node is calling an inplace function or method. 56 | 57 | Args: 58 | node (Node): a node 59 | 60 | Returns: 61 | bool: `True` if it is calling an inplace function or method, `False` otherwise 62 | """ 63 | if node.op == "call_function" and node.target in (operator.setitem,): 64 | return True 65 | if node.op == "call_method" and isinstance(node.target, str) and node.target.endswith("_"): 66 | return True 67 | return False 68 | -------------------------------------------------------------------------------- /src/owlite/api/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from tqdm import tqdm 5 | from tqdm.utils import CallbackIOWrapper 6 | 7 | from ..core.logger import log 8 | 9 | 10 | def upload_file_to_url(file_path: str, dst_url: str) -> None: 11 | """Upload file to destination URL via http request. 12 | 13 | Args: 14 | file_path (str): path to file 15 | dst_url (str): url to upload 16 | 17 | Raises: 18 | FileNotFoundError: when file does not exists at given path 19 | 20 | HTTPError: when request was not successful 21 | """ 22 | log.info(f"Uploading {file_path}") # UX 23 | 24 | if not os.path.exists(file_path): 25 | log.error(f"Cannot upload {file_path} as it is not found") # UX 26 | raise FileNotFoundError("File not found") 27 | 28 | total = os.path.getsize(file_path) 29 | with open(file_path, "rb") as file: 30 | with tqdm( 31 | total=total, 32 | unit="iB", 33 | unit_scale=True, 34 | unit_divisor=1024, 35 | ) as progress_bar: 36 | reader_wrapper = CallbackIOWrapper(progress_bar.update, file, "read") 37 | # pylint: disable-next=missing-timeout 38 | resp = requests.put(dst_url, data=reader_wrapper) 39 | if not resp.ok: 40 | resp.raise_for_status() 41 | 42 | log.info("Uploading done") # UX 43 | 44 | 45 | def download_file_from_url(file_url: str, path_to_save: str) -> None: 46 | """Download file from URL via http request. 47 | 48 | Note that this function will overwrite a file with downloaded file content if a file already exists at given path. 49 | 50 | Args: 51 | file_url: URL of a file to download 52 | 53 | path_to_save: path to save downloaded file 54 | 55 | Raises: 56 | HTTPError: when request was not successful 57 | """ 58 | log.info(f"Downloading file at {path_to_save}") # UX 59 | 60 | if os.path.exists(path_to_save): 61 | log.warning(f"The existing file at {path_to_save} will be overwritten") # UX 62 | 63 | resp = requests.get(file_url, stream=True) # pylint: disable=missing-timeout 64 | total = int(resp.headers.get("content-length", 0)) 65 | with ( 66 | open(path_to_save, "wb") as file, 67 | tqdm( 68 | total=total, 69 | unit="iB", 70 | unit_scale=True, 71 | unit_divisor=1024, 72 | ) as progress_bar, 73 | ): 74 | for data in resp.iter_content(chunk_size=1024): 75 | size = file.write(data) 76 | progress_bar.update(size) 77 | 78 | if not resp.ok: 79 | resp.raise_for_status() 80 | 81 | log.info("Downloading done") # UX 82 | -------------------------------------------------------------------------------- /src/owlite/core/cli/commands/user_commands.py: -------------------------------------------------------------------------------- 1 | """Handles user-related commands in OwLite CLI.""" 2 | 3 | # pylint: disable=unnecessary-lambda, too-few-public-methods 4 | from argparse import Namespace, _SubParsersAction 5 | 6 | from ...logger import log 7 | from ...settings import OWLITE_SETTINGS 8 | from .. import BaseOwLiteCLICommand 9 | from ..api.login import whoami 10 | from ..login import login, logout 11 | 12 | 13 | class UserCommands(BaseOwLiteCLICommand): 14 | """Handles user-related commands in OwLite CLI.""" 15 | 16 | @staticmethod 17 | def register_subcommand(parser: _SubParsersAction) -> None: 18 | """Register subcommands for user-related operations. 19 | 20 | Args: 21 | parser (_SubParsersAction): The parser object to add subcommands to. 22 | """ 23 | login_parser = parser.add_parser("login", help="Login to OwLite") 24 | login_parser.set_defaults(func=lambda args: LoginCommand(args)) 25 | whoami_parser = parser.add_parser("whoami", help="Display the current user's username") 26 | whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) 27 | logout_parser = parser.add_parser("logout", help="Log out") 28 | logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) 29 | 30 | 31 | class BaseUserCommand: 32 | """Base class for user-related commands.""" 33 | 34 | def __init__(self, args: Namespace) -> None: 35 | """Initialize the BaseUserCommand. 36 | 37 | Args: 38 | args: Arguments passed to the command. 39 | """ 40 | self.args = args 41 | 42 | 43 | class LoginCommand(BaseUserCommand): 44 | """Handle the 'login' command.""" 45 | 46 | def run(self) -> None: 47 | """Execute the login operation.""" 48 | login() 49 | if OWLITE_SETTINGS.current_workspace is not None: 50 | log.info(f"Your Current Workspace: {OWLITE_SETTINGS.current_workspace.name}") # UX 51 | log.info("The OwLite Package operates within the selected workspace.") # UX 52 | 53 | 54 | class WhoamiCommand(BaseUserCommand): 55 | """Handle the 'whoami' command.""" 56 | 57 | def run(self) -> None: 58 | """Execute the whoami operation and prints the username.""" 59 | userinfo = whoami() 60 | log.info(userinfo.name) # UX 61 | if OWLITE_SETTINGS.current_workspace is not None: 62 | log.info(f"Your Current Workspace: {OWLITE_SETTINGS.current_workspace.name}") # UX 63 | log.info("The OwLite Package operates within the selected workspace.") # UX 64 | 65 | 66 | class LogoutCommand(BaseUserCommand): 67 | """Handle the 'logout' command.""" 68 | 69 | def run(self) -> None: 70 | """Execute the logout operation.""" 71 | logout() 72 | -------------------------------------------------------------------------------- /src/owlite/nn/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: D205 2 | """Quantization Aware Training (QAT) is a technique that allows the model to learn the quantization error during the 3 | training process. QAT aims to minimize the loss of accuracy during the quantization process, thus making the model 4 | smaller and faster while maintaining as much of its accuracy as possible. OwLite makes QAT easier, requiring only 5 | minimal changes to an existing training code. 6 | 7 | Please review the subdocuments for technical details. 8 | 9 | ## Usage 10 | 11 | To use QAT with OwLite, you can follow your standard training procedure, keeping in mind two aspects: 12 | 13 | * QAT is a process that needs to be performed after the 14 | [convert](https://squeezebits.gitbook.io/owlite/python-api/owlite.owlite.owlite/owlite.owlite.convert) stage, where 15 | you have applied the compression configuration in experiment mode using OwLite. 16 | * If the optimizer for training was set before calling the convert method, you should set the optimizer again with 17 | the new parameter of the converted mode 18 | 19 | Please note that the model converted by OwLite has a fixed batch size. Therefore, you need to set `drop_last=True` 20 | when creating your [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) 21 | object. 22 | 23 | For example: 24 | 25 | ```python 26 | DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 27 | batch_sampler=None, num_workers=0, collate_fn=None, 28 | pin_memory=False, drop_last=True, timeout=0, 29 | worker_init_fn=None, *, prefetch_factor=2, 30 | persistent_workers=False) 31 | ``` 32 | 33 | This ensures that the DataLoader will discard the last remaining batch if the dataset size is not divisible 34 | by the batch size. 35 | 36 | ## Tips for Better Results 37 | 38 | If you are getting unsatisfactory results from your training, consider adjusting the learning rate or the weight decay. 39 | Lowering the learning rate can help the model converge more smoothly while reducing the weight decay can help prevent 40 | the model from over-fitting. 41 | 42 | * **Adjust the Learning Rate**: If the training loss fluctuates, consider reducing the learning rate to stabilize 43 | the training of the compressed model. In this way, the model learns more effectively, leading to better performance. 44 | 45 | * **Reduce Weight Decay**: Similarly, if the learning process is fluctuating, consider reducing the weight decay 46 | to stabilize the training of the compressed model. In this way, the model generalizes better for unseen data. 47 | """ 48 | 49 | from .clq import clq_function 50 | from .fake_fp_quantize import fake_fp8_quantize 51 | from .fake_quantize import FakeQuantizeSignature, fake_quantize 52 | from .ste import fake_quantize_ste_function, scaled_round_ste 53 | from .ste_fp import fake_fp_quantize_ste_function 54 | -------------------------------------------------------------------------------- /src/owlite/nn/modules/qmodule_mixins.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | from ...core.logger import log 6 | from .fake_quantizer import FakeQuantizer 7 | 8 | 9 | class UnaryNeuralQModuleMixin(ABC): 10 | """Mixin-class for fake-quantized counterparts of subclasses of `torch.nn.Module`. 11 | 12 | This mixin assumes that the base class has parameters named `weight` and `bias`. 13 | and that its `forward` method takes exactly one parameter other than `self`. 14 | Examples: `torch.nn.Conv1d`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, `torch.nn.Linear`. 15 | """ 16 | 17 | weight: torch.nn.Parameter 18 | bias: torch.nn.Parameter | None 19 | input_quantizer: FakeQuantizer | None 20 | weight_quantizer: FakeQuantizer | None 21 | 22 | @abstractmethod 23 | def _set_bias_to_zero(self) -> None: 24 | pass 25 | 26 | def clip_weight(self) -> None: 27 | """Clip the weights with narrow range. 28 | 29 | If the weight quantizer exists and narrow range is True, clip the weight values to fit the narrow range. 30 | Otherwise, do nothing. 31 | 32 | Raises: 33 | RuntimeError: in tracing. 34 | """ 35 | if torch.jit.is_tracing(): 36 | log.error("Trying to clipping range a module in tracing(torch.jit.trace)") 37 | log.error(self) 38 | raise RuntimeError("Trying to clipping range a module in tracing(torch.jit.trace)") 39 | if self.weight_quantizer is None: 40 | return 41 | if not self.weight_quantizer.narrow_range: 42 | log.debug("Trying to clipping range a module with the weight quantizer that is not a narrow range.") 43 | log.debug(self.weight_quantizer) 44 | return 45 | # convert all step_size to be positive 46 | self.weight_quantizer.step_size.data = self.weight_quantizer.step_size.data.abs() 47 | 48 | clip_min_values = (self.weight_quantizer.quant_min) * self.weight_quantizer.step_size.data 49 | clip_max_values = (self.weight_quantizer.quant_max) * self.weight_quantizer.step_size.data 50 | shape = [-1, *[1 for _ in range(self.weight.dim() - 1)]] 51 | self.weight.data = self.weight.data.clip(clip_min_values.reshape(shape), clip_max_values.reshape(shape)) 52 | return 53 | 54 | def enable(self) -> None: 55 | """Enable the input and weight quantizers.""" 56 | if self.input_quantizer is not None: 57 | self.input_quantizer.enable() 58 | if self.weight_quantizer is not None: 59 | self.weight_quantizer.enable() 60 | 61 | def disable(self) -> None: 62 | """Disable the input and weight quantizers.""" 63 | if self.input_quantizer is not None: 64 | self.input_quantizer.disable() 65 | if self.weight_quantizer is not None: 66 | self.weight_quantizer.disable() 67 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/fix_hard_coded_devices.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | import torch 4 | from torch.fx import GraphModule 5 | from torch.fx.passes.infra.pass_base import PassBase, PassResult 6 | 7 | from ....core.logger import log 8 | 9 | 10 | class FixHardCodedDevice(PassBase): 11 | """Fix hard coded devices to enable data parallel.""" 12 | 13 | def call(self, graph_module: GraphModule) -> PassResult: 14 | """Fix hard coded devices to enable data parallel. 15 | 16 | Args: 17 | graph_module (GraphModule): the input graph module 18 | 19 | Returns: 20 | PassResult: the result of the pass 21 | """ 22 | if hasattr(graph_module, "sqzb_module_device_canary"): 23 | return PassResult(graph_module, False) 24 | 25 | canary_tensor = torch.tensor(12.03, dtype=torch.float32, device=get_most_common_device(graph_module)) 26 | graph_module.register_buffer("sqzb_module_device_canary", canary_tensor) 27 | 28 | graph = graph_module.graph 29 | with graph.inserting_before(next(iter(graph_module.graph.nodes))): 30 | canary = graph.get_attr("sqzb_module_device_canary") 31 | canary_device = graph.call_function(getattr, (canary, "device")) 32 | graph_module.meta["canary_device_node"] = canary_device 33 | # hack to avoid this node from being cleaned up 34 | output_node = graph.find_nodes(op="output") 35 | assert len(output_node) == 1 36 | output_node = output_node[0] 37 | output_node.args = output_node.args + (canary_device,) 38 | 39 | for node in graph.nodes: 40 | if node.kwargs.get("device", None) is not None: 41 | kwargs = node.kwargs.copy() 42 | kwargs["device"] = canary_device 43 | node.kwargs = kwargs 44 | 45 | if ( 46 | node.op == "call_method" 47 | and node.target == "to" 48 | and len(node.args) == 2 49 | and isinstance(node.args[1], str) 50 | ): 51 | args = (node.args[0], canary_device) 52 | node.args = args 53 | 54 | return PassResult(graph_module, True) 55 | 56 | 57 | def get_most_common_device(model: torch.nn.Module) -> torch.device: 58 | """Find the most common device where the parameters of the model reside. 59 | 60 | Args: 61 | model (torch.nn.Module): a model 62 | Returns: 63 | torch.device: the most common device where the parameters of the model reside. 64 | """ 65 | counter = Counter(p.device for p in model.parameters()) 66 | if len(counter) == 0: 67 | return torch.device("cpu") 68 | if len(counter) > 1: 69 | log.warning(f"The model parameters reside on more than 1 devices: {set(counter.elements())}") 70 | return counter.most_common(1)[0][0] 71 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/utils.py: -------------------------------------------------------------------------------- 1 | import builtins 2 | 3 | import torch.nn.functional as F 4 | from torch.fx.node import Node 5 | 6 | 7 | def call_canonical_mask( 8 | mask: Node | None, 9 | mask_name: str, 10 | other: Node | None, 11 | other_name: str, 12 | target: Node, 13 | check_other: bool = False, 14 | ) -> Node | None: 15 | """Call `F._canonical_mask` with tensors instead of types. 16 | 17 | In code level, this would be equivalent to 18 | ```python 19 | F._canonical_mask( 20 | mask=mask, 21 | mask_name=mask_name, 22 | other_type=other.dtype, 23 | other_name=other_name, 24 | target_type=target.dtype, 25 | check_other=check_other, 26 | ) 27 | ``` 28 | 29 | Args: 30 | mask (Node | None): a node generating a mask tensor 31 | mask_name (str): the `mask_name` parameter for the `F._canonical_mask` 32 | other (Node | None): a reference node for providing `other_type` parameter for the `F._canonical_mask` 33 | other_name (str): the `other_name` parameter for the `F._canonical_mask` 34 | target (Node): a reference node for providing `target_type` parameter for the `F._canonical_mask` 35 | check_other (bool, optional): the `check_other` parameter for the `F._canonical_mask`. Defaults to False. 36 | 37 | Returns: 38 | Node | None: the output node of the function call `F._canonical_mask` 39 | """ 40 | graph = target.graph 41 | if mask is None: 42 | return None 43 | other_type: Node | None = None 44 | if other is not None: 45 | other_type = graph.call_function(builtins.getattr, (other, "dtype")) 46 | target_type: Node | None = None 47 | if target is not None: 48 | target_type = graph.call_function(builtins.getattr, (target, "dtype")) 49 | return graph.call_function( 50 | # pylint: disable-next=protected-access 51 | F._canonical_mask, 52 | kwargs={ 53 | "mask": mask, 54 | "mask_name": mask_name, 55 | "other_type": other_type, 56 | "other_name": other_name, 57 | "target_type": target_type, 58 | "check_other": check_other, 59 | }, 60 | ) 61 | 62 | 63 | def inline_get_seq_len(src: Node, batch_first: bool) -> Node: 64 | """Inline the function call `torch.nn.modules.transformer._get_seq_len`. 65 | 66 | Args: 67 | src (Node): a node corresponding to the argument `src: Tensor` of `_get_seq_len` 68 | batch_first (bool): the argument `batch_first: bool` for `_get_seq_len` 69 | 70 | Returns: 71 | Node: the output node of the inlined function call `_get_seq_len` 72 | """ 73 | graph = src.graph 74 | # if not is_batched: src.shape == (S, E) 75 | # elif batch_first: src.shape == (N, S, E) 76 | # else: src.shape == (S, N, E) 77 | return graph.call_method("size", (src, 0 if batch_first else -2)) 78 | -------------------------------------------------------------------------------- /src/owlite/options/options_dict.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generic, TypeVar, get_args 2 | 3 | from packaging.version import Version 4 | from pydantic import TypeAdapter 5 | from typing_extensions import Self 6 | 7 | from .serializable import Serializable 8 | 9 | K = TypeVar("K", bound=str) 10 | V = TypeVar("V") 11 | 12 | 13 | class OptionsDict(dict[K, V], Generic[K, V], Serializable): 14 | """A simple extension of python `dict` to hold Options as values.""" 15 | 16 | def __init__(self, d: dict[str, Any] | str | None = None): 17 | if self.key_type() is K or self.value_type() is V: # type: ignore[misc] 18 | raise TypeError("You must specify key and value types of OptionsDict.") 19 | # Required for checking if the value type is valid 20 | super(dict, self).__init__() 21 | if d is None: 22 | return 23 | for k, v in type(self).deserialize(d).items(): 24 | self[k] = v 25 | 26 | def update(self, d: dict[Any, Any]) -> None: # type: ignore[override] 27 | for key, value in d.items(): 28 | self[key] = value 29 | 30 | @classmethod 31 | def key_type(cls) -> type: 32 | """Get the key type of this dictionary.""" 33 | # pylint: disable-next=no-member 34 | return get_args(cls.__orig_bases__[0])[0] # type: ignore[attr-defined] 35 | 36 | @classmethod 37 | def value_type(cls) -> type: 38 | """Get the value type of this dictionary.""" 39 | # pylint: disable-next=no-member 40 | return get_args(cls.__orig_bases__[0])[1] # type: ignore[attr-defined] 41 | 42 | @classmethod 43 | def deserialize(cls, d: dict[str, Any] | str | Self) -> Self: 44 | d = cls._deserialize(d) 45 | 46 | options_dict = cls() 47 | for name, data in d.items(): 48 | options_dict[name] = data # type: ignore[index] 49 | 50 | return options_dict 51 | 52 | def serialize_as_json(self, version: Version | None = None) -> dict[str, Any]: 53 | ret = { 54 | name: ( 55 | options.serialize_as_json(version) 56 | if isinstance(options, Serializable) 57 | else TypeAdapter(type(self).value_type()).dump_python(options) # type: ignore[arg-type] 58 | ) 59 | for name, options in self.items() 60 | } 61 | if isinstance((cls_version := getattr(self, "__version__", None)), Version): 62 | ret.update(version=str(self.truncate_version(version) or cls_version)) 63 | return ret # type: ignore[return-value] 64 | 65 | def __setitem__(self, key: K, value: V) -> None: 66 | key = TypeAdapter(type(self).key_type()).validate_python(key) 67 | value_type = type(self).value_type() 68 | if Serializable.issubclass(value_type): 69 | value = value_type.deserialize(value) # type: ignore[assignment, arg-type] 70 | else: 71 | value = TypeAdapter(value_type).validate_python(value) 72 | return super().__setitem__(key, value) 73 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/trace.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access, too-many-statements 2 | import inspect 3 | from typing import Any 4 | 5 | import torch 6 | from torch.fx.graph_module import GraphModule 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | 9 | from ...core.logger import log 10 | from ...enums import ModelStatus 11 | from ..signature import Signature 12 | from .graph_checker import validate_procedure_calls 13 | from .optimize import optimize 14 | from .passes import PassName 15 | 16 | 17 | # pylint: disable-next=too-many-locals 18 | def symbolic_trace( 19 | model: torch.nn.Module, *args: Any, **kwargs: Any 20 | ) -> GraphModule: 21 | """Symbolically trace the input `model` to convert it into a GraphModule. 22 | 23 | In order for the tracing to be successful, the `model` must be able to pass `torch.compile(model, fullgraph=True)`. 24 | 25 | Args: 26 | model (torch.nn.Module): a torch.nn.Module instance. 27 | *args: the example input(s) that would be passed to the model's forward method. 28 | **kwargs: the example input(s) that would be passed to the model's forward method. 29 | 30 | Raises: 31 | TypeError: if the `model` is not an instance of `torch.nn.Module` 32 | RuntimeError: if the tracing fails. 33 | 34 | Returns: 35 | GraphModule: the converted GraphModule. 36 | """ 37 | # If `owlite.fx.symbolic_trace` is called on more than models, the compilation caches created from one model 38 | # might cause an unexpected error for another. Furthermore, OwLite currently doesn't support graph breaks within 39 | # a model, there's no need to keep the compilation caches. 40 | # Hence we always clear all caches before running the compilation. 41 | torch._dynamo.reset() 42 | torch.compiler.reset() 43 | 44 | given_type = type(model) 45 | if isinstance(model, DataParallel | DistributedDataParallel): 46 | log.error( 47 | f"{given_type} is not supported by symbolic trace, please use 'attribute' module to unwrap model " 48 | f"from {given_type}. Try owlite.fx.symbolic_trace(model.module, ...)" 49 | ) 50 | 51 | if not isinstance(model, torch.nn.Module): 52 | raise TypeError(f"Expected torch.nn.Module instance but object of type {given_type} given: {model}") 53 | 54 | training_status = model.training 55 | 56 | original_signature = inspect.signature(model.forward) 57 | exporter = torch._dynamo.export(model, aten_graph=False, pre_dispatch=False, tracing_mode="real") 58 | graph_module = exporter(*args, **kwargs).graph_module 59 | 60 | graph_module.train(training_status) 61 | graph_module.meta["status"] = ModelStatus.TRACED 62 | graph_module_input_signature = Signature.from_module(graph_module, args, kwargs) 63 | graph_module_input_signature.warn_signature_change(dict(original_signature.parameters.items())) 64 | graph_module.meta["input_signature"] = graph_module_input_signature 65 | validate_procedure_calls(graph_module) 66 | 67 | _ = optimize(graph_module) 68 | 69 | return graph_module 70 | -------------------------------------------------------------------------------- /src/owlite/nn/functions/clq.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-argument 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | 7 | from .fake_quantize import BaseFakeINTQuantizeFunction, fake_quantize 8 | 9 | 10 | # mypy: disable-error-code=override 11 | # pylint: disable-next=abstract-method 12 | class CLQFunction(BaseFakeINTQuantizeFunction): 13 | r"""An implementation of QAT function using CLQ (Constrained Learned Quantization). 14 | 15 | In **CLQ(Constrained Learned Quantization)** method, instead of using a fixed set of quantization levels, 16 | this method adapts the scales during training to minimize the impact on model performance. Learnable step_size 17 | allows the model to be better adapted to the distribution of fed data. 18 | ### Gradient of step\_size. 19 | 20 | When $$x$$ is input of $$FakeQuantize$$ and $$s$$ is step\_size of $$FakeQuantize$$ 21 | 22 | $$ 23 | \dfrac{\partial \hat{x}}{\partial s}= \begin{cases} \left( -\dfrac{x}{|s|}+\left\lceil{\dfrac{x}{|s|}} 24 | \right\rfloor \right) \cdot \text{sign}(s) & \text{if, } \text{quant\_min} < \dfrac{x}{|s|} < \text{qant\_max} 25 | \\ \\ \text{quant\_min} \cdot \text{sign}(s) &\text{if, }\dfrac{x}{|s|}\leq \text{quant\_min} \\ 26 | \\ \text{quant\_max}\cdot \text{sign}(s) &\text{if, } \dfrac{x}{|s|}\geq \text{quant\_max} \end{cases} 27 | $$ 28 | """ 29 | 30 | @staticmethod # pylint: disable-next=arguments-differ, too-many-positional-arguments 31 | def forward( 32 | ctx: Any, 33 | inputs: Tensor, 34 | step_size: Tensor, 35 | zero_point: Tensor, 36 | grad_scale: float, 37 | quant_min: int, 38 | quant_max: int, 39 | axis: int | None, 40 | ) -> Tensor: 41 | ctx.save_for_backward(inputs, step_size) 42 | ctx.other = grad_scale, quant_min, quant_max, axis 43 | return fake_quantize(inputs, step_size.abs(), zero_point, quant_min=quant_min, quant_max=quant_max, axis=axis) 44 | 45 | @staticmethod 46 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 47 | inputs, step_size = ctx.saved_tensors 48 | grad_output = grad_outputs[0] 49 | step_size_abs = step_size.abs().reshape([-1] + ([1] * (inputs.dim() - 1))) 50 | grad_scale, quant_min, quant_max, per_channel = ctx.other 51 | affine_input = (inputs / step_size_abs).clip(quant_min, quant_max) 52 | between = affine_input.gt(quant_min) & affine_input.lt(quant_max) 53 | grad_step_size = ( 54 | torch.where(between, (torch.round(affine_input) - affine_input), affine_input) * grad_output * grad_scale 55 | ) 56 | grad_step_size = ( 57 | grad_step_size.sum(dim=tuple(range(1, inputs.dim())), keepdim=False) 58 | if per_channel 59 | else grad_step_size.sum().unsqueeze(dim=0) 60 | ) 61 | grad_step_size = grad_step_size * torch.where(step_size == 0.0, 1.0, step_size.sign()) 62 | grad_output = grad_output * between 63 | return grad_output, grad_step_size, None, None, None, None, None, None 64 | 65 | 66 | clq_function = CLQFunction.apply 67 | -------------------------------------------------------------------------------- /src/owlite/calib/percentile_calibrator.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING 2 | 3 | import torch 4 | from torch.utils.hooks import RemovableHandle 5 | 6 | from .histogram_calibrator import HistogramCalibrator 7 | 8 | if TYPE_CHECKING: 9 | from ..nn import FakeQuantizer 10 | 11 | 12 | class PercentileCalibrator(HistogramCalibrator): 13 | """Percentile Calibrator Class. 14 | 15 | This calibrator also utilizes the data's histogram. However, instead of minimizing an error metric, it employs a 16 | heuristic approach based on a pre-specified percentile. The value corresponding to the chosen percentile is set as 17 | the **maximum absolute value**, and the `step_size` is calculated accordingly. By tuning percentile, user can 18 | control trade-off between quantization accuracy and outlier removal. 19 | 20 | Attributes: 21 | quantizer (`FakeQuantizer`): The `FakeQuantizer` module to be calibrated. 22 | percentile (`float`): The desired percentile value, ranging from 0 to 100. 23 | 24 | """ 25 | 26 | def __init__(self, quantizer: "FakeQuantizer", percentile: float): 27 | """Initialize the percentile calibrator. 28 | 29 | Args: 30 | quantizer (FakeQuantizer): The `FakeQuantizer` module to be calibrated. 31 | percentile(float): The desired percentile value, ranging from 0 to 100. 32 | 33 | Raises: 34 | ValueError: If the percentile is outside the valid range [0, 100]. 35 | """ 36 | super().__init__(quantizer) 37 | if percentile < 0 or percentile > 100: 38 | raise ValueError("percentile must be in range [0,100]") 39 | self.percentile = percentile 40 | 41 | def update(self) -> None: 42 | """Update step_size using "`percentile`".""" 43 | super().update() 44 | assert isinstance(self.hook_handler, RemovableHandle) 45 | 46 | # cumsum_cuda_kernel does not have a deterministic implementation 47 | _deterministic_enable_status = torch.are_deterministic_algorithms_enabled() 48 | torch.use_deterministic_algorithms(False, warn_only=True) 49 | 50 | max_values = torch.empty_like(self.quantizer.step_size) 51 | min_values = torch.empty_like(self.quantizer.step_size) 52 | for chn, _ in enumerate(self.histograms): 53 | total = self.histograms[chn].data.sum() 54 | cdf = torch.cumsum(self.histograms[chn].data / total, 0) 55 | if self.quantizer.symmetric: 56 | idx = torch.searchsorted(cdf, self.percentile / 100) 57 | per_max = self.bin_edges[chn].data[idx] 58 | max_values[chn] = per_max 59 | else: 60 | min_idx = torch.searchsorted(cdf, (1 - self.percentile / 100) * 0.5) 61 | max_idx = torch.searchsorted(cdf, (1 + self.percentile / 100) * 0.5) 62 | max_values[chn] = self.bin_edges[chn][max_idx].data 63 | min_values[chn] = self.bin_edges[chn][min_idx].data 64 | 65 | self.update_fake_quantizer_param_with_max_min(max_values, min_values) 66 | 67 | # allocate deterministic algorithms to original state 68 | torch.use_deterministic_algorithms(_deterministic_enable_status, warn_only=True) 69 | 70 | self.clear() 71 | -------------------------------------------------------------------------------- /src/owlite/core/cli/login.py: -------------------------------------------------------------------------------- 1 | """Module for OwLite Authentication. 2 | 3 | Includes functions for handling OwLite user authentication. 4 | """ 5 | 6 | import re 7 | from getpass import getpass 8 | 9 | from ..cache.tokens import Tokens 10 | from ..cache.workspace import Workspace 11 | from ..device_settings import OWLITE_DEVICE_SETTINGS 12 | from ..logger import log 13 | from ..settings import OWLITE_SETTINGS 14 | from .api.login import login as _login 15 | from .api.login import whoami 16 | 17 | 18 | def login() -> None: 19 | """Login to OwLite. 20 | 21 | Raises: 22 | HTTPError: When login request was not successful. 23 | """ 24 | 25 | def _is_valid_email(email: str) -> bool: 26 | """Check if the email is valid. 27 | 28 | Args: 29 | email (str): A email to check. 30 | 31 | Returns: 32 | bool: True if given email is valid, False otherwise. 33 | """ 34 | regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b" 35 | if not re.fullmatch(regex, email): 36 | log.error("Invalid email provided") # UX 37 | return False 38 | return True 39 | 40 | def _is_valid_password(password: str) -> bool: 41 | r"""Check if the password is valid. 42 | 43 | Args: 44 | password (str): A password to check. 45 | 46 | Returns: 47 | bool: True if given password is valid, False otherwise. 48 | """ 49 | allowed_specials = r"!@#$%^&*()_+\-=\[\]{}|~₩" 50 | regex = r"^(?=.*[!@#$%^&*()_+\-=\[\]{}|~₩])[A-Za-z0-9!@#$%^&*()_+\-=\[\]{}|~₩]{8,}$" 51 | if not re.match(regex, password): 52 | log.error( 53 | "The password does not meet the requirement. A valid password must contain at least eight characters, " 54 | "including one or more alphabetic, numeric, and special characters. " 55 | f"Special characters must be chosen from {allowed_specials}" 56 | ) # UX 57 | return False 58 | return True 59 | 60 | logout(verbose=False) 61 | email = input("Enter your email: ") # UX 62 | if not _is_valid_email(email): 63 | return 64 | password = getpass("Enter your password: ") # UX 65 | if not _is_valid_password(password): 66 | return 67 | 68 | resp = _login(email, password) 69 | tokens = Tokens(**resp) 70 | OWLITE_SETTINGS.tokens = tokens 71 | 72 | userinfo = whoami() 73 | log.info(f"Logged in as {userinfo.name}") # UX 74 | 75 | OWLITE_SETTINGS.current_workspace = Workspace.load(userinfo.default_workspace_id) 76 | 77 | log.info(f"Your authentication token is saved at {OWLITE_SETTINGS.tokens_cache}") # UX 78 | log.debug(f"Saved tokens: \n\t\taccess token= '{tokens.access_token}'\n\t\trefresh token= '{tokens.refresh_token}'") 79 | 80 | 81 | def logout(verbose: bool = True) -> None: 82 | """Logout from OwLite, tokens are deleted from the machine. 83 | 84 | Args: 85 | verbose (bool): Whether to print log messages, defaults to True. 86 | """ 87 | OWLITE_SETTINGS.tokens = None 88 | OWLITE_DEVICE_SETTINGS.connected = None 89 | OWLITE_SETTINGS.current_workspace = None 90 | if verbose: 91 | log.info("Successfully logged out") # UX 92 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_transformer.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=duplicate-code 2 | import torch 3 | from torch.fx.node import Node 4 | 5 | from ..node import get_target_module 6 | from .node_argument import NodeArgument 7 | from .rewrite_pass import RewritePass 8 | 9 | 10 | class TransformerNodeArgument(NodeArgument): 11 | """The arguments of a "call_module" node with target module of type `torch.nn.Transformer`.""" 12 | 13 | src: Node 14 | tgt: Node 15 | src_mask: Node | None = None 16 | tgt_mask: Node | None = None 17 | memory_mask: Node | None = None 18 | src_key_padding_mask: Node | None = None 19 | tgt_key_padding_mask: Node | None = None 20 | memory_key_padding_mask: Node | None = None 21 | src_is_causal: bool | None = None 22 | tgt_is_causal: bool | None = None 23 | memory_is_causal: bool = False 24 | 25 | @classmethod 26 | def validate_node(cls, node: Node) -> bool: 27 | return node.op == "call_module" and isinstance(get_target_module(node), torch.nn.Transformer) 28 | 29 | 30 | class DecomposeTransformer(RewritePass): 31 | """Decompose all occurrences of `torch.nn.Transformer` by an equivalent subgraph. 32 | 33 | Note: this rewrite pass is implemented based on torch>=2.3.1,<=2.4.0 34 | """ 35 | 36 | @classmethod 37 | def rewrite(cls, node: Node) -> dict[Node, Node]: 38 | if (arguments := TransformerNodeArgument.extract_from(node)) is None: 39 | return {} 40 | 41 | if arguments.src_is_causal is None: 42 | raise NotImplementedError( 43 | "Found a `torch.nn.Transformer` layer forwarded with `src_is_causal=None`. " 44 | "OwLite cannot handle dynamic control flow triggered by `src_is_causal` detection. " 45 | "Please set its value to either `True` or `False`." 46 | ) # UX 47 | 48 | if arguments.tgt_is_causal is None: 49 | raise NotImplementedError( 50 | "Found a `torch.nn.Transformer` layer forwarded with `tgt_is_causal=None`. " 51 | "OwLite cannot handle dynamic control flow triggered by `tgt_is_causal` detection. " 52 | "Please set its value to either `True` or `False`." 53 | ) # UX 54 | 55 | graph = node.graph 56 | with graph.inserting_before(node): 57 | memory = graph.call_module( 58 | f"{node.target}.encoder", 59 | args=(arguments.src,), 60 | kwargs={ 61 | "mask": arguments.src_mask, 62 | "src_key_padding_mask": arguments.src_key_padding_mask, 63 | "is_causal": arguments.src_is_causal, 64 | }, 65 | ) 66 | output = graph.call_module( 67 | f"{node.target}.decoder", 68 | args=(arguments.tgt, memory), 69 | kwargs={ 70 | "tgt_mask": arguments.tgt_mask, 71 | "memory_mask": arguments.memory_mask, 72 | "tgt_key_padding_mask": arguments.tgt_key_padding_mask, 73 | "memory_key_padding_mask": arguments.memory_key_padding_mask, 74 | "tgt_is_causal": arguments.tgt_is_causal, 75 | "memory_is_causal": arguments.memory_is_causal, 76 | }, 77 | ) 78 | return {node: output} 79 | -------------------------------------------------------------------------------- /src/owlite/nn/modules/qlinear.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=not-callable 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | from ...options import Channel, FakeQuantizerOptions 7 | from .fake_quantizer import FakeQuantizer 8 | from .qmodule_mixins import UnaryNeuralQModuleMixin 9 | 10 | 11 | # mypy: disable-error-code=misc 12 | class QLinear(torch.nn.Linear, UnaryNeuralQModuleMixin): 13 | r"""Applies a linear transformation with fake-quantized weight $$ A_q $$ to the incoming data: $$ y = xA_q^T + b $$. 14 | 15 | Additionally, fake-quantization is applicable to both the bias and bias addition: 16 | $$y = \text{quant}(xW_q^T) + \text{quant}(b)$$, where represents $$\text{quant}$$ the fake-quantize function. 17 | The module copies the weights and biases from the original linear instance. 18 | 19 | Quantized linear layer inherited from 20 | [torch.nn.Linear](https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py). 21 | """ 22 | 23 | def __init__( 24 | self, 25 | linear: torch.nn.Linear, 26 | weight_opts: FakeQuantizerOptions | None = None, 27 | ) -> None: 28 | """Convert a `Linear` instance to the analogous `QLinear` instance, copying weights and bias if exists. 29 | 30 | Args: 31 | linear (`torch.nn.Linear`): a `Linear` instance to be converted to `QLinear` instance. 32 | weight_opts (`FakeQuantizerOptions | None`, optional): Option for the fake weight quantizer. 33 | Defaults to None. 34 | """ 35 | super().__init__( 36 | linear.in_features, 37 | linear.out_features, 38 | linear.bias is not None, 39 | linear.weight.device, 40 | linear.weight.dtype, 41 | ) 42 | self.train(linear.training) 43 | self.input_quantizer: FakeQuantizer | None = None 44 | with torch.no_grad(): 45 | self.weight.copy_(linear.weight) 46 | if self.bias is not None: 47 | self.bias.copy_(linear.bias) 48 | channel = ( 49 | Channel(axis=0, size=self.out_features) if (weight_opts is not None and weight_opts.per_channel) else None 50 | ) 51 | self.weight_quantizer = FakeQuantizer.create(weight_opts, channel, narrow_range=True) 52 | if self.weight_quantizer is not None: 53 | self.weight_quantizer.to(self.weight.device) 54 | self.bias_quantizer: FakeQuantizer | None = None 55 | self.hidden_input_quantizer: FakeQuantizer | None = None 56 | 57 | def _set_bias_to_zero(self) -> None: 58 | self.bias = torch.nn.Parameter(torch.zeros(self.out_features).to(self.weight.device)) 59 | 60 | # pylint: disable=arguments-renamed, invalid-name 61 | def forward(self, inputs: Tensor) -> Tensor: 62 | """Forward with quantized weight if available.""" 63 | weight = self.weight_quantizer(self.weight) if self.weight_quantizer is not None else self.weight 64 | bias = ( 65 | self.bias_quantizer(self.bias) if self.bias_quantizer is not None and self.bias is not None else self.bias 66 | ) 67 | 68 | if self.hidden_input_quantizer is not None and bias is not None: 69 | x = F.linear(inputs, weight, None) 70 | x = self.hidden_input_quantizer(x) 71 | return torch.add(bias, x) 72 | return F.linear(inputs, weight, bias) 73 | -------------------------------------------------------------------------------- /src/owlite/nn/modules/qconvbn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d, _BatchNorm 4 | 5 | from .fake_quantizer import FakeQuantizer 6 | from .qconv import QConv1d, QConv2d, QConv3d, _QConvNd 7 | 8 | 9 | class _QConvBnNd(torch.nn.Module): 10 | """Base class of quantized covolution with followed batchnorm layer.""" 11 | 12 | def __init__(self, qconv: _QConvNd, bn: _BatchNorm): 13 | super().__init__() 14 | self.qconv: _QConvNd = qconv 15 | self.bn: _BatchNorm | None = bn 16 | 17 | # pylint: disable=protected-access 18 | def forward(self, inputs: Tensor) -> Tensor: 19 | """Forward with quantized convolution with batch normalization.""" 20 | if self.bn is None: 21 | return self.qconv(inputs) 22 | conv_output = self.qconv._conv_forward(inputs, self.qconv._get_weight(), self.qconv.bias) 23 | bn_output = self.bn(conv_output) 24 | bn_output = self._folding_forward(bn_output) 25 | return bn_output 26 | 27 | # pylint: enable=protected-access 28 | def _folding_forward(self, inputs: Tensor) -> Tensor: 29 | if ( 30 | self.bn is None 31 | or not self.qconv.int32_bias 32 | or not FakeQuantizer.check_if_enabled(self.qconv.weight_quantizer) 33 | or not FakeQuantizer.check_if_enabled(self.qconv.input_quantizer) 34 | ): 35 | return inputs 36 | assert self.bn.running_mean is not None 37 | assert self.bn.running_var is not None 38 | # assume that inputs shape is (n,cin,*) and weight shape is (cout,cin,*kernel_shape) 39 | weight_shape = [1] * inputs.dim() 40 | weight_shape[1] = -1 41 | channel_dims = list(range(inputs.dim())) 42 | channel_dims.pop(1) 43 | qconv_bias = self.qconv.bias if self.qconv.bias is not None else torch.zeros_like(self.bn.bias) 44 | mean = inputs.mean(dim=channel_dims) if self.bn.training else self.bn.running_mean 45 | var = inputs.std(dim=channel_dims) ** 2 if self.bn.training else self.bn.running_var 46 | alpha = torch.rsqrt(var + self.bn.eps) * self.bn.weight 47 | beta = self.bn.bias - self.bn.weight * mean * torch.rsqrt(var + self.bn.eps) 48 | fused_scale = alpha * self.qconv.weight_quantizer.step_size * self.qconv.input_quantizer.step_size 49 | fused_bias = (alpha * qconv_bias + beta) / fused_scale 50 | output = inputs - ((fused_bias - fused_bias.round()) * (fused_scale)).detach().reshape(weight_shape) 51 | return output 52 | 53 | 54 | class QConvBn1d(_QConvBnNd): 55 | """This module sequentially calls the `QConv1d` and `BatchNorm1d` modules if they are available.""" 56 | 57 | def __init__(self, qconv: QConv1d, bn: BatchNorm1d): 58 | super().__init__(qconv, bn) 59 | self.qconv: QConv1d 60 | self.bn: BatchNorm1d | None 61 | 62 | 63 | class QConvBn2d(_QConvBnNd): 64 | """This module sequentially calls the `QConv2d` and `BatchNorm2d` modules if they are available.""" 65 | 66 | def __init__(self, qconv: QConv2d, bn: BatchNorm2d): 67 | super().__init__(qconv, bn) 68 | self.qconv: QConv2d 69 | self.bn: BatchNorm2d | None 70 | 71 | 72 | class QConvBn3d(_QConvBnNd): 73 | """This module sequentially calls the `QConv3d` and `BatchNorm3d` modules if they are available.""" 74 | 75 | def __init__(self, qconv: QConv3d, bn: BatchNorm3d): 76 | super().__init__(qconv, bn) 77 | self.qconv: QConv3d 78 | self.bn: BatchNorm3d | None 79 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/optimize.py: -------------------------------------------------------------------------------- 1 | from torch.fx import GraphModule 2 | from torch.fx.passes.infra.pass_base import PassResult 3 | from torch.fx.passes.infra.pass_manager import PassManager 4 | 5 | from ..config import FX_TRANSFORM_MAXIMUM_ITERATION 6 | from .passes import ( 7 | ConnectInplaceOpsToUsers, 8 | DecomposeExpm1, 9 | # DecomposeGELU, 10 | DecomposeInProjection, 11 | DecomposeInProjectionPacked, 12 | DecomposeMultiheadAttention, 13 | DecomposeMultiHeadAttentionForward, 14 | DecomposeScaledDotProductAttention, 15 | DecomposeSiLU, 16 | DecomposeTransformer, 17 | DecomposeTransformerDecoder, 18 | DecomposeTransformerDecoderLayer, 19 | DecomposeTransformerEncoder, 20 | DecomposeTransformerEncoderLayer, 21 | EliminateDummyOutput, 22 | EliminateExplicitGetitem, 23 | EliminateIdentity, 24 | EliminateNopGetitem, 25 | FixHardCodedDevice, 26 | FuseConsecutiveConcats, 27 | PassName, 28 | RewriteLayerNormsFunctional, 29 | ) 30 | 31 | 32 | def optimize( 33 | graph_module: GraphModule, 34 | skipped_optimizers: list[PassName] | None = None, 35 | ) -> PassResult: 36 | """Optimize the given graph module inplace. 37 | 38 | Args: 39 | graph_module (GraphModule): a graph module 40 | skipped_optimizers (list[PassName] | None, optional): the names of optimization passes to skip. 41 | Defaults to None. 42 | 43 | Returns: 44 | PassResult: the result of the transform 45 | """ 46 | result = get_pass_manager(skipped_optimizers)(graph_module) 47 | if result.modified: 48 | graph_module.graph.eliminate_dead_code() 49 | graph_module.graph.lint() 50 | graph_module.recompile() 51 | return result 52 | 53 | 54 | def get_pass_manager(skipped_optimizers: list[PassName] | None = None) -> PassManager: 55 | """Get pass manager. 56 | 57 | Args: 58 | skipped_optimizers (list[PassName] | None, optional): the names of optimization passes to skip. 59 | Defaults to None. 60 | 61 | Returns: 62 | PassManager: a pass manager 63 | """ 64 | pass_manager = PassManager(steps=FX_TRANSFORM_MAXIMUM_ITERATION) 65 | 66 | functionality_fixes = ( 67 | ConnectInplaceOpsToUsers, 68 | EliminateIdentity, 69 | EliminateExplicitGetitem, 70 | FixHardCodedDevice, 71 | ) 72 | transformer_rewrite_passes = ( 73 | DecomposeTransformer, 74 | DecomposeTransformerEncoder, 75 | DecomposeTransformerEncoderLayer, 76 | DecomposeTransformerDecoder, 77 | DecomposeTransformerDecoderLayer, 78 | ) 79 | mha_rewrite_passes = ( 80 | DecomposeMultiheadAttention, 81 | DecomposeMultiHeadAttentionForward, 82 | DecomposeInProjectionPacked, 83 | DecomposeInProjection, 84 | DecomposeScaledDotProductAttention, 85 | ) 86 | other_rewrite_passes = ( 87 | DecomposeExpm1, 88 | DecomposeSiLU, 89 | EliminateNopGetitem, 90 | FuseConsecutiveConcats, 91 | RewriteLayerNormsFunctional, 92 | EliminateDummyOutput, 93 | ) 94 | 95 | for fx_pass in ( 96 | *functionality_fixes, 97 | *transformer_rewrite_passes, 98 | *mha_rewrite_passes, 99 | *other_rewrite_passes, 100 | ): 101 | if skipped_optimizers is not None and fx_pass.__name__ in skipped_optimizers: 102 | continue 103 | pass_manager.add_pass(fx_pass()) 104 | 105 | return pass_manager 106 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_transformer_decoder.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | import torch 4 | from torch.fx.node import Node 5 | 6 | from ..node import get_target_module 7 | from .node_argument import NodeArgument 8 | from .rewrite_pass import RewritePass 9 | 10 | 11 | class TransformerDecoderNodeArgument(NodeArgument): 12 | """The arguments of a "call_module" node with target module of type `torch.nn.Transformer`.""" 13 | 14 | tgt: Node 15 | memory: Node 16 | tgt_mask: Node | None = None 17 | memory_mask: Node | None = None 18 | tgt_key_padding_mask: Node | None = None 19 | memory_key_padding_mask: Node | None = None 20 | tgt_is_causal: bool | None = None 21 | memory_is_causal: bool = False 22 | 23 | @classmethod 24 | def validate_node(cls, node: Node) -> bool: 25 | return node.op == "call_module" and isinstance(get_target_module(node), torch.nn.TransformerDecoder) 26 | 27 | @cached_property 28 | def module(self) -> torch.nn.TransformerDecoder: 29 | """The `torch.nn.TransformerDecoder` layer called by the node.""" 30 | assert isinstance((m := get_target_module(self.node)), torch.nn.TransformerDecoder) 31 | return m 32 | 33 | 34 | class DecomposeTransformerDecoder(RewritePass): 35 | """Decompose all occurrences of `torch.nn.TransformerEncoder` by an equivalent subgraph. 36 | 37 | Note: this rewrite pass is implemented based on torch>=2.3.1,<=2.4.0 38 | """ 39 | 40 | @classmethod 41 | def rewrite(cls, node: Node) -> dict[Node, Node]: 42 | if (arguments := TransformerDecoderNodeArgument.extract_from(node)) is None: 43 | return {} 44 | 45 | if arguments.tgt_is_causal is None: 46 | raise NotImplementedError( 47 | "Found a `torch.nn.TransformerDecoder` layer forwarded with `tgt_is_causal=None`. " 48 | "OwLite cannot handle dynamic control flow triggered by `tgt_is_causal` detection. " 49 | "Please set its value to either `True` or `False`." 50 | ) # UX 51 | 52 | self_layers = arguments.module.layers 53 | # if not isinstance((first_layer := self_layers[0]), torch.nn.TransformerDecoderLayer): 54 | # return {} 55 | 56 | graph = node.graph 57 | with graph.inserting_before(node): 58 | # Note: `tgt_is_causal` must not be a Node ... 59 | # seq_len = inline_get_seq_len(arguments.tgt, first_layer.self_attn.batch_first) 60 | # tgt_is_causal = graph.call_function( 61 | # torch.nn.modules.transformer._detect_is_causal_mask, 62 | # (arguments.tgt_mask, arguments.tgt_is_causal, seq_len) 63 | # ) 64 | 65 | output = arguments.tgt 66 | for i in range(len(self_layers)): 67 | output = graph.call_module( 68 | f"{node.target}.layers.{i}", 69 | args=(output, arguments.memory), 70 | kwargs={ 71 | "tgt_mask": arguments.tgt_mask, 72 | "memory_mask": arguments.memory_mask, 73 | "tgt_key_padding_mask": arguments.tgt_key_padding_mask, 74 | "memory_key_padding_mask": arguments.memory_key_padding_mask, 75 | "tgt_is_causal": arguments.tgt_is_causal, 76 | "memory_is_causal": arguments.memory_is_causal, 77 | }, 78 | ) 79 | 80 | if arguments.module.norm is not None: 81 | output = graph.call_module(f"{node.target}.norm", (output,)) 82 | 83 | return {node: output} 84 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/node_argument.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import TypeVar, overload 3 | 4 | from pydantic import BaseModel, Field, ValidationError 5 | from pydantic_core import PydanticUndefined 6 | from torch.fx.node import Argument, Node 7 | from typing_extensions import Self 8 | 9 | from ....core.logger import log 10 | from ...utils import nodestr 11 | 12 | 13 | class NodeArgument(BaseModel): 14 | """Abstract base class for defining a specific node's arguments.""" 15 | 16 | model_config = { 17 | "arbitrary_types_allowed": True, 18 | "validate_assignment": True, 19 | "validate_default": True, 20 | } 21 | node: Node = Field(exclude=True) 22 | 23 | @classmethod 24 | @abstractmethod 25 | def validate_node(cls, node: Node) -> bool: 26 | """Validate if the node is suitable for argument extraction. 27 | 28 | Args: 29 | node (Node): a node 30 | 31 | Returns: 32 | bool: `True` if it is suitable, `False` otherwise. 33 | """ 34 | raise NotImplementedError(f"{cls.__name__}.validate_node is not implemented") 35 | 36 | @classmethod 37 | def extract_from(cls, node: Node) -> Self | None: 38 | """Extract arguments from the node. 39 | 40 | Args: 41 | node (Node): a node 42 | 43 | Returns: 44 | Self | None: the extracted arguments if succeeded, `None` otherwise. 45 | """ 46 | if not cls.validate_node(node): 47 | return None 48 | try: 49 | arguments = { 50 | name: get_argument( 51 | node, 52 | index - 1, 53 | name, 54 | default=None if field.default is PydanticUndefined else field.default, 55 | ) 56 | for index, (name, field) in enumerate(cls.model_fields.items()) 57 | if index > 0 # should skip the `node` 58 | } 59 | return cls.model_validate({"node": node, **arguments}) 60 | except ValidationError as e: 61 | log.warning(f"Incorrect arguments given to the node {nodestr(node)}: {arguments}. ({e})") 62 | return None 63 | 64 | 65 | DefaultValue = TypeVar("DefaultValue") 66 | 67 | 68 | @overload 69 | def get_argument( 70 | node: Node, 71 | index_as_arg: int, 72 | name_as_kwarg: str, 73 | default: DefaultValue, 74 | ) -> DefaultValue: ... 75 | 76 | 77 | @overload 78 | def get_argument( 79 | node: Node, 80 | index_as_arg: int, 81 | name_as_kwarg: str, 82 | ) -> Argument: ... 83 | 84 | 85 | def get_argument( 86 | node: Node, 87 | index_as_arg: int, 88 | name_as_kwarg: str, 89 | default: DefaultValue | None = None, 90 | ) -> DefaultValue | Argument: 91 | """Get the node argument of the given node. 92 | 93 | Args: 94 | node (Node): a node 95 | index_as_arg (int): the index to look up when the node argument is given as a positional argument 96 | name_as_kwarg (str): the key to look up when the node argument is given as a keyword argument 97 | default (DefaultValue | None, optional): the default value when the node argument is not explicitly specified. 98 | Defaults to None. 99 | 100 | Returns: 101 | DefaultValue | Argument: the node argument if found or its default value. 102 | """ 103 | return ( 104 | node.kwargs[name_as_kwarg] 105 | if name_as_kwarg in node.kwargs 106 | else (node.args[index_as_arg] if len(node.args) > index_as_arg else default) 107 | ) 108 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_transformer_encoder.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | 3 | import torch 4 | from torch.fx.node import Node 5 | 6 | from ..node import get_target_module 7 | from .node_argument import NodeArgument 8 | from .rewrite_pass import RewritePass 9 | from .utils import call_canonical_mask 10 | 11 | 12 | class TransformerEncoderNodeArgument(NodeArgument): 13 | """The arguments of a "call_module" node with target module of type `torch.nn.Transformer`.""" 14 | 15 | src: Node 16 | mask: Node | None = None 17 | src_key_padding_mask: Node | None = None 18 | is_causal: bool | None = None 19 | 20 | @classmethod 21 | def validate_node(cls, node: Node) -> bool: 22 | return node.op == "call_module" and isinstance(get_target_module(node), torch.nn.TransformerEncoder) 23 | 24 | @cached_property 25 | def module(self) -> torch.nn.TransformerEncoder: 26 | """The `torch.nn.TransformerEncoder` layer called by the node.""" 27 | assert isinstance((m := get_target_module(self.node)), torch.nn.TransformerEncoder) 28 | return m 29 | 30 | 31 | class DecomposeTransformerEncoder(RewritePass): 32 | """Decompose all occurrences of `torch.nn.TransformerEncoder` by an equivalent subgraph. 33 | 34 | Note: this rewrite pass is implemented based on torch>=2.3.1,<=2.4.0 35 | """ 36 | 37 | @classmethod 38 | def rewrite(cls, node: Node) -> dict[Node, Node]: 39 | if (arguments := TransformerEncoderNodeArgument.extract_from(node)) is None: 40 | return {} 41 | 42 | if arguments.is_causal is None: 43 | raise NotImplementedError( 44 | "Found a `torch.nn.TransformerEncoder` layer forwarded with `is_causal=None`. " 45 | "OwLite cannot handle dynamic control flow triggered by `is_causal` detection. " 46 | "Please set its value to either `True` or `False`." 47 | ) # UX 48 | 49 | graph = node.graph 50 | with graph.inserting_before(node): 51 | src_key_padding_mask = call_canonical_mask( 52 | mask=arguments.src_key_padding_mask, 53 | mask_name="src_key_padding_mask", 54 | other=arguments.mask, 55 | other_name="mask", 56 | target=arguments.src, 57 | ) 58 | 59 | mask = call_canonical_mask( 60 | mask=arguments.mask, 61 | mask_name="mask", 62 | other=None, 63 | other_name="", 64 | target=arguments.src, 65 | check_other=False, 66 | ) 67 | 68 | self_layers = arguments.module.layers 69 | # first_layer = self_layers[0] 70 | # batch_first = first_layer.self_attn.batch_first 71 | 72 | # Note: `is_causal` must not be a node ... 73 | # seq_len = inline_get_seq_len(arguments.src, batch_first) 74 | # is_causal = graph.call_function( 75 | # torch.nn.modules.transformer._detect_is_causal_mask, 76 | # (mask, arguments.is_causal, seq_len), 77 | # ) 78 | 79 | output = arguments.src 80 | for i in range(len(self_layers)): 81 | output = graph.call_module( 82 | f"{node.target}.layers.{i}", 83 | args=(output,), 84 | kwargs={ 85 | "src_mask": mask, 86 | "is_causal": arguments.is_causal, 87 | "src_key_padding_mask": src_key_padding_mask, 88 | }, 89 | ) 90 | 91 | if arguments.module.norm is not None: 92 | output = graph.call_module(f"{node.target}.norm", (output,)) 93 | 94 | return {node: output} 95 | -------------------------------------------------------------------------------- /src/owlite/core/cli/api/login.py: -------------------------------------------------------------------------------- 1 | """API wrapper module for login.""" 2 | 3 | import requests 4 | from pydantic import AliasChoices, BaseModel, ConfigDict, Field 5 | 6 | from ...api_base import APIBase 7 | from ...cache.workspace import Workspace 8 | from ...constants import OWLITE_API_DEFAULT_TIMEOUT 9 | from ...exceptions import LoginError 10 | from ...logger import log 11 | from ...settings import OWLITE_SETTINGS 12 | 13 | 14 | class UserInfo(BaseModel): 15 | """User Information.""" 16 | 17 | model_config = ConfigDict(extra="ignore") 18 | name: str = Field(validation_alias=AliasChoices("name", "username")) 19 | default_workspace_id: str = Field(validation_alias=AliasChoices("workspace", "default_workspace_id")) 20 | priority_queues_count: int = Field( 21 | validation_alias=AliasChoices("priority_queues_count", "monthly_benchmark_count") 22 | ) 23 | 24 | 25 | def login(email: str, password: str) -> dict[str, str]: 26 | """Attempt login with given email and password, returns dict of tokens if login was successful. 27 | 28 | Args: 29 | email (str): Email. 30 | password (str): Password. 31 | 32 | Raises: 33 | HTTPError: When login was not successful. 34 | 35 | Returns: 36 | dict[str, str]: A dictionary containing access token and refresh token. 37 | """ 38 | main_url = OWLITE_SETTINGS.base_url.MAIN 39 | front_url = OWLITE_SETTINGS.base_url.FRONT 40 | payload = {"username": email, "password": password} 41 | 42 | response = requests.post(f"{main_url}/login", data=payload, timeout=OWLITE_API_DEFAULT_TIMEOUT) 43 | resp = response.json() 44 | 45 | if not response.ok: 46 | if response.status_code == 401: 47 | login_failed_dict = { 48 | "User not found": ( 49 | "The email is not registered. Please check if your email is correct " 50 | f"or sign up at {front_url}/auth/login" 51 | ), 52 | "Incorrect password": "Incorrect password provided. Please check if your password is correct", 53 | } 54 | if resp and resp["detail"] in login_failed_dict: 55 | log.error(login_failed_dict[resp["detail"]]) # UX 56 | raise LoginError("Login failed") 57 | 58 | response.raise_for_status() 59 | 60 | assert isinstance(resp, dict) 61 | return resp 62 | 63 | 64 | def whoami() -> UserInfo: 65 | """Get username with current access token at owlite cache. 66 | 67 | Raises: 68 | LoginError: When no saved login token found. 69 | HTTPError: when request was not successful. 70 | 71 | Returns: 72 | UserInfo: Information of current user. 73 | """ 74 | if OWLITE_SETTINGS.tokens is None: 75 | log.error("Please log in using 'owlite login'. Account not found on this device") # UX 76 | raise LoginError("OwLite token not found") 77 | 78 | main_api = APIBase(OWLITE_SETTINGS.base_url.MAIN, "OWLITE_LOGIN_API") 79 | try: 80 | resp = main_api.post("/login/whoami") 81 | except requests.exceptions.HTTPError as e: 82 | if e.response is not None and e.response.status_code == 403: 83 | raise LoginError("Not authenticated") from e 84 | raise e 85 | assert isinstance(resp, dict) 86 | log.debug(f"whoami response: {resp}") 87 | user_info = UserInfo.model_validate(resp) 88 | log.debug(f"user info : {user_info}") 89 | 90 | if (workspace := OWLITE_SETTINGS.current_workspace) is None: 91 | workspace_id = user_info.default_workspace_id 92 | log.warning("No workspace selected. Automatically selecting the default one") # UX 93 | else: 94 | workspace_id = workspace.id 95 | OWLITE_SETTINGS.current_workspace = Workspace.load(workspace_id) 96 | 97 | return user_info 98 | -------------------------------------------------------------------------------- /src/owlite/nn/functions/ste.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=duplicate-code, unused-argument 2 | from typing import Any 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.autograd import Function 7 | 8 | from ...core.logger import log 9 | from .fake_quantize import BaseFakeINTQuantizeFunction, fake_quantize 10 | 11 | 12 | # mypy: disable-error-code=override 13 | # pylint: disable-next=abstract-method 14 | class ScaledRoundSTE(Function): 15 | r"""A round function that uses STE backward. 16 | 17 | The input is divided by the scale, rounded up, and multiplied by the scale again. 18 | No gradient is propagated through the scale. 19 | 20 | $$ 21 | \text{output} = \lfloor \text{input} / \text{scale} \rceil * \text{scale} 22 | $$ 23 | 24 | """ 25 | 26 | @staticmethod # pylint: disable-next=arguments-differ 27 | def forward(ctx: Any, inputs: Tensor, scale: Tensor | float = 1.0) -> Any: 28 | rounded_input = (inputs / scale).round() 29 | if any(rounded_input > torch.iinfo(torch.int32).max) or any(rounded_input < torch.iinfo(torch.int32).min): 30 | rounded_input = torch.clamp(rounded_input, torch.iinfo(torch.int32).min, torch.iinfo(torch.int32).max) 31 | log.debug_warning("Rounded input is out of bounds") 32 | return rounded_input * scale 33 | 34 | @staticmethod 35 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 36 | return grad_outputs[0], None 37 | 38 | 39 | # pylint: disable-next=abstract-method 40 | class FakeQuantizeSTEFunction(BaseFakeINTQuantizeFunction): 41 | r"""Fake quantizing function for QAT using STE (Straight-Through Estimator). 42 | 43 | For $$ quant\_min $$ <= `input` <= $$ quant\_max $$ the gradient passes straight through, 44 | otherwise the gradient is zero 45 | 46 | In **STE(Straight Through Estimation)** method, the gradient of the round 47 | function used in fake quantization is approximated as 1, and 48 | backpropagation is performed based on this approximation. As a result, 49 | the gradient of the input entering the fake quantizer is propagated as is 50 | when it falls between $$ quant\_min $$ and $$ quant\_max $$, while gradients outside 51 | this range become 0. However, since the gradient propagated to $$ step\_size $$ 52 | is 0, $$ step\_size $$ is fixed. 53 | 54 | When $$x$$ is input of FakeQuantize . 55 | 56 | $$ 57 | \hat{x} = \text{FakeQuantize}(x) 58 | $$ 59 | 60 | 61 | ![STE image](https://github.com/SqueezeBits/owlite/assets/116608095/2d0e071b-394c-4cd1-a68e-33b9a6e18ae6) 62 | """ 63 | 64 | @staticmethod # pylint: disable-next=arguments-differ, too-many-positional-arguments 65 | def forward( 66 | ctx: Any, 67 | inputs: Tensor, 68 | step_size: Tensor, 69 | zero_point: Tensor, 70 | grad_scale: float, # grad_scale is not used 71 | quant_min: int, 72 | quant_max: int, 73 | axis: int | None, 74 | ) -> Tensor: 75 | ctx.save_for_backward(inputs) 76 | lower_bound = quant_min * step_size 77 | upper_bound = quant_max * step_size 78 | ctx.other = lower_bound, upper_bound 79 | return fake_quantize(inputs, step_size, zero_point, quant_min=quant_min, quant_max=quant_max, axis=axis) 80 | 81 | @staticmethod 82 | def backward(ctx: Any, *grad_outputs: Any) -> Any: 83 | inputs = ctx.saved_tensors[0] 84 | grad_output = grad_outputs[0] 85 | lower_bound, upper_bound = ctx.other 86 | lower_bound = lower_bound.reshape([-1] + ([1] * (inputs.dim() - 1))) 87 | upper_bound = lower_bound.reshape([-1] + ([1] * (inputs.dim() - 1))) 88 | grad_inputs = torch.where(inputs.ge(lower_bound) * inputs.le(upper_bound), grad_output, 0) 89 | return grad_inputs, None, None, None, None, None, None, None 90 | 91 | 92 | fake_quantize_ste_function = FakeQuantizeSTEFunction.apply 93 | scaled_round_ste = ScaledRoundSTE.apply 94 | -------------------------------------------------------------------------------- /src/owlite/nn/modules/granularity_mixin.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | from ...options.channel import Channel 6 | 7 | 8 | class GranularityMixin(ABC): 9 | """Abstract base class for mixin classes that represent granularity of FakeQuantization. 10 | 11 | This class provides a common interface for per-channel and per-tensor granularity. 12 | """ 13 | 14 | @abstractmethod 15 | def init_quantization_param(self, channel: Channel | None, zero_point_dtype: torch.dtype = torch.int32) -> None: 16 | """Initialize the quantization parameters with the specified granularity. 17 | 18 | Args: 19 | channel (Channel | None): The channel to initialize. If `None`, use per-tensor quantization. 20 | Otherwise, use per-channel quantization. 21 | zero_point_dtype (torch.dtype) : The data type of zero point. Defaults to `torch.int32`. 22 | """ 23 | 24 | @property 25 | @abstractmethod 26 | def channel(self) -> Channel | None: 27 | """Get the channel associated with the granularity. 28 | 29 | Returns: 30 | Channel | None: The channel associated with the granularity, or `None` if per-tensor quantization is used. 31 | """ 32 | 33 | @property 34 | def per_channel(self) -> bool: 35 | """Check if per-channel quantization is used. 36 | 37 | Equivalent to `self.channel is not None`. 38 | 39 | Returns: 40 | bool: `True` if per-channel quantization is used, `False` otherwise. 41 | """ 42 | return self.channel is not None 43 | 44 | 45 | class PerChannelMixin(GranularityMixin): 46 | """Mixin class for per-channel granularity of FakeQuantization. 47 | 48 | This class provides the implementation for per-channel granularity. 49 | """ 50 | 51 | def init_quantization_param(self, channel: Channel | None, zero_point_dtype: torch.dtype = torch.int32) -> None: 52 | assert isinstance(self, torch.nn.Module) 53 | assert channel is not None 54 | self.register_buffer("_channel_axis", torch.tensor([channel.axis], dtype=torch.int32)) 55 | self.register_buffer("_channel_size", torch.tensor([channel.size], dtype=torch.int32)) 56 | self.step_size = torch.nn.Parameter(torch.ones(channel.size)) 57 | self.zero_point = torch.nn.Parameter( 58 | torch.zeros( 59 | channel.size, 60 | dtype=zero_point_dtype, 61 | ), 62 | requires_grad=False, 63 | ) 64 | 65 | @abstractmethod 66 | def as_per_tensor(self) -> "PerTensorMixin": 67 | """Convert the per-channel granularity to per-tensor granularity.""" 68 | 69 | @property 70 | def channel(self) -> Channel: 71 | return Channel(axis=self._channel_axis.item(), size=self._channel_size.item()) # type: ignore 72 | 73 | 74 | class PerTensorMixin(GranularityMixin): 75 | """Mixin class for per-tensor granularity of FakeQuantization. 76 | 77 | This class provides the implementation for per-tensor granularity, where all channels share 78 | the same quantization parameters. 79 | """ 80 | 81 | def init_quantization_param( 82 | self, channel: Channel | None = None, zero_point_dtype: torch.dtype = torch.int32 83 | ) -> None: 84 | assert channel is None 85 | self.step_size = torch.nn.Parameter(torch.ones(1)) 86 | self.zero_point = torch.nn.Parameter( 87 | torch.zeros( 88 | 1, 89 | dtype=zero_point_dtype, 90 | ), 91 | requires_grad=False, 92 | ) 93 | 94 | @abstractmethod 95 | def as_per_channel(self, channel: Channel) -> "PerChannelMixin": 96 | """Create a new fake per-channel quantizer with the same option (except for the `per_channel` value). 97 | 98 | The `step_size` and `zero_point` of the new fake per-channel quantizer is initialized with shape 99 | `(channel.size,)` filled with values in `self.step_size` and `self.zero_point`, respectively. 100 | """ 101 | 102 | @property 103 | def channel(self) -> None: 104 | return None 105 | -------------------------------------------------------------------------------- /src/owlite/core/device_settings.py: -------------------------------------------------------------------------------- 1 | from functools import cached_property 2 | from typing import get_args 3 | 4 | from ..enums.runtime import Runtime 5 | from .api_base import NEST_API_BASE 6 | from .cache import OWLITE_CACHE_PATH 7 | from .cache.device import Device 8 | from .cache.text import read_text, write_text 9 | from .constants import SUPPORTED_QUALCOMM_DEVICES 10 | from .exceptions import DeviceError, LoginError 11 | from .logger import log 12 | from .settings import OWLITE_SETTINGS 13 | 14 | 15 | class OwLiteDeviceSettings: 16 | """Handles device settings and cache management for OwLite. 17 | 18 | OwLiteDeviceSettings manages device settings and cache information for OwLite. 19 | It provides methods to retrieve, add, and remove device managers, as well as 20 | to set and retrieve connected devices. 21 | 22 | Attributes: 23 | connected_cache (Path): Path to store information about the connected device. 24 | """ 25 | 26 | def __init__(self) -> None: 27 | """Initialize OwLite device settings. 28 | 29 | Initialize paths for OwLite cache directory to store information about 30 | the connected device. 31 | """ 32 | self.connected_cache = OWLITE_CACHE_PATH / "connected" 33 | 34 | @property 35 | def connected(self) -> Device | None: 36 | """Retrieve the connected device. 37 | 38 | Returns: 39 | Device | None: An instance representing the connected device, or None if no device is selected 40 | """ 41 | connected = read_text(self.connected_cache) 42 | if connected is None: 43 | return None 44 | return Device.model_validate_json(connected) 45 | 46 | @connected.setter 47 | def connected(self, connected: Device | None = None) -> None: 48 | """Connect to the device manager and selects a device or deletes a device setting from storage. 49 | 50 | Does not fail if the device does not exist. 51 | 52 | Args: 53 | connected (Device | None): The instance representing the connected device 54 | """ 55 | if connected: 56 | write_text(self.connected_cache, connected.model_dump_json()) 57 | else: 58 | self.connected_cache.unlink(missing_ok=True) 59 | 60 | @cached_property 61 | def devices(self) -> dict[str, Device]: 62 | """The dictionary of devices managed by NEST.""" 63 | tokens = OWLITE_SETTINGS.tokens 64 | if not tokens: 65 | log.error("Please login using 'owlite login' to connect to NEST devices") # UX 66 | raise LoginError("Not authenticated") 67 | 68 | if (workspace := OWLITE_SETTINGS.current_workspace) is None: 69 | log.error("No workspace selected. Please select a workspace") # UX 70 | raise RuntimeError("No workspace selected") 71 | 72 | try: 73 | resp = NEST_API_BASE.get( 74 | "/devices", 75 | params={"workspace_id": workspace.id}, 76 | ) 77 | except Exception as err: 78 | raise DeviceError(err) from err 79 | 80 | assert isinstance(resp, list) 81 | 82 | return {device["name"]: Device(**device) for device in resp} 83 | 84 | def get_device(self, name: str) -> Device: 85 | """Get the device by its name. 86 | 87 | Args: 88 | name (str): The name of a device to retrieve. 89 | 90 | Raises: 91 | RuntimeError: If no device with the provided name is found. 92 | 93 | Returns: 94 | Device: The device with the provided name. 95 | """ 96 | assert self.connected 97 | if self.connected.runtime == Runtime.QNN: 98 | assert name in get_args(SUPPORTED_QUALCOMM_DEVICES) 99 | return Device(name=self.connected.name, runtime=self.connected.runtime, runtime_extra=name) 100 | 101 | if name not in self.devices: 102 | log.error(f"No such device: {name}. Available devices are {', '.join(self.devices.keys())}") # UX 103 | raise RuntimeError("Device not found") 104 | return self.devices[name] 105 | 106 | 107 | OWLITE_DEVICE_SETTINGS = OwLiteDeviceSettings() 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![OwLite logo](https://github.com/SqueezeBits/owlite/assets/64083281/abaa3ad9-0c86-4a9c-9b8d-f54ed6d9524b) 2 | 3 |
4 |

5 | Website • 6 | Web UI • 7 | Key Features • 8 | Installation • 9 | Getting Started • 10 | Contact 11 |

12 |

13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 |

21 |
22 | 23 | ## OwLite 24 | 25 | * https://owlite.ai 26 | * OwLite is a low-code AI model compression toolkit for machine learning models. 27 | * Visualizes computational graphs, identifies bottlenecks, and optimizes latency, and memory usage. 28 | * Also includes an auto-optimization feature and a device farm management system for evaluating optimized models. 29 | 30 | ## Key features 31 | 32 | #### **AI Model Visualization** 33 | 34 | * You can visualize AI models using OwLite's editor function. 35 | * You can easily understand the structure of the entire model at a glance through GUI, 36 | * and at the same time, you can easily obtain detailed information about individual nodes. 37 | 38 | #### **Quantization by Recommended Setting** 39 | 40 | * SqueezeBits' engineers provide recommended quantization settings optimized for the model based on their extensive experience with quantization. 41 | * This allows you to obtain a lightweight model while minimizing accuracy drop. 42 | 43 | #### **Quantization by Custom setting** 44 | 45 | * Based on the visualized model, you can apply quantization to each node directly. 46 | * This allows you to finely adjust the desired performance and optimization. 47 | 48 | #### **Latency Benchmark** 49 | 50 | * You can perform latency benchmarks within OwLite. This allows you to easily compare existing models and models you have edited, and determine the point at which to download the result. 51 | 52 | ## **Installation** 53 | 54 | Using pip (Recommended) 55 | ```bash 56 | pip install owlite --extra-index-url https://pypi.squeezebits.com/ 57 | ``` 58 | 59 | ## Version Compatibilities 60 | 61 | OwLite has been validated with the following combinations of versions. Versions not listed below may exhibit unexpected behavior. 62 | 63 | ### v2.6.0 64 | 65 | | torch | torchvision | diffusers | transformers | 66 | | :---: | :---------: | :-------: | :----------: | 67 | | 2.8.0 | 0.23.0 | 0.34.0 | 4.50.0 | 68 | | 2.7.1 | 0.22.1 | 0.33.0 | 4.48.0 | 69 | | 2.6.0 | 0.21.0 | 0.32.0 | 4.46.0 | 70 | | 2.5.1 | 0.20.1 | 0.30.0 | 4.44.0 | 71 | | 2.4.1 | 0.19.1 | 0.28.0 | 4.42.0 | 72 | | 2.3.1 | 0.18.1 | 0.26.0 | 4.40.0 | 73 | | 2.2.2 | 0.17.2 | 0.24.0 | 4.38.0 | 74 | 75 | ## Getting Started 76 | 77 | Please check [OwLite Documentation] for user guide and troubleshooting examples. 78 | 79 | Explore [OwLite Examples], a repository showcasing seamless PyTorch model compression into TensorRT engines. Easily integrate OwLite with minimal code changes and explore powerful compression results. 80 | 81 | ## Contact 82 | 83 | Please contact [owlite-admin@squeezebits.com](mailto:owlite-admin@squeezebits.com) for any questions or suggestions. 84 | 85 |
86 |
87 |
88 | -------------------------------------------------------------------------------- /src/owlite/core/settings.py: -------------------------------------------------------------------------------- 1 | from .cache import OWLITE_CACHE_PATH 2 | from .cache.base_urls import BaseURLs 3 | from .cache.text import read_text, write_text 4 | from .cache.tokens import Tokens 5 | from .cache.workspace import Workspace 6 | from .logger import log 7 | 8 | 9 | class OwLiteSettings: 10 | """Handle OwLite settings including token management. 11 | 12 | OwLiteSettings manages tokens and URLs within the OwLite system. 13 | It provides methods to retrieve and store tokens for authentication. 14 | 15 | Attributes: 16 | token_cache (Path): Path to store token information. 17 | urls_cache (Path): Path to store URL information. 18 | """ 19 | 20 | def __init__(self) -> None: 21 | """Initialize OwLite settings. 22 | 23 | Initialize paths for OwLite cache directory to store tokens and URLs. 24 | """ 25 | self.tokens_cache = OWLITE_CACHE_PATH / "tokens" 26 | self.urls_cache = OWLITE_CACHE_PATH / "urls" 27 | self.current_workspace_cache = OWLITE_CACHE_PATH / "current_workspace" 28 | self.connected_cache = OWLITE_CACHE_PATH / "connected" 29 | 30 | @property 31 | def tokens(self) -> Tokens | None: 32 | """Retrieve tokens or None if they don't exist. 33 | 34 | Returns: 35 | Tokens | None: An instance of Tokens representing the access token and refresh token, 36 | or None if the tokens don't exist. 37 | """ 38 | read_tokens = read_text(self.tokens_cache) 39 | if not read_tokens: 40 | return None 41 | return Tokens.model_validate_json(read_tokens) 42 | 43 | @tokens.setter 44 | def tokens(self, new_tokens: Tokens | None) -> None: 45 | """Set new tokens or removes existing tokens. 46 | 47 | Args: 48 | new_tokens (Tokens | None): An instance of Tokens representing the new access token and refresh token. 49 | If None, existing tokens will be removed. 50 | """ 51 | if new_tokens: 52 | write_text(self.tokens_cache, new_tokens.model_dump_json()) 53 | log.debug(f"Saved access token='{new_tokens.access_token}' refresh token='{new_tokens.refresh_token}'") 54 | else: 55 | self.tokens_cache.unlink(missing_ok=True) 56 | 57 | @property 58 | def base_url(self) -> BaseURLs: 59 | """Retrieve base URLs. 60 | 61 | Returns the base URLs including FRONT, MAIN, and DOVE. 62 | If no custom URLs are set, it defaults to OwLite base URLs. 63 | 64 | Returns: 65 | BaseURLs: an instance of BaseURLs. 66 | """ 67 | base_urls = read_text(self.urls_cache) 68 | if not base_urls: 69 | return BaseURLs() 70 | return BaseURLs.model_validate_json(base_urls) 71 | 72 | @base_url.setter 73 | def base_url(self, base_urls: BaseURLs) -> None: 74 | """Set or remove custom base URLs. 75 | 76 | Args: 77 | base_urls (BaseURLs): An instance of BaseURLs to set or remove custom base URLs. 78 | 79 | Raises: 80 | ValueError: If the provided 'base_urls' instance is invalid or incomplete. 81 | """ 82 | write_text(self.urls_cache, base_urls.model_dump_json()) 83 | 84 | @property 85 | def current_workspace(self) -> Workspace | None: 86 | """Retrieve current workspace. 87 | 88 | Returns: 89 | Workspace | None: The workspace, or None if it doesn't exist. 90 | """ 91 | current_workspace = read_text(self.current_workspace_cache) 92 | if not current_workspace: 93 | return None 94 | return Workspace.model_validate_json(current_workspace) 95 | 96 | @current_workspace.setter 97 | def current_workspace(self, new_workspace: Workspace | None) -> None: 98 | """Set or remove current workspace. 99 | 100 | Args: 101 | new_workspace (Workspace | None): The workspace to set or None to remove the current workspace. 102 | """ 103 | if new_workspace: 104 | write_text(self.current_workspace_cache, new_workspace.model_dump_json()) 105 | log.debug(f"Saved current workspace: {new_workspace}") 106 | else: 107 | self.current_workspace_cache.unlink(missing_ok=True) 108 | 109 | 110 | OWLITE_SETTINGS = OwLiteSettings() 111 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/passes/decompose_in_projection_packed.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | import torch.nn.functional as F 4 | from torch.fx.node import Node 5 | 6 | from .node_argument import NodeArgument 7 | from .rewrite_pass import RewritePass 8 | 9 | 10 | class InProjectionPackedNodeArgument(NodeArgument): 11 | """The arguments of a "call_function" node with target `F._in_projection_packed`.""" 12 | 13 | q: Node 14 | k: Node 15 | v: Node 16 | w: Node 17 | b: Node | None = None 18 | 19 | @classmethod 20 | def validate_node(cls, node: Node) -> bool: 21 | return ( 22 | # pylint: disable-next=protected-access 23 | node.op == "call_function" and node.target is F._in_projection_packed # type: ignore 24 | ) 25 | 26 | 27 | class DecomposeInProjectionPacked(RewritePass): 28 | """Decompose all occurrences of `F._in_projection_packed` by an equivalent subgraph. 29 | 30 | Note: this rewrite pass is implemented based on torch>=2.3.1,<=2.4.0 31 | """ 32 | 33 | @classmethod # pylint: disable-next=too-many-locals 34 | def rewrite(cls, node: Node) -> dict[Node, Node]: 35 | if (arguments := InProjectionPackedNodeArgument.extract_from(node)) is None: 36 | return {} 37 | 38 | graph = node.graph 39 | q, k, v, w, b = arguments.q, arguments.k, arguments.v, arguments.w, arguments.b 40 | with graph.inserting_before(node): 41 | embed_dim = graph.call_method("size", (q, -1)) 42 | if k is v: 43 | if q is k: 44 | # self-attention 45 | proj = graph.call_function(F.linear, (q, w, b)) 46 | # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as 47 | # chunk() 48 | proj = graph.call_method("unflatten", (proj, -1, (3, embed_dim))) 49 | proj = graph.call_method("unsqueeze", (proj, 0)) 50 | proj = graph.call_method("transpose", (proj, 0, -2)) 51 | proj = graph.call_method("squeeze", (proj, -2)) 52 | proj = graph.call_method("contiguous", (proj,)) 53 | new_q, new_k, new_v = (graph.call_function(operator.getitem, (proj, i)) for i in range(3)) 54 | else: 55 | # encoder-decoder attention 56 | embed_dim_x_2 = graph.call_function(operator.mul, (embed_dim, 2)) 57 | 58 | def call_method_split(x: Node) -> tuple[Node, Node]: 59 | x_splits = graph.call_method("split", (x, [embed_dim, embed_dim_x_2])) 60 | x_q, x_kv = (graph.call_function(operator.getitem, (x_splits, i)) for i in range(2)) 61 | return x_q, x_kv 62 | 63 | w_q, w_kv = call_method_split(w) 64 | if b is None: 65 | b_q = b_kv = None 66 | else: 67 | b_q, b_kv = call_method_split(b) 68 | q_proj = graph.call_function(F.linear, (q, w_q, b_q)) 69 | kv_proj = graph.call_function(F.linear, (k, w_kv, b_kv)) 70 | # reshape to 2, E and not E, 2 is deliberate for better memory coalescing and keeping same order as 71 | # chunk() 72 | kv_proj = graph.call_method("unflatten", (kv_proj, -1, (2, embed_dim))) 73 | kv_proj = graph.call_method("unsqueeze", (kv_proj, 0)) 74 | kv_proj = graph.call_method("transpose", (kv_proj, 0, -2)) 75 | kv_proj = graph.call_method("squeeze", (kv_proj, -2)) 76 | kv_proj = graph.call_method("contiguous", (kv_proj,)) 77 | new_q = q_proj 78 | new_k, new_v = (graph.call_function(operator.getitem, (kv_proj, i)) for i in range(2)) 79 | else: 80 | 81 | def call_method_chunk(x: Node) -> tuple[Node, Node, Node]: 82 | x_chunks = graph.call_method("chunk", (x, 3)) 83 | x_q, x_k, x_v = (graph.call_function(operator.getitem, (x_chunks, i)) for i in range(3)) 84 | return x_q, x_k, x_v 85 | 86 | w_q, w_k, w_v = call_method_chunk(w) 87 | if b is None: 88 | b_q = b_k = b_v = None 89 | else: 90 | b_q, b_k, b_v = call_method_chunk(b) 91 | new_q = graph.call_function(F.linear, (q, w_q, b_q)) 92 | new_k = graph.call_function(F.linear, (k, w_k, b_k)) 93 | new_v = graph.call_function(F.linear, (v, w_v, b_v)) 94 | 95 | old_q, old_k, old_v = tuple(node.users) 96 | return {old_q: new_q, old_k: new_k, old_v: new_v} 97 | -------------------------------------------------------------------------------- /src/owlite/core/logger.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access 2 | import logging 3 | import os 4 | from typing import Any, cast 5 | 6 | DEBUG_WARNING = 15 7 | ULTRA_VERBOSE = -10 8 | 9 | 10 | # pylint: disable=missing-function-docstring, too-few-public-methods 11 | class Logger(logging.Logger): 12 | """The Logger class whose level can be only set via the environmental variable OWLITE_LOG_LEVEL.""" 13 | 14 | ENV_VAR = "OWLITE_LOG_LEVEL" 15 | 16 | def ignore_warnings(self) -> Any: 17 | """Return context manager to ignore warning. 18 | 19 | with log.ignore_warnings(): 20 | log.warning("this warning would be ignored") 21 | 22 | Returns: 23 | _WarningFilterContext instatnce 24 | """ 25 | 26 | class _WarningFilterContext: 27 | def __init__(self, logger: logging.Logger) -> None: 28 | self.logger = logger 29 | self.warning_filter: logging.Filter | None = None 30 | 31 | def __enter__(self) -> logging.Logger: 32 | class WarningFilter(logging.Filter): 33 | """Class to filter warnings.""" 34 | 35 | def filter(self, record: logging.LogRecord) -> bool: 36 | return record.levelno < DEBUG_WARNING 37 | 38 | self.warning_filter = WarningFilter() 39 | self.logger.addFilter(self.warning_filter) 40 | return self.logger 41 | 42 | def __exit__(self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: Any) -> None: 43 | if self.warning_filter: 44 | self.logger.removeFilter(self.warning_filter) 45 | 46 | return _WarningFilterContext(self) 47 | 48 | def debug_warning(self, msg: str, *args: Any, **kwargs: Any) -> None: 49 | if self.isEnabledFor(DEBUG_WARNING): 50 | self._log(DEBUG_WARNING, msg, args, **kwargs) 51 | 52 | # pylint: disable=access-member-before-definition, attribute-defined-outside-init 53 | @property 54 | def level(self) -> int: 55 | if hasattr(self, "_level"): 56 | return self._level 57 | level_from_env = os.getenv(Logger.ENV_VAR, None) 58 | if level_from_env is None: 59 | self._level = logging.INFO # type: int 60 | elif all(c.isdigit() for c in level_from_env): 61 | self._level = int(level_from_env) 62 | else: 63 | self._level = logging._nameToLevel.get(level_from_env, logging.INFO) 64 | return self._level 65 | 66 | # pylint: disable=unused-argument 67 | @level.setter 68 | def level(self, value: Any) -> None: 69 | return 70 | 71 | 72 | class OwLiteFormatter(logging.Formatter): 73 | """Custom log formatter for OwLite application. 74 | 75 | This formatter customizes log messages by adding color-coded level names and an OwLite prefix. 76 | It uses ANSI escape codes for color representation. 77 | 78 | Args: 79 | format_str (str): Log format string. 80 | 81 | Attributes: 82 | FORMATS (dict): A dictionary containing ANSI escape codes for different log levels. 83 | reset (str): ANSI escape code to reset colors to default. 84 | owlite_prefix (str): Prefix for OwLite log messages. 85 | 86 | """ 87 | 88 | def __init__(self, format_str: str) -> None: 89 | super().__init__(format_str) 90 | 91 | def format(self, record: logging.LogRecord) -> str: 92 | log_format = self.FORMATS.get(record.levelno, "") 93 | colored_levelname = f"{log_format}[{record.levelname}]{self.reset}" 94 | 95 | record.levelname = colored_levelname 96 | 97 | return f"{self.owlite_prefix}{super().format(record)}" 98 | 99 | FORMATS = { 100 | logging.WARNING: "\x1b[38;2;255;212;0m", # Yellow color for WARNING 101 | logging.ERROR: "\x1b[38;2;255;40;40m", # Red color for ERROR 102 | logging.DEBUG: "\x1b[38;2;123;131;191m", # Some shade of blue color for DEBUG 103 | DEBUG_WARNING: "\x1b[38;2;175;0;2151m", # DarkViolet color for DEBUG WARNING 104 | } 105 | 106 | reset = "\x1b[0m" 107 | owlite_prefix = f"\x1b[38;2;238;120;31;1mOwLite {reset}" 108 | 109 | 110 | if "owlite" not in logging.getLogger().manager.loggerDict: 111 | logging.addLevelName(DEBUG_WARNING, "DEBUG WARNING") 112 | logging.addLevelName(ULTRA_VERBOSE, "ULTRA_VERBOSE") 113 | 114 | log = Logger("owlite") 115 | 116 | formatter = ( 117 | OwLiteFormatter("%(pathname)s:%(lineno)d %(levelname)s %(message)s") 118 | if log.level <= ULTRA_VERBOSE 119 | else OwLiteFormatter("%(levelname)s %(message)s") 120 | ) 121 | stream_handler = logging.StreamHandler() 122 | stream_handler.setFormatter(formatter) 123 | 124 | log.addHandler(stream_handler) 125 | else: 126 | log = cast(Logger, logging.getLogger().manager.loggerDict["owlite"]) 127 | -------------------------------------------------------------------------------- /src/owlite/api/project.py: -------------------------------------------------------------------------------- 1 | import json 2 | from dataclasses import dataclass, field 3 | from typing import TYPE_CHECKING, Optional 4 | 5 | from requests.exceptions import HTTPError 6 | from typing_extensions import Self 7 | 8 | from ..core.api_base import MAIN_API_BASE 9 | from ..core.cache.device import Device 10 | from ..core.cache.workspace import Workspace 11 | from ..core.constants import OWLITE_FRONT_BASE_URL, OWLITE_HOME_PATH 12 | from ..core.logger import log 13 | from ..enums import Runtime 14 | 15 | if TYPE_CHECKING: 16 | from .baseline import Baseline 17 | 18 | 19 | @dataclass 20 | class Project: 21 | """The OwLite project.""" 22 | 23 | workspace: Workspace 24 | id: str 25 | name: str 26 | baseline: Optional["Baseline"] = field(default=None) 27 | 28 | @property 29 | def url(self) -> str: 30 | """The URL to the project page.""" 31 | return f"{OWLITE_FRONT_BASE_URL}/project/detail?workspace_id={self.workspace.id}&project_id={self.id}" 32 | 33 | @property 34 | def home(self) -> str: 35 | """The directory path for writing outputs produced by this project.""" 36 | return str(OWLITE_HOME_PATH / self.name) 37 | 38 | @classmethod 39 | def create(cls, workspace: Workspace, name: str, device: Device, description: str | None = None) -> Self: 40 | """Create a new project. 41 | 42 | Args: 43 | workspace (Workspace): The workspace to create the project in 44 | name (str): The name for the project to be created 45 | device (Device): The currently connected device. 46 | description (str | None, optional): Optional description for the project. Defaults to None. 47 | 48 | Raises: 49 | RuntimeError: When the project is not created for an unexpected reason. 50 | 51 | Returns: 52 | Project: The newly created project. 53 | """ 54 | if description is None: 55 | description = "" 56 | resp = MAIN_API_BASE.post( 57 | "/projects", 58 | json={ 59 | "workspace_id": workspace.id, 60 | "project_name": name, 61 | "description": description, 62 | "framework": device.runtime.value, 63 | }, 64 | ) 65 | 66 | if not (isinstance(resp, dict) and resp["name"] == name): 67 | raise RuntimeError(f"Failed to create project '{name}'") 68 | 69 | project = cls(workspace, resp["id"], resp["name"]) 70 | log.info(f"Created a new {project}") # UX 71 | return project 72 | 73 | @classmethod 74 | def load_or_create(cls, workspace: Workspace, name: str, device: Device, description: str | None = None) -> Self: 75 | """Load the existing project named `name` if found, creates a new one otherwise. 76 | 77 | Args: 78 | workspace (Workspace): The workspace to load the project in 79 | name (str): The name of the project to be loaded or created 80 | device (Device): The currently connected device. 81 | description (str | None, optional): Optional description that will be used only when a new project is 82 | created. Defaults to None. 83 | 84 | Raises: 85 | e (HTTPError): When an unexpected HTTP status code is returned. 86 | 87 | Returns: 88 | Project: the loaded or created project 89 | """ 90 | try: 91 | return cls.create(workspace, name, device, description) 92 | except HTTPError as e: 93 | if e.response is not None and e.response.status_code == 409: # the project already exists 94 | data = json.loads(e.response.content) 95 | assert ( 96 | existing_name := data["detail"]["existing_project_name"] 97 | ) == name, f"Project name mismatch: {existing_name} != {name}" # UX 98 | 99 | if (existing_framework := data["detail"]["existing_project_framework"]) != device.runtime.value: 100 | raise AssertionError( 101 | f"Project framework mismatch: {Runtime(existing_framework).name} != {device.runtime.name}" 102 | ) from None # UX 103 | 104 | project = cls(workspace, data["detail"]["existing_project_id"], name) 105 | log.info(f"Loaded the existing {project}") # UX 106 | return project 107 | 108 | if e.response is not None and e.response.status_code == 403: 109 | log.error( 110 | "You can create up to 2 Projects in a single Free Plan Workspace. " 111 | "In this execution, OwLite functions will not be executed. " 112 | "Please delete an existing Project or register it in an existing one." 113 | ) # UX 114 | raise e 115 | 116 | def __str__(self) -> str: 117 | return f"project '{self.name}'" 118 | -------------------------------------------------------------------------------- /src/owlite/backend/onnx/optimize.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=c-extension-no-member, broad-exception-caught 2 | 3 | import os 4 | 5 | import onnx 6 | from onnx import ModelProto 7 | 8 | from ... import capi # type: ignore[attr-defined] 9 | from ...core.logger import log 10 | from ..config import ( 11 | ONNX_EXTERNAL_DATA_SIZE_THRESHOLD, 12 | ONNX_TRANSFORM_MAXIMUM_ITERATION, 13 | ) 14 | 15 | 16 | def optimize( 17 | model_proto: ModelProto, 18 | *, 19 | max_num_iters: int = ONNX_TRANSFORM_MAXIMUM_ITERATION, 20 | input_names: list[str] | None = None, 21 | output_names: list[str] | None = None, 22 | skipped_optimizers: list[str] | None = None, 23 | ) -> ModelProto: 24 | """Apply graph-level optimization to the computation graph encapsulated by the given ONNX model proto. 25 | 26 | Args: 27 | model_proto (ModelProto): a model proto to optimize. 28 | max_num_iters (int, optional): the maximum number of iterations to apply the set of optimization passes. 29 | Defaults to ONNX_TRANSFORM_MAXIMUM_ITERATION, which can be set via the environment variable 30 | `OWLITE_ONNX_TRANSFORM_MAXIMUM_ITERATION`. 31 | input_names (list[str] | None, optional): the names of input tensors. 32 | Defaults to None. 33 | output_names (list[str] | None, optional): the names of output tensors. 34 | Defaults to None. 35 | skipped_optimizers (list[str] | None, optional): the names of optimization passes to skip. 36 | Defaults to None. 37 | 38 | Returns: 39 | ModelProto: the ONNX model proto containing the optimized graph. 40 | """ 41 | try: 42 | model_proto_bytes = capi.optimize( 43 | model_proto.SerializeToString(), 44 | input_names, 45 | output_names, 46 | skipped_optimizers, 47 | max_num_iters, 48 | ) 49 | return onnx.load_from_string(model_proto_bytes) 50 | except Exception as e: 51 | log.warning(f"Failed to optimize ONNX: {e}") 52 | return model_proto 53 | 54 | 55 | def optimize_path( 56 | input_path: str, 57 | output_path: str, 58 | *, 59 | size_threshold: int = ONNX_EXTERNAL_DATA_SIZE_THRESHOLD, 60 | max_num_iters: int = ONNX_TRANSFORM_MAXIMUM_ITERATION, 61 | input_names: list[str] | None = None, 62 | output_names: list[str] | None = None, 63 | skipped_optimizers: list[str] | None = None, 64 | ) -> str: 65 | """Apply graph-level optimization to the computation graph encapsulated by the given ONNX model proto. 66 | 67 | Same as `owlite.onnx.optimize` but involves file I/O. (Required for models larger than 2GB.) 68 | 69 | Args: 70 | input_path (str): the path to the input model proto file. 71 | output_path (str): the path to the output model proto file to be created. 72 | size_threshold (int, optional): the lower bound in bytes for the large tensors to save externally. 73 | Defaults to ONNX_EXTERNAL_DATA_SIZE_THRESHOLD, which can be set via the environment variable 74 | `OWLITE_ONNX_EXTERNAL_DATA_SIZE_THRESHOLD`. 75 | max_num_iters (int, optional): the maximum number of iterations to apply the set of optimization passes. 76 | Defaults to ONNX_TRANSFORM_MAXIMUM_ITERATION, which can be set via the environment variable 77 | `OWLITE_ONNX_TRANSFORM_MAXIMUM_ITERATION`. 78 | input_names (list[str] | None, optional): the names of input tensors. 79 | Defaults to None. 80 | output_names (list[str] | None, optional): the names of output tensors. 81 | Defaults to None. 82 | skipped_optimizers (list[str] | None, optional): the names of optimization passes to skip. 83 | Defaults to None. 84 | 85 | Returns: 86 | str: `input_path` if optimization fails, `output_path` otherwise. 87 | """ 88 | input_path = os.path.abspath(input_path) 89 | output_path = os.path.abspath(output_path) 90 | if input_path == output_path: 91 | log.error("You must provide different input_path and output_path to `owlite.onnx.optimize_path`") # UX 92 | raise ValueError("Inplace ONNX optimization via file is not supported.") # UX 93 | try: 94 | output_prefix, _ = os.path.splitext(output_path) 95 | if os.path.isfile(external_data_path := f"{output_prefix}.bin"): 96 | log.warning(f"External data file at {external_data_path} will be overwritten.") # UX 97 | os.remove(external_data_path) 98 | location = os.path.basename(external_data_path) 99 | capi.optimize_path( 100 | input_path, 101 | output_path, 102 | location, 103 | size_threshold, 104 | input_names, 105 | output_names, 106 | skipped_optimizers, 107 | max_num_iters, 108 | ) 109 | return output_path 110 | except Exception as e: 111 | log.warning(f"Failed to optimize ONNX: {e}") 112 | return input_path 113 | -------------------------------------------------------------------------------- /src/owlite/calib/minmax_calibrator.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any 2 | 3 | import torch 4 | from torch.utils.hooks import RemovableHandle 5 | 6 | from ..core.logger import log 7 | from .calibrator import Calibrator 8 | 9 | if TYPE_CHECKING: 10 | from ..nn import FakeQuantizer 11 | 12 | 13 | class MinmaxCalibrator(Calibrator): 14 | r"""Minmax Calibrator Class. 15 | 16 | The MinMaxCalibration calibrator stores the **maximum value** and **minimum value** encountered in the passed data, 17 | utilizing this value as the quantization range. When the original data is represented by $$X$$, 18 | the `step_size` and `zero_point` are caculated as: 19 | 20 | $$ 21 | \text{step\_size}=\frac{\max_{x \\in X}(x) - \min_{x \\in X}(x) }{\text{quant\_max}-\text{quant\_min}} \\ 22 | \text{zero\_point} = - \frac{\min_{x \\in X}(x)}{\text{step\_size}} + \text{quant\_min} 23 | $$ 24 | 25 | For symmetric quantization: 26 | 27 | $$ 28 | \text{step\_size}=\frac{\max_{x \\in X}(|x|)}{\text{quant\_max}-\text{quant\_min}} 29 | \text{zero\_point} = 0 30 | $$ 31 | 32 | Attributes: 33 | max_value (`torch.Tensor`, `optional`): maximum value of data passing through the quantizer. 34 | min_value (`torch.Tensor`, `optional`): minimum value of data passing through the quantizer. 35 | """ 36 | 37 | def __init__(self, quantizer: "FakeQuantizer"): 38 | super().__init__(quantizer) 39 | self.max_value: torch.Tensor | None = None 40 | self.min_value: torch.Tensor | None = None 41 | 42 | def prepare(self) -> RemovableHandle: 43 | # define forward hook function 44 | def minmax_forward_hook_func(module: "FakeQuantizer", inputs: tuple[Any, ...], output: Any) -> Any | None: 45 | """Forward hook function to get minmax value.""" 46 | calibrator = module.calibrator 47 | assert isinstance(calibrator, MinmaxCalibrator) 48 | assert self.check_calib_ready() 49 | 50 | if calibrator.max_value is None or calibrator.min_value is None: 51 | raise ValueError( 52 | "During calibration, calibration attributions should be initialized, but None was provided" 53 | ) 54 | _input = self.convert_to_tensor(inputs) 55 | if module.channel is not None: 56 | axis = module.channel.axis 57 | (other_dims := list(range(_input.dim()))).remove(axis) 58 | _input = _input.permute(axis, *other_dims) # make channel dim is 0 59 | new_max = _input.reshape(_input.size()[0], -1).max(dim=1).values.clone() 60 | new_min = _input.reshape(_input.size()[0], -1).min(dim=1).values.clone() 61 | else: 62 | new_max = _input.max().clone() 63 | new_min = _input.min().clone() 64 | calibrator.max_value.data = torch.maximum( 65 | new_max.to(calibrator.max_value.device), calibrator.max_value 66 | ).data 67 | calibrator.min_value.data = torch.minimum( 68 | new_min.to(calibrator.min_value.device), calibrator.min_value 69 | ).data 70 | return output 71 | 72 | # ~define forward hook function 73 | 74 | if self.max_value is not None or self.min_value is not None: 75 | log.error( 76 | "The min-max attributions are already set before the calibration is prepared.\n" 77 | f"`max_value`: {self.max_value}\n`min_value`: {self.min_value}" 78 | ) 79 | raise ValueError("The min-max attributions are already set before the calibration is prepared") 80 | 81 | self.max_value = ( 82 | torch.ones_like(self.quantizer.step_size.data).to(self.quantizer.step_size.device) * torch.finfo().min 83 | ) 84 | self.min_value = ( 85 | torch.ones_like(self.quantizer.step_size.data).to(self.quantizer.step_size.device) * torch.finfo().max 86 | ) 87 | self.hook_handler = self.quantizer.register_forward_hook(minmax_forward_hook_func) 88 | return self.hook_handler 89 | 90 | def update(self) -> None: 91 | assert self.check_calib_ready() 92 | assert isinstance(self.hook_handler, RemovableHandle) 93 | if self.max_value is None or self.min_value is None: 94 | log.error(f"`max_value` : {self.max_value}") 95 | log.error(f"`min_value` : {self.min_value}") 96 | raise ValueError( 97 | "During preparing calibration, calibration attributions should be initialized, but None was provided" 98 | ) 99 | if self.quantizer.symmetric: 100 | self.update_fake_quantizer_param_with_max_min(torch.max(self.max_value.abs(), self.min_value.abs())) 101 | else: 102 | self.update_fake_quantizer_param_with_max_min(self.max_value, self.min_value) 103 | 104 | # set "min_value" and "max_value" attritbutions to `None` 105 | self.max_value, self.min_value = None, None 106 | 107 | # remove registered forward_hook 108 | self.hook_handler.remove() 109 | self.hook_handler = None 110 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/target.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | import operator 3 | from itertools import product 4 | 5 | import torch 6 | from torch.fx.node import Target as FXTarget 7 | 8 | from ..utils import camel_to_snake 9 | from .types import TorchTarget 10 | 11 | 12 | def torch_targets(fn_name: str) -> list[TorchTarget]: 13 | """Find all torch.* function targets with given fn_name. 14 | 15 | Args: 16 | fn_name (str): torch function name to find (e.g. "randn" for targeting torch.randn) 17 | 18 | Returns: 19 | list[TorchTarget]: list of all torch function targets 20 | """ 21 | inplace_fn_name = f"{fn_name}_" 22 | targets: list[TorchTarget] = [] 23 | if hasattr(torch, fn_name): 24 | targets.append(getattr(torch, fn_name)) 25 | if hasattr(torch, inplace_fn_name): 26 | targets.append(getattr(torch, inplace_fn_name)) 27 | if hasattr(torch.Tensor, fn_name): 28 | targets.append(fn_name) 29 | if hasattr(torch.Tensor, inplace_fn_name): 30 | targets.append(inplace_fn_name) 31 | return targets 32 | 33 | 34 | def functional_targets(op_name: str) -> list[FXTarget]: 35 | targets: list[FXTarget] = [] 36 | inplace_op_name = f"{op_name}_" 37 | # pylint: disable-next=protected-access 38 | for module, name in product((torch.nn.functional, torch._C._nn), (op_name, inplace_op_name)): 39 | if hasattr(module, name): 40 | targets.append(getattr(module, name)) 41 | for i in (1, 2, 3): 42 | op_name_nd = f"{name}{i}d" 43 | if hasattr(module, op_name_nd): 44 | targets.append(getattr(module, op_name_nd)) 45 | return targets 46 | 47 | 48 | def nn_targets(module_name: str) -> list[FXTarget]: 49 | targets: list[FXTarget] = [] 50 | if hasattr(torch.nn, module_name): 51 | targets.append(getattr(torch.nn, module_name)) 52 | 53 | for n in (1, 2, 3): 54 | module_name_nd = f"{module_name}{n}d" 55 | if hasattr(torch.nn, module_name_nd): 56 | targets.append(getattr(torch.nn, module_name_nd)) 57 | 58 | return targets 59 | 60 | 61 | def builtin_targets(op_name: str) -> list[FXTarget]: 62 | inplace_op_name = f"i{op_name}" 63 | targets: list[FXTarget] = [] 64 | if hasattr(operator, op_name): 65 | targets.append(getattr(operator, op_name)) 66 | if hasattr(operator, inplace_op_name): 67 | targets.append(getattr(operator, inplace_op_name)) 68 | return targets 69 | 70 | 71 | def all_torch_functions(op_name: str) -> list[FXTarget]: 72 | return torch_targets(op_name) + functional_targets(op_name) 73 | 74 | 75 | def all_torch_targets(op_name: str) -> list[FXTarget]: 76 | snake_op_name = camel_to_snake(op_name) 77 | return nn_targets(op_name) + all_torch_functions(snake_op_name) 78 | 79 | 80 | def all_targets(op_name: str) -> list[FXTarget]: 81 | snake_op_name = camel_to_snake(op_name) 82 | return builtin_targets(snake_op_name) + all_torch_targets(snake_op_name) 83 | 84 | 85 | ADD_TARGETS = (*all_targets("add"),) 86 | 87 | SUB_TARGETS = (*all_targets("sub"), *all_torch_targets("subtract")) 88 | 89 | MUL_TARGETS = (*all_targets("mul"), *all_torch_targets("multiply")) 90 | 91 | DIV_TARGETS = ( 92 | operator.truediv, 93 | operator.itruediv, 94 | operator.floordiv, 95 | operator.ifloordiv, 96 | *all_torch_targets("div"), 97 | *all_torch_targets("divide"), 98 | *all_torch_targets("floor_divide"), 99 | *all_torch_targets("true_divide"), 100 | ) 101 | 102 | MOD_TARGETS = (operator.mod, operator.imod, *torch_targets("remainder")) 103 | 104 | ARITHMETIC_TARGETS = (*ADD_TARGETS, *SUB_TARGETS, *MUL_TARGETS, *DIV_TARGETS, *MOD_TARGETS) 105 | 106 | CONSTANT_TARGETS = ( 107 | *torch_targets("zero_"), 108 | *torch_targets("new_tensor"), 109 | *torch_targets("empty"), 110 | *torch_targets("empty_like"), 111 | *torch_targets("new_empty"), 112 | *torch_targets("empty_strided"), 113 | *torch_targets("zeros"), 114 | *torch_targets("zeros_like"), 115 | *torch_targets("new_zeros"), 116 | *torch_targets("ones"), 117 | *torch_targets("ones_like"), 118 | *torch_targets("new_ones"), 119 | *torch_targets("full"), 120 | *torch_targets("full_like"), 121 | *torch_targets("new_full"), 122 | *torch_targets("arange"), 123 | *torch_targets("as_tensor"), 124 | *torch_targets("asarray"), 125 | *torch_targets("bartlett_window"), 126 | *torch_targets("eye"), 127 | *torch_targets("from_file"), 128 | *torch_targets("from_numpy"), 129 | *torch_targets("hamming_window"), 130 | *torch_targets("hann_window"), 131 | *torch_targets("kaiser_window"), 132 | *torch_targets("linspace"), 133 | *torch_targets("logspace"), 134 | *torch_targets("numel"), 135 | *torch_targets("scalar_tensor"), 136 | *torch_targets("size"), 137 | *torch_targets("_shape_as_tensor"), 138 | ) 139 | 140 | NONDETERMINISTIC_TARGETS = ( 141 | *torch_targets("bernoulli"), 142 | *torch_targets("normal"), 143 | *torch_targets("rand"), 144 | *torch_targets("rand_like"), 145 | *torch_targets("randn"), 146 | *torch_targets("randn_like"), 147 | *torch_targets("randint"), 148 | *torch_targets("randint_like"), 149 | ) 150 | -------------------------------------------------------------------------------- /src/owlite/compression.py: -------------------------------------------------------------------------------- 1 | r"""Calibration optimizes quantization parameters for minimizing error while preserving model accuracy. 2 | 3 | Quantization is a powerful technique used to reduce the storage and computational requirements of deep learning 4 | models. However, this reduction in precision can potentially hurt model accuracy. Calibration is a crucial step in 5 | quantization that helps mitigate this accuracy loss. 6 | 7 | Calibration involves measuring the distributions of the activations in the model and using this information 8 | to determine the optimal quantization parameters. This process involves: 9 | 10 | 1. Collecting data: A representative dataset, called the **calibration dataset**, is used to evaluate 11 | the trained floating-point model. 12 | 13 | 2. Analyzing data: Statistics about the activation or weight distributions are collected. 14 | Understanding how the data is spread across different values within each layer. 15 | 16 | 3. Selecting quantization parameters: These parameters, such as the quantization step\_size and zero\_point, 17 | are determined using one of several optimization objectives. 18 | The goal is to find the best balance between minimizing quantization error and preserving model accuracy. 19 | """ 20 | 21 | from torch.fx.graph_module import GraphModule 22 | from torch.fx.node import Node 23 | 24 | from .backend.fx.node import find_constant_nodes 25 | from .backend.fx.node_configurator import NodeConfigurator 26 | from .backend.fx.transforms import ( 27 | fuse_bn_into_qlinear_with_quantized_bias, 28 | fuse_bn_into_qmodule_with_per_tensor_quantizer, 29 | qconv_bn_to_qconvbn_with_int32bias, 30 | ) 31 | from .core.logger import log 32 | from .enums import ModelStatus 33 | from .nn import FakeQuantizer, enable_quantizers 34 | from .options.compression_option import CompressionOptions, FakeQuantizerConfig 35 | 36 | 37 | def compress(model: GraphModule, option: CompressionOptions) -> GraphModule: 38 | """Quantize the model with the specification described in options. 39 | 40 | This function inserts quantizers with the quantization options specified in the options, 41 | substitutes them with the Quantized module, and performs post-processing. The linear module 42 | that quantizes the bias cannot fuse the batch norm after quantizing, so it proceeds to fuse 43 | the batch norm. Then, it fuses quantizers with the same quantization option that correspond 44 | to the same tensor in the original model. 45 | 46 | Args: 47 | model (GraphModule): The symbolic traced model to be compressed. 48 | option (CompressionOptions): The option required for compressing the model. 49 | 50 | Raises: 51 | TypeError: If model is not a instance of `GraphModule`. 52 | 53 | Returns: 54 | GraphModule: Compressed model. 55 | """ 56 | if not isinstance(model, GraphModule): 57 | raise TypeError("Only GraphModule instance can be quantized with `owlite.quantize`") 58 | configure(model, option) 59 | fuse_bn_into_qmodule_with_per_tensor_quantizer(model) 60 | fuse_bn_into_qlinear_with_quantized_bias(model) 61 | qconv_bn_to_qconvbn_with_int32bias(model) 62 | enable_quantizers(model) 63 | return model 64 | 65 | 66 | def configure(graph_module: GraphModule, option: CompressionOptions) -> None: 67 | """Configure the input model to a quantized model based on the provided options. 68 | 69 | Args: 70 | graph_module (GraphModule): The model to be compressed. 71 | option (CompressionOptions): The option required for compressing the model. 72 | """ 73 | constants_key = "constants" 74 | if constants_key not in graph_module.meta: 75 | graph_module.meta[constants_key] = find_constant_nodes(graph_module.graph) 76 | add_fake_quantizers(graph_module, option.fake_quantizers) 77 | nodes: list[Node] = [*graph_module.graph.nodes] 78 | for node in nodes: 79 | if node_compression_option := option.node_compression_config.get(node.name): 80 | NodeConfigurator.configure(node, node_compression_option) 81 | graph_module.graph.lint() 82 | graph_module.graph.eliminate_dead_code() 83 | graph_module.recompile() 84 | graph_module.meta["status"] = ModelStatus.COMPRESSED 85 | try: 86 | graph_module.to(next(graph_module.parameters()).device) 87 | except StopIteration: 88 | pass 89 | 90 | 91 | def add_fake_quantizers(graph_module: GraphModule, fake_quantizer_config: FakeQuantizerConfig) -> None: 92 | """Add necessary fake quantizer submodules to the graph module according to the fake quantizer config. 93 | 94 | Args: 95 | graph_module (GraphModule): the graph module where new fake quantizer submodules are to be added 96 | fake_quantizer_config (FakeQuantizerConfig): the configurations for the fake quantizer submodules 97 | to be added 98 | """ 99 | for fake_quantizer_id, target, fake_quantizer_layout in fake_quantizer_config.named_items(): 100 | fake_quantizer = FakeQuantizer.create( 101 | fake_quantizer_layout.option, fake_quantizer_layout.channel, identification=fake_quantizer_id 102 | ) 103 | if fake_quantizer is None: 104 | log.debug_warning(f"Found vacuous layout: {fake_quantizer_layout}") 105 | continue 106 | if not graph_module.add_submodule(target, fake_quantizer): 107 | log.warning(f"Failed to add FakeQuantizer module: {target}") 108 | -------------------------------------------------------------------------------- /src/owlite/backend/fx/node.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import torch 4 | from torch.fx.graph import Graph, Node 5 | 6 | from ..utils import log 7 | from .target import CONSTANT_TARGETS, NONDETERMINISTIC_TARGETS 8 | from .types import TorchTarget 9 | 10 | 11 | def find_placeholders(graph: Graph) -> list[Node]: 12 | """Find all placeholder nodes. 13 | 14 | Args: 15 | graph (Graph): the input graph 16 | 17 | Returns: 18 | list[Node]: the list of nodes whose op is "placeholder" 19 | """ 20 | return [*filter(lambda n: n.op == "placeholder", graph.nodes)] 21 | 22 | 23 | def find_the_output_node(graph: Graph) -> Node: 24 | """Find the unique output node in the graph. 25 | 26 | Args: 27 | graph (Graph): the input graph 28 | 29 | Raises: 30 | RuntimeError: if the graph has no output node or more than one output nodes. 31 | 32 | Returns: 33 | Node: the unique node whose op is "output" 34 | """ 35 | outputs = [*filter(lambda n: n.op == "output", graph.nodes)] 36 | if len(outputs) == 0: 37 | raise RuntimeError('torch.fx.Graph has no node whose op == "output"') 38 | 39 | if len(outputs) > 1: 40 | raise RuntimeError('torch.fx.Graph has more than one node whose op == "output"') 41 | 42 | return outputs[0] 43 | 44 | 45 | def find_constant_nodes(graph: Graph) -> list[Node]: 46 | """Find all constant-foldable nodes in the graph. 47 | 48 | Args: 49 | graph (Graph): the input graph 50 | 51 | Returns: 52 | list[Node]: the list containing all constant-foldable nodes in the graph. 53 | """ 54 | constant_nodes: list[Node] = [] 55 | non_constant_nodes: list[Node] = [] 56 | 57 | def is_constant_getattr_node(node: Node) -> bool: 58 | if node.op == "get_attr": 59 | return True 60 | if not (node.op == "call_function" and node.target is getattr): 61 | return False 62 | name = node.args[1] if len(node.args) > 1 else node.kwargs.get("name", None) 63 | return name not in ("T",) 64 | 65 | def is_constant(node: Node) -> bool: 66 | if node in constant_nodes: 67 | return True 68 | if node in non_constant_nodes: 69 | return False 70 | 71 | if node.op == "placeholder": 72 | non_constant_nodes.append(node) 73 | return False 74 | 75 | if node.target in NONDETERMINISTIC_TARGETS: 76 | non_constant_nodes.append(node) 77 | return False 78 | 79 | is_getattr_node = is_constant_getattr_node(node) 80 | is_constant_generating_node = node.op in ("call_function", "call_method") and node.target in CONSTANT_TARGETS 81 | is_missing_input_node = len(node.all_input_nodes) == 0 82 | if any((is_getattr_node, is_missing_input_node, is_constant_generating_node)): 83 | constant_nodes.append(node) 84 | return True 85 | 86 | result: bool = all(is_constant(input_node) for input_node in node.all_input_nodes) 87 | if result: 88 | constant_nodes.append(node) 89 | else: 90 | non_constant_nodes.append(node) 91 | return result 92 | 93 | for node in graph.nodes: 94 | _ = is_constant(node) 95 | 96 | return constant_nodes 97 | 98 | 99 | def is_output_adapter_node(node: Node) -> bool: 100 | """Check if the node is an output adapter node. 101 | 102 | Args: 103 | node: torch.fx.Node. Node to be checked 104 | 105 | Returns: 106 | bool: True, if node is output_adapter_node. Else, False. 107 | """ 108 | return ( 109 | inspect.isfunction(node.target) 110 | and getattr(node.target, "__name__", "") == "output_adapter" 111 | and getattr(inspect.getmodule(node.target), "__name__", "") == "owlite.backend.fx.trace" 112 | ) 113 | 114 | 115 | def get_target_module(node: Node) -> torch.nn.Module | None: 116 | """Find the module the node is targeting only when the node is a proper "call_module" node. 117 | 118 | Args: 119 | node (FXNode): a node. 120 | 121 | Returns: 122 | torch.nn.Module | None: the module that the node is pointing to if 123 | * its op is `"call_module"`; and 124 | * its target is a string; and 125 | * it belongs to a GraphModule instance 126 | Otherwise, `None` is returned 127 | """ 128 | if node.op != "call_module" or not isinstance(node.target, str): 129 | return None 130 | graph_module = node.graph.owning_module 131 | if graph_module is None: 132 | return None 133 | module: torch.nn.Module | None = None 134 | try: 135 | module = graph_module.get_submodule(node.target) 136 | except AttributeError as e: 137 | log.warning(e) 138 | return None 139 | return module 140 | 141 | 142 | def get_torch_target(node: Node) -> TorchTarget | None: 143 | """Get the PyTorch target module or function of the node. 144 | 145 | Args: 146 | node (Node): a node. 147 | 148 | Returns: 149 | TorchTarget | None: 150 | * If `node.op` is `"call_module"`, returns the class of module instance it is targeting. 151 | * If `node.op` is either `"call_function"` or `"call_method"` and its target is from torch, 152 | returns `node.target` 153 | * Otherwise, returns `None` 154 | """ 155 | target_module = get_target_module(node) 156 | if target_module is not None: 157 | return type(target_module) 158 | if node.op == "call_method" and isinstance(node.target, str) and hasattr(torch.Tensor, node.target): 159 | return node.target 160 | if node.op == "call_function" and callable(node.target): 161 | return node.target 162 | return None 163 | -------------------------------------------------------------------------------- /src/owlite/nn/functions/fake_fp_quantize.py: -------------------------------------------------------------------------------- 1 | """This module provides a custom PyTorch function for fake FP8 quantization. 2 | 3 | The `BaseFakeFPQuantizeFunction` class is a PyTorch function that performs fake FP8 quantization. It takes in an input 4 | tensor, step size, zero point, quantization minimum, quantization maximum, and axis as inputs. The `symbolic` method 5 | defines the symbolic computation graph for the function, which checks if the quantization minimum and maximum are valid 6 | for FP8 quantization. If valid, it calls the `fp8_qdq_symbolic` function to perform the quantization and dequantization. 7 | 8 | The `fake_fp8_quantize` function performs the actual fake FP8 quantization. It takes in an input tensor, step size, 9 | zero point, quantization minimum, quantization maximum, and axis as inputs. It first adjusts the step size and zero 10 | point according to the axis, then performs the quantization by dividing the input by the step size and adding the zero 11 | point. The result is then clipped to the quantization minimum and maximum, and converted to the FP8 data type. 12 | Finally, the result is converted back to the original data type, subtracted by the zero point, and multiplied by the 13 | step size. 14 | 15 | The `fp8_qdq_symbolic` function defines the symbolic computation graph for the fake FP8 quantization. It takes in 16 | an input value, step size, zero point, and axis as inputs. It first casts the zero point to the FP8 data type, then 17 | performs the quantization using the `QuantizeLinear` operator. The result is then dequantized using the 18 | `DequantizeLinear` operator. 19 | 20 | Note: 21 | This implementation assumes that the quantization minimum and maximum are valid for FP8 quantization. 22 | It also assumes that the input tensor is a PyTorch tensor. 23 | """ 24 | 25 | # pylint: disable=unused-argument 26 | import torch 27 | from torch import Tensor, Value 28 | from torch._C._onnx import TensorProtoDataType 29 | from torch.autograd import Function 30 | from torch.onnx._internal import jit_utils 31 | 32 | 33 | # pylint: disable-next=abstract-method 34 | class BaseFakeFPQuantizeFunction(Function): 35 | """An autograd function that performs fake FP quantization. 36 | 37 | Static Methods: 38 | symbolic: Defines the symbolic computation graph for the function. 39 | """ 40 | 41 | @staticmethod 42 | @torch.onnx.symbolic_helper.parse_args("v", "v", "v", "none", "none", "none", "i") # type: ignore # pylint: disable-next=too-many-positional-arguments 43 | def symbolic( 44 | g: jit_utils.GraphContext, 45 | inputs: Value, 46 | step_size: Value, 47 | zero_point: Value, 48 | grad_scale: float, 49 | quant_min: float, 50 | quant_max: float, 51 | axis: int | None, 52 | ) -> Value | tuple[Value, ...]: 53 | r"""Define the symbolic computation graph for fake FP8 quantization. 54 | 55 | Args: 56 | g (`jit_utils.GraphContext`): The graph context. 57 | inputs (`torch.Value`): A tensor to quantize. 58 | step_size (`torch.Value`): The quantization scale, determining the magnitude of each quantization interval. 59 | zero_point (`torch.Tensor`): The quantization zero\_point. It may be expressed as a float in the context of 60 | asymmetric quantization, while for symmetric quantization, it is fixed at 0. 61 | grad_scale (`float`): The gradient scale. 62 | quant_min (`float`): The lower bound of the quantized domain, specified as an integer. 63 | quant_max (`float`): The upper bound of the quantized domain in as an integer. 64 | axis (`int`, optional): Channel axis. Only used when `per_channel` is `True`. Defaults to 0. 65 | 66 | Returns: 67 | The output value. 68 | """ 69 | if (quant_min, quant_max) != (torch.finfo(torch.float8_e4m3fn).min, torch.finfo(torch.float8_e4m3fn).max): 70 | raise torch.onnx.errors.SymbolicValueError( 71 | "For fp quantizer's (quant_min, quant_max), ONNX allows only " 72 | f"({torch.finfo(torch.float8_e4m3fn).min}, {torch.finfo(torch.float8_e4m3fn).max}). " 73 | f"Got ({quant_min}, {quant_max})", 74 | inputs, 75 | ) 76 | zero_point = g.op("Cast", zero_point, to_i=TensorProtoDataType.FLOAT8E4M3FN) 77 | quantized = g.op("QuantizeLinear", inputs, step_size, zero_point, axis_i=axis) 78 | dequantized = g.op("DequantizeLinear", quantized, step_size, zero_point, axis_i=axis) 79 | return dequantized 80 | 81 | 82 | def fake_fp8_quantize( 83 | inputs: Tensor, 84 | step_size: Tensor, 85 | zero_point: Tensor, 86 | *, 87 | quant_min: float, 88 | quant_max: float, 89 | axis: int | None = None, 90 | ) -> torch.Tensor: 91 | """Perform fake FP8 quantization on an input tensor. 92 | 93 | Args: 94 | inputs (`torch.Tensor`): The input tensor. 95 | step_size (`torch.Tensor`): The step size. 96 | zero_point (`torch.Tensor`): The zero point. 97 | quant_min (`float`): The quantization minimum. 98 | quant_max (`float`): The quantization maximum. 99 | axis (`int`, optional): The axis. 100 | 101 | Returns: 102 | Value | tuple[Value, ...]: The quantized tensor. 103 | """ 104 | if axis is not None: 105 | dimlist = [1] * inputs.dim() 106 | dimlist[axis] = -1 107 | step_size = step_size.reshape(dimlist) 108 | zero_point = zero_point.reshape(dimlist) 109 | out = (inputs / step_size) + zero_point 110 | out = out.clip(quant_min, quant_max) 111 | out = out.to(dtype=torch.float8_e4m3fn) 112 | out = out.to(dtype=inputs.dtype) - zero_point 113 | out = out * step_size 114 | return out 115 | -------------------------------------------------------------------------------- /src/owlite/calibrators.py: -------------------------------------------------------------------------------- 1 | from types import TracebackType 2 | 3 | import torch 4 | from torch.fx.graph_module import GraphModule 5 | from torch.nn.parallel import DataParallel, DistributedDataParallel 6 | from tqdm import tqdm 7 | 8 | from .backend.fx.types import GraphModuleOrDataParallel 9 | from .core.constants import OWLITE_CALIBRATION_ENABLE_GRAD 10 | from .core.logger import log 11 | from .enums import ModelStatus 12 | from .nn import FakeQuantizer 13 | 14 | 15 | def _prepare_for_calibration(model: GraphModuleOrDataParallel) -> None: 16 | """Create a calibrator and prepare calibration according to opt. 17 | 18 | Args: 19 | model(`GraphModuleOrDataParallel`): graph module to calibrate. 20 | """ 21 | log.info("Preparing for calibration") # UX 22 | for _, module in model.named_modules(remove_duplicate=True): 23 | if isinstance(module, FakeQuantizer): 24 | module.disable() 25 | module.calibrator.prepare() 26 | log.info("All fake quantizers in the model are now ready for calibration") # UX 27 | log.info("Calibrating the model") # UX 28 | 29 | 30 | def _update_fake_quantizers(model: GraphModuleOrDataParallel) -> None: 31 | """Calculate step size and zero point using data of calibrator and enabling quantization. 32 | 33 | Args: 34 | model(`GraphModuleOrDataParallel`): model to calibrate. 35 | """ 36 | fake_quantizers = [m for m in model.modules() if isinstance(m, FakeQuantizer)] 37 | for module in tqdm(fake_quantizers, desc="Updating fake quantizers"): 38 | module.calibrator.update() 39 | if module.step_size.abs().max() <= 0: 40 | log.error( 41 | f"FakeQuantizer({module.id}) : The step sizes are all zero." 42 | "Make sure the data is fed to the quantizer correctly" 43 | ) 44 | continue 45 | if module.step_size.min() < 0: 46 | log.warning( 47 | f"FakeQuantizer({module.id}) : The step size contains a negative number." 48 | "Automatically changed to positive", 49 | stacklevel=2, 50 | ) 51 | module.step_size.data = module.step_size.data.abs() 52 | module.enable() 53 | if isinstance(model, DataParallel | DistributedDataParallel) and isinstance(model.module, GraphModule): 54 | model.module.meta["status"] = ModelStatus.CALIBRATED 55 | elif isinstance(model, GraphModule): 56 | model.meta["status"] = ModelStatus.CALIBRATED 57 | else: 58 | log.warning( 59 | "It looks like the model provided to `owlite.convert` is contaminated or have not created by the " 60 | "`OwLite.convert` method. The model might not be calibrated correctly." 61 | ) # UX 62 | return 63 | log.info("Updated fake quantizers. Calibration finished") # UX 64 | 65 | 66 | class CalibrationContext(torch.set_grad_enabled): 67 | """ContextManager for calibration. 68 | 69 | CalibrationContext disables gradient calculation. 70 | """ 71 | 72 | def __init__(self, model: GraphModuleOrDataParallel): 73 | super().__init__(mode=OWLITE_CALIBRATION_ENABLE_GRAD) 74 | self.model = model 75 | 76 | def __enter__(self) -> GraphModuleOrDataParallel: # type: ignore[override] 77 | super().__enter__() 78 | _prepare_for_calibration(self.model) 79 | return self.model 80 | 81 | def __exit__( 82 | self, 83 | exc_type: type[BaseException] | None, 84 | exc_val: BaseException | None, 85 | exc_tb: TracebackType | None, 86 | ) -> None: 87 | if exc_type is None: 88 | _update_fake_quantizers(self.model) 89 | super().__exit__(exc_type, exc_val, exc_tb) 90 | 91 | 92 | def calibrate(model: GraphModuleOrDataParallel) -> CalibrationContext: 93 | """Calibration is performed using the supplied data within a 'with' statement. 94 | 95 | `owlite.calibrate` performs Post-Training Quantization (PTQ) calibration on a model converted with the 96 | `OwLite.convert`. It is required to preserve the model's accuracy by carefully selecting the quantization 97 | hyperparameters (the scale and zero-point). PTQ calibration typically requires only a subset of the training data. 98 | 99 | Please review the 100 | [Calibrator](https://squeezebits.gitbook.io/owlite/python-api/owlite.calibrators/owlite.calib.calibrator) 101 | for technical details. 102 | 103 | Args: 104 | model(`GraphModuleOrDataParallel`): GraphModule or DataParallel model to calibrate. 105 | 106 | Returns: 107 | CalibrationContext 108 | 109 | ### Usage 110 | 111 | `owlite.calibrate` returns an `owlite.CalibratorContext` object from the OwLite library can be used with a `with` 112 | statement to perform calibration. The `CalibratorContext` prepares the model for calibration and updates 113 | the model's fake quantizers after calibration is complete. 114 | 115 | **Example** 116 | 117 | ```python 118 | with owlite.calibrate(model): 119 | for i, data in enumerate(train_loader): 120 | model(*data) # feed data to model and store information from it. 121 | # calculate fake quantizers step_sizes and zero_points 122 | 123 | # You should use the `model` outside of the block after the calibration 124 | torch.save(model.state_dict()) 125 | ``` 126 | 127 | In this example, the `owlite.calibrate` creates an `owlite.CalibratorContext`, 128 | referenced by the variable `calibrator`. The training data fetched from `train_loader` 129 | are then passed to the `calibrator` to perform calibration. 130 | 131 | Note that you should continue writing your code outside of the `with` block since the fake quantizers 132 | in the model are updated as the `with` block exits. 133 | 134 | """ 135 | return CalibrationContext(model) 136 | --------------------------------------------------------------------------------