├── LICENSE ├── README.md ├── examples ├── README.md ├── export_i8.py ├── linear_bfp16.py ├── model_custom.py └── model_inject.py ├── ruff.toml ├── tcast ├── __init__.py ├── cast.py ├── datatype.py ├── extension.py ├── injector.py ├── modules.py ├── number.py ├── scale.py ├── utils.py └── version.txt └── tests ├── __init__.py ├── test_bfp.py ├── test_bfp_export.py ├── test_mx.py ├── test_torch.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AMD ROCm™ Software 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorCast (tcast) 4 | 5 | TensorCast is a casting/quantization library in development based on PyTorch 2.2+. 6 | 7 | The scope of TensorCast is defining datatypes and converting tensors between datatypes. A "datatype" is a number format 8 | specification combined with an optional scaling specification. A "cast" is the conversion of a tensor from one datatype 9 | to another. A conversion can include compressed tensors that pack values and scaling information (*actual* cast) or 10 | regular torch tensors (*virtual* cast, or "fake quantization"). In version 1 of TensorCast, only virtual casting is supported. 11 | 12 | The focus of TensorCast is on OCP MX datatypes described in 13 | [OCP MX Formats Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) 14 | as well as additional datatypes pertinent to AMD and such other types as needed to support research in the area of low precision 15 | for machine learning. This focus includes everything needed to describe and convert the datatypes, but also reference 16 | code in various forms that can be used to verify implementations elsewhere. 17 | 18 | Contributors: 19 | 20 | - Eric Dellinger [@ericd](mailto:eric.dellinger@amd.com) 21 | - Alireza Khodamoradi [@alirezak](mailto:alireza.khodamoradi@amd.com) 22 | 23 | ## Structure 24 | 25 | The primary data structures are defined in the classes [NumberSpec](#numberspec), [ScaleSpec](#scalespec), and [DataType](#datatype). 26 | The conversion operators are in the static class [Cast](#cast). 27 | 28 | ### NumberSpec 29 | 30 | Number format specifications are wholly independent of scaling, and simply define the characteristics of floating point, integer, 31 | and unsigned integer formats. Some additional support to ease the conversion process is included. 32 | 33 | #### Inherent Types 34 | 35 | Four number categories are represented: floating point, signed integer, unsigned integer, and exponent. 36 | 37 | A floating point format includes the normal attributes: exponent width, mantissa width, bias or maximum unbiased exponent, and 38 | handling of infinite numbers and NaN. All floating point numbers are signed, with an implicit bit and subnormal support. Three 39 | modes of inf/NaN handling are supported: *ieee*, *fn*, and *fnuz*. The *ieee* mode is the default, and follows the standard IEEE 40 | model of reserving the highest biased exponent value for infinite numbers and NaN. The *fn* mode (**f**inite + **n**an) does not 41 | represent inf, and uses the highest representable value (all bits excluding sign are ones) as the NaN value. The *fnuz* mode 42 | (**f**inite + **n**an represented as **u**nsigned **z**ero), which is LLVM/MLIR standard, where the meaning is that negative 43 | zero indicates NaN, and positive zero is zero. 44 | 45 | Athough unsigned floats and disabling subnormals are potential future features, they are not planned. There is, however, 46 | a special case for describing the power of two scale factors defined by OCP (which are technically unsigned integers), 47 | but sematically it is a biased exponent. Therefore, a mantissa of width zero and an *ieee* mode indicates an unsigned, 48 | biased power of two OCP scale. This is the exponent type mentioned above. 49 | 50 | Integers are defined simply by the number of bits and the presence or absence of a sign bit. 51 | 52 | #### NumberSpec String Encoding 53 | 54 | A `NumberSpec` is created using a string encoding. In TensorCast, string encodings are used to define numbers, scales and datatypes 55 | as an alternative to args and kwargs, although support for construction via parameters is a potential addition. The encoding is 56 | generally EMB format for floats, [u]intK for integers. Exceptions are made for common types (e.g. *float32*, *bfloat16*). The EMB 57 | format is of the form "eXmY[bZ]", where X, Y, and Z are the exponent width, mantissa width, and bias. If the bias is not specified, 58 | the default is (2**(X-1) - 1). **A notable exception occurs with torch dtypes** `torch.float8_e5m2fnuz` and `torch.float8_e4m3fnuz`, in 59 | which the biases are 16 and 8 respectively. These correspond to the Graphcore/Nanoo representations, and in the TensorCast EMB format 60 | are defined as `e5m2b16fnuz` and `e4m3b8fnuz`. 61 | 62 | > Note: flexibility is built in, but testing is limited so far. 63 | > 64 | > - *uintK* and *intK* implemented for 2 <= K <= 32, tested for K in [4, 8, 16] 65 | > - exponent *eXm0* implemented for 4 <= X <= 8, tested for X = 8 66 | > - *eXmYbZ* implemented for 1 <= X <= 8 and 0 <= Y <= 23, testing limited to standard and minifloats 67 | 68 | A `NumberSpec` can alternatively be created using a `torch.dtype` or the string representation thereof. 69 | Since there are different existing naming conventions, the string decoder accepts but strips away any leading "torch." or "float8_". 70 | 71 | #### Auxiliary Data 72 | 73 | During construction, the number spec calculates commonly used information such as *emax*, *emin*, as well as *bits*, *max*, *min*, 74 | *smallest_normal*, and *eps*, in the manner of `torch.finfo` and `torch.iinfo`. Another value, *midmax* is midway between *max* and 75 | 2\*\*(*emax* + 1), which can be used for alternative power of two scale selection. 76 | 77 | #### NumberSpec Implementation Notes 78 | 79 | Signed integer values with a power of two scale are typically implemented as fixed point, with a sign, a single integer bit, and 80 | a fractional component that is *bits* - 2 wide. This can be (and is) represented as a normally biased float with a single exponent 81 | bit and *bits* - 2 mantissa bits. Arithmetically the exponent bit acts as the integer bit. This facilitates casting, while leaving 82 | the actual storage format (floating point, 2's complement, or sign magnitude) as a platform-specific implementation detail. As a 83 | result, integer number specs have both `torch.finfo` and `torch.iinfo` values. 84 | 85 | If the number specification is an exact match to a torch.dtype (regardless of whether a torch dtype or name was used to create the 86 | spec), that dtype will be accessible through the NumberSpec's torch_dtype attribute. 87 | 88 | ### ScaleSpec 89 | 90 | Scaling specifications must differentiate between tensor scales, channel scales, tile scales, subtile scales, and individual scales 91 | (i.e. value exponents). A *tile* in TensorCast is also known as a "block" or "group", but here the term "tile" is used, matching the 92 | Microxcaling ([paper](https://arxiv.org/pdf/2310.10537.pdf), [github](https://github.com/microsoft/microxcaling.git)) terminology 93 | from Microsoft, who developed the OCP MX formats, and did earlier work with 94 | [MSFP](https://proceedings.neurips.cc/paper/2020/file/747e32ab0fea7fbd2ad9ec03daa3f840-Paper.pdf), a block floating point format 95 | that is implemented in the next AIE. 96 | 97 | Microsoft also introduced *subtile* in 98 | [With Shared Microexponents, A Little Shifting Goes a Long Way](https://arxiv.org/abs/2302.08007), 99 | where the initial version of MX (without individual exponents) shared one or more scale offset bits to preserve precision. These 100 | datatypes include MX9, MX6, and MX4 (also known as BFP Prime), which are implemented in an upcoming AIE device. Those MX types are 101 | planned for version 2 of TensorCast. 102 | 103 | In TensorCast V1, tensor, channel, tile, and individual scales are supported, but the first three are mutually exclusive, and the tile 104 | is one dimensional. Two dimensional tiles and hierarchical tensor/tile/subtile scaling are scheduled for V2. 105 | 106 | #### Types of Data and Scales 107 | 108 | Unsigned integer data is generally asymmetric, meaning that there is a zero point in addition to the scale factor. The scale factor 109 | is some form of float, and the zero point can be float or int (the latter guarantees precise 0.0 representation for reduced precision 110 | scales, with a slight loss of SQNR). Integer and exponent number specs are not supported for unsigned int scales. 111 | 112 | Signed integer is generally symmetric around zero, dropping the highest magnitude negative value to avoid bias in the quantization. 113 | The scale numberspec can be either a float or an exponent. Allowing an integer or unsigned bias adjustment in addition to exponent 114 | types 115 | is being considered for V2. Support for unbalanced (asymmetric) scaling is not planned. 116 | 117 | Floating point data has an inherent individual scale, but the tensor/channel/tile scale is restricted to exponent numspecs in V1. 118 | A floating point scale is planned for V2. 119 | 120 | Unscaled data, such as bfloat16, does not have a scale spec in the datatype, so in the scale spec we currently have either a tensor, 121 | channel, or tile scale. A tensor scale is simply a scalar (or two scalars for unsigned data), and is specified by defining the 122 | number spec(s) for the scale with no tile specification. A channel scale is a tile scale, in which tile is the size of the tensor 123 | in the dimension of the scale, and is specified with a tile size of zero. A tile scale has a tile size and the dimension of the 124 | tile. The dimension defaults to -1 (the last dimension of the tensor). 125 | 126 | A limitation in V1 as of now is that padding of tensors is not implemented, so the tensor size in the specified dimension must be a 127 | multiple of the tile size. 128 | 129 | #### ScaleSpec String Encoding 130 | 131 | The components of a scale being scale number spec, optional zero point number spec, and optional tile spec, the string encoding 132 | of a scale specification is the concatenation of the string encodings of the constituents, joined by underscores. 133 | 134 | The number specs are defined above. The tile scale is of the form "tXdY", where X is the size of the tile, a power of two between 2 135 | and 1024 (or 0 for channel scaling) and Y is the dimension of the tile. If the dimension is -1, "dY" is omitted. However, for channel 136 | scaling the tile spec must be included, even if it is only "t0". 137 | 138 | #### ScaleSpec Implementation Notes 139 | 140 | Until 2D tiles, subtiles, hierarchical scaling, and compression are implemented in V2, ScaleSpec is pretty simple. There are methods 141 | for reshaping the tensor to make PyTorch-based scale discovery a bit more straightforward. 142 | 143 | ### DataType 144 | 145 | The datatype is simply a number specification and an optional scale specification. If no scale spec is provided, the datatype is 146 | unscaled. Support for unscaled integer types is unsupported, although it may make sense for int16. 147 | 148 | #### Predefined DataTypes 149 | 150 | PyTorch has dtypes in the torch namespace; TensorCast has predefined dtypes in the tcast namespace. 151 | 152 | These are standard datatypes that are expected to be commonly used. Unscaled dtypes include the standard torch floating point types, 153 | including float8 types if supported in your Pytorch installation. Also included are unscaled versions of the MXFP types: `e3m2fnuz`, 154 | `e2m3fnuz`, and `e2m1fnuz`. 155 | 156 | Tensor scaled types include uint16, int16, uint8, and int8 as well as the MXFP 8 and 6-bit numberspecs. The naming convention for 157 | the dtype names is the numberspec and a scale indicator, encoded as "f" for float16, "b" for bfloat16, "e" for exponent scales. 158 | The uint types have two such indicators, the second being for the zero point, and the zero point has an "i" to indicate an int8 159 | zero point number spec instead of the disallowed "e". Floating point dtypes all have the "e" designation. 160 | 161 | Tile scaled predefined types are the MXFP and MXINT types: mxfp8e5, mxfp8e4, mxfp6e3, mxfp6e2, mxfp4e2, mxint8, and mxint4, all of 162 | which have a tile size of 32. Also included is `bfp16`, which is like mxint8 but with a block size of 8. Other tile scaled dtypes 163 | are the uint8 and uint4 variants of ff, bb, fi, bi with tile size 32 and int8/int4 with float16 and bfloat16 scales. 164 | 165 | #### DataType String Encoding 166 | 167 | A datatype string is the contactenation of the number spec and the scale spec, but construction is done not with an overall 168 | string, but by passing the number spec, the optional scale spec, and an optional concise name (e.g. *mxfp4e2*) to the DataType 169 | constructor. 170 | 171 | ### Cast 172 | 173 | Cast is a static class that contains the PyTorch code to perform rounding, scaling, and quantization. When the torch extension is 174 | implemented, the cast class will be able to route the cast call to the appropriate implementation (e.g. python, cpu C++, gpu C++) 175 | based on a CastMode, tensor characteristics, and available kernels. 176 | 177 | Public methods generally correspond to the API methods in the tcast namespace. Private methods include \_vcast, \_round, \_cast_unscaled, 178 | and \_safe_frexp. 179 | 180 | ### Package Level API 181 | 182 | The classes in tcast need not be used directly. An API wraps essential functionality. 183 | 184 | #### initialize 185 | 186 | The initialize function currently just sets default roundmode and/or scalemode so that overrides in the cast 187 | calls are not necessary. This is optional. Soon, there will also be a default for ComputeMode, which will 188 | select between PyTorch ops, C++/CPU extension, or C++/HIP-CUDA extension. 189 | 190 | ```python 191 | import tcast 192 | tcast.initialize(roundmode="even", scalemode="max") 193 | ``` 194 | 195 | #### number 196 | 197 | This function, given a valid code string, returns a NumberSpec, which can then be used to create a DataType. 198 | 199 | ```python 200 | import tcast 201 | nspec = tcast.number("e5m6") # fp12, an abbreviated version of fp16 202 | ``` 203 | 204 | #### scale 205 | 206 | This function, given a valid code string, returns a ScaleSpec, which can then be used to create a DataType. 207 | 208 | ```python 209 | import tcast 210 | sspec = tcast.scale("e8m0_t32") # power of 2 scaling on the last dimension with tile size 32 211 | ``` 212 | 213 | #### datatype 214 | 215 | This function, given a number spec (NumberSpec or valid numspec code string), an optional scale (ScaleSpec or valid 216 | scale spec code string), and an optional name for the datatype, returns a DataType, which can be passed to a cast function. 217 | If the name is omitted, one is manufactured. 218 | 219 | ```python 220 | import tcast 221 | nspec, sspec = number("e5m6"), scale("e8m0_t32") 222 | dtype = tcast.datatype(nspec, sspec, name="e5m6_e32") 223 | # or 224 | dtype = tcast.datatype("e5m6", "e8m0_t32", name="e5m6_e32") 225 | ``` 226 | 227 | #### cast 228 | 229 | This is intended to be a universal interface to the Cast class, but will be supplemented by task-specific cast methods, 230 | suct as `sparse`. For the current virtual cast limitation, so scale data needs to be returned, and the only parameters 231 | needed are the input `torch.Tensor` and `DataType`, with optional overrides for roundmode and scalemode. 232 | 233 | ```python 234 | import tcast 235 | x = tcast.cast( 236 | torch.randn(1024, 1024, device="cuda", dtype=torch.float16), 237 | tcast.datatype("e5m6", "e8m0_t32", name="e5m6_e32"), 238 | roundmode="nearest", 239 | scalemode="auto" 240 | ) 241 | ``` 242 | 243 | Many common datatypes are predefined, which simplifies the calls: 244 | 245 | ```python 246 | import tcast 247 | x = torch.randn(1024, 1024, device="cuda", dtype=torch.float16) 248 | c = tcast.cast(x, tcast.mxfp6e2) 249 | ``` 250 | 251 | #### sparse 252 | 253 | A simple sparsity function is provided that preserves the M highest magnitude values from N values in a tile along 254 | the specified dimension. In practical hardware terms, the dimension would be the inner dimension of a GEMM and M and N 255 | would be mandated by the hardware platform. Clearly, sparsity has many variations, and magnitude may not be the best 256 | qualifier, but this is a start. 257 | 258 | ```python 259 | import tcast 260 | s = tcast.sparse(x, 8, 4) # 4 of 8 dense values from each tile of 8 261 | ``` 262 | 263 | ## Development Plan 264 | 265 | The feature set planned for version 1 is: 266 | 267 | - Virtual (“fake”) casting (torch.float32, torch.float16, torch.bfloat16 in/out) 268 | - Signed and unsigned integer specifications uint**K** and int**K** for K in [3, 16] 269 | - Floating point e**X**m**Y***infnan* for **X** in [1, 8], **Y** in [0, 16], *infnan* "fn", "fnuz", or none 270 | - Exponent types e**X**m0 for **X** in [4, 8] (biased power of two scale factors) 271 | - Unscaled floating point types 272 | - Tensor scaled floating point types with exponent scale 273 | - Tensor scaled unsigned integers with float scales and either float or int zero points 274 | - Tensor scaled signed integers with float or exponent scales 275 | - Single channel scaled types, as decribed above in tensor scaling 276 | - Single dimension tile scaled types, as described above; tile sizes are powers of two with exponents in [2, 10] 277 | - M of N sparsity within tiles or subtiles 278 | - round modes: nearest, even, zero, and stochastic 279 | - scale modes (exponent selection): max and midmax 280 | - PyTorch python operations for casting 281 | - *C++ (CPU) casting in PyTorch extension* 282 | - *C++ (HIP/CUDA) casting in PyTorch extension* 283 | 284 | The feature set planned for version 2 is: 285 | 286 | - Actual (compressed) casting 287 | - 2D tile specifications 288 | - 1D and 2D subtile specifications with scale offsets from tile scale 289 | - tile and subtile-specific number specifications with selection metadata ("multicast") 290 | - lookup table number specs 291 | - MSFP MX9/MX6/MX4 datatype support 292 | - hierarchical scaling (tensor + tile + subtile + individual exponents) 293 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Examples 4 | 5 | A collection of examples on how to use Tensorcast. 6 | 7 | ## Manual method 8 | User has the most flexibility to quantize each tensor. [linear_bfp16](linear_bfp16.py) shows the usage of **tcast.cast** class to quantize a linear layer's weights directly and its input and output using torch hooks. 9 | 10 | ## Replacing troch modules with customized modules 11 | User can replace torch modules with customized modules with required hooks to quantize layer's weights, input, and output. 12 | [linear_bfp16](linear_bfp16.py) shows how to use the **tcast.TorchInjector** for a single **torch.linear** layer and [model_custom](model_custom.py) shows an example for a torchvision model. In both examples, user can pass a dictionary, **tcast_specs**, to describe the data types. 13 | 14 | ## Replacing model layers with customized layers 15 | User can replace a model's layers with customized layers with required hooks to quantize weights, input, and output. 16 | [model_inject](model_inject.py) shows how to use the **tcast.MixedPrecisionInjector**. In this example, user can pass a dictionary, **tcast_specs**, to describe the data types for each layer's weoghts, input, and output. 17 | 18 | -------------------------------------------------------------------------------- /examples/export_i8.py: -------------------------------------------------------------------------------- 1 | import tcast 2 | import torch 3 | import numpy as np 4 | 5 | def interleave_bfp16(mantissas, exponents, block): 6 | num_blocks = int(np.ceil(mantissas.numel()/block)) 7 | array_size = num_blocks*(block+1) 8 | i8_array = np.zeros(array_size, dtype=np.int8) 9 | indx = 0 10 | for b in range(num_blocks): 11 | e = int(exponents[b*block]) 12 | i8_array[indx] = np.array(e).astype(np.int8) 13 | indx += 1 14 | for v in range(block): 15 | m = int(mantissas[b*block+v]) 16 | i8_array[indx] = m 17 | indx += 1 18 | return i8_array 19 | 20 | if __name__ == "__main__": 21 | block = 16 22 | tensor = (torch.randint(-2048, 2048, (1, block))*torch.randn(1, block)).float() 23 | tcast_dt = tcast.datatype("int8", "e8m0_t"+str(block), export=True) 24 | tensor_tcast_d = tcast.cast(tensor, dtype=tcast_dt, roundmode="even") 25 | tensor_tcast_m = tensor_tcast_d["x_export"].view(-1) 26 | tensor_tcast_e = tensor_tcast_d["meta_export"].view(-1) 27 | i8_array = interleave_bfp16(tensor_tcast_m, tensor_tcast_e, block) 28 | print(tensor_tcast_d["x"]) 29 | print("values: ", end="") 30 | exp = np.array(i8_array[0]).astype(np.uint8) 31 | for i in range(1, i8_array.size): 32 | print(i8_array[i]*(2**(exp-127)), end=", ") 33 | print("\n") 34 | -------------------------------------------------------------------------------- /examples/linear_bfp16.py: -------------------------------------------------------------------------------- 1 | import tcast 2 | import torch 3 | import copy 4 | 5 | # Manual method using hooks 6 | def manual_inject(in_fp32, layer_fp32, tcast_specs): 7 | layer_q = copy.deepcopy(layer_fp32) 8 | 9 | if 'weight_dtype' in tcast_specs: 10 | with torch.no_grad(): 11 | layer_q.weight = torch.nn.parameter.Parameter(tcast.cast(layer_fp32.weight, dtype=tcast_specs['weight_dtype'])) 12 | 13 | if 'input_dtype' in tcast_specs: 14 | def forward_input_hook(module, input): 15 | input[0].copy_(tcast.cast(input[0], dtype=tcast_specs['input_dtype'])) 16 | layer_q.register_forward_pre_hook(forward_input_hook) 17 | 18 | if 'output_dtype' in tcast_specs: 19 | def forward_output_hook(module, input, output): 20 | output.copy_(tcast.cast(output, dtype=tcast_specs['output_dtype'])) 21 | layer_q.register_forward_hook(forward_output_hook) 22 | 23 | return layer_q(in_fp32) 24 | 25 | if __name__ == "__main__": 26 | bfp16ebs8_t = tcast.DataType("int8", "e8m0_t8", "bfp16ebs8_t") 27 | layer_fp32 = torch.nn.Linear(64, 64) 28 | input_fp32 = torch.randn(64, 64) 29 | 30 | output_fp32 = layer_fp32(input_fp32) 31 | 32 | # Manual Method 33 | tcast_specs = {} 34 | output_bfp16_1 = manual_inject(input_fp32, layer_fp32, tcast_specs) 35 | print(f"l2 norm error none: {torch.norm(output_fp32 - output_bfp16_1)}") 36 | 37 | tcast_specs = {'weight_dtype': bfp16ebs8_t} 38 | output_bfp16_1 = manual_inject(input_fp32, layer_fp32, tcast_specs) 39 | print(f"l2 norm error weights-only: {torch.norm(output_fp32 - output_bfp16_1)}") 40 | 41 | tcast_specs = {'weight_dtype': bfp16ebs8_t, 'input_dtype': bfp16ebs8_t} 42 | output_bfp16_1 = manual_inject(input_fp32, layer_fp32, tcast_specs) 43 | print(f"l2 norm error weights and input: {torch.norm(output_fp32 - output_bfp16_1)}") 44 | 45 | tcast_specs = {'weight_dtype': bfp16ebs8_t, 'input_dtype': bfp16ebs8_t, 'output_dtype': bfp16ebs8_t} 46 | output_bfp16_1 = manual_inject(input_fp32, layer_fp32, tcast_specs) 47 | print(f"l2 norm error weight, input, and output: {torch.norm(output_fp32 - output_bfp16_1)}") 48 | 49 | # Modify the pytorch modules 50 | tcast.TorchInjector(tcast_specs) 51 | layer_fp32_2 = torch.nn.Linear(64, 64) 52 | # same weights and biases 53 | with torch.no_grad(): 54 | layer_fp32_2.weight = torch.nn.parameter.Parameter(tcast.cast(layer_fp32.weight, dtype=bfp16ebs8_t)) 55 | layer_fp32_2.bias = layer_fp32.bias 56 | 57 | output_bfp16_2 = layer_fp32_2(input_fp32) 58 | print(f"Method2\nl2 norm error weight, input, and output: {torch.norm(output_fp32 - output_bfp16_2)}") 59 | -------------------------------------------------------------------------------- /examples/model_custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import argparse 4 | import os 5 | from torch.utils.data import DataLoader 6 | import numpy as np 7 | import tcast 8 | from tqdm import tqdm 9 | import time 10 | import timm 11 | 12 | def set_seed(seed): 13 | np.random.seed(seed) 14 | torch.random.manual_seed(seed) 15 | 16 | def seed_worker(worker_id): 17 | worker_seed = torch.initial_seed() % 2**32 18 | np.random.seed(worker_seed) 19 | random.seed(worker_seed) 20 | 21 | def imagenet_loader(args): 22 | batch_size = args.batch_size 23 | num_workers = args.num_worker 24 | data_dir = args.data_dir 25 | 26 | # see https://pytorch.org/vision/stable/models.html for setting transform 27 | transform = torchvision.transforms.Compose([ 28 | torchvision.transforms.Resize(256), 29 | torchvision.transforms.CenterCrop(224), 30 | torchvision.transforms.ToTensor(), 31 | torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]) 33 | ]) 34 | 35 | train_ds = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'ILSVRC2012_img_train'), 36 | transform=transform) 37 | 38 | if not os.path.isfile(os.path.join(data_dir, 'wnid_to_label.pickle')): 39 | with open(os.path.join(data_dir, 'wnid_to_label.pickle'), 'wb') as f: 40 | pickle.dump(train_ds.class_to_idx, f) 41 | 42 | test_ds = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'ILSVRC2012_img_val'), 43 | transform=transform) 44 | g = torch.Generator() 45 | g.manual_seed(0) 46 | 47 | train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=num_workers, 48 | worker_init_fn=seed_worker, generator=g) 49 | test_dl = DataLoader(test_ds, min(batch_size, 1024), shuffle=False, 50 | num_workers=num_workers) 51 | 52 | return train_dl, test_dl 53 | 54 | def test_accuracy(model, test_dl, device, topk=(1, )): 55 | """ 56 | Compute top k accuracy on testing dataset 57 | """ 58 | start = time.time() 59 | model.to(args.device) 60 | model.eval() 61 | maxk = max(topk) 62 | topk_count = np.zeros((len(topk), len(test_dl))) 63 | 64 | for j, (x_test, target) in enumerate(tqdm(test_dl, "Evaluation")): 65 | with torch.no_grad(): 66 | y_pred = model(x_test.to(device)) 67 | topk_pred = torch.topk(y_pred, maxk, dim=1).indices 68 | target = target.to(device).view(-1, 1).expand_as(topk_pred) 69 | correct_mat = (target == topk_pred) 70 | 71 | for i, k in enumerate(topk): 72 | topk_count[i, j] = correct_mat[:, :k].reshape(-1).sum().item() 73 | 74 | topk_accuracy = topk_count.sum(axis=1) / len(test_dl.dataset) 75 | model.cpu() 76 | end = time.time() 77 | print(f'Time taken for inference on {args.model} model is {end - start} seconds.') 78 | print(f'Top-1 accuracy for {args.model} model is {topk_accuracy[0]}.') 79 | print(f'Top-5 accuracy for {args.model} model is {topk_accuracy[1]}.') 80 | 81 | def get_model(args): 82 | if hasattr(torchvision.models, args.model): 83 | return getattr(torchvision.models, args.model)(pretrained=True) 84 | elif args.model in timm.list_models(): 85 | return timm.create_model(args.model, pretrained=True) 86 | else: 87 | raise ValueError(f"Model {args.model} not found in torchvision or timm.") 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument( 92 | '--model', type=str, help='torchvision model to load; pass `resnet50`, `mobilenet_v3_large`, or `inception_v4`') 93 | parser.add_argument( 94 | '--seed', type=int, default=0, help='Seed for sampling the calibration data.') 95 | parser.add_argument( 96 | '--data-dir', type=str, help='imagenet directory') 97 | parser.add_argument( 98 | '--batch-size', default=64, type=int, help='eval batch size') 99 | parser.add_argument( 100 | '--num-worker', default=2, type=int, help='number of workers for loading dataset') 101 | 102 | args = parser.parse_args() 103 | if torch.cuda.is_available(): 104 | args.device = torch.device('cuda:0') 105 | else: 106 | args.device = torch.device('cpu') 107 | set_seed(args.seed) 108 | 109 | train_loader, test_loader = imagenet_loader(args) 110 | 111 | model = get_model(args) 112 | test_accuracy(model, test_loader, args.device, (1, 5)) 113 | 114 | # Using the custom layers in the model 115 | bfp16ebs8_t = tcast.DataType("int8", "e8m0_t8", "bfp16ebs8_t") 116 | tcast_specs = {'weight_dtype': bfp16ebs8_t, 'input_dtype': bfp16ebs8_t, 'output_dtype': bfp16ebs8_t} 117 | tcast.TorchInjector(tcast_specs) 118 | 119 | model_custom = get_model(args) 120 | test_accuracy(model_custom, test_loader, args.device, (1, 5)) -------------------------------------------------------------------------------- /examples/model_inject.py: -------------------------------------------------------------------------------- 1 | import tcast 2 | import torch 3 | import copy 4 | 5 | class NonTcastModel(torch.nn.Module): 6 | def __init__(self): 7 | super(NonTcastModel, self).__init__() 8 | self.conv = torch.nn.Conv2d(3, 16, 3) 9 | self.fc1 = torch.nn.Linear(16, 32) 10 | self.fc2 = torch.nn.Linear(32, 8) 11 | 12 | def forward(self, x): 13 | x = self.conv(x) 14 | x = self.fc1(x) 15 | x = self.fc2(x) 16 | return x 17 | 18 | 19 | 20 | if __name__ == "__main__": 21 | 22 | model = NonTcastModel() 23 | input_fp32 = torch.randn(3, 8, 18) 24 | output_fp32 = model(input_fp32) 25 | 26 | bfp16ebs8_t = tcast.DataType("int8", "e8m0_t8", "bfp16ebs8_t") 27 | bfp16ebs16_t = tcast.DataType("int8", "e8m0_t16", "bfp16ebs16_t") 28 | tcast_specs = {'fc1': {'weight_dtype': bfp16ebs8_t}, 'fc2': {'weight_dtype': bfp16ebs16_t}, 'conv': {'weight_dtype': bfp16ebs8_t}} 29 | 30 | model_mixed = tcast.MixedPrecisionInjector(model, tcast_specs) 31 | output_bfp16 = model_mixed(input_fp32) 32 | 33 | print("l2 norm error: ", torch.norm(output_fp32 - output_bfp16).item()) 34 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | 2 | target-version = "py310" 3 | line-length = 130 4 | indent-width = 4 5 | 6 | 7 | [lint] 8 | select = ["E", "F", "I", "W"] 9 | extend-select = [ 10 | "E501", "UP", "D", "B" 11 | ] 12 | ignore = ["D102", "D107", "D105", "B028"] 13 | 14 | [lint.per-file-ignores] 15 | "__init__.py" = ["F401", "E402", "F403"] 16 | "**/{tests,docs,tools}/*" = ["E402"] 17 | 18 | [lint.pydocstyle] 19 | convention = "google" 20 | 21 | [format] 22 | quote-style = "double" 23 | indent-style = "space" 24 | skip-magic-trailing-comma = true 25 | line-ending = "auto" 26 | docstring-code-format = false 27 | docstring-code-line-length = "dynamic" 28 | 29 | [lint.isort] 30 | force-sort-within-sections = true 31 | known-first-party = ["cadre", "tcast"] 32 | -------------------------------------------------------------------------------- /tcast/__init__.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/__init__.py: package 3 | 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from .cast import Cast, ComputeMode, RoundMode, ScaleMode 9 | from .datatype import DataType 10 | from .extension import Extension 11 | from .number import NumberSpec 12 | from .scale import ScaleSpec 13 | from .utils import ( 14 | TensorCastInternalError, 15 | check_literal, 16 | is_float8_available, 17 | is_float8_fnuz_available, 18 | is_gpu_available, 19 | is_installed, 20 | is_power_of_2, 21 | printoptions, 22 | ) 23 | from .injector import TorchInjector, MixedPrecisionInjector 24 | 25 | __version__ = Path(__file__).with_name("version.txt").open().read().strip() 26 | 27 | 28 | def initialize( 29 | roundmode: RoundMode = None, 30 | scalemode: ScaleMode = None, 31 | compmode: ComputeMode = None, 32 | ext_create: bool = False, 33 | ext_name: str = None, 34 | ext_path: Path = None, 35 | ext_exec: bool = False, 36 | ext_cpu_only: bool = False, 37 | ext_verbose: bool = False, 38 | ): 39 | """For overriding default modes and/or customizing torch cpp_extension.""" 40 | if roundmode is not None: 41 | check_literal(roundmode, RoundMode) 42 | Cast.roundmode = roundmode 43 | if scalemode is not None: 44 | check_literal(scalemode, ScaleMode) 45 | Cast.scalemode = scalemode 46 | if compmode is not None or ext_create: 47 | if compmode: 48 | check_literal(compmode, ComputeMode) 49 | Cast.compmode = compmode 50 | if ext_create or (compmode is not None and compmode != "torch"): 51 | if Cast.extension is not None: 52 | if ext_create: 53 | raise RuntimeError("tcast extension has already been created.") 54 | else: 55 | Cast.extension = Extension(ext_name, ext_path, ext_exec, ext_cpu_only, ext_verbose) 56 | 57 | 58 | def number(code: str) -> NumberSpec: 59 | """Create a number spec from a string code.""" 60 | return NumberSpec(code) 61 | 62 | 63 | def scale(code: str) -> ScaleSpec: 64 | """Create a scale spec from a string code.""" 65 | return ScaleSpec(code) 66 | 67 | 68 | def datatype(nspec: str | NumberSpec, sspec: str | ScaleSpec = None, name: str = None, export: bool = False) -> DataType: 69 | """Create an implicitly scaled or unscaled datatype from a number spec code.""" 70 | return DataType(nspec, sspec, name, export) 71 | 72 | 73 | def cast(x: torch.Tensor, dtype: DataType, roundmode: RoundMode = None, scalemode: ScaleMode = None) -> torch.Tensor: 74 | """Virtual cast a tensor to a scaled or unscaled datatype.""" 75 | return Cast.cast(x, dtype, roundmode, scalemode) 76 | 77 | 78 | def sparse(x: torch.Tensor, stile: int, dense: int, dim: int = -1) -> torch.Tensor: 79 | """Virtual cast a tensor to a scaled or unscaled datatype.""" 80 | return Cast.sparse(x, stile, dense, dim) 81 | 82 | 83 | ##### 84 | ##### Predefined datatypes accessible as tcast.float32, tcast.mxfp8e4, etc 85 | ##### NOTE: bias defaults to 2^(ebits-1) - 1 unless overridden in the eXmYbZ descriptor. 86 | ##### Exceptions are external in torch.float8_e5m2fnuz and torch.float8_e4m3fnuz. 87 | ##### 88 | 89 | ### unscaled 90 | 91 | # torch tensor dtypes 92 | float32 = DataType(torch.float32) 93 | float16 = DataType(torch.float16) 94 | bfloat16 = DataType(torch.bfloat16) 95 | if is_float8_available(): 96 | float8_e5m2 = DataType(torch.float8_e5m2) 97 | float8_e4m3fn = DataType(torch.float8_e4m3fn) 98 | if is_float8_fnuz_available(): 99 | float8_e5m2fnuz = DataType(torch.float8_e5m2fnuz) # bias is 16, nonstandard, matches MI300 100 | float8_e4m3fnuz = DataType(torch.float8_e4m3fnuz) # bias is 8, nonstandard, matches MI300 101 | 102 | # 5-bit exponent 103 | e5m2 = DataType("e5m2") 104 | e5m2fnuz = DataType("e5m2fnuz") # bias is 15, DOES NOT MATCH torch.float8_e5m2fnuz 105 | e5m2b16fnuz = DataType("e5m2b16fnuz") # bias of 16, DOES MATCH torch.float8_e5m2fnuz and MI300 106 | binary8p3 = DataType("e5m2b16fnuz") # IEEE P3109 bias 16 matches torch.float8_e5m2fnuz and MI300 107 | # 4-bit exponent 108 | e4m3fnuz = DataType("e4m3fnuz") # bias is 7, DOES NOT MATCH torch.float8_e5m2fnuz 109 | e4m3fn = DataType("e4m3fn") 110 | e4m3b8fnuz = DataType("e4m3b8fnuz") # bias is 8, DOES MATCH torch.float8_e4m3fnuz and MI300 111 | binary8p4 = DataType("e4m3b8fnuz") # IEEE P3109 bias 8 matches torch.float8_e4m3fnuz and MI300 112 | # 3-bit exponent 113 | binary8p5 = DataType("e3m4b4fnuz") # IEEE P3109 bias 4 consistent with other P3109 float8 types 114 | e3m3fnuz = DataType("e3m3fnuz") # bias 3 115 | e3m2fnuz = DataType("e3m2fnuz") # bias 3 116 | # 2-bit exponent 117 | e2m3fnuz = DataType("e2m3fnuz") # bias 1 118 | e2m1fnuz = DataType("e2m1fnuz") # bias 1 119 | 120 | ### tensor scaled 121 | 122 | uint16_ff = DataType("uint16", "float16_float16", "uint16_ff") 123 | uint16_bb = DataType("uint16", "bfloat16_bfloat16", "uint16_bb") 124 | int16_f = DataType("int16", "float16", "int16_f") 125 | int16_b = DataType("int16", "bfloat16", "int16_b") 126 | int16_e = DataType("int16", "e8m0", "int16_e") 127 | uint8_ff = DataType("uint16", "float16_float16", "uint16_ff") 128 | uint8_bb = DataType("uint16", "bfloat16_bfloat16", "uint16_bb") 129 | uint8_fi = DataType("uint16", "float16_int8", "uint16_fi") 130 | uint8_bi = DataType("uint16", "bfloat16_int8", "uint16_bi") 131 | int8_f = DataType("int8", "float16", "int8_f") 132 | int8_b = DataType("int8", "bfloat16", "int8_b") 133 | int8_e = DataType("int8", "e8m0", "int8_e") 134 | e5m2_e = DataType("e5m2", "e8m0", "e5m2_e") 135 | e5m2z_e = DataType("e5m2fnuz", "e8m0", "e5m2z_e") 136 | e4m3_e = DataType("e4m3fn", "e8m0", "e4m3_e") 137 | e4m3z_e = DataType("e4m3fnuz", "e8m0", "e4m3z_e") 138 | e3m2_e = DataType("e3m2fnuz", "e8m0", "e3m2_e") 139 | e2m3_e = DataType("e2m3fnuz", "e8m0", "e2m3_e") 140 | 141 | # MX, tile size 32, exponent scale (BFP tile size 8) 142 | 143 | mxfp8e5 = DataType("e5m2", "e8m0_t32", "mxfp8e5") 144 | mxfp8e4 = DataType("e4m3fn", "e8m0_t32", "mxfp8e4") 145 | mxfp6e3 = DataType("e3m2fnuz", "e8m0_t32", "mxfp6e3") 146 | mxfp6e2 = DataType("e2m3fnuz", "e8m0_t32", "mxfp6e2") 147 | mxfp4e2 = DataType("e2m1fnuz", "e8m0_t32", "mxfp4e2") 148 | mxint8 = DataType("int8", "e8m0_t32", "mxint8") 149 | mxint4 = DataType("int4", "e8m0_t32", "mxint4") 150 | bfp16 = DataType("int8", "e8m0_t8", "bfp16") 151 | 152 | # Float-scaled integer, tile size 32 153 | 154 | uint4_ff32 = DataType("uint4", "float16_float16_t32", "uint4_ff32") 155 | uint4_bb32 = DataType("uint4", "bfloat16_bfloat16_t32", "uint4_bb32") 156 | uint4_fi32 = DataType("uint4", "float16_int8_t32", "uint4_fi32") 157 | uint4_bi32 = DataType("uint4", "bfloat16_int8_t32", "uint4_bi32") 158 | int4_f32 = DataType("int4", "float16_t32", "int4_f32") 159 | int4_b32 = DataType("int4", "bfloat16_t32", "int4_b32") 160 | # 1-bit exponent (integer) 161 | e1m6b1fnuz = DataType("e1m6b1fnuz", "e8m0_t8", "e1m6_e32") # bias overriden to 1, equivalent to int8 162 | e1m2b1fnuz = DataType("e1m2b1fnuz", "e8m0_t8", "e1m2_e32") # bias overriden to 1, equivalent to int4 163 | -------------------------------------------------------------------------------- /tcast/cast.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/cast.py: casting methods 3 | 4 | from typing import ClassVar, Literal 5 | 6 | import torch 7 | 8 | from .datatype import DataType 9 | from .extension import Extension 10 | from .number import NumberSpec 11 | from .utils import check_literal 12 | 13 | RoundMode = Literal["even", "nearest", "zero", "stochastic"] 14 | ScaleMode = Literal["max", "auto"] 15 | ComputeMode = Literal["cpu", "gpu", "torch"] 16 | 17 | 18 | class Cast: 19 | """Static class with implementations in PyTorch.""" 20 | 21 | roundmode: ClassVar[RoundMode] = "even" # the current rounding mode 22 | scalemode: ClassVar[ScaleMode] = "max" # the current scale selection mode 23 | compmode: ClassVar[ComputeMode] = "torch" # use PyTorch operators for both CPU and GPU 24 | extension: ClassVar[Extension] = None 25 | 26 | @classmethod 27 | def _round(cls, x: torch.Tensor) -> torch.Tensor: 28 | if cls.roundmode == "stochastic": 29 | return torch.sign(x) * torch.trunc(torch.abs(x) + torch.rand_like(x)) 30 | if cls.roundmode == "even": 31 | return torch.round(x) 32 | if cls.roundmode == "nearest": 33 | # return torch.trunc(x + x.sign() * 0.5) torch thinks 0.4999999701976776123046875 + 0.5 is 1.0 34 | return torch.where(x.abs().frac() == 0.5, torch.trunc(x + x.sign() * 0.5), x.round()) 35 | return torch.trunc(x) 36 | 37 | @classmethod 38 | def _safe_frexp(cls, x: torch.Tensor) -> torch.Tensor: 39 | return x.float().add(torch.finfo(torch.float32).eps * (x == 0.0)).frexp().exponent 40 | 41 | @classmethod 42 | def _cast_unscaled(cls, x: torch.Tensor, nspec: NumberSpec) -> torch.Tensor: 43 | assert x.is_floating_point and nspec.is_float 44 | valexp = (cls._safe_frexp(x) - 1).clamp_min(nspec.emin) 45 | rscale = (nspec.mbits - valexp).exp2() 46 | x = cls._round(x * rscale).div(rscale).clamp(-nspec.maxfloat, nspec.maxfloat) 47 | return x 48 | 49 | @classmethod 50 | def _vcast(cls, x: torch.Tensor, dtype: DataType) -> torch.Tensor: 51 | """Virtual cast, atomic.""" 52 | if cls.compmode != "torch": 53 | if cls.extension is None: 54 | cls.extension = Extension() 55 | if cls.extension.has_operation("vcast", cls.compmode): 56 | return cls.extension.exec_operation(x, dtype, "vcast", cls.compmode) 57 | xtype = x.dtype 58 | x = x.clone() 59 | if dtype.is_unscaled: 60 | return cls._cast_unscaled(x, dtype.nspec).to(xtype) 61 | assert dtype and dtype.sspec 62 | dim = None if dtype.sspec.is_tensor else -1 63 | x = dtype.sspec.reshape_tensor(x) 64 | eps = torch.finfo(torch.float32).eps 65 | if dtype.nspec.is_uint: 66 | assert dtype.sspec.scale.is_float 67 | tmin, tmax = torch.aminmax(x, dim=dim, keepdim=True) 68 | tmin, tmax = tmin.clamp_max(0.0), tmax.clamp_min(0.0) 69 | scale = cls._cast_unscaled((tmax - tmin) / dtype.nspec.maxint, dtype.sspec.scale) 70 | if dtype.sspec.zero.is_float: 71 | zero = cls._cast_unscaled(-tmin, dtype.sspec.zero) 72 | x = scale * cls._round((x + zero) / scale.clamp_min(eps)).clamp(0, dtype.nspec.maxint) - zero 73 | else: 74 | zero = cls._round(-tmin / scale) 75 | x = scale * ((cls._round(x / scale.clamp_min(eps)) + zero).clamp(0, dtype.nspec.maxint) - zero) 76 | elif dtype.sspec.scale.is_float: 77 | scale = ( 78 | cls._cast_unscaled(1.0 / x.abs().amax(dim=dim, keepdim=True), dtype.sspec.scale) 79 | if dtype.nspec.is_float 80 | else cls._cast_unscaled(x.abs().amax(dim=dim, keepdim=True) / dtype.nspec.maxint, dtype.sspec.scale) 81 | ) 82 | x = cls._round(x / scale.clamp_min(eps)).clamp(dtype.nspec.minint, dtype.nspec.maxint) * scale 83 | else: 84 | # get po2 scale (mx style) and scale the tensor into dtype-representable range 85 | maxexp = cls._safe_frexp(x).amax(dim=dim, keepdim=True) - 1 - dtype.nspec.emax 86 | if dtype.nspec.ebits > 1 and cls.scalemode == "auto": 87 | maxexp[(x * (-maxexp).exp2()).abs().amax(dim=dim) > dtype.nspec.midmax] += 1 88 | nscale = (-maxexp).exp2() 89 | x *= nscale # scale x into range of the target dtype 90 | valexp = (cls._safe_frexp(x) - 1).clamp_min(dtype.nspec.emin) # get the independent exponents, clipped to emin 91 | rscale = (dtype.nspec.mbits - valexp).exp2() 92 | x = cls._round(x * rscale).div(rscale).clamp(-dtype.nspec.maxfloat, dtype.nspec.maxfloat) 93 | x /= nscale 94 | if dtype.export: 95 | x_export = nscale*x*rscale 96 | meta_export = dtype.sspec.scale.bias-torch.log2((nscale*rscale)) 97 | x = dtype.sspec.revert_tensor(x) 98 | return {'x': x.to(xtype), 'x_export': x_export, 'meta_export': meta_export} 99 | 100 | x = dtype.sspec.revert_tensor(x) 101 | return x.to(xtype) 102 | 103 | @classmethod 104 | def sparse(cls, tensor: torch.Tensor, stile: int, dense: int, dim: int = -1) -> torch.Tensor: 105 | """Simple structured sparsity, M of N, where M is dense values retained out of N.""" 106 | if tensor.shape[-1] % stile != 0: 107 | raise NotImplementedError( 108 | f"Last tensor dim ({tensor.shape[-1]}) must be evenly divisible by sparse tile size ({stile})" 109 | ) 110 | assert dense > 0 and dense < stile 111 | tshape = tensor.shape 112 | t = tensor.clone().transpose(dim, -1).reshape(-1, stile).abs() 113 | idx = t.argsort(dim=-1, descending=True) 114 | premask = torch.full(t.shape, True, dtype=torch.bool, device=tensor.device) 115 | mask = torch.empty_like(premask) 116 | premask[..., :dense] = False 117 | mask.scatter_(-1, idx, premask) 118 | return t.masked_fill_(mask, 0.0).reshape(tshape).transpose(dim, -1) 119 | 120 | @classmethod 121 | def cast( 122 | cls, 123 | x: torch.Tensor, 124 | dtype: DataType, 125 | roundmode: RoundMode = None, 126 | scalemode: ScaleMode = None, 127 | compmode: ComputeMode = None, 128 | ) -> torch.Tensor: 129 | """Generic cast interface.""" 130 | # currently not so generic as we are only doing virtual cast 131 | # roundmode and scalemode are optional overrides 132 | check_literal(roundmode, RoundMode, True) 133 | check_literal(scalemode, ScaleMode, True) 134 | check_literal(compmode, ComputeMode, True) 135 | saveround, savescale, savecomp = cls.roundmode, cls.scalemode, cls.compmode 136 | cls.roundmode = roundmode if roundmode else saveround 137 | cls.scalemode = scalemode if scalemode else savescale 138 | cls.compmode = compmode if compmode else savecomp 139 | x = cls._vcast(x, dtype) 140 | cls.roundmode = saveround 141 | cls.scalemode = savescale 142 | cls.compmode = savecomp 143 | return x 144 | -------------------------------------------------------------------------------- /tcast/datatype.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/datatype.py: combines number spec and scale spec to define a datatype 3 | 4 | from dataclasses import dataclass 5 | 6 | from .number import NumberSpec 7 | from .scale import ScaleSpec 8 | 9 | 10 | @dataclass 11 | class DataType: 12 | """Everything needed to define a scaled or unscaled datatype.""" 13 | 14 | _name: str = None 15 | nspec: NumberSpec = None 16 | sspec: ScaleSpec = None 17 | is_unscaled: bool = False 18 | is_tensor: bool = False 19 | is_channel: bool = False 20 | is_tile: bool = False 21 | export: bool = False 22 | 23 | def __init__(self, nspec: str | NumberSpec, sspec: str | ScaleSpec = None, name: str = None, export: bool = False): 24 | self.nspec = nspec if isinstance(nspec, NumberSpec) else NumberSpec(nspec) 25 | self._name = name 26 | self.sspec = sspec if isinstance(sspec, ScaleSpec) else ScaleSpec(sspec) if sspec else None 27 | self.is_unscaled = self.sspec is None 28 | self.is_tensor = self.sspec is not None and self.sspec.is_tensor 29 | self.is_channel = self.sspec is not None and self.sspec.is_channel 30 | self.is_tile = self.sspec is not None and self.sspec.is_tile 31 | assert int(self.is_unscaled) + int(self.is_tensor) + int(self.is_channel) + int(self.is_tile) == 1 32 | self._check() 33 | self.export = export 34 | 35 | def _check(self): 36 | prefix = f"DataType: '{self.name}'" 37 | if self.is_unscaled: 38 | if not self.nspec.is_float: 39 | raise ValueError(f"{prefix} only float data can be cast unscaled.") 40 | return 41 | if self.nspec.is_exponent: 42 | raise ValueError(f"{prefix} exponent number spec is only permitted as a scale.") 43 | if self.sspec.zero and not self.nspec.is_uint: 44 | raise ValueError(f"{prefix} zero spec in scale is incompatible with float or signed int data spec.") 45 | if self.nspec.is_uint and not self.sspec.zero: 46 | raise ValueError(f"{prefix} uint data requires a zero point.") 47 | if self.nspec.is_float and not self.sspec.scale.is_exponent: 48 | raise NotImplementedError(f"{prefix} only exponent scaling is supported for float data.") 49 | if self.nspec.is_int and not (self.sspec.scale.is_exponent or self.sspec.scale.is_float): 50 | raise ValueError(f"{prefix} int data requires either a float or exponent scale.") 51 | 52 | @property 53 | def name(self): 54 | if self._name: 55 | return self._name 56 | return str(self) 57 | 58 | def bits_per_value(self) -> float: 59 | """Given scaling metadata, how many bits per value?""" 60 | if not self.is_tile: 61 | return float(self.nspec.bits) 62 | bits = self.nspec.bits * self.sspec.tile + self.sspec.scale.bits + (self.sspec.zero.bits if self.sspec.zero else 0) 63 | return bits / self.sspec.tile 64 | 65 | def __str__(self): 66 | s = self.nspec.name 67 | if self.sspec: 68 | s += "_" + self.sspec.name 69 | return s 70 | 71 | @classmethod 72 | def valid(cls, ncode: str, scode: str = None) -> bool: 73 | """Returns True if the code generates a valid datatype.""" 74 | try: 75 | cls(ncode, scode) 76 | return True 77 | except ValueError: 78 | return False 79 | -------------------------------------------------------------------------------- /tcast/extension.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/extension.py: loads torch extension 3 | 4 | from collections.abc import Callable 5 | from pathlib import Path 6 | from types import ModuleType 7 | 8 | import torch 9 | from torch.utils.cpp_extension import load 10 | 11 | from .datatype import DataType 12 | from .utils import is_float8_available, is_float8_fnuz_available 13 | 14 | 15 | class Extension: 16 | """Wrapper for PyTorch extension.""" 17 | 18 | def __init__( 19 | self, extname: str = None, srcpath: Path = None, exec_only: bool = False, cpu_only: bool = False, verbose: bool = True 20 | ): 21 | if extname is None: 22 | extname = "tcast_extension" 23 | cpu_only = cpu_only or not torch.cuda.is_available() 24 | if isinstance(srcpath, str): 25 | srcpath = Path(str) 26 | elif not isinstance(srcpath, Path): 27 | srcpath = Path(__file__).parent.with_name("csrc") 28 | if not srcpath.is_dir(): 29 | raise RuntimeError(f"Extension: cannot find source path {str(srcpath)}") 30 | srcfiles = self.get_source_files(srcpath, cpu_only) 31 | if not srcfiles: 32 | raise RuntimeError(f"Extension: no source files (.cpp/.cxx/.c/.cu) found in {str(srcpath)}") 33 | cpu_flags = self.get_cpu_flags() 34 | if is_float8_available(): 35 | cpu_flags.append("-DFLOAT8_AVAILABLE_CPU") 36 | is_rocm = ( 37 | hasattr(torch.version, "hip") and torch.version.hip is not None and torch.utils.cpp_extension.ROCM_HOME is not None 38 | ) 39 | gpu_flags = self.get_gpu_flags(is_rocm, verbose) if not cpu_only else [] 40 | extension = self.load_extension(extname, srcfiles, cpu_flags, gpu_flags, exec_only, cpu_only, verbose) 41 | if isinstance(extension, ModuleType): 42 | self.extension, self.exec_path = extension, None 43 | print(f"Extension: loaded module {extension.__name__}") 44 | elif isinstance(extension, Path): 45 | self.extension, self.exec_path = None, extension 46 | print(f"Extension: lpath to executable is {str(extension)}") 47 | else: 48 | raise RuntimeError(f"Extension: failed to load, recieved {str(extension)}") 49 | 50 | def get_source_files(self, srcpath: Path, cpu_only: bool) -> list[Path]: 51 | """Get the source files. If cpu_only, skip .cu files.""" 52 | srcfiles = list(srcpath.glob("*.cpp")) + list(srcpath.glob("*.cxx")) + list(srcpath.glob("*.c")) 53 | if not cpu_only: 54 | srcfiles += list(srcpath.glob("*.cu")) 55 | return srcfiles 56 | 57 | def get_gpu_flags(self, is_rocm: bool, verbose: bool = False) -> list[str]: 58 | """Get any GPU flags we might need.""" 59 | if not is_rocm: 60 | flags = ["-O4", "--gpu-architecture=native"] 61 | if verbose: 62 | flags += ["--ptxas-options=-v", "-v"] 63 | else: 64 | flags = ["-O3"] 65 | return flags 66 | 67 | def get_cpu_flags(self) -> list[str]: 68 | """See what AVX is available, if any.""" 69 | flags, result = ["-O3", "-march=native"], None 70 | try: 71 | from cpuinfo import cpu_info 72 | 73 | info = cpu_info() 74 | if info["vendor_id_raw"] == "AuthenticAMD": 75 | result = info["flags"] 76 | except ImportError: 77 | from subprocess import check_output 78 | 79 | result = check_output("lscpu", shell=True).decode("utf-8").strip().lower().split() 80 | finally: 81 | if result: 82 | flags += [f"-m{i}" for i in result if i.startswith("avx")] 83 | return flags 84 | 85 | def load_extension( 86 | self, 87 | name: str, 88 | srcfiles: list[Path], 89 | cflags: list[str], 90 | gflags: list[str], 91 | exec_only: bool, 92 | cpu_only: bool, 93 | verbose: bool, 94 | ) -> ModuleType | Path: 95 | return load( 96 | name=name, 97 | sources=srcfiles, 98 | extra_cflags=cflags, 99 | extra_cuda_cflags=gflags, 100 | verbose=verbose, 101 | is_standalone=exec_only, 102 | with_cuda=not cpu_only, 103 | is_python_module=not exec_only, 104 | ) 105 | 106 | def list_operations(self) -> list[str]: 107 | """Print the operations that are supported.""" 108 | ops = [] 109 | for k, v in self.__dict__.items(): 110 | if isinstance(v, Callable) and (k.endswith("_cpu") or k.endswith("_gpu")): 111 | print(k) 112 | ops.append(k) 113 | return ops 114 | 115 | def has_operation(self, name: str, platform: str) -> bool: 116 | """If the extension exists and has the op for gpu/cpu are requested, do it.""" 117 | if not self.extension: 118 | return False 119 | assert platform in ("cpu", "gpu") 120 | return hasattr(self, f"{name}_{platform}") 121 | 122 | def exec_operation(self, tensor: torch.Tensor, dtype: DataType, name: str, platform: str, **kwargs) -> torch.Tensor: 123 | """Run an operation that has (at least) a tensor.""" 124 | assert self.has_operation(name, platform) 125 | tplatform = "cpu" if tensor.is_cpu else "gpu" 126 | if tplatform != platform: 127 | raise NotImplementedError(f"Extension: tensor is on {tplatform}, but op '{name}' is for {platform}") 128 | return getattr(self, f"{name}_{platform}")(tensor, dtype, **kwargs) 129 | -------------------------------------------------------------------------------- /tcast/injector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tcast 3 | import copy 4 | from .modules import Linear, Conv2d 5 | 6 | def TorchInjector(tcast_specs): 7 | def torch_to_tcast_module(cls): 8 | def __init__(self, *args, **kwargs): 9 | cls.__init__(self, *args, tcast_specs=tcast_specs, **kwargs) 10 | return type(f'{cls.__name__}_tcast', (cls,), {'__init__': __init__}) 11 | 12 | for torchm, tcastm in SUPPORTED_MODULES.items(): 13 | torch.nn.__dict__[torchm] = torch_to_tcast_module(tcastm) 14 | 15 | def MixedPrecisionInjector(model, tcast_specs): 16 | model_mixed = copy.deepcopy(model) 17 | for name, module in model_mixed.named_modules(): 18 | if isinstance(module, torch.nn.Linear): 19 | if name in tcast_specs: 20 | model_mixed.__dict__[name] = Linear(module.in_features, module.out_features, module.bias is not None, tcast_specs=tcast_specs[name], pre_weights=module.weight, pre_bias=module.bias) 21 | elif isinstance(module, torch.nn.Conv2d): 22 | if name in tcast_specs: 23 | model_mixed.__dict__[name] = Conv2d(module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, tcast_specs=tcast_specs[name], pre_weights=module.weight, pre_bias=module.bias) 24 | return model_mixed 25 | 26 | SUPPORTED_MODULES = { 27 | "Linear": Linear, 28 | "Conv2d": Conv2d, 29 | } 30 | -------------------------------------------------------------------------------- /tcast/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tcast 3 | 4 | class Linear(torch.nn.Linear): 5 | def __init__( 6 | self, 7 | in_features, 8 | out_features, 9 | bias=True, 10 | tcast_specs=None, 11 | pre_weights=None, 12 | pre_bias=None, 13 | ): 14 | super().__init__( 15 | in_features, 16 | out_features, 17 | bias=bias) 18 | 19 | self.specs = tcast_specs 20 | 21 | if pre_weights is not None: 22 | with torch.no_grad(): 23 | self.weight = pre_weights 24 | if pre_bias is not None: 25 | with torch.no_grad(): 26 | self.bias = pre_bias 27 | 28 | if 'weight_dtype' in self.specs: 29 | with torch.no_grad(): 30 | self.weight = torch.nn.parameter.Parameter(tcast.cast(self.weight, dtype=self.specs['weight_dtype'])) 31 | 32 | if 'bias_dtype' in self.specs: 33 | with torch.no_grad(): 34 | self.bias = torch.nn.parameter.Parameter(tcast.cast(self.bias, dtype=self.specs['bias_dtype'])) 35 | 36 | def forward(self, inputs): 37 | if 'input_dtype' in self.specs: 38 | inputs = tcast.cast(inputs, dtype=self.specs['input_dtype']) 39 | 40 | if 'custom_accumulation' in self.specs: 41 | # the following could be modified by a method. 42 | outputs = torch.nn.functional.linear( 43 | inputs, 44 | self.weight, 45 | bias=self.bias, 46 | ) 47 | else: 48 | outputs = super().forward(inputs) 49 | 50 | if 'output_dtype' in self.specs: 51 | outputs = tcast.cast(outputs, dtype=self.specs['output_dtype']) 52 | 53 | return outputs 54 | 55 | 56 | class Conv2d(torch.nn.Conv2d): 57 | def __init__( 58 | self, 59 | in_channels, 60 | out_channels, 61 | kernel_size, 62 | stride=1, 63 | padding=0, 64 | dilation=1, 65 | groups=1, 66 | bias=True, 67 | padding_mode='zeros', 68 | tcast_specs=None, 69 | pre_weights=None, 70 | pre_bias=None, 71 | ): 72 | 73 | super(Conv2d, self).__init__( 74 | in_channels, 75 | out_channels, 76 | kernel_size, 77 | stride=stride, 78 | padding=padding, 79 | dilation=dilation, 80 | groups=groups, 81 | bias=bias, 82 | padding_mode=padding_mode 83 | ) 84 | 85 | self.specs = tcast_specs 86 | 87 | if pre_weights is not None: 88 | with torch.no_grad(): 89 | self.weight = pre_weights 90 | if pre_bias is not None: 91 | with torch.no_grad(): 92 | self.bias = pre_bias 93 | 94 | if 'weight_dtype' in self.specs: 95 | with torch.no_grad(): 96 | self.weight = torch.nn.parameter.Parameter(tcast.cast(self.weight, dtype=self.specs['weight_dtype'])) 97 | 98 | if 'bias_dtype' in self.specs: 99 | with torch.no_grad(): 100 | self.bias = torch.nn.parameter.Parameter(tcast.cast(self.bias, dtype=self.specs['bias_dtype'])) 101 | 102 | def forward(self, inputs): 103 | 104 | if 'input_dtype' in self.specs: 105 | inputs = tcast.cast(inputs, dtype=self.specs['input_dtype']) 106 | 107 | if 'custom_accumulation' in self.specs: 108 | # the following could be modified by a method. 109 | #return super()._conv_forward(inputs, self.weight, self.bias) 110 | outputs = torch.nn.functional.conv2d( 111 | inputs, 112 | self.weight, 113 | bias=self.bias, 114 | stride=self.stride, 115 | padding=self.padding, 116 | dialation=self.dilation, 117 | groups=self.groups, 118 | ) 119 | else: 120 | outputs = super().forward(inputs) 121 | 122 | if 'output_dtype' in self.specs: 123 | outputs = tcast.cast(outputs, dtype=self.specs['output_dtype']) 124 | 125 | return outputs 126 | -------------------------------------------------------------------------------- /tcast/number.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/number.py: number format specification 3 | 4 | from dataclasses import dataclass 5 | import re 6 | from typing import Literal 7 | 8 | import torch 9 | 10 | from .utils import is_float8_available, is_float8_fnuz_available 11 | 12 | InfNan = Literal["ieee", "fn", "fnuz"] 13 | 14 | MX2NUMSPEC = dict( 15 | mxfp8e5="e5m2", 16 | mxfp8e4="e4m3fn", 17 | mxfp6e3="e3m2fnuz", 18 | mxfp6e2="e2m3fnuz", 19 | mxfp4e2="e2m1fnuz", 20 | mxint8="int8 or e1m6fnuz", 21 | mxint4="int4 or e1m2fnuz", 22 | bfp16="int8 or e1m6fnuz", 23 | ) 24 | 25 | MXUNSUPPORTED = ("mx9", "mx6", "mx3") 26 | 27 | 28 | @dataclass 29 | class NumberSpec: 30 | """Specification for an unscaled number format.""" 31 | 32 | _name: str = None 33 | bits: int = None 34 | ebits: int = None 35 | mbits: int = None 36 | bias: int = None 37 | signed: bool = True 38 | infnan: InfNan = "ieee" 39 | is_float: bool = False 40 | is_int: bool = False 41 | is_uint: bool = False 42 | is_exponent: bool = False 43 | emax: int = None 44 | emin: int = None 45 | maxfloat: float = None 46 | smallest_normal: float = None 47 | smallest_subnormal: float = None 48 | eps: float = None 49 | midmax: float = None 50 | maxint: int = None 51 | minint: int = None 52 | torch_dtype: torch.dtype = None 53 | 54 | def __init__(self, code: str | torch.dtype): 55 | self._decode(code) 56 | self._check() 57 | 58 | @property 59 | def name(self) -> str: 60 | """Returns the name. May be overloaded in a subclass.""" 61 | return self._name 62 | 63 | @property 64 | def max(self) -> int | float: 65 | """Returns maxfloat for floats, maxint for integers.""" 66 | return self.maxfloat if self.is_float else self.maxint 67 | 68 | @property 69 | def min(self) -> int | float: 70 | """Returns -maxfloat for floats, minint for integers.""" 71 | return -self.maxfloat if self.is_float else self.minint 72 | 73 | @property 74 | def tiny(self) -> int | float: 75 | """Returns smallest_normal, as in torch.finfo.""" 76 | return self.smallest_normal 77 | 78 | def get_number_line(self) -> list[float]: 79 | """All possible values for this number specification.""" 80 | if self.bits > 8: 81 | raise ValueError(f"NumberSpec: too many number line values for {self.bits} bits") 82 | if not (self.is_float or self.ebits == 1): 83 | raise ValueError("NumberSpec: number line must be for float or float-like numbers.") 84 | # get the non-negative numbers then mirror for negatives, giving all 2^bits values, including 2 zeros 85 | line = [i * self.smallest_subnormal for i in range(2**self.mbits)] # subnormals 86 | for e in range(self.emax - self.emin + 1): 87 | line += [(self.smallest_normal + i * self.smallest_subnormal) * 2 ** e for i in range(2**self.mbits)] 88 | return [-v for v in reversed(line)] + line 89 | 90 | def _decode(self, code: str | torch.dtype) -> None: 91 | """Sets fields based on input code string.""" 92 | # 1. Handle the case of the spec defined by a torch.dtype 93 | if isinstance(code, torch.dtype): 94 | self.torch_dtype = code 95 | code = str(code) 96 | code = code.lower().removeprefix("torch.") 97 | if ttype := getattr(torch, code, False): 98 | if self.torch_dtype is None and isinstance(ttype, torch.dtype): 99 | self.torch_dtype = ttype 100 | bias_hack = int(code.startswith("float8") and code.endswith("fnuz")) # implicit non-standard bias for torch fnuz types 101 | name = code = code.removeprefix("float8_") 102 | # 2. Check for implicitly scaled datatypes 103 | if name in MX2NUMSPEC: 104 | tilesize = 8 if name.startswith("bfp") else 32 105 | raise ValueError( 106 | f"\tNumberSpec: code '{name}' is a scaled datatype rather than a number format.\n" 107 | f"\tThe equivalent NumberSpec name is '{MX2NUMSPEC[name]}', to be used in conjunction\n" 108 | f"\twith a ScaleSpec name of 'e8m0-{tilesize}' when creating the DataType." 109 | ) 110 | elif name in MXUNSUPPORTED: 111 | raise NotImplementedError( 112 | f"\tNumberSpec: code '{name}' is a scaled datatype rather than a number format.\n" 113 | f"\tMX types (a/k/a bfp prime) are not yet supported." 114 | ) 115 | # 3. Handle float/bfloat/int/uint style string codes for widths > 8 116 | if m := re.fullmatch(r"(float|bfloat|int|uint)(\d+)", name): 117 | prefix, bits = m.group(1), int(m.group(2)) 118 | if prefix == "bfloat": 119 | self.ebits, self.mbits, self.bias = 8, bits - 9, 127 120 | elif prefix == "float": 121 | if bits > 16: 122 | self.ebits, self.mbits, self.bias = 8, bits - 9, 127 123 | elif bits > 8: 124 | self.ebits, self.mbits, self.bias = 5, bits - 6, 15 125 | else: 126 | raise ValueError(f"NumberSpec: code '{name}': float8 and smaller formats require EMB format.") 127 | elif prefix[0] == "u": 128 | self.ebits, self.mbits, self.bias, self.signed, self.infnan = 0, bits, None, False, None 129 | else: 130 | self.ebits, self.mbits, self.bias, self.infnan = 1, bits - 2, 1, "fnuz" 131 | # 4. Handle EMB stype string codes 132 | if self.mbits is None: 133 | if m := re.fullmatch(r"e(\d+)m(\d+)(b\d+)?(fn|fnuz)?", name): 134 | self.ebits, self.mbits, self.bias = int(m.group(1)), int(m.group(2)), m.group(3) 135 | self.infnan = m.group(4) or "ieee" 136 | self.signed = not (self.infnan == "ieee" and self.mbits == 0) 137 | if self.bias is None: 138 | self.bias = 2 ** (self.ebits - 1) - 1 + bias_hack 139 | else: 140 | self.bias = int(self.bias[1:]) 141 | if self.ebits is None: 142 | raise ValueError(f"NumberSpec: code {code} is not a valid format.") 143 | self._name = name 144 | 145 | # 5. Fill in the remaining fields in the spec from ebits/mbits/signed/infnan 146 | self.is_int = self.ebits == 1 and self.bias == 1 and self.signed and self.infnan == "fnuz" 147 | self.is_float = not self.is_int and self.signed and self.infnan is not None 148 | self.is_uint = self.bias is None and not self.signed and self.infnan is None 149 | self.is_exponent = self.ebits > 0 and self.mbits == 0 and not self.signed and self.infnan == "ieee" 150 | assert self.is_float or self.is_exponent or self.is_int or self.is_uint 151 | self.bits = self.ebits + self.mbits + int(self.signed) 152 | self.maxint = 2 ** (self.bits - int(self.signed)) - 1 153 | self.minint = -self.maxint if self.signed else 0 154 | if self.is_float or self.is_int: 155 | self.emax = 2**self.ebits - 1 - self.bias - int(self.infnan == "ieee") 156 | self.emin = 1 - self.bias 157 | self.maxfloat = 2**self.emax * (2.0 - (1 + int(self.infnan == "fn")) * 2 ** (-self.mbits)) 158 | self.midmax = (2 ** (self.emax + 1) - self.maxfloat) / 2.0 + self.maxfloat 159 | self.eps = 2**-self.mbits 160 | self.smallest_normal = 2**self.emin 161 | self.smallest_subnormal = self.smallest_normal * self.eps 162 | 163 | # 6. See if what we have matches a torch.dtype 164 | if self.torch_dtype is None: 165 | self.torch_dtype = self._find_torch_dtype() 166 | 167 | def _find_torch_dtype(self) -> torch.dtype | None: 168 | if self.bits == 32 and self.ebits == 8 and self.mbits == 23 and self.bias == 127 and self.infnan == "ieee": 169 | return torch.float32 170 | if self.bits == 16 and self.ebits == 5 and self.mbits == 10 and self.bias == 15 and self.infnan == "ieee": 171 | return torch.float16 172 | if self.bits == 16 and self.ebits == 8 and self.mbits == 7 and self.bias == 127 and self.infnan == "ieee": 173 | return torch.bfloat16 174 | if self.bits == 8 and is_float8_available(): 175 | if self.ebits == 5 and self.mbits == 2 and self.bias == 15 and self.infnan == "ieee": 176 | return torch.float8_e5m2 177 | if self.ebits == 4 and self.mbits == 3 and self.bias == 7 and self.infnan == "fn": 178 | return torch.float8_e4m3fn 179 | if self.bits == 8 and is_float8_fnuz_available(): 180 | if self.ebits == 5 and self.mbits == 2 and self.bias == 16 and self.infnan == "fnuz": 181 | return torch.float8_e5m2fnuz 182 | if self.ebits == 4 and self.mbits == 3 and self.bias == 8 and self.infnan == "fnuz": 183 | return torch.float8_e4m3fnuz 184 | return None 185 | 186 | def _check(self) -> None: 187 | # TODO(ericd): additional checks for bad/unsupported combinations of values that parsed correctly 188 | if self.bits > 32: 189 | raise NotImplementedError(f"NumberSpec: ({self.name}) bit widths > 32 are unsupported") 190 | if not self.is_uint: 191 | if self.ebits < 1 or self.ebits > 8: 192 | raise ValueError(f"NumberSpec: ({self.name}) ebits '{self.ebits}' needs to be in [1, 8]") 193 | 194 | @classmethod 195 | def valid(cls, code: str | torch.dtype) -> bool: 196 | """Checks validity without raising an exception.""" 197 | try: 198 | cls(code) 199 | return True 200 | except (ValueError, NotImplementedError): 201 | return False 202 | -------------------------------------------------------------------------------- /tcast/scale.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/scale.py: scaling format specification 3 | 4 | from dataclasses import dataclass 5 | import re 6 | 7 | import torch 8 | 9 | from .number import NumberSpec 10 | 11 | 12 | @dataclass 13 | class ScaleSpec: 14 | """Specifies scaling method for a given NumberSpec.""" 15 | 16 | name: str = None 17 | tile: int = None 18 | dim: int = None 19 | is_tensor: bool = False 20 | is_channel: bool = False 21 | is_tile: bool = False 22 | scale: NumberSpec = None 23 | zero: NumberSpec = None 24 | shape: tuple[int] = None 25 | 26 | def __init__(self, code: str): 27 | self._decode(code) 28 | self._check() 29 | 30 | def _valid_tile(self, tile: int): 31 | return tile in (0, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024) 32 | 33 | def _set_nspec(self, nspec: str): 34 | if self.scale is None: 35 | self.scale = NumberSpec(nspec) 36 | elif self.zero is None: 37 | self.zero = NumberSpec(nspec) 38 | else: 39 | raise ValueError(f"ScaleSpec: more than two NumberSpecs provided in string code '{self.name}'.") 40 | 41 | def _set_tile(self, tile: int, dim: int = None): 42 | if not self._valid_tile(tile): 43 | raise ValueError(f"ScaleSpec: '{tile}' is not a supported tile size.") 44 | if self.tile is not None: 45 | raise ValueError(f"ScaleSpec: second tile spec found in '{self.name}'") 46 | self.tile = tile 47 | self.dim = -1 if dim is None else dim 48 | self.is_channel, self.is_tile = self.tile == 0, self.tile != 0 49 | 50 | def _decode(self, code: str) -> None: 51 | """Sets fields based on input string code.""" 52 | self.name = code = code.lower() 53 | # The string is one or two NumberSpec codes (scale and optional zero point) and 54 | # an optional tile spec code, separated by underscores. The tilespec is tXdY, where X is the tile size 55 | # and Y is the dimension of the tile. If omitted, this is a tensor scale; if present with X=0, it is a 56 | # channel scale. If the dimension is omitted, it defaults to -1, or the last dimension of the tensor. 57 | for segment in code.split("_"): 58 | if NumberSpec.valid(segment): 59 | self._set_nspec(segment) 60 | elif m := re.fullmatch(r"t(\d+)(d\d+)?", segment): 61 | if self.tile is not None: 62 | raise ValueError(f"ScaleSpec: spurious code segment '{segment}' found after tile spec.") 63 | self._set_tile(int(m.group(1)), int(m.group(2)[1:]) if m.group(2) else -1) 64 | else: 65 | raise ValueError(f"ScaleSpec: '{segment}' is neither a valid number or tile specification.") 66 | if self.tile is None: 67 | self.is_tensor = True 68 | if self.scale is None: 69 | raise ValueError(f"ScaleSpec: scale spec '{code}' has no valid NumberSpec name for scale.") 70 | 71 | def _check(self): 72 | # TODO(ericd): additional checks for bad/unsupported combinations of values that parsed correctly 73 | prefix = f"ScaleSpec: '{self.name}'" 74 | if not self.scale: 75 | raise ValueError(f"{prefix} does not specify a scale.") 76 | if self.zero: 77 | if not self.scale.is_float: 78 | raise ValueError(f"{prefix} asymmetric scaling requires a float scale") 79 | if self.zero.is_exponent or self.zero.is_uint: 80 | raise ValueError(f"{prefix} asymmetric scaling requires a float or int zero point type") 81 | 82 | def reshape_tensor(self, tensor: torch.Tensor) -> torch.Tensor: 83 | """Reshape and/or transpose tensor for scaling.""" 84 | if not self.is_tile: 85 | return tensor 86 | tensor = tensor.transpose(self.dim, -1) 87 | self.shape = tensor.shape 88 | return tensor.reshape(-1, self.tile) 89 | 90 | def revert_tensor(self, tensor: torch.Tensor) -> torch.Tensor: 91 | """Revert tensor shape after scaling.""" 92 | if not self.is_tile: 93 | return tensor 94 | assert self.shape 95 | tensor = tensor.reshape(self.shape) 96 | self.shape = None 97 | return tensor.transpose(self.dim, -1) 98 | 99 | @classmethod 100 | def valid(cls, code: str) -> bool: 101 | """Checks validity without raising an exception.""" 102 | try: 103 | cls(code) 104 | return True 105 | except (ValueError, NotImplementedError): 106 | return False 107 | 108 | -------------------------------------------------------------------------------- /tcast/utils.py: -------------------------------------------------------------------------------- 1 | """TensorCast: Conversion and compression of arbitrary datatypes.""" 2 | # tcast/utils.py: utility functions for tensorcast package 3 | 4 | import importlib 5 | from typing import Any 6 | 7 | import torch 8 | 9 | 10 | class TensorCastInternalError(RuntimeError): 11 | """For internal errors being reported as such.""" 12 | 13 | def __init__(self, error_message: str): 14 | super().__init__(self, error_message) 15 | 16 | 17 | def is_installed(name: str) -> bool: 18 | """Without importing, see if a package is present.""" 19 | return importlib.util.find_spec(name) is not None 20 | 21 | 22 | def is_gpu_available(): 23 | """Check to see if a GPU is present.""" 24 | return torch.cuda.is_available() 25 | 26 | 27 | def is_float8_available(): 28 | """Check to see if float8 is present in this version of PyTorch.""" 29 | return hasattr(torch, "float8_e4m3fn") 30 | 31 | def is_float8_fnuz_available(): 32 | """Check to see if float8 is present in this version of PyTorch.""" 33 | return hasattr(torch, "float8_e4m3fnuz") 34 | 35 | 36 | def printoptions(precision: int = 8): 37 | """Set PyTorch printoptions to something useful.""" 38 | torch.set_printoptions(precision=precision, sci_mode=False) 39 | 40 | 41 | def litvals(lit) -> tuple[Any]: 42 | """Get a tuple of Literal values.""" 43 | return lit.__args__ 44 | 45 | 46 | def check_literal(s: str, lit, none_ok=False) -> None: 47 | """Check that a string is a literal from a given Literal.""" 48 | vals = litvals(lit) 49 | if s not in vals and (s is not None or not none_ok): 50 | raise ValueError(f"{s} is not a valid {lit.__name__} value: {','.join(vals)}") 51 | return s 52 | 53 | def is_power_of_2(n: int) -> bool: 54 | """Check for power of 2.""" 55 | return (n & (n - 1)) == 0 56 | -------------------------------------------------------------------------------- /tcast/version.txt: -------------------------------------------------------------------------------- 1 | 0.1.2 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROCm/tensorcast/8a0945bf9cb4a62622da3cc1898a18b969f7b31b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_bfp.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tcast 3 | import torch 4 | from tests.utils import compare_2, tensor_to_bfp 5 | 6 | @pytest.mark.parametrize("datatype", ['bfp16', 'bfp15', 'bfp14', 'bfp13']) 7 | @pytest.mark.parametrize("roundmode", ["even", "nearest"]) 8 | @pytest.mark.parametrize("block_size", ["8", "16", "32"]) 9 | 10 | def test_bfp(datatype, roundmode, block_size): 11 | tensor = torch.randn(16, 1024).float() 12 | p1 = "int"+str(int(datatype[3:])-8) 13 | p2 = "e8m0_t"+block_size 14 | tcast_dt = tcast.datatype(p1, p2) 15 | tensor_bfp = tensor_to_bfp(tensor, 1, tcast_dt, roundmode) 16 | tensor_tcast = tcast.cast(tensor, dtype=tcast_dt, roundmode=roundmode) 17 | compare_2(tensor_bfp, tensor_tcast) 18 | 19 | -------------------------------------------------------------------------------- /tests/test_bfp_export.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tcast 3 | import torch 4 | from tests.utils import compare_2, tensor_to_bfp 5 | from struct import pack, unpack 6 | import numpy as np 7 | 8 | @pytest.mark.parametrize("datatype", ['bfp16', 'bfp15', 'bfp14', 'bfp13']) 9 | @pytest.mark.parametrize("roundmode", ["even", "nearest"]) 10 | @pytest.mark.parametrize("block_size", ["8", "16", "32"]) 11 | 12 | def test_bfp(datatype, roundmode, block_size): 13 | tensor = (torch.randint(-2048, 2048, (16, 1024))*torch.randn(16, 1024)).float() 14 | p1 = "int"+str(int(datatype[3:])-8) 15 | p2 = "e8m0_t"+block_size 16 | tcast_dt = tcast.datatype(p1, p2, export=True) 17 | tensor_bfp = tensor_to_bfp(tensor, 1, tcast_dt, roundmode) 18 | tensor_tcast_d = tcast.cast(tensor, dtype=tcast_dt, roundmode=roundmode) 19 | tensor_tcast = tensor_tcast_d["x"] 20 | compare_2(tensor_bfp, tensor_tcast) 21 | -------------------------------------------------------------------------------- /tests/test_mx.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tcast 3 | import torch 4 | from tests.utils import compare_2 5 | try: 6 | from mx.mx_ops import quantize_mx_op 7 | from mx.elemwise_ops import quantize_elemwise_op 8 | from mx.specs import MxSpecs 9 | MX_AVAILABLE = True 10 | 11 | except ImportError: 12 | MX_AVAILABLE = False 13 | 14 | @pytest.mark.parametrize("datatype", ['float16', 'bfloat16']) 15 | @pytest.mark.parametrize("roundmode", ["even", "nearest"]) 16 | @pytest.mark.skipif(not MX_AVAILABLE, reason="MX library is not available. github.com/microsoft/microxcaling") 17 | 18 | def test_mx_unscaled_datatypes(datatype, roundmode): 19 | tensor = torch.randn(1024, 1024).float() 20 | mx_specs = MxSpecs() 21 | if datatype == 'float16': 22 | mx_specs['fp'] = 16 23 | elif datatype == 'bfloat16': 24 | mx_specs['bfloat'] = 16 25 | mx_specs['round'] = roundmode 26 | tensor_mx = quantize_elemwise_op(tensor, mx_specs) 27 | if 'fp' in datatype: 28 | datatype = datatype[4:] 29 | tcast_dt = tcast.datatype(datatype) 30 | tensor_tcast = tcast.cast(tensor, dtype=tcast_dt, roundmode=roundmode) 31 | compare_2(tensor_mx, tensor_tcast) 32 | 33 | @pytest.mark.parametrize("datatype", ['int8', 'int4', 'fp8_e5m2', 'fp8_e4m3', 'fp6_e3m2', 'fp6_e2m3', 'fp4_e2m1']) 34 | @pytest.mark.parametrize("roundmode", ["even", "nearest"]) 35 | @pytest.mark.skipif(not MX_AVAILABLE, reason="MX library is not available. github.com/microsoft/microxcaling") 36 | 37 | def test_mx_scaled_datatypes(datatype, roundmode): 38 | tensor = torch.randn(1024, 1024).float() 39 | mx_specs = MxSpecs() 40 | mx_specs['block_size'] = 32 41 | mx_specs['round'] = roundmode 42 | tensor_mx = quantize_mx_op(tensor, mx_specs, elem_format=datatype, axes=-1, round=roundmode) 43 | if 'fp' in datatype: 44 | datatype = datatype[4:] 45 | if 'e4' in datatype: 46 | datatype += 'fn' 47 | elif 'e3' in datatype or 'e2' in datatype: 48 | datatype += 'fnuz' 49 | tcast_dt = tcast.datatype(datatype, "e8m0_t32") 50 | tensor_tcast = tcast.cast(tensor, dtype=tcast_dt, roundmode=roundmode, scalemode="max") 51 | compare_2(tensor_mx, tensor_tcast) 52 | -------------------------------------------------------------------------------- /tests/test_torch.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tcast 3 | import torch 4 | from tests.utils import compare_2 5 | 6 | 7 | @pytest.mark.parametrize("datatype", ['float16', 'bfloat16','float8_e5m2', 'float8_e5m2fnuz', 'float8_e4m3fn', 'float8_e4m3fnuz']) 8 | 9 | def test_torch_datatypes(datatype): 10 | tensor = torch.randn(1024, 1024).float() 11 | tensor_torch = tensor.to(getattr(torch, datatype)) 12 | tensor_torch = tensor_torch.float() 13 | if 'float8_' in datatype: 14 | if not tcast.utils.is_float8_available(): 15 | pytest.skip("Skipping because float8 is not available") 16 | if not tcast.utils.is_float8_fnuz_available(): 17 | pytest.skip("Skipping because float8 fnuz is not available") 18 | tcast_dt = tcast.datatype(datatype) 19 | tensor_tcast = tcast.cast(tensor, dtype=tcast_dt, roundmode="even") 20 | compare_2(tensor_torch, tensor_tcast) 21 | 22 | 23 | if __name__ == "__main__": 24 | test_torch_datatypes('float8_e4m3fnuz') 25 | 26 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tcast.cast import RoundMode 3 | import struct 4 | import random 5 | 6 | def compare_2(tensor1, tensor2): 7 | if torch.allclose(tensor1, tensor2): 8 | return 9 | else: 10 | raise ValueError 11 | 12 | def float_to_bits(value_in_float): 13 | s = struct.pack('@f', value_in_float) 14 | return struct.unpack('@I', s)[0] 15 | 16 | 17 | def bits_to_float(value_in_uint): 18 | s = struct.pack('@I', value_in_uint) 19 | return struct.unpack('@f', s)[0] 20 | 21 | 22 | def get_leading_zeros(uival): 23 | andv = 0x80000000 24 | rval = 0 25 | for i in range(31): 26 | if uival & andv == 0: 27 | rval +=1 28 | andv = andv >> 1 29 | else: 30 | return rval 31 | return rval 32 | 33 | 34 | def round_func(sign, mantisa, rmode, mbits, mantisaNotScaled,scale): 35 | masks = [0x0, 0x400000, 0x600000, 0x700000, 0x780000, 0x7C0000, 0x7E0000, 0x7F0000, 0x7F8000, 0x7FC000, 0x7FE000] 36 | adds = [0x0, 0x400000, 0x200000, 0x100000, 0x80000, 0x40000, 0x20000, 0x10000, 0x8000, 0x4000, 0x2000, 0x0] 37 | opmasks = [0x00000000, 0xffffffff] 38 | mask = masks[mbits] 39 | add = adds[mbits] 40 | if rmode == "zero": 41 | return mantisa & mask 42 | elif rmode == "even": 43 | if (mantisa&(~mask) == adds[mbits+1]) and not (mantisa&add): 44 | #print("EVEN rounding needed on scaled mantissa: %x, will use full mantissa maks: %x" %(mantisa,pow(2,(23 - (mbits - scale)))-1)) 45 | if (pow(2,(23 - (mbits - scale)))-1)&mantisaNotScaled != 0: #consider all relevant mantissa bits before scaling (including implicit one scaling) as in SQT 46 | #print("FORCE EVEN ROUNDING like in SQT GPUcore and torchCore") 47 | add = adds[mbits + 1] 48 | mantisa += add & opmasks[((mantisa & mask) != mask)] 49 | return mantisa & mask 50 | else: 51 | return mantisa & mask 52 | else: 53 | add = adds[mbits + 1] 54 | mantisa += add & opmasks[((mantisa & mask) != mask)] 55 | return mantisa & mask 56 | elif rmode == "nearest": 57 | add = adds[mbits + 1] 58 | mantisa += add & opmasks[((mantisa & mask) != mask)] 59 | return mantisa & mask 60 | else: 61 | raise NotImplementedError(f"RoundMode {rmode.name} is not currently implemented in round_func") 62 | 63 | 64 | 65 | def float_to_bfp(fval, max_exp, rmode, mbits): 66 | bits = float_to_bits(fval) # bits in form of uint32 67 | sign = bits & 0x80000000 # sign bit 68 | exp = (bits & 0x7F800000) >> 23 # exponent 69 | 70 | if (exp == 0) and (max_exp !=0): 71 | correctedExponent = 1 72 | else: 73 | correctedExponent = exp 74 | 75 | scale = max(0, max_exp - correctedExponent) # scale required to go to maxexp 76 | mant = bits & 0x7FFFFF # mantisa bits 77 | 78 | #print("mant: %x, maxExp: %d, exp: %d, mbits: %d, scale: %d" % (mant,max_exp,exp,mbits,scale)) 79 | 80 | if (exp == 0): #subnormal 81 | mantScaled = mant >> scale # scale to max exponent 82 | #print("subnorm mantissa before calling round: %x" %(mantScaled)) 83 | mant = round_func(sign, mantScaled, rmode, mbits, mant, scale) # rounding 84 | if mant == 0: 85 | qbits = sign | mant 86 | else: 87 | if (max_exp == 0): 88 | qbits = sign | mant 89 | else: 90 | lziro = get_leading_zeros(mant<<9) # 9: 1bit sign and 8 bits exponent are not considered 91 | if (lziro == 0): #this is only possible when maxExp equals 1 and we scaled with 0 92 | qbits = sign | mant 93 | #print("subnormToREGULAR: need to keep it a SUBNORM") 94 | elif lziro > max_exp: 95 | qbits = sign 96 | #print("subnormToREGULAR: keep only sign so results in ZERO") 97 | elif lziro == max_exp: 98 | qbits = sign | mant 99 | #print("subnormToREGULAR: need to keep it a SUBNORM") 100 | else: 101 | mant = (mant << (lziro+1)) & 0x7FFFFF # scale back so implicit 1 is bit23 and remove it 102 | qbits = sign | ((max_exp-lziro)<<23) | mant 103 | #print("subnormToREGULAR: upgraded to regular") 104 | else: 105 | mantScaled = (mant >> 1) | 0x400000 # insert implicit 1. Since we are quantizing, and bit0 is being eliminated, we don't care about bit0 - might be important for rounding. We will see in the tests .. 106 | mantScaled = mantScaled >> scale # scale to max exponent 107 | #print("mantissa before calling round: %x" %(mantScaled)) 108 | mant = round_func(sign, mantScaled, rmode, mbits,mant,scale) # rounding 109 | if mant == 0: 110 | qbits = sign | mant 111 | else: 112 | lziro = get_leading_zeros(mant<<9) # 9: 1bit sign and 8 bits exponent are not considered 113 | if (lziro >= max_exp): 114 | mant = (mant << max_exp) & 0x7FFFFF 115 | qbits = sign | mant # make subnorm 116 | #print("regularToSUBNORM: need to make it a SUBNORM") 117 | else: 118 | mant = (mant << (lziro+1)) & 0x7FFFFF # scale back so implicit 1 is bit23 and remove it 119 | qbits = sign | ((max_exp-lziro)<<23) | mant 120 | 121 | #print("qbits: %x, mant: %x, maxExp: %d, exp: %d, mbits: %d" % (qbits,mant,max_exp,exp,mbits)) 122 | 123 | return bits_to_float(qbits) 124 | 125 | 126 | def block_to_bfp(tensor, dtype, rmode): 127 | # input tensor is 1D and implicitly, the tensor size is our blocksize 128 | assert tensor.ndim == 1 129 | 130 | #print(f"tensor: ${tensor}") 131 | 132 | # shared exponent is largest number's exponent 133 | max_idx = tensor.abs().argmax() 134 | bits = float_to_bits(tensor[max_idx].item()) 135 | max_exp = (bits >> 23) & 0xFF 136 | tq = torch.zeros_like(tensor) 137 | 138 | for i in range(tensor.shape[0]): 139 | tq[i] = float_to_bfp(tensor[i].item(), max_exp, rmode, dtype.nspec.mbits+1) 140 | 141 | return tq 142 | 143 | def tensor_to_bfp(tensor, axis, dtype, rmode): 144 | blocksize = dtype.sspec.tile 145 | ftensor = tensor.nan_to_num(nan=0.0).float() # d0, d1, ... d_(axis), ..., dn 146 | ftensor = ftensor.permute(*[i for i in range(axis)], *[i for i in range(axis+1,ftensor.ndim)], axis) # d0, d1, ..., dn, d_(axis) 147 | ftensor = ftensor.unsqueeze(ftensor.ndim) # d0, d1, ..., dn, d_(axis), 1 148 | ftensor = ftensor.reshape(*ftensor.shape[:-2], ftensor.shape[ftensor.ndim-2]//blocksize, blocksize) # d0, d1, ..., dn, d_(axis)/blocksize, blocksize 149 | fshape = ftensor.shape 150 | ftensor = ftensor.reshape(-1, blocksize) 151 | for i in range(ftensor.shape[0]): 152 | ftensor[i] = block_to_bfp(ftensor[i], dtype, rmode) 153 | ftensor = ftensor.reshape(*fshape) # d0, d1, ..., dn, d_(axis)/blocksize, blocksize 154 | ftensor = ftensor.reshape(*ftensor.shape[:-2], ftensor.shape[ftensor.ndim-2]*blocksize) # d0, d1, ..., dn, d_(axis) 155 | ftensor = ftensor.permute(*[i for i in range(axis)], ftensor.ndim-1, *[i for i in range(axis,ftensor.ndim-1)]) # d0, d1, ..., d_(axis), ..., dn 156 | return ftensor 157 | --------------------------------------------------------------------------------